Coverage for src/dataknobs_data/migration_old_backup/schema_evolution.py: 0%

205 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-15 12:32 -0500

1"""Schema evolution and versioning utilities.""" 

2 

3import json 

4import logging 

5from dataclasses import dataclass, field 

6from datetime import datetime 

7from enum import Enum 

8from typing import Any, Callable, Dict, List, Optional, Type, Union 

9 

10from dataknobs_data.fields import Field, FieldType 

11from dataknobs_data.records import Record 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16class MigrationType(Enum): 

17 """Types of schema migrations.""" 

18 ADD_FIELD = "add_field" 

19 REMOVE_FIELD = "remove_field" 

20 RENAME_FIELD = "rename_field" 

21 CHANGE_TYPE = "change_type" 

22 ADD_CONSTRAINT = "add_constraint" 

23 REMOVE_CONSTRAINT = "remove_constraint" 

24 CUSTOM = "custom" 

25 

26 

27@dataclass 

28class SchemaField: 

29 """Represents a field in a schema version.""" 

30 name: str 

31 type: Union[str, FieldType] 

32 required: bool = False 

33 default: Any = None 

34 metadata: Dict[str, Any] = field(default_factory=dict) 

35 

36 

37@dataclass 

38class SchemaVersion: 

39 """Represents a schema version.""" 

40 version: str 

41 created_at: datetime = field(default_factory=datetime.now) 

42 description: str = "" 

43 fields: Dict[str, SchemaField] = field(default_factory=dict) 

44 

45 def to_dict(self) -> Dict[str, Any]: 

46 """Convert to dictionary.""" 

47 return { 

48 'version': self.version, 

49 'created_at': self.created_at.isoformat(), 

50 'description': self.description, 

51 'fields': { 

52 name: { 

53 'type': field.type.value if hasattr(field.type, 'value') else str(field.type), 

54 'required': field.required, 

55 'default': field.default, 

56 'metadata': field.metadata 

57 } 

58 for name, field in self.fields.items() 

59 } 

60 } 

61 

62 @classmethod 

63 def from_dict(cls, data: Dict[str, Any]) -> "SchemaVersion": 

64 """Create from dictionary.""" 

65 fields = {} 

66 for name, field_data in data.get('fields', {}).items(): 

67 field_type = field_data.get('type', 'str') 

68 if hasattr(FieldType, field_type.upper()): 

69 field_type = FieldType[field_type.upper()] 

70 

71 fields[name] = SchemaField( 

72 name=name, 

73 type=field_type, 

74 required=field_data.get('required', False), 

75 default=field_data.get('default'), 

76 metadata=field_data.get('metadata', {}) 

77 ) 

78 

79 return cls( 

80 version=data['version'], 

81 created_at=datetime.fromisoformat(data['created_at']) if 'created_at' in data else datetime.now(), 

82 description=data.get('description', ''), 

83 fields=fields 

84 ) 

85 

86 

87@dataclass 

88class Migration: 

89 """Represents a schema migration.""" 

90 from_version: str 

91 to_version: str 

92 migration_type: MigrationType 

93 description: str = "" 

94 operations: List[Dict[str, Any]] = field(default_factory=list) 

95 up_function: Optional[Callable[[Record], Record]] = None 

96 down_function: Optional[Callable[[Record], Record]] = None 

97 

98 def apply_forward(self, record: Record) -> Record: 

99 """Apply forward migration to a record.""" 

100 if self.up_function: 

101 return self.up_function(record) 

102 

103 # Apply built-in migration types 

104 for operation in self.operations: 

105 record = self._apply_operation(record, operation, forward=True) 

106 

107 return record 

108 

109 def apply_backward(self, record: Record) -> Record: 

110 """Apply backward migration to a record.""" 

111 if self.down_function: 

112 return self.down_function(record) 

113 

114 # Apply built-in migration types in reverse 

115 for operation in reversed(self.operations): 

116 record = self._apply_operation(record, operation, forward=False) 

117 

118 return record 

119 

120 def _apply_operation(self, record: Record, operation: Dict[str, Any], forward: bool) -> Record: 

121 """Apply a single migration operation.""" 

122 op_type = MigrationType(operation['type']) 

123 

124 if op_type == MigrationType.ADD_FIELD: 

125 if forward: 

126 field_name = operation['field_name'] 

127 default_value = operation.get('default_value') 

128 if field_name not in record.fields: 

129 field_type = operation.get('field_type', 'str') 

130 if isinstance(field_type, str) and hasattr(FieldType, field_type.upper()): 

131 field_type = FieldType[field_type.upper()] 

132 elif isinstance(field_type, str): 

133 # Map common type names to FieldType 

134 type_map = {'str': FieldType.STRING, 'int': FieldType.INTEGER, 

135 'float': FieldType.FLOAT, 'bool': FieldType.BOOLEAN} 

136 field_type = type_map.get(field_type, FieldType.STRING) 

137 

138 record.fields[field_name] = Field( 

139 name=field_name, 

140 value=default_value, 

141 type=field_type 

142 ) 

143 else: 

144 # Reverse: remove the field 

145 field_name = operation['field_name'] 

146 if field_name in record.fields: 

147 del record.fields[field_name] 

148 

149 elif op_type == MigrationType.REMOVE_FIELD: 

150 if forward: 

151 field_name = operation['field_name'] 

152 if field_name in record.fields: 

153 del record.fields[field_name] 

154 else: 

155 # Reverse: add the field back with stored value 

156 field_name = operation['field_name'] 

157 if field_name not in record.fields: 

158 field_type = operation.get('field_type', 'str') 

159 if isinstance(field_type, str) and hasattr(FieldType, field_type.upper()): 

160 field_type = FieldType[field_type.upper()] 

161 elif isinstance(field_type, str): 

162 type_map = {'str': FieldType.STRING, 'int': FieldType.INTEGER, 

163 'float': FieldType.FLOAT, 'bool': FieldType.BOOLEAN} 

164 field_type = type_map.get(field_type, FieldType.STRING) 

165 

166 record.fields[field_name] = Field( 

167 name=field_name, 

168 value=None, 

169 type=field_type 

170 ) 

171 

172 elif op_type == MigrationType.RENAME_FIELD: 

173 old_name = operation['old_name'] 

174 new_name = operation['new_name'] 

175 

176 if forward: 

177 if old_name in record.fields: 

178 record.fields[new_name] = record.fields.pop(old_name) 

179 else: 

180 if new_name in record.fields: 

181 record.fields[old_name] = record.fields.pop(new_name) 

182 

183 elif op_type == MigrationType.CHANGE_TYPE: 

184 field_name = operation['field_name'] 

185 

186 if forward: 

187 new_type = operation['new_type'] 

188 converter = operation.get('converter') 

189 else: 

190 new_type = operation['old_type'] 

191 converter = operation.get('reverse_converter') 

192 

193 if field_name in record.fields: 

194 field = record.fields[field_name] 

195 if converter: 

196 field.value = converter(field.value) 

197 field.type = new_type 

198 

199 elif op_type == MigrationType.CUSTOM: 

200 custom_func = operation.get('forward' if forward else 'backward') 

201 if custom_func: 

202 record = custom_func(record) 

203 

204 return record 

205 

206 

207class SchemaEvolution: 

208 """Manage schema evolution and migrations.""" 

209 

210 def __init__(self): 

211 """Initialize schema evolution manager.""" 

212 self.versions: Dict[str, SchemaVersion] = {} 

213 self.migrations: List[Migration] = [] 

214 self.current_version: Optional[str] = None 

215 

216 def add_version(self, version: SchemaVersion) -> None: 

217 """Add a schema version.""" 

218 self.versions[version.version] = version 

219 if not self.current_version: 

220 self.current_version = version.version 

221 logger.info(f"Added schema version: {version.version}") 

222 

223 def add_migration(self, migration: Migration) -> None: 

224 """Add a migration between versions.""" 

225 if migration.from_version not in self.versions: 

226 raise ValueError(f"Unknown source version: {migration.from_version}") 

227 if migration.to_version not in self.versions: 

228 raise ValueError(f"Unknown target version: {migration.to_version}") 

229 

230 self.migrations.append(migration) 

231 logger.info(f"Added migration: {migration.from_version} -> {migration.to_version}") 

232 

233 def set_current_version(self, version: str) -> None: 

234 """Set the current schema version.""" 

235 if version not in self.versions: 

236 raise ValueError(f"Unknown version: {version}") 

237 self.current_version = version 

238 

239 def get_migration_path(self, from_version: str, to_version: str) -> List[Migration]: 

240 """Find migration path between versions.""" 

241 if from_version == to_version: 

242 return [] 

243 

244 # Simple linear search for now - could be optimized with graph algorithms 

245 path = [] 

246 current = from_version 

247 

248 while current != to_version: 

249 found = False 

250 for migration in self.migrations: 

251 if migration.from_version == current: 

252 path.append(migration) 

253 current = migration.to_version 

254 found = True 

255 break 

256 

257 if not found: 

258 # Try backward migrations 

259 for migration in self.migrations: 

260 if migration.to_version == current: 

261 path.append(migration) 

262 current = migration.from_version 

263 found = True 

264 break 

265 

266 if not found: 

267 raise ValueError(f"No migration path from {from_version} to {to_version}") 

268 

269 return path 

270 

271 def migrate_record( 

272 self, 

273 record: Record, 

274 from_version: str, 

275 to_version: str 

276 ) -> Record: 

277 """Migrate a record from one version to another.""" 

278 migrations = self.get_migration_path(from_version, to_version) 

279 

280 for migration in migrations: 

281 if migration.from_version == from_version: 

282 # Forward migration 

283 record = migration.apply_forward(record) 

284 else: 

285 # Backward migration 

286 record = migration.apply_backward(record) 

287 

288 # Update record metadata 

289 if not record.metadata: 

290 record.metadata = {} 

291 record.metadata['schema_version'] = to_version 

292 

293 return record 

294 

295 def auto_detect_changes( 

296 self, 

297 old_version: SchemaVersion, 

298 new_version: SchemaVersion 

299 ) -> Migration: 

300 """Auto-detect changes between schema versions.""" 

301 operations = [] 

302 

303 old_fields = set(old_version.fields.keys()) 

304 new_fields = set(new_version.fields.keys()) 

305 

306 # Detect added fields 

307 for field_name in new_fields - old_fields: 

308 field = new_version.fields[field_name] 

309 operations.append({ 

310 'type': MigrationType.ADD_FIELD.value, 

311 'field_name': field_name, 

312 'field_type': str(field.type), 

313 'default_value': field.default 

314 }) 

315 

316 # Detect removed fields 

317 for field_name in old_fields - new_fields: 

318 field = old_version.fields[field_name] 

319 operations.append({ 

320 'type': MigrationType.REMOVE_FIELD.value, 

321 'field_name': field_name, 

322 'field_type': str(field.type) 

323 }) 

324 

325 # Detect type changes 

326 for field_name in old_fields & new_fields: 

327 old_field = old_version.fields[field_name] 

328 new_field = new_version.fields[field_name] 

329 

330 if old_field.type != new_field.type: 

331 operations.append({ 

332 'type': MigrationType.CHANGE_TYPE.value, 

333 'field_name': field_name, 

334 'old_type': str(old_field.type), 

335 'new_type': str(new_field.type) 

336 }) 

337 

338 return Migration( 

339 from_version=old_version.version, 

340 to_version=new_version.version, 

341 migration_type=MigrationType.CUSTOM, 

342 description=f"Auto-detected migration from {old_version.version} to {new_version.version}", 

343 operations=operations 

344 ) 

345 

346 def save_to_file(self, filepath: str) -> None: 

347 """Save schema evolution to JSON file.""" 

348 data = { 

349 'current_version': self.current_version, 

350 'versions': { 

351 version_id: version.to_dict() 

352 for version_id, version in self.versions.items() 

353 }, 

354 'migrations': [ 

355 { 

356 'from_version': m.from_version, 

357 'to_version': m.to_version, 

358 'type': m.migration_type.value, 

359 'description': m.description, 

360 'operations': m.operations 

361 } 

362 for m in self.migrations 

363 ] 

364 } 

365 

366 with open(filepath, 'w') as f: 

367 json.dump(data, f, indent=2) 

368 

369 @classmethod 

370 def load_from_file(cls, filepath: str) -> "SchemaEvolution": 

371 """Load schema evolution from JSON file.""" 

372 with open(filepath, 'r') as f: 

373 data = json.load(f) 

374 

375 evolution = cls() 

376 evolution.current_version = data.get('current_version') 

377 

378 # Load versions 

379 for version_id, version_data in data.get('versions', {}).items(): 

380 version = SchemaVersion.from_dict(version_data) 

381 evolution.versions[version_id] = version 

382 

383 # Load migrations 

384 for migration_data in data.get('migrations', []): 

385 migration = Migration( 

386 from_version=migration_data['from_version'], 

387 to_version=migration_data['to_version'], 

388 migration_type=MigrationType(migration_data['type']), 

389 description=migration_data.get('description', ''), 

390 operations=migration_data.get('operations', []) 

391 ) 

392 evolution.migrations.append(migration) 

393 

394 return evolution