Coverage for src / autoencodix / utils / _traindynamics.py: 76%

76 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from dataclasses import dataclass, field 

2from typing import Dict, Optional, Union, Any 

3 

4import numpy as np 

5 

6 

7# internal check done 

8# write tests: done 

9@dataclass 

10class TrainingDynamics: 

11 """Structure to store training dynamics in the form epoch -> split -> data. 

12 

13 Attributes: 

14 _data: A dictionary to store numpy arrays for each epoch and split 

15 """ 

16 

17 _data: Dict[int, Dict[str, Union[np.ndarray, Dict]]] = field( 

18 default_factory=dict, repr=False 

19 ) 

20 

21 def __eq__(self, other: object) -> bool: 

22 if not isinstance(other, TrainingDynamics): 

23 return False 

24 if self._data.keys() != other._data.keys(): 

25 return False 

26 for epoch, splits in self._data.items(): 

27 other_splits = other._data.get(epoch, {}) 

28 if set(splits.keys()) != set(other_splits.keys()): 

29 return False 

30 for split, data in splits.items(): 

31 other_data = other_splits.get(split, None) 

32 if data is None and other_data is None: 

33 continue 

34 if isinstance(data, np.ndarray) and isinstance(other_data, np.ndarray): 

35 if not np.array_equal(data, other_data): 

36 return False 

37 elif isinstance(data, dict) and isinstance(other_data, dict): 

38 if not data == other_data: # Dict equality check 

39 return False 

40 elif data is None or other_data is None: 

41 return False # One is None, the other isn't 

42 else: 

43 return False # Type mismatch 

44 return True 

45 

46 def add( 

47 self, 

48 epoch: int, 

49 data: Optional[Union[float, np.ndarray, Dict]], 

50 split: str = "train", 

51 ) -> None: 

52 """ 

53 Add a numpy array for a specific epoch and split. 

54 

55 Parameters: 

56 ---------- 

57 epoch : int 

58 The epoch number. 

59 value : np.ndarray 

60 The numpy array to store. 

61 split : str, optional 

62 The data split (default: 'train'). 

63 

64 """ 

65 if data is None: 

66 return 

67 if split not in ["train", "valid", "test"]: 

68 raise KeyError( 

69 f"Invalid split type: {split}, we only support 'train', 'valid', and 'test' splits." 

70 ) 

71 if isinstance(data, (int, float)): 

72 data = np.array(data) 

73 

74 if not isinstance(data, np.ndarray): 

75 if not isinstance(data, Dict): 

76 raise TypeError( 

77 f"Expected value to be of type numpy.ndarray or Dict, got {type(data)}." 

78 ) 

79 

80 if epoch not in self._data: 

81 self._data[epoch] = {} 

82 self._data[epoch][split] = data 

83 

84 def get(self, epoch: Optional[int] = None, split: Optional[str] = None) -> Union[ 

85 np.ndarray, 

86 Dict[str, np.ndarray], 

87 Dict[int, Dict[str, np.ndarray]], 

88 Dict[Any, Any], 

89 ]: 

90 """Retrieve stored numpy arrays with flexible filtering. 

91 

92 Args: 

93 epoch: Specific epoch to retrieve. If None, returns data for all epochs. 

94 split: Specific split to retrieve (e.g., 'train', 'valid', 'test'). 

95 If None, returns data for all splits. 

96 

97 Returns: 

98 - If epoch is None and split is None: 

99 Returns complete data dictionary {epoch: {split: data}} 

100 - If epoch is None and split is provided: 

101 Returns numpy array of values for the specified split across all epochs 

102 - If epoch is provided and split is None: 

103 Returns dictionary of all splits for that epoch {split: data} 

104 - If both epoch and split are provided: 

105 Returns numpy array for specific epoch and split 

106 

107 Examples: 

108 >>> dynamics = TrainingDynamics() 

109 >>> dynamics.add(0, np.array([0.1, 0.2]), "train") 

110 >>> dynamics.add(1, np.array([0.2, 0.3]), "train") 

111 >>> 

112 >>> # Get all data 

113 >>> dynamics.get() # Returns {0: {"train": array([0.1, 0.2])}, 1: {"train": array([0.2, 0.3])}} 

114 >>> 

115 >>> # Get train split across all epochs 

116 >>> dynamics.get(split="train") # Returns array([[0.1, 0.2], [0.2, 0.3]]) 

117 >>> 

118 >>> # Get specific epoch 

119 >>> dynamics.get(epoch=0) # Returns {"train": array([0.1, 0.2])} 

120 """ 

121 # Case 1: No epoch specified 

122 if split not in ["train", "valid", "test", None]: 

123 raise KeyError( 

124 f"Invalid split type: {split}, we only support 'train', 'valid', and 'test' splits." 

125 ) 

126 if len(self._data.keys()) == 0: 

127 return {} 

128 if epoch is None: 

129 # Case 1a: Split specified - return array of values across epochs 

130 if split is not None: 

131 epochs = sorted(self._data.keys()) 

132 data = [] 

133 for e in epochs: 

134 if split in self._data[e]: 

135 data.append(self._data[e][split]) 

136 return np.array(data) if data else np.array([]) 

137 # Case 1b: No split specified - return complete data dictionary 

138 return self._data 

139 

140 # we need the not in self._data check here, because in the predict step we save 

141 # the model outputs with epoch -1 

142 if epoch < 0 and epoch is not None: 

143 # -1 equals to the highest epoch, -2 to the second highest, etc. 

144 # for the predict case, we save the model outputs with epoch -1 

145 # so we can index with -1 directly 

146 epoch = ( 

147 epoch 

148 if epoch in self._data.keys() and split == "test" 

149 else max(self._data.keys()) + (epoch + 1) 

150 ) 

151 if epoch >= 0 and epoch not in self._data: 

152 if split is not None: 

153 return np.array([]) 

154 return {} 

155 epoch_data = self._data[epoch] 

156 

157 if split is None: 

158 return epoch_data 

159 

160 # Case: Both epoch and split specified 

161 return epoch_data.get(split, np.array([])) 

162 

163 def __getitem__( 

164 self, key: Union[int, slice] 

165 ) -> Union[np.ndarray, Dict[int, Dict[str, np.ndarray]], Any]: 

166 """Allow dictionary-style and slice-based access. 

167 

168 Args: 

169 key: int or slice of index to obtain 

170 Returns: 

171 sliced Trainingdynamic 

172 

173 Examples: 

174 dynamics[100] # Get data for epoch 100. 

175 dynamics[50:100] # Get data for epochs 50-100. 

176 """ 

177 if isinstance(key, int): 

178 return self.get(key) 

179 elif isinstance(key, slice): 

180 start = key.start or min(self._data.keys()) 

181 stop = key.stop or max(self._data.keys()) 

182 return {epoch: self.get(epoch) for epoch in range(start, stop)} 

183 raise KeyError(f"Invalid key type: {type(key)}") 

184 

185 def epochs(self) -> list: 

186 """Return all recorded epochs""" 

187 return sorted(self._data.keys())