Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import itertools | |
| import sys | |
| import threading | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| def _patch_torch_load_for_old_ckpt() -> None: | |
| """ | |
| Matches `anime_face_eye_extract._patch_torch_load_for_old_ckpt()` to load older YOLOv5 checkpoints | |
| on newer torch versions. | |
| """ | |
| import numpy as _np | |
| try: | |
| torch.serialization.add_safe_globals([_np.core.multiarray._reconstruct, _np.ndarray]) | |
| except Exception: | |
| pass | |
| _orig_load = torch.load | |
| def _patched_load(*args, **kwargs): # noqa: ANN001 | |
| kwargs.setdefault("weights_only", False) | |
| return _orig_load(*args, **kwargs) | |
| torch.load = _patched_load | |
| def _pre(gray: np.ndarray) -> np.ndarray: | |
| import cv2 | |
| gray = cv2.GaussianBlur(gray, (3, 3), 0) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| return clahe.apply(gray) | |
| def _expand(box, margin: float, W: int, H: int): | |
| x1, y1, x2, y2 = box | |
| cx = (x1 + x2) / 2.0 | |
| cy = (y1 + y2) / 2.0 | |
| w = (x2 - x1) * (1 + margin) | |
| h = (y2 - y1) * (1 + margin) | |
| nx1 = int(round(cx - w / 2)) | |
| ny1 = int(round(cy - h / 2)) | |
| nx2 = int(round(cx + w / 2)) | |
| ny2 = int(round(cy + h / 2)) | |
| nx1 = max(0, min(W, nx1)) | |
| ny1 = max(0, min(H, ny1)) | |
| nx2 = max(0, min(W, nx2)) | |
| ny2 = max(0, min(H, ny2)) | |
| return nx1, ny1, nx2, ny2 | |
| def _shrink(img: np.ndarray, limit: int): | |
| import cv2 | |
| h, w = img.shape[:2] | |
| m = max(h, w) | |
| if m <= limit: | |
| return img, 1.0 | |
| s = limit / float(m) | |
| nh, nw = int(h * s), int(w * s) | |
| small = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA) | |
| return small, s | |
| def _pad_to_square_rgb(img: np.ndarray) -> np.ndarray: | |
| """ | |
| Pad an RGB crop to a square (1:1) using edge-padding. | |
| This guarantees 1:1 aspect ratio without stretching content. | |
| """ | |
| if img is None or img.size == 0: | |
| return img | |
| h, w = img.shape[:2] | |
| if h == w: | |
| return img | |
| s = max(h, w) | |
| pad_y = s - h | |
| pad_x = s - w | |
| top = pad_y // 2 | |
| bottom = pad_y - top | |
| left = pad_x // 2 | |
| right = pad_x - left | |
| return np.pad(img, ((top, bottom), (left, right), (0, 0)), mode="edge") | |
| def _square_box_from_rect(rect, *, scale: float, W: int, H: int): | |
| """ | |
| Convert a rectangle (x1,y1,x2,y2) into a square box centered on the rect, | |
| scaled by `scale`, clamped to image bounds. | |
| """ | |
| x1, y1, x2, y2 = [int(v) for v in rect] | |
| cx = (x1 + x2) / 2.0 | |
| cy = (y1 + y2) / 2.0 | |
| bw = max(1.0, float(x2 - x1)) | |
| bh = max(1.0, float(y2 - y1)) | |
| side = max(bw, bh) * float(scale) | |
| nx1 = int(round(cx - side / 2.0)) | |
| ny1 = int(round(cy - side / 2.0)) | |
| nx2 = int(round(cx + side / 2.0)) | |
| ny2 = int(round(cy + side / 2.0)) | |
| nx1 = max(0, min(W, nx1)) | |
| ny1 = max(0, min(H, ny1)) | |
| nx2 = max(0, min(W, nx2)) | |
| ny2 = max(0, min(H, ny2)) | |
| if nx2 <= nx1 or ny2 <= ny1: | |
| return None | |
| return nx1, ny1, nx2, ny2 | |
| def _split_box_by_midline(box, mid_x: int): | |
| """ | |
| If a box crosses the vertical midline, split into left/right boxes. | |
| Returns list of (tag, box). | |
| """ | |
| x1, y1, x2, y2 = [int(v) for v in box] | |
| if x1 < mid_x < x2: | |
| left = (x1, y1, mid_x, y2) | |
| right = (mid_x, y1, x2, y2) | |
| out = [] | |
| if left[2] > left[0]: | |
| out.append(("left", left)) | |
| if right[2] > right[0]: | |
| out.append(("right", right)) | |
| return out | |
| tag = "left" if (x1 + x2) / 2.0 <= mid_x else "right" | |
| return [(tag, (x1, y1, x2, y2))] | |
| def _best_pair(boxes, W: int, H: int): | |
| clean = [(int(b[0]), int(b[1]), int(b[2]), int(b[3])) for b in boxes] | |
| if len(clean) < 2: | |
| return [] | |
| def cxcy(b): | |
| x1, y1, x2, y2 = b | |
| return (x1 + x2) / 2.0, (y1 + y2) / 2.0 | |
| def area(b): | |
| x1, y1, x2, y2 = b | |
| return max(1, (x2 - x1) * (y2 - y1)) | |
| best = None | |
| best_s = 1e9 | |
| for b1, b2 in itertools.combinations(clean, 2): | |
| c1x, c1y = cxcy(b1) | |
| c2x, c2y = cxcy(b2) | |
| a1, a2 = area(b1), area(b2) | |
| horiz = 0.0 if c1x < c2x else 0.5 | |
| y_aln = abs(c1y - c2y) / max(1.0, H) | |
| szsim = abs(a1 - a2) / float(max(a1, a2)) | |
| gap = abs(c2x - c1x) / max(1.0, W) | |
| if 0.05 <= gap <= 0.5: | |
| gap_pen = 0.0 | |
| else: | |
| gap_pen = 0.5 * ((0.5 + abs(gap - 0.05) * 10) if gap < 0.05 else (gap - 0.5) * 2.0) | |
| mean_y = (c1y + c2y) / 2.0 / max(1.0, H) | |
| upper = 0.3 * max(0.0, (mean_y - 0.67) * 2.0) | |
| s = y_aln + szsim + gap_pen + upper + horiz | |
| if s < best_s: | |
| best_s = s | |
| best = (b1, b2) | |
| if best is None: | |
| return [] | |
| b1, b2 = best | |
| left, right = (b1, b2) if (b1[0] + b1[2]) <= (b2[0] + b2[2]) else (b2, b1) | |
| return [("left", left), ("right", right)] | |
| class ExtractorCfg: | |
| yolo_dir: Path | |
| weights: Path | |
| cascade: Path | |
| imgsz: int = 640 | |
| conf: float = 0.5 | |
| iou: float = 0.5 | |
| yolo_device: str = "cpu" # "cpu" or "0" | |
| eye_roi_frac: float = 0.70 | |
| eye_min_size: int = 12 | |
| eye_margin: float = 0.60 | |
| neighbors: int = 9 | |
| eye_downscale_limit_roi: int = 512 | |
| eye_downscale_limit_face: int = 768 | |
| eye_fallback_to_face: bool = True | |
| class AnimeFaceEyeExtractor: | |
| """ | |
| Single-image view extractor (whole -> face crop, eyes crop) based on `anime_face_eye_extract.py`. | |
| Designed for use in the Gradio UI: caches YOLO model + Haar cascade. | |
| """ | |
| def __init__(self, cfg: ExtractorCfg): | |
| self.cfg = cfg | |
| self._model = None | |
| self._device = None | |
| self._stride = 32 | |
| self._tl = threading.local() | |
| def _init_detector(self) -> None: | |
| if self._model is not None: | |
| return | |
| ydir = self.cfg.yolo_dir.resolve() | |
| if not ydir.exists(): | |
| raise RuntimeError(f"yolov5_anime dir not found: {ydir}") | |
| if str(ydir) not in sys.path: | |
| sys.path.insert(0, str(ydir)) | |
| _patch_torch_load_for_old_ckpt() | |
| from models.experimental import attempt_load | |
| from utils.torch_utils import select_device | |
| self._device = select_device(self.cfg.yolo_device) | |
| self._model = attempt_load(str(self.cfg.weights), map_location=self._device) | |
| self._model.eval() | |
| self._stride = int(self._model.stride.max()) | |
| s = int(self.cfg.imgsz) | |
| s = int(np.ceil(s / self._stride) * self._stride) | |
| self.cfg.imgsz = s | |
| def _letterbox_compat(self, img0, new_shape, stride): | |
| from utils.datasets import letterbox | |
| try: | |
| lb = letterbox(img0, new_shape, stride=stride, auto=False) | |
| except TypeError: | |
| try: | |
| lb = letterbox(img0, new_shape, auto=False) | |
| except TypeError: | |
| lb = letterbox(img0, new_shape) | |
| return lb[0] | |
| def _detect_faces(self, rgb: np.ndarray): | |
| import cv2 | |
| self._init_detector() | |
| from utils.general import non_max_suppression, scale_coords | |
| img0 = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) | |
| h0, w0, _ = img0.shape | |
| img = self._letterbox_compat(img0, self.cfg.imgsz, self._stride) | |
| img = img[:, :, ::-1].transpose(2, 0, 1) | |
| img = np.ascontiguousarray(img) | |
| im = torch.from_numpy(img).to(self._device) | |
| im = im.float() / 255.0 | |
| if im.ndim == 3: | |
| im = im[None] | |
| with torch.no_grad(): | |
| pred = self._model(im)[0] | |
| pred = non_max_suppression(pred, conf_thres=self.cfg.conf, iou_thres=self.cfg.iou, classes=None, agnostic=False) | |
| boxes = [] | |
| det = pred[0] | |
| if det is not None and len(det): | |
| det[:, :4] = scale_coords((self.cfg.imgsz, self.cfg.imgsz), det[:, :4], (h0, w0)).round() | |
| for *xyxy, conf, cls in det.tolist(): | |
| x1, y1, x2, y2 = [int(v) for v in xyxy] | |
| boxes.append((x1, y1, x2, y2)) | |
| return boxes | |
| def _get_cascade(self): | |
| import cv2 | |
| c = getattr(self._tl, "cascade", None) | |
| if c is None: | |
| c = cv2.CascadeClassifier(str(self.cfg.cascade)) | |
| if c.empty(): | |
| raise RuntimeError(f"cascade load fail: {self.cfg.cascade}") | |
| self._tl.cascade = c | |
| return c | |
| def _detect_eyes_in_roi(self, rgb_roi: np.ndarray): | |
| import cv2 | |
| gray = cv2.cvtColor(rgb_roi, cv2.COLOR_RGB2GRAY) | |
| proc = _pre(gray) | |
| H, W = proc.shape[:2] | |
| min_side = max(1, min(W, H)) | |
| dyn_min = int(0.07 * min_side) | |
| min_sz = max(8, int(self.cfg.eye_min_size), dyn_min) | |
| cascade = self._get_cascade() | |
| raw = cascade.detectMultiScale( | |
| proc, | |
| scaleFactor=1.15, | |
| minNeighbors=int(self.cfg.neighbors), | |
| minSize=(min_sz, min_sz), | |
| flags=cv2.CASCADE_SCALE_IMAGE, | |
| ) | |
| try: | |
| arr = np.asarray(raw if not isinstance(raw, tuple) else raw[0]) | |
| except Exception: | |
| arr = np.empty((0, 4), dtype=int) | |
| if arr.size == 0: | |
| return [] | |
| if arr.ndim == 1: | |
| arr = arr.reshape(1, -1) | |
| boxes = [] | |
| for r in arr: | |
| x, y, w, h = [int(v) for v in r[:4]] | |
| if w <= 0 or h <= 0: | |
| continue | |
| boxes.append((x, y, x + w, y + h)) | |
| return boxes | |
| def _pick_best_face(boxes): | |
| if not boxes: | |
| return None | |
| # choose largest-area face | |
| def area(b): | |
| x1, y1, x2, y2 = b | |
| return max(1, (x2 - x1) * (y2 - y1)) | |
| return max(boxes, key=area) | |
| def extract(self, whole_rgb: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: | |
| """ | |
| Args: | |
| whole_rgb: HWC RGB uint8 | |
| Returns: | |
| (face_rgb, eye_rgb) as RGB uint8 crops (or None if not found) | |
| """ | |
| import cv2 | |
| boxes = self._detect_faces(whole_rgb) | |
| face_box = self._pick_best_face(boxes) | |
| if face_box is None: | |
| return None, None | |
| x1, y1, x2, y2 = face_box | |
| H0, W0 = whole_rgb.shape[:2] | |
| x1 = max(0, min(W0, x1)) | |
| x2 = max(0, min(W0, x2)) | |
| y1 = max(0, min(H0, y1)) | |
| y2 = max(0, min(H0, y2)) | |
| if x2 <= x1 or y2 <= y1: | |
| return None, None | |
| face = whole_rgb[y1:y2, x1:x2].copy() | |
| # eye detection on upper ROI | |
| H, W = face.shape[:2] | |
| roi_h = int(H * float(self.cfg.eye_roi_frac)) | |
| roi = face[0: max(1, roi_h), :] | |
| roi_small, s_roi = _shrink(roi, int(self.cfg.eye_downscale_limit_roi)) | |
| face_small, s_face = _shrink(face, int(self.cfg.eye_downscale_limit_face)) | |
| eyes_roi = self._detect_eyes_in_roi(roi_small) | |
| eyes_roi = [(int(a / s_roi), int(b / s_roi), int(c / s_roi), int(d / s_roi)) for (a, b, c, d) in eyes_roi] | |
| labs = _best_pair(eyes_roi, W, roi.shape[0]) | |
| origin = "roi" if labs else None | |
| eyes_full = [] | |
| if self.cfg.eye_fallback_to_face and (not labs): | |
| eyes_full = self._detect_eyes_in_roi(face_small) | |
| eyes_full = [(int(a / s_face), int(b / s_face), int(c / s_face), int(d / s_face)) for (a, b, c, d) in eyes_full] | |
| if len(eyes_full) >= 2: | |
| labs = _best_pair(eyes_full, W, H) | |
| origin = "face" if labs else origin | |
| if not labs: | |
| cand = eyes_roi | |
| cand_origin = "roi" | |
| if self.cfg.eye_fallback_to_face and len(eyes_full) >= 1: | |
| cand = eyes_full | |
| cand_origin = "face" | |
| if len(cand) >= 2: | |
| top2 = sorted(cand, key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)[:2] | |
| top2 = sorted(top2, key=lambda b: (b[0] + b[2])) | |
| labs = [("left", top2[0]), ("right", top2[1])] | |
| origin = cand_origin | |
| elif len(cand) == 1: | |
| labs = [("left", cand[0])] | |
| origin = cand_origin | |
| eye_crop = None | |
| if labs: | |
| src_img = roi if origin == "roi" else face | |
| bound_h = roi.shape[0] if origin == "roi" else H | |
| mid_x = int(round(W / 2.0)) | |
| # Build candidate eye boxes; split any box that crosses the midline | |
| candidates = [] | |
| for tag, b in labs: | |
| candidates.extend(_split_box_by_midline(b, mid_x)) | |
| # Deterministically choose the LEFT eye if present; otherwise fall back to largest | |
| left_boxes = [b for (t, b) in candidates if t == "left"] | |
| pick_from = left_boxes if left_boxes else [b for (_, b) in candidates] | |
| chosen = max(pick_from, key=lambda bb: max(1, (bb[2] - bb[0]) * (bb[3] - bb[1]))) | |
| # Square crop around the chosen eye (no stretching); pad to square to guarantee 1:1. | |
| scale = 1.0 + float(self.cfg.eye_margin) | |
| sq = _square_box_from_rect(chosen, scale=scale, W=W, H=bound_h) | |
| if sq is not None: | |
| ex1, ey1, ex2, ey2 = sq | |
| crop = src_img[ey1:ey2, ex1:ex2] | |
| if crop.size > 0 and min(crop.shape[0], crop.shape[1]) >= int(self.cfg.eye_min_size): | |
| eye_crop = _pad_to_square_rgb(crop.copy()) | |
| return face, eye_crop | |