Module flashy.state

Utility class for automatically handling state of solver. The object StateManager can track stateful components of the solver. Each component should follow the PyTorch idiomatic state_dict() and load_state_dict() methods.

TODO: support strict vs. non strict loading. Support using only some of the states, e.g. when fine tuning a checkpoint vs. continuing from its own checkpoint.

The StateManager itself implements the state dict protocol.

Expand source code
"""Utility class for automatically handling state of solver.
The object `StateManager()` can track stateful components of the solver.
Each component should follow the PyTorch idiomatic `state_dict()` and
`load_state_dict()` methods.

TODO: support strict vs. non strict loading. Support using only some
of the states, e.g. when fine tuning a checkpoint vs. continuing from
its own checkpoint.

The StateManager itself implements the state dict protocol.
"""

import typing as tp


StateDict = tp.Any  # we don't really care if those are really dicts or not


@tp.runtime_checkable
class StateDictSource(tp.Protocol):
    def state_dict(self) -> StateDict:
        ...

    def load_state_dict(self, state: StateDict):
        ...


class DictWrapper:
    """Turn a dict into a `StateDictSource`, using inplace operations."""
    def __init__(self, wrapped: dict):
        self.wrapped = wrapped

    def load_state_dict(self, state: StateDict):
        self.wrapped.clear()
        self.wrapped.update(state)

    def state_dict(self):
        return self.wrapped


class ListWrapper:
    """Turn a list into a `StateDictSource`, using inplace operations."""
    def __init__(self, wrapped: list):
        self.wrapped = wrapped

    def load_state_dict(self, state: StateDict):
        self.wrapped[:] = state

    def state_dict(self):
        return self.wrapped


class AttributeWrapper:
    """Turn any attribute into a `StateDictSource`."""
    def __init__(self, owner: tp.Any, name: str):
        self.owner = owner
        self.name = name

    def load_state_dict(self, state: StateDict):
        current_value = getattr(self.owner, self.name)
        try:
            # Let's see if the type allows us inplace update
            source = as_state_dict_source(current_value)
        except TypeError:
            setattr(self.owner, self.name, state)
        else:
            source.load_state_dict(state)

    def state_dict(self):
        return getattr(self.owner, self.name)


def as_state_dict_source(value: tp.Any) -> StateDictSource:
    """Try to cast the given value to a StateDictSource."""
    if isinstance(value, StateDictSource):
        return value
    elif isinstance(value, dict):
        return DictWrapper(value)
    elif isinstance(value, list):
        return ListWrapper(value)
    else:
        raise TypeError(f"Given type {type(value)} cannot be made into a StateDictSource")


def attribute_as_state_dict_source(owner: tp.Any, name: str) -> StateDictSource:
    """Try to cast the given attribute as a StateDictSource, in priority
    with inplace operations like `ListWrapper`, and otherwise as `AttributeWrapper`.
    """
    value = getattr(owner, name)
    try:
        return as_state_dict_source(value)
    except TypeError:
        return AttributeWrapper(owner, name)


class WriteOnlyWrapper(StateDictSource):
    def __init__(self, source: StateDictSource):
        self.source = source

    def load_state_dict(self, state):
        return

    def state_dict(self):
        return self.source.state_dict()


class StateManager(StateDictSource):
    def __init__(self):
        self.sources = {}

    def register(self, name: str, source: StateDictSource, write_only: bool = False):
        if name in self.sources:
            raise ValueError(f"{name} already present in sources.")
        if write_only:
            source = WriteOnlyWrapper(source)
        self.sources[name] = source

    def state_dict(self) -> StateDict:
        return {
            name: source.state_dict() for name, source in self.sources.items()
        }

    def load_state_dict(self, state: StateDict):
        for name, sub_state in state.items():
            self.sources[name].load_state_dict(sub_state)

Functions

def as_state_dict_source(value: Any) ‑> StateDictSource

Try to cast the given value to a StateDictSource.

Expand source code
def as_state_dict_source(value: tp.Any) -> StateDictSource:
    """Try to cast the given value to a StateDictSource."""
    if isinstance(value, StateDictSource):
        return value
    elif isinstance(value, dict):
        return DictWrapper(value)
    elif isinstance(value, list):
        return ListWrapper(value)
    else:
        raise TypeError(f"Given type {type(value)} cannot be made into a StateDictSource")
def attribute_as_state_dict_source(owner: Any, name: str) ‑> StateDictSource

Try to cast the given attribute as a StateDictSource, in priority with inplace operations like ListWrapper, and otherwise as AttributeWrapper.

Expand source code
def attribute_as_state_dict_source(owner: tp.Any, name: str) -> StateDictSource:
    """Try to cast the given attribute as a StateDictSource, in priority
    with inplace operations like `ListWrapper`, and otherwise as `AttributeWrapper`.
    """
    value = getattr(owner, name)
    try:
        return as_state_dict_source(value)
    except TypeError:
        return AttributeWrapper(owner, name)

Classes

class AttributeWrapper (owner: Any, name: str)

Turn any attribute into a StateDictSource.

Expand source code
class AttributeWrapper:
    """Turn any attribute into a `StateDictSource`."""
    def __init__(self, owner: tp.Any, name: str):
        self.owner = owner
        self.name = name

    def load_state_dict(self, state: StateDict):
        current_value = getattr(self.owner, self.name)
        try:
            # Let's see if the type allows us inplace update
            source = as_state_dict_source(current_value)
        except TypeError:
            setattr(self.owner, self.name, state)
        else:
            source.load_state_dict(state)

    def state_dict(self):
        return getattr(self.owner, self.name)

Methods

def load_state_dict(self, state: Any)
Expand source code
def load_state_dict(self, state: StateDict):
    current_value = getattr(self.owner, self.name)
    try:
        # Let's see if the type allows us inplace update
        source = as_state_dict_source(current_value)
    except TypeError:
        setattr(self.owner, self.name, state)
    else:
        source.load_state_dict(state)
def state_dict(self)
Expand source code
def state_dict(self):
    return getattr(self.owner, self.name)
class DictWrapper (wrapped: dict)

Turn a dict into a StateDictSource, using inplace operations.

Expand source code
class DictWrapper:
    """Turn a dict into a `StateDictSource`, using inplace operations."""
    def __init__(self, wrapped: dict):
        self.wrapped = wrapped

    def load_state_dict(self, state: StateDict):
        self.wrapped.clear()
        self.wrapped.update(state)

    def state_dict(self):
        return self.wrapped

Methods

def load_state_dict(self, state: Any)
Expand source code
def load_state_dict(self, state: StateDict):
    self.wrapped.clear()
    self.wrapped.update(state)
def state_dict(self)
Expand source code
def state_dict(self):
    return self.wrapped
class ListWrapper (wrapped: list)

Turn a list into a StateDictSource, using inplace operations.

Expand source code
class ListWrapper:
    """Turn a list into a `StateDictSource`, using inplace operations."""
    def __init__(self, wrapped: list):
        self.wrapped = wrapped

    def load_state_dict(self, state: StateDict):
        self.wrapped[:] = state

    def state_dict(self):
        return self.wrapped

Methods

def load_state_dict(self, state: Any)
Expand source code
def load_state_dict(self, state: StateDict):
    self.wrapped[:] = state
def state_dict(self)
Expand source code
def state_dict(self):
    return self.wrapped
class StateDictSource (*args, **kwargs)

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
Expand source code
class StateDictSource(tp.Protocol):
    def state_dict(self) -> StateDict:
        ...

    def load_state_dict(self, state: StateDict):
        ...

Ancestors

  • typing.Protocol
  • typing.Generic

Subclasses

Methods

def load_state_dict(self, state: Any)
Expand source code
def load_state_dict(self, state: StateDict):
    ...
def state_dict(self) ‑> Any
Expand source code
def state_dict(self) -> StateDict:
    ...
class StateManager

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
Expand source code
class StateManager(StateDictSource):
    def __init__(self):
        self.sources = {}

    def register(self, name: str, source: StateDictSource, write_only: bool = False):
        if name in self.sources:
            raise ValueError(f"{name} already present in sources.")
        if write_only:
            source = WriteOnlyWrapper(source)
        self.sources[name] = source

    def state_dict(self) -> StateDict:
        return {
            name: source.state_dict() for name, source in self.sources.items()
        }

    def load_state_dict(self, state: StateDict):
        for name, sub_state in state.items():
            self.sources[name].load_state_dict(sub_state)

Ancestors

Methods

def load_state_dict(self, state: Any)
Expand source code
def load_state_dict(self, state: StateDict):
    for name, sub_state in state.items():
        self.sources[name].load_state_dict(sub_state)
def register(self, name: str, source: StateDictSource, write_only: bool = False)

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

Expand source code
def register(self, name: str, source: StateDictSource, write_only: bool = False):
    if name in self.sources:
        raise ValueError(f"{name} already present in sources.")
    if write_only:
        source = WriteOnlyWrapper(source)
    self.sources[name] = source
def state_dict(self) ‑> Any
Expand source code
def state_dict(self) -> StateDict:
    return {
        name: source.state_dict() for name, source in self.sources.items()
    }
class WriteOnlyWrapper (source: StateDictSource)

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
Expand source code
class WriteOnlyWrapper(StateDictSource):
    def __init__(self, source: StateDictSource):
        self.source = source

    def load_state_dict(self, state):
        return

    def state_dict(self):
        return self.source.state_dict()

Ancestors

Methods

def load_state_dict(self, state)
Expand source code
def load_state_dict(self, state):
    return
def state_dict(self)
Expand source code
def state_dict(self):
    return self.source.state_dict()