Coverage for src/dataknobs_data/migration/operations.py: 37%
106 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
1"""Reversible operations for data migration.
2"""
4from __future__ import annotations
6from abc import ABC, abstractmethod
7from dataclasses import dataclass
8from typing import Any, TYPE_CHECKING
10from dataknobs_data.records import Record
12if TYPE_CHECKING:
13 from collections.abc import Callable
14 from dataknobs_data.fields import FieldType
17@dataclass
18class Operation(ABC):
19 """Base class for reversible migration operations.
21 Each operation can be applied forward or reversed for rollback support.
22 """
24 @abstractmethod
25 def apply(self, record: Record) -> Record:
26 """Apply this operation to a record.
28 Args:
29 record: Record to transform
31 Returns:
32 Transformed record
33 """
34 pass
36 @abstractmethod
37 def reverse(self, record: Record) -> Record:
38 """Reverse this operation on a record.
40 Args:
41 record: Record to reverse transform
43 Returns:
44 Record with operation reversed
45 """
46 pass
48 def __repr__(self) -> str:
49 """String representation of operation."""
50 return f"{self.__class__.__name__}()"
53@dataclass
54class AddField(Operation):
55 """Add a new field to records."""
57 field_name: str
58 default_value: Any = None
59 field_type: FieldType | None = None
61 def apply(self, record: Record) -> Record:
62 """Add field with default value."""
63 result = Record(
64 data=dict(record.fields),
65 metadata=record.metadata.copy(),
66 id=record.id
67 )
69 # Only add if field doesn't exist
70 if self.field_name not in result.fields:
71 result.set_field(
72 self.field_name,
73 self.default_value,
74 field_type=self.field_type
75 )
77 return result
79 def reverse(self, record: Record) -> Record:
80 """Remove the added field."""
81 result = Record(
82 data=dict(record.fields),
83 metadata=record.metadata.copy(),
84 id=record.id
85 )
87 if self.field_name in result.fields:
88 del result.fields[self.field_name]
90 return result
92 def __repr__(self) -> str:
93 return f"AddField(field_name='{self.field_name}', default_value={self.default_value})"
96@dataclass
97class RemoveField(Operation):
98 """Remove a field from records."""
100 field_name: str
101 store_removed: bool = False # If True, store removed value in metadata
103 def apply(self, record: Record) -> Record:
104 """Remove the specified field."""
105 result = Record(
106 data=dict(record.fields),
107 metadata=record.metadata.copy(),
108 id=record.id
109 )
111 if self.field_name in result.fields:
112 if self.store_removed:
113 # Store removed value in metadata for potential recovery
114 result.metadata[f"_removed_{self.field_name}"] = result.fields[self.field_name].value
115 del result.fields[self.field_name]
117 return result
119 def reverse(self, record: Record) -> Record:
120 """Restore the removed field if possible."""
121 result = Record(
122 data=dict(record.fields),
123 metadata=record.metadata.copy(),
124 id=record.id
125 )
127 # Try to restore from metadata if available
128 metadata_key = f"_removed_{self.field_name}"
129 if self.store_removed and metadata_key in result.metadata:
130 result.set_field(self.field_name, result.metadata[metadata_key])
131 del result.metadata[metadata_key]
133 return result
135 def __repr__(self) -> str:
136 return f"RemoveField(field_name='{self.field_name}')"
139@dataclass
140class RenameField(Operation):
141 """Rename a field."""
143 old_name: str
144 new_name: str
146 def apply(self, record: Record) -> Record:
147 """Rename field from old_name to new_name."""
148 result = Record(
149 data={},
150 metadata=record.metadata.copy(),
151 id=record.id
152 )
154 # Copy fields with renaming
155 for field_name, field in record.fields.items():
156 if field_name == self.old_name:
157 result.fields[self.new_name] = field
158 # Update field's internal name
159 result.fields[self.new_name].name = self.new_name
160 else:
161 result.fields[field_name] = field
163 return result
165 def reverse(self, record: Record) -> Record:
166 """Rename field from new_name back to old_name."""
167 result = Record(
168 data={},
169 metadata=record.metadata.copy(),
170 id=record.id
171 )
173 # Copy fields with reverse renaming
174 for field_name, field in record.fields.items():
175 if field_name == self.new_name:
176 result.fields[self.old_name] = field
177 # Update field's internal name
178 result.fields[self.old_name].name = self.old_name
179 else:
180 result.fields[field_name] = field
182 return result
184 def __repr__(self) -> str:
185 return f"RenameField(old_name='{self.old_name}', new_name='{self.new_name}')"
188@dataclass
189class TransformField(Operation):
190 """Transform a field's value using a function."""
192 field_name: str
193 transform_fn: Callable[[Any], Any]
194 reverse_fn: Callable[[Any], Any] | None = None
196 def apply(self, record: Record) -> Record:
197 """Apply transformation to field value."""
198 result = Record(
199 data=dict(record.fields),
200 metadata=record.metadata.copy(),
201 id=record.id
202 )
204 if self.field_name in result.fields:
205 old_value = result.fields[self.field_name].value
206 try:
207 new_value = self.transform_fn(old_value)
208 result.set_field(
209 self.field_name,
210 new_value,
211 field_type=result.fields[self.field_name].type,
212 field_metadata=result.fields[self.field_name].metadata
213 )
214 except Exception as e:
215 # If transformation fails, keep original value
216 # Could optionally store error in metadata
217 result.metadata[f"_transform_error_{self.field_name}"] = str(e)
219 return result
221 def reverse(self, record: Record) -> Record:
222 """Reverse the transformation if reverse function provided."""
223 if self.reverse_fn is None:
224 # Can't reverse without reverse function
225 return record
227 result = Record(
228 data=dict(record.fields),
229 metadata=record.metadata.copy(),
230 id=record.id
231 )
233 if self.field_name in result.fields:
234 old_value = result.fields[self.field_name].value
235 try:
236 new_value = self.reverse_fn(old_value)
237 result.set_field(
238 self.field_name,
239 new_value,
240 field_type=result.fields[self.field_name].type,
241 field_metadata=result.fields[self.field_name].metadata
242 )
243 except Exception as e:
244 # If reverse fails, keep original value
245 result.metadata[f"_reverse_error_{self.field_name}"] = str(e)
247 # Clean up any transform error metadata
248 error_key = f"_transform_error_{self.field_name}"
249 if error_key in result.metadata:
250 del result.metadata[error_key]
252 return result
254 def __repr__(self) -> str:
255 return f"TransformField(field_name='{self.field_name}')"
258@dataclass
259class CompositeOperation(Operation):
260 """Combine multiple operations into one."""
262 operations: list[Operation]
264 def apply(self, record: Record) -> Record:
265 """Apply all operations in sequence."""
266 result = record
267 for operation in self.operations:
268 result = operation.apply(result)
269 return result
271 def reverse(self, record: Record) -> Record:
272 """Reverse all operations in reverse order."""
273 result = record
274 for operation in reversed(self.operations):
275 result = operation.reverse(result)
276 return result
278 def add(self, operation: Operation) -> CompositeOperation:
279 """Add an operation (fluent API)."""
280 self.operations.append(operation)
281 return self
283 def __repr__(self) -> str:
284 return f"CompositeOperation(operations={len(self.operations)})"