|
|
import os, sys |
|
|
import cv2 |
|
|
import time |
|
|
import datetime, pytz |
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from torchvision.utils import save_image |
|
|
import json |
|
|
import threading |
|
|
from queue import Queue |
|
|
from pathlib import Path |
|
|
import shutil |
|
|
|
|
|
|
|
|
root_path = os.path.abspath('.') |
|
|
sys.path.append(root_path) |
|
|
from test_code.inference import super_resolve_img |
|
|
from test_code.test_utils import load_grl, load_rrdb, load_dat |
|
|
|
|
|
|
|
|
OUTPUT_DIR = "outputs" |
|
|
HISTORY_FILE = "history.json" |
|
|
VIDEO_QUEUE_FILE = "video_queue.json" |
|
|
video_queue = Queue() |
|
|
processing_status = {} |
|
|
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True) |
|
|
os.makedirs(os.path.join(OUTPUT_DIR, "videos"), exist_ok=True) |
|
|
|
|
|
|
|
|
def auto_download_if_needed(weight_path): |
|
|
if os.path.exists(weight_path): |
|
|
return |
|
|
|
|
|
if not os.path.exists("pretrained"): |
|
|
os.makedirs("pretrained") |
|
|
|
|
|
if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth": |
|
|
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth") |
|
|
os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained") |
|
|
|
|
|
if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth": |
|
|
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth") |
|
|
os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained") |
|
|
|
|
|
if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth": |
|
|
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth") |
|
|
os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained") |
|
|
|
|
|
if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth": |
|
|
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth") |
|
|
os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained") |
|
|
|
|
|
|
|
|
def load_history(): |
|
|
"""Load processing history from JSON file""" |
|
|
if os.path.exists(HISTORY_FILE): |
|
|
with open(HISTORY_FILE, 'r') as f: |
|
|
return json.load(f) |
|
|
return [] |
|
|
|
|
|
|
|
|
def save_history(history): |
|
|
"""Save processing history to JSON file""" |
|
|
with open(HISTORY_FILE, 'w') as f: |
|
|
json.dump(history, f, indent=2) |
|
|
|
|
|
|
|
|
def add_to_history(input_path, output_path, model_name, process_type, status="completed"): |
|
|
"""Add a record to history""" |
|
|
history = load_history() |
|
|
record = { |
|
|
"timestamp": datetime.datetime.now().isoformat(), |
|
|
"input_path": input_path, |
|
|
"output_path": output_path, |
|
|
"model_name": model_name, |
|
|
"process_type": process_type, |
|
|
"status": status |
|
|
} |
|
|
history.insert(0, record) |
|
|
save_history(history) |
|
|
|
|
|
|
|
|
def load_generator(model_name): |
|
|
"""Load the appropriate model""" |
|
|
if model_name == "4xGRL": |
|
|
weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth" |
|
|
auto_download_if_needed(weight_path) |
|
|
generator = load_grl(weight_path, scale=4) |
|
|
|
|
|
elif model_name == "4xRRDB": |
|
|
weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth" |
|
|
auto_download_if_needed(weight_path) |
|
|
generator = load_rrdb(weight_path, scale=4) |
|
|
|
|
|
elif model_name == "2xRRDB": |
|
|
weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth" |
|
|
auto_download_if_needed(weight_path) |
|
|
generator = load_rrdb(weight_path, scale=2) |
|
|
|
|
|
elif model_name == "4xDAT": |
|
|
weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth" |
|
|
auto_download_if_needed(weight_path) |
|
|
generator = load_dat(weight_path, scale=4) |
|
|
else: |
|
|
raise ValueError(f"Model {model_name} not supported") |
|
|
|
|
|
return generator.to(device='cpu') |
|
|
|
|
|
|
|
|
def inference_image(img_path, model_name): |
|
|
"""Process a single image""" |
|
|
try: |
|
|
if img_path is None: |
|
|
return None, "β Please upload an image first" |
|
|
|
|
|
generator = load_generator(model_name) |
|
|
|
|
|
print("Processing image:", img_path) |
|
|
print("Time:", datetime.datetime.now(pytz.timezone('US/Eastern'))) |
|
|
|
|
|
|
|
|
super_resolved_img = super_resolve_img( |
|
|
generator, img_path, output_path=None, |
|
|
downsample_threshold=720, crop_for_4x=True |
|
|
) |
|
|
|
|
|
|
|
|
timestamp = int(time.time() * 1000) |
|
|
output_name = f"image_{timestamp}.png" |
|
|
output_path = os.path.join(OUTPUT_DIR, "images", output_name) |
|
|
save_image(super_resolved_img, output_path) |
|
|
|
|
|
|
|
|
outputs = cv2.imread(output_path) |
|
|
outputs = cv2.cvtColor(outputs, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
add_to_history(img_path, output_path, model_name, "image") |
|
|
|
|
|
return outputs, f"β
Saved to: {output_path}" |
|
|
|
|
|
except Exception as error: |
|
|
return None, f"β Error: {str(error)}" |
|
|
|
|
|
|
|
|
def process_video_frame_by_frame(video_path, model_name, task_id): |
|
|
"""Process video frame by frame""" |
|
|
try: |
|
|
processing_status[task_id] = {"status": "processing", "progress": 0} |
|
|
|
|
|
|
|
|
generator = load_generator(model_name) |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
raise ValueError("Cannot open video file") |
|
|
|
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
|
|
|
timestamp = int(time.time() * 1000) |
|
|
output_name = f"video_{timestamp}.mp4" |
|
|
output_path = os.path.join(OUTPUT_DIR, "videos", output_name) |
|
|
|
|
|
|
|
|
temp_dir = f"temp_frames_{timestamp}" |
|
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
frame_count = 0 |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
|
|
|
temp_frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png") |
|
|
cv2.imwrite(temp_frame_path, frame) |
|
|
|
|
|
|
|
|
super_resolved_img = super_resolve_img( |
|
|
generator, temp_frame_path, output_path=None, |
|
|
downsample_threshold=720, crop_for_4x=True |
|
|
) |
|
|
|
|
|
|
|
|
output_frame_path = os.path.join(temp_dir, f"output_{frame_count:06d}.png") |
|
|
save_image(super_resolved_img, output_frame_path) |
|
|
|
|
|
frame_count += 1 |
|
|
progress = int((frame_count / total_frames) * 100) |
|
|
processing_status[task_id] = {"status": "processing", "progress": progress} |
|
|
|
|
|
print(f"Task {task_id}: Processed frame {frame_count}/{total_frames} ({progress}%)") |
|
|
|
|
|
cap.release() |
|
|
|
|
|
|
|
|
print(f"Task {task_id}: Combining frames into video...") |
|
|
processing_status[task_id] = {"status": "encoding", "progress": 100} |
|
|
|
|
|
os.system(f"ffmpeg -framerate {fps} -i {temp_dir}/output_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}") |
|
|
|
|
|
|
|
|
shutil.rmtree(temp_dir) |
|
|
|
|
|
processing_status[task_id] = {"status": "completed", "progress": 100, "output": output_path} |
|
|
add_to_history(video_path, output_path, model_name, "video") |
|
|
|
|
|
print(f"Task {task_id}: Completed! Output: {output_path}") |
|
|
|
|
|
except Exception as error: |
|
|
processing_status[task_id] = {"status": "error", "error": str(error)} |
|
|
print(f"Task {task_id}: Error - {error}") |
|
|
|
|
|
|
|
|
def video_queue_worker(): |
|
|
"""Background worker to process video queue""" |
|
|
print("Video queue worker started...") |
|
|
while True: |
|
|
try: |
|
|
task = video_queue.get() |
|
|
if task is None: |
|
|
break |
|
|
|
|
|
task_id, video_path, model_name = task |
|
|
print(f"Starting task {task_id}...") |
|
|
process_video_frame_by_frame(video_path, model_name, task_id) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Worker error: {e}") |
|
|
finally: |
|
|
video_queue.task_done() |
|
|
|
|
|
|
|
|
def submit_video(video_path, model_name): |
|
|
"""Submit video to processing queue""" |
|
|
if video_path is None: |
|
|
return None, "β Please upload a video first" |
|
|
|
|
|
task_id = f"task_{int(time.time() * 1000)}" |
|
|
video_queue.put((task_id, video_path, model_name)) |
|
|
processing_status[task_id] = {"status": "queued", "progress": 0} |
|
|
|
|
|
return None, f"β
Video submitted to queue! Task ID: {task_id}\nCheck status in the monitoring section." |
|
|
|
|
|
|
|
|
def get_queue_status(): |
|
|
"""Get current queue status""" |
|
|
status_text = "π **Queue Status**\n\n" |
|
|
status_text += f"Videos in queue: {video_queue.qsize()}\n\n" |
|
|
|
|
|
if processing_status: |
|
|
status_text += "**Active Tasks:**\n" |
|
|
for task_id, status in processing_status.items(): |
|
|
status_text += f"\n㪠{task_id}:\n" |
|
|
status_text += f" Status: {status['status']}\n" |
|
|
status_text += f" Progress: {status.get('progress', 0)}%\n" |
|
|
if 'output' in status: |
|
|
status_text += f" Output: {status['output']}\n" |
|
|
if 'error' in status: |
|
|
status_text += f" Error: {status['error']}\n" |
|
|
else: |
|
|
status_text += "No active tasks" |
|
|
|
|
|
return status_text |
|
|
|
|
|
|
|
|
def get_history_display(): |
|
|
"""Get formatted history for display""" |
|
|
history = load_history() |
|
|
if not history: |
|
|
return "No history available" |
|
|
|
|
|
history_text = "π **Processing History**\n\n" |
|
|
for idx, record in enumerate(history[:50]): |
|
|
history_text += f"**{idx + 1}. {record['process_type'].upper()}** - {record['timestamp']}\n" |
|
|
history_text += f" Model: {record['model_name']}\n" |
|
|
history_text += f" Status: {record['status']}\n" |
|
|
history_text += f" Output: {record['output_path']}\n\n" |
|
|
|
|
|
return history_text |
|
|
|
|
|
|
|
|
def clear_history(): |
|
|
"""Clear all history""" |
|
|
if os.path.exists(HISTORY_FILE): |
|
|
os.remove(HISTORY_FILE) |
|
|
return "β
History cleared!", get_history_display() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
worker_thread = threading.Thread(target=video_queue_worker, daemon=True) |
|
|
worker_thread.start() |
|
|
|
|
|
MARKDOWN = """ |
|
|
# APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) |
|
|
|
|
|
[GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598) |
|
|
|
|
|
APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios. |
|
|
|
|
|
### β οΈ Note: Images with short side > 720px will be downsampled to 720px (e.g., 1920x1080 β 1280x720) |
|
|
### πΉ New: Video processing runs in background queue - you can close the browser and it continues! |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="APISR - Anime Super Resolution") as demo: |
|
|
|
|
|
gr.Markdown(MARKDOWN) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("πΌοΈ Image Processing"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
input_image = gr.Image(type="filepath", label="Input Image") |
|
|
image_model = gr.Dropdown( |
|
|
choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"], |
|
|
value="4xGRL", |
|
|
label="Model" |
|
|
) |
|
|
image_btn = gr.Button("π Process Image", variant="primary") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
output_image = gr.Image(type="numpy", label="Output Image") |
|
|
image_status = gr.Textbox(label="Status", lines=2) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["__assets__/lr_inputs/image-00277.png"], |
|
|
["__assets__/lr_inputs/image-00542.png"], |
|
|
["__assets__/lr_inputs/41.png"], |
|
|
["__assets__/lr_inputs/f91.jpg"], |
|
|
], |
|
|
inputs=[input_image], |
|
|
) |
|
|
|
|
|
image_btn.click( |
|
|
fn=inference_image, |
|
|
inputs=[input_image, image_model], |
|
|
outputs=[output_image, image_status] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("π¬ Video Processing"): |
|
|
gr.Markdown(""" |
|
|
### Video Processing Queue |
|
|
Videos are processed in the background. You can submit multiple videos and close the browser - processing continues! |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_video = gr.Video(label="Input Video") |
|
|
video_model = gr.Dropdown( |
|
|
choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"], |
|
|
value="4xGRL", |
|
|
label="Model" |
|
|
) |
|
|
video_btn = gr.Button("π€ Submit to Queue", variant="primary") |
|
|
video_status = gr.Textbox(label="Submission Status", lines=3) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### π Queue Monitor") |
|
|
queue_status = gr.Textbox(label="Queue Status", lines=15, interactive=False) |
|
|
refresh_btn = gr.Button("π Refresh Status") |
|
|
|
|
|
video_btn.click( |
|
|
fn=submit_video, |
|
|
inputs=[input_video, video_model], |
|
|
outputs=[input_video, video_status] |
|
|
) |
|
|
|
|
|
refresh_btn.click( |
|
|
fn=get_queue_status, |
|
|
outputs=[queue_status] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("π History"): |
|
|
gr.Markdown("### Processing History") |
|
|
|
|
|
with gr.Row(): |
|
|
refresh_history_btn = gr.Button("π Refresh History") |
|
|
clear_history_btn = gr.Button("ποΈ Clear History", variant="stop") |
|
|
|
|
|
history_display = gr.Textbox(label="History", lines=20, interactive=False) |
|
|
clear_status = gr.Textbox(label="Status", lines=1, visible=True) |
|
|
|
|
|
refresh_history_btn.click( |
|
|
fn=get_history_display, |
|
|
outputs=[history_display] |
|
|
) |
|
|
|
|
|
clear_history_btn.click( |
|
|
fn=clear_history, |
|
|
outputs=[clear_status, history_display] |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue(max_size=20) |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_api=False, |
|
|
theme=gr.themes.Soft() |
|
|
) |