docs for muutils v0.9.1
View Source on GitHub

muutils.misc.func


  1from __future__ import annotations
  2import functools
  3import sys
  4from types import CodeType
  5import warnings
  6from typing import Any, Callable, Tuple, cast, TypeVar
  7
  8# TODO: we do a lot of type weirdness here that basedpyright doesn't like
  9# pyright: reportInvalidTypeForm=false
 10
 11try:
 12    if sys.version_info >= (3, 11):
 13        # 3.11+
 14        from typing import Unpack, TypeVarTuple, ParamSpec
 15    else:
 16        # 3.9+
 17        from typing_extensions import Unpack, TypeVarTuple, ParamSpec  # type: ignore[assignment]
 18except ImportError:
 19    warnings.warn(
 20        "muutils.misc.func could not import Unpack and TypeVarTuple from typing or typing_extensions, typed_lambda may not work"
 21    )
 22    ParamSpec = TypeVar  # type: ignore
 23    Unpack = Any  # type: ignore
 24    TypeVarTuple = TypeVar  # type: ignore
 25
 26
 27from muutils.errormode import ErrorMode
 28
 29warnings.warn("muutils.misc.func is experimental, use with caution")
 30
 31ReturnType = TypeVar("ReturnType")
 32T_kwarg = TypeVar("T_kwarg")
 33T_process_in = TypeVar("T_process_in")
 34T_process_out = TypeVar("T_process_out")
 35
 36FuncParams = ParamSpec("FuncParams")
 37FuncParamsPreWrap = ParamSpec("FuncParamsPreWrap")
 38
 39
 40def process_kwarg(
 41    kwarg_name: str,
 42    processor: Callable[[T_process_in], T_process_out],
 43) -> Callable[
 44    [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType]
 45]:
 46    """Decorator that applies a processor to a keyword argument.
 47
 48    The underlying function is expected to have a keyword argument
 49    (with name `kwarg_name`) of type `T_out`, but the caller provides
 50    a value of type `T_in` that is converted via `processor`.
 51
 52    # Parameters:
 53     - `kwarg_name : str`
 54        The name of the keyword argument to process.
 55     - `processor : Callable[[T_in], T_out]`
 56        A callable that converts the input value (`T_in`) into the
 57        type expected by the function (`T_out`).
 58
 59    # Returns:
 60     - A decorator that converts a function of type
 61       `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`)
 62       into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`).
 63    """
 64
 65    def decorator(
 66        func: Callable[FuncParamsPreWrap, ReturnType],
 67    ) -> Callable[FuncParams, ReturnType]:
 68        @functools.wraps(func)
 69        def wrapper(*args: Any, **kwargs: Any) -> ReturnType:
 70            if kwarg_name in kwargs:
 71                # Convert the caller’s value (of type T_in) to T_out
 72                kwargs[kwarg_name] = processor(kwargs[kwarg_name])
 73            return func(*args, **kwargs)  # type: ignore[arg-type]
 74
 75        return cast(Callable[FuncParams, ReturnType], wrapper)  # ty: ignore[invalid-type-form]
 76
 77    return decorator
 78
 79
 80# TYPING: error: Argument of type "(kwarg_name: str, validator: (T_kwarg@validate_kwarg) -> bool, description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT) -> (((() -> ReturnType@validate_kwarg)) -> (() -> ReturnType@validate_kwarg))" cannot be assigned to parameter of type "() -> ReturnType@process_kwarg"
 81# Type "(kwarg_name: str, validator: (T_kwarg@validate_kwarg) -> bool, description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT) -> (((() -> ReturnType@validate_kwarg)) -> (() -> ReturnType@validate_kwarg))" is not assignable to type "() -> ReturnType@process_kwarg"
 82#   Extra parameter "kwarg_name"
 83#   Extra parameter "validator" (reportArgumentType)
 84@process_kwarg("action", ErrorMode.from_any)  # pyright: ignore[reportArgumentType]
 85def validate_kwarg(
 86    kwarg_name: str,
 87    validator: Callable[[T_kwarg], bool],
 88    description: str | None = None,
 89    action: ErrorMode = ErrorMode.EXCEPT,
 90) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
 91    """Decorator that validates a specific keyword argument.
 92
 93    # Parameters:
 94     - `kwarg_name : str`
 95        The name of the keyword argument to validate.
 96     - `validator : Callable[[Any], bool]`
 97        A callable that returns True if the keyword argument is valid.
 98     - `description : str | None`
 99        A message template if validation fails.
100     - `action : str`
101        Either `"raise"` (default) or `"warn"`.
102
103    # Returns:
104     - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]`
105        A decorator that validates the keyword argument.
106
107    # Modifies:
108     - If validation fails and `action=="warn"`, emits a warning.
109       Otherwise, raises a ValueError.
110
111    # Usage:
112
113    ```python
114    @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}")
115    def my_func(x: int) -> int:
116        return x
117
118    assert my_func(x=1) == 1
119    ```
120
121    # Raises:
122     - `ValueError` if validation fails and `action == "raise"`.
123    """
124
125    def decorator(
126        func: Callable[FuncParams, ReturnType],
127    ) -> Callable[FuncParams, ReturnType]:
128        @functools.wraps(func)
129        def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType:  # pyright: ignore[reportUnknownParameterType]
130            if kwarg_name in kwargs:
131                value: Any = kwargs[kwarg_name]
132                if not validator(value):  # ty: ignore[invalid-argument-type]
133                    msg: str = (
134                        description.format(kwarg_name=kwarg_name, value=value)
135                        if description
136                        else f"Validation failed for keyword '{kwarg_name}' with value {value}"
137                    )
138                    if action == "warn":
139                        warnings.warn(msg, UserWarning)
140                    else:
141                        raise ValueError(msg)
142            return func(*args, **kwargs)
143
144        return cast(Callable[FuncParams, ReturnType], wrapper)
145
146    return decorator
147
148
149def replace_kwarg(
150    kwarg_name: str,
151    check: Callable[[T_kwarg], bool],
152    replacement_value: T_kwarg,
153    replace_if_missing: bool = False,
154) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
155    """Decorator that replaces a specific keyword argument value by identity comparison.
156
157    # Parameters:
158     - `kwarg_name : str`
159        The name of the keyword argument to replace.
160     - `check : Callable[[T_kwarg], bool]`
161        A callable that returns True if the keyword argument should be replaced.
162     - `replacement_value : T_kwarg`
163        The value to replace with.
164     - `replace_if_missing : bool`
165        If True, replaces the keyword argument even if it's missing.
166
167    # Returns:
168     - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]`
169        A decorator that replaces the keyword argument value.
170
171    # Modifies:
172     - Updates `kwargs[kwarg_name]` if its value is `default_value`.
173
174    # Usage:
175
176    ```python
177    @replace_kwarg("x", None, "default_string")
178    def my_func(*, x: str | None = None) -> str:
179        return x
180
181    assert my_func(x=None) == "default_string"
182    ```
183    """
184
185    def decorator(
186        func: Callable[FuncParams, ReturnType],
187    ) -> Callable[FuncParams, ReturnType]:
188        @functools.wraps(func)
189        def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType:  # pyright: ignore[reportUnknownParameterType]
190            if kwarg_name in kwargs:
191                # TODO: no way to type hint this, I think
192                if check(kwargs[kwarg_name]):  # type: ignore[arg-type]
193                    kwargs[kwarg_name] = replacement_value  # ty: ignore[invalid-assignment]
194            elif replace_if_missing and kwarg_name not in kwargs:
195                kwargs[kwarg_name] = replacement_value  # ty: ignore[invalid-assignment]
196            return func(*args, **kwargs)
197
198        return cast(Callable[FuncParams, ReturnType], wrapper)
199
200    return decorator
201
202
203def is_none(value: Any) -> bool:
204    return value is None
205
206
207def always_true(value: Any) -> bool:
208    return True
209
210
211def always_false(value: Any) -> bool:
212    return False
213
214
215def format_docstring(
216    **fmt_kwargs: Any,
217) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
218    """Decorator that formats a function's docstring with the provided keyword arguments."""
219
220    def decorator(
221        func: Callable[FuncParams, ReturnType],
222    ) -> Callable[FuncParams, ReturnType]:
223        if func.__doc__ is not None:
224            func.__doc__ = func.__doc__.format(**fmt_kwargs)
225        return func
226
227    return decorator
228
229
230# TODO: no way to make the type system understand this afaik
231LambdaArgs = TypeVarTuple("LambdaArgs")
232LambdaArgsTypes = TypeVar("LambdaArgsTypes", bound=Tuple[type, ...])
233
234
235def typed_lambda(  # pyright: ignore[reportUnknownParameterType]
236    fn: Callable[[Unpack[LambdaArgs]], ReturnType],
237    in_types: LambdaArgsTypes,  # pyright: ignore[reportInvalidTypeVarUse]
238    out_type: type[ReturnType],
239) -> Callable[[Unpack[LambdaArgs]], ReturnType]:
240    """Wraps a lambda function with type hints.
241
242    # Parameters:
243     - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]`
244        The lambda function to wrap.
245     - `in_types : tuple[type, ...]`
246        Tuple of input types.
247     - `out_type : type[ReturnType]`
248        The output type.
249
250    # Returns:
251     - `Callable[..., ReturnType]`
252        A new function with annotations matching the given signature.
253
254    # Usage:
255
256    ```python
257    add = typed_lambda(lambda x, y: x + y, (int, int), int)
258    assert add(1, 2) == 3
259    ```
260
261    # Raises:
262     - `ValueError` if the number of input types doesn't match the lambda's parameters.
263    """
264    # it will just error here if fn.__code__ doesn't exist
265    code: CodeType = fn.__code__  # type: ignore[unresolved-attribute]
266    n_params: int = code.co_argcount
267
268    if len(in_types) != n_params:
269        raise ValueError(
270            f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})"
271        )
272
273    param_names: tuple[str, ...] = code.co_varnames[:n_params]
274    annotations: dict[str, type] = {  # type: ignore[var-annotated]
275        name: typ
276        for name, typ in zip(param_names, in_types)  # type: ignore[arg-type]
277    }
278    annotations["return"] = out_type
279
280    @functools.wraps(fn)
281    def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType:  # pyright: ignore[reportUnknownParameterType]
282        return fn(*args)
283
284    wrapped.__annotations__ = annotations
285    return wrapped

FuncParams = ~FuncParams
FuncParamsPreWrap = ~FuncParamsPreWrap
def process_kwarg( kwarg_name: str, processor: Callable[[~T_process_in], ~T_process_out]) -> Callable[[Callable[~FuncParamsPreWrap, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
41def process_kwarg(
42    kwarg_name: str,
43    processor: Callable[[T_process_in], T_process_out],
44) -> Callable[
45    [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType]
46]:
47    """Decorator that applies a processor to a keyword argument.
48
49    The underlying function is expected to have a keyword argument
50    (with name `kwarg_name`) of type `T_out`, but the caller provides
51    a value of type `T_in` that is converted via `processor`.
52
53    # Parameters:
54     - `kwarg_name : str`
55        The name of the keyword argument to process.
56     - `processor : Callable[[T_in], T_out]`
57        A callable that converts the input value (`T_in`) into the
58        type expected by the function (`T_out`).
59
60    # Returns:
61     - A decorator that converts a function of type
62       `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`)
63       into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`).
64    """
65
66    def decorator(
67        func: Callable[FuncParamsPreWrap, ReturnType],
68    ) -> Callable[FuncParams, ReturnType]:
69        @functools.wraps(func)
70        def wrapper(*args: Any, **kwargs: Any) -> ReturnType:
71            if kwarg_name in kwargs:
72                # Convert the caller’s value (of type T_in) to T_out
73                kwargs[kwarg_name] = processor(kwargs[kwarg_name])
74            return func(*args, **kwargs)  # type: ignore[arg-type]
75
76        return cast(Callable[FuncParams, ReturnType], wrapper)  # ty: ignore[invalid-type-form]
77
78    return decorator

Decorator that applies a processor to a keyword argument.

The underlying function is expected to have a keyword argument (with name kwarg_name) of type T_out, but the caller provides a value of type T_in that is converted via processor.

Parameters:

  • kwarg_name : str The name of the keyword argument to process.
  • processor : Callable[[T_in], T_out] A callable that converts the input value (T_in) into the type expected by the function (T_out).

Returns:

  • A decorator that converts a function of type Callable[OutputParams, ReturnType] (expecting kwarg_name of type T_out) into one of type Callable[InputParams, ReturnType] (accepting kwarg_name of type T_in).
@process_kwarg('action', ErrorMode.from_any)
def validate_kwarg( kwarg_name: str, validator: Callable[[~T_kwarg], bool], description: str | None = None, action: muutils.errormode.ErrorMode = ErrorMode.Except) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
 85@process_kwarg("action", ErrorMode.from_any)  # pyright: ignore[reportArgumentType]
 86def validate_kwarg(
 87    kwarg_name: str,
 88    validator: Callable[[T_kwarg], bool],
 89    description: str | None = None,
 90    action: ErrorMode = ErrorMode.EXCEPT,
 91) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
 92    """Decorator that validates a specific keyword argument.
 93
 94    # Parameters:
 95     - `kwarg_name : str`
 96        The name of the keyword argument to validate.
 97     - `validator : Callable[[Any], bool]`
 98        A callable that returns True if the keyword argument is valid.
 99     - `description : str | None`
100        A message template if validation fails.
101     - `action : str`
102        Either `"raise"` (default) or `"warn"`.
103
104    # Returns:
105     - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]`
106        A decorator that validates the keyword argument.
107
108    # Modifies:
109     - If validation fails and `action=="warn"`, emits a warning.
110       Otherwise, raises a ValueError.
111
112    # Usage:
113
114    ```python
115    @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}")
116    def my_func(x: int) -> int:
117        return x
118
119    assert my_func(x=1) == 1
120    ```
121
122    # Raises:
123     - `ValueError` if validation fails and `action == "raise"`.
124    """
125
126    def decorator(
127        func: Callable[FuncParams, ReturnType],
128    ) -> Callable[FuncParams, ReturnType]:
129        @functools.wraps(func)
130        def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType:  # pyright: ignore[reportUnknownParameterType]
131            if kwarg_name in kwargs:
132                value: Any = kwargs[kwarg_name]
133                if not validator(value):  # ty: ignore[invalid-argument-type]
134                    msg: str = (
135                        description.format(kwarg_name=kwarg_name, value=value)
136                        if description
137                        else f"Validation failed for keyword '{kwarg_name}' with value {value}"
138                    )
139                    if action == "warn":
140                        warnings.warn(msg, UserWarning)
141                    else:
142                        raise ValueError(msg)
143            return func(*args, **kwargs)
144
145        return cast(Callable[FuncParams, ReturnType], wrapper)
146
147    return decorator

Decorator that validates a specific keyword argument.

Parameters:

  • kwarg_name : str The name of the keyword argument to validate.
  • validator : Callable[[Any], bool] A callable that returns True if the keyword argument is valid.
  • description : str | None A message template if validation fails.
  • action : str Either "raise" (default) or "warn".

Returns:

  • Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]] A decorator that validates the keyword argument.

Modifies:

  • If validation fails and action=="warn", emits a warning. Otherwise, raises a ValueError.

Usage:

@validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}")
def my_func(x: int) -> int:
    return x

assert my_func(x=1) == 1

Raises:

  • ValueError if validation fails and action == "raise".
def replace_kwarg( kwarg_name: str, check: Callable[[~T_kwarg], bool], replacement_value: ~T_kwarg, replace_if_missing: bool = False) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
150def replace_kwarg(
151    kwarg_name: str,
152    check: Callable[[T_kwarg], bool],
153    replacement_value: T_kwarg,
154    replace_if_missing: bool = False,
155) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
156    """Decorator that replaces a specific keyword argument value by identity comparison.
157
158    # Parameters:
159     - `kwarg_name : str`
160        The name of the keyword argument to replace.
161     - `check : Callable[[T_kwarg], bool]`
162        A callable that returns True if the keyword argument should be replaced.
163     - `replacement_value : T_kwarg`
164        The value to replace with.
165     - `replace_if_missing : bool`
166        If True, replaces the keyword argument even if it's missing.
167
168    # Returns:
169     - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]`
170        A decorator that replaces the keyword argument value.
171
172    # Modifies:
173     - Updates `kwargs[kwarg_name]` if its value is `default_value`.
174
175    # Usage:
176
177    ```python
178    @replace_kwarg("x", None, "default_string")
179    def my_func(*, x: str | None = None) -> str:
180        return x
181
182    assert my_func(x=None) == "default_string"
183    ```
184    """
185
186    def decorator(
187        func: Callable[FuncParams, ReturnType],
188    ) -> Callable[FuncParams, ReturnType]:
189        @functools.wraps(func)
190        def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType:  # pyright: ignore[reportUnknownParameterType]
191            if kwarg_name in kwargs:
192                # TODO: no way to type hint this, I think
193                if check(kwargs[kwarg_name]):  # type: ignore[arg-type]
194                    kwargs[kwarg_name] = replacement_value  # ty: ignore[invalid-assignment]
195            elif replace_if_missing and kwarg_name not in kwargs:
196                kwargs[kwarg_name] = replacement_value  # ty: ignore[invalid-assignment]
197            return func(*args, **kwargs)
198
199        return cast(Callable[FuncParams, ReturnType], wrapper)
200
201    return decorator

Decorator that replaces a specific keyword argument value by identity comparison.

Parameters:

  • kwarg_name : str The name of the keyword argument to replace.
  • check : Callable[[T_kwarg], bool] A callable that returns True if the keyword argument should be replaced.
  • replacement_value : T_kwarg The value to replace with.
  • replace_if_missing : bool If True, replaces the keyword argument even if it's missing.

Returns:

  • Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]] A decorator that replaces the keyword argument value.

Modifies:

  • Updates kwargs[kwarg_name] if its value is default_value.

Usage:

@replace_kwarg("x", None, "default_string")
def my_func(*, x: str | None = None) -> str:
    return x

assert my_func(x=None) == "default_string"
def is_none(value: Any) -> bool:
204def is_none(value: Any) -> bool:
205    return value is None
def always_true(value: Any) -> bool:
208def always_true(value: Any) -> bool:
209    return True
def always_false(value: Any) -> bool:
212def always_false(value: Any) -> bool:
213    return False
def format_docstring( **fmt_kwargs: Any) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
216def format_docstring(
217    **fmt_kwargs: Any,
218) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
219    """Decorator that formats a function's docstring with the provided keyword arguments."""
220
221    def decorator(
222        func: Callable[FuncParams, ReturnType],
223    ) -> Callable[FuncParams, ReturnType]:
224        if func.__doc__ is not None:
225            func.__doc__ = func.__doc__.format(**fmt_kwargs)
226        return func
227
228    return decorator

Decorator that formats a function's docstring with the provided keyword arguments.

LambdaArgs = LambdaArgs
def typed_lambda( fn: Callable[[Unpack[LambdaArgs]], ~ReturnType], in_types: ~LambdaArgsTypes, out_type: type[~ReturnType]) -> Callable[[Unpack[LambdaArgs]], ~ReturnType]:
236def typed_lambda(  # pyright: ignore[reportUnknownParameterType]
237    fn: Callable[[Unpack[LambdaArgs]], ReturnType],
238    in_types: LambdaArgsTypes,  # pyright: ignore[reportInvalidTypeVarUse]
239    out_type: type[ReturnType],
240) -> Callable[[Unpack[LambdaArgs]], ReturnType]:
241    """Wraps a lambda function with type hints.
242
243    # Parameters:
244     - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]`
245        The lambda function to wrap.
246     - `in_types : tuple[type, ...]`
247        Tuple of input types.
248     - `out_type : type[ReturnType]`
249        The output type.
250
251    # Returns:
252     - `Callable[..., ReturnType]`
253        A new function with annotations matching the given signature.
254
255    # Usage:
256
257    ```python
258    add = typed_lambda(lambda x, y: x + y, (int, int), int)
259    assert add(1, 2) == 3
260    ```
261
262    # Raises:
263     - `ValueError` if the number of input types doesn't match the lambda's parameters.
264    """
265    # it will just error here if fn.__code__ doesn't exist
266    code: CodeType = fn.__code__  # type: ignore[unresolved-attribute]
267    n_params: int = code.co_argcount
268
269    if len(in_types) != n_params:
270        raise ValueError(
271            f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})"
272        )
273
274    param_names: tuple[str, ...] = code.co_varnames[:n_params]
275    annotations: dict[str, type] = {  # type: ignore[var-annotated]
276        name: typ
277        for name, typ in zip(param_names, in_types)  # type: ignore[arg-type]
278    }
279    annotations["return"] = out_type
280
281    @functools.wraps(fn)
282    def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType:  # pyright: ignore[reportUnknownParameterType]
283        return fn(*args)
284
285    wrapped.__annotations__ = annotations
286    return wrapped

Wraps a lambda function with type hints.

Parameters:

  • fn : Callable[[Unpack[LambdaArgs]], ReturnType] The lambda function to wrap.
  • in_types : tuple[type, ...] Tuple of input types.
  • out_type : type[ReturnType] The output type.

Returns:

  • Callable[..., ReturnType] A new function with annotations matching the given signature.

Usage:

add = typed_lambda(lambda x, y: x + y, (int, int), int)
assert add(1, 2) == 3

Raises:

  • ValueError if the number of input types doesn't match the lambda's parameters.