Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import subprocess | |
| import sys | |
| import io | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import spaces | |
| import torch | |
| from diffusers import Flux2Pipeline, Flux2Transformer2DModel | |
| from diffusers import BitsAndBytesConfig as DiffBitsAndBytesConfig | |
| from optimization import optimize_pipeline_ | |
| import requests | |
| from PIL import Image | |
| import json | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| def remote_text_encoder(prompts): | |
| response = requests.post( | |
| "https://remote-text-encoder-flux-2.huggingface.co/predict", | |
| json={"prompt": prompts}, | |
| headers={ | |
| "Authorization": f"Bearer {os.environ['HF_TOKEN']}", | |
| "Content-Type": "application/json", | |
| }, | |
| ) | |
| assert response.status_code == 200, f"{response.status_code=}" | |
| prompt_embeds = torch.load(io.BytesIO(response.content)) | |
| return prompt_embeds | |
| # Load model | |
| repo_id = "black-forest-labs/FLUX.2-dev" | |
| dit = Flux2Transformer2DModel.from_pretrained( | |
| repo_id, subfolder="transformer", torch_dtype=torch.bfloat16 | |
| ) | |
| pipe = Flux2Pipeline.from_pretrained( | |
| repo_id, text_encoder=None, transformer=dit, torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.to("cuda") | |
| pipe.transformer.set_attention_backend("_flash_3_hub") | |
| try: | |
| optimize_pipeline_( | |
| pipe, | |
| image=[Image.new("RGB", (1024, 1024))], | |
| prompt_embeds=remote_text_encoder("prompt").to("cuda"), | |
| guidance_scale=2.5, | |
| width=1024, | |
| height=1024, | |
| num_inference_steps=1, | |
| ) | |
| except Exception as e: | |
| print(f"Optimization failed: {e}") | |
| def get_duration( | |
| prompt, | |
| input_images=None, | |
| seed=42, | |
| randomize_seed=False, | |
| width=1024, | |
| height=1024, | |
| num_inference_steps=50, | |
| guidance_scale=2.5, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| num_images = 0 if input_images is None else len(input_images) | |
| step_duration = 1 + 0.7 * num_images | |
| return num_inference_steps * step_duration + 10 | |
| def infer( | |
| prompt, | |
| input_images=None, | |
| seed=42, | |
| randomize_seed=False, | |
| width=1024, | |
| height=1024, | |
| num_inference_steps=50, | |
| guidance_scale=2.5, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| # Get prompt embeddings from remote text encoder | |
| progress(0.1, desc="Encoding prompt...") | |
| try: | |
| prompt_embeds = remote_text_encoder(prompt).to("cuda") | |
| except Exception as e: | |
| raise gr.Error(f"Remote text encoder failed: {e}") | |
| # Prepare image list (convert None or empty gallery to None) | |
| image_list = None | |
| if input_images is not None and len(input_images) > 0: | |
| image_list = [] | |
| for item in input_images: | |
| image_list.append(item[0]) | |
| # Generate image | |
| progress(0.3, desc="Generating image...") | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| image = pipe( | |
| prompt_embeds=prompt_embeds, | |
| image=image_list, | |
| width=width, | |
| height=height, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| ).images[0] | |
| return image, seed | |
| # --- UI Configuration --- | |
| css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap'); | |
| html, body, .gradio-container { | |
| background-color: #000000 !important; | |
| color: #ffffff !important; | |
| font-family: 'Inter', sans-serif !important; | |
| margin: 0; | |
| padding: 0 !important; | |
| overflow: hidden !important; | |
| height: 100vh !important; | |
| max-height: 100vh !important; | |
| width: 100vw !important; | |
| max-width: 100vw !important; | |
| --color-background-primary: #000000; | |
| --color-background-secondary: #050505; | |
| --color-border-primary: #171717; | |
| --color-text-primary: #ffffff; | |
| --color-text-secondary: #a3a3a3; | |
| } | |
| footer { | |
| display: none !important; | |
| } | |
| /* Layout */ | |
| #main-container { | |
| position: fixed !important; | |
| top: 0; | |
| left: 0; | |
| width: 100vw !important; | |
| height: 100vh !important; | |
| gap: 0 !important; | |
| display: flex; | |
| flex-wrap: nowrap; | |
| overflow: hidden; | |
| z-index: 10; | |
| } | |
| #right-sidebar { | |
| background-color: #050505; | |
| border-left: 1px solid #171717; | |
| width: 320px !important; | |
| max-width: 320px !important; | |
| flex: none !important; | |
| padding: 0 !important; | |
| height: 100%; | |
| overflow-y: auto; | |
| } | |
| #center-canvas { | |
| background-color: #090909; | |
| flex-grow: 1 !important; | |
| display: flex; | |
| flex-direction: column; | |
| justify-content: center; | |
| align-items: center; | |
| padding: 20px; | |
| background-image: radial-gradient(#151515 1px, transparent 1px); | |
| background-size: 20px 20px; | |
| height: 100%; | |
| position: relative; | |
| } | |
| /* Components */ | |
| #generate-btn { | |
| background: #ffffff !important; | |
| color: #000000 !important; | |
| border-radius: 6px !important; | |
| font-weight: 600 !important; | |
| text-transform: uppercase; | |
| font-size: 11px !important; | |
| border: none !important; | |
| } | |
| #prompt-input textarea { | |
| background-color: #000000 !important; | |
| border: 1px solid #262626 !important; | |
| color: white !important; | |
| border-radius: 8px !important; | |
| } | |
| #prompt-input span { | |
| display: none; /* Hide default label if needed, or style it */ | |
| } | |
| /* Accordions */ | |
| .accordion { | |
| background: transparent !important; | |
| border: none !important; | |
| border-bottom: 1px solid #171717 !important; | |
| } | |
| .accordion-label { | |
| font-size: 11px !important; | |
| font-weight: 600 !important; | |
| text-transform: uppercase; | |
| color: #a3a3a3 !important; | |
| } | |
| /* Sliders */ | |
| input[type=range] { | |
| accent-color: white !important; | |
| } | |
| /* Gallery in Sidebar */ | |
| #history-gallery { | |
| flex-grow: 1; | |
| overflow-y: auto; | |
| padding: 10px; | |
| } | |
| #history-gallery .grid-wrap { | |
| grid-template-columns: 1fr !important; /* Force list view */ | |
| } | |
| /* Main Image */ | |
| #main-image { | |
| background: transparent !important; | |
| border: 1px solid #171717; | |
| border-radius: 8px; | |
| overflow: hidden; | |
| box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 10px 10px -5px rgba(0, 0, 0, 0.04); | |
| } | |
| /* Scrollbars */ | |
| ::-webkit-scrollbar { | |
| width: 6px; | |
| height: 6px; | |
| } | |
| ::-webkit-scrollbar-track { | |
| background: #000000; | |
| } | |
| ::-webkit-scrollbar-thumb { | |
| background: #333; | |
| border-radius: 3px; | |
| } | |
| """ | |
| controls_header_html = """ | |
| <div style="padding: 20px 20px 10px 20px;"> | |
| <h2 style="font-size: 11px; font-weight: 600; text-transform: uppercase; color: #666; margin: 0;">Configuration</h2> | |
| </div> | |
| """ | |
| with gr.Blocks(title="FLUX.2 [dev]") as demo: | |
| with gr.Row(elem_id="main-container", variant="compact"): | |
| # --- Center Canvas --- | |
| with gr.Column(elem_id="center-canvas"): | |
| with gr.Row(elem_id="canvas-toolbar"): | |
| gr.Markdown("Canvas", elem_id="canvas-info") | |
| result_image = gr.Image( | |
| elem_id="main-image", interactive=False, show_label=False | |
| ) | |
| # --- Right Sidebar --- | |
| with gr.Column(elem_id="right-sidebar", min_width=320): | |
| gr.HTML(controls_header_html) | |
| # Prompt Section | |
| prompt = gr.Textbox( | |
| elem_id="prompt-input", | |
| lines=4, | |
| placeholder="Describe your imagination...", | |
| label="Prompt", | |
| show_label=True, | |
| ) | |
| run_button = gr.Button("Generate Image", elem_id="generate-btn") | |
| # Settings | |
| input_images = gr.Gallery( | |
| label="Input Image(s)", type="pil", columns=3, rows=1 | |
| ) | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=32, | |
| value=1024, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=32, | |
| value=1024, | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=4 | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", minimum=1, maximum=100, step=1, value=30 | |
| ) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| # Wiring | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn=infer, | |
| inputs=[ | |
| prompt, | |
| input_images, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| num_inference_steps, | |
| guidance_scale, | |
| ], | |
| outputs=[result_image, seed], | |
| ) | |
| demo.launch(css=css) | |