#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Schematic diagram builder for rich scientific diagrams."""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.patches import FancyArrowPatch, FancyBboxPatch
from .._diagram._styles_native import get_edge_style, get_emphasis_style
from .._utils._units import mm_to_pt
# Anchor point definitions (relative to box: 0-1 range)
ANCHOR_POINTS = {
"center": (0.5, 0.5),
"top": (0.5, 1.0),
"bottom": (0.5, 0.0),
"left": (0.0, 0.5),
"right": (1.0, 0.5),
"top-left": (0.0, 1.0),
"top-right": (1.0, 1.0),
"bottom-left": (0.0, 0.0),
"bottom-right": (1.0, 0.0),
}
@dataclass
class BoxSpec:
"""Specification for a rich text box."""
id: str
title: str
subtitle: Optional[str] = None
content: List[Dict] = field(default_factory=list)
emphasis: str = "normal"
shape: str = "rounded"
fill_color: Optional[str] = None
border_color: Optional[str] = None
title_color: Optional[str] = None
padding_mm: float = 5.0 # Inner spacing from box edge to text (mm)
margin_mm: float = 0.0 # Outer spacing for collision detection (mm)
@dataclass
class ArrowSpec:
"""Specification for an arrow."""
id: Optional[str] = None
source: str = ""
target: str = ""
source_anchor: str = "auto"
target_anchor: str = "auto"
label: Optional[str] = None
style: str = "solid"
color: Optional[str] = None
curve: float = 0.0 # Dimensionless curve parameter
linewidth_mm: float = 0.5 # Line width in mm
label_offset_mm: Optional[Tuple[float, float]] = None # Manual (dx, dy) nudge
@dataclass
class PositionSpec:
"""Position and size specification in mm."""
x_mm: float
y_mm: float
width_mm: float
height_mm: float
[docs]
class Schematic:
"""Builder for rich schematic diagrams."""
[docs]
def __init__(
self,
title: Optional[str] = None,
width_mm: float = 170.0,
height_mm: float = 120.0,
):
import warnings
if width_mm > 185.0:
warnings.warn(
f"Schematic width {width_mm}mm exceeds 185mm (double-column max).",
stacklevel=2,
)
self.title = title
self.width_mm = width_mm
self.height_mm = height_mm
# Compute figsize and limits (1 data unit = 1 mm)
self.figsize = (width_mm / 25.4, height_mm / 25.4)
self.xlim = (0, width_mm)
self.ylim = (0, height_mm)
self._boxes: Dict[str, BoxSpec] = {}
self._containers: Dict[str, Dict] = {}
self._arrows: List[ArrowSpec] = []
self._positions: Dict[str, PositionSpec] = {}
self._render_info: Dict[str, Dict[str, Any]] = {}
[docs]
def add_box(
self,
id: str,
title: str,
subtitle: Optional[str] = None,
content: Optional[List] = None,
emphasis: str = "normal",
shape: str = "rounded",
position_mm: Optional[Tuple[float, float]] = None,
size_mm: Optional[Tuple[float, float]] = None,
fill_color: Optional[str] = None,
border_color: Optional[str] = None,
title_color: Optional[str] = None,
padding_mm: float = 5.0,
margin_mm: float = 0.0,
) -> "Schematic":
"""Add a rich text box."""
self._boxes[id] = BoxSpec(
id=id,
title=title,
subtitle=subtitle,
content=content or [],
emphasis=emphasis,
shape=shape,
fill_color=fill_color,
border_color=border_color,
title_color=title_color,
padding_mm=padding_mm,
margin_mm=margin_mm,
)
if position_mm and size_mm:
self._positions[id] = PositionSpec(
x_mm=position_mm[0],
y_mm=position_mm[1],
width_mm=size_mm[0],
height_mm=size_mm[1],
)
return self
[docs]
def add_container(
self,
id: str,
title: Optional[str] = None,
children: Optional[List[str]] = None,
emphasis: str = "muted",
position_mm: Optional[Tuple[float, float]] = None,
size_mm: Optional[Tuple[float, float]] = None,
fill_color: Optional[str] = None,
border_color: Optional[str] = None,
) -> "Schematic":
"""Add a container that groups other boxes."""
self._containers[id] = {
"title": title,
"children": children or [],
"emphasis": emphasis,
"fill_color": fill_color,
"border_color": border_color,
}
if position_mm and size_mm:
self._positions[id] = PositionSpec(
x_mm=position_mm[0],
y_mm=position_mm[1],
width_mm=size_mm[0],
height_mm=size_mm[1],
)
return self
[docs]
def add_arrow(
self,
source: str,
target: str,
source_anchor: str = "auto",
target_anchor: str = "auto",
label: Optional[str] = None,
style: str = "solid",
color: Optional[str] = None,
curve: float = 0.0,
linewidth_mm: float = 0.5,
label_offset_mm: Optional[Tuple[float, float]] = None,
) -> "Schematic":
"""Add an arrow connecting two boxes."""
auto_id = f"arrow:{source}->{target}"
self._arrows.append(
ArrowSpec(
id=auto_id,
source=source,
target=target,
source_anchor=source_anchor,
target_anchor=target_anchor,
label=label,
style=style,
color=color,
curve=curve,
linewidth_mm=linewidth_mm,
label_offset_mm=label_offset_mm,
)
)
return self
[docs]
def validate_containers(self) -> None:
"""Check every container fully encloses its declared children."""
from ._schematic_validate import validate_containers
validate_containers(self)
[docs]
def validate_no_overlap(self) -> None:
"""Check that no two boxes overlap each other."""
from ._schematic_validate import validate_no_overlap
validate_no_overlap(self)
[docs]
def auto_layout(
self,
layout: str = "lr",
margin_mm: float = 15.0,
box_size_mm: Optional[Tuple[float, float]] = None,
gap_mm: float = 10.0,
avoid_overlap: bool = True,
justify: str = "space-between",
align_items: str = "center",
) -> "Schematic":
"""Automatically position boxes. See _schematic_layout for details."""
from ._schematic_layout import auto_layout
auto_layout(
self,
layout=layout,
margin_mm=margin_mm,
box_size_mm=box_size_mm,
gap_mm=gap_mm,
avoid_overlap=avoid_overlap,
justify=justify,
align_items=align_items,
)
return self
[docs]
def _get_anchor(self, pos: PositionSpec, anchor: str) -> Tuple[float, float]:
"""Get absolute position of an anchor point on the visual box edge."""
if anchor not in ANCHOR_POINTS:
anchor = "center"
rx, ry = ANCHOR_POINTS[anchor]
x = pos.x_mm - pos.width_mm / 2 + rx * pos.width_mm
y = pos.y_mm - pos.height_mm / 2 + ry * pos.height_mm
return x, y
[docs]
def _auto_anchor(self, src: PositionSpec, tgt: PositionSpec) -> Tuple[str, str]:
"""Determine best anchor points automatically."""
dx, dy = tgt.x_mm - src.x_mm, tgt.y_mm - src.y_mm
if abs(dx) > abs(dy):
return ("right", "left") if dx > 0 else ("left", "right")
return ("top", "bottom") if dy > 0 else ("bottom", "top")
[docs]
def render(self, ax: Optional[Axes] = None) -> Tuple[Figure, Axes]:
"""Render the schematic."""
figsize = (
(self.xlim[1] - self.xlim[0]) / 25.4,
(self.ylim[1] - self.ylim[0]) / 25.4,
)
owns_fig = ax is None
if owns_fig:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
ax.set_xlim(self.xlim)
ax.set_ylim(self.ylim)
ax.set_aspect("equal")
ax.axis("off")
from . import _schematic_validate as _sv
# Render everything first (so errored figures can be inspected)
for cid, container in self._containers.items():
if cid in self._positions:
self._render_container(ax, cid, container)
for bid, box in self._boxes.items():
if bid in self._positions:
self._render_box(ax, bid, box)
for arrow in self._arrows:
self._render_arrow(ax, arrow)
if self.title:
ax.text(
(self.xlim[0] + self.xlim[1]) / 2,
self.ylim[1] - 5.0,
self.title,
ha="center",
va="top",
fontsize=16,
fontweight="bold",
)
# Validate when we own the figure; when ax is external,
# the caller (schematic_plot) handles validation with error capture
if owns_fig:
_sv.validate_all(self, fig=fig, ax=ax)
return fig, ax
[docs]
def _render_container(self, ax: Axes, cid: str, container: Dict) -> None:
"""Render a container box."""
pos = self._positions[cid]
colors = get_emphasis_style(container["emphasis"])
fill = container.get("fill_color") or colors["fill"]
border = container.get("border_color") or colors["stroke"]
box = FancyBboxPatch(
(pos.x_mm - pos.width_mm / 2, pos.y_mm - pos.height_mm / 2),
pos.width_mm,
pos.height_mm,
boxstyle="round,pad=1.0,rounding_size=3.0",
facecolor=fill,
edgecolor=border,
linewidth=2.5,
zorder=1,
)
ax.add_patch(box)
if container.get("title"):
_ct_bg = dict(facecolor=fill, edgecolor="none", pad=1.0, alpha=0.85)
ax.text(
pos.x_mm,
pos.y_mm + pos.height_mm / 2 - 1.5,
container["title"],
ha="center",
va="top",
fontsize=11,
fontweight="bold",
color=colors["text"],
zorder=7,
bbox=_ct_bg,
)
self._render_info[cid] = {"pos": pos}
# Aesthetic pad for FancyBboxPatch rounding (does NOT affect layout) - now 1mm
_aesthetic_pad = 1.0
[docs]
def _render_box(self, ax: Axes, bid: str, box: BoxSpec) -> None:
"""Render a rich text box."""
pos = self._positions[bid]
colors = get_emphasis_style(box.emphasis)
fill = box.fill_color or colors["fill"]
border = box.border_color or colors["stroke"]
title_color = box.title_color or colors["text"]
pad = self._aesthetic_pad
shape_styles = {
"box": f"square,pad={pad}",
"rounded": f"round,pad={pad},rounding_size=2.0",
"stadium": f"round,pad={pad},rounding_size=5.0",
}
boxstyle = shape_styles.get(box.shape, shape_styles["rounded"])
patch = FancyBboxPatch(
(pos.x_mm - pos.width_mm / 2, pos.y_mm - pos.height_mm / 2),
pos.width_mm,
pos.height_mm,
boxstyle=boxstyle,
facecolor=fill,
edgecolor=border,
linewidth=2,
zorder=2,
)
ax.add_patch(patch)
# Build text items: list of (text, fontsize, fontweight, color)
items = [(box.title, 11, "bold", title_color)]
if box.subtitle:
items.append((box.subtitle, 9, "normal", colors["text"]))
for line in box.content:
if isinstance(line, dict):
items.append(
(
line.get("text", ""),
line.get("fontsize", 8),
line.get("fontweight", "normal"),
line.get("color", colors["text"]),
)
)
else:
items.append((str(line), 8, "normal", colors["text"]))
# Text area = PositionSpec minus padding on all sides
inner_h = pos.height_mm - 2 * box.padding_mm
n = len(items)
gap = min(inner_h / max(n, 1) * 0.85, 6.0) if n > 1 else 0
block_h = gap * (n - 1)
top_y = pos.y_mm + block_h / 2
_txt_bg = dict(facecolor=fill, edgecolor="none", pad=0.5, alpha=0.85)
for i, (text, fsize, fweight, fcolor) in enumerate(items):
ax.text(
pos.x_mm,
top_y - i * gap,
text,
ha="center",
va="center",
fontsize=fsize,
fontweight=fweight,
color=fcolor,
fontstyle="normal",
zorder=7,
bbox=_txt_bg,
)
self._render_info[bid] = {"pos": pos}
[docs]
def _render_arrow(self, ax: Axes, arrow: ArrowSpec) -> None:
"""Render an arrow."""
if arrow.source not in self._positions or arrow.target not in self._positions:
return
src_pos = self._positions[arrow.source]
tgt_pos = self._positions[arrow.target]
# Determine anchors
if arrow.source_anchor == "auto" or arrow.target_anchor == "auto":
auto_src, auto_tgt = self._auto_anchor(src_pos, tgt_pos)
src_anc = auto_src if arrow.source_anchor == "auto" else arrow.source_anchor
tgt_anc = auto_tgt if arrow.target_anchor == "auto" else arrow.target_anchor
else:
src_anc, tgt_anc = arrow.source_anchor, arrow.target_anchor
start = self._get_anchor(src_pos, src_anc)
end = self._get_anchor(tgt_pos, tgt_anc)
style = get_edge_style(arrow.style)
from .._diagram._styles_native import COLORS as _COLORS
color = _COLORS.get(arrow.color, arrow.color) if arrow.color else style["color"]
conn = f"arc3,rad={arrow.curve}" if arrow.curve else "arc3,rad=0"
# Convert linewidth from mm to pt
linewidth_pt = mm_to_pt(arrow.linewidth_mm)
patch = FancyArrowPatch(
posA=start,
posB=end,
arrowstyle="-|>",
color=color,
linestyle=style["linestyle"],
linewidth=linewidth_pt,
connectionstyle=conn,
mutation_scale=15,
zorder=5,
)
ax.add_patch(patch)
if arrow.label:
from ._schematic_validate import compute_arrow_label_position
lx, ly = compute_arrow_label_position(
start, end, arrow.curve, arrow.label_offset_mm
)
_label_bg = dict(facecolor="white", edgecolor="none", pad=1.0, alpha=0.85)
ax.text(
lx,
ly,
arrow.label,
ha="center",
va="bottom",
fontsize=8,
color=color,
fontstyle="italic",
zorder=6,
bbox=_label_bg,
)
[docs]
def render_to_file(self, path: Union[str, Path], dpi: int = 200) -> Path:
"""Render and save. On validation failure, saves as *_FAILED.png."""
path = Path(path)
try:
fig, ax = self.render()
except ValueError:
fig = plt.gcf()
failed = path.with_stem(f"{path.stem}_FAILED")
fig.savefig(failed, dpi=dpi, bbox_inches="tight", facecolor="white")
plt.close(fig)
raise
fig.savefig(path, dpi=dpi, bbox_inches="tight", facecolor="white")
plt.close(fig)
return path
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert schematic to dictionary for serialization."""
from ._schematic_io import schematic_to_dict
return schematic_to_dict(self)
[docs]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Schematic":
"""Create Schematic from dictionary (recipe reproduction)."""
from ._schematic_io import schematic_from_dict
return schematic_from_dict(data)
__all__ = ["Schematic", "ArrowSpec", "BoxSpec", "PositionSpec"]