Source code for tropea_clustering._internal.plot

"""Auxiliary functions for plotting the results of onion-clustering."""

import os
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from matplotlib.colors import rgb2hex
from matplotlib.patches import Ellipse
from matplotlib.ticker import MaxNLocator

from tropea_clustering._internal.functions import gaussian

COLORMAP = "viridis"


[docs] def plot_output_uni( title: str, input_data: np.ndarray, n_windows: int, state_list: List, ): """Plots clustering output with Gaussians and threshols. Parameters ---------- title : str The path of the .png file the figure will be saved as. input_data : ndarray of shape (n_particles * n_windows, tau_window) The input data array. n_windows : int The number of windows used. state_list : List[StateUni] The list of the cluster states. """ n_particles = int(input_data.shape[0] / n_windows) n_frames = n_windows * input_data.shape[1] input_data = np.reshape(input_data, (n_particles, n_frames)) flat_m = input_data.flatten() counts, bins = np.histogram(flat_m, bins=100, density=True) bins -= (bins[1] - bins[0]) / 2 counts *= flat_m.size fig, axes = plt.subplots( 1, 2, sharey=True, gridspec_kw={"width_ratios": [3, 1]}, figsize=(9, 4.8), ) axes[1].stairs( counts, bins, fill=True, orientation="horizontal", alpha=0.5 ) palette = [] n_states = len(state_list) cmap = plt.get_cmap(COLORMAP, n_states + 1) for i in range(1, cmap.N): rgba = cmap(i) palette.append(rgb2hex(rgba)) t_steps = input_data.shape[1] time = np.linspace(0, t_steps - 1, t_steps) step = 1 if input_data.size > 1e6: step = 10 for mol in input_data[::step]: axes[0].plot( time, mol, c="xkcd:black", ms=0.1, lw=0.1, alpha=0.5, rasterized=True, ) for state_id, state in enumerate(state_list): attr = state.get_attributes() popt = [attr["mean"], attr["sigma"], attr["area"]] axes[1].plot( gaussian(np.linspace(bins[0], bins[-1], 1000), *popt), np.linspace(bins[0], bins[-1], 1000), color=palette[state_id], ) style_color_map = { 0: ("--", "xkcd:black"), 1: ("--", "xkcd:blue"), 2: ("--", "xkcd:red"), } time2 = np.linspace( time[0] - 0.05 * (time[-1] - time[0]), time[-1] + 0.05 * (time[-1] - time[0]), 100, ) for state_id, state in enumerate(state_list): th_inf = state.get_attributes()["th_inf"] th_sup = state.get_attributes()["th_sup"] linestyle, color = style_color_map.get(th_inf[1], ("-", "xkcd:black")) axes[1].hlines( th_inf[0], xmin=0.0, xmax=np.amax(counts), linestyle=linestyle, color=color, ) axes[0].fill_between( time2, th_inf[0], th_sup[0], color=palette[state_id], alpha=0.25, ) axes[1].hlines( state_list[-1].get_attributes()["th_sup"][0], xmin=0.0, xmax=np.amax(counts), linestyle=linestyle, color="black", ) # Set plot titles and axis labels axes[0].set_ylabel("Signal") axes[0].set_xlabel(r"Time [frame]") axes[1].set_xticklabels([]) fig.savefig(title, dpi=600)
[docs] def plot_one_trj_uni( title: str, example_id: int, input_data: np.ndarray, labels: np.ndarray, n_windows: int, ): """Plots the colored trajectory of one example particle. Parameters ---------- title : str The path of the .png file the figure will be saved as. example_id : int The ID of the selected particle. input_data : ndarray of shape (n_particles * n_windows, tau_window) The input data array. labels : ndarray of shape (n_particles * n_windows,) The output of the clustering algorithm. n_windows : int The number of windows used. """ tau_window = input_data.shape[1] n_particles = int(input_data.shape[0] / n_windows) n_frames = n_windows * tau_window input_data = np.reshape(input_data, (n_particles, n_frames)) labels = np.reshape(labels, (n_particles, n_windows)) labels = np.repeat(labels, tau_window, axis=1) signal = input_data[example_id][: labels.shape[1]] t_steps = labels.shape[1] time = np.linspace(0, t_steps - 1, t_steps) fig, axes = plt.subplots() unique_labels = np.unique(labels) # If there are no assigned window, we still need the "0" state # for consistency: if -1 not in unique_labels: unique_labels = np.insert(unique_labels, 0, -1) cmap = plt.get_cmap( COLORMAP, np.max(unique_labels) - np.min(unique_labels) + 1 ) color = labels[example_id] + 1 axes.plot(time, signal, c="black", lw=0.1) axes.scatter( time, signal, c=color, cmap=cmap, vmin=np.min(unique_labels) + 1, vmax=np.max(unique_labels) + 1, s=1.0, ) # Add title and labels to the axes fig.suptitle(f"Example particle: ID = {example_id}") axes.set_xlabel("Time [frame]") axes.set_ylabel("Signal") fig.savefig(title, dpi=600)
[docs] def plot_state_populations( title: str, n_windows: int, labels: np.ndarray, ): """ Plot the populations of states over time. Parameters ---------- title : str The path of the .png file the figure will be saved as. n_windows : int The number of windows used. labels : ndarray of shape (n_particles * n_windows,) The output of the clustering algorithm. Notes ----- If all the points are classified, we still need the "-1" state for consistency. """ n_particles = int(labels.shape[0] / n_windows) labels = np.reshape(labels, (n_particles, n_windows)) unique_labels = np.unique(labels) if -1 not in unique_labels: unique_labels = np.insert(unique_labels, 0, -1) list_of_populations = [] for label in unique_labels: population = np.sum(labels == label, axis=0) list_of_populations.append(population / n_particles) palette = [] n_states = unique_labels.size cmap = plt.get_cmap(COLORMAP, n_states) for i in range(cmap.N): rgba = cmap(i) palette.append(rgb2hex(rgba)) fig, axes = plt.subplots() t_steps = labels.shape[1] time = np.linspace(0, t_steps - 1, t_steps) for label, pop in enumerate(list_of_populations): axes.plot(time, pop, label=f"ENV{label}", color=palette[label]) axes.set_xlabel(r"Time [frame]") axes.set_ylabel(r"Population") axes.legend() fig.savefig(title, dpi=600)
[docs] def plot_medoids_uni( title: str, input_data: np.ndarray, labels: np.ndarray, ): """ Compute and plot the average signal sequence inside each state. Parameters ---------- title : str The path of the .png file the figure will be saved as. input_data : ndarray of shape (n_particles * n_windows, tau_window) The input data array. labels : ndarray of shape (n_particles * n_windows,) The output of the clustering algorithm. Notes ----- - If all the points are classified, we still need the "-1" state for consistency. - Prints the output to files. """ center_list = [] std_list = [] env0 = [] list_of_labels = np.unique(labels) if -1 not in list_of_labels: list_of_labels = np.insert(list_of_labels, 0, -1) for ref_label in list_of_labels: tmp = [] for i, label in enumerate(labels): if label == ref_label: tmp.append(input_data[i]) if len(tmp) > 0 and ref_label > -1: center_list.append(np.mean(tmp, axis=0)) std_list.append(np.std(tmp, axis=0)) elif len(tmp) > 0: env0 = tmp center_arr = np.array(center_list) std_arr = np.array(std_list) np.savetxt( "medoid_center.txt", center_arr, header="Signal average for each ENV", ) np.savetxt( "medoid_stddev.txt", std_arr, header="Signal standard deviation for each ENV", ) palette = [] cmap = plt.get_cmap(COLORMAP, list_of_labels.size) palette.append(rgb2hex(cmap(0))) for i in range(1, cmap.N): rgba = cmap(i) palette.append(rgb2hex(rgba)) fig, axes = plt.subplots() time_seq = range(input_data.shape[1]) for center_id, center in enumerate(center_list): err_inf = center - std_list[center_id] err_sup = center + std_list[center_id] axes.fill_between( time_seq, err_inf, err_sup, alpha=0.25, color=palette[center_id + 1], ) axes.plot( time_seq, center, label=f"ENV{center_id + 1}", marker="o", c=palette[center_id + 1], ) for window in env0: axes.plot( time_seq, window, lw=0.1, c=palette[0], zorder=0, alpha=0.2, ) fig.suptitle("Average time sequence inside each environments") axes.set_xlabel(r"Time [frames]") axes.set_ylabel(r"Signal") axes.xaxis.set_major_locator(MaxNLocator(integer=True)) axes.legend(loc="lower left") fig.savefig(title, dpi=600)
[docs] def plot_sankey( title: str, labels: np.ndarray, n_windows: int, tmp_frame_list: list[int], ): """ Plots the Sankey diagram at the desired frames. Parameters ---------- title : str The path of the .png file the figure will be saved as. labels : ndarray of shape (n_particles * n_windows,) The output of the clustering algorithm. n_windows : int The number of windows used. tmp_frame_list : List[int] The list of windows at which we want to plot the Sankey. Notes ----- - If there are no assigned window, we still need the "-1" state for consistency - Requires kaleido. """ n_particles = int(labels.shape[0] / n_windows) all_the_labels = np.reshape(labels, (n_particles, n_windows)) frame_list = np.array(tmp_frame_list) unique_labels = np.unique(all_the_labels) if -1 not in unique_labels: unique_labels = np.insert(unique_labels, 0, -1) n_states = unique_labels.size source = np.empty((frame_list.size - 1) * n_states**2) target = np.empty((frame_list.size - 1) * n_states**2) value = np.empty((frame_list.size - 1) * n_states**2) count = 0 tmp_label1 = [] tmp_label2 = [] # Loop through the frame_list and calculate the transition matrix # for each time window. for i, t_0 in enumerate(frame_list[:-1]): # Calculate the time jump for the current time window. t_jump = frame_list[i + 1] - frame_list[i] trans_mat = np.zeros((n_states, n_states)) # Iterate through the current time window and increment # the transition counts in trans_mat for label in all_the_labels: trans_mat[label[t_0] + 1][label[t_0 + t_jump] + 1] += 1 # Store the source, target, and value for the Sankey diagram # based on trans_mat for j, row in enumerate(trans_mat): for k, elem in enumerate(row): source[count] = j + i * n_states target[count] = k + (i + 1) * n_states value[count] = elem count += 1 # Calculate the starting and ending fractions for each state # and store node labels for j in range(-1, n_states - 1): start_fr = np.sum(trans_mat[j]) / np.sum(trans_mat) end_fr = np.sum(trans_mat.T[j]) / np.sum(trans_mat) if i == -1: tmp_label1.append(f"State {j}: {start_fr * 100:.2f}%") tmp_label2.append(f"State {j}: {end_fr * 100:.2f}%") arr_label1 = np.array(tmp_label1) arr_label2 = np.array(tmp_label2).flatten() # Concatenate the temporary labels to create the final node labels. label = np.concatenate((arr_label1, arr_label2)) # Generate a color palette for the Sankey diagram. palette = [] cmap = plt.get_cmap(COLORMAP, n_states) for i in range(cmap.N): rgba = cmap(i) palette.append(rgb2hex(rgba)) # Tile the color palette to match the number of frames. color = np.tile(palette, frame_list.size) # Create dictionaries to define the Sankey diagram nodes and links. node = {"label": label, "pad": 30, "thickness": 20, "color": color} link = {"source": source, "target": target, "value": value} # Create the Sankey diagram using Plotly. sankey_data = go.Sankey(link=link, node=node, arrangement="perpendicular") fig = go.Figure(sankey_data) # Add the title with the time information. fig.update_layout(title=f"Frames: {frame_list}") fig.write_image(title, scale=5.0)
[docs] def plot_time_res_analysis( title: str, tra: np.ndarray, ): """ Plots the results of clustering at different time resolutions. Plots the number of states (including the unclassified points) and the fraction of unclassificed data points as a function of the time resolution used. Parameters ---------- title : str The path of the .png file the figure will be saved as. tra : ndarray of shape (n_windows, 3) Contains the number of states and the population of ENV0 at every tau_window. """ fig, axes = plt.subplots() axes.plot(tra[:, 0], tra[:, 1], marker="o") axes.set_xlabel(r"Time resolution $\Delta t$ [frame]") axes.set_ylabel(r"# environments", weight="bold", c="#1f77b4") axes.set_xscale("log") axes.set_ylim(-0.2, np.max(tra[:, 1]) + 0.2) axes.yaxis.set_major_locator(MaxNLocator(integer=True)) axesr = axes.twinx() axesr.plot(tra[:, 0], tra[:, 2], marker="o", c="#ff7f0e") axesr.set_ylabel("Population of env 0", weight="bold", c="#ff7f0e") axesr.set_ylim(-0.02, 1.02) fig.savefig(title, dpi=600)
[docs] def plot_pop_fractions( title: str, list_of_pop: List[List[float]], ): """ Plot, for every time resolution, the populations of the ENVs. Parameters ---------- title : str The path of the .png file the figure will be saved as. list_of_pop : List[List[float]] For every tau_window, this is the list of the populations of all the states (the first one is ENV0). Notes ----- The bottom state is the ENV0. """ max_num_of_states = np.max([len(pop_list) for pop_list in list_of_pop]) for _, pop_list in enumerate(list_of_pop): while len(pop_list) < max_num_of_states: pop_list.append(0.0) pop_array = np.array(list_of_pop) fig, axes = plt.subplots() width = 1 min_tau_w = 2 time = range(min_tau_w, pop_array.shape[0] + min_tau_w) bottom = np.zeros(len(pop_array)) for _, state in enumerate(pop_array.T): _ = axes.bar(time, state, width, bottom=bottom, edgecolor="black") bottom += state axes.set_xlabel(r"Time resolution $\Delta t$ [frames]") axes.set_ylabel(r"Population's fractions") axes.set_xscale("log") fig.savefig(title, dpi=600)
[docs] def plot_medoids_multi( title: str, tau_window: int, input_data: np.ndarray, labels: np.ndarray, ): """ Compute and plot the average signal sequence inside each state. Parameters ---------- title : str The path of the .png file the figure will be saved as. tau_window : int The length of the signal window used. input_data : ndarray of shape (n_dims, n_particles, n_frames) The input data array. labels : ndarray of shape (n_particles * n_windows,) The output of the clustering algorithm. Notes ----- - If there are no assigned window, we still need the "-1" state for consistency - Prints the output to file """ ndims = input_data.shape[0] if ndims != 2: print("plot_medoids_multi() does not work with 3D data.") return list_of_labels = np.unique(labels) if -1 not in list_of_labels: list_of_labels = np.insert(list_of_labels, 0, -1) center_list = [] env0 = [] reshaped_data = input_data.transpose(1, 2, 0) labels = np.repeat(labels, tau_window) reshaped_labels = np.reshape( labels, (input_data.shape[1], input_data.shape[2]) ) for ref_label in list_of_labels: tmp = [] for i, mol in enumerate(reshaped_labels): for window, label in enumerate(mol[::tau_window]): if label == ref_label: time_0 = window * tau_window time_1 = (window + 1) * tau_window tmp.append(reshaped_data[i][time_0:time_1]) if len(tmp) > 0 and ref_label > -1: center_list.append(np.mean(tmp, axis=0)) elif len(tmp) > 0: env0 = tmp center_arr = np.array(center_list) np.save( "medoid_center.npy", center_arr, ) palette = [] cmap = plt.get_cmap(COLORMAP, list_of_labels.size) palette.append(rgb2hex(cmap(0))) for i in range(1, cmap.N): rgba = cmap(i) palette.append(rgb2hex(rgba)) fig, axes = plt.subplots() for id_c, center in enumerate(center_list): sig_x = center[:, 0] sig_y = center[:, 1] axes.plot( sig_x, sig_y, label=f"ENV{id_c + 1}", marker="o", c=palette[id_c + 1], ) for win in env0: axes.plot( win.T[0], win.T[1], lw=0.1, c=palette[0], zorder=0, alpha=0.25, ) fig.suptitle("Average time sequence inside each environments") axes.set_xlabel(r"Signal 1") axes.set_ylabel(r"Signal 2") axes.legend() fig.savefig(title, dpi=600)
[docs] def plot_output_multi( title: str, input_data: np.ndarray, state_list: List, labels: np.ndarray, tau_window: int, ): """ Plot a cumulative figure showing trajectories and identified states. Parameters ---------- title : str The path of the .png file the figure will be saved as. input_data : ndarray of shape (n_dims, n_particles, n_frames) The input data array. state_list : List[StateUni] The list of the cluster states. labels : ndarray of shape (n_particles * n_windows,) The output of the clustering algorithm. tau_window : int The length of the signal window used. """ n_states = len(state_list) + 1 tmp = plt.get_cmap(COLORMAP, n_states) colors_from_cmap = tmp(np.arange(0, 1, 1 / n_states)) colors_from_cmap[-1] = tmp(1.0) m_clean = input_data.transpose(1, 2, 0) n_windows = int(m_clean.shape[1] / tau_window) tmp_labels = labels.reshape((m_clean.shape[0], n_windows)) all_the_labels = np.repeat(tmp_labels, tau_window, axis=1) if m_clean.shape[2] == 3: fig, ax = plt.subplots(2, 2, figsize=(6, 6)) dir0 = [0, 0, 1] dir1 = [1, 2, 2] ax0 = [0, 0, 1] ax1 = [0, 1, 0] for k in range(3): d_0 = dir0[k] d_1 = dir1[k] a_0 = ax0[k] a_1 = ax1[k] # Plot the individual trajectories id_max, id_min = 0, 0 for idx, mol in enumerate(m_clean): if np.max(mol) == np.max(m_clean): id_max = idx if np.min(mol) == np.min(m_clean): id_min = idx line_w = 0.05 max_t = all_the_labels.shape[1] m_resized = m_clean[:, :max_t:, :] step = 5 if m_resized.size > 1000000 else 1 for i, mol in enumerate(m_resized[::step]): ax[a_0][a_1].plot( mol.T[d_0], mol.T[d_1], c="black", lw=line_w, rasterized=True, zorder=0, ) color_list = all_the_labels[i * step] + 1 ax[a_0][a_1].scatter( mol.T[d_0], mol.T[d_1], c=color_list, cmap=COLORMAP, vmin=0, vmax=n_states - 1, s=0.5, rasterized=True, ) color_list = all_the_labels[id_min] + 1 ax[a_0][a_1].plot( m_resized[id_min].T[d_0], m_resized[id_min].T[d_1], c="black", lw=line_w, rasterized=True, zorder=0, ) ax[a_0][a_1].scatter( m_resized[id_min].T[d_0], m_resized[id_min].T[d_1], c=color_list, cmap=COLORMAP, vmin=0, vmax=n_states - 1, s=0.5, rasterized=True, ) color_list = all_the_labels[id_max] + 1 ax[a_0][a_1].plot( m_resized[id_max].T[d_0], m_resized[id_max].T[d_1], c="black", lw=line_w, rasterized=True, zorder=0, ) ax[a_0][a_1].scatter( m_resized[id_max].T[d_0], m_resized[id_max].T[d_1], c=color_list, cmap=COLORMAP, vmin=0, vmax=n_states - 1, s=0.5, rasterized=True, ) # Plot the Gaussian distributions of states if k == 0: for state in state_list: att = state.get_attributes() ellipse = Ellipse( tuple(att["mean"]), att["axis"][d_0], att["axis"][d_1], color="black", fill=False, ) ax[a_0][a_1].add_patch(ellipse) # Set plot titles and axis labels ax[a_0][a_1].set_xlabel(f"Signal {d_0}") ax[a_0][a_1].set_ylabel(f"Signal {d_1}") ax[1][1].axis("off") fig.savefig(title, dpi=600) plt.close(fig) elif m_clean.shape[2] == 2: fig, ax = plt.subplots(figsize=(6, 6)) # Plot the individual trajectories id_max, id_min = 0, 0 for idx, mol in enumerate(m_clean): if np.max(mol) == np.max(m_clean): id_max = idx if np.min(mol) == np.min(m_clean): id_min = idx line_w = 0.05 max_t = all_the_labels.shape[1] m_resized = m_clean[:, :max_t:, :] step = 5 if m_resized.size > 1000000 else 1 for i, mol in enumerate(m_resized[::step]): ax.plot( mol.T[0], mol.T[1], c="black", lw=line_w, rasterized=True, zorder=0, ) color_list = all_the_labels[i * step] + 1 ax.scatter( mol.T[0], mol.T[1], c=color_list, cmap=COLORMAP, vmin=0, vmax=n_states - 1, s=0.5, rasterized=True, ) color_list = all_the_labels[id_min] + 1 ax.plot( m_resized[id_min].T[0], m_resized[id_min].T[1], c="black", lw=line_w, rasterized=True, zorder=0, ) ax.scatter( m_resized[id_min].T[0], m_resized[id_min].T[1], c=color_list, cmap=COLORMAP, vmin=0, vmax=n_states - 1, s=0.5, rasterized=True, ) color_list = all_the_labels[id_max] + 1 ax.plot( m_resized[id_max].T[0], m_resized[id_max].T[1], c="black", lw=line_w, rasterized=True, zorder=0, ) ax.scatter( m_resized[id_max].T[0], m_resized[id_max].T[1], c=color_list, cmap=COLORMAP, vmin=0, vmax=n_states - 1, s=0.5, rasterized=True, ) # Plot the Gaussian distributions of states for state in state_list: att = state.get_attributes() ellipse = Ellipse( tuple(att["mean"]), att["axis"][0], att["axis"][1], color="black", fill=False, ) ax.add_patch(ellipse) # Set plot titles and axis labels ax.set_xlabel(r"$x$") ax.set_ylabel(r"$y$") fig.savefig(title, dpi=600)
[docs] def plot_one_trj_multi( title: str, example_id: int, tau_window: int, input_data: np.ndarray, labels: np.ndarray, ): """Plots the colored trajectory of an example particle. Parameters ---------- title : str The path of the .png file the figure will be saved as. example_id : int The ID of the selected particle. tau_window : int The length of the signal window used. input_data : ndarray of shape (n_dims, n_particles, n_frames) The input data array. labels : ndarray of shape (n_particles * n_windows,) The output of the clustering algorithm. """ m_clean = input_data.transpose(1, 2, 0) n_windows = int(m_clean.shape[1] / tau_window) tmp_labels = labels.reshape((m_clean.shape[0], n_windows)) all_the_labels = np.repeat(tmp_labels, tau_window, axis=1) # Get the signal of the example particle sig_x = m_clean[example_id].T[0][: all_the_labels.shape[1]] sig_y = m_clean[example_id].T[1][: all_the_labels.shape[1]] fig, ax = plt.subplots(figsize=(6, 6)) # Create a colormap to map colors to the labels cmap = plt.get_cmap( COLORMAP, int( np.max(np.unique(all_the_labels)) - np.min(np.unique(all_the_labels)) + 1 ), ) color = all_the_labels[example_id] ax.plot(sig_x, sig_y, c="black", lw=0.1) ax.scatter( sig_x, sig_y, c=color, cmap=cmap, vmin=np.min(np.unique(all_the_labels)), vmax=np.max(np.unique(all_the_labels)), s=1.0, zorder=10, ) # Set plot titles and axis labels fig.suptitle(f"Example particle: ID = {example_id}") ax.set_xlabel(r"$x$") ax.set_ylabel(r"$y$") fig.savefig(title, dpi=600)
[docs] def color_trj_from_xyz( trj_path: str, labels: np.ndarray, n_particles: int, tau_window: int, ): """ Saves a colored .xyz file ('colored_trj.xyz') in the working directory. In the input file, the (x, y, z) coordinates of the particles need to be stored in the second, third and fourth column respectively. Parameters ---------- trj_path : str The path to the input .xyz trajectory. labels : np.ndarray (n_particles * n_windows,) The output of the clustering algorithm. n_particles : int The number of particles in the system. tau_window : int The length of the signal windows. """ if os.path.exists(trj_path): with open(trj_path, "r", encoding="utf-8") as in_file: tmp = [line.strip().split() for line in in_file] tmp_labels = labels.reshape((n_particles, -1)) all_the_labels = np.repeat(tmp_labels, tau_window, axis=1) + 1 total_time = int(labels.shape[0] / n_particles) * tau_window nlines = (n_particles + 2) * total_time tmp = tmp[:nlines] with open("colored_trj.xyz", "w+", encoding="utf-8") as out_file: i = 0 for j in range(total_time): print(tmp[i][0], file=out_file) print("Properties=species:S:1:pos:R:3", file=out_file) for k in range(n_particles): print( all_the_labels[k][j], tmp[i + 2 + k][1], tmp[i + 2 + k][2], tmp[i + 2 + k][3], file=out_file, ) i += n_particles + 2 else: raise ValueError(f"ValueError: {trj_path} not found.")