Source code for scitex_ml.clustering._umap

#!./env/bin/python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-09-12 05:37:55 (ywatanabe)"
# _umap_dev.py


"""
This script does XYZ.
"""

import scitex_io
import scitex_plt
from scitex_dev import try_import_optional

# `figrecipe` is the underlying impl that the umbrella exposes as
# `scitex.plt` — `ax`, `configure_mpl`, `_subplots`, etc. live there.
# Use the peer standalone directly (PA304-clean); fall back to None
# when the optional plotting layer isn't installed.
figrecipe = try_import_optional("figrecipe", extra="plt", pkg="scitex-ml")
_FIGRECIPE_AVAILABLE = figrecipe is not None


"""
Imports
"""
import sys

import matplotlib.pyplot as plt
import numpy as np
from natsort import natsorted
from sklearn.preprocessing import LabelEncoder

# `umap` (umap-learn) eagerly imports tensorflow via parametric_umap. Defer
# the import to call time so the broader scitex_ml.clustering package can
# be imported even when the env's TF is broken or unavailable.


# sys.path = ["."] + sys.path
# from scripts import utils, load

"""
Warnings
"""
# warnings.simplefilter("ignore", UserWarning)


"""
Config
"""
# CONFIG = scitex.gen.load_configs()


"""
Functions & Classes
"""


[docs] def umap( data, labels, hues=None, hues_colors=None, axes=None, axes_titles=None, supervised=False, title="UMAP Clustering", alpha=1.0, s=3, use_independent_legend=False, add_super_imposed=False, umap_model=None, ): """ Perform UMAP clustering and visualization. Parameters ---------- data_all : list List of data arrays to cluster labels_all : list List of label arrays corresponding to data_all hues_all : list, optional List of hue arrays for coloring points hues_colors_all : list, optional List of color mappings for hues axes : matplotlib.axes.Axes, optional Existing axes to plot on axes_titles : list, optional Titles for each subplot supervised : bool, optional Whether to use supervised UMAP title : str, optional Main title for the plot alpha : float, optional Transparency of points s : int, optional Size of points use_independent_legend : bool, optional Whether to create separate legend figures add_super_imposed : bool, optional Whether to add a superimposed plot umap_model : umap.UMAP, optional Pre-fitted UMAP model Returns ------- tuple Figure, legend figures (if applicable), and UMAP model """ # Renaming data_all = data labels_all = labels hues_all = hues hues_colors_all = hues_colors data_all, labels_all, hues_all, hues_colors_all = _check_input_vars( data_all, labels_all, hues_all, hues_colors_all ) # Label Encoding le = LabelEncoder() le.fit(natsorted(np.hstack(labels_all))) labels_all = [le.transform(labels) for labels in labels_all] # Running UMAP Clustering _umap = _run_umap(umap_model, data_all, labels_all, supervised, title) # Plotting fig, legend_figs = _plot( _umap, le, data_all, labels_all, hues_all, hues_colors_all, add_super_imposed, axes, title, axes_titles, use_independent_legend, s, alpha, ) return fig, legend_figs, _umap
def _plot( _umap, le, data_all, labels_all, hues_all, hues_colors_all, add_super_imposed, axes, title, axes_titles, use_independent_legend, s, alpha, ): # Plotting ncols = len(data_all) + 1 if add_super_imposed else len(data_all) share = True if ncols > 1 else False if axes is None: fig, axes = scitex_plt.subplots(ncols=ncols, sharex=share, sharey=share) else: assert len(axes) == ncols fig = ( axes[0].get_figure() # axes if isinstance(axes, (np.ndarray, figrecipe._subplots.AxesWrapper)) # axis else axes.get_figure() ) fig.supxlabel("UMAP 1") fig.supylabel("UMAP 2") fig.suptitle(title) for ii, (data, labels, hues, hues_colors) in enumerate( zip(data_all, labels_all, hues_all, hues_colors_all) ): embedding = _umap.transform(data) # ax if ncols == 1: ax = axes else: ax = axes[ii + 1] if add_super_imposed else axes[ii] _hues = le.inverse_transform(labels) if hues is None else hues for hue in np.unique(_hues): indi = hue == np.array(_hues) if hues_colors: colors = np.vstack(hues_colors)[indi] colors = [colors[ii] for ii in range(len(colors))] else: colors = None ax.scatter( x=embedding[:, 0][indi], y=embedding[:, 1][indi], label=hue, c=colors, s=s, alpha=alpha, ) ax.set_box_aspect(1) if axes_titles is not None: ax.set_title(axes_titles[ii]) # Merged axis if add_super_imposed: ax = axes[0] _hues = le.inverse_transform(labels) if hues is None else hues for hue in np.unique(_hues): indi = hue == np.array(_hues) ax.scatter( x=embedding[:, 0][indi], y=embedding[:, 1][indi], label=hue, c=np.vstack(hues_colors)[indi][0], s=s, alpha=alpha, ) ax.set_title("Superimposed") ax.set_box_aspect(1) # ax.sns_scatterplot( # x=embedding[:, 0], # y=embedding[:, 1], # hue=le.inverse_transform(labels) if hues is None else hues, # palette=hues_colors, # legend="full" if ii == 0 else False, # s=s, # alpha=alpha, # ) # Sharing was already requested at subplots() creation # (sharex=share, sharey=share); the legacy `figrecipe.ax.sharex/sharey` # helpers don't exist on the standalone figrecipe surface. if not use_independent_legend: # `axes` can be a single Axes (ncols==1) or an array (ncols>1); # np.atleast_1d makes both iterable via .flat. for ax in np.atleast_1d(axes).flat: ax.legend(loc="upper left") return fig, None elif use_independent_legend: legend_figs = [] for i, ax in enumerate(axes): legend = ax.get_legend() if legend: legend_fig = plt.figure(figsize=(3, 2)) new_legend = legend_fig.gca().legend( handles=legend.get_lines(), labels=[t.get_text() for t in legend.texts], loc="center", ) # new_legend = legend_fig.gca().legend( # handles=legend.legendHandles, # labels=legend.texts, # loc="center", # ) # legend_fig.canvas.draw() legend_figs.append(legend_fig) ax.get_legend().remove() for ax in axes: ax.legend_ = None # elif use_independent_legend: # legend_figs = [] # for i, ax in enumerate(axes): # legend = ax.get_legend() # if legend: # legend_fig = plt.figure(figsize=(3, 2)) # new_legend = legend_fig.gca().legend( # handles=legend.legendHandles, # labels=legend.texts, # loc="center", # ) # legend_fig.canvas.draw() # legend_filename = f"legend_{i}.png" # legend_fig.savefig(legend_filename, bbox_inches="tight") # legend_figs.append(legend_fig) # plt.close(legend_fig) # for ax in axes: # ax.legend_ = None return fig, legend_figs def _run_umap(umap_model, data_all, labels_all, supervised, title): # UMAP Clustering if not umap_model: # Lazy import — `umap.umap_` pulls tensorflow via parametric_umap; # we keep it out of module scope so the package imports cleanly even # when the env's TF stack is broken. import umap.umap_ as umap_orig umap_model = umap_orig.UMAP(random_state=42) supervised_label_or_none = labels_all[0] if supervised else None title = f"(Supervised) {title}" if supervised else f"(Unsupervised) {title}" _umap = umap_model.fit(data_all[0], y=supervised_label_or_none) else: _umap = umap_model return _umap def _check_input_vars(data_all, labels_all, hues_all, hues_colors_all): # Ensures input formats if hues_all is None: hues_all = [None for _ in range(len(data_all))] if hues_colors_all is None: hues_colors_all = [None for _ in range(len(data_all))] assert len(data_all) == len(labels_all) == len(hues_all) == len(hues_colors_all) assert ( isinstance(data_all, list) and isinstance(labels_all, list) and isinstance(hues_all, list) and isinstance(hues_colors_all, list) ) return data_all, labels_all, hues_all, hues_colors_all def _test(dataset_str="iris"): from sklearn.datasets import load_digits, load_iris from sklearn.model_selection import train_test_split # Load iris dataset load_dataset = {"iris": load_iris, "mnist": load_digits}[dataset_str] dataset = load_dataset() X = dataset.data y = dataset.target # Split data into two parts X1, X2, y1, y2 = train_test_split(X, y, test_size=0.5, random_state=42) # Call umap function fig, legend_figs, umap_model = umap( data=[X1, X2], labels=[y1, y2], # axes=axes, axes_titles=[f"{dataset_str} Set 1", f"{dataset_str} Set 2"], supervised=True, title=dataset_str, use_independent_legend=True, s=10, ) # plt.tight_layout() scitex_io.save(fig, f"/tmp/scitex/umap/{dataset_str}.jpg") # Save legend figures if any if legend_figs: for i, leg_fig in enumerate(legend_figs): scitex_io.save(leg_fig, f"/tmp/scitex/umap/{dataset_str}_legend_{i}.jpg") main = umap if __name__ == "__main__": import scitex # # Argument Parser # import argparse # parser = argparse.ArgumentParser(description='') # parser.add_argument('--var', '-v', type=int, default=1, help='') # parser.add_argument('--flag', '-f', action='store_true', default=False, help='') # args = parser.parse_args() # Main CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( sys, plt, verbose=False, agg=True ) _test(dataset_str="mnist") # main() scitex.session.close(CONFIG, verbose=False, notify=False) # EOF