muutils.misc.func
1from __future__ import annotations 2import functools 3import sys 4from types import CodeType 5import warnings 6from typing import Any, Callable, Tuple, cast, TypeVar 7 8# TODO: we do a lot of type weirdness here that basedpyright doesn't like 9# pyright: reportInvalidTypeForm=false 10 11try: 12 if sys.version_info >= (3, 11): 13 # 3.11+ 14 from typing import Unpack, TypeVarTuple, ParamSpec 15 else: 16 # 3.9+ 17 from typing_extensions import Unpack, TypeVarTuple, ParamSpec # type: ignore[assignment] 18except ImportError: 19 warnings.warn( 20 "muutils.misc.func could not import Unpack and TypeVarTuple from typing or typing_extensions, typed_lambda may not work" 21 ) 22 ParamSpec = TypeVar # type: ignore 23 Unpack = Any # type: ignore 24 TypeVarTuple = TypeVar # type: ignore 25 26 27from muutils.errormode import ErrorMode 28 29warnings.warn("muutils.misc.func is experimental, use with caution") 30 31ReturnType = TypeVar("ReturnType") 32T_kwarg = TypeVar("T_kwarg") 33T_process_in = TypeVar("T_process_in") 34T_process_out = TypeVar("T_process_out") 35 36FuncParams = ParamSpec("FuncParams") 37FuncParamsPreWrap = ParamSpec("FuncParamsPreWrap") 38 39 40def process_kwarg( 41 kwarg_name: str, 42 processor: Callable[[T_process_in], T_process_out], 43) -> Callable[ 44 [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType] 45]: 46 """Decorator that applies a processor to a keyword argument. 47 48 The underlying function is expected to have a keyword argument 49 (with name `kwarg_name`) of type `T_out`, but the caller provides 50 a value of type `T_in` that is converted via `processor`. 51 52 # Parameters: 53 - `kwarg_name : str` 54 The name of the keyword argument to process. 55 - `processor : Callable[[T_in], T_out]` 56 A callable that converts the input value (`T_in`) into the 57 type expected by the function (`T_out`). 58 59 # Returns: 60 - A decorator that converts a function of type 61 `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`) 62 into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`). 63 """ 64 65 def decorator( 66 func: Callable[FuncParamsPreWrap, ReturnType], 67 ) -> Callable[FuncParams, ReturnType]: 68 @functools.wraps(func) 69 def wrapper(*args: Any, **kwargs: Any) -> ReturnType: 70 if kwarg_name in kwargs: 71 # Convert the caller’s value (of type T_in) to T_out 72 kwargs[kwarg_name] = processor(kwargs[kwarg_name]) 73 return func(*args, **kwargs) # type: ignore[arg-type] 74 75 return cast(Callable[FuncParams, ReturnType], wrapper) # ty: ignore[invalid-type-form] 76 77 return decorator 78 79 80# TYPING: error: Argument of type "(kwarg_name: str, validator: (T_kwarg@validate_kwarg) -> bool, description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT) -> (((() -> ReturnType@validate_kwarg)) -> (() -> ReturnType@validate_kwarg))" cannot be assigned to parameter of type "() -> ReturnType@process_kwarg" 81# Type "(kwarg_name: str, validator: (T_kwarg@validate_kwarg) -> bool, description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT) -> (((() -> ReturnType@validate_kwarg)) -> (() -> ReturnType@validate_kwarg))" is not assignable to type "() -> ReturnType@process_kwarg" 82# Extra parameter "kwarg_name" 83# Extra parameter "validator" (reportArgumentType) 84@process_kwarg("action", ErrorMode.from_any) # pyright: ignore[reportArgumentType] 85def validate_kwarg( 86 kwarg_name: str, 87 validator: Callable[[T_kwarg], bool], 88 description: str | None = None, 89 action: ErrorMode = ErrorMode.EXCEPT, 90) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 91 """Decorator that validates a specific keyword argument. 92 93 # Parameters: 94 - `kwarg_name : str` 95 The name of the keyword argument to validate. 96 - `validator : Callable[[Any], bool]` 97 A callable that returns True if the keyword argument is valid. 98 - `description : str | None` 99 A message template if validation fails. 100 - `action : str` 101 Either `"raise"` (default) or `"warn"`. 102 103 # Returns: 104 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 105 A decorator that validates the keyword argument. 106 107 # Modifies: 108 - If validation fails and `action=="warn"`, emits a warning. 109 Otherwise, raises a ValueError. 110 111 # Usage: 112 113 ```python 114 @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}") 115 def my_func(x: int) -> int: 116 return x 117 118 assert my_func(x=1) == 1 119 ``` 120 121 # Raises: 122 - `ValueError` if validation fails and `action == "raise"`. 123 """ 124 125 def decorator( 126 func: Callable[FuncParams, ReturnType], 127 ) -> Callable[FuncParams, ReturnType]: 128 @functools.wraps(func) 129 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType] 130 if kwarg_name in kwargs: 131 value: Any = kwargs[kwarg_name] 132 if not validator(value): # ty: ignore[invalid-argument-type] 133 msg: str = ( 134 description.format(kwarg_name=kwarg_name, value=value) 135 if description 136 else f"Validation failed for keyword '{kwarg_name}' with value {value}" 137 ) 138 if action == "warn": 139 warnings.warn(msg, UserWarning) 140 else: 141 raise ValueError(msg) 142 return func(*args, **kwargs) 143 144 return cast(Callable[FuncParams, ReturnType], wrapper) 145 146 return decorator 147 148 149def replace_kwarg( 150 kwarg_name: str, 151 check: Callable[[T_kwarg], bool], 152 replacement_value: T_kwarg, 153 replace_if_missing: bool = False, 154) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 155 """Decorator that replaces a specific keyword argument value by identity comparison. 156 157 # Parameters: 158 - `kwarg_name : str` 159 The name of the keyword argument to replace. 160 - `check : Callable[[T_kwarg], bool]` 161 A callable that returns True if the keyword argument should be replaced. 162 - `replacement_value : T_kwarg` 163 The value to replace with. 164 - `replace_if_missing : bool` 165 If True, replaces the keyword argument even if it's missing. 166 167 # Returns: 168 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 169 A decorator that replaces the keyword argument value. 170 171 # Modifies: 172 - Updates `kwargs[kwarg_name]` if its value is `default_value`. 173 174 # Usage: 175 176 ```python 177 @replace_kwarg("x", None, "default_string") 178 def my_func(*, x: str | None = None) -> str: 179 return x 180 181 assert my_func(x=None) == "default_string" 182 ``` 183 """ 184 185 def decorator( 186 func: Callable[FuncParams, ReturnType], 187 ) -> Callable[FuncParams, ReturnType]: 188 @functools.wraps(func) 189 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType] 190 if kwarg_name in kwargs: 191 # TODO: no way to type hint this, I think 192 if check(kwargs[kwarg_name]): # type: ignore[arg-type] 193 kwargs[kwarg_name] = replacement_value # ty: ignore[invalid-assignment] 194 elif replace_if_missing and kwarg_name not in kwargs: 195 kwargs[kwarg_name] = replacement_value # ty: ignore[invalid-assignment] 196 return func(*args, **kwargs) 197 198 return cast(Callable[FuncParams, ReturnType], wrapper) 199 200 return decorator 201 202 203def is_none(value: Any) -> bool: 204 return value is None 205 206 207def always_true(value: Any) -> bool: 208 return True 209 210 211def always_false(value: Any) -> bool: 212 return False 213 214 215def format_docstring( 216 **fmt_kwargs: Any, 217) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 218 """Decorator that formats a function's docstring with the provided keyword arguments.""" 219 220 def decorator( 221 func: Callable[FuncParams, ReturnType], 222 ) -> Callable[FuncParams, ReturnType]: 223 if func.__doc__ is not None: 224 func.__doc__ = func.__doc__.format(**fmt_kwargs) 225 return func 226 227 return decorator 228 229 230# TODO: no way to make the type system understand this afaik 231LambdaArgs = TypeVarTuple("LambdaArgs") 232LambdaArgsTypes = TypeVar("LambdaArgsTypes", bound=Tuple[type, ...]) 233 234 235def typed_lambda( # pyright: ignore[reportUnknownParameterType] 236 fn: Callable[[Unpack[LambdaArgs]], ReturnType], 237 in_types: LambdaArgsTypes, # pyright: ignore[reportInvalidTypeVarUse] 238 out_type: type[ReturnType], 239) -> Callable[[Unpack[LambdaArgs]], ReturnType]: 240 """Wraps a lambda function with type hints. 241 242 # Parameters: 243 - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]` 244 The lambda function to wrap. 245 - `in_types : tuple[type, ...]` 246 Tuple of input types. 247 - `out_type : type[ReturnType]` 248 The output type. 249 250 # Returns: 251 - `Callable[..., ReturnType]` 252 A new function with annotations matching the given signature. 253 254 # Usage: 255 256 ```python 257 add = typed_lambda(lambda x, y: x + y, (int, int), int) 258 assert add(1, 2) == 3 259 ``` 260 261 # Raises: 262 - `ValueError` if the number of input types doesn't match the lambda's parameters. 263 """ 264 # it will just error here if fn.__code__ doesn't exist 265 code: CodeType = fn.__code__ # type: ignore[unresolved-attribute] 266 n_params: int = code.co_argcount 267 268 if len(in_types) != n_params: 269 raise ValueError( 270 f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})" 271 ) 272 273 param_names: tuple[str, ...] = code.co_varnames[:n_params] 274 annotations: dict[str, type] = { # type: ignore[var-annotated] 275 name: typ 276 for name, typ in zip(param_names, in_types) # type: ignore[arg-type] 277 } 278 annotations["return"] = out_type 279 280 @functools.wraps(fn) 281 def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType: # pyright: ignore[reportUnknownParameterType] 282 return fn(*args) 283 284 wrapped.__annotations__ = annotations 285 return wrapped
FuncParams =
~FuncParams
FuncParamsPreWrap =
~FuncParamsPreWrap
def
process_kwarg( kwarg_name: str, processor: Callable[[~T_process_in], ~T_process_out]) -> Callable[[Callable[~FuncParamsPreWrap, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
41def process_kwarg( 42 kwarg_name: str, 43 processor: Callable[[T_process_in], T_process_out], 44) -> Callable[ 45 [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType] 46]: 47 """Decorator that applies a processor to a keyword argument. 48 49 The underlying function is expected to have a keyword argument 50 (with name `kwarg_name`) of type `T_out`, but the caller provides 51 a value of type `T_in` that is converted via `processor`. 52 53 # Parameters: 54 - `kwarg_name : str` 55 The name of the keyword argument to process. 56 - `processor : Callable[[T_in], T_out]` 57 A callable that converts the input value (`T_in`) into the 58 type expected by the function (`T_out`). 59 60 # Returns: 61 - A decorator that converts a function of type 62 `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`) 63 into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`). 64 """ 65 66 def decorator( 67 func: Callable[FuncParamsPreWrap, ReturnType], 68 ) -> Callable[FuncParams, ReturnType]: 69 @functools.wraps(func) 70 def wrapper(*args: Any, **kwargs: Any) -> ReturnType: 71 if kwarg_name in kwargs: 72 # Convert the caller’s value (of type T_in) to T_out 73 kwargs[kwarg_name] = processor(kwargs[kwarg_name]) 74 return func(*args, **kwargs) # type: ignore[arg-type] 75 76 return cast(Callable[FuncParams, ReturnType], wrapper) # ty: ignore[invalid-type-form] 77 78 return decorator
Decorator that applies a processor to a keyword argument.
The underlying function is expected to have a keyword argument
(with name kwarg_name) of type T_out, but the caller provides
a value of type T_in that is converted via processor.
Parameters:
kwarg_name : strThe name of the keyword argument to process.processor : Callable[[T_in], T_out]A callable that converts the input value (T_in) into the type expected by the function (T_out).
Returns:
- A decorator that converts a function of type
Callable[OutputParams, ReturnType](expectingkwarg_nameof typeT_out) into one of typeCallable[InputParams, ReturnType](acceptingkwarg_nameof typeT_in).
@process_kwarg('action', ErrorMode.from_any)
def
validate_kwarg( kwarg_name: str, validator: Callable[[~T_kwarg], bool], description: str | None = None, action: muutils.errormode.ErrorMode = ErrorMode.Except) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
85@process_kwarg("action", ErrorMode.from_any) # pyright: ignore[reportArgumentType] 86def validate_kwarg( 87 kwarg_name: str, 88 validator: Callable[[T_kwarg], bool], 89 description: str | None = None, 90 action: ErrorMode = ErrorMode.EXCEPT, 91) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 92 """Decorator that validates a specific keyword argument. 93 94 # Parameters: 95 - `kwarg_name : str` 96 The name of the keyword argument to validate. 97 - `validator : Callable[[Any], bool]` 98 A callable that returns True if the keyword argument is valid. 99 - `description : str | None` 100 A message template if validation fails. 101 - `action : str` 102 Either `"raise"` (default) or `"warn"`. 103 104 # Returns: 105 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 106 A decorator that validates the keyword argument. 107 108 # Modifies: 109 - If validation fails and `action=="warn"`, emits a warning. 110 Otherwise, raises a ValueError. 111 112 # Usage: 113 114 ```python 115 @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}") 116 def my_func(x: int) -> int: 117 return x 118 119 assert my_func(x=1) == 1 120 ``` 121 122 # Raises: 123 - `ValueError` if validation fails and `action == "raise"`. 124 """ 125 126 def decorator( 127 func: Callable[FuncParams, ReturnType], 128 ) -> Callable[FuncParams, ReturnType]: 129 @functools.wraps(func) 130 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType] 131 if kwarg_name in kwargs: 132 value: Any = kwargs[kwarg_name] 133 if not validator(value): # ty: ignore[invalid-argument-type] 134 msg: str = ( 135 description.format(kwarg_name=kwarg_name, value=value) 136 if description 137 else f"Validation failed for keyword '{kwarg_name}' with value {value}" 138 ) 139 if action == "warn": 140 warnings.warn(msg, UserWarning) 141 else: 142 raise ValueError(msg) 143 return func(*args, **kwargs) 144 145 return cast(Callable[FuncParams, ReturnType], wrapper) 146 147 return decorator
Decorator that validates a specific keyword argument.
Parameters:
kwarg_name : strThe name of the keyword argument to validate.validator : Callable[[Any], bool]A callable that returns True if the keyword argument is valid.description : str | NoneA message template if validation fails.action : strEither"raise"(default) or"warn".
Returns:
Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]A decorator that validates the keyword argument.
Modifies:
- If validation fails and
action=="warn", emits a warning. Otherwise, raises a ValueError.
Usage:
@validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}")
def my_func(x: int) -> int:
return x
assert my_func(x=1) == 1
Raises:
ValueErrorif validation fails andaction == "raise".
def
replace_kwarg( kwarg_name: str, check: Callable[[~T_kwarg], bool], replacement_value: ~T_kwarg, replace_if_missing: bool = False) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
150def replace_kwarg( 151 kwarg_name: str, 152 check: Callable[[T_kwarg], bool], 153 replacement_value: T_kwarg, 154 replace_if_missing: bool = False, 155) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 156 """Decorator that replaces a specific keyword argument value by identity comparison. 157 158 # Parameters: 159 - `kwarg_name : str` 160 The name of the keyword argument to replace. 161 - `check : Callable[[T_kwarg], bool]` 162 A callable that returns True if the keyword argument should be replaced. 163 - `replacement_value : T_kwarg` 164 The value to replace with. 165 - `replace_if_missing : bool` 166 If True, replaces the keyword argument even if it's missing. 167 168 # Returns: 169 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 170 A decorator that replaces the keyword argument value. 171 172 # Modifies: 173 - Updates `kwargs[kwarg_name]` if its value is `default_value`. 174 175 # Usage: 176 177 ```python 178 @replace_kwarg("x", None, "default_string") 179 def my_func(*, x: str | None = None) -> str: 180 return x 181 182 assert my_func(x=None) == "default_string" 183 ``` 184 """ 185 186 def decorator( 187 func: Callable[FuncParams, ReturnType], 188 ) -> Callable[FuncParams, ReturnType]: 189 @functools.wraps(func) 190 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType] 191 if kwarg_name in kwargs: 192 # TODO: no way to type hint this, I think 193 if check(kwargs[kwarg_name]): # type: ignore[arg-type] 194 kwargs[kwarg_name] = replacement_value # ty: ignore[invalid-assignment] 195 elif replace_if_missing and kwarg_name not in kwargs: 196 kwargs[kwarg_name] = replacement_value # ty: ignore[invalid-assignment] 197 return func(*args, **kwargs) 198 199 return cast(Callable[FuncParams, ReturnType], wrapper) 200 201 return decorator
Decorator that replaces a specific keyword argument value by identity comparison.
Parameters:
kwarg_name : strThe name of the keyword argument to replace.check : Callable[[T_kwarg], bool]A callable that returns True if the keyword argument should be replaced.replacement_value : T_kwargThe value to replace with.replace_if_missing : boolIf True, replaces the keyword argument even if it's missing.
Returns:
Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]A decorator that replaces the keyword argument value.
Modifies:
- Updates
kwargs[kwarg_name]if its value isdefault_value.
Usage:
@replace_kwarg("x", None, "default_string")
def my_func(*, x: str | None = None) -> str:
return x
assert my_func(x=None) == "default_string"
def
is_none(value: Any) -> bool:
def
always_true(value: Any) -> bool:
def
always_false(value: Any) -> bool:
def
format_docstring( **fmt_kwargs: Any) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
216def format_docstring( 217 **fmt_kwargs: Any, 218) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 219 """Decorator that formats a function's docstring with the provided keyword arguments.""" 220 221 def decorator( 222 func: Callable[FuncParams, ReturnType], 223 ) -> Callable[FuncParams, ReturnType]: 224 if func.__doc__ is not None: 225 func.__doc__ = func.__doc__.format(**fmt_kwargs) 226 return func 227 228 return decorator
Decorator that formats a function's docstring with the provided keyword arguments.
LambdaArgs =
LambdaArgs
def
typed_lambda( fn: Callable[[Unpack[LambdaArgs]], ~ReturnType], in_types: ~LambdaArgsTypes, out_type: type[~ReturnType]) -> Callable[[Unpack[LambdaArgs]], ~ReturnType]:
236def typed_lambda( # pyright: ignore[reportUnknownParameterType] 237 fn: Callable[[Unpack[LambdaArgs]], ReturnType], 238 in_types: LambdaArgsTypes, # pyright: ignore[reportInvalidTypeVarUse] 239 out_type: type[ReturnType], 240) -> Callable[[Unpack[LambdaArgs]], ReturnType]: 241 """Wraps a lambda function with type hints. 242 243 # Parameters: 244 - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]` 245 The lambda function to wrap. 246 - `in_types : tuple[type, ...]` 247 Tuple of input types. 248 - `out_type : type[ReturnType]` 249 The output type. 250 251 # Returns: 252 - `Callable[..., ReturnType]` 253 A new function with annotations matching the given signature. 254 255 # Usage: 256 257 ```python 258 add = typed_lambda(lambda x, y: x + y, (int, int), int) 259 assert add(1, 2) == 3 260 ``` 261 262 # Raises: 263 - `ValueError` if the number of input types doesn't match the lambda's parameters. 264 """ 265 # it will just error here if fn.__code__ doesn't exist 266 code: CodeType = fn.__code__ # type: ignore[unresolved-attribute] 267 n_params: int = code.co_argcount 268 269 if len(in_types) != n_params: 270 raise ValueError( 271 f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})" 272 ) 273 274 param_names: tuple[str, ...] = code.co_varnames[:n_params] 275 annotations: dict[str, type] = { # type: ignore[var-annotated] 276 name: typ 277 for name, typ in zip(param_names, in_types) # type: ignore[arg-type] 278 } 279 annotations["return"] = out_type 280 281 @functools.wraps(fn) 282 def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType: # pyright: ignore[reportUnknownParameterType] 283 return fn(*args) 284 285 wrapped.__annotations__ = annotations 286 return wrapped
Wraps a lambda function with type hints.
Parameters:
fn : Callable[[Unpack[LambdaArgs]], ReturnType]The lambda function to wrap.in_types : tuple[type, ...]Tuple of input types.out_type : type[ReturnType]The output type.
Returns:
Callable[..., ReturnType]A new function with annotations matching the given signature.
Usage:
add = typed_lambda(lambda x, y: x + y, (int, int), int)
assert add(1, 2) == 3
Raises:
ValueErrorif the number of input types doesn't match the lambda's parameters.