docs for muutils v0.9.1
View Source on GitHub

muutils.misc.classes


  1from __future__ import annotations
  2
  3from typing import (
  4    Iterable,
  5    Any,
  6    Protocol,
  7    ClassVar,
  8    runtime_checkable,
  9)
 10
 11from muutils.misc.sequence import flatten
 12
 13
 14def is_abstract(cls: type) -> bool:
 15    """
 16    Returns if a class is abstract.
 17    """
 18    if not hasattr(cls, "__abstractmethods__"):
 19        return False  # an ordinary class
 20    elif len(cls.__abstractmethods__) == 0:  # type: ignore[invalid-argument-type] # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
 21        return False  # a concrete implementation of an abstract class
 22    else:
 23        return True  # an abstract class
 24
 25
 26def get_all_subclasses(class_: type, include_self=False) -> set[type]:
 27    """
 28    Returns a set containing all child classes in the subclass graph of `class_`.
 29    I.e., includes subclasses of subclasses, etc.
 30
 31    # Parameters
 32    - `include_self`: Whether to include `class_` itself in the returned set
 33    - `class_`: Superclass
 34
 35    # Development
 36    Since most class hierarchies are small, the inefficiencies of the existing recursive implementation aren't problematic.
 37    It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
 38    """
 39    subs: set[type] = set(
 40        flatten(
 41            get_all_subclasses(sub, include_self=True)
 42            for sub in class_.__subclasses__()
 43            if sub is not None
 44        )
 45    )
 46    if include_self:
 47        subs.add(class_)
 48    return subs
 49
 50
 51def isinstance_by_type_name(o: object, type_name: str):
 52    """Behaves like stdlib `isinstance` except it accepts a string representation of the type rather than the type itself.
 53    This is a hacky function intended to circumvent the need to import a type into a module.
 54    It is susceptible to type name collisions.
 55
 56    # Parameters
 57    `o`: Object (not the type itself) whose type to interrogate
 58    `type_name`: The string returned by `type_.__name__`.
 59    Generic types are not supported, only types that would appear in `type_.__mro__`.
 60    """
 61    return type_name in {s.__name__ for s in type(o).__mro__}
 62
 63
 64# dataclass magic
 65# --------------------------------------------------------------------------------
 66
 67
 68@runtime_checkable
 69class IsDataclass(Protocol):
 70    # Generic type for any dataclass instance
 71    # https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass
 72    __dataclass_fields__: ClassVar[dict[str, Any]]  # pyright: ignore[reportExplicitAny]
 73
 74
 75def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[Any]:  # pyright: ignore[reportExplicitAny]
 76    """Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself.
 77    The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical.
 78    Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen.
 79    """
 80    # TYPING: ty gives @Todo here
 81    return (  # type: ignore[invalid-return-type]
 82        *(
 83            getattr(dc, fld.name)
 84            for fld in filter(lambda x: x.compare, dc.__dataclass_fields__.values())
 85        ),
 86        type(dc),
 87    )
 88
 89
 90def dataclass_set_equals(
 91    coll1: Iterable[IsDataclass], coll2: Iterable[IsDataclass]
 92) -> bool:
 93    """Compares 2 collections of dataclass instances as if they were sets.
 94    Duplicates are ignored in the same manner as a set.
 95    Unfrozen dataclasses can't be placed in sets since they're not hashable.
 96    Collections of them may be compared using this function.
 97    """
 98
 99    return {get_hashable_eq_attrs(x) for x in coll1} == {
100        get_hashable_eq_attrs(y) for y in coll2
101    }

def is_abstract(cls: type) -> bool:
15def is_abstract(cls: type) -> bool:
16    """
17    Returns if a class is abstract.
18    """
19    if not hasattr(cls, "__abstractmethods__"):
20        return False  # an ordinary class
21    elif len(cls.__abstractmethods__) == 0:  # type: ignore[invalid-argument-type] # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
22        return False  # a concrete implementation of an abstract class
23    else:
24        return True  # an abstract class

Returns if a class is abstract.

def get_all_subclasses(class_: type, include_self=False) -> set[type]:
27def get_all_subclasses(class_: type, include_self=False) -> set[type]:
28    """
29    Returns a set containing all child classes in the subclass graph of `class_`.
30    I.e., includes subclasses of subclasses, etc.
31
32    # Parameters
33    - `include_self`: Whether to include `class_` itself in the returned set
34    - `class_`: Superclass
35
36    # Development
37    Since most class hierarchies are small, the inefficiencies of the existing recursive implementation aren't problematic.
38    It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
39    """
40    subs: set[type] = set(
41        flatten(
42            get_all_subclasses(sub, include_self=True)
43            for sub in class_.__subclasses__()
44            if sub is not None
45        )
46    )
47    if include_self:
48        subs.add(class_)
49    return subs

Returns a set containing all child classes in the subclass graph of class_. I.e., includes subclasses of subclasses, etc.

Parameters

  • include_self: Whether to include class_ itself in the returned set
  • class_: Superclass

Development

Since most class hierarchies are small, the inefficiencies of the existing recursive implementation aren't problematic. It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.

def isinstance_by_type_name(o: object, type_name: str):
52def isinstance_by_type_name(o: object, type_name: str):
53    """Behaves like stdlib `isinstance` except it accepts a string representation of the type rather than the type itself.
54    This is a hacky function intended to circumvent the need to import a type into a module.
55    It is susceptible to type name collisions.
56
57    # Parameters
58    `o`: Object (not the type itself) whose type to interrogate
59    `type_name`: The string returned by `type_.__name__`.
60    Generic types are not supported, only types that would appear in `type_.__mro__`.
61    """
62    return type_name in {s.__name__ for s in type(o).__mro__}

Behaves like stdlib isinstance except it accepts a string representation of the type rather than the type itself. This is a hacky function intended to circumvent the need to import a type into a module. It is susceptible to type name collisions.

Parameters

o: Object (not the type itself) whose type to interrogate type_name: The string returned by type_.__name__. Generic types are not supported, only types that would appear in type_.__mro__.

@runtime_checkable
class IsDataclass(typing.Protocol):
69@runtime_checkable
70class IsDataclass(Protocol):
71    # Generic type for any dataclass instance
72    # https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass
73    __dataclass_fields__: ClassVar[dict[str, Any]]  # pyright: ignore[reportExplicitAny]

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto[T](Protocol):
    def meth(self) -> T:
        ...
IsDataclass(*args, **kwargs)
1945def _no_init_or_replace_init(self, *args, **kwargs):
1946    cls = type(self)
1947
1948    if cls._is_protocol:
1949        raise TypeError('Protocols cannot be instantiated')
1950
1951    # Already using a custom `__init__`. No need to calculate correct
1952    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1953    if cls.__init__ is not _no_init_or_replace_init:
1954        return
1955
1956    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1957    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1958    # searches for a proper new `__init__` in the MRO. The new `__init__`
1959    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1960    # instantiation of the protocol subclass will thus use the new
1961    # `__init__` and no longer call `_no_init_or_replace_init`.
1962    for base in cls.__mro__:
1963        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1964        if init is not _no_init_or_replace_init:
1965            cls.__init__ = init
1966            break
1967    else:
1968        # should not happen
1969        cls.__init__ = object.__init__
1970
1971    cls.__init__(self, *args, **kwargs)
def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[typing.Any]:
76def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[Any]:  # pyright: ignore[reportExplicitAny]
77    """Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself.
78    The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical.
79    Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen.
80    """
81    # TYPING: ty gives @Todo here
82    return (  # type: ignore[invalid-return-type]
83        *(
84            getattr(dc, fld.name)
85            for fld in filter(lambda x: x.compare, dc.__dataclass_fields__.values())
86        ),
87        type(dc),
88    )

Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen.

def dataclass_set_equals( coll1: Iterable[IsDataclass], coll2: Iterable[IsDataclass]) -> bool:
 91def dataclass_set_equals(
 92    coll1: Iterable[IsDataclass], coll2: Iterable[IsDataclass]
 93) -> bool:
 94    """Compares 2 collections of dataclass instances as if they were sets.
 95    Duplicates are ignored in the same manner as a set.
 96    Unfrozen dataclasses can't be placed in sets since they're not hashable.
 97    Collections of them may be compared using this function.
 98    """
 99
100    return {get_hashable_eq_attrs(x) for x in coll1} == {
101        get_hashable_eq_attrs(y) for y in coll2
102    }

Compares 2 collections of dataclass instances as if they were sets. Duplicates are ignored in the same manner as a set. Unfrozen dataclasses can't be placed in sets since they're not hashable. Collections of them may be compared using this function.