| |
|
|
| from itertools import cycle |
| from time import time |
|
|
| import gradio as gr |
| import matplotlib.colors as colors |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from joblib import cpu_count |
| from sklearn.cluster import Birch, MiniBatchKMeans |
| from sklearn.datasets import make_blobs |
|
|
| plt.switch_backend("agg") |
|
|
|
|
| def do_submit(n_samples, birch_threshold, birch_n_clusters): |
| n_samples = int(n_samples) |
| birch_threshold = float(birch_threshold) |
| birch_n_clusters = int(birch_n_clusters) |
| result = "" |
|
|
| |
| xx = np.linspace(-22, 22, 10) |
| yy = np.linspace(-22, 22, 10) |
| xx, yy = np.meshgrid(xx, yy) |
| n_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis])) |
|
|
| |
| X, y = make_blobs(n_samples=n_samples, centers=n_centers, random_state=0) |
|
|
| |
| colors_ = cycle(colors.cnames.keys()) |
|
|
| fig = plt.figure(figsize=(12, 4)) |
| fig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9) |
|
|
| |
| |
| birch_models = [ |
| Birch(threshold=1.7, n_clusters=None), |
| Birch(threshold=1.7, n_clusters=100), |
| ] |
| final_step = ["without global clustering", "with global clustering"] |
|
|
| for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)): |
| t = time() |
| birch_model.fit(X) |
| result += ( |
| "BIRCH %s as the final step took %0.2f seconds" % (info, (time() - t)) |
| + "\n" |
| ) |
|
|
| |
| labels = birch_model.labels_ |
| centroids = birch_model.subcluster_centers_ |
| n_clusters = np.unique(labels).size |
| result = result + "n_clusters : %d" % n_clusters + "\n" |
|
|
| ax = fig.add_subplot(1, 3, ind + 1) |
| for this_centroid, k, col in zip(centroids, range(n_clusters), colors_): |
| mask = labels == k |
| ax.scatter( |
| X[mask, 0], X[mask, 1], c="w", edgecolor=col, marker=".", alpha=0.5 |
| ) |
| if birch_model.n_clusters is None: |
| ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25) |
| ax.set_ylim([-25, 25]) |
| ax.set_xlim([-25, 25]) |
| ax.set_autoscaley_on(False) |
| ax.set_title("BIRCH %s" % info) |
|
|
| |
| mbk = MiniBatchKMeans( |
| init="k-means++", |
| n_clusters=100, |
| batch_size=256 * cpu_count(), |
| n_init=10, |
| max_no_improvement=10, |
| verbose=0, |
| random_state=0, |
| ) |
| t0 = time() |
| mbk.fit(X) |
| t_mini_batch = time() - t0 |
| result += "Time taken to run MiniBatchKMeans %0.2f seconds" % t_mini_batch + "\n" |
| mbk_means_labels_unique = np.unique(mbk.labels_) |
|
|
| ax = fig.add_subplot(1, 3, 3) |
| for this_centroid, k, col in zip(mbk.cluster_centers_, range(n_clusters), colors_): |
| mask = mbk.labels_ == k |
| ax.scatter(X[mask, 0], X[mask, 1], marker=".", c="w", edgecolor=col, alpha=0.5) |
| ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25) |
| ax.set_xlim([-25, 25]) |
| ax.set_ylim([-25, 25]) |
| ax.set_title("MiniBatchKMeans") |
| ax.set_autoscaley_on(False) |
|
|
| return fig, result |
|
|
|
|
| |
| theme = gr.themes.Monochrome( |
| primary_hue="indigo", |
| secondary_hue="blue", |
| neutral_hue="slate", |
| radius_size=gr.themes.sizes.radius_sm, |
| font=[ |
| gr.themes.GoogleFont("Open Sans"), |
| "ui-sans-serif", |
| "system-ui", |
| "sans-serif", |
| ], |
| ) |
|
|
| title = "Compare BIRCH and MiniBatchKMeans" |
| with gr.Blocks(title=title, theme=theme) as demo: |
| gr.Markdown(f"## {title}") |
| gr.Markdown( |
| "This is an interactive demo for this [scikit-learn example](https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html)." |
| ) |
|
|
| gr.Markdown( |
| "This example compares the timing of BIRCH (with and without the global clustering step) and \ |
| MiniBatchKMeans on a synthetic dataset having 25,000 samples and 2 features generated using make_blobs.\ |
| \n Both MiniBatchKMeans and BIRCH are very scalable algorithms and could run efficiently on hundreds of thousands or \ |
| even millions of datapoints. We chose to limit the dataset size of this example in the interest of keeping our \ |
| Continuous Integration resource usage reasonable but the interested reader might enjoy editing this script to \ |
| rerun it with a larger value for n_samples.\ |
| \n\n\ |
| If n_clusters is set to None, the data is reduced from 25,000 samples to a set of 158 clusters. This can be viewed as a preprocessing step before the final (global) clustering step that further reduces these 158 clusters to 100 clusters." |
| ) |
| |
| n_samples = gr.Slider( |
| minimum=20000, |
| maximum=80000, |
| label="Number of samples", |
| step=500, |
| value=25000, |
| ) |
| birch_threshold = gr.Slider( |
| minimum=0.5, |
| maximum=2.0, |
| label="Birch Threshold", |
| step=0.1, |
| value=1.7, |
| ) |
| birch_n_clusters = gr.Slider( |
| minimum=0, |
| maximum=100, |
| label="Birch number of clusters", |
| step=1, |
| value=100, |
| ) |
|
|
| plt_out = gr.Plot() |
| output = gr.Textbox(label="Output", multiline=True) |
|
|
| sub_btn = gr.Button("Submit") |
| sub_btn.click( |
| fn=do_submit, |
| inputs=[n_samples, birch_threshold, birch_n_clusters], |
| outputs=[plt_out, output], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|
|
|