docs for muutils v0.9.1
View Source on GitHub

muutils.tensor_info

get metadata about a tensor, mostly for muutils.dbg


  1"get metadata about a tensor, mostly for `muutils.dbg`"
  2
  3from __future__ import annotations
  4
  5import numpy as np
  6from typing import Union, Any, Literal, List, Dict, overload, Optional, TYPE_CHECKING
  7
  8if TYPE_CHECKING:
  9    from typing import TypedDict
 10else:
 11    try:
 12        from typing import TypedDict
 13    except ImportError:
 14        from typing_extensions import TypedDict
 15
 16# Global color definitions
 17COLORS: Dict[str, Dict[str, str]] = {
 18    "latex": {
 19        "range": r"\textcolor{purple}",
 20        "mean": r"\textcolor{teal}",
 21        "std": r"\textcolor{orange}",
 22        "median": r"\textcolor{green}",
 23        "warning": r"\textcolor{red}",
 24        "shape": r"\textcolor{magenta}",
 25        "dtype": r"\textcolor{gray}",
 26        "device": r"\textcolor{gray}",
 27        "requires_grad": r"\textcolor{gray}",
 28        "sparkline": r"\textcolor{blue}",
 29        "torch": r"\textcolor{orange}",
 30        "dtype_bool": r"\textcolor{gray}",
 31        "dtype_int": r"\textcolor{blue}",
 32        "dtype_float": r"\textcolor{red!70}",  # 70% red intensity
 33        "dtype_str": r"\textcolor{red}",
 34        "device_cuda": r"\textcolor{green}",
 35        "reset": "",
 36    },
 37    "terminal": {
 38        "range": "\033[35m",  # purple
 39        "mean": "\033[36m",  # cyan/teal
 40        "std": "\033[33m",  # yellow/orange
 41        "median": "\033[32m",  # green
 42        "warning": "\033[31m",  # red
 43        "shape": "\033[95m",  # bright magenta
 44        "dtype": "\033[90m",  # gray
 45        "device": "\033[90m",  # gray
 46        "requires_grad": "\033[90m",  # gray
 47        "sparkline": "\033[34m",  # blue
 48        "torch": "\033[38;5;208m",  # bright orange
 49        "dtype_bool": "\033[38;5;245m",  # medium grey
 50        "dtype_int": "\033[38;5;39m",  # bright blue
 51        "dtype_float": "\033[38;5;167m",  # softer red/coral
 52        "device_cuda": "\033[38;5;76m",  # NVIDIA-style bright green
 53        "reset": "\033[0m",
 54    },
 55    "none": {
 56        "range": "",
 57        "mean": "",
 58        "std": "",
 59        "median": "",
 60        "warning": "",
 61        "shape": "",
 62        "dtype": "",
 63        "device": "",
 64        "requires_grad": "",
 65        "sparkline": "",
 66        "torch": "",
 67        "dtype_bool": "",
 68        "dtype_int": "",
 69        "dtype_float": "",
 70        "dtype_str": "",
 71        "device_cuda": "",
 72        "reset": "",
 73    },
 74}
 75
 76OutputFormat = Literal["unicode", "latex", "ascii"]
 77
 78
 79class ArraySummarySettings(TypedDict):
 80    """Type definition for array_summary default settings."""
 81
 82    fmt: OutputFormat
 83    precision: int
 84    stats: bool
 85    shape: bool
 86    dtype: bool
 87    device: bool
 88    requires_grad: bool
 89    sparkline: bool
 90    sparkline_bins: int
 91    sparkline_logy: Optional[bool]
 92    colored: bool
 93    as_list: bool
 94    eq_char: str
 95
 96
 97SYMBOLS: Dict[OutputFormat, Dict[str, str]] = {
 98    "latex": {
 99        "range": r"\mathcal{R}",
100        "mean": r"\mu",
101        "std": r"\sigma",
102        "median": r"\tilde{x}",
103        "distribution": r"\mathbb{P}",
104        "distribution_log": r"\mathbb{P}_L",
105        "nan_values": r"\text{NANvals}",
106        "warning": "!!!",
107        "requires_grad": r"\nabla",
108        "true": r"\checkmark",
109        "false": r"\times",
110    },
111    "unicode": {
112        "range": "R",
113        "mean": "μ",
114        "std": "σ",
115        "median": "x̃",
116        "distribution": "ℙ",
117        "distribution_log": "ℙ˪",
118        "nan_values": "NANvals",
119        "warning": "🚨",
120        "requires_grad": "∇",
121        "true": "✓",
122        "false": "✗",
123    },
124    "ascii": {
125        "range": "range",
126        "mean": "mean",
127        "std": "std",
128        "median": "med",
129        "distribution": "dist",
130        "distribution_log": "dist_log",
131        "nan_values": "NANvals",
132        "warning": "!!!",
133        "requires_grad": "requires_grad",
134        "true": "1",
135        "false": "0",
136    },
137}
138"Symbols for different formats"
139
140SPARK_CHARS: Dict[OutputFormat, List[str]] = {
141    "unicode": list(" ▁▂▃▄▅▆▇█"),
142    "ascii": list(" _.-~=#"),
143    "latex": list(" ▁▂▃▄▅▆▇█"),
144}
145"characters for sparklines in different formats"
146
147
148def array_info(
149    A: Any,
150    hist_bins: int = 5,
151) -> Dict[str, Any]:
152    """Extract statistical information from an array-like object.
153
154    # Parameters:
155     - `A : array-like`
156            Array to analyze (numpy array or torch tensor)
157
158    # Returns:
159     - `Dict[str, Any]`
160            Dictionary containing raw statistical information with numeric values
161    """
162    result: Dict[str, Any] = {
163        "is_tensor": None,
164        "device": None,
165        "requires_grad": None,
166        "shape": None,
167        "dtype": None,
168        "size": None,
169        "has_nans": None,
170        "nan_count": None,
171        "nan_percent": None,
172        "min": None,
173        "max": None,
174        "range": None,
175        "mean": None,
176        "std": None,
177        "median": None,
178        "histogram": None,
179        "bins": None,
180        "status": None,
181    }
182
183    # Check if it's a tensor by looking at its class name
184    # This avoids importing torch directly
185    A_type: str = type(A).__name__
186    result["is_tensor"] = A_type == "Tensor"
187
188    # Try to get device information if it's a tensor
189    if result["is_tensor"]:
190        try:
191            result["device"] = str(getattr(A, "device", None))
192        except:  # noqa: E722
193            pass
194
195    # Convert to numpy array for calculations
196    try:
197        # For PyTorch tensors
198        if result["is_tensor"]:
199            # Check if tensor is on GPU
200            is_cuda: bool = False
201            try:
202                is_cuda = bool(getattr(A, "is_cuda", False))
203            except:  # noqa: E722
204                pass
205
206            if is_cuda:
207                try:
208                    # Try to get CPU tensor first
209                    cpu_tensor = getattr(A, "cpu", lambda: A)()
210                except:  # noqa: E722
211                    A_np = np.array([])
212            else:
213                cpu_tensor = A
214            try:
215                # For CPU tensor, just detach and convert
216                detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()  # pyright: ignore[reportPossiblyUnboundVariable]
217                A_np = getattr(detached, "numpy", lambda: np.array([]))()
218            except:  # noqa: E722
219                A_np = np.array([])
220        else:
221            # For numpy arrays and other array-like objects
222            A_np = np.asarray(A)
223    except:  # noqa: E722
224        A_np = np.array([])
225
226    # Get basic information
227    try:
228        result["shape"] = A_np.shape
229        result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
230        result["size"] = A_np.size
231        result["requires_grad"] = getattr(A, "requires_grad", None)
232    except:  # noqa: E722
233        pass
234
235    # If array is empty, return early
236    if result["size"] == 0:
237        result["status"] = "empty array"
238        return result
239
240    # Flatten array for statistics if it's multi-dimensional
241    # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225
242    try:
243        if len(A_np.shape) > 1:
244            A_flat = A_np.flatten()  # type: ignore[assignment]
245        else:
246            A_flat = A_np  # type: ignore[assignment]
247    except:  # noqa: E722
248        A_flat = A_np  # type: ignore[assignment]
249
250    # Check for NaN values
251    try:
252        nan_mask = np.isnan(A_flat)
253        result["nan_count"] = np.sum(nan_mask)
254        result["has_nans"] = result["nan_count"] > 0
255        result_size: int = result["size"]  # ty: ignore[invalid-assignment]
256        if result_size > 0:
257            result["nan_percent"] = (result["nan_count"] / result_size) * 100
258    except:  # noqa: E722
259        pass
260
261    # If all values are NaN, return early
262    if result["has_nans"] and result["nan_count"] == result["size"]:
263        result["status"] = "all NaN"
264        return result
265
266    # Calculate statistics
267    try:
268        if result["has_nans"]:
269            result["min"] = float(np.nanmin(A_flat))
270            result["max"] = float(np.nanmax(A_flat))
271            result["mean"] = float(np.nanmean(A_flat))
272            result["std"] = float(np.nanstd(A_flat))
273            result["median"] = float(np.nanmedian(A_flat))
274            result["range"] = (result["min"], result["max"])
275
276            # Remove NaNs for histogram
277            # TYPING: nan mask will def be bound on this branch, idk why it thinks the operator is bad
278            A_hist = A_flat[~nan_mask]  # pyright: ignore[reportOperatorIssue, reportPossiblyUnboundVariable]
279        else:
280            result["min"] = float(np.min(A_flat))
281            result["max"] = float(np.max(A_flat))
282            result["mean"] = float(np.mean(A_flat))
283            result["std"] = float(np.std(A_flat))
284            result["median"] = float(np.median(A_flat))
285            result["range"] = (result["min"], result["max"])
286
287            A_hist = A_flat
288
289        # Calculate histogram data for sparklines
290        if A_hist.size > 0:
291            try:
292                # TODO: handle bool tensors correctly
293                # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
294                hist, bins = np.histogram(A_hist, bins=hist_bins)
295                result["histogram"] = hist
296                result["bins"] = bins
297            except:  # noqa: E722
298                pass
299
300        result["status"] = "ok"
301    except Exception as e:
302        result["status"] = f"error: {str(e)}"
303
304    return result
305
306
307SparklineFormat = Literal["unicode", "latex", "ascii"]
308
309
310def generate_sparkline(
311    histogram: np.ndarray,
312    format: SparklineFormat = "unicode",
313    log_y: Optional[bool] = None,
314) -> tuple[str, bool]:
315    """Generate a sparkline visualization of the histogram.
316
317    # Parameters:
318    - `histogram : np.ndarray`
319        Histogram data
320    - `format : Literal["unicode", "latex", "ascii"]`
321        Output format (defaults to `"unicode"`)
322    - `log_y : bool|None`
323        Whether to use logarithmic y-scale. `None` for automatic detection
324        (defaults to `None`)
325
326    # Returns:
327    - `tuple[str, bool]`
328        Sparkline visualization and whether log scale was used
329    """
330    if histogram is None or len(histogram) == 0:
331        return "", False
332
333    # Get the appropriate character set
334    chars: List[str]
335    if format in SPARK_CHARS:
336        chars = SPARK_CHARS[format]
337    else:
338        chars = SPARK_CHARS["ascii"]
339
340    # automatic detection of log_y
341    if log_y is None:
342        # we bin the histogram values to the number of levels in our sparkline characters
343        hist_hist = np.histogram(histogram, bins=len(chars))[0]
344        # if every bin except the smallest (first) and largest (last) is empty,
345        # then we should use the log scale. if those bins are nonempty, keep the linear scale
346        if hist_hist[1:-1].max() > 0:
347            log_y = False
348        else:
349            log_y = True
350
351    # Handle log scale
352    if log_y:
353        # Add small value to avoid log(0)
354        hist_data = np.log1p(histogram)
355    else:
356        hist_data = histogram
357
358    # Normalize to character set range
359    if hist_data.max() > 0:
360        normalized = hist_data / hist_data.max() * (len(chars) - 1)
361    else:
362        normalized = np.zeros_like(hist_data)
363
364    # Convert to characters
365    spark = ""
366    for val in normalized:
367        idx = round(val)
368        spark += chars[idx]
369
370    return spark, log_y
371
372
373DEFAULT_SETTINGS: ArraySummarySettings = {
374    "fmt": "unicode",
375    "precision": 2,
376    "stats": True,
377    "shape": True,
378    "dtype": True,
379    "device": True,
380    "requires_grad": True,
381    "sparkline": False,
382    "sparkline_bins": 5,
383    "sparkline_logy": None,
384    "colored": False,
385    "as_list": False,
386    "eq_char": "=",
387}
388
389
390def apply_color(
391    text: str, color_key: str, colors: Dict[str, str], using_tex: bool
392) -> str:
393    if using_tex:
394        return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
395    else:
396        return (
397            f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text
398        )
399
400
401def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
402    """Colorize dtype string with specific colors for torch and type names."""
403
404    # Handle torch prefix
405    type_part: str = dtype_str
406    prefix_part: Optional[str] = None
407    if "torch." in dtype_str:
408        parts = dtype_str.split("torch.")
409        if len(parts) == 2:
410            prefix_part = apply_color("torch", "torch", colors, using_tex)
411            type_part = parts[1]
412
413    # Handle type coloring
414    color_key: str = "dtype"
415    if "bool" in dtype_str.lower():
416        color_key = "dtype_bool"
417    elif "int" in dtype_str.lower():
418        color_key = "dtype_int"
419    elif "float" in dtype_str.lower():
420        color_key = "dtype_float"
421
422    type_colored: str = apply_color(type_part, color_key, colors, using_tex)
423
424    if prefix_part:
425        return f"{prefix_part}.{type_colored}"
426    else:
427        return type_colored
428
429
430def format_shape_colored(
431    shape_val: Any, colors: Dict[str, str], using_tex: bool
432) -> str:
433    """Format shape with proper coloring for both 1D and multi-D arrays."""
434
435    def apply_color(text: str, color_key: str) -> str:
436        if using_tex:
437            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
438        else:
439            return (
440                f"{colors[color_key]}{text}{colors['reset']}"
441                if colors[color_key]
442                else text
443            )
444
445    if len(shape_val) == 1:
446        # For 1D arrays, still color the dimension value
447        return apply_color(str(shape_val[0]), "shape")
448    else:
449        # For multi-D arrays, color each dimension
450        return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")"
451
452
453def format_device_colored(
454    device_str: str, colors: Dict[str, str], using_tex: bool
455) -> str:
456    """Format device string with CUDA highlighting."""
457
458    def apply_color(text: str, color_key: str) -> str:
459        if using_tex:
460            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
461        else:
462            return (
463                f"{colors[color_key]}{text}{colors['reset']}"
464                if colors[color_key]
465                else text
466            )
467
468    if "cuda" in device_str.lower():
469        return apply_color(device_str, "device_cuda")
470    else:
471        return apply_color(device_str, "device")
472
473
474class _UseDefaultType:
475    pass
476
477
478_USE_DEFAULT = _UseDefaultType()
479
480
481@overload
482def array_summary(
483    array: Any,
484    fmt: OutputFormat = ...,
485    precision: int = ...,
486    stats: bool = ...,
487    shape: bool = ...,
488    dtype: bool = ...,
489    device: bool = ...,
490    requires_grad: bool = ...,
491    sparkline: bool = ...,
492    sparkline_bins: int = ...,
493    sparkline_logy: Optional[bool] = ...,
494    colored: bool = ...,
495    eq_char: str = ...,
496    *,
497    as_list: Literal[True],
498) -> List[str]: ...
499@overload
500def array_summary(
501    array: Any,
502    fmt: OutputFormat = ...,
503    precision: int = ...,
504    stats: bool = ...,
505    shape: bool = ...,
506    dtype: bool = ...,
507    device: bool = ...,
508    requires_grad: bool = ...,
509    sparkline: bool = ...,
510    sparkline_bins: int = ...,
511    sparkline_logy: Optional[bool] = ...,
512    colored: bool = ...,
513    eq_char: str = ...,
514    as_list: Literal[False] = ...,
515) -> str: ...
516@overload
517def array_summary(
518    array: Any,
519    fmt: OutputFormat = ...,
520    precision: int = ...,
521    stats: bool = ...,
522    shape: bool = ...,
523    dtype: bool = ...,
524    device: bool = ...,
525    requires_grad: bool = ...,
526    sparkline: bool = ...,
527    sparkline_bins: int = ...,
528    sparkline_logy: Optional[bool] = ...,
529    colored: bool = ...,
530    eq_char: str = ...,
531    as_list: bool = ...,
532) -> Union[str, List[str]]: ...
533def array_summary(
534    array: Any,
535    fmt: Union[OutputFormat, _UseDefaultType] = _USE_DEFAULT,
536    precision: Union[int, _UseDefaultType] = _USE_DEFAULT,
537    stats: Union[bool, _UseDefaultType] = _USE_DEFAULT,
538    shape: Union[bool, _UseDefaultType] = _USE_DEFAULT,
539    dtype: Union[bool, _UseDefaultType] = _USE_DEFAULT,
540    device: Union[bool, _UseDefaultType] = _USE_DEFAULT,
541    requires_grad: Union[bool, _UseDefaultType] = _USE_DEFAULT,
542    sparkline: Union[bool, _UseDefaultType] = _USE_DEFAULT,
543    sparkline_bins: Union[int, _UseDefaultType] = _USE_DEFAULT,
544    sparkline_logy: Union[Optional[bool], _UseDefaultType] = _USE_DEFAULT,
545    colored: Union[bool, _UseDefaultType] = _USE_DEFAULT,
546    eq_char: Union[str, _UseDefaultType] = _USE_DEFAULT,
547    as_list: Union[bool, _UseDefaultType] = _USE_DEFAULT,
548) -> Union[str, List[str]]:
549    """Format array information into a readable summary.
550
551    # Parameters:
552     - `array`
553            array-like object (numpy array or torch tensor)
554     - `precision : int`
555            Decimal places (defaults to `2`)
556     - `format : Literal["unicode", "latex", "ascii"]`
557            Output format (defaults to `{default_fmt}`)
558     - `stats : bool`
559            Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
560     - `shape : bool`
561            Whether to include shape info (defaults to `True`)
562     - `dtype : bool`
563            Whether to include dtype info (defaults to `True`)
564     - `device : bool`
565            Whether to include device info for torch tensors (defaults to `True`)
566     - `requires_grad : bool`
567            Whether to include requires_grad info for torch tensors (defaults to `True`)
568     - `sparkline : bool`
569            Whether to include a sparkline visualization (defaults to `False`)
570     - `sparkline_width : int`
571            Width of the sparkline (defaults to `20`)
572     - `sparkline_logy : bool|None`
573            Whether to use logarithmic y-scale for sparkline (defaults to `None`)
574     - `colored : bool`
575            Whether to add color to output (defaults to `False`)
576     - `as_list : bool`
577            Whether to return as list of strings instead of joined string (defaults to `False`)
578
579    # Returns:
580     - `Union[str, List[str]]`
581            Formatted statistical summary, either as string or list of strings
582    """
583    if isinstance(fmt, _UseDefaultType):
584        fmt = DEFAULT_SETTINGS["fmt"]
585    if isinstance(precision, _UseDefaultType):
586        precision = DEFAULT_SETTINGS["precision"]
587    if isinstance(stats, _UseDefaultType):
588        stats = DEFAULT_SETTINGS["stats"]
589    if isinstance(shape, _UseDefaultType):
590        shape = DEFAULT_SETTINGS["shape"]
591    if isinstance(dtype, _UseDefaultType):
592        dtype = DEFAULT_SETTINGS["dtype"]
593    if isinstance(device, _UseDefaultType):
594        device = DEFAULT_SETTINGS["device"]
595    if isinstance(requires_grad, _UseDefaultType):
596        requires_grad = DEFAULT_SETTINGS["requires_grad"]
597    if isinstance(sparkline, _UseDefaultType):
598        sparkline = DEFAULT_SETTINGS["sparkline"]
599    if isinstance(sparkline_bins, _UseDefaultType):
600        sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
601    if isinstance(sparkline_logy, _UseDefaultType):
602        sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
603    if isinstance(colored, _UseDefaultType):
604        colored = DEFAULT_SETTINGS["colored"]
605    if isinstance(as_list, _UseDefaultType):
606        as_list = DEFAULT_SETTINGS["as_list"]
607    if isinstance(eq_char, _UseDefaultType):
608        eq_char = DEFAULT_SETTINGS["eq_char"]
609
610    array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
611    result_parts: List[str] = []
612    using_tex: bool = fmt == "latex"
613
614    # Set color scheme based on format and colored flag
615    colors: Dict[str, str]
616    if colored:
617        colors = COLORS["latex"] if using_tex else COLORS["terminal"]
618    else:
619        colors = COLORS["none"]
620
621    # Get symbols for the current format
622    symbols: Dict[str, str] = SYMBOLS[fmt]
623
624    # Helper function to colorize text
625    def colorize(text: str, color_key: str) -> str:
626        if using_tex:
627            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
628        else:
629            return (
630                f"{colors[color_key]}{text}{colors['reset']}"
631                if colors[color_key]
632                else text
633            )
634
635    # Check if dtype is integer type
636    dtype_str: str = array_data.get("dtype", "")
637    is_int_dtype: bool = any(
638        int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"]
639    )
640
641    # Format string for numbers
642    float_fmt: str = f".{precision}f"
643
644    # Handle error status or empty array
645    if (
646        array_data["status"] in ["empty array", "all NaN", "unknown"]
647        or array_data["size"] == 0
648    ):
649        status = array_data["status"]
650        result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
651    else:
652        # Add NaN warning at the beginning if there are NaNs
653        if array_data["has_nans"]:
654            _percent: str = "\\%" if using_tex else "%"
655            nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
656            result_parts.append(colorize(nan_str, "warning"))
657
658        # Statistics
659        if stats:
660            for stat_key in ["mean", "std", "median"]:
661                if array_data[stat_key] is not None:
662                    stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
663                    stat_colored: str = colorize(stat_str, stat_key)
664                    result_parts.append(f"{symbols[stat_key]}={stat_colored}")
665
666            # Range (min, max)
667            if array_data["range"] is not None:
668                min_val, max_val = array_data["range"]
669                if is_int_dtype:
670                    min_str: str = f"{int(min_val):d}"
671                    max_str: str = f"{int(max_val):d}"
672                else:
673                    min_str = f"{min_val:{float_fmt}}"
674                    max_str = f"{max_val:{float_fmt}}"
675                min_colored: str = colorize(min_str, "range")
676                max_colored: str = colorize(max_str, "range")
677                range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
678                result_parts.append(range_str)
679
680    # Add sparkline if requested
681    if sparkline and array_data["histogram"] is not None:
682        # this should return whether log_y is used or not and then we set the symbol accordingly
683        spark, used_log = generate_sparkline(
684            array_data["histogram"],
685            format=fmt,
686            log_y=sparkline_logy,
687        )
688        if spark:
689            spark_colored = colorize(spark, "sparkline")
690            dist_symbol = (
691                symbols["distribution_log"] if used_log else symbols["distribution"]
692            )
693            result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|")
694
695    # Add shape if requested
696    if shape and array_data["shape"]:
697        shape_val = array_data["shape"]
698        shape_str = format_shape_colored(shape_val, colors, using_tex)
699        result_parts.append(f"shape{eq_char}{shape_str}")
700
701    # Add dtype if requested
702    if dtype and array_data["dtype"]:
703        dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex)
704        result_parts.append(f"dtype={dtype_colored}")
705
706    # Add device if requested and it's a tensor with device info
707    if device and array_data["is_tensor"] and array_data["device"]:
708        device_colored = format_device_colored(array_data["device"], colors, using_tex)
709        result_parts.append(f"device{eq_char}{device_colored}")
710
711    # Add gradient info
712    if requires_grad and array_data["is_tensor"]:
713        bool_req_grad_symb: str = (
714            symbols["true"] if array_data["requires_grad"] else symbols["false"]
715        )
716        result_parts.append(
717            colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
718        )
719
720    # Return as list if requested, otherwise join with spaces
721    if as_list:
722        return result_parts
723    else:
724        joinchar: str = r" \quad " if using_tex else " "
725        return joinchar.join(result_parts)

COLORS: Dict[str, Dict[str, str]] = {'latex': {'range': '\\textcolor{purple}', 'mean': '\\textcolor{teal}', 'std': '\\textcolor{orange}', 'median': '\\textcolor{green}', 'warning': '\\textcolor{red}', 'shape': '\\textcolor{magenta}', 'dtype': '\\textcolor{gray}', 'device': '\\textcolor{gray}', 'requires_grad': '\\textcolor{gray}', 'sparkline': '\\textcolor{blue}', 'torch': '\\textcolor{orange}', 'dtype_bool': '\\textcolor{gray}', 'dtype_int': '\\textcolor{blue}', 'dtype_float': '\\textcolor{red!70}', 'dtype_str': '\\textcolor{red}', 'device_cuda': '\\textcolor{green}', 'reset': ''}, 'terminal': {'range': '\x1b[35m', 'mean': '\x1b[36m', 'std': '\x1b[33m', 'median': '\x1b[32m', 'warning': '\x1b[31m', 'shape': '\x1b[95m', 'dtype': '\x1b[90m', 'device': '\x1b[90m', 'requires_grad': '\x1b[90m', 'sparkline': '\x1b[34m', 'torch': '\x1b[38;5;208m', 'dtype_bool': '\x1b[38;5;245m', 'dtype_int': '\x1b[38;5;39m', 'dtype_float': '\x1b[38;5;167m', 'device_cuda': '\x1b[38;5;76m', 'reset': '\x1b[0m'}, 'none': {'range': '', 'mean': '', 'std': '', 'median': '', 'warning': '', 'shape': '', 'dtype': '', 'device': '', 'requires_grad': '', 'sparkline': '', 'torch': '', 'dtype_bool': '', 'dtype_int': '', 'dtype_float': '', 'dtype_str': '', 'device_cuda': '', 'reset': ''}}
OutputFormat = typing.Literal['unicode', 'latex', 'ascii']
class ArraySummarySettings(typing.TypedDict):
80class ArraySummarySettings(TypedDict):
81    """Type definition for array_summary default settings."""
82
83    fmt: OutputFormat
84    precision: int
85    stats: bool
86    shape: bool
87    dtype: bool
88    device: bool
89    requires_grad: bool
90    sparkline: bool
91    sparkline_bins: int
92    sparkline_logy: Optional[bool]
93    colored: bool
94    as_list: bool
95    eq_char: str

Type definition for array_summary default settings.

fmt: Literal['unicode', 'latex', 'ascii']
precision: int
stats: bool
shape: bool
dtype: bool
device: bool
requires_grad: bool
sparkline: bool
sparkline_bins: int
sparkline_logy: Optional[bool]
colored: bool
as_list: bool
eq_char: str
Inherited Members
builtins.dict
get
setdefault
pop
popitem
keys
items
values
update
fromkeys
clear
copy
SYMBOLS: Dict[Literal['unicode', 'latex', 'ascii'], Dict[str, str]] = {'latex': {'range': '\\mathcal{R}', 'mean': '\\mu', 'std': '\\sigma', 'median': '\\tilde{x}', 'distribution': '\\mathbb{P}', 'distribution_log': '\\mathbb{P}_L', 'nan_values': '\\text{NANvals}', 'warning': '!!!', 'requires_grad': '\\nabla', 'true': '\\checkmark', 'false': '\\times'}, 'unicode': {'range': 'R', 'mean': 'μ', 'std': 'σ', 'median': 'x̃', 'distribution': 'ℙ', 'distribution_log': 'ℙ˪', 'nan_values': 'NANvals', 'warning': '🚨', 'requires_grad': '∇', 'true': '✓', 'false': '✗'}, 'ascii': {'range': 'range', 'mean': 'mean', 'std': 'std', 'median': 'med', 'distribution': 'dist', 'distribution_log': 'dist_log', 'nan_values': 'NANvals', 'warning': '!!!', 'requires_grad': 'requires_grad', 'true': '1', 'false': '0'}}

Symbols for different formats

SPARK_CHARS: Dict[Literal['unicode', 'latex', 'ascii'], List[str]] = {'unicode': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], 'ascii': [' ', '_', '.', '-', '~', '=', '#'], 'latex': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']}

characters for sparklines in different formats

def array_info(A: Any, hist_bins: int = 5) -> Dict[str, Any]:
149def array_info(
150    A: Any,
151    hist_bins: int = 5,
152) -> Dict[str, Any]:
153    """Extract statistical information from an array-like object.
154
155    # Parameters:
156     - `A : array-like`
157            Array to analyze (numpy array or torch tensor)
158
159    # Returns:
160     - `Dict[str, Any]`
161            Dictionary containing raw statistical information with numeric values
162    """
163    result: Dict[str, Any] = {
164        "is_tensor": None,
165        "device": None,
166        "requires_grad": None,
167        "shape": None,
168        "dtype": None,
169        "size": None,
170        "has_nans": None,
171        "nan_count": None,
172        "nan_percent": None,
173        "min": None,
174        "max": None,
175        "range": None,
176        "mean": None,
177        "std": None,
178        "median": None,
179        "histogram": None,
180        "bins": None,
181        "status": None,
182    }
183
184    # Check if it's a tensor by looking at its class name
185    # This avoids importing torch directly
186    A_type: str = type(A).__name__
187    result["is_tensor"] = A_type == "Tensor"
188
189    # Try to get device information if it's a tensor
190    if result["is_tensor"]:
191        try:
192            result["device"] = str(getattr(A, "device", None))
193        except:  # noqa: E722
194            pass
195
196    # Convert to numpy array for calculations
197    try:
198        # For PyTorch tensors
199        if result["is_tensor"]:
200            # Check if tensor is on GPU
201            is_cuda: bool = False
202            try:
203                is_cuda = bool(getattr(A, "is_cuda", False))
204            except:  # noqa: E722
205                pass
206
207            if is_cuda:
208                try:
209                    # Try to get CPU tensor first
210                    cpu_tensor = getattr(A, "cpu", lambda: A)()
211                except:  # noqa: E722
212                    A_np = np.array([])
213            else:
214                cpu_tensor = A
215            try:
216                # For CPU tensor, just detach and convert
217                detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()  # pyright: ignore[reportPossiblyUnboundVariable]
218                A_np = getattr(detached, "numpy", lambda: np.array([]))()
219            except:  # noqa: E722
220                A_np = np.array([])
221        else:
222            # For numpy arrays and other array-like objects
223            A_np = np.asarray(A)
224    except:  # noqa: E722
225        A_np = np.array([])
226
227    # Get basic information
228    try:
229        result["shape"] = A_np.shape
230        result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
231        result["size"] = A_np.size
232        result["requires_grad"] = getattr(A, "requires_grad", None)
233    except:  # noqa: E722
234        pass
235
236    # If array is empty, return early
237    if result["size"] == 0:
238        result["status"] = "empty array"
239        return result
240
241    # Flatten array for statistics if it's multi-dimensional
242    # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225
243    try:
244        if len(A_np.shape) > 1:
245            A_flat = A_np.flatten()  # type: ignore[assignment]
246        else:
247            A_flat = A_np  # type: ignore[assignment]
248    except:  # noqa: E722
249        A_flat = A_np  # type: ignore[assignment]
250
251    # Check for NaN values
252    try:
253        nan_mask = np.isnan(A_flat)
254        result["nan_count"] = np.sum(nan_mask)
255        result["has_nans"] = result["nan_count"] > 0
256        result_size: int = result["size"]  # ty: ignore[invalid-assignment]
257        if result_size > 0:
258            result["nan_percent"] = (result["nan_count"] / result_size) * 100
259    except:  # noqa: E722
260        pass
261
262    # If all values are NaN, return early
263    if result["has_nans"] and result["nan_count"] == result["size"]:
264        result["status"] = "all NaN"
265        return result
266
267    # Calculate statistics
268    try:
269        if result["has_nans"]:
270            result["min"] = float(np.nanmin(A_flat))
271            result["max"] = float(np.nanmax(A_flat))
272            result["mean"] = float(np.nanmean(A_flat))
273            result["std"] = float(np.nanstd(A_flat))
274            result["median"] = float(np.nanmedian(A_flat))
275            result["range"] = (result["min"], result["max"])
276
277            # Remove NaNs for histogram
278            # TYPING: nan mask will def be bound on this branch, idk why it thinks the operator is bad
279            A_hist = A_flat[~nan_mask]  # pyright: ignore[reportOperatorIssue, reportPossiblyUnboundVariable]
280        else:
281            result["min"] = float(np.min(A_flat))
282            result["max"] = float(np.max(A_flat))
283            result["mean"] = float(np.mean(A_flat))
284            result["std"] = float(np.std(A_flat))
285            result["median"] = float(np.median(A_flat))
286            result["range"] = (result["min"], result["max"])
287
288            A_hist = A_flat
289
290        # Calculate histogram data for sparklines
291        if A_hist.size > 0:
292            try:
293                # TODO: handle bool tensors correctly
294                # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
295                hist, bins = np.histogram(A_hist, bins=hist_bins)
296                result["histogram"] = hist
297                result["bins"] = bins
298            except:  # noqa: E722
299                pass
300
301        result["status"] = "ok"
302    except Exception as e:
303        result["status"] = f"error: {str(e)}"
304
305    return result

Extract statistical information from an array-like object.

Parameters:

  • A : array-like Array to analyze (numpy array or torch tensor)

Returns:

  • Dict[str, Any] Dictionary containing raw statistical information with numeric values
SparklineFormat = typing.Literal['unicode', 'latex', 'ascii']
def generate_sparkline( histogram: numpy.ndarray, format: Literal['unicode', 'latex', 'ascii'] = 'unicode', log_y: Optional[bool] = None) -> tuple[str, bool]:
311def generate_sparkline(
312    histogram: np.ndarray,
313    format: SparklineFormat = "unicode",
314    log_y: Optional[bool] = None,
315) -> tuple[str, bool]:
316    """Generate a sparkline visualization of the histogram.
317
318    # Parameters:
319    - `histogram : np.ndarray`
320        Histogram data
321    - `format : Literal["unicode", "latex", "ascii"]`
322        Output format (defaults to `"unicode"`)
323    - `log_y : bool|None`
324        Whether to use logarithmic y-scale. `None` for automatic detection
325        (defaults to `None`)
326
327    # Returns:
328    - `tuple[str, bool]`
329        Sparkline visualization and whether log scale was used
330    """
331    if histogram is None or len(histogram) == 0:
332        return "", False
333
334    # Get the appropriate character set
335    chars: List[str]
336    if format in SPARK_CHARS:
337        chars = SPARK_CHARS[format]
338    else:
339        chars = SPARK_CHARS["ascii"]
340
341    # automatic detection of log_y
342    if log_y is None:
343        # we bin the histogram values to the number of levels in our sparkline characters
344        hist_hist = np.histogram(histogram, bins=len(chars))[0]
345        # if every bin except the smallest (first) and largest (last) is empty,
346        # then we should use the log scale. if those bins are nonempty, keep the linear scale
347        if hist_hist[1:-1].max() > 0:
348            log_y = False
349        else:
350            log_y = True
351
352    # Handle log scale
353    if log_y:
354        # Add small value to avoid log(0)
355        hist_data = np.log1p(histogram)
356    else:
357        hist_data = histogram
358
359    # Normalize to character set range
360    if hist_data.max() > 0:
361        normalized = hist_data / hist_data.max() * (len(chars) - 1)
362    else:
363        normalized = np.zeros_like(hist_data)
364
365    # Convert to characters
366    spark = ""
367    for val in normalized:
368        idx = round(val)
369        spark += chars[idx]
370
371    return spark, log_y

Generate a sparkline visualization of the histogram.

Parameters:

  • histogram : np.ndarray Histogram data
  • format : Literal["unicode", "latex", "ascii"] Output format (defaults to "unicode")
  • log_y : bool|None Whether to use logarithmic y-scale. None for automatic detection (defaults to None)

Returns:

  • tuple[str, bool] Sparkline visualization and whether log scale was used
DEFAULT_SETTINGS: ArraySummarySettings = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': False, 'sparkline_bins': 5, 'sparkline_logy': None, 'colored': False, 'as_list': False, 'eq_char': '='}
def apply_color( text: str, color_key: str, colors: Dict[str, str], using_tex: bool) -> str:
391def apply_color(
392    text: str, color_key: str, colors: Dict[str, str], using_tex: bool
393) -> str:
394    if using_tex:
395        return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
396    else:
397        return (
398            f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text
399        )
def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
402def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
403    """Colorize dtype string with specific colors for torch and type names."""
404
405    # Handle torch prefix
406    type_part: str = dtype_str
407    prefix_part: Optional[str] = None
408    if "torch." in dtype_str:
409        parts = dtype_str.split("torch.")
410        if len(parts) == 2:
411            prefix_part = apply_color("torch", "torch", colors, using_tex)
412            type_part = parts[1]
413
414    # Handle type coloring
415    color_key: str = "dtype"
416    if "bool" in dtype_str.lower():
417        color_key = "dtype_bool"
418    elif "int" in dtype_str.lower():
419        color_key = "dtype_int"
420    elif "float" in dtype_str.lower():
421        color_key = "dtype_float"
422
423    type_colored: str = apply_color(type_part, color_key, colors, using_tex)
424
425    if prefix_part:
426        return f"{prefix_part}.{type_colored}"
427    else:
428        return type_colored

Colorize dtype string with specific colors for torch and type names.

def format_shape_colored(shape_val: Any, colors: Dict[str, str], using_tex: bool) -> str:
431def format_shape_colored(
432    shape_val: Any, colors: Dict[str, str], using_tex: bool
433) -> str:
434    """Format shape with proper coloring for both 1D and multi-D arrays."""
435
436    def apply_color(text: str, color_key: str) -> str:
437        if using_tex:
438            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
439        else:
440            return (
441                f"{colors[color_key]}{text}{colors['reset']}"
442                if colors[color_key]
443                else text
444            )
445
446    if len(shape_val) == 1:
447        # For 1D arrays, still color the dimension value
448        return apply_color(str(shape_val[0]), "shape")
449    else:
450        # For multi-D arrays, color each dimension
451        return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")"

Format shape with proper coloring for both 1D and multi-D arrays.

def format_device_colored(device_str: str, colors: Dict[str, str], using_tex: bool) -> str:
454def format_device_colored(
455    device_str: str, colors: Dict[str, str], using_tex: bool
456) -> str:
457    """Format device string with CUDA highlighting."""
458
459    def apply_color(text: str, color_key: str) -> str:
460        if using_tex:
461            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
462        else:
463            return (
464                f"{colors[color_key]}{text}{colors['reset']}"
465                if colors[color_key]
466                else text
467            )
468
469    if "cuda" in device_str.lower():
470        return apply_color(device_str, "device_cuda")
471    else:
472        return apply_color(device_str, "device")

Format device string with CUDA highlighting.

def array_summary( array: Any, fmt: Union[Literal['unicode', 'latex', 'ascii'], muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, precision: Union[int, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, stats: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, shape: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, dtype: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, device: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, requires_grad: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, sparkline: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, sparkline_bins: Union[int, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, sparkline_logy: Union[bool, NoneType, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, colored: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, eq_char: Union[str, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, as_list: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>) -> Union[str, List[str]]:
534def array_summary(
535    array: Any,
536    fmt: Union[OutputFormat, _UseDefaultType] = _USE_DEFAULT,
537    precision: Union[int, _UseDefaultType] = _USE_DEFAULT,
538    stats: Union[bool, _UseDefaultType] = _USE_DEFAULT,
539    shape: Union[bool, _UseDefaultType] = _USE_DEFAULT,
540    dtype: Union[bool, _UseDefaultType] = _USE_DEFAULT,
541    device: Union[bool, _UseDefaultType] = _USE_DEFAULT,
542    requires_grad: Union[bool, _UseDefaultType] = _USE_DEFAULT,
543    sparkline: Union[bool, _UseDefaultType] = _USE_DEFAULT,
544    sparkline_bins: Union[int, _UseDefaultType] = _USE_DEFAULT,
545    sparkline_logy: Union[Optional[bool], _UseDefaultType] = _USE_DEFAULT,
546    colored: Union[bool, _UseDefaultType] = _USE_DEFAULT,
547    eq_char: Union[str, _UseDefaultType] = _USE_DEFAULT,
548    as_list: Union[bool, _UseDefaultType] = _USE_DEFAULT,
549) -> Union[str, List[str]]:
550    """Format array information into a readable summary.
551
552    # Parameters:
553     - `array`
554            array-like object (numpy array or torch tensor)
555     - `precision : int`
556            Decimal places (defaults to `2`)
557     - `format : Literal["unicode", "latex", "ascii"]`
558            Output format (defaults to `{default_fmt}`)
559     - `stats : bool`
560            Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
561     - `shape : bool`
562            Whether to include shape info (defaults to `True`)
563     - `dtype : bool`
564            Whether to include dtype info (defaults to `True`)
565     - `device : bool`
566            Whether to include device info for torch tensors (defaults to `True`)
567     - `requires_grad : bool`
568            Whether to include requires_grad info for torch tensors (defaults to `True`)
569     - `sparkline : bool`
570            Whether to include a sparkline visualization (defaults to `False`)
571     - `sparkline_width : int`
572            Width of the sparkline (defaults to `20`)
573     - `sparkline_logy : bool|None`
574            Whether to use logarithmic y-scale for sparkline (defaults to `None`)
575     - `colored : bool`
576            Whether to add color to output (defaults to `False`)
577     - `as_list : bool`
578            Whether to return as list of strings instead of joined string (defaults to `False`)
579
580    # Returns:
581     - `Union[str, List[str]]`
582            Formatted statistical summary, either as string or list of strings
583    """
584    if isinstance(fmt, _UseDefaultType):
585        fmt = DEFAULT_SETTINGS["fmt"]
586    if isinstance(precision, _UseDefaultType):
587        precision = DEFAULT_SETTINGS["precision"]
588    if isinstance(stats, _UseDefaultType):
589        stats = DEFAULT_SETTINGS["stats"]
590    if isinstance(shape, _UseDefaultType):
591        shape = DEFAULT_SETTINGS["shape"]
592    if isinstance(dtype, _UseDefaultType):
593        dtype = DEFAULT_SETTINGS["dtype"]
594    if isinstance(device, _UseDefaultType):
595        device = DEFAULT_SETTINGS["device"]
596    if isinstance(requires_grad, _UseDefaultType):
597        requires_grad = DEFAULT_SETTINGS["requires_grad"]
598    if isinstance(sparkline, _UseDefaultType):
599        sparkline = DEFAULT_SETTINGS["sparkline"]
600    if isinstance(sparkline_bins, _UseDefaultType):
601        sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
602    if isinstance(sparkline_logy, _UseDefaultType):
603        sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
604    if isinstance(colored, _UseDefaultType):
605        colored = DEFAULT_SETTINGS["colored"]
606    if isinstance(as_list, _UseDefaultType):
607        as_list = DEFAULT_SETTINGS["as_list"]
608    if isinstance(eq_char, _UseDefaultType):
609        eq_char = DEFAULT_SETTINGS["eq_char"]
610
611    array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
612    result_parts: List[str] = []
613    using_tex: bool = fmt == "latex"
614
615    # Set color scheme based on format and colored flag
616    colors: Dict[str, str]
617    if colored:
618        colors = COLORS["latex"] if using_tex else COLORS["terminal"]
619    else:
620        colors = COLORS["none"]
621
622    # Get symbols for the current format
623    symbols: Dict[str, str] = SYMBOLS[fmt]
624
625    # Helper function to colorize text
626    def colorize(text: str, color_key: str) -> str:
627        if using_tex:
628            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
629        else:
630            return (
631                f"{colors[color_key]}{text}{colors['reset']}"
632                if colors[color_key]
633                else text
634            )
635
636    # Check if dtype is integer type
637    dtype_str: str = array_data.get("dtype", "")
638    is_int_dtype: bool = any(
639        int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"]
640    )
641
642    # Format string for numbers
643    float_fmt: str = f".{precision}f"
644
645    # Handle error status or empty array
646    if (
647        array_data["status"] in ["empty array", "all NaN", "unknown"]
648        or array_data["size"] == 0
649    ):
650        status = array_data["status"]
651        result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
652    else:
653        # Add NaN warning at the beginning if there are NaNs
654        if array_data["has_nans"]:
655            _percent: str = "\\%" if using_tex else "%"
656            nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
657            result_parts.append(colorize(nan_str, "warning"))
658
659        # Statistics
660        if stats:
661            for stat_key in ["mean", "std", "median"]:
662                if array_data[stat_key] is not None:
663                    stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
664                    stat_colored: str = colorize(stat_str, stat_key)
665                    result_parts.append(f"{symbols[stat_key]}={stat_colored}")
666
667            # Range (min, max)
668            if array_data["range"] is not None:
669                min_val, max_val = array_data["range"]
670                if is_int_dtype:
671                    min_str: str = f"{int(min_val):d}"
672                    max_str: str = f"{int(max_val):d}"
673                else:
674                    min_str = f"{min_val:{float_fmt}}"
675                    max_str = f"{max_val:{float_fmt}}"
676                min_colored: str = colorize(min_str, "range")
677                max_colored: str = colorize(max_str, "range")
678                range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
679                result_parts.append(range_str)
680
681    # Add sparkline if requested
682    if sparkline and array_data["histogram"] is not None:
683        # this should return whether log_y is used or not and then we set the symbol accordingly
684        spark, used_log = generate_sparkline(
685            array_data["histogram"],
686            format=fmt,
687            log_y=sparkline_logy,
688        )
689        if spark:
690            spark_colored = colorize(spark, "sparkline")
691            dist_symbol = (
692                symbols["distribution_log"] if used_log else symbols["distribution"]
693            )
694            result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|")
695
696    # Add shape if requested
697    if shape and array_data["shape"]:
698        shape_val = array_data["shape"]
699        shape_str = format_shape_colored(shape_val, colors, using_tex)
700        result_parts.append(f"shape{eq_char}{shape_str}")
701
702    # Add dtype if requested
703    if dtype and array_data["dtype"]:
704        dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex)
705        result_parts.append(f"dtype={dtype_colored}")
706
707    # Add device if requested and it's a tensor with device info
708    if device and array_data["is_tensor"] and array_data["device"]:
709        device_colored = format_device_colored(array_data["device"], colors, using_tex)
710        result_parts.append(f"device{eq_char}{device_colored}")
711
712    # Add gradient info
713    if requires_grad and array_data["is_tensor"]:
714        bool_req_grad_symb: str = (
715            symbols["true"] if array_data["requires_grad"] else symbols["false"]
716        )
717        result_parts.append(
718            colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
719        )
720
721    # Return as list if requested, otherwise join with spaces
722    if as_list:
723        return result_parts
724    else:
725        joinchar: str = r" \quad " if using_tex else " "
726        return joinchar.join(result_parts)

Format array information into a readable summary.

Parameters:

  • array array-like object (numpy array or torch tensor)
  • precision : int Decimal places (defaults to 2)
  • format : Literal["unicode", "latex", "ascii"] Output format (defaults to {default_fmt})
  • stats : bool Whether to include statistical info (μ, σ, x̃) (defaults to True)
  • shape : bool Whether to include shape info (defaults to True)
  • dtype : bool Whether to include dtype info (defaults to True)
  • device : bool Whether to include device info for torch tensors (defaults to True)
  • requires_grad : bool Whether to include requires_grad info for torch tensors (defaults to True)
  • sparkline : bool Whether to include a sparkline visualization (defaults to False)
  • sparkline_width : int Width of the sparkline (defaults to 20)
  • sparkline_logy : bool|None Whether to use logarithmic y-scale for sparkline (defaults to None)
  • colored : bool Whether to add color to output (defaults to False)
  • as_list : bool Whether to return as list of strings instead of joined string (defaults to False)

Returns:

  • Union[str, List[str]] Formatted statistical summary, either as string or list of strings