hcitools.plot
This module contains functions for visualizing data and analysis results.
1""" 2This module contains functions for visualizing data and analysis results. 3""" 4 5# Imports 6from rich import print 7 8import plotly.graph_objects as go 9import plotly.subplots as sp 10import plotly.express as px 11import plotly.io as pio 12import matplotlib.pyplot as plt 13import seaborn as sns 14import pandas as pd 15import numpy as np 16import textwrap 17import math 18import io 19 20 21LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 22ONE_THIRD = 1.0 / 3.0 23ONE_SIXTH = 1.0 / 6.0 24TWO_THIRD = 2.0 / 3.0 25 26 27def set_renderer(renderer): 28 """ 29 Set the plotly default renderer 30 """ 31 32 pio.renderers.default = renderer 33 34 35class colormap: 36 """ 37 Custom colormaps for plotly figures 38 39 These colormaps assume the data has been scaled to between 0 and 1. 40 41 Attributes 42 ---------- 43 `OgBu` : list 44 Seaborn diverging colorscale from blue (low) to orange (high) 45 """ 46 47 OgBu = [[0.00, '#3F7F93'], [0.10, '#6296A6'], [0.20, '#85ADB9'], 48 [0.30, '#A9C4CC'], [0.40, '#CDDBE0'], [0.50, '#F2F1F1'], 49 [0.60, '#E9D2CD'], [0.70, '#DFB3A7'], [0.80, '#D69483'], 50 [0.90, '#CC745D'], [1.00, '#C3553A']] 51 53 54 55class LabelEncoder: 56 """ 57 Encode target labels with values between 0 and n_classes-1 58 59 Attributes 60 ---------- 61 `encoder` : dict 62 dictionary mapping target labels to encodings 63 `decoder` : dict 64 dictionary mapping encodings to target labels 65 `dtype` : np.dtype 66 dtype of original labels 67 68 Methods 69 ------- 70 `encode(labels)` 71 Encode a list of target labels 72 `decode(enc_labels)` 73 Decode a list of encoded labels 74 """ 75 76 def encode(self, labels): 77 """ 78 Parameters 79 ---------- 80 `labels` : array_like 81 list of target labels 82 83 Returns 84 ------- 85 np.array 86 Encoded labels 87 88 Raises 89 ------ 90 AssertionError 91 If `enc_labels` is not 1-dimensional 92 """ 93 94 labels = np.asarray(labels) 95 assert labels.ndim == 1, "labels must be 1-dimensional" 96 97 # Get unique classes 98 classes = np.unique(labels) 99 100 # Create and store maps 101 self.encoder = {l: float(e) for e, l in enumerate(classes)} 102 self.decoder = {float(e): l for e, l in enumerate(classes)} 103 self.dtype = labels.dtype 104 105 return np.fromiter(map(self.encoder.get, labels), dtype=float) 106 107 def decode(self, enc_labels): 108 """ 109 Parameters 110 ---------- 111 `enc_labels` : array_like 112 list of encoded lavels 113 114 Returns 115 ------- 116 np.array 117 Decoded labels 118 119 Raises 120 ------ 121 AssertionError 122 If `enc_labels` is not 1-dimensional 123 """ 124 125 enc_labels = np.asarray(enc_labels) 126 assert enc_labels.ndim == 1, "enc_labels must be 1-dimensional" 127 128 return np.fromiter(map(self.decoder.get, enc_labels), dtype=self.dtype) 129 130 131def _make_plate(data, feature, time_col='timepoint'): 132 """ 133 Convert a feature data frame into a plate layout. 134 135 This assumes `data` contains the following columns: `row`, `column` and 136 `time_col`. 137 138 Parameters 139 ---------- 140 `data` : pd.DataFrame 141 a data frame of features including certain metadata columns 142 `feature` : str 143 feature to populate plate with 144 `time_col` : str, optional 145 column that defines time points, by default 'timepoint' 146 147 Returns 148 ------- 149 np.array 150 $(k \times r \times c)$ array where `k` = timepoint 151 """ 152 153 assert feature in data.columns, "feature must be a column in data" 154 assert time_col in data.columns, f"{time_col} must be a column in data" 155 assert ('row' in data.columns) and ('column' in data.columns), \ 156 "'row' and 'column' must be columns in data" 157 158 # Extract time points 159 times = data[time_col].unique() 160 161 # Define row and column names 162 r = len(np.unique(data['row'])) 163 c = len(np.unique(data['column'])) 164 rows = {i: x for i, x in enumerate(LETTERS[:r], 1)} 165 cols = [str(i) for i in range(1, c+1)] 166 167 # Create plate 168 plate = [] 169 for T in times: 170 plate.append( 171 data 172 .query(f"{time_col} == {T}") 173 .loc[:, ['row', 'column', feature]] 174 .pivot(index='row', columns='column', values=feature) 175 .rename(index=rows) 176 .sort_index(ascending=False) 177 .values 178 ) 179 180 return np.array(plate), list(rows.values())[::-1], cols 181 182 183def plate_heatmap(data, feature, time_col='timepoint', colorscale=colormap.OgBu): 184 """ 185 Create an interactive plate heatmap; Including an animation for timelapses 186 187 This function assumes that `data` contains the following columns: `row`, 188 `column`, `time_col`, `compound`, `conc`. 189 190 Parameters 191 ---------- 192 `data` : pd.DataFrame 193 a data frame of features including certain metadata columns 194 `feature` : str 195 feature to populate plate with 196 `time_col` : str, optional 197 column that defines time points, by default 'timepoint' 198 This assumes the first time point is 1. 199 `colorscale` : list, optional 200 Plotly-compatible colormap, by default `colormap.OgBu` 201 See `colormap.OgBu` for examples. 202 203 Returns 204 ------- 205 go.Figure 206 Plotly figure 207 """ 208 209 data.columns = [x.lower() for x in data.columns] 210 feature = feature.lower() 211 time_col = time_col.lower() 212 213 assert feature in data.columns, "feature must be a column in data" 214 assert time_col in data.columns, f"{time_col} must be a column in data" 215 assert data[time_col].min() > 0, "the first time point must be 1 not 0" 216 assert ('compound' in data.columns) and ('conc' in data.columns), \ 217 "'compound' and 'conc' must be columns in data" 218 219 def platemap(x, y, z, cmpd, conc): 220 """ 221 Wrapper for go.Heatmap 222 """ 223 224 # Insert line breaks in the compound names 225 cmpd = [ 226 [x.replace(' ', '<br>') if isinstance(x, str) else '' for x in sub] 227 for sub in cmpd 228 ] 229 230 return go.Heatmap( 231 x=x, y=y, z=z, 232 colorscale=colorscale, 233 text=cmpd, 234 customdata=conc, 235 hovertemplate=( 236 '<b>Well:</b> %{y}%{x}<br>' + 237 '<b>Compound:</b> %{text}' + 238 '<b>Concentration:</b> %{customdata}<br>' + 239 '<b>Value:</b> %{z:.2f}<extra></extra>' 240 ), 241 texttemplate='%{text}', 242 # textfont_size=8.5 243 ) 244 245 # Extract time points 246 times = data[time_col].unique() 247 248 # Reformat data as plate 249 plate, rows, cols = _make_plate(data, feature, time_col) 250 251 # Create data for tooltips 252 cmpd = _make_plate(data, 'compound', time_col)[0][0,...] 253 conc = _make_plate(data, 'conc', time_col)[0][0,...] 254 255 # Create figure and fill in layout 256 fig = {'data': [], 'layout': {}, 'frames': []} 257 258 fig['layout'] = go.Layout( 259 title=feature, 260 title_x=0.5, 261 xaxis={ 262 'showgrid': False, 263 'showticklabels': True, 264 'tickfont': {'size': 16, 'color': 'black'} 265 }, 266 yaxis={ 267 'showgrid': False, 268 'showticklabels': True, 269 'tickfont': {'size': 16, 'color': 'black'} 270 }, 271 hovermode='closest', 272 updatemenus=[{ 273 "buttons": [ 274 {"args": [None, {"frame": {"duration": 500, "redraw": True}, 275 "fromcurrent": True, 276 "transition": {"duration": 300, 277 "easing": "quadratic-in-out"}}], 278 "label": "Play", 279 "method": "animate"}, 280 {"args": [[None], {"frame": {"duration": 0, "redraw": True}, 281 "mode": "immediate", 282 "transition": {"duration": 0}}], 283 "label": "Pause", 284 "method": "animate"} 285 ], 286 "direction": "left", 287 "pad": {"r": 10, "t": 87}, 288 "showactive": False, 289 "type": "buttons", 290 "x": 0.1, 291 "xanchor": "right", 292 "y": 0, 293 "yanchor": "top" 294 }] if len(times) > 1 else None 295 ) 296 297 # Add Time 0 plate to data 298 fig['data'].append(platemap(cols, rows, plate[0,...], cmpd, conc)) 299 300 # Create frames & animations 301 if len(times) > 1: 302 # Create sliders 303 sliders = { 304 "active": 0, 305 "yanchor": "top", 306 "xanchor": "left", 307 "currentvalue": { 308 "font": {"size": 20}, 309 "prefix": "Timepoint:", 310 "visible": True, 311 "xanchor": "right" 312 }, 313 "transition": {"duration": 300, "easing": "cubic-in-out"}, 314 "pad": {"b": 10, "t": 50}, 315 "len": 0.9, 316 "x": 0.1, 317 "y": 0, 318 "steps": [] 319 } 320 321 # Create frames 322 for time in data[time_col].unique().astype(str): 323 # New frame 324 fig['frames'].append({ 325 "data": platemap(cols, rows, plate[int(time)-1,...], cmpd, conc), 326 "name": str(time) 327 }) 328 329 # Corresponding slider step 330 sliders['steps'].append({ 331 "args": [ 332 [time], 333 {"frame": {"duration": 300, "redraw": True}, 334 "mode": "immediate", 335 "transition": {"duration": 300}} 336 ], 337 "label": time, 338 "method": "animate" 339 }) 340 341 # Add sliders to figure layout 342 fig['layout']['sliders'] = [sliders] 343 344 return go.Figure(fig) 345 346 347def pca_comps(proj, exp_var, time_col='timepoint', n_comps=4): 348 """ 349 Plot a scatter grid of PCA components 350 351 This function is written to use the output from `process.dim_reduction` 352 353 Parameters 354 ---------- 355 `proj` : _pd.DataFrame 356 Data frame with pca projections, from `process.dim_reduction` 357 `exp_var` : array_like 358 List of explained variances for each PCA component 359 `time_col` : str, optional 360 Column containing time points; must be in index; by default 'timepoint' 361 `n_comps` : int, optional 362 Number of pca components to plot, by default 4 363 364 Returns 365 ------- 366 go.Figure 367 Plotly figure 368 """ 369 370 proj.columns = [x.lower() for x in proj.columns] 371 assert 'variable' in proj.columns, "variable must be a column in proj" 372 assert 'compound' in proj.columns, f"compound must be a column in proj" 373 assert time_col in proj.columns, f"{time_col} must be a column in proj" 374 assert 'conc' in proj.columns, "conc must be a column in proj" 375 376 # Prepare matrix of components as well as variables for plotting 377 comp_cols = [str(x+1) for x in range(n_comps)] 378 comps = (proj.query("variable == 'PCA'") 379 .reset_index(drop=True) 380 [['compound', 'conc', time_col, *comp_cols]]) 381 compounds = comps['compound'] 382 comps.drop(['compound', time_col, 'conc'], axis=1, inplace=True) 383 384 # Create labels 385 labels = {str(i): f"PC {i+1} ({var:.2f}%)" for i, var in enumerate(exp_var)} 386 389 390 # Create figure 391 fig = px.scatter_matrix( 392 comps.values, 393 labels=labels, 394 dimensions=range(n_comps), 395 color=compounds, 396 opacity=0.5, 397 template='plotly_white' 398 ) 399 fig.update_traces(diagonal_visible=False) 400 fig.update_layout(paper_bgcolor='white', plot_bgcolor='white', height=500) 401 402 return fig 403 404 405def clusters(data, compound_a, compound_b, method, time_col='timepoint'): 406 """ 407 Create clustering figures that compare 2 compounds. 408 409 This function is written to use the output from `process.dim_reduction` 410 411 Parameters 412 ---------- 413 `data` : pd.DataFrame 414 _description_ 415 `compound_a` : str 416 Compound A (red points) 417 `compound_b` : str 418 Compound B (green points) 419 `method` : str 420 One of 'PCA', 'tSNE' or 'UMAP' 421 `time_col` : str, optional 422 Column containing time points, by default 'timepoint' 423 424 Returns 425 ------- 426 go.Figure 427 Plotly figure 428 """ 429 430 assert isinstance(data, pd.DataFrame), "data must be a data frame" 431 assert 'compound' in data.columns, "'compound' must be a column in data" 432 assert compound_a in data['compound'].tolist(), \ 433 "compound_a must be present in data['compound']" 434 assert compound_b in data['compound'].tolist(), \ 435 "compound_b must be present in data['compound']" 436 method = method.lower() 437 assert method in ['pca', 'tsne', 'umap'], \ 438 "method must be one of 'PCA', 'tSNE' or 'UMAP'" 439 assert time_col in data.columns, f"{time_col} must be a column in proj" 440 441 times = data[time_col].unique() 442 method = method.upper() 443 444 def create_traces(compound, colorscale, linecolor, cbar_pos): 445 """ 446 Create traces for a particular compound 447 """ 448 449 # Define colors for different time points 450 if len(times) > 1: 451 color_map = {tp: c for tp, c in 452 zip(times, sns.color_palette(colorscale, len(times)).as_hex())} 453 data['timecolor'] = data[time_col].replace(color_map) 454 else: 455 data['timecolor'] = colorscale 456 457 # Subset data for compound 458 cmpd = (data 459 .query(f"compound == '{compound}' & variable == '{method}'") 460 .reset_index(drop=True)) 461 462 # Encode different sizes for each concentration 463 cmpd['conc'] = cmpd['conc'].astype(float) 464 concs = sorted(cmpd['conc'].unique()) 465 if len(concs) > 1: 466 conc_map = {c: s*3 for c, s in zip(concs, range(1, len(concs)+1))} 467 else: 468 conc_map = {concs[0]: 20} 469 470 # Create traces 471 traces = [] 472 473 for i, conc in enumerate(concs): 474 # Create mask to subset data 475 I = (cmpd['conc'] == conc) 476 477 traces.append( 478 go.Scatter( 479 x=cmpd.loc[I, '1'], 480 y=cmpd.loc[I, '2'], 481 mode='markers', 482 marker=dict( 483 color=cmpd.loc[I, time_col], 484 size=conc_map[conc], 485 opacity=0.5, 486 colorscale=colorscale, 487 colorbar=dict( 488 x=cbar_pos, 489 thickness=20, 490 yanchor='middle', 491 len=.7 492 ) if (i == 0) and (len(times) > 1) else None, 493 line=dict(width=1.2, color=linecolor) 494 ), 495 name=str(conc), 496 legendgroup=compound.replace(' ', '').lower(), 497 legendgrouptitle_text=(compound if i == 0 else None) 498 ) 499 ) 500 501 return traces 502 503 # Create figure traces 504 traces = [*create_traces(compound_a, 'Reds', 'red', -0.25), 505 *create_traces(compound_b, 'Greens', 'green', -0.35)] 506 507 # Create layout 508 layout = go.Layout( 509 legend=dict(tracegroupgap=20, groupclick='toggleitem'), 510 template='plotly_white', 511 margin=dict(l=20, r=20, t=70, b=40), 512 height=500, 513 title=f"{compound_a} vs {compound_b}<br>({method} Clusters)", 514 title_x=0.5, 515 xaxis_title=f'{method} 1', 516 yaxis_title=f'{method} 2', 517 ) 518 519 # Create figure 520 fig = go.Figure(data=traces, layout=layout) 521 fig.update_yaxes( 522 scaleanchor='x', 523 scaleratio=1 524 ) 525 526 # Annotate the colorbars 527 if len(times) > 1: 528 fig.add_annotation( 529 xref='paper', 530 yref='paper', 531 x=-0.33, 532 y= 0.92, 533 text='Time Point', 534 font_size=14, 535 showarrow=False 536 ) 537 538 return fig 539 540 541def _make_grid(items, col_wrap=2): 542 """ 543 Split a list of items into a grid for subplots 544 545 Parameters 546 ---------- 547 `items` : list 548 List of items for each subplot (e.g., tiles) 549 `col_wrap` : int, optional 550 Number of columns allowed in layout, by default 2 551 552 Returns 553 ------- 554 go.Figure 555 Plotly figure 556 """ 557 558 def grid_dims(n, col_wrap): 559 """ 560 Determine grid dimensions 561 """ 562 563 nrows = math.ceil(n / col_wrap) 564 ncols = col_wrap if n > col_wrap else n 565 566 return nrows, ncols 567 568 nrows, ncols = grid_dims(len(items), col_wrap) 569 positions = { 570 x: {'x': (i // col_wrap) + 1, 'y': (i % col_wrap) + 1} 571 for i, x in enumerate(items) 572 } 573 574 return nrows, ncols, positions 575 576 577def _v(m1, m2, hue): 578 hue = hue % 1.0 579 if hue < ONE_SIXTH: 580 return m1 + (m2-m1)*hue*6.0 581 if hue < 0.5: 582 return m2 583 if hue < TWO_THIRD: 584 return m1 + (m2-m1)*(TWO_THIRD-hue)*6.0 585 return m1 586 587 588def _hls_to_rgb(h, l, s): 589 """ 590 Convert HLS (Hue, Luminance, Saturation) to RGB 591 """ 592 593 if s == 0.0: 594 return l, l, l 595 if l <= 0.5: 596 m2 = l * (1.0+s) 597 else: 598 m2 = l+s-(l*s) 599 m1 = 2.0*l - m2 600 return (_v(m1, m2, h+ONE_THIRD), _v(m1, m2, h), _v(m1, m2, h-ONE_THIRD)) 601 602 603def _get_colors(n): 604 """ 605 Generate n visually distinct colors. 606 607 This is taken from [this](https://stackoverflow.com/a/9701141) stack 608 overflow post. 609 """ 610 611 colors = [] 612 for i in np.arange(0., 360., 360. / n): 613 hue = i/360. 614 lightness = (50 + np.random.rand() * 10)/100 615 saturation = (90 + np.random.rand() * 10)/100 616 617 r, g, b = _hls_to_rgb(hue, lightness, saturation) 618 colors.append( "#%02x%02x%02x" % (int(r*255), int(g*255), int(b*255)) ) 619 620 return colors 621 622 623def distplot(data, features, group_col, tooltips=None, kind='box', col_wrap=2, 624 title_len=30): 625 """ 626 Create boxplots showing the distibution of features for different groups. 627 628 This generates a figure with as many subplots as there are features 629 630 Parameters 631 ---------- 632 `data` : pd.DataFrame 633 Data frame to plot 634 `features` : list 635 List of features to visualize 636 `group_col` : str 637 `data` column that contains groups of interest 638 `tooltips` : dict, optional 639 Dictionary that defines annotation tooltips, by default None 640 Keys = Tooltip Name; 641 Values = Corresponding column in `data` 642 `kind` : str, optional 643 Type of plot to generate; one of 'box', 'bar', by default 'box' 644 `col_wrap` : int, optional 645 Number of columns allowed in layout, by default 3 646 `title_len` : int, optional 647 Wrap length for subplot titles, by default 30 648 649 Returns 650 ------- 651 go.Figure 652 Plotly figure 653 654 Raises 655 ------ 656 NotImplementedError 657 When `kind != 'box'` 658 """ 659 660 assert isinstance(data, pd.DataFrame), "data must be a data frame" 661 assert kind in ['box', 'bar'], "kind must be one of 'box', 'bar'." 662 assert group_col in data.columns, "group_col must be a column in data" 663 for f in features: 664 assert f in data.columns, "features must contain columns from data" 665 if tooltips is not None: 666 for col in tooltips.values(): 667 assert col in data.columns, \ 668 "Values of tooltips must be columns of data" 669 670 # Determine grid dimensions & positions 671 nrows, ncols, positions = _make_grid(features, col_wrap=col_wrap) 672 673 # Wrap text for subplot titles 674 titles = ['<br>'.join(textwrap.wrap(x, title_len)) for x in features] 675 676 # Get list of groups & colors 677 groups = data[group_col].unique() 678 colors = _get_colors(len(groups)) 679 680 # Create figure 681 fig = sp.make_subplots(rows=nrows, cols=ncols, subplot_titles=titles) 682 683 # Add traces 684 if kind == 'box': 685 for feature, pos in positions.items(): 686 for grp, color in zip(groups, colors): 687 # Prepare data for annotations 688 _data = data.query(f"{group_col} == '{grp}'") 689 text = 'well: ' + data['well'] + '<br>' 690 if tooltips is not None: 691 for name, col in tooltips.items(): 692 text += f'{name}: ' + data[col].astype(str) + '<br>' 693 694 fig.add_trace( 695 go.Box( 696 y=_data[feature], 697 name=grp, 698 text=text.tolist(), 699 hovertemplate='%{text}', 700 legendgroup=grp, 701 marker_color=color, 702 showlegend=True if pos['x'] == pos['y'] == 1 else False 703 ), 704 row=pos['x'], 705 col=pos['y'] 706 ) 707 elif kind == 'bar': 708 raise NotImplementedError("Can't do that yet. Working on it.") 709 else: 710 raise NotImplementedError("Can't do that yet.") 711 712 fig.update_xaxes(showticklabels=False) 713 fig.update_layout(template='plotly_white') 714 fig.update_annotations(font=dict(family="Helvetica", size=14)) 715 716 return fig 717 718 719def textplot(text): 720 """ 721 Create a blank figure to display some text. Serves as placeholder for 722 actual figure. 723 724 Parameters 725 ---------- 726 `text` : str 727 Message to display in figure 728 729 Return 730 ------ 731 go.Figure 732 Plotly figure 733 """ 734 735 fig = go.Figure( 736 go.Scatter( 737 x=[0], y=[0], text=[text], textposition='top center', 738 textfont_size=16, mode='text', hoverinfo='skip' 739 ) 740 ) 741 fig.update_layout(template='simple_white', height=300) 742 fig.update_xaxes(visible=False, fixedrange=True) 743 fig.update_yaxes(visible=False, fixedrange=True) 744 745 return fig 746 747 748def gifify(fig, file, frame_title='Frame', fps=30) -> None: 749 """ 750 Export a plotly animation as a gif 751 752 Parameters 753 ---------- 754 `fig` : go.Figure 755 Plotly figure 756 `file` : str 757 Path to file where figure gif will be stored 758 `frame_title` : str, optional 759 Title that describes each frame, by default 'Frame' 760 `fps` : int, optional 761 Frame rate, by default 30 762 """ 763 764 assert isinstance(fig, go.Figure), \ 765 "This only works for plotly figures" 766 assert file.endswith('.gif'), "file must be a .gif" 767 768 import moviepy.editor as mpy 769 from PIL import Image 770 771 def fig2array(fig): 772 """ 773 Convert a plotly figure to a numpy array 774 """ 775 776 bytes = fig.to_image(format='jpg', engine='kaleido') 777 buffer = io.BytesIO(bytes) 778 img = Image.open(buffer) 779 780 return img 781 782 # Remove sliders and buttons from figure layout 783 exclude = ['updatemenus', 'sliders'] 784 layout = fig.to_dict()['layout'] 785 layout = {k: v for k, v in layout.items() if not k in exclude} 786 787 # Create list to store frames (as images) 788 frames = [] 789 for i, frame in enumerate(fig['frames']): 790 _fig = go.Figure(data=frame['data'], layout=layout) 791 _fig.update_layout(title=f"{frame_title} {i+1}", title_x=0.5) 792 frames.append( fig2array(_fig) ) 793 794 # Create animation 795 make_frame = lambda t: frames[int(t)] 796 anim = mpy.VideoClip(make_frame, duration=len(frames)) 797 anim.write_gif(file, fps=fps, logger=None) 798 print("Done :thumbsup:") 799 800 801def heatmap(data, col_groups=None, col_colors=None, col_group_names=None, 802 row_groups=None, row_colors=None, row_group_names=None, 803 clust_cols=True, clust_rows=True, cluster_kws=dict()): 804 """ 805 Construct an interactive heatmap 806 807 Parameters 808 ---------- 809 data : pd.DataFrame 810 Data to plota 811 {row, col}_groups : dict 812 Dictionary assigning groups to rows or columns. 813 Keys should be the index or columns of data. 814 Values should be a list of groups. 815 {row, col}_group_names : list 816 Names for each of the row/col groups 817 Should be the same length as the lists in {row, col}_groups 818 {row, col}_colors : dict 819 Dictionary defining colors for each group. 820 Keys = groups; Values = colors; 821 clust_{rows, cols} : bool 822 Should row and/or column clustering be performed 823 cluster_kws : dict 824 kwargs for sns.clustermap 825 """ 826 827 from sklearn.preprocessing import LabelEncoder 828 830 831 # Check inputs 833 ## col_groups and row_groups: each value should be of the same length 834 if col_group_names is None: 835 col_group_names = [] 836 if row_group_names is None: 837 row_group_names = [] 838 839 # Determine the size of the subplot grid 840 n_col_grps = len(col_group_names) 841 n_row_grps = len(row_group_names) 842 I, J = n_col_grps+1, n_row_grps+1 843 844 # Define column widths and row heights 845 col_widths = [0.03 for _ in range(J-1)] + [1-(J-1)*0.03] 846 row_heights = [0.07 for _ in range(I-1)] + [1-(I-1)*0.07] 847 848 # Create subplot grid 849 fig = sp.make_subplots( 850 rows=I, 851 cols=J, 852 column_widths=col_widths, 853 row_heights=row_heights, 854 vertical_spacing=0.01, 855 horizontal_spacing=0.01, 856 shared_xaxes=True, 857 shared_yaxes=True 858 ) 859 860 # Perform clustering and extract clustered data frame from seaborn clustermap 861 if clust_cols or clust_rows: 862 data = sns.clustermap(data, row_cluster=clust_rows, col_cluster=clust_cols, 863 **cluster_kws).data2d 864 plt.close() 865 866 # Plot the heatmap and adjust axes 867 fig.append_trace( 868 go.Heatmap( 869 z=data, 870 x=data.columns, 871 y=data.index.astype(str), 872 colorscale='RdBu_r', 873 hovertemplate='<b>Sample:</b> %{y}<br>'+ 874 '<b>Feature:</b> %{x}<br>'+ 875 '<b>Value:</b>%{z}' 876 '<extra></extra>' 877 ), 878 row=I, col=J 879 ) 880 fig.update_yaxes(row=I, col=J, showticklabels=False, autorange='reversed') 881 fig.update_xaxes(row=I, col=J, showticklabels=True, tickangle=270) 882 fig.update_traces(row=I, col=J, colorbar_len=0.7) 883 884 # Add row colors 885 if row_groups is not None: 886 for j, grp in enumerate(row_group_names): 887 # Create row data 888 row_data = [row_groups[r][j] for r in data.index] 889 890 # Encode row data numerically so that heatmap can be plotted 891 le = LabelEncoder().fit(row_data) 892 Z = le.transform(row_data) 893 894 # Define colorscale 895 znorm = np.unique((Z-Z.min()) / (Z.max()-Z.min())) 896 zmax = Z.max() 897 colorscale = [[z, row_colors[le.inverse_transform([int(z*zmax)])[0]]] 898 for z in znorm] 899 900 fig.append_trace( 901 go.Heatmap( 902 z=pd.DataFrame(Z), 903 y=data.index.astype(str), 904 x=[grp], 905 text=pd.DataFrame(row_data), 906 colorscale=colorscale, 907 hovertemplate='<b>Sample:<b> %{y}<br>'+ 908 f'<b>{grp}:</b>: %{{text}}'+ 909 '<extra></extra>', 910 showscale=False 911 ), 912 row=I, col=j+1 913 ) 914 fig.update_yaxes(row=I, col=j+1, showticklabels=False, 915 autorange='reversed') 916 fig.update_xaxes(row=I, col=j+1, showticklabels=True, 917 tickangle=270, 918 tickfont={'size': 15, 'family': 'Arial'}) 919 920 # Add column colors 921 if col_groups is not None: 922 for i, grp in enumerate(col_group_names): 923 # Create column data 924 col_data = [col_groups[r][i] for r in data.columns] 925 926 # Encode col data numerically so that heatmap can be plotted 927 le = LabelEncoder().fit(col_data) 928 Z = le.transform(col_data) 929 930 # Define colorscale 931 znorm = np.unique((Z-Z.min()) / (Z.max()-Z.min())) 932 zmax = Z.max() 933 colorscale = [[z, col_colors[le.inverse_transform([int(z*zmax)])[0]]] 934 for z in znorm] 935 936 fig.append_trace( 937 go.Heatmap( 938 z=pd.DataFrame(Z).T, 939 y=[grp], 940 x=data.columns, 941 text=pd.DataFrame(col_data).T, 942 colorscale=colorscale, 943 hovertemplate='<b>Feature:</b> %{x}<br>'+ 944 f'<b>{grp}:</b>: %{{text}}'+ 945 '<extra></extra>', 946 showscale=False 947 ), 948 row=i+1, col=J 949 ) 950 fig.update_yaxes(row=i+1, col=J, showticklabels=True, 951 autorange='reversed', side='right', 952 tickfont={'size': 15, 'family': 'Arial'}) 953 fig.update_xaxes(row=i+1, col=J, showticklabels=False) 954 956 # BUG: Clustering Rows makes some data disappear; 957 # looks like some aggregation is happening 958 959 return fig
28def set_renderer(renderer): 29 """ 30 Set the plotly default renderer 31 """ 32 33 pio.renderers.default = renderer
Set the plotly default renderer
36class colormap: 37 """ 38 Custom colormaps for plotly figures 39 40 These colormaps assume the data has been scaled to between 0 and 1. 41 42 Attributes 43 ---------- 44 `OgBu` : list 45 Seaborn diverging colorscale from blue (low) to orange (high) 46 """ 47 48 OgBu = [[0.00, '#3F7F93'], [0.10, '#6296A6'], [0.20, '#85ADB9'], 49 [0.30, '#A9C4CC'], [0.40, '#CDDBE0'], [0.50, '#F2F1F1'], 50 [0.60, '#E9D2CD'], [0.70, '#DFB3A7'], [0.80, '#D69483'], 51 [0.90, '#CC745D'], [1.00, '#C3553A']]
Custom colormaps for plotly figures
These colormaps assume the data has been scaled to between 0 and 1.
Attributes
OgBu
(list): Seaborn diverging colorscale from blue (low) to orange (high)
56class LabelEncoder: 57 """ 58 Encode target labels with values between 0 and n_classes-1 59 60 Attributes 61 ---------- 62 `encoder` : dict 63 dictionary mapping target labels to encodings 64 `decoder` : dict 65 dictionary mapping encodings to target labels 66 `dtype` : np.dtype 67 dtype of original labels 68 69 Methods 70 ------- 71 `encode(labels)` 72 Encode a list of target labels 73 `decode(enc_labels)` 74 Decode a list of encoded labels 75 """ 76 77 def encode(self, labels): 78 """ 79 Parameters 80 ---------- 81 `labels` : array_like 82 list of target labels 83 84 Returns 85 ------- 86 np.array 87 Encoded labels 88 89 Raises 90 ------ 91 AssertionError 92 If `enc_labels` is not 1-dimensional 93 """ 94 95 labels = np.asarray(labels) 96 assert labels.ndim == 1, "labels must be 1-dimensional" 97 98 # Get unique classes 99 classes = np.unique(labels) 100 101 # Create and store maps 102 self.encoder = {l: float(e) for e, l in enumerate(classes)} 103 self.decoder = {float(e): l for e, l in enumerate(classes)} 104 self.dtype = labels.dtype 105 106 return np.fromiter(map(self.encoder.get, labels), dtype=float) 107 108 def decode(self, enc_labels): 109 """ 110 Parameters 111 ---------- 112 `enc_labels` : array_like 113 list of encoded lavels 114 115 Returns 116 ------- 117 np.array 118 Decoded labels 119 120 Raises 121 ------ 122 AssertionError 123 If `enc_labels` is not 1-dimensional 124 """ 125 126 enc_labels = np.asarray(enc_labels) 127 assert enc_labels.ndim == 1, "enc_labels must be 1-dimensional" 128 129 return np.fromiter(map(self.decoder.get, enc_labels), dtype=self.dtype)
Encode target labels with values between 0 and n_classes-1
Attributes
encoder
(dict): dictionary mapping target labels to encodingsdecoder
(dict): dictionary mapping encodings to target labelsdtype
(np.dtype): dtype of original labels
Methods
encode(labels)
Encode a list of target labels
decode(enc_labels)
Decode a list of encoded labels
77 def encode(self, labels): 78 """ 79 Parameters 80 ---------- 81 `labels` : array_like 82 list of target labels 83 84 Returns 85 ------- 86 np.array 87 Encoded labels 88 89 Raises 90 ------ 91 AssertionError 92 If `enc_labels` is not 1-dimensional 93 """ 94 95 labels = np.asarray(labels) 96 assert labels.ndim == 1, "labels must be 1-dimensional" 97 98 # Get unique classes 99 classes = np.unique(labels) 100 101 # Create and store maps 102 self.encoder = {l: float(e) for e, l in enumerate(classes)} 103 self.decoder = {float(e): l for e, l in enumerate(classes)} 104 self.dtype = labels.dtype 105 106 return np.fromiter(map(self.encoder.get, labels), dtype=float)
Parameters
labels
(array_like): list of target labels
Returns
- np.array: Encoded labels
Raises
- AssertionError: If
enc_labels
is not 1-dimensional
108 def decode(self, enc_labels): 109 """ 110 Parameters 111 ---------- 112 `enc_labels` : array_like 113 list of encoded lavels 114 115 Returns 116 ------- 117 np.array 118 Decoded labels 119 120 Raises 121 ------ 122 AssertionError 123 If `enc_labels` is not 1-dimensional 124 """ 125 126 enc_labels = np.asarray(enc_labels) 127 assert enc_labels.ndim == 1, "enc_labels must be 1-dimensional" 128 129 return np.fromiter(map(self.decoder.get, enc_labels), dtype=self.dtype)
Parameters
enc_labels
(array_like): list of encoded lavels
Returns
- np.array: Decoded labels
Raises
- AssertionError: If
enc_labels
is not 1-dimensional
184def plate_heatmap(data, feature, time_col='timepoint', colorscale=colormap.OgBu): 185 """ 186 Create an interactive plate heatmap; Including an animation for timelapses 187 188 This function assumes that `data` contains the following columns: `row`, 189 `column`, `time_col`, `compound`, `conc`. 190 191 Parameters 192 ---------- 193 `data` : pd.DataFrame 194 a data frame of features including certain metadata columns 195 `feature` : str 196 feature to populate plate with 197 `time_col` : str, optional 198 column that defines time points, by default 'timepoint' 199 This assumes the first time point is 1. 200 `colorscale` : list, optional 201 Plotly-compatible colormap, by default `colormap.OgBu` 202 See `colormap.OgBu` for examples. 203 204 Returns 205 ------- 206 go.Figure 207 Plotly figure 208 """ 209 210 data.columns = [x.lower() for x in data.columns] 211 feature = feature.lower() 212 time_col = time_col.lower() 213 214 assert feature in data.columns, "feature must be a column in data" 215 assert time_col in data.columns, f"{time_col} must be a column in data" 216 assert data[time_col].min() > 0, "the first time point must be 1 not 0" 217 assert ('compound' in data.columns) and ('conc' in data.columns), \ 218 "'compound' and 'conc' must be columns in data" 219 220 def platemap(x, y, z, cmpd, conc): 221 """ 222 Wrapper for go.Heatmap 223 """ 224 225 # Insert line breaks in the compound names 226 cmpd = [ 227 [x.replace(' ', '<br>') if isinstance(x, str) else '' for x in sub] 228 for sub in cmpd 229 ] 230 231 return go.Heatmap( 232 x=x, y=y, z=z, 233 colorscale=colorscale, 234 text=cmpd, 235 customdata=conc, 236 hovertemplate=( 237 '<b>Well:</b> %{y}%{x}<br>' + 238 '<b>Compound:</b> %{text}' + 239 '<b>Concentration:</b> %{customdata}<br>' + 240 '<b>Value:</b> %{z:.2f}<extra></extra>' 241 ), 242 texttemplate='%{text}', 243 # textfont_size=8.5 244 ) 245 246 # Extract time points 247 times = data[time_col].unique() 248 249 # Reformat data as plate 250 plate, rows, cols = _make_plate(data, feature, time_col) 251 252 # Create data for tooltips 253 cmpd = _make_plate(data, 'compound', time_col)[0][0,...] 254 conc = _make_plate(data, 'conc', time_col)[0][0,...] 255 256 # Create figure and fill in layout 257 fig = {'data': [], 'layout': {}, 'frames': []} 258 259 fig['layout'] = go.Layout( 260 title=feature, 261 title_x=0.5, 262 xaxis={ 263 'showgrid': False, 264 'showticklabels': True, 265 'tickfont': {'size': 16, 'color': 'black'} 266 }, 267 yaxis={ 268 'showgrid': False, 269 'showticklabels': True, 270 'tickfont': {'size': 16, 'color': 'black'} 271 }, 272 hovermode='closest', 273 updatemenus=[{ 274 "buttons": [ 275 {"args": [None, {"frame": {"duration": 500, "redraw": True}, 276 "fromcurrent": True, 277 "transition": {"duration": 300, 278 "easing": "quadratic-in-out"}}], 279 "label": "Play", 280 "method": "animate"}, 281 {"args": [[None], {"frame": {"duration": 0, "redraw": True}, 282 "mode": "immediate", 283 "transition": {"duration": 0}}], 284 "label": "Pause", 285 "method": "animate"} 286 ], 287 "direction": "left", 288 "pad": {"r": 10, "t": 87}, 289 "showactive": False, 290 "type": "buttons", 291 "x": 0.1, 292 "xanchor": "right", 293 "y": 0, 294 "yanchor": "top" 295 }] if len(times) > 1 else None 296 ) 297 298 # Add Time 0 plate to data 299 fig['data'].append(platemap(cols, rows, plate[0,...], cmpd, conc)) 300 301 # Create frames & animations 302 if len(times) > 1: 303 # Create sliders 304 sliders = { 305 "active": 0, 306 "yanchor": "top", 307 "xanchor": "left", 308 "currentvalue": { 309 "font": {"size": 20}, 310 "prefix": "Timepoint:", 311 "visible": True, 312 "xanchor": "right" 313 }, 314 "transition": {"duration": 300, "easing": "cubic-in-out"}, 315 "pad": {"b": 10, "t": 50}, 316 "len": 0.9, 317 "x": 0.1, 318 "y": 0, 319 "steps": [] 320 } 321 322 # Create frames 323 for time in data[time_col].unique().astype(str): 324 # New frame 325 fig['frames'].append({ 326 "data": platemap(cols, rows, plate[int(time)-1,...], cmpd, conc), 327 "name": str(time) 328 }) 329 330 # Corresponding slider step 331 sliders['steps'].append({ 332 "args": [ 333 [time], 334 {"frame": {"duration": 300, "redraw": True}, 335 "mode": "immediate", 336 "transition": {"duration": 300}} 337 ], 338 "label": time, 339 "method": "animate" 340 }) 341 342 # Add sliders to figure layout 343 fig['layout']['sliders'] = [sliders] 344 345 return go.Figure(fig)
Create an interactive plate heatmap; Including an animation for timelapses
This function assumes that data
contains the following columns: row
,
column
, time_col
, compound
, conc
.
Parameters
data
(pd.DataFrame): a data frame of features including certain metadata columnsfeature
(str): feature to populate plate withtime_col
(str, optional): column that defines time points, by default 'timepoint' This assumes the first time point is 1.colorscale
(list, optional): Plotly-compatible colormap, by defaultcolormap.OgBu
Seecolormap.OgBu
for examples.
Returns
- go.Figure: Plotly figure
348def pca_comps(proj, exp_var, time_col='timepoint', n_comps=4): 349 """ 350 Plot a scatter grid of PCA components 351 352 This function is written to use the output from `process.dim_reduction` 353 354 Parameters 355 ---------- 356 `proj` : _pd.DataFrame 357 Data frame with pca projections, from `process.dim_reduction` 358 `exp_var` : array_like 359 List of explained variances for each PCA component 360 `time_col` : str, optional 361 Column containing time points; must be in index; by default 'timepoint' 362 `n_comps` : int, optional 363 Number of pca components to plot, by default 4 364 365 Returns 366 ------- 367 go.Figure 368 Plotly figure 369 """ 370 371 proj.columns = [x.lower() for x in proj.columns] 372 assert 'variable' in proj.columns, "variable must be a column in proj" 373 assert 'compound' in proj.columns, f"compound must be a column in proj" 374 assert time_col in proj.columns, f"{time_col} must be a column in proj" 375 assert 'conc' in proj.columns, "conc must be a column in proj" 376 377 # Prepare matrix of components as well as variables for plotting 378 comp_cols = [str(x+1) for x in range(n_comps)] 379 comps = (proj.query("variable == 'PCA'") 380 .reset_index(drop=True) 381 [['compound', 'conc', time_col, *comp_cols]]) 382 compounds = comps['compound'] 383 comps.drop(['compound', time_col, 'conc'], axis=1, inplace=True) 384 385 # Create labels 386 labels = {str(i): f"PC {i+1} ({var:.2f}%)" for i, var in enumerate(exp_var)} 387 390 391 # Create figure 392 fig = px.scatter_matrix( 393 comps.values, 394 labels=labels, 395 dimensions=range(n_comps), 396 color=compounds, 397 opacity=0.5, 398 template='plotly_white' 399 ) 400 fig.update_traces(diagonal_visible=False) 401 fig.update_layout(paper_bgcolor='white', plot_bgcolor='white', height=500) 402 403 return fig
Plot a scatter grid of PCA components
This function is written to use the output from process.dim_reduction
Parameters
proj
(_pd.DataFrame): Data frame with pca projections, fromprocess.dim_reduction
exp_var
(array_like): List of explained variances for each PCA componenttime_col
(str, optional): Column containing time points; must be in index; by default 'timepoint'n_comps
(int, optional): Number of pca components to plot, by default 4
Returns
- go.Figure: Plotly figure
406def clusters(data, compound_a, compound_b, method, time_col='timepoint'): 407 """ 408 Create clustering figures that compare 2 compounds. 409 410 This function is written to use the output from `process.dim_reduction` 411 412 Parameters 413 ---------- 414 `data` : pd.DataFrame 415 _description_ 416 `compound_a` : str 417 Compound A (red points) 418 `compound_b` : str 419 Compound B (green points) 420 `method` : str 421 One of 'PCA', 'tSNE' or 'UMAP' 422 `time_col` : str, optional 423 Column containing time points, by default 'timepoint' 424 425 Returns 426 ------- 427 go.Figure 428 Plotly figure 429 """ 430 431 assert isinstance(data, pd.DataFrame), "data must be a data frame" 432 assert 'compound' in data.columns, "'compound' must be a column in data" 433 assert compound_a in data['compound'].tolist(), \ 434 "compound_a must be present in data['compound']" 435 assert compound_b in data['compound'].tolist(), \ 436 "compound_b must be present in data['compound']" 437 method = method.lower() 438 assert method in ['pca', 'tsne', 'umap'], \ 439 "method must be one of 'PCA', 'tSNE' or 'UMAP'" 440 assert time_col in data.columns, f"{time_col} must be a column in proj" 441 442 times = data[time_col].unique() 443 method = method.upper() 444 445 def create_traces(compound, colorscale, linecolor, cbar_pos): 446 """ 447 Create traces for a particular compound 448 """ 449 450 # Define colors for different time points 451 if len(times) > 1: 452 color_map = {tp: c for tp, c in 453 zip(times, sns.color_palette(colorscale, len(times)).as_hex())} 454 data['timecolor'] = data[time_col].replace(color_map) 455 else: 456 data['timecolor'] = colorscale 457 458 # Subset data for compound 459 cmpd = (data 460 .query(f"compound == '{compound}' & variable == '{method}'") 461 .reset_index(drop=True)) 462 463 # Encode different sizes for each concentration 464 cmpd['conc'] = cmpd['conc'].astype(float) 465 concs = sorted(cmpd['conc'].unique()) 466 if len(concs) > 1: 467 conc_map = {c: s*3 for c, s in zip(concs, range(1, len(concs)+1))} 468 else: 469 conc_map = {concs[0]: 20} 470 471 # Create traces 472 traces = [] 473 474 for i, conc in enumerate(concs): 475 # Create mask to subset data 476 I = (cmpd['conc'] == conc) 477 478 traces.append( 479 go.Scatter( 480 x=cmpd.loc[I, '1'], 481 y=cmpd.loc[I, '2'], 482 mode='markers', 483 marker=dict( 484 color=cmpd.loc[I, time_col], 485 size=conc_map[conc], 486 opacity=0.5, 487 colorscale=colorscale, 488 colorbar=dict( 489 x=cbar_pos, 490 thickness=20, 491 yanchor='middle', 492 len=.7 493 ) if (i == 0) and (len(times) > 1) else None, 494 line=dict(width=1.2, color=linecolor) 495 ), 496 name=str(conc), 497 legendgroup=compound.replace(' ', '').lower(), 498 legendgrouptitle_text=(compound if i == 0 else None) 499 ) 500 ) 501 502 return traces 503 504 # Create figure traces 505 traces = [*create_traces(compound_a, 'Reds', 'red', -0.25), 506 *create_traces(compound_b, 'Greens', 'green', -0.35)] 507 508 # Create layout 509 layout = go.Layout( 510 legend=dict(tracegroupgap=20, groupclick='toggleitem'), 511 template='plotly_white', 512 margin=dict(l=20, r=20, t=70, b=40), 513 height=500, 514 title=f"{compound_a} vs {compound_b}<br>({method} Clusters)", 515 title_x=0.5, 516 xaxis_title=f'{method} 1', 517 yaxis_title=f'{method} 2', 518 ) 519 520 # Create figure 521 fig = go.Figure(data=traces, layout=layout) 522 fig.update_yaxes( 523 scaleanchor='x', 524 scaleratio=1 525 ) 526 527 # Annotate the colorbars 528 if len(times) > 1: 529 fig.add_annotation( 530 xref='paper', 531 yref='paper', 532 x=-0.33, 533 y= 0.92, 534 text='Time Point', 535 font_size=14, 536 showarrow=False 537 ) 538 539 return fig
Create clustering figures that compare 2 compounds.
This function is written to use the output from process.dim_reduction
Parameters
data
(pd.DataFrame): _description_compound_a
(str): Compound A (red points)compound_b
(str): Compound B (green points)method
(str): One of 'PCA', 'tSNE' or 'UMAP'time_col
(str, optional): Column containing time points, by default 'timepoint'
Returns
- go.Figure: Plotly figure
624def distplot(data, features, group_col, tooltips=None, kind='box', col_wrap=2, 625 title_len=30): 626 """ 627 Create boxplots showing the distibution of features for different groups. 628 629 This generates a figure with as many subplots as there are features 630 631 Parameters 632 ---------- 633 `data` : pd.DataFrame 634 Data frame to plot 635 `features` : list 636 List of features to visualize 637 `group_col` : str 638 `data` column that contains groups of interest 639 `tooltips` : dict, optional 640 Dictionary that defines annotation tooltips, by default None 641 Keys = Tooltip Name; 642 Values = Corresponding column in `data` 643 `kind` : str, optional 644 Type of plot to generate; one of 'box', 'bar', by default 'box' 645 `col_wrap` : int, optional 646 Number of columns allowed in layout, by default 3 647 `title_len` : int, optional 648 Wrap length for subplot titles, by default 30 649 650 Returns 651 ------- 652 go.Figure 653 Plotly figure 654 655 Raises 656 ------ 657 NotImplementedError 658 When `kind != 'box'` 659 """ 660 661 assert isinstance(data, pd.DataFrame), "data must be a data frame" 662 assert kind in ['box', 'bar'], "kind must be one of 'box', 'bar'." 663 assert group_col in data.columns, "group_col must be a column in data" 664 for f in features: 665 assert f in data.columns, "features must contain columns from data" 666 if tooltips is not None: 667 for col in tooltips.values(): 668 assert col in data.columns, \ 669 "Values of tooltips must be columns of data" 670 671 # Determine grid dimensions & positions 672 nrows, ncols, positions = _make_grid(features, col_wrap=col_wrap) 673 674 # Wrap text for subplot titles 675 titles = ['<br>'.join(textwrap.wrap(x, title_len)) for x in features] 676 677 # Get list of groups & colors 678 groups = data[group_col].unique() 679 colors = _get_colors(len(groups)) 680 681 # Create figure 682 fig = sp.make_subplots(rows=nrows, cols=ncols, subplot_titles=titles) 683 684 # Add traces 685 if kind == 'box': 686 for feature, pos in positions.items(): 687 for grp, color in zip(groups, colors): 688 # Prepare data for annotations 689 _data = data.query(f"{group_col} == '{grp}'") 690 text = 'well: ' + data['well'] + '<br>' 691 if tooltips is not None: 692 for name, col in tooltips.items(): 693 text += f'{name}: ' + data[col].astype(str) + '<br>' 694 695 fig.add_trace( 696 go.Box( 697 y=_data[feature], 698 name=grp, 699 text=text.tolist(), 700 hovertemplate='%{text}', 701 legendgroup=grp, 702 marker_color=color, 703 showlegend=True if pos['x'] == pos['y'] == 1 else False 704 ), 705 row=pos['x'], 706 col=pos['y'] 707 ) 708 elif kind == 'bar': 709 raise NotImplementedError("Can't do that yet. Working on it.") 710 else: 711 raise NotImplementedError("Can't do that yet.") 712 713 fig.update_xaxes(showticklabels=False) 714 fig.update_layout(template='plotly_white') 715 fig.update_annotations(font=dict(family="Helvetica", size=14)) 716 717 return fig
Create boxplots showing the distibution of features for different groups.
This generates a figure with as many subplots as there are features
Parameters
data
(pd.DataFrame): Data frame to plotfeatures
(list): List of features to visualizegroup_col
(str):data
column that contains groups of interesttooltips
(dict, optional): Dictionary that defines annotation tooltips, by default None Keys = Tooltip Name;
Values = Corresponding column indata
kind
(str, optional): Type of plot to generate; one of 'box', 'bar', by default 'box'col_wrap
(int, optional): Number of columns allowed in layout, by default 3title_len
(int, optional): Wrap length for subplot titles, by default 30
Returns
- go.Figure: Plotly figure
Raises
- NotImplementedError: When
kind != 'box'
720def textplot(text): 721 """ 722 Create a blank figure to display some text. Serves as placeholder for 723 actual figure. 724 725 Parameters 726 ---------- 727 `text` : str 728 Message to display in figure 729 730 Return 731 ------ 732 go.Figure 733 Plotly figure 734 """ 735 736 fig = go.Figure( 737 go.Scatter( 738 x=[0], y=[0], text=[text], textposition='top center', 739 textfont_size=16, mode='text', hoverinfo='skip' 740 ) 741 ) 742 fig.update_layout(template='simple_white', height=300) 743 fig.update_xaxes(visible=False, fixedrange=True) 744 fig.update_yaxes(visible=False, fixedrange=True) 745 746 return fig
Create a blank figure to display some text. Serves as placeholder for actual figure.
Parameters
text
(str): Message to display in figure
Return
go.Figure Plotly figure
749def gifify(fig, file, frame_title='Frame', fps=30) -> None: 750 """ 751 Export a plotly animation as a gif 752 753 Parameters 754 ---------- 755 `fig` : go.Figure 756 Plotly figure 757 `file` : str 758 Path to file where figure gif will be stored 759 `frame_title` : str, optional 760 Title that describes each frame, by default 'Frame' 761 `fps` : int, optional 762 Frame rate, by default 30 763 """ 764 765 assert isinstance(fig, go.Figure), \ 766 "This only works for plotly figures" 767 assert file.endswith('.gif'), "file must be a .gif" 768 769 import moviepy.editor as mpy 770 from PIL import Image 771 772 def fig2array(fig): 773 """ 774 Convert a plotly figure to a numpy array 775 """ 776 777 bytes = fig.to_image(format='jpg', engine='kaleido') 778 buffer = io.BytesIO(bytes) 779 img = Image.open(buffer) 780 781 return img 782 783 # Remove sliders and buttons from figure layout 784 exclude = ['updatemenus', 'sliders'] 785 layout = fig.to_dict()['layout'] 786 layout = {k: v for k, v in layout.items() if not k in exclude} 787 788 # Create list to store frames (as images) 789 frames = [] 790 for i, frame in enumerate(fig['frames']): 791 _fig = go.Figure(data=frame['data'], layout=layout) 792 _fig.update_layout(title=f"{frame_title} {i+1}", title_x=0.5) 793 frames.append( fig2array(_fig) ) 794 795 # Create animation 796 make_frame = lambda t: frames[int(t)] 797 anim = mpy.VideoClip(make_frame, duration=len(frames)) 798 anim.write_gif(file, fps=fps, logger=None) 799 print("Done :thumbsup:")
Export a plotly animation as a gif
Parameters
fig
(go.Figure): Plotly figurefile
(str): Path to file where figure gif will be storedframe_title
(str, optional): Title that describes each frame, by default 'Frame'fps
(int, optional): Frame rate, by default 30
802def heatmap(data, col_groups=None, col_colors=None, col_group_names=None, 803 row_groups=None, row_colors=None, row_group_names=None, 804 clust_cols=True, clust_rows=True, cluster_kws=dict()): 805 """ 806 Construct an interactive heatmap 807 808 Parameters 809 ---------- 810 data : pd.DataFrame 811 Data to plota 812 {row, col}_groups : dict 813 Dictionary assigning groups to rows or columns. 814 Keys should be the index or columns of data. 815 Values should be a list of groups. 816 {row, col}_group_names : list 817 Names for each of the row/col groups 818 Should be the same length as the lists in {row, col}_groups 819 {row, col}_colors : dict 820 Dictionary defining colors for each group. 821 Keys = groups; Values = colors; 822 clust_{rows, cols} : bool 823 Should row and/or column clustering be performed 824 cluster_kws : dict 825 kwargs for sns.clustermap 826 """ 827 828 from sklearn.preprocessing import LabelEncoder 829 831 832 # Check inputs 834 ## col_groups and row_groups: each value should be of the same length 835 if col_group_names is None: 836 col_group_names = [] 837 if row_group_names is None: 838 row_group_names = [] 839 840 # Determine the size of the subplot grid 841 n_col_grps = len(col_group_names) 842 n_row_grps = len(row_group_names) 843 I, J = n_col_grps+1, n_row_grps+1 844 845 # Define column widths and row heights 846 col_widths = [0.03 for _ in range(J-1)] + [1-(J-1)*0.03] 847 row_heights = [0.07 for _ in range(I-1)] + [1-(I-1)*0.07] 848 849 # Create subplot grid 850 fig = sp.make_subplots( 851 rows=I, 852 cols=J, 853 column_widths=col_widths, 854 row_heights=row_heights, 855 vertical_spacing=0.01, 856 horizontal_spacing=0.01, 857 shared_xaxes=True, 858 shared_yaxes=True 859 ) 860 861 # Perform clustering and extract clustered data frame from seaborn clustermap 862 if clust_cols or clust_rows: 863 data = sns.clustermap(data, row_cluster=clust_rows, col_cluster=clust_cols, 864 **cluster_kws).data2d 865 plt.close() 866 867 # Plot the heatmap and adjust axes 868 fig.append_trace( 869 go.Heatmap( 870 z=data, 871 x=data.columns, 872 y=data.index.astype(str), 873 colorscale='RdBu_r', 874 hovertemplate='<b>Sample:</b> %{y}<br>'+ 875 '<b>Feature:</b> %{x}<br>'+ 876 '<b>Value:</b>%{z}' 877 '<extra></extra>' 878 ), 879 row=I, col=J 880 ) 881 fig.update_yaxes(row=I, col=J, showticklabels=False, autorange='reversed') 882 fig.update_xaxes(row=I, col=J, showticklabels=True, tickangle=270) 883 fig.update_traces(row=I, col=J, colorbar_len=0.7) 884 885 # Add row colors 886 if row_groups is not None: 887 for j, grp in enumerate(row_group_names): 888 # Create row data 889 row_data = [row_groups[r][j] for r in data.index] 890 891 # Encode row data numerically so that heatmap can be plotted 892 le = LabelEncoder().fit(row_data) 893 Z = le.transform(row_data) 894 895 # Define colorscale 896 znorm = np.unique((Z-Z.min()) / (Z.max()-Z.min())) 897 zmax = Z.max() 898 colorscale = [[z, row_colors[le.inverse_transform([int(z*zmax)])[0]]] 899 for z in znorm] 900 901 fig.append_trace( 902 go.Heatmap( 903 z=pd.DataFrame(Z), 904 y=data.index.astype(str), 905 x=[grp], 906 text=pd.DataFrame(row_data), 907 colorscale=colorscale, 908 hovertemplate='<b>Sample:<b> %{y}<br>'+ 909 f'<b>{grp}:</b>: %{{text}}'+ 910 '<extra></extra>', 911 showscale=False 912 ), 913 row=I, col=j+1 914 ) 915 fig.update_yaxes(row=I, col=j+1, showticklabels=False, 916 autorange='reversed') 917 fig.update_xaxes(row=I, col=j+1, showticklabels=True, 918 tickangle=270, 919 tickfont={'size': 15, 'family': 'Arial'}) 920 921 # Add column colors 922 if col_groups is not None: 923 for i, grp in enumerate(col_group_names): 924 # Create column data 925 col_data = [col_groups[r][i] for r in data.columns] 926 927 # Encode col data numerically so that heatmap can be plotted 928 le = LabelEncoder().fit(col_data) 929 Z = le.transform(col_data) 930 931 # Define colorscale 932 znorm = np.unique((Z-Z.min()) / (Z.max()-Z.min())) 933 zmax = Z.max() 934 colorscale = [[z, col_colors[le.inverse_transform([int(z*zmax)])[0]]] 935 for z in znorm] 936 937 fig.append_trace( 938 go.Heatmap( 939 z=pd.DataFrame(Z).T, 940 y=[grp], 941 x=data.columns, 942 text=pd.DataFrame(col_data).T, 943 colorscale=colorscale, 944 hovertemplate='<b>Feature:</b> %{x}<br>'+ 945 f'<b>{grp}:</b>: %{{text}}'+ 946 '<extra></extra>', 947 showscale=False 948 ), 949 row=i+1, col=J 950 ) 951 fig.update_yaxes(row=i+1, col=J, showticklabels=True, 952 autorange='reversed', side='right', 953 tickfont={'size': 15, 'family': 'Arial'}) 954 fig.update_xaxes(row=i+1, col=J, showticklabels=False) 955 957 # BUG: Clustering Rows makes some data disappear; 958 # looks like some aggregation is happening 959 960 return fig
Construct an interactive heatmap
Parameters
- data (pd.DataFrame): Data to plota
- {row, col}_groups (dict): Dictionary assigning groups to rows or columns. Keys should be the index or columns of data. Values should be a list of groups.
- {row, col}_group_names (list): Names for each of the row/col groups Should be the same length as the lists in {row, col}_groups
- {row, col}_colors (dict): Dictionary defining colors for each group. Keys = groups; Values = colors;
- clust_{rows, cols} (bool): Should row and/or column clustering be performed
- cluster_kws (dict): kwargs for sns.clustermap