erukude commited on
Commit
177588a
·
verified ·
1 Parent(s): f2ac517

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,111 @@
1
  ---
2
- license: cc-by-4.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: en
3
+ license: mit
4
+ tags:
5
+ - gan
6
+ - cgan
7
+ - keras
8
+ - tensorflow
9
+ - computer-vision
10
+ - image-processing
11
+ - astronomy
12
+ - galaxy-morphology
13
+ - image-segmentation
14
+ pipeline_tag: image-to-image
15
+ library_name: keras
16
+ datasets:
17
+ - desi-legacy-survey
18
  ---
19
+
20
+ # Galaxy Image Simplification using Generative AI
21
+
22
+ This repository hosts the pretrained models for **Galaxy Image Simplification using Generative AI**, a pipeline that converts complex galaxy images into simplified, skeletonized representations suitable for quantitative morphology analysis.
23
+
24
+ The pipeline combines:
25
+
26
+ - A **ResNet-based classifier** to select **spiral galaxies**
27
+ - A **conditional GAN (cGAN)** to produce initial arm masks
28
+ - A **post-processing cGAN** to smooth and connect broken arm segments
29
+
30
+ These models were trained on images from the **DESI Legacy Survey** with manually annotated spiral arms.
31
+
32
+ ---
33
+
34
+ ## Model Sources
35
+
36
+ - **Code & full project:**
37
+ https://github.com/SaiTeja-Erukude/galaxy-image-simplification-using-genai
38
+
39
+ ---
40
+
41
+ ## Files in this repository
42
+
43
+ | File name | Type | Description |
44
+ |----------------------------------|---------------|-------------------------------------------------------------------|
45
+ | `models/galaxy_classifier_resnet50.h5` | Keras model | ResNet-based binary classifier: spiral vs. non-spiral galaxy |
46
+ | `models/galaxy_simplifier_cgan.h5` | Keras model | Conditional GAN: galaxy RGB image ➜ initial arm-highlighted image |
47
+ | `models/postprocess_cgan.h5` | Keras model | Conditional GAN: initial mask ➜ refined, smooth/connected mask |
48
+ | `predict.py` | Python script | Full inference pipeline (classification ➜ simplifier cGAN ➜ post-cGAN) |
49
+ | `graphical_abstract.jpg` | Image | Graphical abstract / high-level overview of the Galaxy Simplifier pipeline |
50
+ | `requirements.txt` | Text file | Python dependencies needed for running inference |
51
+ | `README.md` | Markdown | Model card and usage instructions (this file) |
52
+
53
+ ---
54
+
55
+ ## Intended Use
56
+
57
+ ### What this model does
58
+
59
+ Given an optical galaxy image (RGB, 256×256):
60
+
61
+ 1. **ResNet classifier (`galaxy_classifier_resnet50.h5`)**
62
+ - Predicts whether the galaxy is a **spiral**.
63
+ - Outputs a 2-class softmax:
64
+ - class `0` – non-spiral / other
65
+ - class `1` – spiral
66
+ - Typical usage: apply a confidence threshold on the spiral class (e.g. `p_spiral > 0.65`) before running the GAN pipeline.
67
+
68
+ 2. **Skeletonization cGAN (`galaxy_simplifier_cgan.h5`)**
69
+ - Input: original RGB galaxy image (normalized to `[-1, 1]`).
70
+ - Output: image where **white lines** track the spiral arms (initial skeleton-like mask).
71
+
72
+ 3. **Post-processing cGAN (`postprocess_cgan.h5`)**
73
+ - Input: initial cGAN output.
74
+ - Output: refined mask with **smoother and better-connected arm structures**.
75
+ - This can be further processed with classical image processing (thresholding, skeletonization, dilation) to produce final binary masks.
76
+
77
+ ### Primary use cases
78
+
79
+ - Large-scale **spiral galaxy selection** and morphology analysis
80
+ - Measuring arm geometry, pitch angles, and other structural properties
81
+ - Building catalogs of simplified galaxy images from wide-field surveys
82
+
83
+ ### Not intended for
84
+
85
+ - General-purpose image generation outside the astronomy domain
86
+ - High-fidelity photometric modeling or pixel-perfect reconstruction of galaxies
87
+
88
+ ---
89
+
90
+ ## How to use
91
+
92
+ You can either:
93
+
94
+ - use your **own inference script**, or
95
+ - use the provided minimalistic `inference.py`.
96
+
97
+ ---
98
+
99
+ ## Citation
100
+ If you use this code, models, or catalog in your research, please cite:
101
+
102
+ ```bibtex
103
+ @article{erukude2025galaxy,
104
+ title={Galaxy image simplification using Generative AI},
105
+ author={Erukude, Sai Teja and Shamir, Lior},
106
+ journal={Astronomy and Computing},
107
+ pages={100990},
108
+ year={2025},
109
+ publisher={Elsevier}
110
+ }
111
+ ```
graphical_abstract.jpg ADDED
inference.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from keras.models import load_model
4
+ from keras.preprocessing.image import load_img, img_to_array
5
+ from matplotlib import pyplot
6
+
7
+
8
+ ######################
9
+ # Configuration
10
+ ######################
11
+ RESNET_PATH = "path_to_resnet50_model.h5"
12
+ CGAN_PATH = "path_to_cgan_model.h5"
13
+ POST_CGAN_PATH = "path_to_postprocess_cgan_model.h5" # <--- NEW
14
+
15
+ DATA_PATH = "path_to_test_dir"
16
+ OUTPUT_PATH = "path_to_output_dir"
17
+
18
+ HEIGHT, WIDTH = 256, 256
19
+ TARGET_SIZE = (HEIGHT, WIDTH)
20
+ BATCH_SIZE = 32
21
+
22
+ os.makedirs(OUTPUT_PATH, exist_ok=True)
23
+
24
+
25
+ # Load the models
26
+ resnet_model = load_model(RESNET_PATH)
27
+ print("Resnet50 loaded successfully!")
28
+
29
+ cgan_model = load_model(CGAN_PATH)
30
+ print("cGAN loaded successfully!")
31
+
32
+ post_cgan_model = load_model(POST_CGAN_PATH)
33
+ print("Post-processing cGAN loaded successfully!")
34
+
35
+
36
+ ######################
37
+ #
38
+ ######################
39
+ def load_and_preprocess(img_path: str, model: str = "resnet") -> np.ndarray:
40
+ """
41
+ Desc:
42
+ Load an image from disk and preprocess it for input into a deep learning model.
43
+ Args:
44
+ img_path (str): Path to the image file.
45
+ model (str): The model type to preprocess for.
46
+ "resnet" uses scaling to [0,1], other models use [-1,1] normalization.
47
+ Returns:
48
+ np.ndarray: Preprocessed image ready for model input.
49
+ """
50
+ img = load_img(img_path, target_size=TARGET_SIZE)
51
+ img_array = img_to_array(img)
52
+ img_array = np.expand_dims(img_array, axis=0)
53
+
54
+ if model == "resnet":
55
+ return img_array / 255.0
56
+ # for "cgan" and "post_cgan" we assume [-1, 1] normalization
57
+ return (img_array - 127.5) / 127.5
58
+
59
+
60
+ ######################
61
+ #
62
+ ######################
63
+ def plot_generated_image(gen_image: np.ndarray, filename: str) -> None:
64
+ """
65
+ Save a generated image to disk after rescaling it from [-1, 1] to [0, 1].
66
+ Args:
67
+ gen_image (np.ndarray): The generated image array, expected shape (1, H, W, C).
68
+ filename (str): The filename to save the image as (including extension, e.g., "image.png").
69
+ Returns:
70
+ None
71
+ """
72
+ # Scale from [-1,1] to [0,1]
73
+ gen_image = (gen_image + 1) / 2.0
74
+
75
+ # Save the generated image
76
+ output_filename = os.path.join(OUTPUT_PATH, filename)
77
+ pyplot.imsave(output_filename, gen_image[0])
78
+
79
+
80
+ all_ctr = 0
81
+ spiral_ctr = 0
82
+
83
+ # === Loop through images ===
84
+ for filename in os.listdir(DATA_PATH):
85
+ if not filename.lower().endswith(('.jpg', '.jpeg', '.png')):
86
+ continue
87
+
88
+ img_path = os.path.join(DATA_PATH, filename)
89
+ all_ctr += 1
90
+
91
+ # Step 1: Classify with ResNet50
92
+ resnet_input = load_and_preprocess(img_path, model="resnet")
93
+ resnet_preds = resnet_model.predict(resnet_input, verbose=0)
94
+
95
+ predicted_class = np.argmax(resnet_preds, axis=1)[0]
96
+ if predicted_class == 1: # Spiral galaxy
97
+
98
+ if resnet_preds[0][1] > 0.65: # Confidence threshold
99
+
100
+ # Step 2: Process with first cGAN (skeletonization)
101
+ cgan_input = load_and_preprocess(img_path, model="cgan")
102
+ cgan_output = cgan_model.predict(cgan_input, verbose=0)
103
+
104
+ # Step 3: Post-process with second cGAN (smoothing/connecting lines)
105
+ post_output = post_cgan_model.predict(cgan_output, verbose=0)
106
+
107
+ # Step 4: Save final post-processed output
108
+ plot_generated_image(post_output, filename)
109
+ spiral_ctr += 1
110
+
111
+ print(f"Found '{spiral_ctr}' spiral galaxies in '{all_ctr}' images.")
models/galaxy_classifier_resnet50.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:289c82af19f09fd2d88b051baa3653cabeb27b042887f0498cf517d254aa833a
3
+ size 228803360
models/galaxy_simplifier_cgan.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6896bbb5bf076155f347816d55526c462e821148e507899ef691efc01fa9e3f
3
+ size 217868656
models/postprocess_cgan.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc6ce957ede6517225b6db3e49f76cad03f148f00960e9bc0bd874b824661bb7
3
+ size 217868656
requirements.txt ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.2.2
2
+ albucore==0.0.17
3
+ albumentations==1.4.18
4
+ annotated-types==0.7.0
5
+ astunparse==1.6.3
6
+ cachetools==5.5.2
7
+ certifi==2025.1.31
8
+ charset-normalizer==3.4.1
9
+ colorama==0.4.6
10
+ contourpy==1.1.1
11
+ cycler==0.12.1
12
+ eval_type_backport==0.2.2
13
+ flatbuffers==25.2.10
14
+ fonttools==4.57.0
15
+ gast==0.4.0
16
+ google-auth==2.38.0
17
+ google-auth-oauthlib==1.0.0
18
+ google-pasta==0.2.0
19
+ grpcio==1.70.0
20
+ h5py==3.11.0
21
+ idna==3.10
22
+ imageio==2.35.1
23
+ importlib_metadata==8.5.0
24
+ importlib_resources==6.4.5
25
+ keras==2.13.1
26
+ kiwisolver==1.4.7
27
+ lazy_loader==0.4
28
+ libclang==18.1.1
29
+ Markdown==3.7
30
+ MarkupSafe==2.1.5
31
+ matplotlib==3.7.5
32
+ networkx==3.1
33
+ numpy==1.24.4
34
+ oauthlib==3.2.2
35
+ opencv-python==4.11.0.86
36
+ opencv-python-headless==4.11.0.86
37
+ opt_einsum==3.4.0
38
+ packaging==24.2
39
+ pillow==10.4.0
40
+ protobuf==4.25.6
41
+ pyasn1==0.6.1
42
+ pyasn1_modules==0.4.2
43
+ pydantic==2.10.6
44
+ pydantic_core==2.27.2
45
+ pyparsing==3.1.4
46
+ python-dateutil==2.9.0.post0
47
+ PyWavelets==1.4.1
48
+ PyYAML==6.0.2
49
+ requests==2.32.3
50
+ requests-oauthlib==2.0.0
51
+ rsa==4.9
52
+ scikit-image==0.21.0
53
+ scipy==1.10.1
54
+ six==1.17.0
55
+ tensorboard==2.13.0
56
+ tensorboard-data-server==0.7.2
57
+ tensorflow==2.13.0
58
+ tensorflow-estimator==2.13.0
59
+ tensorflow-intel==2.13.0
60
+ tensorflow-io-gcs-filesystem==0.31.0
61
+ tensorflow_keras==0.1
62
+ termcolor==2.4.0
63
+ tifffile==2023.7.10
64
+ tqdm==4.67.1
65
+ typing_extensions==4.13.1
66
+ urllib3==2.2.3
67
+ Werkzeug==3.0.6
68
+ wrapt==1.17.2
69
+ zipp==3.20.2