muutils.dbg
this code is based on an implementation of the Rust builtin dbg! for Python, originally from
https://github.com/tylerwince/pydbg/blob/master/pydbg.py
although it has been significantly modified
licensed under MIT:
Copyright (c) 2019 Tyler Wince
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1""" 2 3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from 4https://github.com/tylerwince/pydbg/blob/master/pydbg.py 5although it has been significantly modified 6 7licensed under MIT: 8 9Copyright (c) 2019 Tyler Wince 10 11Permission is hereby granted, free of charge, to any person obtaining a copy 12of this software and associated documentation files (the "Software"), to deal 13in the Software without restriction, including without limitation the rights 14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15copies of the Software, and to permit persons to whom the Software is 16furnished to do so, subject to the following conditions: 17 18The above copyright notice and this permission notice shall be included in 19all copies or substantial portions of the Software. 20 21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27THE SOFTWARE. 28 29""" 30 31from __future__ import annotations 32 33import inspect 34import sys 35import typing 36from pathlib import Path 37import re 38 39# type defs 40_ExpType = typing.TypeVar("_ExpType") 41_ExpType_dict = typing.TypeVar( 42 "_ExpType_dict", bound=typing.Dict[typing.Any, typing.Any] 43) 44_ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any]) 45 46 47# TypedDict definitions for configuration dictionaries 48class DBGDictDefaultsType(typing.TypedDict): 49 key_types: bool 50 val_types: bool 51 max_len: int 52 indent: str 53 max_depth: int 54 55 56class DBGListDefaultsType(typing.TypedDict): 57 max_len: int 58 summary_show_types: bool 59 60 61class DBGTensorArraySummaryDefaultsType(typing.TypedDict): 62 fmt: typing.Literal["unicode", "latex", "ascii"] 63 precision: int 64 stats: bool 65 shape: bool 66 dtype: bool 67 device: bool 68 requires_grad: bool 69 sparkline: bool 70 sparkline_bins: int 71 sparkline_logy: typing.Union[None, bool] 72 colored: bool 73 eq_char: str 74 75 76# Sentinel type for no expression passed 77class _NoExpPassedSentinel: 78 """Unique sentinel type used to indicate that no expression was passed.""" 79 80 pass 81 82 83_NoExpPassed = _NoExpPassedSentinel() 84 85# global variables 86_CWD: Path = Path.cwd().absolute() 87_COUNTER: int = 0 88_DBG_MODULE_FILE: str = str(Path(__file__).resolve()) 89 90# configuration 91PATH_MODE: typing.Literal["relative", "absolute"] = "relative" 92DEFAULT_VAL_JOINER: str = " = " 93 94 95# path processing 96def _process_path(path: Path) -> str: 97 path_abs: Path = path.absolute() 98 fname: Path 99 if PATH_MODE == "absolute": 100 fname = path_abs 101 elif PATH_MODE == "relative": 102 try: 103 # if it's inside the cwd, print the relative path 104 fname = path.relative_to(_CWD) 105 except ValueError: 106 # if its not in the subpath, use the absolute path 107 fname = path_abs 108 else: 109 raise ValueError("PATH_MODE must be either 'relative' or 'absolute") 110 111 return fname.as_posix() 112 113 114# actual dbg function 115@typing.overload 116def dbg() -> _NoExpPassedSentinel: ... 117@typing.overload 118def dbg( 119 exp: _NoExpPassedSentinel, 120 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 121 val_joiner: str = DEFAULT_VAL_JOINER, 122) -> _NoExpPassedSentinel: ... 123@typing.overload 124def dbg( 125 exp: _ExpType, 126 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 127 val_joiner: str = DEFAULT_VAL_JOINER, 128) -> _ExpType: ... 129def dbg( 130 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 131 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 132 val_joiner: str = DEFAULT_VAL_JOINER, 133) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 134 """Call dbg with any variable or expression. 135 136 Calling dbg will print to stderr the current filename and lineno, 137 as well as the passed expression and what the expression evaluates to: 138 139 from muutils.dbg import dbg 140 141 a = 2 142 b = 5 143 144 dbg(a+b) 145 146 def square(x: int) -> int: 147 return x * x 148 149 dbg(square(a)) 150 151 """ 152 global _COUNTER 153 154 # get the context 155 line_exp: str = "unknown" 156 current_file: str = "unknown" 157 dbg_frame: typing.Optional[inspect.FrameInfo] = None 158 for frame in inspect.stack(): 159 if frame.code_context is None: 160 continue 161 if str(Path(frame.filename).resolve()) == _DBG_MODULE_FILE: 162 continue 163 line: str = frame.code_context[0] 164 if "dbg" in line: 165 current_file = _process_path(Path(frame.filename)) 166 dbg_frame = frame 167 start: int = line.find("(") + 1 168 end: int = line.rfind(")") 169 if end == -1: 170 end = len(line) 171 line_exp = line[start:end] 172 break 173 174 fname: str = "unknown" 175 if current_file.startswith("/tmp/ipykernel_"): 176 stack: list[inspect.FrameInfo] = inspect.stack() 177 filtered_functions: list[str] = [] 178 # this loop will find, in this order: 179 # - the dbg function call 180 # - the functions we care about displaying 181 # - `<module>` 182 # - a bunch of jupyter internals we don't care about 183 for frame_info in stack: 184 if _process_path(Path(frame_info.filename)) != current_file: 185 continue 186 if frame_info.function == "<module>": 187 break 188 if frame_info.function.startswith("dbg"): 189 continue 190 filtered_functions.append(frame_info.function) 191 if dbg_frame is not None: 192 filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}") 193 else: 194 filtered_functions.append(current_file) 195 filtered_functions.reverse() 196 fname = " -> ".join(filtered_functions) 197 elif dbg_frame is not None: 198 fname = f"{current_file}:{dbg_frame.lineno}" 199 200 # assemble the message 201 msg: str 202 if exp is _NoExpPassed: 203 # if no expression is passed, just show location and counter value 204 msg = f"[ {fname} ] <dbg {_COUNTER}>" 205 _COUNTER += 1 206 else: 207 # if expression passed, format its value and show location, expr, and value 208 exp_val: str = formatter(exp) if formatter else repr(exp) 209 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" 210 211 # print the message 212 print( 213 msg, 214 file=sys.stderr, 215 ) 216 217 # return the expression itself 218 return exp 219 220 221# formatted `dbg_*` functions with their helpers 222 223DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: DBGTensorArraySummaryDefaultsType = { 224 "fmt": "unicode", 225 "precision": 2, 226 "stats": True, 227 "shape": True, 228 "dtype": True, 229 "device": True, 230 "requires_grad": True, 231 "sparkline": True, 232 "sparkline_bins": 7, 233 "sparkline_logy": None, # None means auto-detect 234 "colored": True, 235 "eq_char": "=", 236} 237 238 239DBG_TENSOR_VAL_JOINER: str = ": " 240 241 242def tensor_info(tensor: typing.Any) -> str: 243 from muutils.tensor_info import array_summary 244 245 # TODO: explicitly pass args to avoid type: ignore (mypy can't match overloads with **TypedDict spread) 246 return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) # type: ignore[call-overload] 247 248 249DBG_DICT_DEFAULTS: DBGDictDefaultsType = { 250 "key_types": True, 251 "val_types": True, 252 "max_len": 32, 253 "indent": " ", 254 "max_depth": 3, 255} 256 257DBG_LIST_DEFAULTS: DBGListDefaultsType = { 258 "max_len": 16, 259 "summary_show_types": True, 260} 261 262 263def list_info( 264 lst: typing.List[typing.Any], 265) -> str: 266 len_l: int = len(lst) 267 output: str 268 if len_l > DBG_LIST_DEFAULTS["max_len"]: 269 output = f"<list of len()={len_l}" 270 if DBG_LIST_DEFAULTS["summary_show_types"]: 271 val_types: typing.Set[str] = set(type(x).__name__ for x in lst) 272 output += f", types={{{', '.join(sorted(val_types))}}}" 273 output += ">" 274 else: 275 output = "[" + ", ".join(repr(x) for x in lst) + "]" 276 277 return output 278 279 280TENSOR_STR_TYPES: typing.Set[str] = { 281 "<class 'torch.Tensor'>", 282 "<class 'numpy.ndarray'>", 283} 284 285 286def dict_info( 287 d: typing.Dict[typing.Any, typing.Any], 288 depth: int = 0, 289) -> str: 290 len_d: int = len(d) 291 indent: str = DBG_DICT_DEFAULTS["indent"] 292 293 # summary line 294 output: str = f"{indent * depth}<dict of len()={len_d}" 295 296 if DBG_DICT_DEFAULTS["key_types"] and len_d > 0: 297 key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys()) 298 key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}" 299 output += f", key_types={key_types_str}" 300 301 if DBG_DICT_DEFAULTS["val_types"] and len_d > 0: 302 val_types: typing.Set[str] = set(type(v).__name__ for v in d.values()) 303 val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}" 304 output += f", val_types={val_types_str}" 305 306 output += ">" 307 308 # keys/values if not to deep and not too many 309 if depth < DBG_DICT_DEFAULTS["max_depth"]: 310 if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]: 311 for k, v in d.items(): 312 key_str: str = repr(k) if not isinstance(k, str) else k 313 314 val_str: str 315 val_type_str: str = str(type(v)) 316 if isinstance(v, dict): 317 val_str = dict_info(v, depth + 1) 318 elif val_type_str in TENSOR_STR_TYPES: 319 val_str = tensor_info(v) 320 elif isinstance(v, list): 321 val_str = list_info(v) 322 else: 323 val_str = repr(v) 324 325 output += ( 326 f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}" 327 ) 328 329 return output 330 331 332def info_auto( 333 obj: typing.Any, 334) -> str: 335 """Automatically format an object for debugging.""" 336 if isinstance(obj, dict): 337 return dict_info(obj) 338 elif isinstance(obj, list): 339 return list_info(obj) 340 elif str(type(obj)) in TENSOR_STR_TYPES: 341 return tensor_info(obj) 342 else: 343 return repr(obj) 344 345 346def dbg_tensor( 347 tensor: _ExpType, # numpy array or torch tensor 348) -> _ExpType: 349 """dbg function for tensors, using tensor_info formatter.""" 350 return dbg( 351 tensor, 352 formatter=tensor_info, 353 val_joiner=DBG_TENSOR_VAL_JOINER, 354 ) 355 356 357def dbg_dict( 358 d: _ExpType_dict, 359) -> _ExpType_dict: 360 """dbg function for dictionaries, using dict_info formatter.""" 361 return dbg( 362 d, 363 formatter=dict_info, 364 val_joiner=DBG_TENSOR_VAL_JOINER, 365 ) 366 367 368def dbg_auto( 369 obj: _ExpType, 370) -> _ExpType: 371 """dbg function for automatic formatting based on type.""" 372 return dbg( 373 obj, 374 formatter=info_auto, 375 val_joiner=DBG_TENSOR_VAL_JOINER, 376 ) 377 378 379def _normalize_for_loose(text: str) -> str: 380 """Normalize text for loose matching by replacing non-alphanumeric chars with spaces.""" 381 normalized: str = re.sub(r"[^a-zA-Z0-9]+", " ", text) 382 return " ".join(normalized.split()) 383 384 385def _compile_pattern( 386 pattern: str | re.Pattern[str], 387 *, 388 cased: bool = False, 389 loose: bool = False, 390) -> re.Pattern[str]: 391 """Compile pattern with appropriate flags for case sensitivity and loose matching.""" 392 if isinstance(pattern, re.Pattern): 393 return pattern 394 395 # Start with no flags for case-insensitive default 396 flags: int = 0 397 if not cased: 398 flags |= re.IGNORECASE 399 400 if loose: 401 pattern = _normalize_for_loose(pattern) 402 403 return re.compile(pattern, flags) 404 405 406def grep_repr( 407 obj: typing.Any, 408 pattern: str | re.Pattern[str], 409 *, 410 char_context: int | None = 20, 411 line_context: int | None = None, 412 before_context: int = 0, 413 after_context: int = 0, 414 context: int | None = None, 415 max_count: int | None = None, 416 cased: bool = False, 417 loose: bool = False, 418 line_numbers: bool = False, 419 highlight: bool = True, 420 color: str = "31", 421 separator: str = "--", 422 quiet: bool = False, 423) -> typing.List[str] | None: 424 """grep-like search on ``repr(obj)`` with improved grep-style options. 425 426 By default, string patterns are case-insensitive. Pre-compiled regex 427 patterns use their own flags. 428 429 Parameters: 430 - obj: Object to search (its repr() string is scanned) 431 - pattern: Regular expression pattern (string or pre-compiled) 432 - char_context: Characters of context before/after each match (default: 20) 433 - line_context: Lines of context before/after; overrides char_context 434 - before_context: Lines of context before match (like grep -B) 435 - after_context: Lines of context after match (like grep -A) 436 - context: Lines of context before AND after (like grep -C) 437 - max_count: Stop after this many matches 438 - cased: Force case-sensitive search for string patterns 439 - loose: Normalize spaces/punctuation for flexible matching 440 - line_numbers: Show line numbers in output 441 - highlight: Wrap matches with ANSI color codes 442 - color: ANSI color code (default: "31" for red) 443 - separator: Separator between multiple matches 444 - quiet: Return results instead of printing 445 446 Returns: 447 - None if quiet=False (prints to stdout) 448 - List[str] if quiet=True (returns formatted output lines) 449 """ 450 # Handle context parameter shortcuts 451 if context is not None: 452 before_context = after_context = context 453 454 # Prepare text and pattern 455 text: str = repr(obj) 456 if loose: 457 text = _normalize_for_loose(text) 458 459 regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose) 460 461 def _color_match(segment: str) -> str: 462 if not highlight: 463 return segment 464 return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment) 465 466 output_lines: list[str] = [] 467 match_count: int = 0 468 469 # Determine if we're using line-based context 470 using_line_context = ( 471 line_context is not None or before_context > 0 or after_context > 0 472 ) 473 474 if using_line_context: 475 lines: list[str] = text.splitlines() 476 line_starts: list[int] = [] 477 pos: int = 0 478 for line in lines: 479 line_starts.append(pos) 480 pos += len(line) + 1 # +1 for newline 481 482 processed_lines: set[int] = set() 483 484 for match in regex.finditer(text): 485 if max_count is not None and match_count >= max_count: 486 break 487 488 # Find which line contains this match 489 match_line = max( 490 i for i, start in enumerate(line_starts) if start <= match.start() 491 ) 492 493 # Calculate context range 494 ctx_before: int 495 ctx_after: int 496 if line_context is not None: 497 ctx_before = ctx_after = line_context 498 else: 499 ctx_before, ctx_after = before_context, after_context 500 501 start_line: int = max(0, match_line - ctx_before) 502 end_line: int = min(len(lines), match_line + ctx_after + 1) 503 504 # Avoid duplicate output for overlapping contexts 505 line_range: set[int] = set(range(start_line, end_line)) 506 if line_range & processed_lines: 507 continue 508 processed_lines.update(line_range) 509 510 # Format the context block 511 context_lines: list[str] = [] 512 for i in range(start_line, end_line): 513 line_text = lines[i] 514 if line_numbers: 515 line_prefix = f"{i + 1}:" 516 line_text = f"{line_prefix}{line_text}" 517 context_lines.append(_color_match(line_text)) 518 519 if output_lines and separator: 520 output_lines.append(separator) 521 output_lines.extend(context_lines) 522 match_count += 1 523 524 else: 525 # Character-based context 526 ctx: int = 0 if char_context is None else char_context 527 528 for match in regex.finditer(text): 529 if max_count is not None and match_count >= max_count: 530 break 531 532 start: int = max(0, match.start() - ctx) 533 end: int = min(len(text), match.end() + ctx) 534 snippet: str = text[start:end] 535 536 if output_lines and separator: 537 output_lines.append(separator) 538 output_lines.append(_color_match(snippet)) 539 match_count += 1 540 541 if quiet: 542 return output_lines 543 else: 544 for line in output_lines: 545 print(line) 546 return None
49class DBGDictDefaultsType(typing.TypedDict): 50 key_types: bool 51 val_types: bool 52 max_len: int 53 indent: str 54 max_depth: int
Inherited Members
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- update
- fromkeys
- clear
- copy
Inherited Members
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- update
- fromkeys
- clear
- copy
62class DBGTensorArraySummaryDefaultsType(typing.TypedDict): 63 fmt: typing.Literal["unicode", "latex", "ascii"] 64 precision: int 65 stats: bool 66 shape: bool 67 dtype: bool 68 device: bool 69 requires_grad: bool 70 sparkline: bool 71 sparkline_bins: int 72 sparkline_logy: typing.Union[None, bool] 73 colored: bool 74 eq_char: str
Inherited Members
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- update
- fromkeys
- clear
- copy
130def dbg( 131 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 132 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 133 val_joiner: str = DEFAULT_VAL_JOINER, 134) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 135 """Call dbg with any variable or expression. 136 137 Calling dbg will print to stderr the current filename and lineno, 138 as well as the passed expression and what the expression evaluates to: 139 140 from muutils.dbg import dbg 141 142 a = 2 143 b = 5 144 145 dbg(a+b) 146 147 def square(x: int) -> int: 148 return x * x 149 150 dbg(square(a)) 151 152 """ 153 global _COUNTER 154 155 # get the context 156 line_exp: str = "unknown" 157 current_file: str = "unknown" 158 dbg_frame: typing.Optional[inspect.FrameInfo] = None 159 for frame in inspect.stack(): 160 if frame.code_context is None: 161 continue 162 if str(Path(frame.filename).resolve()) == _DBG_MODULE_FILE: 163 continue 164 line: str = frame.code_context[0] 165 if "dbg" in line: 166 current_file = _process_path(Path(frame.filename)) 167 dbg_frame = frame 168 start: int = line.find("(") + 1 169 end: int = line.rfind(")") 170 if end == -1: 171 end = len(line) 172 line_exp = line[start:end] 173 break 174 175 fname: str = "unknown" 176 if current_file.startswith("/tmp/ipykernel_"): 177 stack: list[inspect.FrameInfo] = inspect.stack() 178 filtered_functions: list[str] = [] 179 # this loop will find, in this order: 180 # - the dbg function call 181 # - the functions we care about displaying 182 # - `<module>` 183 # - a bunch of jupyter internals we don't care about 184 for frame_info in stack: 185 if _process_path(Path(frame_info.filename)) != current_file: 186 continue 187 if frame_info.function == "<module>": 188 break 189 if frame_info.function.startswith("dbg"): 190 continue 191 filtered_functions.append(frame_info.function) 192 if dbg_frame is not None: 193 filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}") 194 else: 195 filtered_functions.append(current_file) 196 filtered_functions.reverse() 197 fname = " -> ".join(filtered_functions) 198 elif dbg_frame is not None: 199 fname = f"{current_file}:{dbg_frame.lineno}" 200 201 # assemble the message 202 msg: str 203 if exp is _NoExpPassed: 204 # if no expression is passed, just show location and counter value 205 msg = f"[ {fname} ] <dbg {_COUNTER}>" 206 _COUNTER += 1 207 else: 208 # if expression passed, format its value and show location, expr, and value 209 exp_val: str = formatter(exp) if formatter else repr(exp) 210 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" 211 212 # print the message 213 print( 214 msg, 215 file=sys.stderr, 216 ) 217 218 # return the expression itself 219 return exp
Call dbg with any variable or expression.
Calling dbg will print to stderr the current filename and lineno, as well as the passed expression and what the expression evaluates to:
from muutils.dbg import dbg
a = 2
b = 5
dbg(a+b)
def square(x: int) -> int:
return x * x
dbg(square(a))
243def tensor_info(tensor: typing.Any) -> str: 244 from muutils.tensor_info import array_summary 245 246 # TODO: explicitly pass args to avoid type: ignore (mypy can't match overloads with **TypedDict spread) 247 return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) # type: ignore[call-overload]
264def list_info( 265 lst: typing.List[typing.Any], 266) -> str: 267 len_l: int = len(lst) 268 output: str 269 if len_l > DBG_LIST_DEFAULTS["max_len"]: 270 output = f"<list of len()={len_l}" 271 if DBG_LIST_DEFAULTS["summary_show_types"]: 272 val_types: typing.Set[str] = set(type(x).__name__ for x in lst) 273 output += f", types={{{', '.join(sorted(val_types))}}}" 274 output += ">" 275 else: 276 output = "[" + ", ".join(repr(x) for x in lst) + "]" 277 278 return output
287def dict_info( 288 d: typing.Dict[typing.Any, typing.Any], 289 depth: int = 0, 290) -> str: 291 len_d: int = len(d) 292 indent: str = DBG_DICT_DEFAULTS["indent"] 293 294 # summary line 295 output: str = f"{indent * depth}<dict of len()={len_d}" 296 297 if DBG_DICT_DEFAULTS["key_types"] and len_d > 0: 298 key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys()) 299 key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}" 300 output += f", key_types={key_types_str}" 301 302 if DBG_DICT_DEFAULTS["val_types"] and len_d > 0: 303 val_types: typing.Set[str] = set(type(v).__name__ for v in d.values()) 304 val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}" 305 output += f", val_types={val_types_str}" 306 307 output += ">" 308 309 # keys/values if not to deep and not too many 310 if depth < DBG_DICT_DEFAULTS["max_depth"]: 311 if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]: 312 for k, v in d.items(): 313 key_str: str = repr(k) if not isinstance(k, str) else k 314 315 val_str: str 316 val_type_str: str = str(type(v)) 317 if isinstance(v, dict): 318 val_str = dict_info(v, depth + 1) 319 elif val_type_str in TENSOR_STR_TYPES: 320 val_str = tensor_info(v) 321 elif isinstance(v, list): 322 val_str = list_info(v) 323 else: 324 val_str = repr(v) 325 326 output += ( 327 f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}" 328 ) 329 330 return output
333def info_auto( 334 obj: typing.Any, 335) -> str: 336 """Automatically format an object for debugging.""" 337 if isinstance(obj, dict): 338 return dict_info(obj) 339 elif isinstance(obj, list): 340 return list_info(obj) 341 elif str(type(obj)) in TENSOR_STR_TYPES: 342 return tensor_info(obj) 343 else: 344 return repr(obj)
Automatically format an object for debugging.
347def dbg_tensor( 348 tensor: _ExpType, # numpy array or torch tensor 349) -> _ExpType: 350 """dbg function for tensors, using tensor_info formatter.""" 351 return dbg( 352 tensor, 353 formatter=tensor_info, 354 val_joiner=DBG_TENSOR_VAL_JOINER, 355 )
dbg function for tensors, using tensor_info formatter.
358def dbg_dict( 359 d: _ExpType_dict, 360) -> _ExpType_dict: 361 """dbg function for dictionaries, using dict_info formatter.""" 362 return dbg( 363 d, 364 formatter=dict_info, 365 val_joiner=DBG_TENSOR_VAL_JOINER, 366 )
dbg function for dictionaries, using dict_info formatter.
369def dbg_auto( 370 obj: _ExpType, 371) -> _ExpType: 372 """dbg function for automatic formatting based on type.""" 373 return dbg( 374 obj, 375 formatter=info_auto, 376 val_joiner=DBG_TENSOR_VAL_JOINER, 377 )
dbg function for automatic formatting based on type.
407def grep_repr( 408 obj: typing.Any, 409 pattern: str | re.Pattern[str], 410 *, 411 char_context: int | None = 20, 412 line_context: int | None = None, 413 before_context: int = 0, 414 after_context: int = 0, 415 context: int | None = None, 416 max_count: int | None = None, 417 cased: bool = False, 418 loose: bool = False, 419 line_numbers: bool = False, 420 highlight: bool = True, 421 color: str = "31", 422 separator: str = "--", 423 quiet: bool = False, 424) -> typing.List[str] | None: 425 """grep-like search on ``repr(obj)`` with improved grep-style options. 426 427 By default, string patterns are case-insensitive. Pre-compiled regex 428 patterns use their own flags. 429 430 Parameters: 431 - obj: Object to search (its repr() string is scanned) 432 - pattern: Regular expression pattern (string or pre-compiled) 433 - char_context: Characters of context before/after each match (default: 20) 434 - line_context: Lines of context before/after; overrides char_context 435 - before_context: Lines of context before match (like grep -B) 436 - after_context: Lines of context after match (like grep -A) 437 - context: Lines of context before AND after (like grep -C) 438 - max_count: Stop after this many matches 439 - cased: Force case-sensitive search for string patterns 440 - loose: Normalize spaces/punctuation for flexible matching 441 - line_numbers: Show line numbers in output 442 - highlight: Wrap matches with ANSI color codes 443 - color: ANSI color code (default: "31" for red) 444 - separator: Separator between multiple matches 445 - quiet: Return results instead of printing 446 447 Returns: 448 - None if quiet=False (prints to stdout) 449 - List[str] if quiet=True (returns formatted output lines) 450 """ 451 # Handle context parameter shortcuts 452 if context is not None: 453 before_context = after_context = context 454 455 # Prepare text and pattern 456 text: str = repr(obj) 457 if loose: 458 text = _normalize_for_loose(text) 459 460 regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose) 461 462 def _color_match(segment: str) -> str: 463 if not highlight: 464 return segment 465 return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment) 466 467 output_lines: list[str] = [] 468 match_count: int = 0 469 470 # Determine if we're using line-based context 471 using_line_context = ( 472 line_context is not None or before_context > 0 or after_context > 0 473 ) 474 475 if using_line_context: 476 lines: list[str] = text.splitlines() 477 line_starts: list[int] = [] 478 pos: int = 0 479 for line in lines: 480 line_starts.append(pos) 481 pos += len(line) + 1 # +1 for newline 482 483 processed_lines: set[int] = set() 484 485 for match in regex.finditer(text): 486 if max_count is not None and match_count >= max_count: 487 break 488 489 # Find which line contains this match 490 match_line = max( 491 i for i, start in enumerate(line_starts) if start <= match.start() 492 ) 493 494 # Calculate context range 495 ctx_before: int 496 ctx_after: int 497 if line_context is not None: 498 ctx_before = ctx_after = line_context 499 else: 500 ctx_before, ctx_after = before_context, after_context 501 502 start_line: int = max(0, match_line - ctx_before) 503 end_line: int = min(len(lines), match_line + ctx_after + 1) 504 505 # Avoid duplicate output for overlapping contexts 506 line_range: set[int] = set(range(start_line, end_line)) 507 if line_range & processed_lines: 508 continue 509 processed_lines.update(line_range) 510 511 # Format the context block 512 context_lines: list[str] = [] 513 for i in range(start_line, end_line): 514 line_text = lines[i] 515 if line_numbers: 516 line_prefix = f"{i + 1}:" 517 line_text = f"{line_prefix}{line_text}" 518 context_lines.append(_color_match(line_text)) 519 520 if output_lines and separator: 521 output_lines.append(separator) 522 output_lines.extend(context_lines) 523 match_count += 1 524 525 else: 526 # Character-based context 527 ctx: int = 0 if char_context is None else char_context 528 529 for match in regex.finditer(text): 530 if max_count is not None and match_count >= max_count: 531 break 532 533 start: int = max(0, match.start() - ctx) 534 end: int = min(len(text), match.end() + ctx) 535 snippet: str = text[start:end] 536 537 if output_lines and separator: 538 output_lines.append(separator) 539 output_lines.append(_color_match(snippet)) 540 match_count += 1 541 542 if quiet: 543 return output_lines 544 else: 545 for line in output_lines: 546 print(line) 547 return None
grep-like search on repr(obj) with improved grep-style options.
By default, string patterns are case-insensitive. Pre-compiled regex patterns use their own flags.
Parameters:
- obj: Object to search (its repr() string is scanned)
- pattern: Regular expression pattern (string or pre-compiled)
- char_context: Characters of context before/after each match (default: 20)
- line_context: Lines of context before/after; overrides char_context
- before_context: Lines of context before match (like grep -B)
- after_context: Lines of context after match (like grep -A)
- context: Lines of context before AND after (like grep -C)
- max_count: Stop after this many matches
- cased: Force case-sensitive search for string patterns
- loose: Normalize spaces/punctuation for flexible matching
- line_numbers: Show line numbers in output
- highlight: Wrap matches with ANSI color codes
- color: ANSI color code (default: "31" for red)
- separator: Separator between multiple matches
- quiet: Return results instead of printing
Returns:
- None if quiet=False (prints to stdout)
- List[str] if quiet=True (returns formatted output lines)