| import gradio as gr |
| from PIL import Image |
|
|
| import torch |
| import re |
| import os |
| import requests |
|
|
| from customization import customize_vae_decoder |
| from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler, EulerDiscreteScheduler |
| from torchvision import transforms |
| from attribution import MappingNetwork |
|
|
| import math |
| from typing import List |
| from PIL import Image, ImageChops |
| import numpy as np |
| import torch |
|
|
|
|
| PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/" |
|
|
|
|
| def get_image_grid(images: List[Image.Image]) -> Image: |
| num_images = len(images) |
| cols = 3 |
| rows = 1 |
| width, height = images[0].size |
| grid_image = Image.new('RGB', (cols * width, rows * height)) |
| for i, img in enumerate(images): |
| x = i % cols |
| y = i // cols |
| grid_image.paste(img, (x * width, y * height)) |
| return grid_image |
|
|
|
|
| class AttributionModel: |
| def __init__(self): |
| is_cuda = False |
| if torch.cuda.is_available(): |
| is_cuda = True |
| |
| scheduler = EulerDiscreteScheduler.from_pretrained('stabilityai/stable-diffusion-2', subfolder="scheduler") |
| self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2', scheduler=scheduler) |
| if is_cuda: |
| self.pipe = self.pipe.to("cuda") |
| self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR) |
| self.vae = AutoencoderKL.from_pretrained( |
| 'stabilityai/stable-diffusion-2', subfolder="vae" |
| ) |
| self.vae = customize_vae_decoder(self.vae, 128, "deqkv", "all", False, 1.0) |
|
|
| self.mapping_network = MappingNetwork(32, 0, 128, None, num_layers=2, w_avg_beta=None, normalization = False) |
| |
| from torchvision.models import resnet50, ResNet50_Weights |
| self.decoding_network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) |
| self.decoding_network.fc = torch.nn.Linear(2048,32) |
| |
| self.vae.decoder.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'vae_decoder.pth'))) |
| self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth'))) |
| self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth'))) |
|
|
| if is_cuda: |
| self.vae = self.vae.to("cuda") |
| self.mapping_network = self.mapping_network.to("cuda") |
| self.decoding_network = self.decoding_network.to("cuda") |
|
|
| self.test_norm = transforms.Compose( |
| [ |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| ] |
| ) |
|
|
| def infer(self, prompt, negative, steps, guidance_scale): |
| with torch.no_grad(): |
| out_latents = self.pipe([prompt], negative_prompt=[negative], output_type="latent", num_inference_steps=steps, guidance_scale=guidance_scale).images |
| image_attr = self.inference_with_attribution(out_latents) |
| image_attr_pil = self.pipe.numpy_to_pil(image_attr[0]) |
|
|
| image_org = self.inference_without_attribution(out_latents) |
| image_org_pil = self.pipe.numpy_to_pil(image_org[0]) |
|
|
| |
| diff_factor = 5 |
| image_diff_pil = ImageChops.difference(image_org_pil[0], image_attr_pil[0]).convert("RGB", (diff_factor,0,0,0,0,diff_factor,0,0,0,0,diff_factor,0)) |
|
|
| return image_org_pil[0], image_attr_pil[0], image_diff_pil |
|
|
| def inference_without_attribution(self, latents): |
| latents = 1 / 0.18215 * latents |
| with torch.no_grad(): |
| image = self.pipe.vae.decode(latents).sample |
| image = image.clamp(-1,1) |
| image = (image / 2 + 0.5).clamp(0, 1) |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| return image |
|
|
| def get_phis(self, phi_dimension, batch_size ,eps = 1e-8): |
| phi_length = phi_dimension |
| b = batch_size |
| phi = torch.empty(b,phi_length).uniform_(0,1) |
| return torch.bernoulli(phi) + eps |
|
|
|
|
| def inference_with_attribution(self, latents, key=None): |
| if key==None: |
| key = self.get_phis(32, 1) |
|
|
| latents = 1 / 0.18215 * latents |
| with torch.no_grad(): |
| image = self.vae.decode(latents, self.mapping_network(key.cuda())).sample |
| image = image.clamp(-1,1) |
| image = (image / 2 + 0.5).clamp(0, 1) |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| return image |
|
|
| def postprocess(self, image): |
| image = self.resize_transform(image) |
| return image |
|
|
| def detect_key(self, image): |
| reconstructed_keys = self.decoding_network(self.test_norm((image / 2 + 0.5).clamp(0, 1))) |
| return reconstructed_keys |
|
|
|
|
| attribution_model = AttributionModel() |
| def get_images(prompt, negative, steps, guidence_scale): |
| x1, x2, x3 = attribution_model.infer(prompt, negative, steps, guidence_scale) |
| return [x1, x2, x3] |
|
|
|
|
| image_examples = [ |
| ["A pikachu fine dining with a view to the Eiffel Tower", "low quality", 50, 10], |
| ["A mecha robot in a favela in expressionist style", "low quality, 3d, photorealistic", 50, 10] |
| ] |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown( |
| """<h1 style="text-align: center;"><b>WOUAF: |
| Weight Modulation for User Attribution and Fingerprinting in Text-to-Image Diffusion Models</b> <br> <a href="https://wouaf.vercel.app">Project Page</a></h1>""") |
|
|
| with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): |
| with gr.Column(): |
| text = gr.Textbox( |
| label="Enter your prompt", |
| show_label=False, |
| max_lines=1, |
| placeholder="Enter your prompt", |
| elem_id="prompt-text-input", |
| ).style( |
| border=(True, False, True, True), |
| rounded=(True, False, False, True), |
| container=False, |
| ) |
| negative = gr.Textbox( |
| label="Enter your negative prompt", |
| show_label=False, |
| max_lines=1, |
| placeholder="Enter a negative prompt", |
| elem_id="negative-prompt-text-input", |
| ).style( |
| border=(True, False, True, True), |
| rounded=(True, False, False, True), |
| container=False, |
| ) |
|
|
| with gr.Row(): |
| steps = gr.Slider(label="Steps", minimum=45, maximum=55, value=50, step=1) |
| guidance_scale = gr.Slider( |
| label="Guidance Scale", minimum=0, maximum=10, value=7.5, step=0.1 |
| ) |
|
|
| with gr.Row(): |
| btn = gr.Button(value="Generate Image", full_width=False) |
|
|
| with gr.Row(): |
| im_2 = gr.Image(type="pil", label="without attribution") |
| im_3 = gr.Image(type="pil", label="**with** attribution") |
| im_4 = gr.Image(type="pil", label="pixel-wise difference multiplied by 5") |
|
|
| |
| btn.click(get_images, inputs=[text, negative, steps, guidance_scale], outputs=[im_2, im_3, im_4]) |
|
|
| gr.Examples( |
| examples=image_examples, |
| inputs=[text, negative, steps, guidance_scale], |
| outputs=[im_2, im_3, im_4], |
| fn=get_images, |
| cache_examples=True, |
| ) |
|
|
| gr.HTML( |
| """ |
| <div class="footer"> |
| <p>Pre-trained model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a> |
| </p> |
| <p> |
| Fine-tuned by authors for research purpose. |
| </p> |
| </div> |
| """ |
| ) |
| with gr.Accordion(label="Ethics & Privacy", open=False): |
| gr.HTML( |
| """<div class="acknowledgments"> |
| <p><h4>Privacy</h4> |
| We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI. |
| <p><h4>Biases and content acknowledgment</h4> |
| This model will have the same biases as Stable Diffusion V2.1 </div> |
| """ |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|