Module flashy.state
Utility class for automatically handling state of solver.
The object StateManager
can track stateful components of the solver.
Each component should follow the PyTorch idiomatic state_dict()
and
load_state_dict()
methods.
TODO: support strict vs. non strict loading. Support using only some of the states, e.g. when fine tuning a checkpoint vs. continuing from its own checkpoint.
The StateManager itself implements the state dict protocol.
Expand source code
"""Utility class for automatically handling state of solver.
The object `StateManager()` can track stateful components of the solver.
Each component should follow the PyTorch idiomatic `state_dict()` and
`load_state_dict()` methods.
TODO: support strict vs. non strict loading. Support using only some
of the states, e.g. when fine tuning a checkpoint vs. continuing from
its own checkpoint.
The StateManager itself implements the state dict protocol.
"""
import typing as tp
StateDict = tp.Any # we don't really care if those are really dicts or not
@tp.runtime_checkable
class StateDictSource(tp.Protocol):
def state_dict(self) -> StateDict:
...
def load_state_dict(self, state: StateDict):
...
class DictWrapper:
"""Turn a dict into a `StateDictSource`, using inplace operations."""
def __init__(self, wrapped: dict):
self.wrapped = wrapped
def load_state_dict(self, state: StateDict):
self.wrapped.clear()
self.wrapped.update(state)
def state_dict(self):
return self.wrapped
class ListWrapper:
"""Turn a list into a `StateDictSource`, using inplace operations."""
def __init__(self, wrapped: list):
self.wrapped = wrapped
def load_state_dict(self, state: StateDict):
self.wrapped[:] = state
def state_dict(self):
return self.wrapped
class AttributeWrapper:
"""Turn any attribute into a `StateDictSource`."""
def __init__(self, owner: tp.Any, name: str):
self.owner = owner
self.name = name
def load_state_dict(self, state: StateDict):
current_value = getattr(self.owner, self.name)
try:
# Let's see if the type allows us inplace update
source = as_state_dict_source(current_value)
except TypeError:
setattr(self.owner, self.name, state)
else:
source.load_state_dict(state)
def state_dict(self):
return getattr(self.owner, self.name)
def as_state_dict_source(value: tp.Any) -> StateDictSource:
"""Try to cast the given value to a StateDictSource."""
if isinstance(value, StateDictSource):
return value
elif isinstance(value, dict):
return DictWrapper(value)
elif isinstance(value, list):
return ListWrapper(value)
else:
raise TypeError(f"Given type {type(value)} cannot be made into a StateDictSource")
def attribute_as_state_dict_source(owner: tp.Any, name: str) -> StateDictSource:
"""Try to cast the given attribute as a StateDictSource, in priority
with inplace operations like `ListWrapper`, and otherwise as `AttributeWrapper`.
"""
value = getattr(owner, name)
try:
return as_state_dict_source(value)
except TypeError:
return AttributeWrapper(owner, name)
class WriteOnlyWrapper(StateDictSource):
def __init__(self, source: StateDictSource):
self.source = source
def load_state_dict(self, state):
return
def state_dict(self):
return self.source.state_dict()
class StateManager(StateDictSource):
def __init__(self):
self.sources = {}
def register(self, name: str, source: StateDictSource, write_only: bool = False):
if name in self.sources:
raise ValueError(f"{name} already present in sources.")
if write_only:
source = WriteOnlyWrapper(source)
self.sources[name] = source
def state_dict(self) -> StateDict:
return {
name: source.state_dict() for name, source in self.sources.items()
}
def load_state_dict(self, state: StateDict):
for name, sub_state in state.items():
self.sources[name].load_state_dict(sub_state)
Functions
def as_state_dict_source(value: Any) ‑> StateDictSource
-
Try to cast the given value to a StateDictSource.
Expand source code
def as_state_dict_source(value: tp.Any) -> StateDictSource: """Try to cast the given value to a StateDictSource.""" if isinstance(value, StateDictSource): return value elif isinstance(value, dict): return DictWrapper(value) elif isinstance(value, list): return ListWrapper(value) else: raise TypeError(f"Given type {type(value)} cannot be made into a StateDictSource")
def attribute_as_state_dict_source(owner: Any, name: str) ‑> StateDictSource
-
Try to cast the given attribute as a StateDictSource, in priority with inplace operations like
ListWrapper
, and otherwise asAttributeWrapper
.Expand source code
def attribute_as_state_dict_source(owner: tp.Any, name: str) -> StateDictSource: """Try to cast the given attribute as a StateDictSource, in priority with inplace operations like `ListWrapper`, and otherwise as `AttributeWrapper`. """ value = getattr(owner, name) try: return as_state_dict_source(value) except TypeError: return AttributeWrapper(owner, name)
Classes
class AttributeWrapper (owner: Any, name: str)
-
Turn any attribute into a
StateDictSource
.Expand source code
class AttributeWrapper: """Turn any attribute into a `StateDictSource`.""" def __init__(self, owner: tp.Any, name: str): self.owner = owner self.name = name def load_state_dict(self, state: StateDict): current_value = getattr(self.owner, self.name) try: # Let's see if the type allows us inplace update source = as_state_dict_source(current_value) except TypeError: setattr(self.owner, self.name, state) else: source.load_state_dict(state) def state_dict(self): return getattr(self.owner, self.name)
Methods
def load_state_dict(self, state: Any)
-
Expand source code
def load_state_dict(self, state: StateDict): current_value = getattr(self.owner, self.name) try: # Let's see if the type allows us inplace update source = as_state_dict_source(current_value) except TypeError: setattr(self.owner, self.name, state) else: source.load_state_dict(state)
def state_dict(self)
-
Expand source code
def state_dict(self): return getattr(self.owner, self.name)
class DictWrapper (wrapped: dict)
-
Turn a dict into a
StateDictSource
, using inplace operations.Expand source code
class DictWrapper: """Turn a dict into a `StateDictSource`, using inplace operations.""" def __init__(self, wrapped: dict): self.wrapped = wrapped def load_state_dict(self, state: StateDict): self.wrapped.clear() self.wrapped.update(state) def state_dict(self): return self.wrapped
Methods
def load_state_dict(self, state: Any)
-
Expand source code
def load_state_dict(self, state: StateDict): self.wrapped.clear() self.wrapped.update(state)
def state_dict(self)
-
Expand source code
def state_dict(self): return self.wrapped
class ListWrapper (wrapped: list)
-
Turn a list into a
StateDictSource
, using inplace operations.Expand source code
class ListWrapper: """Turn a list into a `StateDictSource`, using inplace operations.""" def __init__(self, wrapped: list): self.wrapped = wrapped def load_state_dict(self, state: StateDict): self.wrapped[:] = state def state_dict(self): return self.wrapped
Methods
def load_state_dict(self, state: Any)
-
Expand source code
def load_state_dict(self, state: StateDict): self.wrapped[:] = state
def state_dict(self)
-
Expand source code
def state_dict(self): return self.wrapped
class StateDictSource (*args, **kwargs)
-
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]): def meth(self) -> T: ...
Expand source code
class StateDictSource(tp.Protocol): def state_dict(self) -> StateDict: ... def load_state_dict(self, state: StateDict): ...
Ancestors
- typing.Protocol
- typing.Generic
Subclasses
Methods
def load_state_dict(self, state: Any)
-
Expand source code
def load_state_dict(self, state: StateDict): ...
def state_dict(self) ‑> Any
-
Expand source code
def state_dict(self) -> StateDict: ...
class StateManager
-
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]): def meth(self) -> T: ...
Expand source code
class StateManager(StateDictSource): def __init__(self): self.sources = {} def register(self, name: str, source: StateDictSource, write_only: bool = False): if name in self.sources: raise ValueError(f"{name} already present in sources.") if write_only: source = WriteOnlyWrapper(source) self.sources[name] = source def state_dict(self) -> StateDict: return { name: source.state_dict() for name, source in self.sources.items() } def load_state_dict(self, state: StateDict): for name, sub_state in state.items(): self.sources[name].load_state_dict(sub_state)
Ancestors
- StateDictSource
- typing.Protocol
- typing.Generic
Methods
def load_state_dict(self, state: Any)
-
Expand source code
def load_state_dict(self, state: StateDict): for name, sub_state in state.items(): self.sources[name].load_state_dict(sub_state)
def register(self, name: str, source: StateDictSource, write_only: bool = False)
-
Register a virtual subclass of an ABC.
Returns the subclass, to allow usage as a class decorator.
Expand source code
def register(self, name: str, source: StateDictSource, write_only: bool = False): if name in self.sources: raise ValueError(f"{name} already present in sources.") if write_only: source = WriteOnlyWrapper(source) self.sources[name] = source
def state_dict(self) ‑> Any
-
Expand source code
def state_dict(self) -> StateDict: return { name: source.state_dict() for name, source in self.sources.items() }
class WriteOnlyWrapper (source: StateDictSource)
-
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]): def meth(self) -> T: ...
Expand source code
class WriteOnlyWrapper(StateDictSource): def __init__(self, source: StateDictSource): self.source = source def load_state_dict(self, state): return def state_dict(self): return self.source.state_dict()
Ancestors
- StateDictSource
- typing.Protocol
- typing.Generic
Methods
def load_state_dict(self, state)
-
Expand source code
def load_state_dict(self, state): return
def state_dict(self)
-
Expand source code
def state_dict(self): return self.source.state_dict()