#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Wrapped Axes that records all plotting calls."""
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import numpy as np
from matplotlib.axes import Axes
from ._axes_methods import RecordingAxesMethods
from ._axes_schematic import SchematicMixin
from ._axes_style_mixin import AxesStyleMixin
if TYPE_CHECKING:
from .._recorder import Recorder
[docs]
class RecordingAxes(RecordingAxesMethods, AxesStyleMixin, SchematicMixin):
"""Wrapper around matplotlib Axes that records all calls.
This wrapper intercepts calls to plotting methods and records them
for later reproduction.
Parameters
----------
ax : matplotlib.axes.Axes
The underlying matplotlib axes.
recorder : Recorder
The recorder instance to log calls to.
position : tuple
(row, col) position in the figure grid.
Examples
--------
>>> import figrecipe as ps
>>> fig, ax = ps.subplots()
>>> ax.plot([1, 2, 3], [4, 5, 6], color='red', id='my_line')
>>> # The call is recorded automatically
"""
# Internal: Methods whose results can be referenced by other methods
_RESULT_REFERENCEABLE_METHODS = {"contour", "contourf"}
# Internal: Methods that take results from other methods as arguments
_RESULT_REFERENCING_METHODS = {"clabel"}
[docs]
def __init__(
self,
ax: Axes,
recorder: "Recorder",
position: Tuple[int, int] = (0, 0),
):
self._ax = ax
self._recorder = recorder
self._position = position
self._track = True
# Map matplotlib result objects (by id) to their source call_id
self._result_refs: Dict[int, str] = {}
@property
def ax(self) -> Axes:
"""Get the underlying matplotlib axes."""
return self._ax
@property
def position(self) -> Tuple[int, int]:
"""Get axes position in grid."""
return self._position
[docs]
def __getattr__(self, name: str) -> Any:
"""Intercept attribute access to wrap methods."""
attr = getattr(self._ax, name)
# Use custom wrappers for methods with special styling
if callable(attr) and name == "bar":
return self._create_bar_wrapper()
# Route boxplot to wrapper that sets patch_artist=True
if callable(attr) and name == "boxplot":
return self._create_boxplot_wrapper()
# Route legend to wrapper that applies frame styling
if callable(attr) and name == "legend":
return self._create_legend_wrapper()
# Route stem to wrapper that handles color kwarg
if callable(attr) and name == "stem":
return self._create_stem_wrapper()
# If it's a plotting or decoration method, wrap it
if callable(attr) and name in (
self._recorder.PLOTTING_METHODS | self._recorder.DECORATION_METHODS
):
return self._create_recording_wrapper(name, attr)
# For other methods/attributes, return as-is
return attr
[docs]
def __dir__(self):
"""Return list of attributes for tab completion.
Exposes all matplotlib plotting and decoration methods alongside
figrecipe's custom methods and properties.
"""
# Get base attributes (excluding private)
base_attrs = [a for a in super().__dir__() if not a.startswith("_")]
# Add all matplotlib plotting methods
from .._params import DECORATION_METHODS, PLOTTING_METHODS
matplotlib_methods = sorted(PLOTTING_METHODS | DECORATION_METHODS)
# Combine and deduplicate
return sorted(set(base_attrs + matplotlib_methods))
[docs]
def _create_recording_wrapper(self, method_name: str, method: callable):
"""Create a wrapper function that records the call."""
from ._axes_helpers import (
inject_clip_on_from_style,
inject_method_defaults,
record_call_with_color_capture,
)
def wrapper(
*args,
id: Optional[str] = None,
track: bool = True,
stats: Optional[Dict[str, Any]] = None,
**kwargs,
):
from ..styles._internal import resolve_colors_in_kwargs
kwargs = resolve_colors_in_kwargs(kwargs)
kwargs = inject_clip_on_from_style(kwargs, method_name)
kwargs = inject_method_defaults(kwargs, method_name)
result = method(*args, **kwargs)
if self._track and track:
record_kwargs = kwargs.copy()
if stats is not None:
record_kwargs["stats"] = stats
record_call_with_color_capture(
self._recorder,
self._position,
method_name,
args,
record_kwargs,
result,
id,
self._result_refs,
self._RESULT_REFERENCING_METHODS,
self._RESULT_REFERENCEABLE_METHODS,
)
return result
return wrapper
[docs]
def _create_bar_wrapper(self):
"""Create wrapper for bar() with SCITEX error bar styling."""
from ._axes_plots import bar_plot
def wrapper(
*args,
id: Optional[str] = None,
track: bool = True,
stats: Optional[Dict[str, Any]] = None,
**kwargs,
):
if stats is not None:
kwargs["stats"] = stats
return bar_plot(
self._ax,
args,
kwargs,
self._recorder,
self._position,
track=self._track and track,
call_id=id,
)
return wrapper
[docs]
def _create_boxplot_wrapper(self):
"""Create wrapper for boxplot() with patch_artist=True default."""
from ._boxplot import boxplot_plot
def wrapper(
*args,
id: Optional[str] = None,
track: bool = True,
**kwargs,
):
# Handle positional x argument
x = args[0] if args else kwargs.pop("x", None)
if x is None:
raise ValueError("boxplot requires data argument")
return boxplot_plot(
self._ax,
x,
self._recorder,
self._position,
track=self._track and track,
call_id=id,
**kwargs,
)
return wrapper
[docs]
def _create_stem_wrapper(self):
"""Create wrapper for stem() that accepts color kwarg."""
original_stem = self._ax.stem
def wrapper(
*args,
id: Optional[str] = None,
track: bool = True,
color=None,
**kwargs,
):
# Call original stem
container = original_stem(*args, **kwargs)
# Apply color if provided (stem doesn't accept color kwarg natively)
if color is not None:
import matplotlib.colors as mcolors
color_val = mcolors.to_rgba(color)
container.markerline.set_color(color_val)
container.stemlines.set_color(color_val)
# Record the call with color
if self._track and track:
record_kwargs = kwargs.copy()
# Capture the actual color (either provided or from cycle)
if color is not None:
import matplotlib.colors as mcolors
record_kwargs["color"] = mcolors.to_hex(color)
else:
# Capture from result
try:
import matplotlib.colors as mcolors
c = container.markerline.get_color()
record_kwargs["color"] = mcolors.to_hex(c)
except Exception:
pass
self._recorder.record_call(
ax_position=self._position,
method_name="stem",
args=args,
kwargs=record_kwargs,
call_id=id,
)
return container
return wrapper
[docs]
def _create_legend_wrapper(self):
"""Create wrapper for legend() that applies frame styling and records the call."""
from ..styles import load_style
original_legend = self._ax.legend
def wrapper(
*args,
id: Optional[str] = None,
track: bool = True,
**kwargs,
):
legend = original_legend(*args, **kwargs)
# Apply SCITEX style frame settings
if legend is not None:
style = load_style()
legend_config = style.get("legend", {})
frameon = legend_config.get("frameon", True)
edge_mm = legend_config.get("edge_mm", 0.2)
edgecolor = legend_config.get("edgecolor", "black")
if frameon and edge_mm:
frame = legend.get_frame()
frame.set_linewidth(edge_mm * 72 / 25.4) # mm to points
if edgecolor:
frame.set_edgecolor(edgecolor)
# Record the legend call for reproduction
if self._track and track:
record_kwargs = kwargs.copy()
# Handle custom handles - extract color/label info for serialization
if "handles" in record_kwargs:
handles = record_kwargs.pop("handles")
handle_specs = []
for h in handles:
spec = {"label": h.get_label()}
if hasattr(h, "get_facecolor"):
spec["facecolor"] = list(h.get_facecolor())
if hasattr(h, "get_edgecolor"):
spec["edgecolor"] = list(h.get_edgecolor())
handle_specs.append(spec)
record_kwargs["_handle_specs"] = handle_specs
self._recorder.record_call(
ax_position=self._position,
method_name="legend",
args=args,
kwargs=record_kwargs,
call_id=id,
)
return legend
return wrapper
[docs]
def set_caption(self, caption: str) -> "RecordingAxes":
"""Set panel caption metadata (not rendered, stored in recipe)."""
ax_record = self._recorder.figure_record.get_or_create_axes(*self._position)
ax_record.caption = caption
return self
@property
def panel_caption(self) -> Optional[str]:
"""Get the panel caption metadata."""
ax_record = self._recorder.figure_record.get_or_create_axes(*self._position)
return ax_record.caption
[docs]
def set_stats(self, stats: Dict[str, Any]) -> "RecordingAxes":
"""Set panel-level statistics metadata (not rendered, stored in recipe)."""
ax_record = self._recorder.figure_record.get_or_create_axes(*self._position)
ax_record.stats = stats
return self
@property
def stats(self) -> Optional[Dict[str, Any]]:
"""Get the panel-level statistics metadata."""
ax_record = self._recorder.figure_record.get_or_create_axes(*self._position)
return ax_record.stats
[docs]
def _no_record(self):
"""Context manager to temporarily disable recording (internal)."""
return _NoRecordContext(self)
[docs]
def _record_seaborn_call(
self,
func_name: str,
args: tuple,
kwargs: Dict[str, Any],
data_arrays: Dict[str, np.ndarray],
call_id: Optional[str] = None,
) -> None:
"""Record a seaborn plotting call."""
if not self._track:
return
from ._axes_seaborn import record_seaborn_call
record_seaborn_call(
self._recorder,
self._position,
func_name,
args,
kwargs,
data_arrays,
call_id,
)
# Expose common properties directly
@property
def figure(self):
return self._ax.figure
@property
def xaxis(self):
return self._ax.xaxis
@property
def yaxis(self):
return self._ax.yaxis
# Methods that should not be recorded
[docs]
def get_xlim(self):
return self._ax.get_xlim()
[docs]
def get_ylim(self):
return self._ax.get_ylim()
[docs]
def get_xlabel(self):
return self._ax.get_xlabel()
[docs]
def get_ylabel(self):
return self._ax.get_ylabel()
[docs]
def get_title(self):
return self._ax.get_title()
@property
def caption(self) -> Optional[str]:
"""Get the panel caption metadata."""
ax_record = self._recorder.figure_record.get_or_create_axes(*self._position)
return ax_record.caption
[docs]
def generate_panel_caption(
self, label: Optional[str] = None, style: str = "publication"
) -> str:
"""Generate a caption for this panel from stats metadata."""
from ._caption_generator import generate_panel_caption
return generate_panel_caption(label=label, stats=self.stats, style=style)
class _NoRecordContext:
"""Context manager to temporarily disable recording."""
def __init__(self, axes: RecordingAxes):
self._axes = axes
self._original_track = axes._track
def __enter__(self):
self._axes._track = False
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._axes._track = self._original_track
return False