app-gb199hq1 / models.py
John6666's picture
Update models.py
ebae902 verified
raw
history blame
7.47 kB
import os
import base64
import requests
from typing import List, Tuple, Optional
from PIL import Image
from io import BytesIO
try:
from google import genai
except ImportError:
genai = None
try:
from openai import OpenAI
import gradio as gr
except ImportError:
OpenAI = None
gr = None # Keep gr dependency localized for error handling
# Import internal utility (if needed, though base64 encoding logic is placed inline)
# from utils import image_to_base664
# --- API Clients Initialization ---
GEMINI_CLIENT = None
OPENAI_CLIENT = None
if genai:
try:
# Client initialization relies on GEMINI_API_KEY environment variable
GEMINI_CLIENT = genai.Client()
except Exception:
# Fail silently if key is missing, handle error in function call
pass
if OpenAI:
try:
# Client initialization relies on OPENAI_API_KEY environment variable
OPENAI_CLIENT = OpenAI()
except Exception:
# Fail silently if key is missing, handle error in function call
pass
def get_generation_prompt(
model_choice: str,
prompt: str,
image_paths: List[str]
) -> str:
"""Analyzes the images and prompt using the selected multimodal model to generate a detailed prompt for DALL-E 3."""
print(f"--- Analyzing inputs using {model_choice} ---")
analysis_prompt = (
f"You are an expert creative director. Based on the {len(image_paths)} input images and their text prompt, "
f"synthesize a new, single, extremely detailed, aesthetic, and descriptive prompt (max 500 characters) "
f"suitable for a cutting-edge text-to-image generator like DALL-E 3. "
f"The resulting image must be a 'remix' or fusion incorporating key visual, thematic, and stylistic elements "
f"from all available images, guided by the text prompt: '{prompt}'."
f"Focus on composition, lighting, style, mood, and texture. Do not mention 'image', 'remix', or 'input images' in the output. "
f"Output ONLY the final descriptive prompt text, nothing else."
)
# Load images as PIL objects
images = [Image.open(path) for path in image_paths if path]
# --- GEMINI Analysis Path ---
if model_choice == 'gemini-2':
if not GEMINI_CLIENT:
return f"Gemini API Key missing or client failed to initialize. Fallback prompt: Fusion of provided visual elements inspired by the prompt: {prompt}."
try:
# Contents should be images first, then the text prompt
contents = images + [analysis_prompt]
response = GEMINI_CLIENT.models.generate_content(
model='gemini-2.0-flash-live',
contents=contents
)
expanded_prompt = response.text.strip()
print(f"Gemini Analysis Output: {expanded_prompt}")
return expanded_prompt
except Exception as e:
print(f"Gemini API Error: {e}")
return f"Error using Gemini for analysis. Fallback prompt: Creative fusion of the three elements provided, inspired by the theme: {prompt}."
# --- GPT Analysis Path ---
elif model_choice == 'gpt image-1':
if not OPENAI_CLIENT:
return f"OpenAI API Key missing or client failed to initialize. Fallback prompt: Fusion of provided visual elements inspired by the prompt: {prompt}."
try:
# Prepare contents for gpt-image-1-low with base64 encoded images
contents = [
{"type": "text", "text": analysis_prompt}
]
for img in images:
buffered = BytesIO()
# Use JPEG to reduce payload size
img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
contents.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64}",
"detail": "low" # Use low detail for speed
}
})
response = OPENAI_CLIENT.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "user", "content": contents}
],
max_tokens=500
)
expanded_prompt = response.choices[0].message.content.strip()
print(f"gpt-image-1-low Analysis Output: {expanded_prompt}")
return expanded_prompt
except Exception as e:
print(f"GPT API Error: {e}")
return f"Error using gpt-image-1-low for analysis. Fallback prompt: Creative fusion of the three elements provided, inspired by the theme: {prompt}."
# Fallback if model is unrecognized
return f"Creative synthesis of the visual elements provided, inspired by the prompt: {prompt}. Ensure photorealistic quality."
def generate_remixed_image(
model_choice: str,
prompt: str,
image1_path: Optional[str],
image2_path: Optional[str],
image3_path: Optional[str]
) -> Tuple[str, Image.Image | None]:
"""Orchestrates prompt generation (via selected model) and image synthesis (via DALL-E 3)."""
image_paths = [image1_path, image2_path, image3_path]
valid_paths = [path for path in image_paths if path is not None]
if not OPENAI_CLIENT:
# Raise generic Gradio error if client is missing
if gr:
raise gr.Error("OpenAI client not initialized. Please set OPENAI_API_KEY environment variable.")
else:
raise ValueError("OpenAI client not initialized.")
if not valid_paths:
if gr:
raise gr.Error("Please upload at least one image to remix.")
else:
raise ValueError("No images provided.")
# 1. Generate the optimized DALL-E 3 prompt using the selected analysis model
final_prompt = get_generation_prompt(model_choice, prompt, valid_paths)
print(f"\n--- Final Prompt for DALL-E 3: {final_prompt} ---")
# 2. Generate the image using DALL-E 3 (OpenAI API)
try:
dalle_response = OPENAI_CLIENT.images.generate(
model="gpt-image-1",
prompt=final_prompt,
size="1024x1024",
quality="medium", # valid: low | medium | high | auto
n=1,
)
b64 = dalle_response.data[0].b64_json
img_bytes = base64.b64decode(b64)
remixed_image = Image.open(BytesIO(img_bytes)).convert("RGB")
return final_prompt, remixed_image
except Exception as e:
print(f"DALL-E 3 Generation Error: {e}")
error_msg = f"Image generation failed: {str(e)}"
if gr:
# Create a placeholder error image for display
placeholder_img = Image.new('RGB', (1024, 1024), color = 'darkred')
from PIL import ImageDraw, ImageFont
d = ImageDraw.Draw(placeholder_img)
try:
font = ImageFont.truetype("arial.ttf", 40)
except IOError:
font = ImageFont.load_default()
d.text((50, 450), "GENERATION FAILED", fill=(255, 255, 255), font=font)
d.text((50, 550), f"Error: {error_msg}", fill=(255, 200, 200), font=font)
return f"FAILED. Error: {error_msg}", placeholder_img
raise ValueError(error_msg)