docs for muutils v0.6.13
View Source on GitHub

muutils.nbutils.configure_notebook

shared utilities for setting up a notebook


  1"""shared utilities for setting up a notebook"""
  2
  3from __future__ import annotations
  4
  5import os
  6import typing
  7import warnings
  8
  9import matplotlib.pyplot as plt  # type: ignore[import]
 10
 11
 12class PlotlyNotInstalledWarning(UserWarning):
 13    pass
 14
 15
 16# handle plotly importing
 17PLOTLY_IMPORTED: bool
 18try:
 19    import plotly.io as pio  # type: ignore[import]
 20except ImportError:
 21    warnings.warn(
 22        "Plotly not installed. Plotly plots will not be available.",
 23        PlotlyNotInstalledWarning,
 24    )
 25    PLOTLY_IMPORTED = False
 26else:
 27    PLOTLY_IMPORTED = True
 28
 29# figure out if we're in a jupyter notebook
 30try:
 31    from IPython import get_ipython  # type: ignore[import-not-found]
 32
 33    IN_JUPYTER = get_ipython() is not None
 34except ImportError:
 35    IN_JUPYTER = False
 36
 37# muutils imports
 38from muutils.mlutils import get_device, set_reproducibility  # noqa: E402
 39
 40# handling figures
 41PlottingMode = typing.Literal["ignore", "inline", "widget", "save"]
 42PLOT_MODE: PlottingMode = "inline"
 43CONVERSION_PLOTMODE_OVERRIDE: PlottingMode | None = None
 44FIG_COUNTER: int = 0
 45FIG_OUTPUT_FMT: str | None = None
 46FIG_NUMBERED_FNAME: str = "figure-{num}"
 47FIG_CONFIG: dict | None = None
 48FIG_BASEPATH: str | None = None
 49CLOSE_AFTER_PLOTSHOW: bool = False
 50
 51MATPLOTLIB_FORMATS = ["pdf", "png", "jpg", "jpeg", "svg", "eps", "ps", "tif", "tiff"]
 52TIKZPLOTLIB_FORMATS = ["tex", "tikz"]
 53
 54
 55class UnknownFigureFormatWarning(UserWarning):
 56    pass
 57
 58
 59def universal_savefig(fname: str, fmt: str | None = None) -> None:
 60    # try to infer format from fname
 61    if fmt is None:
 62        fmt = fname.split(".")[-1]
 63
 64    if not (fmt in MATPLOTLIB_FORMATS or fmt in TIKZPLOTLIB_FORMATS):
 65        warnings.warn(
 66            f"Unknown format '{fmt}', defaulting to '{FIG_OUTPUT_FMT}'",
 67            UnknownFigureFormatWarning,
 68        )
 69        fmt = FIG_OUTPUT_FMT
 70
 71    # not sure why linting is throwing an error here
 72    if not fname.endswith(fmt):  # type: ignore[arg-type]
 73        fname += f".{fmt}"
 74
 75    if fmt in MATPLOTLIB_FORMATS:
 76        plt.savefig(fname, format=fmt, bbox_inches="tight")
 77    elif fmt in TIKZPLOTLIB_FORMATS:
 78        import tikzplotlib  # type: ignore[import]
 79
 80        tikzplotlib.save(fname)
 81    else:
 82        warnings.warn(f"Unknown format '{fmt}', going with matplotlib default")
 83        plt.savefig(fname, bbox_inches="tight")
 84
 85
 86def setup_plots(
 87    plot_mode: PlottingMode = "inline",
 88    fig_output_fmt: str | None = "pdf",
 89    fig_numbered_fname: str = "figure-{num}",
 90    fig_config: dict | None = None,
 91    fig_basepath: str | None = None,
 92    close_after_plotshow: bool = False,
 93) -> None:
 94    """Set up plot saving/rendering options"""
 95    global \
 96        PLOT_MODE, \
 97        CONVERSION_PLOTMODE_OVERRIDE, \
 98        FIG_COUNTER, \
 99        FIG_OUTPUT_FMT, \
100        FIG_NUMBERED_FNAME, \
101        FIG_CONFIG, \
102        FIG_BASEPATH, \
103        CLOSE_AFTER_PLOTSHOW
104
105    # set plot mode, handling override
106    if CONVERSION_PLOTMODE_OVERRIDE is not None:
107        # override if set
108        PLOT_MODE = CONVERSION_PLOTMODE_OVERRIDE
109    else:
110        # otherwise use the given plot mode
111        PLOT_MODE = plot_mode
112
113    FIG_COUNTER = 0
114    CLOSE_AFTER_PLOTSHOW = close_after_plotshow
115
116    if PLOT_MODE == "inline":
117        if IN_JUPYTER:
118            ipython = get_ipython()
119            ipython.magic("matplotlib inline")
120        else:
121            raise RuntimeError(
122                f"Cannot use inline plotting outside of Jupyter\n{PLOT_MODE = }\t{CONVERSION_PLOTMODE_OVERRIDE = }"
123            )
124        return
125    elif PLOT_MODE == "widget":
126        if IN_JUPYTER:
127            ipython = get_ipython()
128            ipython.magic("matplotlib widget")
129        else:
130            # matplotlib outside of jupyter will bring up a new window by default
131            pass
132        return
133    elif PLOT_MODE == "ignore":
134        # disable plotting
135        plt.show = lambda: None  # type: ignore[misc]
136        return
137
138    # everything except saving handled up to this point
139    assert PLOT_MODE == "save", f"Invalid plot mode: {PLOT_MODE}"
140
141    FIG_OUTPUT_FMT = fig_output_fmt
142    FIG_NUMBERED_FNAME = fig_numbered_fname
143    FIG_CONFIG = fig_config
144
145    # set default figure format in rcParams savefig.format
146    plt.rcParams["savefig.format"] = FIG_OUTPUT_FMT
147    if FIG_OUTPUT_FMT in TIKZPLOTLIB_FORMATS:
148        try:
149            import tikzplotlib  # type: ignore[import] # noqa: F401
150        except ImportError:
151            warnings.warn(
152                f"Tikzplotlib not installed. Cannot save figures in Tikz format '{FIG_OUTPUT_FMT}', things might break."
153            )
154    else:
155        if FIG_OUTPUT_FMT not in MATPLOTLIB_FORMATS:
156            warnings.warn(
157                f'Unknown figure format, things might break: {plt.rcParams["savefig.format"] = }'
158            )
159
160    # if base path not given, make one
161    if fig_basepath is None:
162        if fig_config is None:
163            # if no config, use the current time
164            from datetime import datetime
165
166            fig_basepath = f"figures/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
167        else:
168            # if config given, convert to string
169            from muutils.misc import dict_to_filename
170
171            fig_basepath = f"figures/{dict_to_filename(fig_config)}"
172
173    FIG_BASEPATH = fig_basepath
174    os.makedirs(fig_basepath, exist_ok=True)
175
176    # if config given, serialize and save that config
177    if fig_config is not None:
178        import json
179
180        from muutils.json_serialize import json_serialize
181
182        with open(f"{fig_basepath}/config.json", "w") as f:
183            json.dump(
184                json_serialize(fig_config),
185                f,
186                indent="\t",
187            )
188
189    print(f"Figures will be saved to: '{fig_basepath}'")
190
191
192def configure_notebook(
193    *args,
194    seed: int = 42,
195    device: typing.Any = None,  # this can be a string, torch.device, or None
196    dark_mode: bool = True,
197    plot_mode: PlottingMode = "inline",
198    fig_output_fmt: str | None = "pdf",
199    fig_numbered_fname: str = "figure-{num}",
200    fig_config: dict | None = None,
201    fig_basepath: str | None = None,
202    close_after_plotshow: bool = False,
203) -> "torch.device|None":  # type: ignore[name-defined] # noqa: F821
204    """Shared Jupyter notebook setup steps
205
206    - Set random seeds and library reproducibility settings
207    - Set device based on availability
208    - Set module reloading before code execution
209    - Set plot formatting
210    - Set plot saving/rendering options
211
212    # Parameters:
213     - `seed : int`
214        random seed across libraries including torch, numpy, and random (defaults to `42`)
215       (defaults to `42`)
216     - `device : typing.Any`
217       pytorch device to use
218       (defaults to `None`)
219     - `dark_mode : bool`
220       figures in dark mode
221       (defaults to `True`)
222     - `plot_mode : PlottingMode`
223       how to display plots, one of `PlottingMode` or `["ignore", "inline", "widget", "save"]`
224       (defaults to `"inline"`)
225     - `fig_output_fmt : str | None`
226       format for saving figures
227       (defaults to `"pdf"`)
228     - `fig_numbered_fname : str`
229        format for saving figures with numbers (if they aren't named)
230       (defaults to `"figure-{num}"`)
231     - `fig_config : dict | None`
232       metadata to save with the figures
233       (defaults to `None`)
234     - `fig_basepath : str | None`
235        base path for saving figures
236       (defaults to `None`)
237     - `close_after_plotshow : bool`
238        close figures after showing them
239       (defaults to `False`)
240
241    # Returns:
242     - `torch.device|None`
243       the device set, if torch is installed
244    """
245
246    # set some globals related to plotting
247    setup_plots(
248        plot_mode=plot_mode,
249        fig_output_fmt=fig_output_fmt,
250        fig_numbered_fname=fig_numbered_fname,
251        fig_config=fig_config,
252        fig_basepath=fig_basepath,
253        close_after_plotshow=close_after_plotshow,
254    )
255
256    global PLOT_MODE, FIG_OUTPUT_FMT, FIG_BASEPATH
257
258    print(f"set up plots with {PLOT_MODE = }, {FIG_OUTPUT_FMT = }, {FIG_BASEPATH = }")
259
260    # Set seeds and other reproducibility-related library options
261    set_reproducibility(seed)
262
263    # Reload modules before executing user code
264    if IN_JUPYTER:
265        ipython = get_ipython()
266        if "IPython.extensions.autoreload" not in ipython.extension_manager.loaded:
267            ipython.magic("load_ext autoreload")
268            ipython.magic("autoreload 2")
269
270        # Specify plotly renderer for vscode
271        if PLOTLY_IMPORTED:
272            pio.renderers.default = "notebook_connected"
273
274            if dark_mode:
275                pio.templates.default = "plotly_dark"
276                plt.style.use("dark_background")
277
278    try:
279        # Set device
280        device = get_device(device)
281        return device
282    except ImportError:
283        warnings.warn("Torch not installed. Cannot get/set device.")
284        return None
285
286
287def plotshow(
288    fname: str | None = None,
289    plot_mode: PlottingMode | None = None,
290    fmt: str | None = None,
291):
292    """Show the active plot, depending on global configs"""
293    global FIG_COUNTER, CLOSE_AFTER_PLOTSHOW, PLOT_MODE
294    FIG_COUNTER += 1
295
296    if plot_mode is None:
297        plot_mode = PLOT_MODE
298
299    if plot_mode == "save":
300        # get numbered figure name if not given
301        if fname is None:
302            fname = FIG_NUMBERED_FNAME.format(num=FIG_COUNTER)
303
304        # save figure
305        assert FIG_BASEPATH is not None
306        universal_savefig(os.path.join(FIG_BASEPATH, fname), fmt=fmt)
307    elif plot_mode == "ignore":
308        # do nothing
309        pass
310    elif plot_mode == "inline":
311        # show figure
312        plt.show()
313    elif plot_mode == "widget":
314        # show figure
315        plt.show()
316    else:
317        warnings.warn(f"Invalid plot mode: {plot_mode}")
318
319    if CLOSE_AFTER_PLOTSHOW:
320        plt.close()

class PlotlyNotInstalledWarning(builtins.UserWarning):
13class PlotlyNotInstalledWarning(UserWarning):
14    pass

Base class for warnings generated by user code.

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
PLOTLY_IMPORTED: bool = True
PlottingMode = typing.Literal['ignore', 'inline', 'widget', 'save']
PLOT_MODE: Literal['ignore', 'inline', 'widget', 'save'] = 'inline'
CONVERSION_PLOTMODE_OVERRIDE: Optional[Literal['ignore', 'inline', 'widget', 'save']] = None
FIG_COUNTER: int = 0
FIG_OUTPUT_FMT: str | None = None
FIG_NUMBERED_FNAME: str = 'figure-{num}'
FIG_CONFIG: dict | None = None
FIG_BASEPATH: str | None = None
CLOSE_AFTER_PLOTSHOW: bool = False
MATPLOTLIB_FORMATS = ['pdf', 'png', 'jpg', 'jpeg', 'svg', 'eps', 'ps', 'tif', 'tiff']
TIKZPLOTLIB_FORMATS = ['tex', 'tikz']
class UnknownFigureFormatWarning(builtins.UserWarning):
56class UnknownFigureFormatWarning(UserWarning):
57    pass

Base class for warnings generated by user code.

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def universal_savefig(fname: str, fmt: str | None = None) -> None:
60def universal_savefig(fname: str, fmt: str | None = None) -> None:
61    # try to infer format from fname
62    if fmt is None:
63        fmt = fname.split(".")[-1]
64
65    if not (fmt in MATPLOTLIB_FORMATS or fmt in TIKZPLOTLIB_FORMATS):
66        warnings.warn(
67            f"Unknown format '{fmt}', defaulting to '{FIG_OUTPUT_FMT}'",
68            UnknownFigureFormatWarning,
69        )
70        fmt = FIG_OUTPUT_FMT
71
72    # not sure why linting is throwing an error here
73    if not fname.endswith(fmt):  # type: ignore[arg-type]
74        fname += f".{fmt}"
75
76    if fmt in MATPLOTLIB_FORMATS:
77        plt.savefig(fname, format=fmt, bbox_inches="tight")
78    elif fmt in TIKZPLOTLIB_FORMATS:
79        import tikzplotlib  # type: ignore[import]
80
81        tikzplotlib.save(fname)
82    else:
83        warnings.warn(f"Unknown format '{fmt}', going with matplotlib default")
84        plt.savefig(fname, bbox_inches="tight")
def setup_plots( plot_mode: Literal['ignore', 'inline', 'widget', 'save'] = 'inline', fig_output_fmt: str | None = 'pdf', fig_numbered_fname: str = 'figure-{num}', fig_config: dict | None = None, fig_basepath: str | None = None, close_after_plotshow: bool = False) -> None:
 87def setup_plots(
 88    plot_mode: PlottingMode = "inline",
 89    fig_output_fmt: str | None = "pdf",
 90    fig_numbered_fname: str = "figure-{num}",
 91    fig_config: dict | None = None,
 92    fig_basepath: str | None = None,
 93    close_after_plotshow: bool = False,
 94) -> None:
 95    """Set up plot saving/rendering options"""
 96    global \
 97        PLOT_MODE, \
 98        CONVERSION_PLOTMODE_OVERRIDE, \
 99        FIG_COUNTER, \
100        FIG_OUTPUT_FMT, \
101        FIG_NUMBERED_FNAME, \
102        FIG_CONFIG, \
103        FIG_BASEPATH, \
104        CLOSE_AFTER_PLOTSHOW
105
106    # set plot mode, handling override
107    if CONVERSION_PLOTMODE_OVERRIDE is not None:
108        # override if set
109        PLOT_MODE = CONVERSION_PLOTMODE_OVERRIDE
110    else:
111        # otherwise use the given plot mode
112        PLOT_MODE = plot_mode
113
114    FIG_COUNTER = 0
115    CLOSE_AFTER_PLOTSHOW = close_after_plotshow
116
117    if PLOT_MODE == "inline":
118        if IN_JUPYTER:
119            ipython = get_ipython()
120            ipython.magic("matplotlib inline")
121        else:
122            raise RuntimeError(
123                f"Cannot use inline plotting outside of Jupyter\n{PLOT_MODE = }\t{CONVERSION_PLOTMODE_OVERRIDE = }"
124            )
125        return
126    elif PLOT_MODE == "widget":
127        if IN_JUPYTER:
128            ipython = get_ipython()
129            ipython.magic("matplotlib widget")
130        else:
131            # matplotlib outside of jupyter will bring up a new window by default
132            pass
133        return
134    elif PLOT_MODE == "ignore":
135        # disable plotting
136        plt.show = lambda: None  # type: ignore[misc]
137        return
138
139    # everything except saving handled up to this point
140    assert PLOT_MODE == "save", f"Invalid plot mode: {PLOT_MODE}"
141
142    FIG_OUTPUT_FMT = fig_output_fmt
143    FIG_NUMBERED_FNAME = fig_numbered_fname
144    FIG_CONFIG = fig_config
145
146    # set default figure format in rcParams savefig.format
147    plt.rcParams["savefig.format"] = FIG_OUTPUT_FMT
148    if FIG_OUTPUT_FMT in TIKZPLOTLIB_FORMATS:
149        try:
150            import tikzplotlib  # type: ignore[import] # noqa: F401
151        except ImportError:
152            warnings.warn(
153                f"Tikzplotlib not installed. Cannot save figures in Tikz format '{FIG_OUTPUT_FMT}', things might break."
154            )
155    else:
156        if FIG_OUTPUT_FMT not in MATPLOTLIB_FORMATS:
157            warnings.warn(
158                f'Unknown figure format, things might break: {plt.rcParams["savefig.format"] = }'
159            )
160
161    # if base path not given, make one
162    if fig_basepath is None:
163        if fig_config is None:
164            # if no config, use the current time
165            from datetime import datetime
166
167            fig_basepath = f"figures/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
168        else:
169            # if config given, convert to string
170            from muutils.misc import dict_to_filename
171
172            fig_basepath = f"figures/{dict_to_filename(fig_config)}"
173
174    FIG_BASEPATH = fig_basepath
175    os.makedirs(fig_basepath, exist_ok=True)
176
177    # if config given, serialize and save that config
178    if fig_config is not None:
179        import json
180
181        from muutils.json_serialize import json_serialize
182
183        with open(f"{fig_basepath}/config.json", "w") as f:
184            json.dump(
185                json_serialize(fig_config),
186                f,
187                indent="\t",
188            )
189
190    print(f"Figures will be saved to: '{fig_basepath}'")

Set up plot saving/rendering options

def configure_notebook( *args, seed: int = 42, device: Any = None, dark_mode: bool = True, plot_mode: Literal['ignore', 'inline', 'widget', 'save'] = 'inline', fig_output_fmt: str | None = 'pdf', fig_numbered_fname: str = 'figure-{num}', fig_config: dict | None = None, fig_basepath: str | None = None, close_after_plotshow: bool = False) -> torch.device | None:
193def configure_notebook(
194    *args,
195    seed: int = 42,
196    device: typing.Any = None,  # this can be a string, torch.device, or None
197    dark_mode: bool = True,
198    plot_mode: PlottingMode = "inline",
199    fig_output_fmt: str | None = "pdf",
200    fig_numbered_fname: str = "figure-{num}",
201    fig_config: dict | None = None,
202    fig_basepath: str | None = None,
203    close_after_plotshow: bool = False,
204) -> "torch.device|None":  # type: ignore[name-defined] # noqa: F821
205    """Shared Jupyter notebook setup steps
206
207    - Set random seeds and library reproducibility settings
208    - Set device based on availability
209    - Set module reloading before code execution
210    - Set plot formatting
211    - Set plot saving/rendering options
212
213    # Parameters:
214     - `seed : int`
215        random seed across libraries including torch, numpy, and random (defaults to `42`)
216       (defaults to `42`)
217     - `device : typing.Any`
218       pytorch device to use
219       (defaults to `None`)
220     - `dark_mode : bool`
221       figures in dark mode
222       (defaults to `True`)
223     - `plot_mode : PlottingMode`
224       how to display plots, one of `PlottingMode` or `["ignore", "inline", "widget", "save"]`
225       (defaults to `"inline"`)
226     - `fig_output_fmt : str | None`
227       format for saving figures
228       (defaults to `"pdf"`)
229     - `fig_numbered_fname : str`
230        format for saving figures with numbers (if they aren't named)
231       (defaults to `"figure-{num}"`)
232     - `fig_config : dict | None`
233       metadata to save with the figures
234       (defaults to `None`)
235     - `fig_basepath : str | None`
236        base path for saving figures
237       (defaults to `None`)
238     - `close_after_plotshow : bool`
239        close figures after showing them
240       (defaults to `False`)
241
242    # Returns:
243     - `torch.device|None`
244       the device set, if torch is installed
245    """
246
247    # set some globals related to plotting
248    setup_plots(
249        plot_mode=plot_mode,
250        fig_output_fmt=fig_output_fmt,
251        fig_numbered_fname=fig_numbered_fname,
252        fig_config=fig_config,
253        fig_basepath=fig_basepath,
254        close_after_plotshow=close_after_plotshow,
255    )
256
257    global PLOT_MODE, FIG_OUTPUT_FMT, FIG_BASEPATH
258
259    print(f"set up plots with {PLOT_MODE = }, {FIG_OUTPUT_FMT = }, {FIG_BASEPATH = }")
260
261    # Set seeds and other reproducibility-related library options
262    set_reproducibility(seed)
263
264    # Reload modules before executing user code
265    if IN_JUPYTER:
266        ipython = get_ipython()
267        if "IPython.extensions.autoreload" not in ipython.extension_manager.loaded:
268            ipython.magic("load_ext autoreload")
269            ipython.magic("autoreload 2")
270
271        # Specify plotly renderer for vscode
272        if PLOTLY_IMPORTED:
273            pio.renderers.default = "notebook_connected"
274
275            if dark_mode:
276                pio.templates.default = "plotly_dark"
277                plt.style.use("dark_background")
278
279    try:
280        # Set device
281        device = get_device(device)
282        return device
283    except ImportError:
284        warnings.warn("Torch not installed. Cannot get/set device.")
285        return None

Shared Jupyter notebook setup steps

  • Set random seeds and library reproducibility settings
  • Set device based on availability
  • Set module reloading before code execution
  • Set plot formatting
  • Set plot saving/rendering options

Parameters:

  • seed : int random seed across libraries including torch, numpy, and random (defaults to 42) (defaults to 42)
  • device : typing.Any pytorch device to use (defaults to None)
  • dark_mode : bool figures in dark mode (defaults to True)
  • plot_mode : PlottingMode how to display plots, one of PlottingMode or ["ignore", "inline", "widget", "save"] (defaults to "inline")
  • fig_output_fmt : str | None format for saving figures (defaults to "pdf")
  • fig_numbered_fname : str format for saving figures with numbers (if they aren't named) (defaults to "figure-{num}")
  • fig_config : dict | None metadata to save with the figures (defaults to None)
  • fig_basepath : str | None base path for saving figures (defaults to None)
  • close_after_plotshow : bool close figures after showing them (defaults to False)

Returns:

  • torch.device|None the device set, if torch is installed
def plotshow( fname: str | None = None, plot_mode: Optional[Literal['ignore', 'inline', 'widget', 'save']] = None, fmt: str | None = None):
288def plotshow(
289    fname: str | None = None,
290    plot_mode: PlottingMode | None = None,
291    fmt: str | None = None,
292):
293    """Show the active plot, depending on global configs"""
294    global FIG_COUNTER, CLOSE_AFTER_PLOTSHOW, PLOT_MODE
295    FIG_COUNTER += 1
296
297    if plot_mode is None:
298        plot_mode = PLOT_MODE
299
300    if plot_mode == "save":
301        # get numbered figure name if not given
302        if fname is None:
303            fname = FIG_NUMBERED_FNAME.format(num=FIG_COUNTER)
304
305        # save figure
306        assert FIG_BASEPATH is not None
307        universal_savefig(os.path.join(FIG_BASEPATH, fname), fmt=fmt)
308    elif plot_mode == "ignore":
309        # do nothing
310        pass
311    elif plot_mode == "inline":
312        # show figure
313        plt.show()
314    elif plot_mode == "widget":
315        # show figure
316        plt.show()
317    else:
318        warnings.warn(f"Invalid plot mode: {plot_mode}")
319
320    if CLOSE_AFTER_PLOTSHOW:
321        plt.close()

Show the active plot, depending on global configs