Source code for figrecipe._composition._compose

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Main composition logic for combining multiple figures.

Supports two composition modes:
1. Grid-based: layout=(nrows, ncols) with sources={(row, col): path}
2. Mm-based: canvas_size_mm=(w, h) with sources={path: {"xy_mm": ..., "size_mm": ...}}

All layouts maintain matplotlib editability - no PIL image pasting.
"""

from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

from numpy.typing import NDArray

from .._recorder import FigureRecord
from .._serializer import load_recipe
from .._wrappers import RecordingAxes, RecordingFigure

# Supported image file extensions for raw image composition
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp", ".gif", ".webp"}
# SVG requires special handling (vector format)
VECTOR_EXTENSIONS = {".svg"}

# Default DPI for mm-based composition
DEFAULT_DPI = 300


def _is_image_file(path: Path) -> bool:
    """Check if path is a supported image file."""
    suffix = path.suffix.lower()
    return suffix in IMAGE_EXTENSIONS or suffix in VECTOR_EXTENSIONS


def _is_mm_based_sources(sources: Dict) -> bool:
    """Check if sources dict uses mm-based positioning."""
    if not sources:
        return False
    first_key = next(iter(sources.keys()))
    # Grid-based uses tuple keys like (0, 0)
    if isinstance(first_key, tuple):
        return False
    # Mm-based uses string/Path keys with dict values containing xy_mm
    first_value = sources[first_key]
    return isinstance(first_value, dict) and "xy_mm" in first_value


def _mm_to_inch(mm: float) -> float:
    """Convert millimeters to inches."""
    return mm / 25.4


def _create_image_record(image_path: Path) -> FigureRecord:
    """Create a FigureRecord from a raw image file."""
    from datetime import datetime

    import matplotlib
    import numpy as np
    from PIL import Image

    from .._recorder import AxesRecord, CallRecord

    suffix = image_path.suffix.lower()
    if suffix == ".svg":
        try:
            import io

            import cairosvg

            png_data = cairosvg.svg2png(url=str(image_path))
            img = Image.open(io.BytesIO(png_data))
        except ImportError:
            raise ImportError(
                "cairosvg is required for SVG support. Install with: pip install cairosvg"
            )
    else:
        img = Image.open(image_path)

    if img.mode not in ("RGB", "RGBA"):
        img = img.convert("RGBA")
    img_array = np.array(img)

    imshow_call = CallRecord(
        id=f"imshow_{image_path.stem}",
        function="imshow",
        args=[{"name": "X", "dtype": "ndarray", "data": img_array.tolist()}],
        kwargs={"aspect": "equal"},
        timestamp=datetime.now().isoformat(),
        ax_position=(0, 0),
    )

    axis_off_call = CallRecord(
        id="axis_off",
        function="axis",
        args=[{"name": "arg0", "data": "off"}],
        kwargs={},
        timestamp=datetime.now().isoformat(),
        ax_position=(0, 0),
    )

    ax_record = AxesRecord(
        position=(0, 0),
        calls=[imshow_call],
        decorations=[axis_off_call],
    )

    height, width = img_array.shape[:2]
    dpi = 100
    figsize = (width / dpi, height / dpi)

    record = FigureRecord(
        figsize=figsize,
        dpi=dpi,
        matplotlib_version=matplotlib.__version__,
    )
    record.axes["ax_0_0"] = ax_record

    return record


[docs] def compose( sources: Dict[Any, Any], layout: Optional[Tuple[int, int]] = None, canvas_size_mm: Optional[Tuple[float, float]] = None, gap_mm: float = 5.0, dpi: int = DEFAULT_DPI, panel_labels: bool = False, label_style: str = "uppercase", **kwargs, ) -> Tuple[RecordingFigure, Union[RecordingAxes, NDArray, List[RecordingAxes]]]: """Compose a new figure from multiple sources (recipes or raw images). Supports two modes automatically detected from sources format: 1. Grid-based: sources={(row, col): path} Uses layout=(nrows, ncols) for subplot grid. 2. Mm-based: sources={path: {"xy_mm": (x, y), "size_mm": (w, h)}} Uses canvas_size_mm for precise positioning. Parameters ---------- sources : dict Either: - Grid-based: {(row, col): source_path} mapping positions to sources - Mm-based: {source_path: {"xy_mm": (x, y), "size_mm": (w, h)}} layout : tuple, optional (nrows, ncols) for grid-based composition. Auto-detected if not provided. canvas_size_mm : tuple, optional (width_mm, height_mm) for mm-based composition. Required for mm-based mode. gap_mm : float Gap between panels in mm (for auto-layout modes like 'horizontal'). dpi : int DPI for the output figure. panel_labels : bool If True, add panel labels (A, B, C...) to each panel. label_style : str 'uppercase', 'lowercase', or 'numeric'. **kwargs Additional arguments passed to figure creation. Returns ------- fig : RecordingFigure Composed figure (editable, recordable). axes : RecordingAxes, ndarray, or list Axes of the composed figure. Examples -------- Grid-based composition: >>> fig, axes = fr.compose( ... layout=(1, 2), ... sources={ ... (0, 0): "panel_a.yaml", ... (0, 1): "panel_b.yaml", ... } ... ) Mm-based free-form composition: >>> fig, axes = fr.compose( ... canvas_size_mm=(180, 120), ... sources={ ... "panel_a.yaml": {"xy_mm": (0, 0), "size_mm": (85, 55)}, ... "panel_b.yaml": {"xy_mm": (90, 0), "size_mm": (85, 55)}, ... "panel_c.yaml": {"xy_mm": (0, 60), "size_mm": (175, 55)}, ... } ... ) """ if _is_mm_based_sources(sources): return _compose_mm_based( sources, canvas_size_mm, dpi, panel_labels, label_style, **kwargs ) else: return _compose_grid_based(sources, layout, panel_labels, label_style, **kwargs)
def _compose_grid_based( sources: Dict[Tuple[int, int], Any], layout: Optional[Tuple[int, int]], panel_labels: bool, label_style: str, **kwargs, ) -> Tuple[RecordingFigure, Union[RecordingAxes, NDArray]]: """Grid-based composition using subplots.""" from .. import subplots # Auto-detect layout from source positions if layout is None: if not sources: raise ValueError("sources cannot be empty") max_row = max(pos[0] for pos in sources.keys()) + 1 max_col = max(pos[1] for pos in sources.keys()) + 1 layout = (max_row, max_col) nrows, ncols = layout fig, axes = subplots(nrows=nrows, ncols=ncols, **kwargs) source_data_dirs = {} for (row, col), source_spec in sources.items(): source_record, ax_key, source_path = _parse_source_spec_with_path(source_spec) ax_record = source_record.axes.get(ax_key) if ax_record is None: available = list(source_record.axes.keys()) raise ValueError( f"Axes '{ax_key}' not found in source. Available: {available}" ) target_ax = _get_axes_at(axes, row, col, nrows, ncols) _replay_axes_record(target_ax, ax_record, fig.record, row, col) if source_path is not None: data_dir = source_path.parent / f"{source_path.stem}_data" if data_dir.exists(): target_ax_key = f"ax_{row}_{col}" source_data_dirs[target_ax_key] = data_dir if source_data_dirs: fig.record.source_data_dirs = source_data_dirs # Add panel labels if requested if panel_labels: _add_panel_labels_grid(axes, nrows, ncols, label_style) return fig, axes def _compose_mm_based( sources: Dict[str, Dict[str, Any]], canvas_size_mm: Optional[Tuple[float, float]], dpi: int, panel_labels: bool, label_style: str, **kwargs, ) -> Tuple[RecordingFigure, List[RecordingAxes]]: """Mm-based composition using fig.add_axes() for precise positioning.""" import matplotlib from .._recorder import Recorder from .._wrappers import RecordingAxes as RA from .._wrappers import RecordingFigure as RF if canvas_size_mm is None: # Auto-calculate canvas size from panel positions max_x = 0 max_y = 0 for spec in sources.values(): xy = spec["xy_mm"] size = spec["size_mm"] max_x = max(max_x, xy[0] + size[0]) max_y = max(max_y, xy[1] + size[1]) canvas_size_mm = (max_x + 5, max_y + 5) # Add margin # Convert canvas size to inches width_inch = _mm_to_inch(canvas_size_mm[0]) height_inch = _mm_to_inch(canvas_size_mm[1]) # Create figure with specified size mpl_fig = matplotlib.pyplot.figure(figsize=(width_inch, height_inch), dpi=dpi) # Create recorder for tracking recorder = Recorder() recorder.start_figure(figsize=(width_inch, height_inch), dpi=dpi) # Store mm composition metadata recorder.figure_record.composition_mode = "mm" recorder.figure_record.canvas_size_mm = canvas_size_mm axes_list = [] source_data_dirs = {} for idx, (source_path, spec) in enumerate(sources.items()): xy_mm = spec["xy_mm"] size_mm = spec["size_mm"] # Convert mm to figure fraction (0-1) # Note: matplotlib uses bottom-left origin, our mm uses top-left left = xy_mm[0] / canvas_size_mm[0] # Flip y: top-left origin -> bottom-left origin bottom = 1.0 - (xy_mm[1] + size_mm[1]) / canvas_size_mm[1] width = size_mm[0] / canvas_size_mm[0] height = size_mm[1] / canvas_size_mm[1] # Create axes at precise position mpl_ax = mpl_fig.add_axes([left, bottom, width, height]) # Load and replay source source_record, ax_key, path = _parse_source_spec_with_path(source_path) ax_record = source_record.axes.get(ax_key) if ax_record is None: available = list(source_record.axes.keys()) raise ValueError( f"Axes '{ax_key}' not found in source. Available: {available}" ) # Wrap as RecordingAxes target_ax = RA(mpl_ax, recorder, position=(0, idx)) axes_list.append(target_ax) # Replay the source onto this axes _replay_axes_record_mm(mpl_ax, ax_record, recorder.figure_record, idx, spec) # Track source data directory if path is not None: data_dir = path.parent / f"{path.stem}_data" if data_dir.exists(): source_data_dirs[f"ax_mm_{idx}"] = data_dir # Wrap as RecordingFigure fig = RF(mpl_fig, recorder, axes_list) if source_data_dirs: fig.record.source_data_dirs = source_data_dirs # Add panel labels if requested if panel_labels: _add_panel_labels_mm(mpl_fig, sources, canvas_size_mm, label_style) return fig, axes_list def _replay_axes_record_mm( mpl_ax, ax_record, fig_record: FigureRecord, idx: int, spec: Dict[str, Any], ) -> None: """Replay axes record for mm-based composition.""" from .._reproducer._core import _replay_call result_cache: Dict[str, Any] = {} for call in ax_record.calls: result = _replay_call(mpl_ax, call, result_cache) if result is not None: result_cache[call.id] = result for call in ax_record.decorations: result = _replay_call(mpl_ax, call, result_cache) if result is not None: result_cache[call.id] = result # Store with mm position info ax_key = f"ax_mm_{idx}" ax_record_copy = ax_record ax_record_copy.mm_position = spec fig_record.axes[ax_key] = ax_record_copy def _add_panel_labels_grid(axes, nrows: int, ncols: int, style: str) -> None: """Add panel labels to grid-based composition.""" labels = _get_panel_labels(nrows * ncols, style) idx = 0 for row in range(nrows): for col in range(ncols): ax = _get_axes_at(axes, row, col, nrows, ncols) mpl_ax = ax._ax if hasattr(ax, "_ax") else ax mpl_ax.text( -0.1, 1.1, labels[idx], transform=mpl_ax.transAxes, fontsize=10, fontweight="bold", va="top", ha="right", ) idx += 1 def _add_panel_labels_mm(fig, sources: Dict, canvas_size_mm: Tuple, style: str) -> None: """Add panel labels to mm-based composition.""" labels = _get_panel_labels(len(sources), style) for idx, (_, spec) in enumerate(sources.items()): xy_mm = spec["xy_mm"] # Position label at top-left of panel x_frac = xy_mm[0] / canvas_size_mm[0] y_frac = 1.0 - xy_mm[1] / canvas_size_mm[1] fig.text( x_frac - 0.02, y_frac + 0.02, labels[idx], fontsize=10, fontweight="bold", va="bottom", ha="right", ) def _get_panel_labels(n: int, style: str) -> List[str]: """Generate panel labels based on style.""" if style == "uppercase": return [chr(ord("A") + i) for i in range(n)] elif style == "lowercase": return [chr(ord("a") + i) for i in range(n)] else: # numeric return [str(i + 1) for i in range(n)] def _parse_source_spec( spec: Union[str, Path, FigureRecord, Tuple], ) -> Tuple[FigureRecord, str]: """Parse source specification into (FigureRecord, ax_key).""" record, ax_key, _ = _parse_source_spec_with_path(spec) return record, ax_key def _parse_source_spec_with_path( spec: Union[str, Path, FigureRecord, Tuple], ) -> Tuple[FigureRecord, str, Optional[Path]]: """Parse source specification into (FigureRecord, ax_key, source_path).""" if isinstance(spec, (str, Path)): path = Path(spec) if _is_image_file(path): return _create_image_record(path), "ax_0_0", path return load_recipe(path), "ax_0_0", path elif isinstance(spec, FigureRecord): return spec, "ax_0_0", None elif isinstance(spec, tuple) and len(spec) == 2: source, ax_key = spec if isinstance(source, (str, Path)): path = Path(source) if _is_image_file(path): return _create_image_record(path), "ax_0_0", path return load_recipe(path), ax_key, path elif isinstance(source, FigureRecord): return source, ax_key, None raise TypeError(f"Invalid source in tuple: {type(source)}") raise TypeError(f"Invalid source spec type: {type(spec)}") def _get_axes_at( axes: Union[RecordingAxes, NDArray], row: int, col: int, nrows: int, ncols: int, ) -> RecordingAxes: """Get axes at position, handling different array shapes.""" if nrows == 1 and ncols == 1: return axes elif nrows == 1: return axes[col] elif ncols == 1: return axes[row] else: return axes[row, col] def _replay_axes_record( target_ax: RecordingAxes, ax_record, fig_record: FigureRecord, row: int, col: int, ) -> None: """Replay all calls from ax_record onto target axes.""" from .._reproducer._core import _replay_call mpl_ax = target_ax._ax if hasattr(target_ax, "_ax") else target_ax result_cache: Dict[str, Any] = {} for call in ax_record.calls: result = _replay_call(mpl_ax, call, result_cache) if result is not None: result_cache[call.id] = result for call in ax_record.decorations: result = _replay_call(mpl_ax, call, result_cache) if result is not None: result_cache[call.id] = result ax_key = f"ax_{row}_{col}" fig_record.axes[ax_key] = ax_record __all__ = ["compose"]