Coverage for src/dataknobs_data/migration_v2/operations.py: 94%

106 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-15 12:29 -0500

1""" 

2Reversible operations for data migration. 

3""" 

4 

5from abc import ABC, abstractmethod 

6from dataclasses import dataclass 

7from typing import Any, Callable, List, Optional 

8 

9from dataknobs_data.records import Record 

10from dataknobs_data.fields import FieldType 

11 

12 

13@dataclass 

14class Operation(ABC): 

15 """ 

16 Base class for reversible migration operations. 

17  

18 Each operation can be applied forward or reversed for rollback support. 

19 """ 

20 

21 @abstractmethod 

22 def apply(self, record: Record) -> Record: 

23 """ 

24 Apply this operation to a record. 

25  

26 Args: 

27 record: Record to transform 

28  

29 Returns: 

30 Transformed record 

31 """ 

32 pass 

33 

34 @abstractmethod 

35 def reverse(self, record: Record) -> Record: 

36 """ 

37 Reverse this operation on a record. 

38  

39 Args: 

40 record: Record to reverse transform 

41  

42 Returns: 

43 Record with operation reversed 

44 """ 

45 pass 

46 

47 def __repr__(self) -> str: 

48 """String representation of operation.""" 

49 return f"{self.__class__.__name__}()" 

50 

51 

52@dataclass 

53class AddField(Operation): 

54 """Add a new field to records.""" 

55 

56 field_name: str 

57 default_value: Any = None 

58 field_type: Optional[FieldType] = None 

59 

60 def apply(self, record: Record) -> Record: 

61 """Add field with default value.""" 

62 result = Record( 

63 data=dict(record.fields), 

64 metadata=record.metadata.copy(), 

65 id=record.id 

66 ) 

67 

68 # Only add if field doesn't exist 

69 if self.field_name not in result.fields: 

70 result.set_field( 

71 self.field_name, 

72 self.default_value, 

73 field_type=self.field_type 

74 ) 

75 

76 return result 

77 

78 def reverse(self, record: Record) -> Record: 

79 """Remove the added field.""" 

80 result = Record( 

81 data=dict(record.fields), 

82 metadata=record.metadata.copy(), 

83 id=record.id 

84 ) 

85 

86 if self.field_name in result.fields: 

87 del result.fields[self.field_name] 

88 

89 return result 

90 

91 def __repr__(self) -> str: 

92 return f"AddField(field_name='{self.field_name}', default_value={self.default_value})" 

93 

94 

95@dataclass 

96class RemoveField(Operation): 

97 """Remove a field from records.""" 

98 

99 field_name: str 

100 store_removed: bool = False # If True, store removed value in metadata 

101 

102 def apply(self, record: Record) -> Record: 

103 """Remove the specified field.""" 

104 result = Record( 

105 data=dict(record.fields), 

106 metadata=record.metadata.copy(), 

107 id=record.id 

108 ) 

109 

110 if self.field_name in result.fields: 

111 if self.store_removed: 

112 # Store removed value in metadata for potential recovery 

113 result.metadata[f"_removed_{self.field_name}"] = result.fields[self.field_name].value 

114 del result.fields[self.field_name] 

115 

116 return result 

117 

118 def reverse(self, record: Record) -> Record: 

119 """Restore the removed field if possible.""" 

120 result = Record( 

121 data=dict(record.fields), 

122 metadata=record.metadata.copy(), 

123 id=record.id 

124 ) 

125 

126 # Try to restore from metadata if available 

127 metadata_key = f"_removed_{self.field_name}" 

128 if self.store_removed and metadata_key in result.metadata: 

129 result.set_field(self.field_name, result.metadata[metadata_key]) 

130 del result.metadata[metadata_key] 

131 

132 return result 

133 

134 def __repr__(self) -> str: 

135 return f"RemoveField(field_name='{self.field_name}')" 

136 

137 

138@dataclass 

139class RenameField(Operation): 

140 """Rename a field.""" 

141 

142 old_name: str 

143 new_name: str 

144 

145 def apply(self, record: Record) -> Record: 

146 """Rename field from old_name to new_name.""" 

147 result = Record( 

148 data={}, 

149 metadata=record.metadata.copy(), 

150 id=record.id 

151 ) 

152 

153 # Copy fields with renaming 

154 for field_name, field in record.fields.items(): 

155 if field_name == self.old_name: 

156 result.fields[self.new_name] = field 

157 # Update field's internal name 

158 result.fields[self.new_name].name = self.new_name 

159 else: 

160 result.fields[field_name] = field 

161 

162 return result 

163 

164 def reverse(self, record: Record) -> Record: 

165 """Rename field from new_name back to old_name.""" 

166 result = Record( 

167 data={}, 

168 metadata=record.metadata.copy(), 

169 id=record.id 

170 ) 

171 

172 # Copy fields with reverse renaming 

173 for field_name, field in record.fields.items(): 

174 if field_name == self.new_name: 

175 result.fields[self.old_name] = field 

176 # Update field's internal name 

177 result.fields[self.old_name].name = self.old_name 

178 else: 

179 result.fields[field_name] = field 

180 

181 return result 

182 

183 def __repr__(self) -> str: 

184 return f"RenameField(old_name='{self.old_name}', new_name='{self.new_name}')" 

185 

186 

187@dataclass 

188class TransformField(Operation): 

189 """Transform a field's value using a function.""" 

190 

191 field_name: str 

192 transform_fn: Callable[[Any], Any] 

193 reverse_fn: Optional[Callable[[Any], Any]] = None 

194 

195 def apply(self, record: Record) -> Record: 

196 """Apply transformation to field value.""" 

197 result = Record( 

198 data=dict(record.fields), 

199 metadata=record.metadata.copy(), 

200 id=record.id 

201 ) 

202 

203 if self.field_name in result.fields: 

204 old_value = result.fields[self.field_name].value 

205 try: 

206 new_value = self.transform_fn(old_value) 

207 result.set_field( 

208 self.field_name, 

209 new_value, 

210 field_type=result.fields[self.field_name].type, 

211 field_metadata=result.fields[self.field_name].metadata 

212 ) 

213 except Exception as e: 

214 # If transformation fails, keep original value 

215 # Could optionally store error in metadata 

216 result.metadata[f"_transform_error_{self.field_name}"] = str(e) 

217 

218 return result 

219 

220 def reverse(self, record: Record) -> Record: 

221 """Reverse the transformation if reverse function provided.""" 

222 if self.reverse_fn is None: 

223 # Can't reverse without reverse function 

224 return record 

225 

226 result = Record( 

227 data=dict(record.fields), 

228 metadata=record.metadata.copy(), 

229 id=record.id 

230 ) 

231 

232 if self.field_name in result.fields: 

233 old_value = result.fields[self.field_name].value 

234 try: 

235 new_value = self.reverse_fn(old_value) 

236 result.set_field( 

237 self.field_name, 

238 new_value, 

239 field_type=result.fields[self.field_name].type, 

240 field_metadata=result.fields[self.field_name].metadata 

241 ) 

242 except Exception as e: 

243 # If reverse fails, keep original value 

244 result.metadata[f"_reverse_error_{self.field_name}"] = str(e) 

245 

246 # Clean up any transform error metadata 

247 error_key = f"_transform_error_{self.field_name}" 

248 if error_key in result.metadata: 

249 del result.metadata[error_key] 

250 

251 return result 

252 

253 def __repr__(self) -> str: 

254 return f"TransformField(field_name='{self.field_name}')" 

255 

256 

257@dataclass 

258class CompositeOperation(Operation): 

259 """Combine multiple operations into one.""" 

260 

261 operations: List[Operation] 

262 

263 def apply(self, record: Record) -> Record: 

264 """Apply all operations in sequence.""" 

265 result = record 

266 for operation in self.operations: 

267 result = operation.apply(result) 

268 return result 

269 

270 def reverse(self, record: Record) -> Record: 

271 """Reverse all operations in reverse order.""" 

272 result = record 

273 for operation in reversed(self.operations): 

274 result = operation.reverse(result) 

275 return result 

276 

277 def add(self, operation: Operation) -> 'CompositeOperation': 

278 """Add an operation (fluent API).""" 

279 self.operations.append(operation) 

280 return self 

281 

282 def __repr__(self) -> str: 

283 return f"CompositeOperation(operations={len(self.operations)})"