hcitools.plot

This module contains functions for visualizing data and analysis results.

  1"""
  2This module contains functions for visualizing data and analysis results.
  3"""
  4
  5# Imports
  6from rich import print
  7
  8import plotly.graph_objects as go
  9import plotly.subplots as sp
 10import plotly.express as px
 11import plotly.io as pio
 12import matplotlib.pyplot as plt
 13import seaborn as sns
 14import pandas as pd
 15import numpy as np
 16import textwrap
 17import math
 18import io
 19
 20
 21LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
 22ONE_THIRD = 1.0 / 3.0
 23ONE_SIXTH = 1.0 / 6.0
 24TWO_THIRD = 2.0 / 3.0
 25
 26
 27def set_renderer(renderer):
 28    """
 29    Set the plotly default renderer
 30    """
 31
 32    pio.renderers.default = renderer
 33
 34
 35class colormap:
 36    """
 37    Custom colormaps for plotly figures
 38
 39    These colormaps assume the data has been scaled to between 0 and 1.
 40
 41    Attributes
 42    ----------
 43    `OgBu` : list
 44        Seaborn diverging colorscale from blue (low) to orange (high)
 45    """
 46
 47    OgBu = [[0.00, '#3F7F93'], [0.10, '#6296A6'], [0.20, '#85ADB9'], 
 48            [0.30, '#A9C4CC'], [0.40, '#CDDBE0'], [0.50, '#F2F1F1'], 
 49            [0.60, '#E9D2CD'], [0.70, '#DFB3A7'], [0.80, '#D69483'], 
 50            [0.90, '#CC745D'], [1.00, '#C3553A']]
 51
 53
 54
 55class LabelEncoder:
 56    """
 57    Encode target labels with values between 0 and n_classes-1
 58
 59    Attributes
 60    ----------
 61    `encoder` : dict
 62        dictionary mapping target labels to encodings
 63    `decoder` : dict
 64        dictionary mapping encodings to target labels
 65    `dtype` : np.dtype
 66        dtype of original labels
 67
 68    Methods
 69    -------
 70    `encode(labels)`
 71        Encode a list of target labels
 72    `decode(enc_labels)`
 73        Decode a list of encoded labels
 74    """
 75
 76    def encode(self, labels):
 77        """
 78        Parameters
 79        ----------
 80        `labels` : array_like
 81            list of target labels
 82
 83        Returns
 84        -------
 85        np.array
 86            Encoded labels
 87
 88        Raises
 89        ------
 90        AssertionError
 91            If `enc_labels` is not 1-dimensional
 92        """
 93
 94        labels = np.asarray(labels)
 95        assert labels.ndim == 1, "labels must be 1-dimensional"
 96
 97        # Get unique classes
 98        classes = np.unique(labels)
 99
100        # Create and store maps
101        self.encoder = {l: float(e) for e, l in enumerate(classes)}
102        self.decoder = {float(e): l for e, l in enumerate(classes)}
103        self.dtype = labels.dtype
104
105        return np.fromiter(map(self.encoder.get, labels), dtype=float)
106
107    def decode(self, enc_labels):
108        """
109        Parameters
110        ----------
111        `enc_labels` : array_like
112            list of encoded lavels
113
114        Returns
115        -------
116        np.array
117            Decoded labels
118
119        Raises
120        ------
121        AssertionError
122            If `enc_labels` is not 1-dimensional
123        """
124
125        enc_labels = np.asarray(enc_labels)
126        assert enc_labels.ndim == 1, "enc_labels must be 1-dimensional"
127
128        return np.fromiter(map(self.decoder.get, enc_labels), dtype=self.dtype)
129
130
131def _make_plate(data, feature, time_col='timepoint'):
132    """
133    Convert a feature data frame into a plate layout.
134
135    This assumes `data` contains the following columns: `row`, `column` and
136    `time_col`.
137
138    Parameters
139    ----------
140    `data` : pd.DataFrame
141        a data frame of features including certain metadata columns
142    `feature` : str
143        feature to populate plate with
144    `time_col` : str, optional
145        column that defines time points, by default 'timepoint'
146
147    Returns
148    -------
149    np.array
150        $(k \times r \times c)$ array where `k` = timepoint
151    """
152
153    assert feature in data.columns, "feature must be a column in data"
154    assert time_col in data.columns, f"{time_col} must be a column in data"
155    assert ('row' in data.columns) and ('column' in data.columns), \
156        "'row' and 'column' must be columns in data"
157
158    # Extract time points
159    times = data[time_col].unique()
160
161    # Define row and column names
162    r = len(np.unique(data['row']))
163    c = len(np.unique(data['column']))
164    rows = {i: x for i, x in enumerate(LETTERS[:r], 1)}
165    cols = [str(i) for i in range(1, c+1)]
166
167    # Create plate
168    plate = []
169    for T in times:
170        plate.append(
171            data
172                .query(f"{time_col} == {T}")
173                .loc[:, ['row', 'column', feature]]
174                .pivot(index='row', columns='column', values=feature)
175                .rename(index=rows)
176                .sort_index(ascending=False)
177                .values
178        )
179    
180    return np.array(plate), list(rows.values())[::-1], cols
181
182
183def plate_heatmap(data, feature, time_col='timepoint', colorscale=colormap.OgBu):
184    """
185    Create an interactive plate heatmap; Including an animation for timelapses
186
187    This function assumes that `data` contains the following columns: `row`, 
188    `column`, `time_col`, `compound`, `conc`.
189
190    Parameters
191    ----------
192    `data` : pd.DataFrame
193        a data frame of features including certain metadata columns
194    `feature` : str
195        feature to populate plate with
196    `time_col` : str, optional
197        column that defines time points, by default 'timepoint'
198        This assumes the first time point is 1.
199    `colorscale` : list, optional
200        Plotly-compatible colormap, by default `colormap.OgBu`
201        See `colormap.OgBu` for examples.
202
203    Returns
204    -------
205    go.Figure
206        Plotly figure
207    """
208
209    data.columns = [x.lower() for x in data.columns]
210    feature = feature.lower()
211    time_col = time_col.lower()
212
213    assert feature in data.columns, "feature must be a column in data"
214    assert time_col in data.columns, f"{time_col} must be a column in data"  
215    assert data[time_col].min() > 0, "the first time point must be 1 not 0"
216    assert ('compound' in data.columns) and ('conc' in data.columns), \
217        "'compound' and 'conc' must be columns in data"
218
219    def platemap(x, y, z, cmpd, conc):
220        """
221        Wrapper for go.Heatmap
222        """
223
224        # Insert line breaks in the compound names
225        cmpd = [
226            [x.replace(' ', '<br>') if isinstance(x, str) else '' for x in sub]
227            for sub in cmpd
228        ]
229
230        return go.Heatmap(
231            x=x, y=y, z=z,
232            colorscale=colorscale,
233            text=cmpd,
234            customdata=conc,
235            hovertemplate=(
236                '<b>Well:</b> %{y}%{x}<br>' +
237                '<b>Compound:</b> %{text}' +
238                '<b>Concentration:</b> %{customdata}<br>' +
239                '<b>Value:</b> %{z:.2f}<extra></extra>'
240            ),
241            texttemplate='%{text}',
242            # textfont_size=8.5
243        )
244
245    # Extract time points
246    times = data[time_col].unique()
247
248    # Reformat data as plate
249    plate, rows, cols = _make_plate(data, feature, time_col)
250
251    # Create data for tooltips
252    cmpd = _make_plate(data, 'compound', time_col)[0][0,...]
253    conc = _make_plate(data, 'conc', time_col)[0][0,...]
254
255    # Create figure and fill in layout
256    fig = {'data': [], 'layout': {}, 'frames': []}
257
258    fig['layout'] = go.Layout(
259        title=feature,
260        title_x=0.5,
261        xaxis={
262            'showgrid': False,
263            'showticklabels': True,
264            'tickfont': {'size': 16, 'color': 'black'}
265        },
266        yaxis={
267            'showgrid': False,
268            'showticklabels': True,
269            'tickfont': {'size': 16, 'color': 'black'}
270        },
271        hovermode='closest',
272        updatemenus=[{
273            "buttons": [
274                {"args": [None, {"frame": {"duration": 500, "redraw": True},
275                                "fromcurrent": True, 
276                                "transition": {"duration": 300, 
277                                                "easing": "quadratic-in-out"}}],
278                "label": "Play",
279                "method": "animate"},
280                {"args": [[None], {"frame": {"duration": 0, "redraw": True},
281                                "mode": "immediate", 
282                                "transition": {"duration": 0}}],
283                "label": "Pause",
284                "method": "animate"}
285            ],
286            "direction": "left",
287            "pad": {"r": 10, "t": 87},
288            "showactive": False,
289            "type": "buttons",
290            "x": 0.1,
291            "xanchor": "right",
292            "y": 0,
293            "yanchor": "top"   
294        }] if len(times) > 1 else None
295    )
296
297    # Add Time 0 plate to data
298    fig['data'].append(platemap(cols, rows, plate[0,...], cmpd, conc))
299
300    # Create frames & animations
301    if len(times) > 1:
302        # Create sliders
303        sliders = {
304            "active": 0,
305            "yanchor": "top",
306            "xanchor": "left",
307            "currentvalue": {
308                "font": {"size": 20},
309                "prefix": "Timepoint:",
310                "visible": True,
311                "xanchor": "right"
312            },
313            "transition": {"duration": 300, "easing": "cubic-in-out"},
314            "pad": {"b": 10, "t": 50},
315            "len": 0.9,
316            "x": 0.1,
317            "y": 0,
318            "steps": []
319        }
320
321        # Create frames
322        for time in data[time_col].unique().astype(str):
323            # New frame
324            fig['frames'].append({
325                "data": platemap(cols, rows, plate[int(time)-1,...], cmpd, conc),
326                "name": str(time)
327            })
328
329            # Corresponding slider step
330            sliders['steps'].append({
331                "args": [ 
332                    [time], 
333                    {"frame": {"duration": 300, "redraw": True},
334                    "mode": "immediate",
335                    "transition": {"duration": 300}}
336                ],
337                "label": time,
338                "method": "animate"
339            })
340
341        # Add sliders to figure layout
342        fig['layout']['sliders'] = [sliders]
343
344    return go.Figure(fig)
345
346
347def pca_comps(proj, exp_var, time_col='timepoint', n_comps=4):
348    """
349    Plot a scatter grid of PCA components
350
351    This function is written to use the output from `process.dim_reduction`
352
353    Parameters
354    ----------
355    `proj` : _pd.DataFrame
356        Data frame with pca projections, from `process.dim_reduction`
357    `exp_var` : array_like
358        List of explained variances for each PCA component
359    `time_col` : str, optional
360        Column containing time points; must be in index; by default 'timepoint'
361    `n_comps` : int, optional
362        Number of pca components to plot, by default 4
363
364    Returns
365    -------
366    go.Figure
367        Plotly figure
368    """
369
370    proj.columns = [x.lower() for x in proj.columns]
371    assert 'variable' in proj.columns, "variable must be a column in proj"
372    assert 'compound' in proj.columns, f"compound must be a column in proj"
373    assert time_col in proj.columns, f"{time_col} must be a column in proj"
374    assert 'conc' in proj.columns, "conc must be a column in proj"
375
376    # Prepare matrix of components as well as variables for plotting
377    comp_cols = [str(x+1) for x in range(n_comps)]
378    comps = (proj.query("variable == 'PCA'")
379        .reset_index(drop=True)
380        [['compound', 'conc',  time_col, *comp_cols]])
381    compounds = comps['compound']
382    comps.drop(['compound', time_col, 'conc'], axis=1, inplace=True)
383
384    # Create labels
385    labels = {str(i): f"PC {i+1} ({var:.2f}%)" for i, var in enumerate(exp_var)}
386
389
390    # Create figure
391    fig = px.scatter_matrix(
392        comps.values,
393        labels=labels,
394        dimensions=range(n_comps),
395        color=compounds,
396        opacity=0.5,
397        template='plotly_white'
398    )
399    fig.update_traces(diagonal_visible=False)
400    fig.update_layout(paper_bgcolor='white', plot_bgcolor='white', height=500)
401    
402    return fig
403
404
405def clusters(data, compound_a, compound_b, method, time_col='timepoint'):
406    """
407    Create clustering figures that compare 2 compounds.
408
409    This function is written to use the output from `process.dim_reduction`
410
411    Parameters
412    ----------
413    `data` : pd.DataFrame
414        _description_
415    `compound_a` : str
416        Compound A (red points)
417    `compound_b` : str
418        Compound B (green points)
419    `method` : str
420        One of 'PCA', 'tSNE' or 'UMAP'
421    `time_col` : str, optional
422        Column containing time points, by default 'timepoint'
423
424    Returns
425    -------
426    go.Figure
427        Plotly figure
428    """
429
430    assert isinstance(data, pd.DataFrame), "data must be a data frame"
431    assert 'compound' in data.columns, "'compound' must be a column in data"
432    assert compound_a in data['compound'].tolist(), \
433        "compound_a must be present in data['compound']"
434    assert compound_b in data['compound'].tolist(), \
435        "compound_b must be present in data['compound']"
436    method = method.lower()
437    assert method in ['pca', 'tsne', 'umap'], \
438        "method must be one of 'PCA', 'tSNE' or 'UMAP'"
439    assert time_col in data.columns, f"{time_col} must be a column in proj"
440
441    times = data[time_col].unique()
442    method = method.upper()
443
444    def create_traces(compound, colorscale, linecolor, cbar_pos):
445        """
446        Create traces for a particular compound
447        """
448
449        # Define colors for different time points
450        if len(times) > 1:
451            color_map = {tp: c for tp, c in
452                zip(times, sns.color_palette(colorscale, len(times)).as_hex())}
453            data['timecolor'] = data[time_col].replace(color_map)
454        else:
455            data['timecolor'] = colorscale
456
457        # Subset data for compound
458        cmpd = (data
459            .query(f"compound == '{compound}' & variable == '{method}'")
460            .reset_index(drop=True))
461
462        # Encode different sizes for each concentration
463        cmpd['conc'] = cmpd['conc'].astype(float)
464        concs = sorted(cmpd['conc'].unique())
465        if len(concs) > 1:
466            conc_map = {c: s*3 for c, s in zip(concs, range(1, len(concs)+1))}
467        else:
468            conc_map = {concs[0]: 20}
469
470        # Create traces
471        traces = []
472
473        for i, conc in enumerate(concs):
474            # Create mask to subset data
475            I = (cmpd['conc'] == conc)
476
477            traces.append(
478                go.Scatter(
479                    x=cmpd.loc[I, '1'],
480                    y=cmpd.loc[I, '2'],
481                    mode='markers',
482                    marker=dict(
483                        color=cmpd.loc[I, time_col],
484                        size=conc_map[conc],
485                        opacity=0.5,
486                        colorscale=colorscale,
487                        colorbar=dict(
488                            x=cbar_pos, 
489                            thickness=20, 
490                            yanchor='middle', 
491                            len=.7
492                        ) if (i == 0) and (len(times) > 1) else None,
493                        line=dict(width=1.2, color=linecolor)
494                    ),
495                    name=str(conc),
496                    legendgroup=compound.replace(' ', '').lower(),
497                    legendgrouptitle_text=(compound if i == 0 else None)
498                )
499            )
500
501        return traces
502
503    # Create figure traces
504    traces = [*create_traces(compound_a, 'Reds', 'red', -0.25),
505            *create_traces(compound_b, 'Greens', 'green', -0.35)]
506
507    # Create layout
508    layout = go.Layout(
509        legend=dict(tracegroupgap=20, groupclick='toggleitem'),
510        template='plotly_white',
511        margin=dict(l=20, r=20, t=70, b=40),
512        height=500,
513        title=f"{compound_a} vs {compound_b}<br>({method} Clusters)",
514        title_x=0.5,
515        xaxis_title=f'{method} 1',
516        yaxis_title=f'{method} 2',
517    )
518
519    # Create figure
520    fig = go.Figure(data=traces, layout=layout)
521    fig.update_yaxes(
522        scaleanchor='x',
523        scaleratio=1
524    )
525
526    # Annotate the colorbars
527    if len(times) > 1:
528        fig.add_annotation(
529            xref='paper', 
530            yref='paper', 
531            x=-0.33, 
532            y= 0.92, 
533            text='Time Point',
534            font_size=14,
535            showarrow=False
536        )
537
538    return fig
539
540
541def _make_grid(items, col_wrap=2):
542    """
543    Split a list of items into a grid for subplots
544
545    Parameters
546    ----------
547    `items` : list
548        List of items for each subplot (e.g., tiles)
549    `col_wrap` : int, optional
550        Number of columns allowed in layout, by default 2
551
552    Returns
553    -------
554    go.Figure
555        Plotly figure
556    """
557
558    def grid_dims(n, col_wrap):
559        """
560        Determine grid dimensions
561        """
562
563        nrows = math.ceil(n / col_wrap)
564        ncols = col_wrap if n > col_wrap else n
565
566        return nrows, ncols
567
568    nrows, ncols = grid_dims(len(items), col_wrap)
569    positions = {
570        x: {'x': (i // col_wrap) + 1, 'y': (i % col_wrap) + 1} 
571        for i, x in enumerate(items)
572    }
573
574    return nrows, ncols, positions
575
576
577def _v(m1, m2, hue):
578    hue = hue % 1.0
579    if hue < ONE_SIXTH:
580        return m1 + (m2-m1)*hue*6.0
581    if hue < 0.5:
582        return m2
583    if hue < TWO_THIRD:
584        return m1 + (m2-m1)*(TWO_THIRD-hue)*6.0
585    return m1
586
587
588def _hls_to_rgb(h, l, s):
589    """
590    Convert HLS (Hue, Luminance, Saturation) to RGB
591    """
592
593    if s == 0.0:
594        return l, l, l
595    if l <= 0.5:
596        m2 = l * (1.0+s)
597    else:
598        m2 = l+s-(l*s)
599    m1 = 2.0*l - m2
600    return (_v(m1, m2, h+ONE_THIRD), _v(m1, m2, h), _v(m1, m2, h-ONE_THIRD))
601
602
603def _get_colors(n):
604    """
605    Generate n visually distinct colors.
606
607    This is taken from [this](https://stackoverflow.com/a/9701141) stack 
608    overflow post.
609    """
610
611    colors = []
612    for i in np.arange(0., 360., 360. / n):
613        hue = i/360.
614        lightness = (50 + np.random.rand() * 10)/100
615        saturation = (90 + np.random.rand() * 10)/100
616        
617        r, g, b = _hls_to_rgb(hue, lightness, saturation)
618        colors.append( "#%02x%02x%02x" % (int(r*255), int(g*255), int(b*255)) )
619    
620    return colors
621
622
623def distplot(data, features, group_col, tooltips=None, kind='box', col_wrap=2, 
624             title_len=30):
625    """
626    Create boxplots showing the distibution of features for different groups.
627
628    This generates a figure with as many subplots as there are features
629
630    Parameters
631    ----------
632    `data` : pd.DataFrame
633        Data frame to plot
634    `features` : list
635        List of features to visualize
636    `group_col` : str
637        `data` column that contains groups of interest
638    `tooltips` : dict, optional
639        Dictionary that defines annotation tooltips, by default None
640        Keys = Tooltip Name;  
641        Values = Corresponding column in `data`
642    `kind` : str, optional
643        Type of plot to generate; one of 'box', 'bar', by default 'box'
644    `col_wrap` : int, optional
645        Number of columns allowed in layout, by default 3
646    `title_len` : int, optional
647        Wrap length for subplot titles, by default 30
648
649    Returns
650    -------
651    go.Figure
652        Plotly figure
653
654    Raises
655    ------
656    NotImplementedError
657        When `kind != 'box'`
658    """
659
660    assert isinstance(data, pd.DataFrame), "data must be a data frame"
661    assert kind in ['box', 'bar'], "kind must be one of 'box', 'bar'."
662    assert group_col in data.columns, "group_col must be a column in data"
663    for f in features:
664        assert f in data.columns, "features must contain columns from data"
665    if tooltips is not None:
666        for col in tooltips.values():
667            assert col in data.columns, \
668                "Values of tooltips must be columns of data"
669
670    # Determine grid dimensions & positions
671    nrows, ncols, positions = _make_grid(features, col_wrap=col_wrap)
672
673    # Wrap text for subplot titles
674    titles = ['<br>'.join(textwrap.wrap(x, title_len)) for x in features]
675
676    # Get list of groups & colors
677    groups = data[group_col].unique()
678    colors = _get_colors(len(groups))
679
680    # Create figure
681    fig = sp.make_subplots(rows=nrows, cols=ncols, subplot_titles=titles)
682
683    # Add traces
684    if kind == 'box':
685        for feature, pos in positions.items():
686            for grp, color in zip(groups, colors):
687                # Prepare data for annotations
688                _data = data.query(f"{group_col} == '{grp}'")
689                text = 'well: ' + data['well'] + '<br>'
690                if tooltips is not None:
691                    for name, col in tooltips.items():
692                        text += f'{name}: ' + data[col].astype(str) + '<br>'
693
694                fig.add_trace(
695                    go.Box(
696                        y=_data[feature],
697                        name=grp,
698                        text=text.tolist(),
699                        hovertemplate='%{text}',
700                        legendgroup=grp,
701                        marker_color=color,
702                        showlegend=True if pos['x'] == pos['y'] == 1 else False
703                    ),
704                    row=pos['x'], 
705                    col=pos['y']
706                )
707    elif kind == 'bar':
708        raise NotImplementedError("Can't do that yet. Working on it.")
709    else:
710        raise NotImplementedError("Can't do that yet.")
711
712    fig.update_xaxes(showticklabels=False)
713    fig.update_layout(template='plotly_white')
714    fig.update_annotations(font=dict(family="Helvetica", size=14))
715
716    return fig
717
718
719def textplot(text):
720    """
721    Create a blank figure to display some text. Serves as placeholder for 
722    actual figure.
723
724    Parameters
725    ----------
726    `text` : str
727        Message to display in figure
728
729    Return
730    ------
731    go.Figure
732        Plotly figure
733    """
734
735    fig = go.Figure(
736        go.Scatter(
737            x=[0], y=[0], text=[text], textposition='top center',
738            textfont_size=16, mode='text', hoverinfo='skip'
739        )
740    )
741    fig.update_layout(template='simple_white', height=300)
742    fig.update_xaxes(visible=False, fixedrange=True)
743    fig.update_yaxes(visible=False, fixedrange=True)
744
745    return fig
746
747
748def gifify(fig, file, frame_title='Frame', fps=30) -> None:
749    """
750    Export a plotly animation as a gif
751
752    Parameters
753    ----------
754    `fig` : go.Figure
755        Plotly figure
756    `file` : str
757        Path to file where figure gif will be stored
758    `frame_title` : str, optional
759        Title that describes each frame, by default 'Frame'
760    `fps` : int, optional
761        Frame rate, by default 30
762    """
763
764    assert isinstance(fig, go.Figure), \
765        "This only works for plotly figures"
766    assert file.endswith('.gif'), "file must be a .gif"
767
768    import moviepy.editor as mpy
769    from PIL import Image
770
771    def fig2array(fig):
772        """
773        Convert a plotly figure to a numpy array
774        """
775
776        bytes = fig.to_image(format='jpg', engine='kaleido')
777        buffer = io.BytesIO(bytes)
778        img = Image.open(buffer)
779
780        return img
781
782    # Remove sliders and buttons from figure layout
783    exclude = ['updatemenus', 'sliders']
784    layout = fig.to_dict()['layout']
785    layout = {k: v for k, v in layout.items() if not k in exclude}
786
787    # Create list to store frames (as images)
788    frames = []
789    for i, frame in enumerate(fig['frames']):
790        _fig = go.Figure(data=frame['data'], layout=layout)
791        _fig.update_layout(title=f"{frame_title} {i+1}", title_x=0.5)
792        frames.append( fig2array(_fig) )
793
794    # Create animation
795    make_frame = lambda t: frames[int(t)]
796    anim = mpy.VideoClip(make_frame, duration=len(frames))
797    anim.write_gif(file, fps=fps, logger=None)
798    print("Done :thumbsup:")
799
800
801def heatmap(data, col_groups=None, col_colors=None, col_group_names=None,
802            row_groups=None, row_colors=None, row_group_names=None,
803            clust_cols=True, clust_rows=True, cluster_kws=dict()):
804    """
805    Construct an interactive heatmap
806
807    Parameters
808    ----------
809    data : pd.DataFrame
810        Data to plota
811    {row, col}_groups : dict
812        Dictionary assigning groups to rows or columns.
813        Keys should be the index or columns of data.
814        Values should be a list of groups.
815    {row, col}_group_names : list
816        Names for each of the row/col groups
817        Should be the same length as the lists in {row, col}_groups
818    {row, col}_colors : dict
819        Dictionary defining colors for each group.
820        Keys = groups;  Values = colors;
821    clust_{rows, cols} : bool
822        Should row and/or column clustering be performed
823    cluster_kws : dict
824        kwargs for sns.clustermap
825    """
826
827    from sklearn.preprocessing import LabelEncoder
828
830
831    # Check inputs
833    ## col_groups and row_groups: each value should be of the same length
834    if col_group_names is None:
835        col_group_names = []
836    if row_group_names is None:
837        row_group_names = []
838
839    # Determine the size of the subplot grid
840    n_col_grps = len(col_group_names)
841    n_row_grps = len(row_group_names)
842    I, J = n_col_grps+1, n_row_grps+1
843
844    # Define column widths and row heights
845    col_widths = [0.03 for _ in range(J-1)] + [1-(J-1)*0.03]
846    row_heights = [0.07 for _ in range(I-1)] + [1-(I-1)*0.07]
847
848    # Create subplot grid
849    fig = sp.make_subplots(
850        rows=I, 
851        cols=J, 
852        column_widths=col_widths, 
853        row_heights=row_heights,
854        vertical_spacing=0.01,
855        horizontal_spacing=0.01,
856        shared_xaxes=True,
857        shared_yaxes=True
858    )
859
860    # Perform clustering and extract clustered data frame from seaborn clustermap
861    if clust_cols or clust_rows:
862        data = sns.clustermap(data, row_cluster=clust_rows, col_cluster=clust_cols,
863                            **cluster_kws).data2d
864        plt.close()
865
866    # Plot the heatmap and adjust axes
867    fig.append_trace(
868        go.Heatmap(
869            z=data,
870            x=data.columns,
871            y=data.index.astype(str),
872            colorscale='RdBu_r',
873            hovertemplate='<b>Sample:</b> %{y}<br>'+
874                        '<b>Feature:</b> %{x}<br>'+
875                        '<b>Value:</b>%{z}'
876                        '<extra></extra>'
877        ),
878        row=I, col=J
879    )
880    fig.update_yaxes(row=I, col=J, showticklabels=False, autorange='reversed')
881    fig.update_xaxes(row=I, col=J, showticklabels=True, tickangle=270)
882    fig.update_traces(row=I, col=J, colorbar_len=0.7)
883
884    # Add row colors
885    if row_groups is not None:
886        for j, grp in enumerate(row_group_names):
887            # Create row data
888            row_data = [row_groups[r][j] for r in data.index]
889
890            # Encode row data numerically so that heatmap can be plotted
891            le = LabelEncoder().fit(row_data)
892            Z = le.transform(row_data)
893
894            # Define colorscale
895            znorm = np.unique((Z-Z.min()) / (Z.max()-Z.min()))
896            zmax = Z.max()
897            colorscale = [[z, row_colors[le.inverse_transform([int(z*zmax)])[0]]] 
898                        for z in znorm]
899
900            fig.append_trace(
901                go.Heatmap(
902                    z=pd.DataFrame(Z), 
903                    y=data.index.astype(str),
904                    x=[grp],
905                    text=pd.DataFrame(row_data),
906                    colorscale=colorscale,
907                    hovertemplate='<b>Sample:<b> %{y}<br>'+ 
908                                f'<b>{grp}:</b>: %{{text}}'+
909                                '<extra></extra>',
910                    showscale=False
911                ),
912                row=I, col=j+1
913            )
914            fig.update_yaxes(row=I, col=j+1, showticklabels=False, 
915                             autorange='reversed')
916            fig.update_xaxes(row=I, col=j+1, showticklabels=True, 
917                             tickangle=270,
918                             tickfont={'size': 15, 'family': 'Arial'})
919
920    # Add column colors
921    if col_groups is not None:
922        for i, grp in enumerate(col_group_names):
923            # Create column data
924            col_data = [col_groups[r][i] for r in data.columns]
925
926            # Encode col data numerically so that heatmap can be plotted
927            le = LabelEncoder().fit(col_data)
928            Z = le.transform(col_data)
929
930            # Define colorscale
931            znorm = np.unique((Z-Z.min()) / (Z.max()-Z.min()))
932            zmax = Z.max()
933            colorscale = [[z, col_colors[le.inverse_transform([int(z*zmax)])[0]]] 
934                        for z in znorm]
935
936            fig.append_trace(
937                go.Heatmap(
938                    z=pd.DataFrame(Z).T,
939                    y=[grp],
940                    x=data.columns,
941                    text=pd.DataFrame(col_data).T,
942                    colorscale=colorscale,
943                    hovertemplate='<b>Feature:</b> %{x}<br>'+
944                                f'<b>{grp}:</b>: %{{text}}'+
945                                '<extra></extra>',
946                    showscale=False
947                ),
948                row=i+1, col=J
949            )
950            fig.update_yaxes(row=i+1, col=J, showticklabels=True, 
951                             autorange='reversed', side='right', 
952                             tickfont={'size': 15, 'family': 'Arial'})
953            fig.update_xaxes(row=i+1, col=J, showticklabels=False)
954
956    # BUG: Clustering Rows makes some data disappear; 
957    #      looks like some aggregation is happening
958
959    return fig
def set_renderer(renderer):
28def set_renderer(renderer):
29    """
30    Set the plotly default renderer
31    """
32
33    pio.renderers.default = renderer

Set the plotly default renderer

class colormap:
36class colormap:
37    """
38    Custom colormaps for plotly figures
39
40    These colormaps assume the data has been scaled to between 0 and 1.
41
42    Attributes
43    ----------
44    `OgBu` : list
45        Seaborn diverging colorscale from blue (low) to orange (high)
46    """
47
48    OgBu = [[0.00, '#3F7F93'], [0.10, '#6296A6'], [0.20, '#85ADB9'], 
49            [0.30, '#A9C4CC'], [0.40, '#CDDBE0'], [0.50, '#F2F1F1'], 
50            [0.60, '#E9D2CD'], [0.70, '#DFB3A7'], [0.80, '#D69483'], 
51            [0.90, '#CC745D'], [1.00, '#C3553A']]

Custom colormaps for plotly figures

These colormaps assume the data has been scaled to between 0 and 1.

Attributes
  • OgBu (list): Seaborn diverging colorscale from blue (low) to orange (high)
colormap()
OgBu = [[0.0, '#3F7F93'], [0.1, '#6296A6'], [0.2, '#85ADB9'], [0.3, '#A9C4CC'], [0.4, '#CDDBE0'], [0.5, '#F2F1F1'], [0.6, '#E9D2CD'], [0.7, '#DFB3A7'], [0.8, '#D69483'], [0.9, '#CC745D'], [1.0, '#C3553A']]
class LabelEncoder:
 56class LabelEncoder:
 57    """
 58    Encode target labels with values between 0 and n_classes-1
 59
 60    Attributes
 61    ----------
 62    `encoder` : dict
 63        dictionary mapping target labels to encodings
 64    `decoder` : dict
 65        dictionary mapping encodings to target labels
 66    `dtype` : np.dtype
 67        dtype of original labels
 68
 69    Methods
 70    -------
 71    `encode(labels)`
 72        Encode a list of target labels
 73    `decode(enc_labels)`
 74        Decode a list of encoded labels
 75    """
 76
 77    def encode(self, labels):
 78        """
 79        Parameters
 80        ----------
 81        `labels` : array_like
 82            list of target labels
 83
 84        Returns
 85        -------
 86        np.array
 87            Encoded labels
 88
 89        Raises
 90        ------
 91        AssertionError
 92            If `enc_labels` is not 1-dimensional
 93        """
 94
 95        labels = np.asarray(labels)
 96        assert labels.ndim == 1, "labels must be 1-dimensional"
 97
 98        # Get unique classes
 99        classes = np.unique(labels)
100
101        # Create and store maps
102        self.encoder = {l: float(e) for e, l in enumerate(classes)}
103        self.decoder = {float(e): l for e, l in enumerate(classes)}
104        self.dtype = labels.dtype
105
106        return np.fromiter(map(self.encoder.get, labels), dtype=float)
107
108    def decode(self, enc_labels):
109        """
110        Parameters
111        ----------
112        `enc_labels` : array_like
113            list of encoded lavels
114
115        Returns
116        -------
117        np.array
118            Decoded labels
119
120        Raises
121        ------
122        AssertionError
123            If `enc_labels` is not 1-dimensional
124        """
125
126        enc_labels = np.asarray(enc_labels)
127        assert enc_labels.ndim == 1, "enc_labels must be 1-dimensional"
128
129        return np.fromiter(map(self.decoder.get, enc_labels), dtype=self.dtype)

Encode target labels with values between 0 and n_classes-1

Attributes
  • encoder (dict): dictionary mapping target labels to encodings
  • decoder (dict): dictionary mapping encodings to target labels
  • dtype (np.dtype): dtype of original labels
Methods

encode(labels) Encode a list of target labels decode(enc_labels) Decode a list of encoded labels

LabelEncoder()
def encode(self, labels):
 77    def encode(self, labels):
 78        """
 79        Parameters
 80        ----------
 81        `labels` : array_like
 82            list of target labels
 83
 84        Returns
 85        -------
 86        np.array
 87            Encoded labels
 88
 89        Raises
 90        ------
 91        AssertionError
 92            If `enc_labels` is not 1-dimensional
 93        """
 94
 95        labels = np.asarray(labels)
 96        assert labels.ndim == 1, "labels must be 1-dimensional"
 97
 98        # Get unique classes
 99        classes = np.unique(labels)
100
101        # Create and store maps
102        self.encoder = {l: float(e) for e, l in enumerate(classes)}
103        self.decoder = {float(e): l for e, l in enumerate(classes)}
104        self.dtype = labels.dtype
105
106        return np.fromiter(map(self.encoder.get, labels), dtype=float)
Parameters
  • labels (array_like): list of target labels
Returns
  • np.array: Encoded labels
Raises
  • AssertionError: If enc_labels is not 1-dimensional
def decode(self, enc_labels):
108    def decode(self, enc_labels):
109        """
110        Parameters
111        ----------
112        `enc_labels` : array_like
113            list of encoded lavels
114
115        Returns
116        -------
117        np.array
118            Decoded labels
119
120        Raises
121        ------
122        AssertionError
123            If `enc_labels` is not 1-dimensional
124        """
125
126        enc_labels = np.asarray(enc_labels)
127        assert enc_labels.ndim == 1, "enc_labels must be 1-dimensional"
128
129        return np.fromiter(map(self.decoder.get, enc_labels), dtype=self.dtype)
Parameters
  • enc_labels (array_like): list of encoded lavels
Returns
  • np.array: Decoded labels
Raises
  • AssertionError: If enc_labels is not 1-dimensional
def plate_heatmap( data, feature, time_col='timepoint', colorscale=[[0.0, '#3F7F93'], [0.1, '#6296A6'], [0.2, '#85ADB9'], [0.3, '#A9C4CC'], [0.4, '#CDDBE0'], [0.5, '#F2F1F1'], [0.6, '#E9D2CD'], [0.7, '#DFB3A7'], [0.8, '#D69483'], [0.9, '#CC745D'], [1.0, '#C3553A']]):
184def plate_heatmap(data, feature, time_col='timepoint', colorscale=colormap.OgBu):
185    """
186    Create an interactive plate heatmap; Including an animation for timelapses
187
188    This function assumes that `data` contains the following columns: `row`, 
189    `column`, `time_col`, `compound`, `conc`.
190
191    Parameters
192    ----------
193    `data` : pd.DataFrame
194        a data frame of features including certain metadata columns
195    `feature` : str
196        feature to populate plate with
197    `time_col` : str, optional
198        column that defines time points, by default 'timepoint'
199        This assumes the first time point is 1.
200    `colorscale` : list, optional
201        Plotly-compatible colormap, by default `colormap.OgBu`
202        See `colormap.OgBu` for examples.
203
204    Returns
205    -------
206    go.Figure
207        Plotly figure
208    """
209
210    data.columns = [x.lower() for x in data.columns]
211    feature = feature.lower()
212    time_col = time_col.lower()
213
214    assert feature in data.columns, "feature must be a column in data"
215    assert time_col in data.columns, f"{time_col} must be a column in data"  
216    assert data[time_col].min() > 0, "the first time point must be 1 not 0"
217    assert ('compound' in data.columns) and ('conc' in data.columns), \
218        "'compound' and 'conc' must be columns in data"
219
220    def platemap(x, y, z, cmpd, conc):
221        """
222        Wrapper for go.Heatmap
223        """
224
225        # Insert line breaks in the compound names
226        cmpd = [
227            [x.replace(' ', '<br>') if isinstance(x, str) else '' for x in sub]
228            for sub in cmpd
229        ]
230
231        return go.Heatmap(
232            x=x, y=y, z=z,
233            colorscale=colorscale,
234            text=cmpd,
235            customdata=conc,
236            hovertemplate=(
237                '<b>Well:</b> %{y}%{x}<br>' +
238                '<b>Compound:</b> %{text}' +
239                '<b>Concentration:</b> %{customdata}<br>' +
240                '<b>Value:</b> %{z:.2f}<extra></extra>'
241            ),
242            texttemplate='%{text}',
243            # textfont_size=8.5
244        )
245
246    # Extract time points
247    times = data[time_col].unique()
248
249    # Reformat data as plate
250    plate, rows, cols = _make_plate(data, feature, time_col)
251
252    # Create data for tooltips
253    cmpd = _make_plate(data, 'compound', time_col)[0][0,...]
254    conc = _make_plate(data, 'conc', time_col)[0][0,...]
255
256    # Create figure and fill in layout
257    fig = {'data': [], 'layout': {}, 'frames': []}
258
259    fig['layout'] = go.Layout(
260        title=feature,
261        title_x=0.5,
262        xaxis={
263            'showgrid': False,
264            'showticklabels': True,
265            'tickfont': {'size': 16, 'color': 'black'}
266        },
267        yaxis={
268            'showgrid': False,
269            'showticklabels': True,
270            'tickfont': {'size': 16, 'color': 'black'}
271        },
272        hovermode='closest',
273        updatemenus=[{
274            "buttons": [
275                {"args": [None, {"frame": {"duration": 500, "redraw": True},
276                                "fromcurrent": True, 
277                                "transition": {"duration": 300, 
278                                                "easing": "quadratic-in-out"}}],
279                "label": "Play",
280                "method": "animate"},
281                {"args": [[None], {"frame": {"duration": 0, "redraw": True},
282                                "mode": "immediate", 
283                                "transition": {"duration": 0}}],
284                "label": "Pause",
285                "method": "animate"}
286            ],
287            "direction": "left",
288            "pad": {"r": 10, "t": 87},
289            "showactive": False,
290            "type": "buttons",
291            "x": 0.1,
292            "xanchor": "right",
293            "y": 0,
294            "yanchor": "top"   
295        }] if len(times) > 1 else None
296    )
297
298    # Add Time 0 plate to data
299    fig['data'].append(platemap(cols, rows, plate[0,...], cmpd, conc))
300
301    # Create frames & animations
302    if len(times) > 1:
303        # Create sliders
304        sliders = {
305            "active": 0,
306            "yanchor": "top",
307            "xanchor": "left",
308            "currentvalue": {
309                "font": {"size": 20},
310                "prefix": "Timepoint:",
311                "visible": True,
312                "xanchor": "right"
313            },
314            "transition": {"duration": 300, "easing": "cubic-in-out"},
315            "pad": {"b": 10, "t": 50},
316            "len": 0.9,
317            "x": 0.1,
318            "y": 0,
319            "steps": []
320        }
321
322        # Create frames
323        for time in data[time_col].unique().astype(str):
324            # New frame
325            fig['frames'].append({
326                "data": platemap(cols, rows, plate[int(time)-1,...], cmpd, conc),
327                "name": str(time)
328            })
329
330            # Corresponding slider step
331            sliders['steps'].append({
332                "args": [ 
333                    [time], 
334                    {"frame": {"duration": 300, "redraw": True},
335                    "mode": "immediate",
336                    "transition": {"duration": 300}}
337                ],
338                "label": time,
339                "method": "animate"
340            })
341
342        # Add sliders to figure layout
343        fig['layout']['sliders'] = [sliders]
344
345    return go.Figure(fig)

Create an interactive plate heatmap; Including an animation for timelapses

This function assumes that data contains the following columns: row, column, time_col, compound, conc.

Parameters
  • data (pd.DataFrame): a data frame of features including certain metadata columns
  • feature (str): feature to populate plate with
  • time_col (str, optional): column that defines time points, by default 'timepoint' This assumes the first time point is 1.
  • colorscale (list, optional): Plotly-compatible colormap, by default colormap.OgBu See colormap.OgBu for examples.
Returns
  • go.Figure: Plotly figure
def pca_comps(proj, exp_var, time_col='timepoint', n_comps=4):
348def pca_comps(proj, exp_var, time_col='timepoint', n_comps=4):
349    """
350    Plot a scatter grid of PCA components
351
352    This function is written to use the output from `process.dim_reduction`
353
354    Parameters
355    ----------
356    `proj` : _pd.DataFrame
357        Data frame with pca projections, from `process.dim_reduction`
358    `exp_var` : array_like
359        List of explained variances for each PCA component
360    `time_col` : str, optional
361        Column containing time points; must be in index; by default 'timepoint'
362    `n_comps` : int, optional
363        Number of pca components to plot, by default 4
364
365    Returns
366    -------
367    go.Figure
368        Plotly figure
369    """
370
371    proj.columns = [x.lower() for x in proj.columns]
372    assert 'variable' in proj.columns, "variable must be a column in proj"
373    assert 'compound' in proj.columns, f"compound must be a column in proj"
374    assert time_col in proj.columns, f"{time_col} must be a column in proj"
375    assert 'conc' in proj.columns, "conc must be a column in proj"
376
377    # Prepare matrix of components as well as variables for plotting
378    comp_cols = [str(x+1) for x in range(n_comps)]
379    comps = (proj.query("variable == 'PCA'")
380        .reset_index(drop=True)
381        [['compound', 'conc',  time_col, *comp_cols]])
382    compounds = comps['compound']
383    comps.drop(['compound', time_col, 'conc'], axis=1, inplace=True)
384
385    # Create labels
386    labels = {str(i): f"PC {i+1} ({var:.2f}%)" for i, var in enumerate(exp_var)}
387
390
391    # Create figure
392    fig = px.scatter_matrix(
393        comps.values,
394        labels=labels,
395        dimensions=range(n_comps),
396        color=compounds,
397        opacity=0.5,
398        template='plotly_white'
399    )
400    fig.update_traces(diagonal_visible=False)
401    fig.update_layout(paper_bgcolor='white', plot_bgcolor='white', height=500)
402    
403    return fig

Plot a scatter grid of PCA components

This function is written to use the output from process.dim_reduction

Parameters
  • proj (_pd.DataFrame): Data frame with pca projections, from process.dim_reduction
  • exp_var (array_like): List of explained variances for each PCA component
  • time_col (str, optional): Column containing time points; must be in index; by default 'timepoint'
  • n_comps (int, optional): Number of pca components to plot, by default 4
Returns
  • go.Figure: Plotly figure
def clusters(data, compound_a, compound_b, method, time_col='timepoint'):
406def clusters(data, compound_a, compound_b, method, time_col='timepoint'):
407    """
408    Create clustering figures that compare 2 compounds.
409
410    This function is written to use the output from `process.dim_reduction`
411
412    Parameters
413    ----------
414    `data` : pd.DataFrame
415        _description_
416    `compound_a` : str
417        Compound A (red points)
418    `compound_b` : str
419        Compound B (green points)
420    `method` : str
421        One of 'PCA', 'tSNE' or 'UMAP'
422    `time_col` : str, optional
423        Column containing time points, by default 'timepoint'
424
425    Returns
426    -------
427    go.Figure
428        Plotly figure
429    """
430
431    assert isinstance(data, pd.DataFrame), "data must be a data frame"
432    assert 'compound' in data.columns, "'compound' must be a column in data"
433    assert compound_a in data['compound'].tolist(), \
434        "compound_a must be present in data['compound']"
435    assert compound_b in data['compound'].tolist(), \
436        "compound_b must be present in data['compound']"
437    method = method.lower()
438    assert method in ['pca', 'tsne', 'umap'], \
439        "method must be one of 'PCA', 'tSNE' or 'UMAP'"
440    assert time_col in data.columns, f"{time_col} must be a column in proj"
441
442    times = data[time_col].unique()
443    method = method.upper()
444
445    def create_traces(compound, colorscale, linecolor, cbar_pos):
446        """
447        Create traces for a particular compound
448        """
449
450        # Define colors for different time points
451        if len(times) > 1:
452            color_map = {tp: c for tp, c in
453                zip(times, sns.color_palette(colorscale, len(times)).as_hex())}
454            data['timecolor'] = data[time_col].replace(color_map)
455        else:
456            data['timecolor'] = colorscale
457
458        # Subset data for compound
459        cmpd = (data
460            .query(f"compound == '{compound}' & variable == '{method}'")
461            .reset_index(drop=True))
462
463        # Encode different sizes for each concentration
464        cmpd['conc'] = cmpd['conc'].astype(float)
465        concs = sorted(cmpd['conc'].unique())
466        if len(concs) > 1:
467            conc_map = {c: s*3 for c, s in zip(concs, range(1, len(concs)+1))}
468        else:
469            conc_map = {concs[0]: 20}
470
471        # Create traces
472        traces = []
473
474        for i, conc in enumerate(concs):
475            # Create mask to subset data
476            I = (cmpd['conc'] == conc)
477
478            traces.append(
479                go.Scatter(
480                    x=cmpd.loc[I, '1'],
481                    y=cmpd.loc[I, '2'],
482                    mode='markers',
483                    marker=dict(
484                        color=cmpd.loc[I, time_col],
485                        size=conc_map[conc],
486                        opacity=0.5,
487                        colorscale=colorscale,
488                        colorbar=dict(
489                            x=cbar_pos, 
490                            thickness=20, 
491                            yanchor='middle', 
492                            len=.7
493                        ) if (i == 0) and (len(times) > 1) else None,
494                        line=dict(width=1.2, color=linecolor)
495                    ),
496                    name=str(conc),
497                    legendgroup=compound.replace(' ', '').lower(),
498                    legendgrouptitle_text=(compound if i == 0 else None)
499                )
500            )
501
502        return traces
503
504    # Create figure traces
505    traces = [*create_traces(compound_a, 'Reds', 'red', -0.25),
506            *create_traces(compound_b, 'Greens', 'green', -0.35)]
507
508    # Create layout
509    layout = go.Layout(
510        legend=dict(tracegroupgap=20, groupclick='toggleitem'),
511        template='plotly_white',
512        margin=dict(l=20, r=20, t=70, b=40),
513        height=500,
514        title=f"{compound_a} vs {compound_b}<br>({method} Clusters)",
515        title_x=0.5,
516        xaxis_title=f'{method} 1',
517        yaxis_title=f'{method} 2',
518    )
519
520    # Create figure
521    fig = go.Figure(data=traces, layout=layout)
522    fig.update_yaxes(
523        scaleanchor='x',
524        scaleratio=1
525    )
526
527    # Annotate the colorbars
528    if len(times) > 1:
529        fig.add_annotation(
530            xref='paper', 
531            yref='paper', 
532            x=-0.33, 
533            y= 0.92, 
534            text='Time Point',
535            font_size=14,
536            showarrow=False
537        )
538
539    return fig

Create clustering figures that compare 2 compounds.

This function is written to use the output from process.dim_reduction

Parameters
  • data (pd.DataFrame): _description_
  • compound_a (str): Compound A (red points)
  • compound_b (str): Compound B (green points)
  • method (str): One of 'PCA', 'tSNE' or 'UMAP'
  • time_col (str, optional): Column containing time points, by default 'timepoint'
Returns
  • go.Figure: Plotly figure
def distplot( data, features, group_col, tooltips=None, kind='box', col_wrap=2, title_len=30):
624def distplot(data, features, group_col, tooltips=None, kind='box', col_wrap=2, 
625             title_len=30):
626    """
627    Create boxplots showing the distibution of features for different groups.
628
629    This generates a figure with as many subplots as there are features
630
631    Parameters
632    ----------
633    `data` : pd.DataFrame
634        Data frame to plot
635    `features` : list
636        List of features to visualize
637    `group_col` : str
638        `data` column that contains groups of interest
639    `tooltips` : dict, optional
640        Dictionary that defines annotation tooltips, by default None
641        Keys = Tooltip Name;  
642        Values = Corresponding column in `data`
643    `kind` : str, optional
644        Type of plot to generate; one of 'box', 'bar', by default 'box'
645    `col_wrap` : int, optional
646        Number of columns allowed in layout, by default 3
647    `title_len` : int, optional
648        Wrap length for subplot titles, by default 30
649
650    Returns
651    -------
652    go.Figure
653        Plotly figure
654
655    Raises
656    ------
657    NotImplementedError
658        When `kind != 'box'`
659    """
660
661    assert isinstance(data, pd.DataFrame), "data must be a data frame"
662    assert kind in ['box', 'bar'], "kind must be one of 'box', 'bar'."
663    assert group_col in data.columns, "group_col must be a column in data"
664    for f in features:
665        assert f in data.columns, "features must contain columns from data"
666    if tooltips is not None:
667        for col in tooltips.values():
668            assert col in data.columns, \
669                "Values of tooltips must be columns of data"
670
671    # Determine grid dimensions & positions
672    nrows, ncols, positions = _make_grid(features, col_wrap=col_wrap)
673
674    # Wrap text for subplot titles
675    titles = ['<br>'.join(textwrap.wrap(x, title_len)) for x in features]
676
677    # Get list of groups & colors
678    groups = data[group_col].unique()
679    colors = _get_colors(len(groups))
680
681    # Create figure
682    fig = sp.make_subplots(rows=nrows, cols=ncols, subplot_titles=titles)
683
684    # Add traces
685    if kind == 'box':
686        for feature, pos in positions.items():
687            for grp, color in zip(groups, colors):
688                # Prepare data for annotations
689                _data = data.query(f"{group_col} == '{grp}'")
690                text = 'well: ' + data['well'] + '<br>'
691                if tooltips is not None:
692                    for name, col in tooltips.items():
693                        text += f'{name}: ' + data[col].astype(str) + '<br>'
694
695                fig.add_trace(
696                    go.Box(
697                        y=_data[feature],
698                        name=grp,
699                        text=text.tolist(),
700                        hovertemplate='%{text}',
701                        legendgroup=grp,
702                        marker_color=color,
703                        showlegend=True if pos['x'] == pos['y'] == 1 else False
704                    ),
705                    row=pos['x'], 
706                    col=pos['y']
707                )
708    elif kind == 'bar':
709        raise NotImplementedError("Can't do that yet. Working on it.")
710    else:
711        raise NotImplementedError("Can't do that yet.")
712
713    fig.update_xaxes(showticklabels=False)
714    fig.update_layout(template='plotly_white')
715    fig.update_annotations(font=dict(family="Helvetica", size=14))
716
717    return fig

Create boxplots showing the distibution of features for different groups.

This generates a figure with as many subplots as there are features

Parameters
  • data (pd.DataFrame): Data frame to plot
  • features (list): List of features to visualize
  • group_col (str): data column that contains groups of interest
  • tooltips (dict, optional): Dictionary that defines annotation tooltips, by default None Keys = Tooltip Name;
    Values = Corresponding column in data
  • kind (str, optional): Type of plot to generate; one of 'box', 'bar', by default 'box'
  • col_wrap (int, optional): Number of columns allowed in layout, by default 3
  • title_len (int, optional): Wrap length for subplot titles, by default 30
Returns
  • go.Figure: Plotly figure
Raises
  • NotImplementedError: When kind != 'box'
def textplot(text):
720def textplot(text):
721    """
722    Create a blank figure to display some text. Serves as placeholder for 
723    actual figure.
724
725    Parameters
726    ----------
727    `text` : str
728        Message to display in figure
729
730    Return
731    ------
732    go.Figure
733        Plotly figure
734    """
735
736    fig = go.Figure(
737        go.Scatter(
738            x=[0], y=[0], text=[text], textposition='top center',
739            textfont_size=16, mode='text', hoverinfo='skip'
740        )
741    )
742    fig.update_layout(template='simple_white', height=300)
743    fig.update_xaxes(visible=False, fixedrange=True)
744    fig.update_yaxes(visible=False, fixedrange=True)
745
746    return fig

Create a blank figure to display some text. Serves as placeholder for actual figure.

Parameters
  • text (str): Message to display in figure
Return

go.Figure Plotly figure

def gifify(fig, file, frame_title='Frame', fps=30) -> None:
749def gifify(fig, file, frame_title='Frame', fps=30) -> None:
750    """
751    Export a plotly animation as a gif
752
753    Parameters
754    ----------
755    `fig` : go.Figure
756        Plotly figure
757    `file` : str
758        Path to file where figure gif will be stored
759    `frame_title` : str, optional
760        Title that describes each frame, by default 'Frame'
761    `fps` : int, optional
762        Frame rate, by default 30
763    """
764
765    assert isinstance(fig, go.Figure), \
766        "This only works for plotly figures"
767    assert file.endswith('.gif'), "file must be a .gif"
768
769    import moviepy.editor as mpy
770    from PIL import Image
771
772    def fig2array(fig):
773        """
774        Convert a plotly figure to a numpy array
775        """
776
777        bytes = fig.to_image(format='jpg', engine='kaleido')
778        buffer = io.BytesIO(bytes)
779        img = Image.open(buffer)
780
781        return img
782
783    # Remove sliders and buttons from figure layout
784    exclude = ['updatemenus', 'sliders']
785    layout = fig.to_dict()['layout']
786    layout = {k: v for k, v in layout.items() if not k in exclude}
787
788    # Create list to store frames (as images)
789    frames = []
790    for i, frame in enumerate(fig['frames']):
791        _fig = go.Figure(data=frame['data'], layout=layout)
792        _fig.update_layout(title=f"{frame_title} {i+1}", title_x=0.5)
793        frames.append( fig2array(_fig) )
794
795    # Create animation
796    make_frame = lambda t: frames[int(t)]
797    anim = mpy.VideoClip(make_frame, duration=len(frames))
798    anim.write_gif(file, fps=fps, logger=None)
799    print("Done :thumbsup:")

Export a plotly animation as a gif

Parameters
  • fig (go.Figure): Plotly figure
  • file (str): Path to file where figure gif will be stored
  • frame_title (str, optional): Title that describes each frame, by default 'Frame'
  • fps (int, optional): Frame rate, by default 30
def heatmap( data, col_groups=None, col_colors=None, col_group_names=None, row_groups=None, row_colors=None, row_group_names=None, clust_cols=True, clust_rows=True, cluster_kws={}):
802def heatmap(data, col_groups=None, col_colors=None, col_group_names=None,
803            row_groups=None, row_colors=None, row_group_names=None,
804            clust_cols=True, clust_rows=True, cluster_kws=dict()):
805    """
806    Construct an interactive heatmap
807
808    Parameters
809    ----------
810    data : pd.DataFrame
811        Data to plota
812    {row, col}_groups : dict
813        Dictionary assigning groups to rows or columns.
814        Keys should be the index or columns of data.
815        Values should be a list of groups.
816    {row, col}_group_names : list
817        Names for each of the row/col groups
818        Should be the same length as the lists in {row, col}_groups
819    {row, col}_colors : dict
820        Dictionary defining colors for each group.
821        Keys = groups;  Values = colors;
822    clust_{rows, cols} : bool
823        Should row and/or column clustering be performed
824    cluster_kws : dict
825        kwargs for sns.clustermap
826    """
827
828    from sklearn.preprocessing import LabelEncoder
829
831
832    # Check inputs
834    ## col_groups and row_groups: each value should be of the same length
835    if col_group_names is None:
836        col_group_names = []
837    if row_group_names is None:
838        row_group_names = []
839
840    # Determine the size of the subplot grid
841    n_col_grps = len(col_group_names)
842    n_row_grps = len(row_group_names)
843    I, J = n_col_grps+1, n_row_grps+1
844
845    # Define column widths and row heights
846    col_widths = [0.03 for _ in range(J-1)] + [1-(J-1)*0.03]
847    row_heights = [0.07 for _ in range(I-1)] + [1-(I-1)*0.07]
848
849    # Create subplot grid
850    fig = sp.make_subplots(
851        rows=I, 
852        cols=J, 
853        column_widths=col_widths, 
854        row_heights=row_heights,
855        vertical_spacing=0.01,
856        horizontal_spacing=0.01,
857        shared_xaxes=True,
858        shared_yaxes=True
859    )
860
861    # Perform clustering and extract clustered data frame from seaborn clustermap
862    if clust_cols or clust_rows:
863        data = sns.clustermap(data, row_cluster=clust_rows, col_cluster=clust_cols,
864                            **cluster_kws).data2d
865        plt.close()
866
867    # Plot the heatmap and adjust axes
868    fig.append_trace(
869        go.Heatmap(
870            z=data,
871            x=data.columns,
872            y=data.index.astype(str),
873            colorscale='RdBu_r',
874            hovertemplate='<b>Sample:</b> %{y}<br>'+
875                        '<b>Feature:</b> %{x}<br>'+
876                        '<b>Value:</b>%{z}'
877                        '<extra></extra>'
878        ),
879        row=I, col=J
880    )
881    fig.update_yaxes(row=I, col=J, showticklabels=False, autorange='reversed')
882    fig.update_xaxes(row=I, col=J, showticklabels=True, tickangle=270)
883    fig.update_traces(row=I, col=J, colorbar_len=0.7)
884
885    # Add row colors
886    if row_groups is not None:
887        for j, grp in enumerate(row_group_names):
888            # Create row data
889            row_data = [row_groups[r][j] for r in data.index]
890
891            # Encode row data numerically so that heatmap can be plotted
892            le = LabelEncoder().fit(row_data)
893            Z = le.transform(row_data)
894
895            # Define colorscale
896            znorm = np.unique((Z-Z.min()) / (Z.max()-Z.min()))
897            zmax = Z.max()
898            colorscale = [[z, row_colors[le.inverse_transform([int(z*zmax)])[0]]] 
899                        for z in znorm]
900
901            fig.append_trace(
902                go.Heatmap(
903                    z=pd.DataFrame(Z), 
904                    y=data.index.astype(str),
905                    x=[grp],
906                    text=pd.DataFrame(row_data),
907                    colorscale=colorscale,
908                    hovertemplate='<b>Sample:<b> %{y}<br>'+ 
909                                f'<b>{grp}:</b>: %{{text}}'+
910                                '<extra></extra>',
911                    showscale=False
912                ),
913                row=I, col=j+1
914            )
915            fig.update_yaxes(row=I, col=j+1, showticklabels=False, 
916                             autorange='reversed')
917            fig.update_xaxes(row=I, col=j+1, showticklabels=True, 
918                             tickangle=270,
919                             tickfont={'size': 15, 'family': 'Arial'})
920
921    # Add column colors
922    if col_groups is not None:
923        for i, grp in enumerate(col_group_names):
924            # Create column data
925            col_data = [col_groups[r][i] for r in data.columns]
926
927            # Encode col data numerically so that heatmap can be plotted
928            le = LabelEncoder().fit(col_data)
929            Z = le.transform(col_data)
930
931            # Define colorscale
932            znorm = np.unique((Z-Z.min()) / (Z.max()-Z.min()))
933            zmax = Z.max()
934            colorscale = [[z, col_colors[le.inverse_transform([int(z*zmax)])[0]]] 
935                        for z in znorm]
936
937            fig.append_trace(
938                go.Heatmap(
939                    z=pd.DataFrame(Z).T,
940                    y=[grp],
941                    x=data.columns,
942                    text=pd.DataFrame(col_data).T,
943                    colorscale=colorscale,
944                    hovertemplate='<b>Feature:</b> %{x}<br>'+
945                                f'<b>{grp}:</b>: %{{text}}'+
946                                '<extra></extra>',
947                    showscale=False
948                ),
949                row=i+1, col=J
950            )
951            fig.update_yaxes(row=i+1, col=J, showticklabels=True, 
952                             autorange='reversed', side='right', 
953                             tickfont={'size': 15, 'family': 'Arial'})
954            fig.update_xaxes(row=i+1, col=J, showticklabels=False)
955
957    # BUG: Clustering Rows makes some data disappear; 
958    #      looks like some aggregation is happening
959
960    return fig

Construct an interactive heatmap

Parameters
  • data (pd.DataFrame): Data to plota
  • {row, col}_groups (dict): Dictionary assigning groups to rows or columns. Keys should be the index or columns of data. Values should be a list of groups.
  • {row, col}_group_names (list): Names for each of the row/col groups Should be the same length as the lists in {row, col}_groups
  • {row, col}_colors (dict): Dictionary defining colors for each group. Keys = groups; Values = colors;
  • clust_{rows, cols} (bool): Should row and/or column clustering be performed
  • cluster_kws (dict): kwargs for sns.clustermap