Coverage for src / autoencodix / utils / _traindynamics.py: 76%
76 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from dataclasses import dataclass, field
2from typing import Dict, Optional, Union, Any
4import numpy as np
7# internal check done
8# write tests: done
9@dataclass
10class TrainingDynamics:
11 """Structure to store training dynamics in the form epoch -> split -> data.
13 Attributes:
14 _data: A dictionary to store numpy arrays for each epoch and split
15 """
17 _data: Dict[int, Dict[str, Union[np.ndarray, Dict]]] = field(
18 default_factory=dict, repr=False
19 )
21 def __eq__(self, other: object) -> bool:
22 if not isinstance(other, TrainingDynamics):
23 return False
24 if self._data.keys() != other._data.keys():
25 return False
26 for epoch, splits in self._data.items():
27 other_splits = other._data.get(epoch, {})
28 if set(splits.keys()) != set(other_splits.keys()):
29 return False
30 for split, data in splits.items():
31 other_data = other_splits.get(split, None)
32 if data is None and other_data is None:
33 continue
34 if isinstance(data, np.ndarray) and isinstance(other_data, np.ndarray):
35 if not np.array_equal(data, other_data):
36 return False
37 elif isinstance(data, dict) and isinstance(other_data, dict):
38 if not data == other_data: # Dict equality check
39 return False
40 elif data is None or other_data is None:
41 return False # One is None, the other isn't
42 else:
43 return False # Type mismatch
44 return True
46 def add(
47 self,
48 epoch: int,
49 data: Optional[Union[float, np.ndarray, Dict]],
50 split: str = "train",
51 ) -> None:
52 """
53 Add a numpy array for a specific epoch and split.
55 Parameters:
56 ----------
57 epoch : int
58 The epoch number.
59 value : np.ndarray
60 The numpy array to store.
61 split : str, optional
62 The data split (default: 'train').
64 """
65 if data is None:
66 return
67 if split not in ["train", "valid", "test"]:
68 raise KeyError(
69 f"Invalid split type: {split}, we only support 'train', 'valid', and 'test' splits."
70 )
71 if isinstance(data, (int, float)):
72 data = np.array(data)
74 if not isinstance(data, np.ndarray):
75 if not isinstance(data, Dict):
76 raise TypeError(
77 f"Expected value to be of type numpy.ndarray or Dict, got {type(data)}."
78 )
80 if epoch not in self._data:
81 self._data[epoch] = {}
82 self._data[epoch][split] = data
84 def get(self, epoch: Optional[int] = None, split: Optional[str] = None) -> Union[
85 np.ndarray,
86 Dict[str, np.ndarray],
87 Dict[int, Dict[str, np.ndarray]],
88 Dict[Any, Any],
89 ]:
90 """Retrieve stored numpy arrays with flexible filtering.
92 Args:
93 epoch: Specific epoch to retrieve. If None, returns data for all epochs.
94 split: Specific split to retrieve (e.g., 'train', 'valid', 'test').
95 If None, returns data for all splits.
97 Returns:
98 - If epoch is None and split is None:
99 Returns complete data dictionary {epoch: {split: data}}
100 - If epoch is None and split is provided:
101 Returns numpy array of values for the specified split across all epochs
102 - If epoch is provided and split is None:
103 Returns dictionary of all splits for that epoch {split: data}
104 - If both epoch and split are provided:
105 Returns numpy array for specific epoch and split
107 Examples:
108 >>> dynamics = TrainingDynamics()
109 >>> dynamics.add(0, np.array([0.1, 0.2]), "train")
110 >>> dynamics.add(1, np.array([0.2, 0.3]), "train")
111 >>>
112 >>> # Get all data
113 >>> dynamics.get() # Returns {0: {"train": array([0.1, 0.2])}, 1: {"train": array([0.2, 0.3])}}
114 >>>
115 >>> # Get train split across all epochs
116 >>> dynamics.get(split="train") # Returns array([[0.1, 0.2], [0.2, 0.3]])
117 >>>
118 >>> # Get specific epoch
119 >>> dynamics.get(epoch=0) # Returns {"train": array([0.1, 0.2])}
120 """
121 # Case 1: No epoch specified
122 if split not in ["train", "valid", "test", None]:
123 raise KeyError(
124 f"Invalid split type: {split}, we only support 'train', 'valid', and 'test' splits."
125 )
126 if len(self._data.keys()) == 0:
127 return {}
128 if epoch is None:
129 # Case 1a: Split specified - return array of values across epochs
130 if split is not None:
131 epochs = sorted(self._data.keys())
132 data = []
133 for e in epochs:
134 if split in self._data[e]:
135 data.append(self._data[e][split])
136 return np.array(data) if data else np.array([])
137 # Case 1b: No split specified - return complete data dictionary
138 return self._data
140 # we need the not in self._data check here, because in the predict step we save
141 # the model outputs with epoch -1
142 if epoch < 0 and epoch is not None:
143 # -1 equals to the highest epoch, -2 to the second highest, etc.
144 # for the predict case, we save the model outputs with epoch -1
145 # so we can index with -1 directly
146 epoch = (
147 epoch
148 if epoch in self._data.keys() and split == "test"
149 else max(self._data.keys()) + (epoch + 1)
150 )
151 if epoch >= 0 and epoch not in self._data:
152 if split is not None:
153 return np.array([])
154 return {}
155 epoch_data = self._data[epoch]
157 if split is None:
158 return epoch_data
160 # Case: Both epoch and split specified
161 return epoch_data.get(split, np.array([]))
163 def __getitem__(
164 self, key: Union[int, slice]
165 ) -> Union[np.ndarray, Dict[int, Dict[str, np.ndarray]], Any]:
166 """Allow dictionary-style and slice-based access.
168 Args:
169 key: int or slice of index to obtain
170 Returns:
171 sliced Trainingdynamic
173 Examples:
174 dynamics[100] # Get data for epoch 100.
175 dynamics[50:100] # Get data for epochs 50-100.
176 """
177 if isinstance(key, int):
178 return self.get(key)
179 elif isinstance(key, slice):
180 start = key.start or min(self._data.keys())
181 stop = key.stop or max(self._data.keys())
182 return {epoch: self.get(epoch) for epoch in range(start, stop)}
183 raise KeyError(f"Invalid key type: {type(key)}")
185 def epochs(self) -> list:
186 """Return all recorded epochs"""
187 return sorted(self._data.keys())