muutils.tensor_info
get metadata about a tensor, mostly for muutils.dbg
1"get metadata about a tensor, mostly for `muutils.dbg`" 2 3from __future__ import annotations 4 5import numpy as np 6from typing import Union, Any, Literal, List, Dict, overload, Optional, TYPE_CHECKING 7 8if TYPE_CHECKING: 9 from typing import TypedDict 10else: 11 try: 12 from typing import TypedDict 13 except ImportError: 14 from typing_extensions import TypedDict 15 16# Global color definitions 17COLORS: Dict[str, Dict[str, str]] = { 18 "latex": { 19 "range": r"\textcolor{purple}", 20 "mean": r"\textcolor{teal}", 21 "std": r"\textcolor{orange}", 22 "median": r"\textcolor{green}", 23 "warning": r"\textcolor{red}", 24 "shape": r"\textcolor{magenta}", 25 "dtype": r"\textcolor{gray}", 26 "device": r"\textcolor{gray}", 27 "requires_grad": r"\textcolor{gray}", 28 "sparkline": r"\textcolor{blue}", 29 "torch": r"\textcolor{orange}", 30 "dtype_bool": r"\textcolor{gray}", 31 "dtype_int": r"\textcolor{blue}", 32 "dtype_float": r"\textcolor{red!70}", # 70% red intensity 33 "dtype_str": r"\textcolor{red}", 34 "device_cuda": r"\textcolor{green}", 35 "reset": "", 36 }, 37 "terminal": { 38 "range": "\033[35m", # purple 39 "mean": "\033[36m", # cyan/teal 40 "std": "\033[33m", # yellow/orange 41 "median": "\033[32m", # green 42 "warning": "\033[31m", # red 43 "shape": "\033[95m", # bright magenta 44 "dtype": "\033[90m", # gray 45 "device": "\033[90m", # gray 46 "requires_grad": "\033[90m", # gray 47 "sparkline": "\033[34m", # blue 48 "torch": "\033[38;5;208m", # bright orange 49 "dtype_bool": "\033[38;5;245m", # medium grey 50 "dtype_int": "\033[38;5;39m", # bright blue 51 "dtype_float": "\033[38;5;167m", # softer red/coral 52 "device_cuda": "\033[38;5;76m", # NVIDIA-style bright green 53 "reset": "\033[0m", 54 }, 55 "none": { 56 "range": "", 57 "mean": "", 58 "std": "", 59 "median": "", 60 "warning": "", 61 "shape": "", 62 "dtype": "", 63 "device": "", 64 "requires_grad": "", 65 "sparkline": "", 66 "torch": "", 67 "dtype_bool": "", 68 "dtype_int": "", 69 "dtype_float": "", 70 "dtype_str": "", 71 "device_cuda": "", 72 "reset": "", 73 }, 74} 75 76OutputFormat = Literal["unicode", "latex", "ascii"] 77 78 79class ArraySummarySettings(TypedDict): 80 """Type definition for array_summary default settings.""" 81 82 fmt: OutputFormat 83 precision: int 84 stats: bool 85 shape: bool 86 dtype: bool 87 device: bool 88 requires_grad: bool 89 sparkline: bool 90 sparkline_bins: int 91 sparkline_logy: Optional[bool] 92 colored: bool 93 as_list: bool 94 eq_char: str 95 96 97SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { 98 "latex": { 99 "range": r"\mathcal{R}", 100 "mean": r"\mu", 101 "std": r"\sigma", 102 "median": r"\tilde{x}", 103 "distribution": r"\mathbb{P}", 104 "distribution_log": r"\mathbb{P}_L", 105 "nan_values": r"\text{NANvals}", 106 "warning": "!!!", 107 "requires_grad": r"\nabla", 108 "true": r"\checkmark", 109 "false": r"\times", 110 }, 111 "unicode": { 112 "range": "R", 113 "mean": "μ", 114 "std": "σ", 115 "median": "x̃", 116 "distribution": "ℙ", 117 "distribution_log": "ℙ˪", 118 "nan_values": "NANvals", 119 "warning": "🚨", 120 "requires_grad": "∇", 121 "true": "✓", 122 "false": "✗", 123 }, 124 "ascii": { 125 "range": "range", 126 "mean": "mean", 127 "std": "std", 128 "median": "med", 129 "distribution": "dist", 130 "distribution_log": "dist_log", 131 "nan_values": "NANvals", 132 "warning": "!!!", 133 "requires_grad": "requires_grad", 134 "true": "1", 135 "false": "0", 136 }, 137} 138"Symbols for different formats" 139 140SPARK_CHARS: Dict[OutputFormat, List[str]] = { 141 "unicode": list(" ▁▂▃▄▅▆▇█"), 142 "ascii": list(" _.-~=#"), 143 "latex": list(" ▁▂▃▄▅▆▇█"), 144} 145"characters for sparklines in different formats" 146 147 148def array_info( 149 A: Any, 150 hist_bins: int = 5, 151) -> Dict[str, Any]: 152 """Extract statistical information from an array-like object. 153 154 # Parameters: 155 - `A : array-like` 156 Array to analyze (numpy array or torch tensor) 157 158 # Returns: 159 - `Dict[str, Any]` 160 Dictionary containing raw statistical information with numeric values 161 """ 162 result: Dict[str, Any] = { 163 "is_tensor": None, 164 "device": None, 165 "requires_grad": None, 166 "shape": None, 167 "dtype": None, 168 "size": None, 169 "has_nans": None, 170 "nan_count": None, 171 "nan_percent": None, 172 "min": None, 173 "max": None, 174 "range": None, 175 "mean": None, 176 "std": None, 177 "median": None, 178 "histogram": None, 179 "bins": None, 180 "status": None, 181 } 182 183 # Check if it's a tensor by looking at its class name 184 # This avoids importing torch directly 185 A_type: str = type(A).__name__ 186 result["is_tensor"] = A_type == "Tensor" 187 188 # Try to get device information if it's a tensor 189 if result["is_tensor"]: 190 try: 191 result["device"] = str(getattr(A, "device", None)) 192 except: # noqa: E722 193 pass 194 195 # Convert to numpy array for calculations 196 try: 197 # For PyTorch tensors 198 if result["is_tensor"]: 199 # Check if tensor is on GPU 200 is_cuda: bool = False 201 try: 202 is_cuda = bool(getattr(A, "is_cuda", False)) 203 except: # noqa: E722 204 pass 205 206 if is_cuda: 207 try: 208 # Try to get CPU tensor first 209 cpu_tensor = getattr(A, "cpu", lambda: A)() 210 except: # noqa: E722 211 A_np = np.array([]) 212 else: 213 cpu_tensor = A 214 try: 215 # For CPU tensor, just detach and convert 216 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() # pyright: ignore[reportPossiblyUnboundVariable] 217 A_np = getattr(detached, "numpy", lambda: np.array([]))() 218 except: # noqa: E722 219 A_np = np.array([]) 220 else: 221 # For numpy arrays and other array-like objects 222 A_np = np.asarray(A) 223 except: # noqa: E722 224 A_np = np.array([]) 225 226 # Get basic information 227 try: 228 result["shape"] = A_np.shape 229 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 230 result["size"] = A_np.size 231 result["requires_grad"] = getattr(A, "requires_grad", None) 232 except: # noqa: E722 233 pass 234 235 # If array is empty, return early 236 if result["size"] == 0: 237 result["status"] = "empty array" 238 return result 239 240 # Flatten array for statistics if it's multi-dimensional 241 # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225 242 try: 243 if len(A_np.shape) > 1: 244 A_flat = A_np.flatten() # type: ignore[assignment] 245 else: 246 A_flat = A_np # type: ignore[assignment] 247 except: # noqa: E722 248 A_flat = A_np # type: ignore[assignment] 249 250 # Check for NaN values 251 try: 252 nan_mask = np.isnan(A_flat) 253 result["nan_count"] = np.sum(nan_mask) 254 result["has_nans"] = result["nan_count"] > 0 255 result_size: int = result["size"] # ty: ignore[invalid-assignment] 256 if result_size > 0: 257 result["nan_percent"] = (result["nan_count"] / result_size) * 100 258 except: # noqa: E722 259 pass 260 261 # If all values are NaN, return early 262 if result["has_nans"] and result["nan_count"] == result["size"]: 263 result["status"] = "all NaN" 264 return result 265 266 # Calculate statistics 267 try: 268 if result["has_nans"]: 269 result["min"] = float(np.nanmin(A_flat)) 270 result["max"] = float(np.nanmax(A_flat)) 271 result["mean"] = float(np.nanmean(A_flat)) 272 result["std"] = float(np.nanstd(A_flat)) 273 result["median"] = float(np.nanmedian(A_flat)) 274 result["range"] = (result["min"], result["max"]) 275 276 # Remove NaNs for histogram 277 # TYPING: nan mask will def be bound on this branch, idk why it thinks the operator is bad 278 A_hist = A_flat[~nan_mask] # pyright: ignore[reportOperatorIssue, reportPossiblyUnboundVariable] 279 else: 280 result["min"] = float(np.min(A_flat)) 281 result["max"] = float(np.max(A_flat)) 282 result["mean"] = float(np.mean(A_flat)) 283 result["std"] = float(np.std(A_flat)) 284 result["median"] = float(np.median(A_flat)) 285 result["range"] = (result["min"], result["max"]) 286 287 A_hist = A_flat 288 289 # Calculate histogram data for sparklines 290 if A_hist.size > 0: 291 try: 292 # TODO: handle bool tensors correctly 293 # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. 294 hist, bins = np.histogram(A_hist, bins=hist_bins) 295 result["histogram"] = hist 296 result["bins"] = bins 297 except: # noqa: E722 298 pass 299 300 result["status"] = "ok" 301 except Exception as e: 302 result["status"] = f"error: {str(e)}" 303 304 return result 305 306 307SparklineFormat = Literal["unicode", "latex", "ascii"] 308 309 310def generate_sparkline( 311 histogram: np.ndarray, 312 format: SparklineFormat = "unicode", 313 log_y: Optional[bool] = None, 314) -> tuple[str, bool]: 315 """Generate a sparkline visualization of the histogram. 316 317 # Parameters: 318 - `histogram : np.ndarray` 319 Histogram data 320 - `format : Literal["unicode", "latex", "ascii"]` 321 Output format (defaults to `"unicode"`) 322 - `log_y : bool|None` 323 Whether to use logarithmic y-scale. `None` for automatic detection 324 (defaults to `None`) 325 326 # Returns: 327 - `tuple[str, bool]` 328 Sparkline visualization and whether log scale was used 329 """ 330 if histogram is None or len(histogram) == 0: 331 return "", False 332 333 # Get the appropriate character set 334 chars: List[str] 335 if format in SPARK_CHARS: 336 chars = SPARK_CHARS[format] 337 else: 338 chars = SPARK_CHARS["ascii"] 339 340 # automatic detection of log_y 341 if log_y is None: 342 # we bin the histogram values to the number of levels in our sparkline characters 343 hist_hist = np.histogram(histogram, bins=len(chars))[0] 344 # if every bin except the smallest (first) and largest (last) is empty, 345 # then we should use the log scale. if those bins are nonempty, keep the linear scale 346 if hist_hist[1:-1].max() > 0: 347 log_y = False 348 else: 349 log_y = True 350 351 # Handle log scale 352 if log_y: 353 # Add small value to avoid log(0) 354 hist_data = np.log1p(histogram) 355 else: 356 hist_data = histogram 357 358 # Normalize to character set range 359 if hist_data.max() > 0: 360 normalized = hist_data / hist_data.max() * (len(chars) - 1) 361 else: 362 normalized = np.zeros_like(hist_data) 363 364 # Convert to characters 365 spark = "" 366 for val in normalized: 367 idx = round(val) 368 spark += chars[idx] 369 370 return spark, log_y 371 372 373DEFAULT_SETTINGS: ArraySummarySettings = { 374 "fmt": "unicode", 375 "precision": 2, 376 "stats": True, 377 "shape": True, 378 "dtype": True, 379 "device": True, 380 "requires_grad": True, 381 "sparkline": False, 382 "sparkline_bins": 5, 383 "sparkline_logy": None, 384 "colored": False, 385 "as_list": False, 386 "eq_char": "=", 387} 388 389 390def apply_color( 391 text: str, color_key: str, colors: Dict[str, str], using_tex: bool 392) -> str: 393 if using_tex: 394 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 395 else: 396 return ( 397 f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text 398 ) 399 400 401def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str: 402 """Colorize dtype string with specific colors for torch and type names.""" 403 404 # Handle torch prefix 405 type_part: str = dtype_str 406 prefix_part: Optional[str] = None 407 if "torch." in dtype_str: 408 parts = dtype_str.split("torch.") 409 if len(parts) == 2: 410 prefix_part = apply_color("torch", "torch", colors, using_tex) 411 type_part = parts[1] 412 413 # Handle type coloring 414 color_key: str = "dtype" 415 if "bool" in dtype_str.lower(): 416 color_key = "dtype_bool" 417 elif "int" in dtype_str.lower(): 418 color_key = "dtype_int" 419 elif "float" in dtype_str.lower(): 420 color_key = "dtype_float" 421 422 type_colored: str = apply_color(type_part, color_key, colors, using_tex) 423 424 if prefix_part: 425 return f"{prefix_part}.{type_colored}" 426 else: 427 return type_colored 428 429 430def format_shape_colored( 431 shape_val: Any, colors: Dict[str, str], using_tex: bool 432) -> str: 433 """Format shape with proper coloring for both 1D and multi-D arrays.""" 434 435 def apply_color(text: str, color_key: str) -> str: 436 if using_tex: 437 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 438 else: 439 return ( 440 f"{colors[color_key]}{text}{colors['reset']}" 441 if colors[color_key] 442 else text 443 ) 444 445 if len(shape_val) == 1: 446 # For 1D arrays, still color the dimension value 447 return apply_color(str(shape_val[0]), "shape") 448 else: 449 # For multi-D arrays, color each dimension 450 return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")" 451 452 453def format_device_colored( 454 device_str: str, colors: Dict[str, str], using_tex: bool 455) -> str: 456 """Format device string with CUDA highlighting.""" 457 458 def apply_color(text: str, color_key: str) -> str: 459 if using_tex: 460 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 461 else: 462 return ( 463 f"{colors[color_key]}{text}{colors['reset']}" 464 if colors[color_key] 465 else text 466 ) 467 468 if "cuda" in device_str.lower(): 469 return apply_color(device_str, "device_cuda") 470 else: 471 return apply_color(device_str, "device") 472 473 474class _UseDefaultType: 475 pass 476 477 478_USE_DEFAULT = _UseDefaultType() 479 480 481@overload 482def array_summary( 483 array: Any, 484 fmt: OutputFormat = ..., 485 precision: int = ..., 486 stats: bool = ..., 487 shape: bool = ..., 488 dtype: bool = ..., 489 device: bool = ..., 490 requires_grad: bool = ..., 491 sparkline: bool = ..., 492 sparkline_bins: int = ..., 493 sparkline_logy: Optional[bool] = ..., 494 colored: bool = ..., 495 eq_char: str = ..., 496 *, 497 as_list: Literal[True], 498) -> List[str]: ... 499@overload 500def array_summary( 501 array: Any, 502 fmt: OutputFormat = ..., 503 precision: int = ..., 504 stats: bool = ..., 505 shape: bool = ..., 506 dtype: bool = ..., 507 device: bool = ..., 508 requires_grad: bool = ..., 509 sparkline: bool = ..., 510 sparkline_bins: int = ..., 511 sparkline_logy: Optional[bool] = ..., 512 colored: bool = ..., 513 eq_char: str = ..., 514 as_list: Literal[False] = ..., 515) -> str: ... 516@overload 517def array_summary( 518 array: Any, 519 fmt: OutputFormat = ..., 520 precision: int = ..., 521 stats: bool = ..., 522 shape: bool = ..., 523 dtype: bool = ..., 524 device: bool = ..., 525 requires_grad: bool = ..., 526 sparkline: bool = ..., 527 sparkline_bins: int = ..., 528 sparkline_logy: Optional[bool] = ..., 529 colored: bool = ..., 530 eq_char: str = ..., 531 as_list: bool = ..., 532) -> Union[str, List[str]]: ... 533def array_summary( 534 array: Any, 535 fmt: Union[OutputFormat, _UseDefaultType] = _USE_DEFAULT, 536 precision: Union[int, _UseDefaultType] = _USE_DEFAULT, 537 stats: Union[bool, _UseDefaultType] = _USE_DEFAULT, 538 shape: Union[bool, _UseDefaultType] = _USE_DEFAULT, 539 dtype: Union[bool, _UseDefaultType] = _USE_DEFAULT, 540 device: Union[bool, _UseDefaultType] = _USE_DEFAULT, 541 requires_grad: Union[bool, _UseDefaultType] = _USE_DEFAULT, 542 sparkline: Union[bool, _UseDefaultType] = _USE_DEFAULT, 543 sparkline_bins: Union[int, _UseDefaultType] = _USE_DEFAULT, 544 sparkline_logy: Union[Optional[bool], _UseDefaultType] = _USE_DEFAULT, 545 colored: Union[bool, _UseDefaultType] = _USE_DEFAULT, 546 eq_char: Union[str, _UseDefaultType] = _USE_DEFAULT, 547 as_list: Union[bool, _UseDefaultType] = _USE_DEFAULT, 548) -> Union[str, List[str]]: 549 """Format array information into a readable summary. 550 551 # Parameters: 552 - `array` 553 array-like object (numpy array or torch tensor) 554 - `precision : int` 555 Decimal places (defaults to `2`) 556 - `format : Literal["unicode", "latex", "ascii"]` 557 Output format (defaults to `{default_fmt}`) 558 - `stats : bool` 559 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 560 - `shape : bool` 561 Whether to include shape info (defaults to `True`) 562 - `dtype : bool` 563 Whether to include dtype info (defaults to `True`) 564 - `device : bool` 565 Whether to include device info for torch tensors (defaults to `True`) 566 - `requires_grad : bool` 567 Whether to include requires_grad info for torch tensors (defaults to `True`) 568 - `sparkline : bool` 569 Whether to include a sparkline visualization (defaults to `False`) 570 - `sparkline_width : int` 571 Width of the sparkline (defaults to `20`) 572 - `sparkline_logy : bool|None` 573 Whether to use logarithmic y-scale for sparkline (defaults to `None`) 574 - `colored : bool` 575 Whether to add color to output (defaults to `False`) 576 - `as_list : bool` 577 Whether to return as list of strings instead of joined string (defaults to `False`) 578 579 # Returns: 580 - `Union[str, List[str]]` 581 Formatted statistical summary, either as string or list of strings 582 """ 583 if isinstance(fmt, _UseDefaultType): 584 fmt = DEFAULT_SETTINGS["fmt"] 585 if isinstance(precision, _UseDefaultType): 586 precision = DEFAULT_SETTINGS["precision"] 587 if isinstance(stats, _UseDefaultType): 588 stats = DEFAULT_SETTINGS["stats"] 589 if isinstance(shape, _UseDefaultType): 590 shape = DEFAULT_SETTINGS["shape"] 591 if isinstance(dtype, _UseDefaultType): 592 dtype = DEFAULT_SETTINGS["dtype"] 593 if isinstance(device, _UseDefaultType): 594 device = DEFAULT_SETTINGS["device"] 595 if isinstance(requires_grad, _UseDefaultType): 596 requires_grad = DEFAULT_SETTINGS["requires_grad"] 597 if isinstance(sparkline, _UseDefaultType): 598 sparkline = DEFAULT_SETTINGS["sparkline"] 599 if isinstance(sparkline_bins, _UseDefaultType): 600 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 601 if isinstance(sparkline_logy, _UseDefaultType): 602 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 603 if isinstance(colored, _UseDefaultType): 604 colored = DEFAULT_SETTINGS["colored"] 605 if isinstance(as_list, _UseDefaultType): 606 as_list = DEFAULT_SETTINGS["as_list"] 607 if isinstance(eq_char, _UseDefaultType): 608 eq_char = DEFAULT_SETTINGS["eq_char"] 609 610 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 611 result_parts: List[str] = [] 612 using_tex: bool = fmt == "latex" 613 614 # Set color scheme based on format and colored flag 615 colors: Dict[str, str] 616 if colored: 617 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 618 else: 619 colors = COLORS["none"] 620 621 # Get symbols for the current format 622 symbols: Dict[str, str] = SYMBOLS[fmt] 623 624 # Helper function to colorize text 625 def colorize(text: str, color_key: str) -> str: 626 if using_tex: 627 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 628 else: 629 return ( 630 f"{colors[color_key]}{text}{colors['reset']}" 631 if colors[color_key] 632 else text 633 ) 634 635 # Check if dtype is integer type 636 dtype_str: str = array_data.get("dtype", "") 637 is_int_dtype: bool = any( 638 int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"] 639 ) 640 641 # Format string for numbers 642 float_fmt: str = f".{precision}f" 643 644 # Handle error status or empty array 645 if ( 646 array_data["status"] in ["empty array", "all NaN", "unknown"] 647 or array_data["size"] == 0 648 ): 649 status = array_data["status"] 650 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 651 else: 652 # Add NaN warning at the beginning if there are NaNs 653 if array_data["has_nans"]: 654 _percent: str = "\\%" if using_tex else "%" 655 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 656 result_parts.append(colorize(nan_str, "warning")) 657 658 # Statistics 659 if stats: 660 for stat_key in ["mean", "std", "median"]: 661 if array_data[stat_key] is not None: 662 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 663 stat_colored: str = colorize(stat_str, stat_key) 664 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 665 666 # Range (min, max) 667 if array_data["range"] is not None: 668 min_val, max_val = array_data["range"] 669 if is_int_dtype: 670 min_str: str = f"{int(min_val):d}" 671 max_str: str = f"{int(max_val):d}" 672 else: 673 min_str = f"{min_val:{float_fmt}}" 674 max_str = f"{max_val:{float_fmt}}" 675 min_colored: str = colorize(min_str, "range") 676 max_colored: str = colorize(max_str, "range") 677 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 678 result_parts.append(range_str) 679 680 # Add sparkline if requested 681 if sparkline and array_data["histogram"] is not None: 682 # this should return whether log_y is used or not and then we set the symbol accordingly 683 spark, used_log = generate_sparkline( 684 array_data["histogram"], 685 format=fmt, 686 log_y=sparkline_logy, 687 ) 688 if spark: 689 spark_colored = colorize(spark, "sparkline") 690 dist_symbol = ( 691 symbols["distribution_log"] if used_log else symbols["distribution"] 692 ) 693 result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|") 694 695 # Add shape if requested 696 if shape and array_data["shape"]: 697 shape_val = array_data["shape"] 698 shape_str = format_shape_colored(shape_val, colors, using_tex) 699 result_parts.append(f"shape{eq_char}{shape_str}") 700 701 # Add dtype if requested 702 if dtype and array_data["dtype"]: 703 dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex) 704 result_parts.append(f"dtype={dtype_colored}") 705 706 # Add device if requested and it's a tensor with device info 707 if device and array_data["is_tensor"] and array_data["device"]: 708 device_colored = format_device_colored(array_data["device"], colors, using_tex) 709 result_parts.append(f"device{eq_char}{device_colored}") 710 711 # Add gradient info 712 if requires_grad and array_data["is_tensor"]: 713 bool_req_grad_symb: str = ( 714 symbols["true"] if array_data["requires_grad"] else symbols["false"] 715 ) 716 result_parts.append( 717 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") 718 ) 719 720 # Return as list if requested, otherwise join with spaces 721 if as_list: 722 return result_parts 723 else: 724 joinchar: str = r" \quad " if using_tex else " " 725 return joinchar.join(result_parts)
COLORS: Dict[str, Dict[str, str]] =
{'latex': {'range': '\\textcolor{purple}', 'mean': '\\textcolor{teal}', 'std': '\\textcolor{orange}', 'median': '\\textcolor{green}', 'warning': '\\textcolor{red}', 'shape': '\\textcolor{magenta}', 'dtype': '\\textcolor{gray}', 'device': '\\textcolor{gray}', 'requires_grad': '\\textcolor{gray}', 'sparkline': '\\textcolor{blue}', 'torch': '\\textcolor{orange}', 'dtype_bool': '\\textcolor{gray}', 'dtype_int': '\\textcolor{blue}', 'dtype_float': '\\textcolor{red!70}', 'dtype_str': '\\textcolor{red}', 'device_cuda': '\\textcolor{green}', 'reset': ''}, 'terminal': {'range': '\x1b[35m', 'mean': '\x1b[36m', 'std': '\x1b[33m', 'median': '\x1b[32m', 'warning': '\x1b[31m', 'shape': '\x1b[95m', 'dtype': '\x1b[90m', 'device': '\x1b[90m', 'requires_grad': '\x1b[90m', 'sparkline': '\x1b[34m', 'torch': '\x1b[38;5;208m', 'dtype_bool': '\x1b[38;5;245m', 'dtype_int': '\x1b[38;5;39m', 'dtype_float': '\x1b[38;5;167m', 'device_cuda': '\x1b[38;5;76m', 'reset': '\x1b[0m'}, 'none': {'range': '', 'mean': '', 'std': '', 'median': '', 'warning': '', 'shape': '', 'dtype': '', 'device': '', 'requires_grad': '', 'sparkline': '', 'torch': '', 'dtype_bool': '', 'dtype_int': '', 'dtype_float': '', 'dtype_str': '', 'device_cuda': '', 'reset': ''}}
OutputFormat =
typing.Literal['unicode', 'latex', 'ascii']
class
ArraySummarySettings(typing.TypedDict):
80class ArraySummarySettings(TypedDict): 81 """Type definition for array_summary default settings.""" 82 83 fmt: OutputFormat 84 precision: int 85 stats: bool 86 shape: bool 87 dtype: bool 88 device: bool 89 requires_grad: bool 90 sparkline: bool 91 sparkline_bins: int 92 sparkline_logy: Optional[bool] 93 colored: bool 94 as_list: bool 95 eq_char: str
Type definition for array_summary default settings.
Inherited Members
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- update
- fromkeys
- clear
- copy
SYMBOLS: Dict[Literal['unicode', 'latex', 'ascii'], Dict[str, str]] =
{'latex': {'range': '\\mathcal{R}', 'mean': '\\mu', 'std': '\\sigma', 'median': '\\tilde{x}', 'distribution': '\\mathbb{P}', 'distribution_log': '\\mathbb{P}_L', 'nan_values': '\\text{NANvals}', 'warning': '!!!', 'requires_grad': '\\nabla', 'true': '\\checkmark', 'false': '\\times'}, 'unicode': {'range': 'R', 'mean': 'μ', 'std': 'σ', 'median': 'x̃', 'distribution': 'ℙ', 'distribution_log': 'ℙ˪', 'nan_values': 'NANvals', 'warning': '🚨', 'requires_grad': '∇', 'true': '✓', 'false': '✗'}, 'ascii': {'range': 'range', 'mean': 'mean', 'std': 'std', 'median': 'med', 'distribution': 'dist', 'distribution_log': 'dist_log', 'nan_values': 'NANvals', 'warning': '!!!', 'requires_grad': 'requires_grad', 'true': '1', 'false': '0'}}
Symbols for different formats
SPARK_CHARS: Dict[Literal['unicode', 'latex', 'ascii'], List[str]] =
{'unicode': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], 'ascii': [' ', '_', '.', '-', '~', '=', '#'], 'latex': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']}
characters for sparklines in different formats
def
array_info(A: Any, hist_bins: int = 5) -> Dict[str, Any]:
149def array_info( 150 A: Any, 151 hist_bins: int = 5, 152) -> Dict[str, Any]: 153 """Extract statistical information from an array-like object. 154 155 # Parameters: 156 - `A : array-like` 157 Array to analyze (numpy array or torch tensor) 158 159 # Returns: 160 - `Dict[str, Any]` 161 Dictionary containing raw statistical information with numeric values 162 """ 163 result: Dict[str, Any] = { 164 "is_tensor": None, 165 "device": None, 166 "requires_grad": None, 167 "shape": None, 168 "dtype": None, 169 "size": None, 170 "has_nans": None, 171 "nan_count": None, 172 "nan_percent": None, 173 "min": None, 174 "max": None, 175 "range": None, 176 "mean": None, 177 "std": None, 178 "median": None, 179 "histogram": None, 180 "bins": None, 181 "status": None, 182 } 183 184 # Check if it's a tensor by looking at its class name 185 # This avoids importing torch directly 186 A_type: str = type(A).__name__ 187 result["is_tensor"] = A_type == "Tensor" 188 189 # Try to get device information if it's a tensor 190 if result["is_tensor"]: 191 try: 192 result["device"] = str(getattr(A, "device", None)) 193 except: # noqa: E722 194 pass 195 196 # Convert to numpy array for calculations 197 try: 198 # For PyTorch tensors 199 if result["is_tensor"]: 200 # Check if tensor is on GPU 201 is_cuda: bool = False 202 try: 203 is_cuda = bool(getattr(A, "is_cuda", False)) 204 except: # noqa: E722 205 pass 206 207 if is_cuda: 208 try: 209 # Try to get CPU tensor first 210 cpu_tensor = getattr(A, "cpu", lambda: A)() 211 except: # noqa: E722 212 A_np = np.array([]) 213 else: 214 cpu_tensor = A 215 try: 216 # For CPU tensor, just detach and convert 217 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() # pyright: ignore[reportPossiblyUnboundVariable] 218 A_np = getattr(detached, "numpy", lambda: np.array([]))() 219 except: # noqa: E722 220 A_np = np.array([]) 221 else: 222 # For numpy arrays and other array-like objects 223 A_np = np.asarray(A) 224 except: # noqa: E722 225 A_np = np.array([]) 226 227 # Get basic information 228 try: 229 result["shape"] = A_np.shape 230 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 231 result["size"] = A_np.size 232 result["requires_grad"] = getattr(A, "requires_grad", None) 233 except: # noqa: E722 234 pass 235 236 # If array is empty, return early 237 if result["size"] == 0: 238 result["status"] = "empty array" 239 return result 240 241 # Flatten array for statistics if it's multi-dimensional 242 # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225 243 try: 244 if len(A_np.shape) > 1: 245 A_flat = A_np.flatten() # type: ignore[assignment] 246 else: 247 A_flat = A_np # type: ignore[assignment] 248 except: # noqa: E722 249 A_flat = A_np # type: ignore[assignment] 250 251 # Check for NaN values 252 try: 253 nan_mask = np.isnan(A_flat) 254 result["nan_count"] = np.sum(nan_mask) 255 result["has_nans"] = result["nan_count"] > 0 256 result_size: int = result["size"] # ty: ignore[invalid-assignment] 257 if result_size > 0: 258 result["nan_percent"] = (result["nan_count"] / result_size) * 100 259 except: # noqa: E722 260 pass 261 262 # If all values are NaN, return early 263 if result["has_nans"] and result["nan_count"] == result["size"]: 264 result["status"] = "all NaN" 265 return result 266 267 # Calculate statistics 268 try: 269 if result["has_nans"]: 270 result["min"] = float(np.nanmin(A_flat)) 271 result["max"] = float(np.nanmax(A_flat)) 272 result["mean"] = float(np.nanmean(A_flat)) 273 result["std"] = float(np.nanstd(A_flat)) 274 result["median"] = float(np.nanmedian(A_flat)) 275 result["range"] = (result["min"], result["max"]) 276 277 # Remove NaNs for histogram 278 # TYPING: nan mask will def be bound on this branch, idk why it thinks the operator is bad 279 A_hist = A_flat[~nan_mask] # pyright: ignore[reportOperatorIssue, reportPossiblyUnboundVariable] 280 else: 281 result["min"] = float(np.min(A_flat)) 282 result["max"] = float(np.max(A_flat)) 283 result["mean"] = float(np.mean(A_flat)) 284 result["std"] = float(np.std(A_flat)) 285 result["median"] = float(np.median(A_flat)) 286 result["range"] = (result["min"], result["max"]) 287 288 A_hist = A_flat 289 290 # Calculate histogram data for sparklines 291 if A_hist.size > 0: 292 try: 293 # TODO: handle bool tensors correctly 294 # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. 295 hist, bins = np.histogram(A_hist, bins=hist_bins) 296 result["histogram"] = hist 297 result["bins"] = bins 298 except: # noqa: E722 299 pass 300 301 result["status"] = "ok" 302 except Exception as e: 303 result["status"] = f"error: {str(e)}" 304 305 return result
Extract statistical information from an array-like object.
Parameters:
A : array-likeArray to analyze (numpy array or torch tensor)
Returns:
Dict[str, Any]Dictionary containing raw statistical information with numeric values
SparklineFormat =
typing.Literal['unicode', 'latex', 'ascii']
def
generate_sparkline( histogram: numpy.ndarray, format: Literal['unicode', 'latex', 'ascii'] = 'unicode', log_y: Optional[bool] = None) -> tuple[str, bool]:
311def generate_sparkline( 312 histogram: np.ndarray, 313 format: SparklineFormat = "unicode", 314 log_y: Optional[bool] = None, 315) -> tuple[str, bool]: 316 """Generate a sparkline visualization of the histogram. 317 318 # Parameters: 319 - `histogram : np.ndarray` 320 Histogram data 321 - `format : Literal["unicode", "latex", "ascii"]` 322 Output format (defaults to `"unicode"`) 323 - `log_y : bool|None` 324 Whether to use logarithmic y-scale. `None` for automatic detection 325 (defaults to `None`) 326 327 # Returns: 328 - `tuple[str, bool]` 329 Sparkline visualization and whether log scale was used 330 """ 331 if histogram is None or len(histogram) == 0: 332 return "", False 333 334 # Get the appropriate character set 335 chars: List[str] 336 if format in SPARK_CHARS: 337 chars = SPARK_CHARS[format] 338 else: 339 chars = SPARK_CHARS["ascii"] 340 341 # automatic detection of log_y 342 if log_y is None: 343 # we bin the histogram values to the number of levels in our sparkline characters 344 hist_hist = np.histogram(histogram, bins=len(chars))[0] 345 # if every bin except the smallest (first) and largest (last) is empty, 346 # then we should use the log scale. if those bins are nonempty, keep the linear scale 347 if hist_hist[1:-1].max() > 0: 348 log_y = False 349 else: 350 log_y = True 351 352 # Handle log scale 353 if log_y: 354 # Add small value to avoid log(0) 355 hist_data = np.log1p(histogram) 356 else: 357 hist_data = histogram 358 359 # Normalize to character set range 360 if hist_data.max() > 0: 361 normalized = hist_data / hist_data.max() * (len(chars) - 1) 362 else: 363 normalized = np.zeros_like(hist_data) 364 365 # Convert to characters 366 spark = "" 367 for val in normalized: 368 idx = round(val) 369 spark += chars[idx] 370 371 return spark, log_y
Generate a sparkline visualization of the histogram.
Parameters:
histogram : np.ndarrayHistogram dataformat : Literal["unicode", "latex", "ascii"]Output format (defaults to"unicode")log_y : bool|NoneWhether to use logarithmic y-scale.Nonefor automatic detection (defaults toNone)
Returns:
tuple[str, bool]Sparkline visualization and whether log scale was used
DEFAULT_SETTINGS: ArraySummarySettings =
{'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': False, 'sparkline_bins': 5, 'sparkline_logy': None, 'colored': False, 'as_list': False, 'eq_char': '='}
def
apply_color( text: str, color_key: str, colors: Dict[str, str], using_tex: bool) -> str:
391def apply_color( 392 text: str, color_key: str, colors: Dict[str, str], using_tex: bool 393) -> str: 394 if using_tex: 395 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 396 else: 397 return ( 398 f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text 399 )
def
colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
402def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str: 403 """Colorize dtype string with specific colors for torch and type names.""" 404 405 # Handle torch prefix 406 type_part: str = dtype_str 407 prefix_part: Optional[str] = None 408 if "torch." in dtype_str: 409 parts = dtype_str.split("torch.") 410 if len(parts) == 2: 411 prefix_part = apply_color("torch", "torch", colors, using_tex) 412 type_part = parts[1] 413 414 # Handle type coloring 415 color_key: str = "dtype" 416 if "bool" in dtype_str.lower(): 417 color_key = "dtype_bool" 418 elif "int" in dtype_str.lower(): 419 color_key = "dtype_int" 420 elif "float" in dtype_str.lower(): 421 color_key = "dtype_float" 422 423 type_colored: str = apply_color(type_part, color_key, colors, using_tex) 424 425 if prefix_part: 426 return f"{prefix_part}.{type_colored}" 427 else: 428 return type_colored
Colorize dtype string with specific colors for torch and type names.
def
format_shape_colored(shape_val: Any, colors: Dict[str, str], using_tex: bool) -> str:
431def format_shape_colored( 432 shape_val: Any, colors: Dict[str, str], using_tex: bool 433) -> str: 434 """Format shape with proper coloring for both 1D and multi-D arrays.""" 435 436 def apply_color(text: str, color_key: str) -> str: 437 if using_tex: 438 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 439 else: 440 return ( 441 f"{colors[color_key]}{text}{colors['reset']}" 442 if colors[color_key] 443 else text 444 ) 445 446 if len(shape_val) == 1: 447 # For 1D arrays, still color the dimension value 448 return apply_color(str(shape_val[0]), "shape") 449 else: 450 # For multi-D arrays, color each dimension 451 return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")"
Format shape with proper coloring for both 1D and multi-D arrays.
def
format_device_colored(device_str: str, colors: Dict[str, str], using_tex: bool) -> str:
454def format_device_colored( 455 device_str: str, colors: Dict[str, str], using_tex: bool 456) -> str: 457 """Format device string with CUDA highlighting.""" 458 459 def apply_color(text: str, color_key: str) -> str: 460 if using_tex: 461 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 462 else: 463 return ( 464 f"{colors[color_key]}{text}{colors['reset']}" 465 if colors[color_key] 466 else text 467 ) 468 469 if "cuda" in device_str.lower(): 470 return apply_color(device_str, "device_cuda") 471 else: 472 return apply_color(device_str, "device")
Format device string with CUDA highlighting.
def
array_summary( array: Any, fmt: Union[Literal['unicode', 'latex', 'ascii'], muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, precision: Union[int, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, stats: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, shape: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, dtype: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, device: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, requires_grad: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, sparkline: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, sparkline_bins: Union[int, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, sparkline_logy: Union[bool, NoneType, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, colored: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, eq_char: Union[str, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>, as_list: Union[bool, muutils.tensor_info._UseDefaultType] = <muutils.tensor_info._UseDefaultType object>) -> Union[str, List[str]]:
534def array_summary( 535 array: Any, 536 fmt: Union[OutputFormat, _UseDefaultType] = _USE_DEFAULT, 537 precision: Union[int, _UseDefaultType] = _USE_DEFAULT, 538 stats: Union[bool, _UseDefaultType] = _USE_DEFAULT, 539 shape: Union[bool, _UseDefaultType] = _USE_DEFAULT, 540 dtype: Union[bool, _UseDefaultType] = _USE_DEFAULT, 541 device: Union[bool, _UseDefaultType] = _USE_DEFAULT, 542 requires_grad: Union[bool, _UseDefaultType] = _USE_DEFAULT, 543 sparkline: Union[bool, _UseDefaultType] = _USE_DEFAULT, 544 sparkline_bins: Union[int, _UseDefaultType] = _USE_DEFAULT, 545 sparkline_logy: Union[Optional[bool], _UseDefaultType] = _USE_DEFAULT, 546 colored: Union[bool, _UseDefaultType] = _USE_DEFAULT, 547 eq_char: Union[str, _UseDefaultType] = _USE_DEFAULT, 548 as_list: Union[bool, _UseDefaultType] = _USE_DEFAULT, 549) -> Union[str, List[str]]: 550 """Format array information into a readable summary. 551 552 # Parameters: 553 - `array` 554 array-like object (numpy array or torch tensor) 555 - `precision : int` 556 Decimal places (defaults to `2`) 557 - `format : Literal["unicode", "latex", "ascii"]` 558 Output format (defaults to `{default_fmt}`) 559 - `stats : bool` 560 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 561 - `shape : bool` 562 Whether to include shape info (defaults to `True`) 563 - `dtype : bool` 564 Whether to include dtype info (defaults to `True`) 565 - `device : bool` 566 Whether to include device info for torch tensors (defaults to `True`) 567 - `requires_grad : bool` 568 Whether to include requires_grad info for torch tensors (defaults to `True`) 569 - `sparkline : bool` 570 Whether to include a sparkline visualization (defaults to `False`) 571 - `sparkline_width : int` 572 Width of the sparkline (defaults to `20`) 573 - `sparkline_logy : bool|None` 574 Whether to use logarithmic y-scale for sparkline (defaults to `None`) 575 - `colored : bool` 576 Whether to add color to output (defaults to `False`) 577 - `as_list : bool` 578 Whether to return as list of strings instead of joined string (defaults to `False`) 579 580 # Returns: 581 - `Union[str, List[str]]` 582 Formatted statistical summary, either as string or list of strings 583 """ 584 if isinstance(fmt, _UseDefaultType): 585 fmt = DEFAULT_SETTINGS["fmt"] 586 if isinstance(precision, _UseDefaultType): 587 precision = DEFAULT_SETTINGS["precision"] 588 if isinstance(stats, _UseDefaultType): 589 stats = DEFAULT_SETTINGS["stats"] 590 if isinstance(shape, _UseDefaultType): 591 shape = DEFAULT_SETTINGS["shape"] 592 if isinstance(dtype, _UseDefaultType): 593 dtype = DEFAULT_SETTINGS["dtype"] 594 if isinstance(device, _UseDefaultType): 595 device = DEFAULT_SETTINGS["device"] 596 if isinstance(requires_grad, _UseDefaultType): 597 requires_grad = DEFAULT_SETTINGS["requires_grad"] 598 if isinstance(sparkline, _UseDefaultType): 599 sparkline = DEFAULT_SETTINGS["sparkline"] 600 if isinstance(sparkline_bins, _UseDefaultType): 601 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 602 if isinstance(sparkline_logy, _UseDefaultType): 603 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 604 if isinstance(colored, _UseDefaultType): 605 colored = DEFAULT_SETTINGS["colored"] 606 if isinstance(as_list, _UseDefaultType): 607 as_list = DEFAULT_SETTINGS["as_list"] 608 if isinstance(eq_char, _UseDefaultType): 609 eq_char = DEFAULT_SETTINGS["eq_char"] 610 611 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 612 result_parts: List[str] = [] 613 using_tex: bool = fmt == "latex" 614 615 # Set color scheme based on format and colored flag 616 colors: Dict[str, str] 617 if colored: 618 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 619 else: 620 colors = COLORS["none"] 621 622 # Get symbols for the current format 623 symbols: Dict[str, str] = SYMBOLS[fmt] 624 625 # Helper function to colorize text 626 def colorize(text: str, color_key: str) -> str: 627 if using_tex: 628 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 629 else: 630 return ( 631 f"{colors[color_key]}{text}{colors['reset']}" 632 if colors[color_key] 633 else text 634 ) 635 636 # Check if dtype is integer type 637 dtype_str: str = array_data.get("dtype", "") 638 is_int_dtype: bool = any( 639 int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"] 640 ) 641 642 # Format string for numbers 643 float_fmt: str = f".{precision}f" 644 645 # Handle error status or empty array 646 if ( 647 array_data["status"] in ["empty array", "all NaN", "unknown"] 648 or array_data["size"] == 0 649 ): 650 status = array_data["status"] 651 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 652 else: 653 # Add NaN warning at the beginning if there are NaNs 654 if array_data["has_nans"]: 655 _percent: str = "\\%" if using_tex else "%" 656 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 657 result_parts.append(colorize(nan_str, "warning")) 658 659 # Statistics 660 if stats: 661 for stat_key in ["mean", "std", "median"]: 662 if array_data[stat_key] is not None: 663 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 664 stat_colored: str = colorize(stat_str, stat_key) 665 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 666 667 # Range (min, max) 668 if array_data["range"] is not None: 669 min_val, max_val = array_data["range"] 670 if is_int_dtype: 671 min_str: str = f"{int(min_val):d}" 672 max_str: str = f"{int(max_val):d}" 673 else: 674 min_str = f"{min_val:{float_fmt}}" 675 max_str = f"{max_val:{float_fmt}}" 676 min_colored: str = colorize(min_str, "range") 677 max_colored: str = colorize(max_str, "range") 678 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 679 result_parts.append(range_str) 680 681 # Add sparkline if requested 682 if sparkline and array_data["histogram"] is not None: 683 # this should return whether log_y is used or not and then we set the symbol accordingly 684 spark, used_log = generate_sparkline( 685 array_data["histogram"], 686 format=fmt, 687 log_y=sparkline_logy, 688 ) 689 if spark: 690 spark_colored = colorize(spark, "sparkline") 691 dist_symbol = ( 692 symbols["distribution_log"] if used_log else symbols["distribution"] 693 ) 694 result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|") 695 696 # Add shape if requested 697 if shape and array_data["shape"]: 698 shape_val = array_data["shape"] 699 shape_str = format_shape_colored(shape_val, colors, using_tex) 700 result_parts.append(f"shape{eq_char}{shape_str}") 701 702 # Add dtype if requested 703 if dtype and array_data["dtype"]: 704 dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex) 705 result_parts.append(f"dtype={dtype_colored}") 706 707 # Add device if requested and it's a tensor with device info 708 if device and array_data["is_tensor"] and array_data["device"]: 709 device_colored = format_device_colored(array_data["device"], colors, using_tex) 710 result_parts.append(f"device{eq_char}{device_colored}") 711 712 # Add gradient info 713 if requires_grad and array_data["is_tensor"]: 714 bool_req_grad_symb: str = ( 715 symbols["true"] if array_data["requires_grad"] else symbols["false"] 716 ) 717 result_parts.append( 718 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") 719 ) 720 721 # Return as list if requested, otherwise join with spaces 722 if as_list: 723 return result_parts 724 else: 725 joinchar: str = r" \quad " if using_tex else " " 726 return joinchar.join(result_parts)
Format array information into a readable summary.
Parameters:
arrayarray-like object (numpy array or torch tensor)precision : intDecimal places (defaults to2)format : Literal["unicode", "latex", "ascii"]Output format (defaults to{default_fmt})stats : boolWhether to include statistical info (μ, σ, x̃) (defaults toTrue)shape : boolWhether to include shape info (defaults toTrue)dtype : boolWhether to include dtype info (defaults toTrue)device : boolWhether to include device info for torch tensors (defaults toTrue)requires_grad : boolWhether to include requires_grad info for torch tensors (defaults toTrue)sparkline : boolWhether to include a sparkline visualization (defaults toFalse)sparkline_width : intWidth of the sparkline (defaults to20)sparkline_logy : bool|NoneWhether to use logarithmic y-scale for sparkline (defaults toNone)colored : boolWhether to add color to output (defaults toFalse)as_list : boolWhether to return as list of strings instead of joined string (defaults toFalse)
Returns:
Union[str, List[str]]Formatted statistical summary, either as string or list of strings