ArtistEmbeddingClassifier / app /view_extractor.py
iljung1106
Add list of artists and make eyes only capture one eye with square.
38ad444
from __future__ import annotations
import itertools
import sys
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import torch
def _patch_torch_load_for_old_ckpt() -> None:
"""
Matches `anime_face_eye_extract._patch_torch_load_for_old_ckpt()` to load older YOLOv5 checkpoints
on newer torch versions.
"""
import numpy as _np
try:
torch.serialization.add_safe_globals([_np.core.multiarray._reconstruct, _np.ndarray])
except Exception:
pass
_orig_load = torch.load
def _patched_load(*args, **kwargs): # noqa: ANN001
kwargs.setdefault("weights_only", False)
return _orig_load(*args, **kwargs)
torch.load = _patched_load
def _pre(gray: np.ndarray) -> np.ndarray:
import cv2
gray = cv2.GaussianBlur(gray, (3, 3), 0)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return clahe.apply(gray)
def _expand(box, margin: float, W: int, H: int):
x1, y1, x2, y2 = box
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
w = (x2 - x1) * (1 + margin)
h = (y2 - y1) * (1 + margin)
nx1 = int(round(cx - w / 2))
ny1 = int(round(cy - h / 2))
nx2 = int(round(cx + w / 2))
ny2 = int(round(cy + h / 2))
nx1 = max(0, min(W, nx1))
ny1 = max(0, min(H, ny1))
nx2 = max(0, min(W, nx2))
ny2 = max(0, min(H, ny2))
return nx1, ny1, nx2, ny2
def _shrink(img: np.ndarray, limit: int):
import cv2
h, w = img.shape[:2]
m = max(h, w)
if m <= limit:
return img, 1.0
s = limit / float(m)
nh, nw = int(h * s), int(w * s)
small = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA)
return small, s
def _pad_to_square_rgb(img: np.ndarray) -> np.ndarray:
"""
Pad an RGB crop to a square (1:1) using edge-padding.
This guarantees 1:1 aspect ratio without stretching content.
"""
if img is None or img.size == 0:
return img
h, w = img.shape[:2]
if h == w:
return img
s = max(h, w)
pad_y = s - h
pad_x = s - w
top = pad_y // 2
bottom = pad_y - top
left = pad_x // 2
right = pad_x - left
return np.pad(img, ((top, bottom), (left, right), (0, 0)), mode="edge")
def _square_box_from_rect(rect, *, scale: float, W: int, H: int):
"""
Convert a rectangle (x1,y1,x2,y2) into a square box centered on the rect,
scaled by `scale`, clamped to image bounds.
"""
x1, y1, x2, y2 = [int(v) for v in rect]
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
bw = max(1.0, float(x2 - x1))
bh = max(1.0, float(y2 - y1))
side = max(bw, bh) * float(scale)
nx1 = int(round(cx - side / 2.0))
ny1 = int(round(cy - side / 2.0))
nx2 = int(round(cx + side / 2.0))
ny2 = int(round(cy + side / 2.0))
nx1 = max(0, min(W, nx1))
ny1 = max(0, min(H, ny1))
nx2 = max(0, min(W, nx2))
ny2 = max(0, min(H, ny2))
if nx2 <= nx1 or ny2 <= ny1:
return None
return nx1, ny1, nx2, ny2
def _split_box_by_midline(box, mid_x: int):
"""
If a box crosses the vertical midline, split into left/right boxes.
Returns list of (tag, box).
"""
x1, y1, x2, y2 = [int(v) for v in box]
if x1 < mid_x < x2:
left = (x1, y1, mid_x, y2)
right = (mid_x, y1, x2, y2)
out = []
if left[2] > left[0]:
out.append(("left", left))
if right[2] > right[0]:
out.append(("right", right))
return out
tag = "left" if (x1 + x2) / 2.0 <= mid_x else "right"
return [(tag, (x1, y1, x2, y2))]
def _best_pair(boxes, W: int, H: int):
clean = [(int(b[0]), int(b[1]), int(b[2]), int(b[3])) for b in boxes]
if len(clean) < 2:
return []
def cxcy(b):
x1, y1, x2, y2 = b
return (x1 + x2) / 2.0, (y1 + y2) / 2.0
def area(b):
x1, y1, x2, y2 = b
return max(1, (x2 - x1) * (y2 - y1))
best = None
best_s = 1e9
for b1, b2 in itertools.combinations(clean, 2):
c1x, c1y = cxcy(b1)
c2x, c2y = cxcy(b2)
a1, a2 = area(b1), area(b2)
horiz = 0.0 if c1x < c2x else 0.5
y_aln = abs(c1y - c2y) / max(1.0, H)
szsim = abs(a1 - a2) / float(max(a1, a2))
gap = abs(c2x - c1x) / max(1.0, W)
if 0.05 <= gap <= 0.5:
gap_pen = 0.0
else:
gap_pen = 0.5 * ((0.5 + abs(gap - 0.05) * 10) if gap < 0.05 else (gap - 0.5) * 2.0)
mean_y = (c1y + c2y) / 2.0 / max(1.0, H)
upper = 0.3 * max(0.0, (mean_y - 0.67) * 2.0)
s = y_aln + szsim + gap_pen + upper + horiz
if s < best_s:
best_s = s
best = (b1, b2)
if best is None:
return []
b1, b2 = best
left, right = (b1, b2) if (b1[0] + b1[2]) <= (b2[0] + b2[2]) else (b2, b1)
return [("left", left), ("right", right)]
@dataclass
class ExtractorCfg:
yolo_dir: Path
weights: Path
cascade: Path
imgsz: int = 640
conf: float = 0.5
iou: float = 0.5
yolo_device: str = "cpu" # "cpu" or "0"
eye_roi_frac: float = 0.70
eye_min_size: int = 12
eye_margin: float = 0.60
neighbors: int = 9
eye_downscale_limit_roi: int = 512
eye_downscale_limit_face: int = 768
eye_fallback_to_face: bool = True
class AnimeFaceEyeExtractor:
"""
Single-image view extractor (whole -> face crop, eyes crop) based on `anime_face_eye_extract.py`.
Designed for use in the Gradio UI: caches YOLO model + Haar cascade.
"""
def __init__(self, cfg: ExtractorCfg):
self.cfg = cfg
self._model = None
self._device = None
self._stride = 32
self._tl = threading.local()
def _init_detector(self) -> None:
if self._model is not None:
return
ydir = self.cfg.yolo_dir.resolve()
if not ydir.exists():
raise RuntimeError(f"yolov5_anime dir not found: {ydir}")
if str(ydir) not in sys.path:
sys.path.insert(0, str(ydir))
_patch_torch_load_for_old_ckpt()
from models.experimental import attempt_load
from utils.torch_utils import select_device
self._device = select_device(self.cfg.yolo_device)
self._model = attempt_load(str(self.cfg.weights), map_location=self._device)
self._model.eval()
self._stride = int(self._model.stride.max())
s = int(self.cfg.imgsz)
s = int(np.ceil(s / self._stride) * self._stride)
self.cfg.imgsz = s
def _letterbox_compat(self, img0, new_shape, stride):
from utils.datasets import letterbox
try:
lb = letterbox(img0, new_shape, stride=stride, auto=False)
except TypeError:
try:
lb = letterbox(img0, new_shape, auto=False)
except TypeError:
lb = letterbox(img0, new_shape)
return lb[0]
def _detect_faces(self, rgb: np.ndarray):
import cv2
self._init_detector()
from utils.general import non_max_suppression, scale_coords
img0 = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
h0, w0, _ = img0.shape
img = self._letterbox_compat(img0, self.cfg.imgsz, self._stride)
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img)
im = torch.from_numpy(img).to(self._device)
im = im.float() / 255.0
if im.ndim == 3:
im = im[None]
with torch.no_grad():
pred = self._model(im)[0]
pred = non_max_suppression(pred, conf_thres=self.cfg.conf, iou_thres=self.cfg.iou, classes=None, agnostic=False)
boxes = []
det = pred[0]
if det is not None and len(det):
det[:, :4] = scale_coords((self.cfg.imgsz, self.cfg.imgsz), det[:, :4], (h0, w0)).round()
for *xyxy, conf, cls in det.tolist():
x1, y1, x2, y2 = [int(v) for v in xyxy]
boxes.append((x1, y1, x2, y2))
return boxes
def _get_cascade(self):
import cv2
c = getattr(self._tl, "cascade", None)
if c is None:
c = cv2.CascadeClassifier(str(self.cfg.cascade))
if c.empty():
raise RuntimeError(f"cascade load fail: {self.cfg.cascade}")
self._tl.cascade = c
return c
def _detect_eyes_in_roi(self, rgb_roi: np.ndarray):
import cv2
gray = cv2.cvtColor(rgb_roi, cv2.COLOR_RGB2GRAY)
proc = _pre(gray)
H, W = proc.shape[:2]
min_side = max(1, min(W, H))
dyn_min = int(0.07 * min_side)
min_sz = max(8, int(self.cfg.eye_min_size), dyn_min)
cascade = self._get_cascade()
raw = cascade.detectMultiScale(
proc,
scaleFactor=1.15,
minNeighbors=int(self.cfg.neighbors),
minSize=(min_sz, min_sz),
flags=cv2.CASCADE_SCALE_IMAGE,
)
try:
arr = np.asarray(raw if not isinstance(raw, tuple) else raw[0])
except Exception:
arr = np.empty((0, 4), dtype=int)
if arr.size == 0:
return []
if arr.ndim == 1:
arr = arr.reshape(1, -1)
boxes = []
for r in arr:
x, y, w, h = [int(v) for v in r[:4]]
if w <= 0 or h <= 0:
continue
boxes.append((x, y, x + w, y + h))
return boxes
@staticmethod
def _pick_best_face(boxes):
if not boxes:
return None
# choose largest-area face
def area(b):
x1, y1, x2, y2 = b
return max(1, (x2 - x1) * (y2 - y1))
return max(boxes, key=area)
def extract(self, whole_rgb: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""
Args:
whole_rgb: HWC RGB uint8
Returns:
(face_rgb, eye_rgb) as RGB uint8 crops (or None if not found)
"""
import cv2
boxes = self._detect_faces(whole_rgb)
face_box = self._pick_best_face(boxes)
if face_box is None:
return None, None
x1, y1, x2, y2 = face_box
H0, W0 = whole_rgb.shape[:2]
x1 = max(0, min(W0, x1))
x2 = max(0, min(W0, x2))
y1 = max(0, min(H0, y1))
y2 = max(0, min(H0, y2))
if x2 <= x1 or y2 <= y1:
return None, None
face = whole_rgb[y1:y2, x1:x2].copy()
# eye detection on upper ROI
H, W = face.shape[:2]
roi_h = int(H * float(self.cfg.eye_roi_frac))
roi = face[0: max(1, roi_h), :]
roi_small, s_roi = _shrink(roi, int(self.cfg.eye_downscale_limit_roi))
face_small, s_face = _shrink(face, int(self.cfg.eye_downscale_limit_face))
eyes_roi = self._detect_eyes_in_roi(roi_small)
eyes_roi = [(int(a / s_roi), int(b / s_roi), int(c / s_roi), int(d / s_roi)) for (a, b, c, d) in eyes_roi]
labs = _best_pair(eyes_roi, W, roi.shape[0])
origin = "roi" if labs else None
eyes_full = []
if self.cfg.eye_fallback_to_face and (not labs):
eyes_full = self._detect_eyes_in_roi(face_small)
eyes_full = [(int(a / s_face), int(b / s_face), int(c / s_face), int(d / s_face)) for (a, b, c, d) in eyes_full]
if len(eyes_full) >= 2:
labs = _best_pair(eyes_full, W, H)
origin = "face" if labs else origin
if not labs:
cand = eyes_roi
cand_origin = "roi"
if self.cfg.eye_fallback_to_face and len(eyes_full) >= 1:
cand = eyes_full
cand_origin = "face"
if len(cand) >= 2:
top2 = sorted(cand, key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)[:2]
top2 = sorted(top2, key=lambda b: (b[0] + b[2]))
labs = [("left", top2[0]), ("right", top2[1])]
origin = cand_origin
elif len(cand) == 1:
labs = [("left", cand[0])]
origin = cand_origin
eye_crop = None
if labs:
src_img = roi if origin == "roi" else face
bound_h = roi.shape[0] if origin == "roi" else H
mid_x = int(round(W / 2.0))
# Build candidate eye boxes; split any box that crosses the midline
candidates = []
for tag, b in labs:
candidates.extend(_split_box_by_midline(b, mid_x))
# Deterministically choose the LEFT eye if present; otherwise fall back to largest
left_boxes = [b for (t, b) in candidates if t == "left"]
pick_from = left_boxes if left_boxes else [b for (_, b) in candidates]
chosen = max(pick_from, key=lambda bb: max(1, (bb[2] - bb[0]) * (bb[3] - bb[1])))
# Square crop around the chosen eye (no stretching); pad to square to guarantee 1:1.
scale = 1.0 + float(self.cfg.eye_margin)
sq = _square_box_from_rect(chosen, scale=scale, W=W, H=bound_h)
if sq is not None:
ex1, ey1, ex2, ey2 = sq
crop = src_img[ey1:ey2, ex1:ex2]
if crop.size > 0 and min(crop.shape[0], crop.shape[1]) >= int(self.cfg.eye_min_size):
eye_crop = _pad_to_square_rgb(crop.copy())
return face, eye_crop