docs for muutils v0.9.1
View Source on GitHub

muutils.dbg

this code is based on an implementation of the Rust builtin dbg! for Python, originally from https://github.com/tylerwince/pydbg/blob/master/pydbg.py although it has been significantly modified

licensed under MIT:

Copyright (c) 2019 Tyler Wince

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


  1"""
  2
  3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from
  4https://github.com/tylerwince/pydbg/blob/master/pydbg.py
  5although it has been significantly modified
  6
  7licensed under MIT:
  8
  9Copyright (c) 2019 Tyler Wince
 10
 11Permission is hereby granted, free of charge, to any person obtaining a copy
 12of this software and associated documentation files (the "Software"), to deal
 13in the Software without restriction, including without limitation the rights
 14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 15copies of the Software, and to permit persons to whom the Software is
 16furnished to do so, subject to the following conditions:
 17
 18The above copyright notice and this permission notice shall be included in
 19all copies or substantial portions of the Software.
 20
 21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 27THE SOFTWARE.
 28
 29"""
 30
 31from __future__ import annotations
 32
 33import inspect
 34import sys
 35import typing
 36from pathlib import Path
 37import re
 38
 39# type defs
 40_ExpType = typing.TypeVar("_ExpType")
 41_ExpType_dict = typing.TypeVar(
 42    "_ExpType_dict", bound=typing.Dict[typing.Any, typing.Any]
 43)
 44_ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any])
 45
 46
 47# TypedDict definitions for configuration dictionaries
 48class DBGDictDefaultsType(typing.TypedDict):
 49    key_types: bool
 50    val_types: bool
 51    max_len: int
 52    indent: str
 53    max_depth: int
 54
 55
 56class DBGListDefaultsType(typing.TypedDict):
 57    max_len: int
 58    summary_show_types: bool
 59
 60
 61class DBGTensorArraySummaryDefaultsType(typing.TypedDict):
 62    fmt: typing.Literal["unicode", "latex", "ascii"]
 63    precision: int
 64    stats: bool
 65    shape: bool
 66    dtype: bool
 67    device: bool
 68    requires_grad: bool
 69    sparkline: bool
 70    sparkline_bins: int
 71    sparkline_logy: typing.Union[None, bool]
 72    colored: bool
 73    eq_char: str
 74
 75
 76# Sentinel type for no expression passed
 77class _NoExpPassedSentinel:
 78    """Unique sentinel type used to indicate that no expression was passed."""
 79
 80    pass
 81
 82
 83_NoExpPassed = _NoExpPassedSentinel()
 84
 85# global variables
 86_CWD: Path = Path.cwd().absolute()
 87_COUNTER: int = 0
 88_DBG_MODULE_FILE: str = str(Path(__file__).resolve())
 89
 90# configuration
 91PATH_MODE: typing.Literal["relative", "absolute"] = "relative"
 92DEFAULT_VAL_JOINER: str = " = "
 93
 94
 95# path processing
 96def _process_path(path: Path) -> str:
 97    path_abs: Path = path.absolute()
 98    fname: Path
 99    if PATH_MODE == "absolute":
100        fname = path_abs
101    elif PATH_MODE == "relative":
102        try:
103            # if it's inside the cwd, print the relative path
104            fname = path.relative_to(_CWD)
105        except ValueError:
106            # if its not in the subpath, use the absolute path
107            fname = path_abs
108    else:
109        raise ValueError("PATH_MODE must be either 'relative' or 'absolute")
110
111    return fname.as_posix()
112
113
114# actual dbg function
115@typing.overload
116def dbg() -> _NoExpPassedSentinel: ...
117@typing.overload
118def dbg(
119    exp: _NoExpPassedSentinel,
120    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
121    val_joiner: str = DEFAULT_VAL_JOINER,
122) -> _NoExpPassedSentinel: ...
123@typing.overload
124def dbg(
125    exp: _ExpType,
126    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
127    val_joiner: str = DEFAULT_VAL_JOINER,
128) -> _ExpType: ...
129def dbg(
130    exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
131    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
132    val_joiner: str = DEFAULT_VAL_JOINER,
133) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
134    """Call dbg with any variable or expression.
135
136    Calling dbg will print to stderr the current filename and lineno,
137    as well as the passed expression and what the expression evaluates to:
138
139            from muutils.dbg import dbg
140
141            a = 2
142            b = 5
143
144            dbg(a+b)
145
146            def square(x: int) -> int:
147                    return x * x
148
149            dbg(square(a))
150
151    """
152    global _COUNTER
153
154    # get the context
155    line_exp: str = "unknown"
156    current_file: str = "unknown"
157    dbg_frame: typing.Optional[inspect.FrameInfo] = None
158    for frame in inspect.stack():
159        if frame.code_context is None:
160            continue
161        if str(Path(frame.filename).resolve()) == _DBG_MODULE_FILE:
162            continue
163        line: str = frame.code_context[0]
164        if "dbg" in line:
165            current_file = _process_path(Path(frame.filename))
166            dbg_frame = frame
167            start: int = line.find("(") + 1
168            end: int = line.rfind(")")
169            if end == -1:
170                end = len(line)
171            line_exp = line[start:end]
172            break
173
174    fname: str = "unknown"
175    if current_file.startswith("/tmp/ipykernel_"):
176        stack: list[inspect.FrameInfo] = inspect.stack()
177        filtered_functions: list[str] = []
178        # this loop will find, in this order:
179        # - the dbg function call
180        # - the functions we care about displaying
181        # - `<module>`
182        # - a bunch of jupyter internals we don't care about
183        for frame_info in stack:
184            if _process_path(Path(frame_info.filename)) != current_file:
185                continue
186            if frame_info.function == "<module>":
187                break
188            if frame_info.function.startswith("dbg"):
189                continue
190            filtered_functions.append(frame_info.function)
191        if dbg_frame is not None:
192            filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}")
193        else:
194            filtered_functions.append(current_file)
195        filtered_functions.reverse()
196        fname = " -> ".join(filtered_functions)
197    elif dbg_frame is not None:
198        fname = f"{current_file}:{dbg_frame.lineno}"
199
200    # assemble the message
201    msg: str
202    if exp is _NoExpPassed:
203        # if no expression is passed, just show location and counter value
204        msg = f"[ {fname} ] <dbg {_COUNTER}>"
205        _COUNTER += 1
206    else:
207        # if expression passed, format its value and show location, expr, and value
208        exp_val: str = formatter(exp) if formatter else repr(exp)
209        msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"
210
211    # print the message
212    print(
213        msg,
214        file=sys.stderr,
215    )
216
217    # return the expression itself
218    return exp
219
220
221# formatted `dbg_*` functions with their helpers
222
223DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: DBGTensorArraySummaryDefaultsType = {
224    "fmt": "unicode",
225    "precision": 2,
226    "stats": True,
227    "shape": True,
228    "dtype": True,
229    "device": True,
230    "requires_grad": True,
231    "sparkline": True,
232    "sparkline_bins": 7,
233    "sparkline_logy": None,  # None means auto-detect
234    "colored": True,
235    "eq_char": "=",
236}
237
238
239DBG_TENSOR_VAL_JOINER: str = ": "
240
241
242def tensor_info(tensor: typing.Any) -> str:
243    from muutils.tensor_info import array_summary
244
245    # TODO: explicitly pass args to avoid type: ignore (mypy can't match overloads with **TypedDict spread)
246    return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS)  # type: ignore[call-overload]
247
248
249DBG_DICT_DEFAULTS: DBGDictDefaultsType = {
250    "key_types": True,
251    "val_types": True,
252    "max_len": 32,
253    "indent": "  ",
254    "max_depth": 3,
255}
256
257DBG_LIST_DEFAULTS: DBGListDefaultsType = {
258    "max_len": 16,
259    "summary_show_types": True,
260}
261
262
263def list_info(
264    lst: typing.List[typing.Any],
265) -> str:
266    len_l: int = len(lst)
267    output: str
268    if len_l > DBG_LIST_DEFAULTS["max_len"]:
269        output = f"<list of len()={len_l}"
270        if DBG_LIST_DEFAULTS["summary_show_types"]:
271            val_types: typing.Set[str] = set(type(x).__name__ for x in lst)
272            output += f", types={{{', '.join(sorted(val_types))}}}"
273        output += ">"
274    else:
275        output = "[" + ", ".join(repr(x) for x in lst) + "]"
276
277    return output
278
279
280TENSOR_STR_TYPES: typing.Set[str] = {
281    "<class 'torch.Tensor'>",
282    "<class 'numpy.ndarray'>",
283}
284
285
286def dict_info(
287    d: typing.Dict[typing.Any, typing.Any],
288    depth: int = 0,
289) -> str:
290    len_d: int = len(d)
291    indent: str = DBG_DICT_DEFAULTS["indent"]
292
293    # summary line
294    output: str = f"{indent * depth}<dict of len()={len_d}"
295
296    if DBG_DICT_DEFAULTS["key_types"] and len_d > 0:
297        key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys())
298        key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}"
299        output += f", key_types={key_types_str}"
300
301    if DBG_DICT_DEFAULTS["val_types"] and len_d > 0:
302        val_types: typing.Set[str] = set(type(v).__name__ for v in d.values())
303        val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}"
304        output += f", val_types={val_types_str}"
305
306    output += ">"
307
308    # keys/values if not to deep and not too many
309    if depth < DBG_DICT_DEFAULTS["max_depth"]:
310        if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]:
311            for k, v in d.items():
312                key_str: str = repr(k) if not isinstance(k, str) else k
313
314                val_str: str
315                val_type_str: str = str(type(v))
316                if isinstance(v, dict):
317                    val_str = dict_info(v, depth + 1)
318                elif val_type_str in TENSOR_STR_TYPES:
319                    val_str = tensor_info(v)
320                elif isinstance(v, list):
321                    val_str = list_info(v)
322                else:
323                    val_str = repr(v)
324
325                output += (
326                    f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}"
327                )
328
329    return output
330
331
332def info_auto(
333    obj: typing.Any,
334) -> str:
335    """Automatically format an object for debugging."""
336    if isinstance(obj, dict):
337        return dict_info(obj)
338    elif isinstance(obj, list):
339        return list_info(obj)
340    elif str(type(obj)) in TENSOR_STR_TYPES:
341        return tensor_info(obj)
342    else:
343        return repr(obj)
344
345
346def dbg_tensor(
347    tensor: _ExpType,  # numpy array or torch tensor
348) -> _ExpType:
349    """dbg function for tensors, using tensor_info formatter."""
350    return dbg(
351        tensor,
352        formatter=tensor_info,
353        val_joiner=DBG_TENSOR_VAL_JOINER,
354    )
355
356
357def dbg_dict(
358    d: _ExpType_dict,
359) -> _ExpType_dict:
360    """dbg function for dictionaries, using dict_info formatter."""
361    return dbg(
362        d,
363        formatter=dict_info,
364        val_joiner=DBG_TENSOR_VAL_JOINER,
365    )
366
367
368def dbg_auto(
369    obj: _ExpType,
370) -> _ExpType:
371    """dbg function for automatic formatting based on type."""
372    return dbg(
373        obj,
374        formatter=info_auto,
375        val_joiner=DBG_TENSOR_VAL_JOINER,
376    )
377
378
379def _normalize_for_loose(text: str) -> str:
380    """Normalize text for loose matching by replacing non-alphanumeric chars with spaces."""
381    normalized: str = re.sub(r"[^a-zA-Z0-9]+", " ", text)
382    return " ".join(normalized.split())
383
384
385def _compile_pattern(
386    pattern: str | re.Pattern[str],
387    *,
388    cased: bool = False,
389    loose: bool = False,
390) -> re.Pattern[str]:
391    """Compile pattern with appropriate flags for case sensitivity and loose matching."""
392    if isinstance(pattern, re.Pattern):
393        return pattern
394
395    # Start with no flags for case-insensitive default
396    flags: int = 0
397    if not cased:
398        flags |= re.IGNORECASE
399
400    if loose:
401        pattern = _normalize_for_loose(pattern)
402
403    return re.compile(pattern, flags)
404
405
406def grep_repr(
407    obj: typing.Any,
408    pattern: str | re.Pattern[str],
409    *,
410    char_context: int | None = 20,
411    line_context: int | None = None,
412    before_context: int = 0,
413    after_context: int = 0,
414    context: int | None = None,
415    max_count: int | None = None,
416    cased: bool = False,
417    loose: bool = False,
418    line_numbers: bool = False,
419    highlight: bool = True,
420    color: str = "31",
421    separator: str = "--",
422    quiet: bool = False,
423) -> typing.List[str] | None:
424    """grep-like search on ``repr(obj)`` with improved grep-style options.
425
426    By default, string patterns are case-insensitive. Pre-compiled regex
427    patterns use their own flags.
428
429    Parameters:
430    - obj: Object to search (its repr() string is scanned)
431    - pattern: Regular expression pattern (string or pre-compiled)
432    - char_context: Characters of context before/after each match (default: 20)
433    - line_context: Lines of context before/after; overrides char_context
434    - before_context: Lines of context before match (like grep -B)
435    - after_context: Lines of context after match (like grep -A)
436    - context: Lines of context before AND after (like grep -C)
437    - max_count: Stop after this many matches
438    - cased: Force case-sensitive search for string patterns
439    - loose: Normalize spaces/punctuation for flexible matching
440    - line_numbers: Show line numbers in output
441    - highlight: Wrap matches with ANSI color codes
442    - color: ANSI color code (default: "31" for red)
443    - separator: Separator between multiple matches
444    - quiet: Return results instead of printing
445
446    Returns:
447    - None if quiet=False (prints to stdout)
448    - List[str] if quiet=True (returns formatted output lines)
449    """
450    # Handle context parameter shortcuts
451    if context is not None:
452        before_context = after_context = context
453
454    # Prepare text and pattern
455    text: str = repr(obj)
456    if loose:
457        text = _normalize_for_loose(text)
458
459    regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose)
460
461    def _color_match(segment: str) -> str:
462        if not highlight:
463            return segment
464        return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment)
465
466    output_lines: list[str] = []
467    match_count: int = 0
468
469    # Determine if we're using line-based context
470    using_line_context = (
471        line_context is not None or before_context > 0 or after_context > 0
472    )
473
474    if using_line_context:
475        lines: list[str] = text.splitlines()
476        line_starts: list[int] = []
477        pos: int = 0
478        for line in lines:
479            line_starts.append(pos)
480            pos += len(line) + 1  # +1 for newline
481
482        processed_lines: set[int] = set()
483
484        for match in regex.finditer(text):
485            if max_count is not None and match_count >= max_count:
486                break
487
488            # Find which line contains this match
489            match_line = max(
490                i for i, start in enumerate(line_starts) if start <= match.start()
491            )
492
493            # Calculate context range
494            ctx_before: int
495            ctx_after: int
496            if line_context is not None:
497                ctx_before = ctx_after = line_context
498            else:
499                ctx_before, ctx_after = before_context, after_context
500
501            start_line: int = max(0, match_line - ctx_before)
502            end_line: int = min(len(lines), match_line + ctx_after + 1)
503
504            # Avoid duplicate output for overlapping contexts
505            line_range: set[int] = set(range(start_line, end_line))
506            if line_range & processed_lines:
507                continue
508            processed_lines.update(line_range)
509
510            # Format the context block
511            context_lines: list[str] = []
512            for i in range(start_line, end_line):
513                line_text = lines[i]
514                if line_numbers:
515                    line_prefix = f"{i + 1}:"
516                    line_text = f"{line_prefix}{line_text}"
517                context_lines.append(_color_match(line_text))
518
519            if output_lines and separator:
520                output_lines.append(separator)
521            output_lines.extend(context_lines)
522            match_count += 1
523
524    else:
525        # Character-based context
526        ctx: int = 0 if char_context is None else char_context
527
528        for match in regex.finditer(text):
529            if max_count is not None and match_count >= max_count:
530                break
531
532            start: int = max(0, match.start() - ctx)
533            end: int = min(len(text), match.end() + ctx)
534            snippet: str = text[start:end]
535
536            if output_lines and separator:
537                output_lines.append(separator)
538            output_lines.append(_color_match(snippet))
539            match_count += 1
540
541    if quiet:
542        return output_lines
543    else:
544        for line in output_lines:
545            print(line)
546        return None

class DBGDictDefaultsType(typing.TypedDict):
49class DBGDictDefaultsType(typing.TypedDict):
50    key_types: bool
51    val_types: bool
52    max_len: int
53    indent: str
54    max_depth: int
key_types: bool
val_types: bool
max_len: int
indent: str
max_depth: int
Inherited Members
builtins.dict
get
setdefault
pop
popitem
keys
items
values
update
fromkeys
clear
copy
class DBGListDefaultsType(typing.TypedDict):
57class DBGListDefaultsType(typing.TypedDict):
58    max_len: int
59    summary_show_types: bool
max_len: int
summary_show_types: bool
Inherited Members
builtins.dict
get
setdefault
pop
popitem
keys
items
values
update
fromkeys
clear
copy
class DBGTensorArraySummaryDefaultsType(typing.TypedDict):
62class DBGTensorArraySummaryDefaultsType(typing.TypedDict):
63    fmt: typing.Literal["unicode", "latex", "ascii"]
64    precision: int
65    stats: bool
66    shape: bool
67    dtype: bool
68    device: bool
69    requires_grad: bool
70    sparkline: bool
71    sparkline_bins: int
72    sparkline_logy: typing.Union[None, bool]
73    colored: bool
74    eq_char: str
fmt: Literal['unicode', 'latex', 'ascii']
precision: int
stats: bool
shape: bool
dtype: bool
device: bool
requires_grad: bool
sparkline: bool
sparkline_bins: int
sparkline_logy: Optional[bool]
colored: bool
eq_char: str
Inherited Members
builtins.dict
get
setdefault
pop
popitem
keys
items
values
update
fromkeys
clear
copy
PATH_MODE: Literal['relative', 'absolute'] = 'relative'
DEFAULT_VAL_JOINER: str = ' = '
def dbg( exp: Union[~_ExpType, muutils.dbg._NoExpPassedSentinel] = <muutils.dbg._NoExpPassedSentinel object>, formatter: Optional[Callable[[Any], str]] = None, val_joiner: str = ' = ') -> Union[~_ExpType, muutils.dbg._NoExpPassedSentinel]:
130def dbg(
131    exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
132    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
133    val_joiner: str = DEFAULT_VAL_JOINER,
134) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
135    """Call dbg with any variable or expression.
136
137    Calling dbg will print to stderr the current filename and lineno,
138    as well as the passed expression and what the expression evaluates to:
139
140            from muutils.dbg import dbg
141
142            a = 2
143            b = 5
144
145            dbg(a+b)
146
147            def square(x: int) -> int:
148                    return x * x
149
150            dbg(square(a))
151
152    """
153    global _COUNTER
154
155    # get the context
156    line_exp: str = "unknown"
157    current_file: str = "unknown"
158    dbg_frame: typing.Optional[inspect.FrameInfo] = None
159    for frame in inspect.stack():
160        if frame.code_context is None:
161            continue
162        if str(Path(frame.filename).resolve()) == _DBG_MODULE_FILE:
163            continue
164        line: str = frame.code_context[0]
165        if "dbg" in line:
166            current_file = _process_path(Path(frame.filename))
167            dbg_frame = frame
168            start: int = line.find("(") + 1
169            end: int = line.rfind(")")
170            if end == -1:
171                end = len(line)
172            line_exp = line[start:end]
173            break
174
175    fname: str = "unknown"
176    if current_file.startswith("/tmp/ipykernel_"):
177        stack: list[inspect.FrameInfo] = inspect.stack()
178        filtered_functions: list[str] = []
179        # this loop will find, in this order:
180        # - the dbg function call
181        # - the functions we care about displaying
182        # - `<module>`
183        # - a bunch of jupyter internals we don't care about
184        for frame_info in stack:
185            if _process_path(Path(frame_info.filename)) != current_file:
186                continue
187            if frame_info.function == "<module>":
188                break
189            if frame_info.function.startswith("dbg"):
190                continue
191            filtered_functions.append(frame_info.function)
192        if dbg_frame is not None:
193            filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}")
194        else:
195            filtered_functions.append(current_file)
196        filtered_functions.reverse()
197        fname = " -> ".join(filtered_functions)
198    elif dbg_frame is not None:
199        fname = f"{current_file}:{dbg_frame.lineno}"
200
201    # assemble the message
202    msg: str
203    if exp is _NoExpPassed:
204        # if no expression is passed, just show location and counter value
205        msg = f"[ {fname} ] <dbg {_COUNTER}>"
206        _COUNTER += 1
207    else:
208        # if expression passed, format its value and show location, expr, and value
209        exp_val: str = formatter(exp) if formatter else repr(exp)
210        msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"
211
212    # print the message
213    print(
214        msg,
215        file=sys.stderr,
216    )
217
218    # return the expression itself
219    return exp

Call dbg with any variable or expression.

Calling dbg will print to stderr the current filename and lineno, as well as the passed expression and what the expression evaluates to:

    from muutils.dbg import dbg

    a = 2
    b = 5

    dbg(a+b)

    def square(x: int) -> int:
            return x * x

    dbg(square(a))
DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: DBGTensorArraySummaryDefaultsType = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': True, 'sparkline_bins': 7, 'sparkline_logy': None, 'colored': True, 'eq_char': '='}
DBG_TENSOR_VAL_JOINER: str = ': '
def tensor_info(tensor: Any) -> str:
243def tensor_info(tensor: typing.Any) -> str:
244    from muutils.tensor_info import array_summary
245
246    # TODO: explicitly pass args to avoid type: ignore (mypy can't match overloads with **TypedDict spread)
247    return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS)  # type: ignore[call-overload]
DBG_DICT_DEFAULTS: DBGDictDefaultsType = {'key_types': True, 'val_types': True, 'max_len': 32, 'indent': ' ', 'max_depth': 3}
DBG_LIST_DEFAULTS: DBGListDefaultsType = {'max_len': 16, 'summary_show_types': True}
def list_info(lst: List[Any]) -> str:
264def list_info(
265    lst: typing.List[typing.Any],
266) -> str:
267    len_l: int = len(lst)
268    output: str
269    if len_l > DBG_LIST_DEFAULTS["max_len"]:
270        output = f"<list of len()={len_l}"
271        if DBG_LIST_DEFAULTS["summary_show_types"]:
272            val_types: typing.Set[str] = set(type(x).__name__ for x in lst)
273            output += f", types={{{', '.join(sorted(val_types))}}}"
274        output += ">"
275    else:
276        output = "[" + ", ".join(repr(x) for x in lst) + "]"
277
278    return output
TENSOR_STR_TYPES: Set[str] = {"<class 'numpy.ndarray'>", "<class 'torch.Tensor'>"}
def dict_info(d: Dict[Any, Any], depth: int = 0) -> str:
287def dict_info(
288    d: typing.Dict[typing.Any, typing.Any],
289    depth: int = 0,
290) -> str:
291    len_d: int = len(d)
292    indent: str = DBG_DICT_DEFAULTS["indent"]
293
294    # summary line
295    output: str = f"{indent * depth}<dict of len()={len_d}"
296
297    if DBG_DICT_DEFAULTS["key_types"] and len_d > 0:
298        key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys())
299        key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}"
300        output += f", key_types={key_types_str}"
301
302    if DBG_DICT_DEFAULTS["val_types"] and len_d > 0:
303        val_types: typing.Set[str] = set(type(v).__name__ for v in d.values())
304        val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}"
305        output += f", val_types={val_types_str}"
306
307    output += ">"
308
309    # keys/values if not to deep and not too many
310    if depth < DBG_DICT_DEFAULTS["max_depth"]:
311        if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]:
312            for k, v in d.items():
313                key_str: str = repr(k) if not isinstance(k, str) else k
314
315                val_str: str
316                val_type_str: str = str(type(v))
317                if isinstance(v, dict):
318                    val_str = dict_info(v, depth + 1)
319                elif val_type_str in TENSOR_STR_TYPES:
320                    val_str = tensor_info(v)
321                elif isinstance(v, list):
322                    val_str = list_info(v)
323                else:
324                    val_str = repr(v)
325
326                output += (
327                    f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}"
328                )
329
330    return output
def info_auto(obj: Any) -> str:
333def info_auto(
334    obj: typing.Any,
335) -> str:
336    """Automatically format an object for debugging."""
337    if isinstance(obj, dict):
338        return dict_info(obj)
339    elif isinstance(obj, list):
340        return list_info(obj)
341    elif str(type(obj)) in TENSOR_STR_TYPES:
342        return tensor_info(obj)
343    else:
344        return repr(obj)

Automatically format an object for debugging.

def dbg_tensor(tensor: ~_ExpType) -> ~_ExpType:
347def dbg_tensor(
348    tensor: _ExpType,  # numpy array or torch tensor
349) -> _ExpType:
350    """dbg function for tensors, using tensor_info formatter."""
351    return dbg(
352        tensor,
353        formatter=tensor_info,
354        val_joiner=DBG_TENSOR_VAL_JOINER,
355    )

dbg function for tensors, using tensor_info formatter.

def dbg_dict(d: ~_ExpType_dict) -> ~_ExpType_dict:
358def dbg_dict(
359    d: _ExpType_dict,
360) -> _ExpType_dict:
361    """dbg function for dictionaries, using dict_info formatter."""
362    return dbg(
363        d,
364        formatter=dict_info,
365        val_joiner=DBG_TENSOR_VAL_JOINER,
366    )

dbg function for dictionaries, using dict_info formatter.

def dbg_auto(obj: ~_ExpType) -> ~_ExpType:
369def dbg_auto(
370    obj: _ExpType,
371) -> _ExpType:
372    """dbg function for automatic formatting based on type."""
373    return dbg(
374        obj,
375        formatter=info_auto,
376        val_joiner=DBG_TENSOR_VAL_JOINER,
377    )

dbg function for automatic formatting based on type.

def grep_repr( obj: Any, pattern: str | re.Pattern[str], *, char_context: int | None = 20, line_context: int | None = None, before_context: int = 0, after_context: int = 0, context: int | None = None, max_count: int | None = None, cased: bool = False, loose: bool = False, line_numbers: bool = False, highlight: bool = True, color: str = '31', separator: str = '--', quiet: bool = False) -> Optional[List[str]]:
407def grep_repr(
408    obj: typing.Any,
409    pattern: str | re.Pattern[str],
410    *,
411    char_context: int | None = 20,
412    line_context: int | None = None,
413    before_context: int = 0,
414    after_context: int = 0,
415    context: int | None = None,
416    max_count: int | None = None,
417    cased: bool = False,
418    loose: bool = False,
419    line_numbers: bool = False,
420    highlight: bool = True,
421    color: str = "31",
422    separator: str = "--",
423    quiet: bool = False,
424) -> typing.List[str] | None:
425    """grep-like search on ``repr(obj)`` with improved grep-style options.
426
427    By default, string patterns are case-insensitive. Pre-compiled regex
428    patterns use their own flags.
429
430    Parameters:
431    - obj: Object to search (its repr() string is scanned)
432    - pattern: Regular expression pattern (string or pre-compiled)
433    - char_context: Characters of context before/after each match (default: 20)
434    - line_context: Lines of context before/after; overrides char_context
435    - before_context: Lines of context before match (like grep -B)
436    - after_context: Lines of context after match (like grep -A)
437    - context: Lines of context before AND after (like grep -C)
438    - max_count: Stop after this many matches
439    - cased: Force case-sensitive search for string patterns
440    - loose: Normalize spaces/punctuation for flexible matching
441    - line_numbers: Show line numbers in output
442    - highlight: Wrap matches with ANSI color codes
443    - color: ANSI color code (default: "31" for red)
444    - separator: Separator between multiple matches
445    - quiet: Return results instead of printing
446
447    Returns:
448    - None if quiet=False (prints to stdout)
449    - List[str] if quiet=True (returns formatted output lines)
450    """
451    # Handle context parameter shortcuts
452    if context is not None:
453        before_context = after_context = context
454
455    # Prepare text and pattern
456    text: str = repr(obj)
457    if loose:
458        text = _normalize_for_loose(text)
459
460    regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose)
461
462    def _color_match(segment: str) -> str:
463        if not highlight:
464            return segment
465        return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment)
466
467    output_lines: list[str] = []
468    match_count: int = 0
469
470    # Determine if we're using line-based context
471    using_line_context = (
472        line_context is not None or before_context > 0 or after_context > 0
473    )
474
475    if using_line_context:
476        lines: list[str] = text.splitlines()
477        line_starts: list[int] = []
478        pos: int = 0
479        for line in lines:
480            line_starts.append(pos)
481            pos += len(line) + 1  # +1 for newline
482
483        processed_lines: set[int] = set()
484
485        for match in regex.finditer(text):
486            if max_count is not None and match_count >= max_count:
487                break
488
489            # Find which line contains this match
490            match_line = max(
491                i for i, start in enumerate(line_starts) if start <= match.start()
492            )
493
494            # Calculate context range
495            ctx_before: int
496            ctx_after: int
497            if line_context is not None:
498                ctx_before = ctx_after = line_context
499            else:
500                ctx_before, ctx_after = before_context, after_context
501
502            start_line: int = max(0, match_line - ctx_before)
503            end_line: int = min(len(lines), match_line + ctx_after + 1)
504
505            # Avoid duplicate output for overlapping contexts
506            line_range: set[int] = set(range(start_line, end_line))
507            if line_range & processed_lines:
508                continue
509            processed_lines.update(line_range)
510
511            # Format the context block
512            context_lines: list[str] = []
513            for i in range(start_line, end_line):
514                line_text = lines[i]
515                if line_numbers:
516                    line_prefix = f"{i + 1}:"
517                    line_text = f"{line_prefix}{line_text}"
518                context_lines.append(_color_match(line_text))
519
520            if output_lines and separator:
521                output_lines.append(separator)
522            output_lines.extend(context_lines)
523            match_count += 1
524
525    else:
526        # Character-based context
527        ctx: int = 0 if char_context is None else char_context
528
529        for match in regex.finditer(text):
530            if max_count is not None and match_count >= max_count:
531                break
532
533            start: int = max(0, match.start() - ctx)
534            end: int = min(len(text), match.end() + ctx)
535            snippet: str = text[start:end]
536
537            if output_lines and separator:
538                output_lines.append(separator)
539            output_lines.append(_color_match(snippet))
540            match_count += 1
541
542    if quiet:
543        return output_lines
544    else:
545        for line in output_lines:
546            print(line)
547        return None

grep-like search on repr(obj) with improved grep-style options.

By default, string patterns are case-insensitive. Pre-compiled regex patterns use their own flags.

Parameters:

  • obj: Object to search (its repr() string is scanned)
  • pattern: Regular expression pattern (string or pre-compiled)
  • char_context: Characters of context before/after each match (default: 20)
  • line_context: Lines of context before/after; overrides char_context
  • before_context: Lines of context before match (like grep -B)
  • after_context: Lines of context after match (like grep -A)
  • context: Lines of context before AND after (like grep -C)
  • max_count: Stop after this many matches
  • cased: Force case-sensitive search for string patterns
  • loose: Normalize spaces/punctuation for flexible matching
  • line_numbers: Show line numbers in output
  • highlight: Wrap matches with ANSI color codes
  • color: ANSI color code (default: "31" for red)
  • separator: Separator between multiple matches
  • quiet: Return results instead of printing

Returns:

  • None if quiet=False (prints to stdout)
  • List[str] if quiet=True (returns formatted output lines)