Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import requests | |
| from typing import List, Tuple, Optional | |
| from PIL import Image | |
| from io import BytesIO | |
| try: | |
| from google import genai | |
| except ImportError: | |
| genai = None | |
| try: | |
| from openai import OpenAI | |
| import gradio as gr | |
| except ImportError: | |
| OpenAI = None | |
| gr = None # Keep gr dependency localized for error handling | |
| # Import internal utility (if needed, though base64 encoding logic is placed inline) | |
| # from utils import image_to_base664 | |
| # --- API Clients Initialization --- | |
| GEMINI_CLIENT = None | |
| OPENAI_CLIENT = None | |
| if genai: | |
| try: | |
| # Client initialization relies on GEMINI_API_KEY environment variable | |
| GEMINI_CLIENT = genai.Client() | |
| except Exception: | |
| # Fail silently if key is missing, handle error in function call | |
| pass | |
| if OpenAI: | |
| try: | |
| # Client initialization relies on OPENAI_API_KEY environment variable | |
| OPENAI_CLIENT = OpenAI() | |
| except Exception: | |
| # Fail silently if key is missing, handle error in function call | |
| pass | |
| def get_generation_prompt( | |
| model_choice: str, | |
| prompt: str, | |
| image_paths: List[str] | |
| ) -> str: | |
| """Analyzes the images and prompt using the selected multimodal model to generate a detailed prompt for DALL-E 3.""" | |
| print(f"--- Analyzing inputs using {model_choice} ---") | |
| analysis_prompt = ( | |
| f"You are an expert creative director. Based on the {len(image_paths)} input images and their text prompt, " | |
| f"synthesize a new, single, extremely detailed, aesthetic, and descriptive prompt (max 500 characters) " | |
| f"suitable for a cutting-edge text-to-image generator like DALL-E 3. " | |
| f"The resulting image must be a 'remix' or fusion incorporating key visual, thematic, and stylistic elements " | |
| f"from all available images, guided by the text prompt: '{prompt}'." | |
| f"Focus on composition, lighting, style, mood, and texture. Do not mention 'image', 'remix', or 'input images' in the output. " | |
| f"Output ONLY the final descriptive prompt text, nothing else." | |
| ) | |
| # Load images as PIL objects | |
| images = [Image.open(path) for path in image_paths if path] | |
| # --- GEMINI Analysis Path --- | |
| if model_choice == 'gemini-2': | |
| if not GEMINI_CLIENT: | |
| return f"Gemini API Key missing or client failed to initialize. Fallback prompt: Fusion of provided visual elements inspired by the prompt: {prompt}." | |
| try: | |
| # Contents should be images first, then the text prompt | |
| contents = images + [analysis_prompt] | |
| response = GEMINI_CLIENT.models.generate_content( | |
| model='gemini-2.0-flash-live', | |
| contents=contents | |
| ) | |
| expanded_prompt = response.text.strip() | |
| print(f"Gemini Analysis Output: {expanded_prompt}") | |
| return expanded_prompt | |
| except Exception as e: | |
| print(f"Gemini API Error: {e}") | |
| return f"Error using Gemini for analysis. Fallback prompt: Creative fusion of the three elements provided, inspired by the theme: {prompt}." | |
| # --- GPT Analysis Path --- | |
| elif model_choice == 'gpt image-1': | |
| if not OPENAI_CLIENT: | |
| return f"OpenAI API Key missing or client failed to initialize. Fallback prompt: Fusion of provided visual elements inspired by the prompt: {prompt}." | |
| try: | |
| # Prepare contents for gpt-image-1-low with base64 encoded images | |
| contents = [ | |
| {"type": "text", "text": analysis_prompt} | |
| ] | |
| for img in images: | |
| buffered = BytesIO() | |
| # Use JPEG to reduce payload size | |
| img.save(buffered, format="JPEG") | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| contents.append({ | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{img_base64}", | |
| "detail": "low" # Use low detail for speed | |
| } | |
| }) | |
| response = OPENAI_CLIENT.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "user", "content": contents} | |
| ], | |
| max_tokens=500 | |
| ) | |
| expanded_prompt = response.choices[0].message.content.strip() | |
| print(f"gpt-image-1-low Analysis Output: {expanded_prompt}") | |
| return expanded_prompt | |
| except Exception as e: | |
| print(f"GPT API Error: {e}") | |
| return f"Error using gpt-image-1-low for analysis. Fallback prompt: Creative fusion of the three elements provided, inspired by the theme: {prompt}." | |
| # Fallback if model is unrecognized | |
| return f"Creative synthesis of the visual elements provided, inspired by the prompt: {prompt}. Ensure photorealistic quality." | |
| def generate_remixed_image( | |
| model_choice: str, | |
| prompt: str, | |
| image1_path: Optional[str], | |
| image2_path: Optional[str], | |
| image3_path: Optional[str] | |
| ) -> Tuple[str, Image.Image | None]: | |
| """Orchestrates prompt generation (via selected model) and image synthesis (via DALL-E 3).""" | |
| image_paths = [image1_path, image2_path, image3_path] | |
| valid_paths = [path for path in image_paths if path is not None] | |
| if not OPENAI_CLIENT: | |
| # Raise generic Gradio error if client is missing | |
| if gr: | |
| raise gr.Error("OpenAI client not initialized. Please set OPENAI_API_KEY environment variable.") | |
| else: | |
| raise ValueError("OpenAI client not initialized.") | |
| if not valid_paths: | |
| if gr: | |
| raise gr.Error("Please upload at least one image to remix.") | |
| else: | |
| raise ValueError("No images provided.") | |
| # 1. Generate the optimized DALL-E 3 prompt using the selected analysis model | |
| final_prompt = get_generation_prompt(model_choice, prompt, valid_paths) | |
| print(f"\n--- Final Prompt for DALL-E 3: {final_prompt} ---") | |
| # 2. Generate the image using DALL-E 3 (OpenAI API) | |
| try: | |
| dalle_response = OPENAI_CLIENT.images.generate( | |
| model="gpt-image-1", | |
| prompt=final_prompt, | |
| size="1024x1024", | |
| quality="medium", # valid: low | medium | high | auto | |
| n=1, | |
| ) | |
| b64 = dalle_response.data[0].b64_json | |
| img_bytes = base64.b64decode(b64) | |
| remixed_image = Image.open(BytesIO(img_bytes)).convert("RGB") | |
| return final_prompt, remixed_image | |
| except Exception as e: | |
| print(f"DALL-E 3 Generation Error: {e}") | |
| error_msg = f"Image generation failed: {str(e)}" | |
| if gr: | |
| # Create a placeholder error image for display | |
| placeholder_img = Image.new('RGB', (1024, 1024), color = 'darkred') | |
| from PIL import ImageDraw, ImageFont | |
| d = ImageDraw.Draw(placeholder_img) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 40) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| d.text((50, 450), "GENERATION FAILED", fill=(255, 255, 255), font=font) | |
| d.text((50, 550), f"Error: {error_msg}", fill=(255, 200, 200), font=font) | |
| return f"FAILED. Error: {error_msg}", placeholder_img | |
| raise ValueError(error_msg) |