Source code for locpix.preprocessing.datastruc

"""Datastruc module.

This module contains definitions of the datastructure the
SMLM dataitem will be parsed as.

"""

import numpy as np
import napari
import matplotlib.pyplot as plt
import polars as pl
import pyarrow.parquet as pq
import ast
import os
import json

_interpolate = {
    "log2": lambda d: np.log2(d),
    "log10": lambda d: np.log10(d),
    "linear": lambda d: d,
}


[docs] class item: """smlm datastructure. This is the basic datastructure which will contain all the information from a point set that is needed Attributes: name (string) : Contains the name of the item df (polars dataframe): Dataframe with the data contained, containing columns: 'channel' 'frame' 'x' 'y' 'z'. If manual annotation is done an additional column 'gt_label' will be present dim (int): Dimensions of data channels (list): list of ints, representing channels user wants to consider in the original data channel_label (list of strings) : The label for each channel i.e. ['egfr', 'ereg','unk'] means channel 0 is egfr protein, channel 1 is ereg proteins and channel 2 is unknown histo (dict): Dictionary of 2D or 3D arrays. Each key corresponds to the channel for which the histogram contains the relevant binned data, in form [X,Y,Z] i.e. histo[1] = histogram of channel 1 localisations. Note that if considering an image, need to transpose the histogram to follow image conventions. histo_edges (tuple of lists; each list contains floats): Tuple containing (x_edges,y_edges) or (x_edges, y_edges, z_edges) where x/y/z_edges are list of floats, each representing the edge of the bin in the original space. e.g. ([0,10,20],[13,25,20],[2,3,4]) histo_mask (numpy array): Array containing integers where each should represent a different label of the MANUAL segmentation 0 is reserved for background, is of form [X,Y,Z] bin_sizes (tuple of floats): Size of bins of the histogram e.g. (23.2, 34.5, 21.3) gt_label_map (dict): Dictionary with integer keys representing the gt labels for each localisation with value being a string, representing the real concept e.g. 0:'dog', 1:'cat' """
[docs] def __init__( self, name, df, dim, channels, channel_label, histo={}, histo_edges=None, histo_mask=None, bin_sizes=None, gt_label_map={}, ): """Initialises item""" self.name = name self.df = df self.histo = histo self.histo_edges = histo_edges self.histo_mask = histo_mask self.dim = dim self.bin_sizes = bin_sizes self.channels = channels self.channel_label = channel_label self.gt_label_map = gt_label_map # channel labels and channel choice need to be same in length if (channels is not None) or (channel_label is not None): if len(channels) != len(channel_label): raise ValueError( f"Labels for each channel is length {len(channel_label)}\n" f"Channels is length {len(channels)}\n" f"These must be the same length!" )
# def save(self, save_loc): # """Save the item # Args: # save_loc (string): Location to save the .pkl file""" # dict = {"name": self.name, "df": self.df, "dim": self.dim, # "channels": self.channels, "histo": self.histo, # "histo_edges": self.histo_edges, # "histo_mask": self.histo_mask, # "bin_sizes": self.bin_sizes} # pickle.dump(dict, open(save_loc, "wb"), pickle.HIGHEST_PROTOCOL) # def load(self, input_file): # """ Loads item saved as .pkl file # Args: # input_file (string) : Location of the .pkl file to # load dataitem from""" # with open(input_file, 'rb') as f: # dict = pickle.load(f) # self.__init__(name=dict['name'], df=dict['df'], dim=dict['dim'], # channels=dict['channels'], # histo=dict['histo'], # histo_edges=dict['histo_edges'], # histo_mask=dict['histo_mask'], # bin_sizes=dict['bin_sizes'])
[docs] def chan_2_label(self, chan): """Returns the label associated with the channel specified Args: chan (int) : Integer representing the channel""" return self.channel_label[chan]
[docs] def label_2_chan(self, label): """Returns the channel associated with the channel label specified Args: label (string) : String representing the label you want to find the channel for""" if label not in self.channel_label: raise ValueError( "The label specified is not present in" "the channel labels" ) return self.channel_label.index(label)
[docs] def coord_2_histo( self, histo_size, ): """Converts localisations into histogram of desired size, with option to plot the image (histo.T). Note the interpolation is only applied for visualisation, not for the actual data in the histogram! Args: histo_size (tuple): Tuple representing number of bins/pixels in x,y,z""" # get max and min x/y/(z) values df_max = self.df.max() df_min = self.df.min() if self.dim == 2: x_bins, y_bins = histo_size x_max = df_max["x"][0] y_max = df_max["y"][0] x_min = df_min["x"][0] y_min = df_min["y"][0] elif self.dim == 3: x_bins, y_bins, z_bins = histo_size z_max = df_max["z"][0] z_min = df_min["z"][0] # if instead want desired bin size e.g. 50nm, 50nm, 50nm # number of bins required for desired bin_size # note need to check if do this that agrees with np.digitize # and need to make sure that same issue we had before # with the last localisation is dealt with # x_bins = int((self.max['x'] - self.min['x']) / bin_size[0]) # y_bins = int((self.max['y'] - self.min['y']) / bin_size[1]) # z_bins = int((self.max['z'] - self.min['z']) / bin_size[2]) # size of actual bins, given the number of bins (should be # very close to desired tests size) x_bin_size = (x_max - x_min) / x_bins y_bin_size = (y_max - y_min) / y_bins # need to increase bin size very marginally to include last localisation x_bin_size = x_bin_size * 1.001 y_bin_size = y_bin_size * 1.001 # location of edges of histogram, based on actual tests size x_edges = [x_min + x_bin_size * i for i in range(x_bins + 1)] y_edges = [y_min + y_bin_size * i for i in range(y_bins + 1)] # treat z separately, as often only in 2D if self.dim == 3: z_bin_size = (z_max - z_min) / z_bins # need to increase bin size very marginally to include last localisation z_bin_size = z_bin_size * 1.001 z_edges = [z_min + z_bin_size * i for i in range(z_bins + 1)] # size per tests in nm; location of histo edges in original space if self.dim == 2: self.bin_sizes = (x_bin_size, y_bin_size) self.histo_edges = (x_edges, y_edges) if self.dim == 3: self.bin_sizes = (x_bin_size, y_bin_size, z_bin_size) self.histo_edges = (x_edges, y_edges, z_edges) print("-- Bin sizes -- ") print(self.bin_sizes) if self.dim == 2: # 2D histogram for every channel, assigned to self.histo (dict) for chan in self.channels: df = self.df.filter(pl.col("channel") == chan) sample = np.array((df["x"], df["y"])) sample = np.swapaxes(sample, 0, 1) # (D, N) where D is # self.dim and N is number of localisations self.histo[chan], _ = np.histogramdd(sample, bins=self.histo_edges) if self.dim == 3: # 3D histogram for every channel, assigned to self.histo (dict) for chan in self.channels: df = self.df[self.df["channel"] == chan] sample = np.array((df["x"], df["y"], df["z"])) sample = np.swapaxes(sample, 0, 1) # (D, N) where D is self.dim and N is number of # localisations self.histo[chan], _ = np.histogramdd(sample, bins=self.histo_edges) plt.close() # work out pixel for each localisations self._coord_2_pixel()
# check digitize agree # df_min = self.df.min() # x_min = df_min["x_pixel"][0] # y_min = df_min["y_pixel"][0] # df_max = self.df.max() # x_max = df_max["x_pixel"][0] # y_max = df_max["y_pixel"][0] # print('check') # print(x_min, x_max) # print(y_min, y_max) # print('check') # for chan in self.channels: # df = self.df.filter(pl.col("channel") == chan) # my_x = df["x_pixel"].to_numpy() # my_y = df["y_pixel"].to_numpy() # their_x = np.digitize(df['x'], bins=self.histo_edges[0]) # their_y = np.digitize(df['y'], bins=self.histo_edges[1]) # print('assert equal') # np.testing.assert_array_equal(my_x, their_x-1) # np.testing.assert_array_equal(my_y, their_y-1) def _coord_2_pixel(self): """Calculate the pixels corresponding to each localisation""" # drop pixel columns if already present for col in ["x_pixel", "y_pixel", "z_pixel"]: if col in self.df.columns: self.df = self.df.drop(col) # necessary for pd.eval below df_min = self.df.min() x_min = df_min["x"][0] y_min = df_min["y"][0] if self.dim == 2: x_pixel_width, y_pixel_width = self.bin_sizes elif self.dim == 3: x_pixel_width, y_pixel_width, z_pixel_width = self.bin_sizes # calculate pixel indices for localisations self.df = self.df.select( [ pl.all(), pl.col("x").map(lambda q: (q - x_min) / x_pixel_width).alias("x_pixel"), pl.col("y").map(lambda q: (q - y_min) / y_pixel_width).alias("y_pixel"), ] ) # floor the pixel locations self.df = self.df.with_columns(pl.col("x_pixel").cast(int, strict=True)) self.df = self.df.with_columns(pl.col("y_pixel").cast(int, strict=True)) # localisations at the end get assigned to outside the histogram, # therefore need to be assigned to previous pixel # self.df = self.df.with_columns( # pl.when(pl.col("x_pixel") == self.df.max()["x_pixel"][0]) # .then(self.df.max()["x_pixel"][0] - 1) # .otherwise(pl.col("x_pixel")) # .alias("x_pixel") # ) # self.df = self.df.with_columns( # pl.when(pl.col("y_pixel") == self.df.max()["y_pixel"][0]) # .then(self.df.max()["y_pixel"][0] - 1) # .otherwise(pl.col("y_pixel")) # .alias("y_pixel") # ) if self.dim == 3: z_min = df_min["z"][0] # calculate pixel indices for localisations self.df = self.df.select( [ pl.all(), pl.col("z") .map(lambda q: (q - z_min) / z_pixel_width) .alias("z_pixel"), ] ) # floor the pixel locations self.df = self.df.with_columns(pl.col("z_pixel").cast(int, strict=True)) # localisations at the end get assigned to outside the histogram, # therefore need to be assigned # to previous pixel # self.df = self.df.with_columns( # pl.when(pl.col("z_pixel") == self.df.max()["z_pixel"][0]) # .then(self.df.max()["z_pixel"][0] - 1) # .otherwise(pl.col("z_pixel")) # .alias("z_pixel") # )
[docs] def manual_segment( self, cmap=["green", "red", "blue", "bop purple"], relabel=False, markers_loc=None, ): """Manually segment the image (histogram.T). Return the segmented histogram and extra column in dataframe corresponding to label. 0 should be reserved for background Args: cmap (list of strings) : Colourmaps napari uses to plot the histograms Returns: markers (list) : Coordinates of markers if added """ # if already has gt label raise error if "gt_label" in self.df.columns and not relabel: raise ValueError( "Manual segment cannot be called on a file which\ already has gt labels in it" ) while True: if self.dim == 2: # overlay all channels for src if len(self.channels) != 1: # create the viewer and add each channel (first channel on own, # then iterate through others) colormap_list = cmap # note image shape when plotted: [x, y] viewer = napari.view_image( self.histo[self.channels[0]].T, name=f"Channel {self.channels[0]}/" f"{self.chan_2_label(self.channels[0])}", rgb=False, blending="additive", colormap=colormap_list[0], gamma=2, contrast_limits=[0, 30], ) for index, chan in enumerate(self.channels[1:]): viewer.add_image( self.histo[chan].T, name=f"Channel {chan}/{self.chan_2_label(chan)}", rgb=False, blending="additive", colormap=colormap_list[index + 1], gamma=2, contrast_limits=[0, 30], ) # add labels if present if relabel: # note this has to be called after coord_2_histo to be in the # correct shape histo_mask = self.render_seg() viewer.add_labels(histo_mask.T, name="Labels") if markers_loc is not None: markers = np.load(markers_loc, allow_pickle=True) if markers.any() is not None: viewer.add_points(markers, name="Points") napari.run() # only one channel else: img = self.histo[self.channels[0]].T # create the viewer and add the image viewer = napari.view_image( img, name=f"Channel {self.channels[0]}/" f"{self.chan_2_label(self.channels[0])}", rgb=False, gamma=2, contrast_limits=[0, 30], ) # add labels if present if relabel: # note this has to be called after coord_2_histo to be in the # correct shape histo_mask = self.render_seg() viewer.add_labels(histo_mask.T, name="Labels") if markers_loc is not None: markers = np.load(markers_loc, allow_pickle=True) if markers.any() is not None: viewer.add_points(markers, name="Points") napari.run() try: markers = viewer.layers["Points"].data x = [[int(float(j)) for j in i] for i in markers] markers = [tuple(i) for i in x] except KeyError: print("You must add points!") markers = None # histogram mask should be assigned to GUI output try: self.histo_mask = viewer.layers["Labels"].data.T break except KeyError: print("You must add labels!") elif self.dim == 3: print("segment 3D image") # segment the coordinates self._manual_seg_pixel_2_coord() return markers
def _manual_seg_pixel_2_coord(self): """Get the localisations associated with manual annotation. Each integer should represent a different label, where 0 is reserved for background. """ if self.dim == 2: # create dataframe flatten_mask = np.ravel(self.histo_mask) mesh_grid = np.meshgrid( range(self.histo_mask.shape[0]), range(self.histo_mask.shape[1]) ) x_pixel = np.ravel(mesh_grid[1]) y_pixel = np.ravel(mesh_grid[0]) label = flatten_mask data = {"x_pixel": x_pixel, "y_pixel": y_pixel, "gt_label": label} mask_df = pl.DataFrame( data, schema=[ ("x_pixel", pl.Int64), ("y_pixel", pl.Int64), ("gt_label", pl.Int64), ], ).sort(["x_pixel", "y_pixel"]) # drop gt_label if already present if "gt_label" in self.df.columns: self.df = self.df.drop("gt_label") # join mask dataframe self.df = self.df.join(mask_df, how="inner", on=["x_pixel", "y_pixel"]) # sanity check # print(len(self.df)) # print(self.df.columns) # print(self.df.head(10)) elif self.dim == 3: print("segment the 3d coords")
[docs] def mask_pixel_2_coord(self, img_mask: np.ndarray) -> pl.DataFrame: """For a given mask over the image (value at each pixel normally representing a label), return the dataframe with a column giving the value for each localisation. Note that it is assumed that the img_mask is a mask of the image, therefore have to transpose img_mask for it to be in the same configuration as the histogram Note we also use this for labels and when the img_mask represents probabilities. Args: img_mask (np.ndarray): Mask over the image - to reiterate, to convert this to histogram space need to transpose it Returns: df (polars dataframe): Original dataframe with additional column with the predicted label""" if self.dim == 2: # list of mask dataframes, each mask dataframe contains # (x,y,label) columns # transpose the image mask to histogram space histo_mask = img_mask.T # create dataframe flatten_mask = np.ravel(histo_mask) mesh_grid = np.meshgrid( range(histo_mask.shape[0]), range(histo_mask.shape[1]) ) x_pixel = np.ravel(mesh_grid[1]) y_pixel = np.ravel(mesh_grid[0]) label = flatten_mask data = {"x_pixel": x_pixel, "y_pixel": y_pixel, "pred_label": label} mask_df = pl.DataFrame( data, schema=[ ("x_pixel", pl.Int64), ("y_pixel", pl.Int64), ("pred_label", pl.Float64), ], ).sort(["x_pixel", "y_pixel"]) # sanity check # print(len(self.df)) # print(self.df.columns) # print(self.df.head(10)) # join mask dataframe df = self.df.join(mask_df, how="inner", on=["x_pixel", "y_pixel"]) # sanity check # print(len(df)) # print(df.columns) # print(df.head(10)) return df elif self.dim == 3: print("segment the 3d coords")
[docs] def save_df_to_csv( self, csv_loc, drop_zero_label=False, drop_pixel_col=True, save_chan_label=True ): """Save the dataframe to a .csv with option to: drop positions which are background drop the column containing pixel information save additional column with labels for each localisation Args: csv_loc (String): Save the csv to this location drop_zero_label (bool): If True then only non zero label positions are saved to csv drop_pixel_col (bool): If True then don't save the column with x,y,z pixel save_chan_label (bool) : If True then save an additional column for each localisation containing the label for each channel Returns: None""" save_df = self.df if drop_pixel_col: # don't want to save x,y pixel to csv save_df = save_df.drop("x_pixel") save_df = save_df.drop("y_pixel") if self.dim == 3: save_df = save_df.drop("z_pixel") # rearrange so x,y,z, ...,labels,channels # if self.dim == 2: # cols = ['x', 'y', 'gt_label', 'channel'] # elif self.dim == 3: # cols = ['x', 'y', 'z', 'gt_label', 'channel'] # save_df_cols = save_df.columns # cols = [col for col in cols if col in save_df_cols] + # [col for col in save_df_cols if col #]not in cols] # save_df = save_df[cols] # drop rows with zero label if drop_zero_label: save_df = save_df.filter(pl.col("gt_label") != 0) # save channel label as well if save_chan_label: label_df = pl.DataFrame({"chan_label": self.channel_label}).with_row_count( "channel" ) label_df = label_df.with_columns(pl.col("channel").cast(pl.Int64)) save_df = save_df.join(label_df, on="channel", how="inner") # save to location save_df.write_csv(csv_loc, sep=",")
[docs] def save_to_parquet( self, save_folder, drop_zero_label=False, drop_pixel_col=False, gt_label_map=None, overwrite=False, ): """Save the dataframe to a parquet with option to drop positions which are background and can drop the column containing pixel information Args: save_folder (String): Save the df to this folder drop_zero_label (bool): If True then only non zero label positions are saved to parquet drop_pixel_col (bool): If True then don't save the column with x,y,z pixel gt_label_map (dict): Dictionary with integer keys representing the gt labels for each localisation with value being a string, representing the real concept e.g. 0:'dog', 1:'cat' overwrite (bool): Whether to overwrite Returns: None """ save_df = self.df if drop_pixel_col: # don't want to save x,y pixel to csv save_df = save_df.drop("x_pixel") save_df = save_df.drop("y_pixel") if self.dim == 3: save_df = save_df.drop("z_pixel") # drop rows with zero label if drop_zero_label: save_df = save_df.filter(pl.col("gt_label") != 0) # convert to arrow + add in metadata if doesn't exist arrow_table = save_df.to_arrow() # convert gt label map to bytes old_metadata = arrow_table.schema.metadata # convert to bytes gt_label_map = json.dumps(gt_label_map).encode("utf-8") meta_data = { "name": self.name, "dim": str(self.dim), "channels": str(self.channels), "channel_label": str(self.channel_label), "gt_label_map": gt_label_map, "bin_sizes": str(self.bin_sizes), } # add in label mapping # meta_data.update(gt_label_map) # merge existing with new meta data merged_metadata = {**meta_data, **(old_metadata or {})} arrow_table = arrow_table.replace_schema_metadata(merged_metadata) save_loc = os.path.join( save_folder, self.name + ".parquet" ) # note if change this need to adjust annotate.py if os.path.exists(save_loc) and not overwrite: raise ValueError( f"Cannot overwite file at {save_loc}." "If you want to overwrite please set overwrite==True" ) pq.write_table(arrow_table, save_loc)
# To access metadata write # parquet_table = pq.read_table(file_path) # parquet_table.schema.metadata ==> metadata # note if accessing keys need # parquet_table.schema.metadata[b'key_name']) # note that b is bytes
[docs] def load_from_parquet(self, input_file): """Loads item saved as .parquet file Args: input_file (string) : Location of the .parquet file to load dataitem from""" # read in parquet file arrow_table = pq.read_table(input_file) # print("loaded metadata", arrow_table.schema.metadata) # metadata name = arrow_table.schema.metadata[b"name"].decode("utf-8") gt_label_map = json.loads( arrow_table.schema.metadata[b"gt_label_map"].decode("utf-8") ) if gt_label_map is not None: # convert string keys to int keys for the mapping gt_label_map = {int(key): value for key, value in gt_label_map.items()} # convert string keys to int keys for the mapping dim = arrow_table.schema.metadata[b"dim"] dim = int(dim) channels = arrow_table.schema.metadata[b"channels"] channels = ast.literal_eval(channels.decode("utf-8")) channel_label = arrow_table.schema.metadata[b"channel_label"] channel_label = ast.literal_eval(channel_label.decode("utf-8")) bin_sizes = arrow_table.schema.metadata[b"bin_sizes"] bin_sizes = ast.literal_eval(bin_sizes.decode("utf-8")) df = pl.from_arrow(arrow_table) # print("channel label", channel_label) self.__init__( name=name, df=df, dim=dim, channels=channels, channel_label=channel_label, gt_label_map=gt_label_map, bin_sizes=bin_sizes, )
[docs] def render_histo(self, labels=None): """Render the histogram from the .parquet file If labels are specified then the histogram is rendered in the order of these lables If not specified defaults to rendering in the channels specified in order by user e.g. [0,3,1,2] Assumes localisations have associated x_pixel and y_pixel already. Args: labels (list) : Order of labels to stack histograms in e.g. labels=['egfr','ereg'] means all images will be returned with egfr in channel 0 and ereg in channel 1 Returns: histo (np.histogram) : Histogram of the localisation data channel_map (list) : List where the first value is the channel in the first axis of the histogram, second value is the channel in the second axis of the histogram etc. e.g. [1,3] : 1st channel is in 1st axis, 3rd channel in 2nd axis label_map (list) : List where the first value is the label in the first axis of the histogram, second value is the channel in the second axis of the histogram etc. e.g. ['egfr','ereg'] : egfr is in 1st axis, ereg in 2nd axis""" histos = [] df_max = self.df.max() x_bins = df_max["x_pixel"][0] + 1 y_bins = df_max["y_pixel"][0] + 1 if labels is None: channels = self.channels label_map = [self.chan_2_label(chan) for chan in channels] else: channels = [self.label_2_chan(label) for label in labels] label_map = labels for chan in channels: df = self.df.filter(pl.col("channel") == chan) histo = np.zeros((x_bins, y_bins)) df = df.groupby(by=["x_pixel", "y_pixel"]).count() x_pixels = df["x_pixel"].to_numpy() y_pixels = df["y_pixel"].to_numpy() counts = df["count"].to_numpy() histo[x_pixels, y_pixels] = counts histos.append(histo) histo = np.stack(histos) channel_map = channels return histo, channel_map, label_map
[docs] def render_seg(self): """Render the segmentation of the histogram""" labels = self.df.select(pl.col("gt_label")).to_numpy() x_pixels = self.df.select(pl.col("x_pixel")).to_numpy() y_pixels = self.df.select(pl.col("y_pixel")).to_numpy() histo_width = np.max(x_pixels) + 1 histo_height = np.max(y_pixels) + 1 histo = np.zeros((histo_width, histo_height), dtype=np.int64) histo[x_pixels, y_pixels] = labels return histo