import gradio as gr import numpy as np import torch import torchvision from torch import nn from torchvision import transforms import typing as tp from huggingface_hub import list_repo_files, hf_hub_download from ultralytics import YOLO import cv2 # --------------------------------- # 0. Get dataset file names # --------------------------------- repo_type = "dataset" repo_id = "eloise54/cots_yolo_dataset" files = list_repo_files(repo_id, repo_type=repo_type) def get_dataset_splits(files): train_images = [] val_images = [] test_images = [] train_labels = [] val_labels = [] test_labels = [] for x in files: if ".jpg" in x: l = x.replace("images/", "labels/") l = l.replace(".jpg", ".txt") if "train/" in x: train_images.append(x) train_labels.append(l) elif "val/" in x: val_images.append(x) val_labels.append(l) elif "test/" in x: test_images.append(x) test_labels.append(l) return train_images, val_images, test_images, train_labels, val_labels, test_labels train_images, val_images, test_images, train_labels, val_labels, test_labels = get_dataset_splits(files) # --------------------------------- # 1. Load model # --------------------------------- model = YOLO('runs/detect/yolov11m_1920p/weights/best.pt').to("cpu") model.eval() # --------------------------------- # 2. Define function to read labels and draw boxes # --------------------------------- def read_ground_truth(label_file_path, img_width, img_height): ground_truth_boxes = [] try: with open(label_file_path) as f: for line in f: cls, xc, yc, w, h = map(float, line.split()) print(cls, xc, yc, w, h) xc = xc * img_width yc = yc * img_height w = w * img_width h = h * img_height x0 = xc - 0.5 * w y0 = yc - 0.5 * h x1 = xc + 0.5 * w y1 = yc + 0.5 * h ground_truth_boxes.append({ "class_id": int(cls), "box": [x0, y0, x1, y1] }) except: pass#no label txt files means no COTS in image return ground_truth_boxes def draw_rectangle(img, box, color, thickness): start_point = (int(box[0]), int(box[1])) end_point = (int(box[2]), int(box[3])) overlay = img.copy() alpha = 0.5 overlay = cv2.rectangle(overlay, start_point, end_point, color, thickness) img = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0) return img # --------------------------------- # 3. Prediction function # --------------------------------- def get_sample(index: int, dataset_choice: str): images = [] labels = [] if dataset_choice == "train": images = train_images labels = train_labels elif dataset_choice == "val": images = val_images labels = val_labels elif dataset_choice == "test": images = test_images labels = test_labels index = max(0, min(index, len(images) - 1)) # clamp index downloaded_path = hf_hub_download(repo_id=repo_id,repo_type=repo_type,filename=images[index],local_dir=".") try: downloaded_path = hf_hub_download(repo_id=repo_id,repo_type=repo_type,filename=labels[index],local_dir=".") except: pass #no label txt files means no COTS in image pred_color = (0, 0, 255) gt_color = (0, 255, 0) thickness = 15 img = cv2.imread(images[index]) with torch.no_grad(): results = model(images[index], imgsz=1920) gt = read_ground_truth(labels[index], img.shape[1], img.shape[0]) for res in results: boxes = res.boxes.xyxy for box in boxes: img = draw_rectangle(img, box, pred_color, thickness) for box_dict in gt: img = draw_rectangle(img, box_dict['box'], gt_color, thickness) img = img[...,::-1] # BGR to RGB return img, index, index, dataset_choice # --------------------------------- # 4. Navigation functions # --------------------------------- def next_sample(index: int, dataset_choice: str): return get_sample(index + 1, dataset_choice) def prev_sample(index: int, dataset_choice: str): return get_sample(index - 1, dataset_choice) # --------------------------------- # 5. UI elements # --------------------------------- dataset_information= """ ## Dataset overview [![Hugging Face Dataset](https://img.shields.io/badge/huggingface-dataset-blue?logo=huggingface)](https://huggingface.co/datasets/eloise54/cots_yolo_dataset) This dataset is a **modified version** of the [CSIRO COTS and COTS Scars Dataset](https://data.csiro.au/collection/csiro:64235), originally released under the [Creative Commons Attribution 4.0 License (CC BY 4.0)](https://creativecommons.org/licenses/by/4.0/). The original dataset contains images and annotations for **Crown-of-Thorns Starfish (COTS)** and **COTS scars**, collected to support coral reef monitoring and control efforts on the Great Barrier Reef (GBR). These starfish are coral predators, and their outbreaks can severely damage reef ecosystems. **PCSIRO COTS and COTS Scars Dataset reference:** ```bibtex @dataset{csiro_cots_2024, author = {Armin, Ali and Bainbridge, Scott and Page, Geoff and Tychsen-Smith, Lachlan and Coleman, Greg and Oorloff, Jeremy and Harvey, De'vereux and Do, Brendan and Marsh, Benjamin and Lawrence, Emma and Kusy, Brano and Hayder, Zeeshan and Bonin, Mary}, title = {COTS and COTS scar dataset}, year = {2024}, publisher = {CSIRO}, version = {v1}, doi = {10.25919/03a7-hn83}, url = {https://data.csiro.au/collection/csiro:64235} } ``` """ with gr.Blocks() as demo: gr.Markdown("## 🪸 Crown of thorns starfish detection - protect the great barrier reef") gr.Markdown("Use **Next** or **Previous** to browse samples and see model predictions vs ground truth.") state = gr.State(0) # holds current index with gr.Row(): dropdown = gr.Dropdown( ["train", "val", "test"], label="Dataset split to use", value="train") dataset_choice = gr.Text(label="Using Dataset") with gr.Row(equal_height=True): index_input = gr.Number(label="Enter image number to display: ", value=0, precision=0) go_btn = gr.Button("Apply") with gr.Row(): image_output = gr.Image(label="Image") with gr.Row(): gr.Markdown("Green is ground truth, Red is model prediction") with gr.Row(): index = gr.Text(label="Current Image Number", interactive=False) with gr.Row(): prev_btn = gr.Button("⬅️ Prev image") next_btn = gr.Button("Next image➡️") with gr.Row(): gr.Markdown(dataset_information) # Connect navigation prev_btn.click(fn=prev_sample, inputs=[state, dropdown], outputs=[image_output, state, index, dataset_choice]) next_btn.click(fn=next_sample, inputs=[state, dropdown], outputs=[image_output, state, index, dataset_choice]) go_btn.click(fn=get_sample, inputs=[index_input, dropdown], outputs=[image_output, state, index, dataset_choice]) # Load initial image demo.load(fn=get_sample, inputs=[state, dropdown], outputs=[image_output, state, index, dataset_choice]) # --------------------------------- # 6. Run # --------------------------------- if __name__ == "__main__": demo.launch(show_api=False)