Update app.py
Browse files
app.py
CHANGED
|
@@ -6,10 +6,10 @@ import torch
|
|
| 6 |
def load_model(base_model_id, adapter_model_id=None):
|
| 7 |
if torch.cuda.is_available():
|
| 8 |
device = "cuda"
|
| 9 |
-
info = "Running on GPU (CUDA)"
|
| 10 |
else:
|
| 11 |
device = "cpu"
|
| 12 |
-
info = "Running on CPU"
|
| 13 |
|
| 14 |
# Load the base model dynamically on the correct device
|
| 15 |
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
|
|
@@ -23,24 +23,25 @@ def load_model(base_model_id, adapter_model_id=None):
|
|
| 23 |
|
| 24 |
return pipe, info
|
| 25 |
|
| 26 |
-
|
| 27 |
-
if torch.cuda.is_available():
|
| 28 |
-
device = "cuda"
|
| 29 |
-
info = "Running on GPU (CUDA) 🔥"
|
| 30 |
-
else:
|
| 31 |
-
device = "cpu"
|
| 32 |
-
info = "Running on CPU 🥶"
|
| 33 |
-
|
| 34 |
# Function for text-to-image generation with dynamic model ID and device info
|
| 35 |
def generate_image(base_model_id, adapter_model_id, prompt):
|
| 36 |
pipe, info = load_model(base_model_id, adapter_model_id)
|
| 37 |
image = pipe(prompt).images[0]
|
| 38 |
return image, info
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Create the Gradio interface
|
| 41 |
with gr.Blocks() as demo:
|
| 42 |
gr.Markdown("## Custom Text-to-Image Generator with Adapter Support")
|
| 43 |
-
gr.Markdown(f"{info}")
|
|
|
|
| 44 |
with gr.Row():
|
| 45 |
with gr.Column():
|
| 46 |
base_model_id = gr.Textbox(label="Enter Base Model ID (e.g., CompVis/stable-diffusion-v1-4)", placeholder="Base Model ID")
|
|
|
|
| 6 |
def load_model(base_model_id, adapter_model_id=None):
|
| 7 |
if torch.cuda.is_available():
|
| 8 |
device = "cuda"
|
| 9 |
+
info = "Running on GPU (CUDA) 🔥"
|
| 10 |
else:
|
| 11 |
device = "cpu"
|
| 12 |
+
info = "Running on CPU 🥶"
|
| 13 |
|
| 14 |
# Load the base model dynamically on the correct device
|
| 15 |
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
|
|
|
|
| 23 |
|
| 24 |
return pipe, info
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# Function for text-to-image generation with dynamic model ID and device info
|
| 27 |
def generate_image(base_model_id, adapter_model_id, prompt):
|
| 28 |
pipe, info = load_model(base_model_id, adapter_model_id)
|
| 29 |
image = pipe(prompt).images[0]
|
| 30 |
return image, info
|
| 31 |
|
| 32 |
+
# Check device (GPU/CPU) once at the start and show it in the UI
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
device = "cuda"
|
| 35 |
+
info = "Running on GPU (CUDA) 🔥"
|
| 36 |
+
else:
|
| 37 |
+
device = "cpu"
|
| 38 |
+
info = "Running on CPU 🥶"
|
| 39 |
+
|
| 40 |
# Create the Gradio interface
|
| 41 |
with gr.Blocks() as demo:
|
| 42 |
gr.Markdown("## Custom Text-to-Image Generator with Adapter Support")
|
| 43 |
+
gr.Markdown(f"**{info}**") # Display GPU/CPU information in the UI
|
| 44 |
+
|
| 45 |
with gr.Row():
|
| 46 |
with gr.Column():
|
| 47 |
base_model_id = gr.Textbox(label="Enter Base Model ID (e.g., CompVis/stable-diffusion-v1-4)", placeholder="Base Model ID")
|