Coverage for tortoise_serializer/serializers.py: 88%

347 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-17 19:39 +0200

1import asyncio 

2import inspect 

3import logging 

4from collections.abc import Awaitable, Callable 

5from enum import Enum 

6from functools import lru_cache, wraps 

7from inspect import iscoroutinefunction 

8from typing import ( 

9 Any, 

10 Generator, 

11 Generic, 

12 Self, 

13 Sequence, 

14 Type, 

15 get_args, 

16 override, 

17) 

18 

19from frozendict import frozendict 

20from pydantic import BaseModel, ValidationError 

21from pydantic.main import IncEx 

22from structlog import get_logger 

23from tortoise import Model, fields 

24from tortoise.fields.relational import ( 

25 BackwardFKRelation, 

26 ForeignKeyFieldInstance, 

27 ManyToManyFieldInstance, 

28 ManyToManyRelation, 

29 _NoneAwaitable, 

30) 

31from tortoise.queryset import QuerySet 

32from typing_extensions import deprecated 

33 

34from tortoise_serializer.exceptions import ( 

35 TortoiseSerializerClassMethodException, 

36 TortoiseSerializerException, 

37) 

38from tortoise_serializer.types import MODEL, ContextType, T, Unset, UnsetType 

39 

40logger = get_logger() 

41log_level = logging.INFO 

42logging.getLogger(__name__).setLevel(log_level) 

43 

44 

45@deprecated("use require_condition_or_unset instead") 

46def require_permission_or_unset( 

47 permission_checker: Callable[[MODEL, ContextType], bool], 

48): 

49 """Ensure the context contains the required permissions for the decorated resolver 

50 if the permission is False then this will return UnsetType instead of 

51 calling the decorated resolver 

52 

53 :example: 

54 ```python 

55 def is_owner(instance: Model, context: ContextType) -> bool: 

56 return instance.created_by == context.get("user", None) 

57 

58 @require_permission_or_unset(is_owner) 

59 def resolve_secret_value(cls, instance: User, context) -> str: 

60 return "It's secret!" 

61 ``` 

62 """ 

63 

64 def decorator(func: Callable[..., T]): 

65 @wraps(func) 

66 def wrapper( 

67 cls, instance: MODEL, context: ContextType 

68 ) -> T | UnsetType: 

69 if not permission_checker(instance, context): 

70 return Unset 

71 return func(cls, instance, context) 

72 

73 @wraps(func) 

74 async def a_wrapper( 

75 cls, instance: MODEL, context: ContextType 

76 ) -> T | UnsetType: 

77 if not permission_checker(instance, context): 

78 return Unset 

79 return await func(cls, instance, context) 

80 

81 return wrapper if not iscoroutinefunction(func) else a_wrapper 

82 

83 return decorator 

84 

85 

86def require_condition_or_unset( 

87 condition_checker: Callable[[MODEL, ContextType], bool], 

88) -> Callable[[Callable[..., T]], Callable[..., T | UnsetType]]: 

89 """Ensure the condition is met for the decorated resolver. 

90 If the condition is False then this will return UnsetType instead of 

91 calling the decorated resolver. 

92 

93 This is a generic version that can be used for any condition, not just permissions. 

94 

95 :example: 

96 ```python 

97 def is_visible(instance: Model, context: ContextType) -> bool: 

98 return instance.is_public or context.get("user") == instance.owner 

99 

100 @require_condition_or_unset(is_visible) 

101 def resolve_content(cls, instance: Post, context) -> str: 

102 return instance.content 

103 

104 def is_valid_time(instance: Model, context: ContextType) -> bool: 

105 return datetime.now() >= instance.publish_time 

106 

107 @require_condition_or_unset(is_valid_time) 

108 def resolve_premium_content(cls, instance: Article, context) -> str: 

109 return instance.premium_content 

110 ``` 

111 """ 

112 

113 def decorator(func: Callable[..., T]) -> Callable[..., T | UnsetType]: 

114 @wraps(func) 

115 def wrapper( 

116 cls, instance: MODEL, context: ContextType 

117 ) -> T | UnsetType: 

118 if not condition_checker(instance, context): 

119 return Unset 

120 return func(cls, instance, context) 

121 

122 @wraps(func) 

123 async def a_wrapper( 

124 cls, instance: MODEL, context: ContextType 

125 ) -> T | UnsetType: 

126 if not condition_checker(instance, context): 

127 return Unset 

128 return await func(cls, instance, context) 

129 

130 return wrapper if not iscoroutinefunction(func) else a_wrapper 

131 

132 return decorator 

133 

134 

135class Serializer(BaseModel): 

136 """ 

137 Serializer of tortoise orm models 

138 

139 Resolvers: 

140 they are function can be async or not, with the name starting by resolve_* 

141 if a field is in the serializer and not in the `instance` then the serializer 

142 will look for a resolver before complaining 

143 

144 resolvers overrides `computed_fields` with same names since they are technically 

145 computed fields 

146 

147 priority order: 

148 computed_fields > foreign keys > model_fields 

149 """ 

150 

151 @classmethod 

152 async def from_tortoise_orm( 

153 cls, 

154 instance: Model, 

155 computed_fields: dict[str, Callable[[Model, Any], Awaitable[Any]]] 

156 | None = None, 

157 context: dict[str, Any] | ContextType | None = None, 

158 ) -> Self: 

159 if computed_fields is None: 

160 computed_fields = {} 

161 computed_fields |= cls._collect_resolvers() 

162 

163 # using a frozendict to allow caching when context is involved 

164 # also prevent missuses of the context: it must be considered as 

165 # read only 

166 frozen_context = frozendict(context or {}) 

167 

168 # fetch related fields before calling concurent resolvers 

169 # so all of them are guaranteed to have the model populated properly 

170 await cls._fetch_related_fields(instance) 

171 

172 ( 

173 models_fields, 

174 fk_fields, 

175 computed_fields_values, 

176 ) = await asyncio.gather( 

177 cls._resolve_model_fields(instance), 

178 cls._resolve_foreignkeys( 

179 instance, frozen_context, computed_fields 

180 ), 

181 cls._resolve_computed_fields( 

182 instance, frozen_context, computed_fields 

183 ), 

184 ) 

185 

186 fields_values = models_fields | fk_fields | computed_fields_values 

187 cls._remove_unsets(fields_values) 

188 try: 

189 return cls.model_validate(fields_values) 

190 except ValidationError: 

191 logger.error( 

192 "Failed to validate with model", 

193 model=cls.__name__, 

194 data=fields_values, 

195 instance=instance, 

196 context=frozen_context, 

197 models_fields=models_fields, 

198 fk_fields=fk_fields, 

199 computed_fields_values=computed_fields_values, 

200 computed_fields=computed_fields, 

201 ) 

202 raise 

203 

204 @classmethod 

205 async def from_tortoise_instances( 

206 cls, instances: Sequence[Model], **kwargs 

207 ) -> list[Self]: 

208 """Return a list of Self (Serializer) for the given sequence of 

209 tortoise instances 

210 """ 

211 return await asyncio.gather( 

212 *[ 

213 cls.from_tortoise_orm(instance, **kwargs) 

214 for instance in instances 

215 ] 

216 ) 

217 

218 @classmethod 

219 async def _fetch_related_fields(cls, instance: Model) -> None: 

220 fetch_related_fields = cls._get_non_fetched_related_field_names( 

221 instance 

222 ) 

223 if not fetch_related_fields: 

224 return 

225 

226 logger.debug( 

227 "Fetching related fields, consider using prefetch_related", 

228 serializer=cls, 

229 instance=instance, 

230 fields=fetch_related_fields, 

231 ) 

232 

233 # Fetch all the related fields 

234 await instance.fetch_related(*fetch_related_fields) 

235 

236 @staticmethod 

237 def _remove_unsets(data: dict[str, Any]) -> None: 

238 """Remove any Unset items from the given dictionary""" 

239 fields_to_remove = [ 

240 field_name 

241 for field_name, field_value in data.items() 

242 if field_value is Unset 

243 ] 

244 for field in fields_to_remove: 

245 data.pop(field, None) 

246 

247 @classmethod 

248 async def _resolve_model_fields(cls, instance: Model) -> dict[str, Any]: 

249 data = {} 

250 for field_name in cls.model_fields.keys(): 

251 if hasattr(instance, field_name): 

252 field_value = getattr(instance, field_name) 

253 

254 # ignore this is a job for _resolve_foreignkeys 

255 if isinstance(field_value, Model): 

256 continue 

257 # ignore, this is a job for _resolve_computed_fields 

258 if hasattr(cls, f"resolve_{field_name}"): 

259 continue 

260 

261 # unpack enum values 

262 if isinstance(field_value, Enum): 

263 field_value = field_value.value 

264 

265 data[field_name] = field_value 

266 return data 

267 

268 @classmethod 

269 def _get_non_fetched_related_field_names( 

270 cls, instance: Model 

271 ) -> list[str]: 

272 """Returns the list of all fields that need to be fetched 

273 to represent the current `cls` instance 

274 note this won't fetch nested serialziers field names 

275 """ 

276 fetch_related_fields = [] 

277 for field_name in cls.model_fields: 

278 # if a resolver already exists we use it instead of trying to 

279 # resolve it as a foreign key 

280 if hasattr(cls, f"resolve_{field_name}"): 

281 continue 

282 

283 relational_instance = getattr(instance, field_name, None) 

284 

285 # if the instance has been already fetched we don't add the field 

286 # to the list 

287 if isinstance(relational_instance, Model): 

288 continue 

289 

290 # if the item is None we output the value as None to see if the 

291 # serializer can allow it 

292 if relational_instance is None: 

293 continue 

294 elif isinstance(relational_instance, _NoneAwaitable): 

295 continue 

296 elif isinstance(relational_instance, ManyToManyRelation): 

297 if not relational_instance._fetched: 

298 fetch_related_fields.append(field_name) 

299 elif isinstance(relational_instance, fields.ReverseRelation): 

300 if not relational_instance._fetched: 

301 fetch_related_fields.append(field_name) 

302 else: 

303 if isinstance(relational_instance, QuerySet): 

304 fetch_related_fields.append(field_name) 

305 return fetch_related_fields 

306 

307 @classmethod 

308 async def _resolve_foreignkeys( 

309 cls, 

310 instance: Model, 

311 context: ContextType, 

312 computed_fields: dict[str, Callable[[Model, Any], Awaitable[Any]]], 

313 ) -> dict[str, Any]: 

314 data = {} 

315 for field_name, serializers in cls._get_nested_serializers().items(): 

316 # resolvers have higher priority 

317 if hasattr(cls, f"resolve_{field_name}"): 

318 continue 

319 

320 # for now: we only support one nested serializer 

321 if not len(serializers) == 1: 

322 raise ValueError( 

323 "Cannot use more than one serialzier for each nested relation" 

324 ) 

325 (serializer,) = serializers 

326 

327 relational_instance = getattr(instance, field_name, None) 

328 

329 # if the item is None we output the value as None to see if the 

330 # serializer can allow it 

331 if relational_instance is None or isinstance( 

332 relational_instance, _NoneAwaitable 

333 ): 

334 value = None 

335 # handling many to many relationships 

336 elif isinstance(relational_instance, ManyToManyRelation): 

337 value = await serializer.from_tortoise_instances( 

338 relational_instance.related_objects, context=context 

339 ) 

340 

341 # handle reverse relations 

342 elif isinstance(relational_instance, fields.ReverseRelation): 

343 tasks = [ 

344 serializer.from_tortoise_orm( 

345 instance, 

346 context=context, 

347 computed_fields=computed_fields.get(field_name, None), 

348 ) 

349 for instance in relational_instance.related_objects 

350 ] 

351 value = await asyncio.gather(*tasks) 

352 

353 # validating the nested relationship with a from_tortoise_orm call 

354 # to the nested serializer 

355 else: 

356 value = await serializers[0].from_tortoise_orm( 

357 relational_instance, 

358 context=context, 

359 computed_fields=computed_fields.get(field_name, None), 

360 ) 

361 data[field_name] = value 

362 return data 

363 

364 @classmethod 

365 async def _resolve_computed_fields( 

366 cls, 

367 instance: Model, 

368 context: ContextType, 

369 computed_fields: dict[str, Callable[[Model, Any], Awaitable[Any]]] 

370 | None = None, 

371 ) -> dict[str, Any]: 

372 """Resolve all values for computed fields 

373 note that async function will be called in an asyncio.TaskGroup 

374 """ 

375 if not computed_fields: 

376 return {} 

377 data = {} 

378 async with asyncio.TaskGroup() as tg: 

379 for field_name, field_resolver in computed_fields.items(): 

380 if not inspect.ismethod(field_resolver): 

381 raise TortoiseSerializerClassMethodException( 

382 cls, field_name 

383 ) 

384 

385 # ignore any nested serializers, it will be a job for the 

386 # foreign key resolver 

387 if isinstance( 

388 field_resolver, dict 

389 ) and cls._is_nested_serializer(field_name): 

390 continue 

391 

392 # add tasks to the taskgroup 

393 elif iscoroutinefunction(field_resolver): 

394 data[field_name] = tg.create_task( 

395 field_resolver(instance, context) 

396 ) 

397 

398 # get the values output values of sync resolvers 

399 elif callable(field_resolver): 

400 data[field_name] = field_resolver(instance, context) 

401 

402 # copy raw values 

403 else: 

404 data[field_name] = field_resolver 

405 

406 # we unpack the Task results for finished tasks 

407 for field_name, field_value in data.items(): 

408 if isinstance(field_value, asyncio.Task): 

409 data[field_name] = field_value.result() 

410 

411 return data 

412 

413 @classmethod 

414 def _is_nested_serializer(cls, field_name: str) -> bool: 

415 """ 

416 Check if the given field name corresponds to a nested serializer. 

417 """ 

418 # Ensure the field exists in the annotations 

419 if field_name not in cls.__annotations__: 

420 return False 

421 

422 # Get the type annotation for the field 

423 field_type = cls.__annotations__[field_name] 

424 

425 # Check if the field type corresponds to a nested serializer 

426 args = get_args(field_type) 

427 if args: 

428 return any( 

429 isinstance(arg, type) and issubclass(arg, Serializer) 

430 for arg in args 

431 ) 

432 return isinstance(field_type, type) and issubclass( 

433 field_type, Serializer 

434 ) 

435 

436 @classmethod 

437 def _get_nested_serializers_for_field( 

438 cls, field_name: str 

439 ) -> list["Serializer"]: 

440 """ 

441 Get a list of nested serializers for the given field, if any. 

442 

443 Args: 

444 field_name: The name of the field to check for nested serializers 

445 

446 Returns: 

447 A list of nested Serializer classes found in the field's type hints. 

448 Returns an empty list if no nested serializers are found or if the field 

449 doesn't exist. 

450 """ 

451 if ( 

452 not hasattr(cls, "model_fields") 

453 or field_name not in cls.model_fields 

454 ): 

455 return [] 

456 

457 field_annotation = cls.model_fields[field_name].annotation 

458 if not field_annotation: 

459 return [] 

460 

461 # Handle generic types (like list[Serializer]) 

462 type_args = get_args(field_annotation) 

463 if type_args: 

464 return [ 

465 arg 

466 for arg in type_args 

467 if isinstance(arg, type) and issubclass(arg, Serializer) 

468 ] 

469 

470 # Handle direct Serializer type 

471 if isinstance(field_annotation, type) and issubclass( 

472 field_annotation, Serializer 

473 ): 

474 return [field_annotation] 

475 

476 return [] 

477 

478 @classmethod 

479 @lru_cache() 

480 def _get_nested_serializers(cls) -> dict[str, list["Serializer"]]: 

481 serializers = {} 

482 for field_name in cls.model_fields.keys(): 

483 field_serializers = cls._get_nested_serializers_for_field( 

484 field_name 

485 ) 

486 if field_serializers: 

487 serializers[field_name] = field_serializers 

488 elif cls._is_nested_serializer(field_name): 

489 serializers[field_name] = [ 

490 cls.model_fields[field_name].annotation 

491 ] 

492 return serializers 

493 

494 @classmethod 

495 async def from_queryset( 

496 cls, queryset: QuerySet, *args, **kwargs 

497 ) -> list[Self]: 

498 """ 

499 Return a list of Self (Serializer) from the given queryset 

500 all instances are fetched in concurency using asyncio 

501 

502 Parameters: 

503 - `queryset`: The QuerySet instance to serialize from 

504 any *args, *kwargs will be passed to `from_tortoise_orm` method. 

505 """ 

506 

507 tasks = [ 

508 cls.from_tortoise_orm(instance, *args, **kwargs) 

509 async for instance in queryset 

510 ] 

511 return await asyncio.gather(*tasks) 

512 

513 @classmethod 

514 def _collect_resolvers( 

515 cls, 

516 ) -> dict[str, Callable[[Model, Any], Awaitable[Any]]]: 

517 """Collect all resolvers defined in the class, both method-based and decorator-based.""" 

518 fields = {} 

519 

520 # Collect method-based resolvers (starting with resolve_) 

521 for method in dir(cls): 

522 if method.startswith("resolve_") and callable( 

523 getattr(cls, method) 

524 ): 

525 fields[method.removeprefix("resolve_")] = getattr(cls, method) 

526 

527 # Collect decorator-based resolvers 

528 for attr_name in dir(cls): 

529 attr = getattr(cls, attr_name) 

530 if callable(attr) and hasattr(attr, "_resolver_fields"): 

531 for field_name in attr._resolver_fields: 

532 fields[field_name] = attr 

533 

534 return fields 

535 

536 def partial_update_tortoise_instance(self, model: Model, **kwargs) -> bool: 

537 """Update instance of `model` with the current serializer instance fields 

538 return `True` if the instance had been changed, `False` otherwise 

539 """ 

540 updater = self.model_dump(exclude_unset=True, **kwargs) 

541 if not updater: 

542 logger.debug( 

543 "No fields to update", model=model, fields_to_update=updater 

544 ) 

545 return False 

546 values_changed: bool = False 

547 for field, value in updater.items(): 

548 if hasattr(model, field): 

549 if getattr(model, field) == value: 

550 logger.debug( 

551 "Value remains the same", model=model, field_name=field 

552 ) 

553 else: 

554 setattr(model, field, value) 

555 logger.debug( 

556 "Updated Field", model=model, field_name=field 

557 ) 

558 values_changed = True 

559 return values_changed 

560 

561 async def create_tortoise_instance( 

562 self, 

563 model: Type[MODEL], 

564 *, 

565 _exclude: IncEx | None = None, 

566 _context: ContextType | None = None, 

567 **kwargs, 

568 ) -> MODEL: 

569 model_data = self.model_dump(exclude=_exclude) 

570 return await model.create(**(model_data | kwargs)) 

571 

572 def has_been_set(self, field_name: str) -> bool: 

573 """Return True if `field_name` has been set, otherwise False""" 

574 data = self.model_dump(include={field_name}, exclude_unset=True) 

575 return field_name in data 

576 

577 @classmethod 

578 def get_prefetch_fields_generator( 

579 cls, prefix: str = "" 

580 ) -> Generator[str, None, None]: 

581 """ 

582 Generate prefetch fields for all nested serializers. 

583 """ 

584 if prefix: 

585 prefix = prefix + "__" 

586 

587 for field_name in cls.model_fields.keys(): 

588 field_serializers = cls._get_nested_serializers_for_field( 

589 field_name 

590 ) 

591 

592 # If no nested serializers are found, skip this field 

593 if not field_serializers: 

594 continue 

595 

596 # check if the serializer need to be filterd out 

597 if not cls._filter_nested_serializer( 

598 field_name, field_serializers 

599 ): 

600 continue 

601 

602 # Field is a nested serializer 

603 yield prefix + field_name 

604 

605 # Recursively get prefetch fields from nested serializers 

606 for nested_serializer in field_serializers: 

607 yield from nested_serializer.get_prefetch_fields( 

608 prefix + field_name 

609 ) 

610 

611 @classmethod 

612 def _filter_nested_serializer( 

613 cls, field_name: str, serializers: Sequence["Serializer"] 

614 ) -> bool: 

615 """Override to filter out serializers from the prefetch fields""" 

616 return True 

617 

618 @classmethod 

619 def get_prefetch_fields(cls, prefix: str = "") -> list[str]: 

620 """ 

621 Generate prefetch fields for all nested serializers. 

622 The concept is to pass the output of that function to 

623 `Model.fetch_related()` or `QuerySet[Model].prefech_related()` 

624 """ 

625 return list(cls.get_prefetch_fields_generator(prefix)) 

626 

627 

628class ModelSerializer(Serializer, Generic[MODEL]): 

629 @classmethod 

630 @lru_cache() 

631 def get_model_class(cls) -> Type[MODEL]: 

632 """ 

633 Retrieve the model class associated with the current ModelSerializer 

634 subclass. 

635 

636 This method iterates through the class hierarchy to find the first 

637 class that inherits from tortoise.models.models.BaseModel and has a 

638 "__pydantic_generic_metadata__" attribute. 

639 It then extracts the model class from the "args" of the 

640 "__pydantic_generic_metadata__" attribute. 

641 

642 If no such class is found, a TortoiseSerializerException is raised. 

643 

644 Returns: 

645 Type[MODEL]: The model class associated with the current 

646 ModelSerializer subclass. 

647 """ 

648 for parent_class in cls.__mro__: 

649 if issubclass(parent_class, BaseModel) and hasattr( 

650 parent_class, "__pydantic_generic_metadata__" 

651 ): 

652 parent_meta = parent_class.__pydantic_generic_metadata__ 

653 origin = parent_meta.get("origin", None) 

654 if origin: 

655 args = parent_meta.get("args", None) 

656 return args[0] 

657 

658 raise TortoiseSerializerException( 

659 f"Bad configuration for ModelSerializer {cls}" 

660 ) 

661 

662 @override 

663 async def create_tortoise_instance( 

664 self, *, _exclude=None, _context: ContextType | None = None, **kwargs 

665 ) -> MODEL: 

666 """Creates the tortoise instance of this serializer and it's nested relations. 

667 it's highly recommended to use this inside a a `transaction` context 

668 

669 `_context` will be passed to any nested ModelSerializer as it is. 

670 """ 

671 creation_kwargs = {} 

672 exclude = set() 

673 many_to_manys: dict[str, list[Model]] = {} 

674 backward_fks: dict[str, list[ModelSerializer]] = {} 

675 model_class = self.get_model_class() 

676 

677 # as tempting as it might be, don't try to put that into a concurent 

678 # task like asyncio.gather: here we are probably in a transaction 

679 # context and tortoise will complain if we have 2 concurent operations 

680 for field_name, serializers in self._get_nested_serializers().items(): 

681 serialized_value = getattr(self, field_name) 

682 

683 # allow nones to be passed if the model allow them 

684 if serialized_value is None: 

685 continue 

686 

687 serializer_class = serializers[0] 

688 if not issubclass(serializer_class, ModelSerializer): 

689 raise TortoiseSerializerException( 

690 f"Bad configuration for field {field_name}:" 

691 " this must inherit from ModelSerializer" 

692 ) 

693 relation = model_class._meta.fields_map[field_name] 

694 if isinstance(relation, ManyToManyFieldInstance): 

695 for serializer in [ 

696 serializer_class.model_validate(item) 

697 for item in serialized_value 

698 ]: 

699 instance = await serializer.create_tortoise_instance( 

700 **kwargs.get(field_name, {}), 

701 _context=_context, 

702 ) 

703 many_to_manys[field_name] = many_to_manys.get( 

704 field_name, [] 

705 ) + [instance] 

706 exclude.add(field_name) 

707 

708 # backward foreign keys 

709 elif isinstance(relation, BackwardFKRelation): 

710 for serializer in [ 

711 serializer_class.model_validate(item) 

712 for item in serialized_value 

713 ]: 

714 backward_fks[field_name] = backward_fks.get( 

715 field_name, [] 

716 ) + [serializer] 

717 exclude.add(field_name) 

718 

719 elif isinstance(relation, ForeignKeyFieldInstance): 

720 serializer = serializer_class.model_validate(serialized_value) 

721 relation_instance = await serializer.create_tortoise_instance( 

722 **kwargs.get(field_name, {}), 

723 _context=_context, 

724 ) 

725 

726 # assign both `field_name_id` and `field_name` to have them 

727 # in the instance available (for external use) and avoid to 

728 # have to re-fetch them 

729 creation_kwargs[field_name + "_id"] = relation_instance.id 

730 creation_kwargs[field_name] = relation_instance 

731 exclude.add(field_name) 

732 

733 merged_kwargs = creation_kwargs | kwargs 

734 if _exclude: 

735 exclude = exclude | set(_exclude) 

736 instance = await super().create_tortoise_instance( 

737 model_class, 

738 _exclude=exclude, 

739 _context=_context, 

740 **merged_kwargs, 

741 ) 

742 for field_name, instances in many_to_manys.items(): 

743 await getattr(instance, field_name).add(*instances) 

744 

745 await self._create_backward_fks( 

746 model_class, instance, backward_fks, _context, _exclude or set() 

747 ) 

748 return instance 

749 

750 async def _create_backward_fks( 

751 self, 

752 serializer_model_class: Type[Model], 

753 instance: MODEL, 

754 backward_fks: dict[str, list[Self]], 

755 _context: ContextType | None, 

756 _exclude: set[str], 

757 ) -> None: 

758 """Creates the backward ForeignKeys for a given instance of self.get_model_class""" 

759 

760 for field_name, serializers in backward_fks.items(): 

761 if field_name in _exclude: 

762 continue 

763 field: fields.ReverseRelation = ( 

764 serializer_model_class._meta.fields_map[field_name] 

765 ) 

766 backward_key = field.relation_field 

767 for serializer in serializers: 

768 await serializer.create_tortoise_instance( 

769 _context=_context, 

770 **{backward_key: instance.id}, 

771 ) 

772 

773 @classmethod 

774 @lru_cache() 

775 def get_model_fields( 

776 cls, prefix: str | None = None, max_depth: int = 3 

777 ) -> set[str]: 

778 """Return the set of fields that are common to the model and this serializer, 

779 including nested serializer fields up to the specified max_depth. 

780 

781 Args: 

782 prefix (str | None): A string prefix to prepend to nested fields. 

783 max_depth (int): Maximum depth for nested field exploration. 

784 

785 Returns: 

786 Set[str]: A set of field names including nested fields, with prefixes applied. 

787 """ 

788 model_fields: set[str] = set(cls.get_model_class()._meta.fields) 

789 serializer_fields: set[str] = set(cls.model_fields.keys()) 

790 common_fields = model_fields.intersection(serializer_fields) 

791 

792 # Prepare prefix if not provided 

793 prefix = prefix or "" 

794 

795 if max_depth > 0: 

796 for field_name in common_fields.copy(): 

797 # Get nested serializers for this field 

798 serializers = cls._get_nested_serializers_for_field(field_name) 

799 if not serializers: 

800 continue 

801 

802 serializer_class = serializers[0] 

803 if not issubclass(serializer_class, ModelSerializer): 

804 raise TortoiseSerializerException( 

805 f"Bad configuration for field {field_name}:" 

806 f" this must inherit from ModelSerializer ({serializer_class})" 

807 ) 

808 

809 # Recursive call to get nested fields 

810 nested_fields = serializer_class.get_model_fields( 

811 prefix=f"{prefix}{field_name}__", 

812 max_depth=max_depth - 1, 

813 ) 

814 # Merge nested fields into the common fields 

815 common_fields.update(nested_fields) 

816 

817 # Add prefix to all fields 

818 return {f"{prefix}{field}" for field in common_fields} 

819 

820 @classmethod 

821 def _filter_nested_serializer( 

822 cls, field_name: str, serializers: Sequence["Serializer"] 

823 ) -> bool: 

824 # on ModelSerialzer we can check if the nested serializer exists 

825 # in the model so we avoid to return wrong fields in the prefetch 

826 # requests 

827 return field_name in cls.get_model_fields() 

828 

829 @classmethod 

830 def get_only_fetch_fields(cls, path: str | None = None) -> list[str]: 

831 """ 

832 Get the list of fields that should be fetched from the database. 

833 

834 This method recursively traverses the serializer's fields and nested 

835 serializers to build a list of database fields that need to be fetched. 

836 It handles both direct model fields and nested relationships. 

837 

838 Args: 

839 path (str | None): Optional path prefix for nested fields. Used 

840 internally for recursion. 

841 

842 Returns: 

843 list[str]: List of field paths that should be fetched from the 

844 database. 

845 

846 Raises: 

847 TortoiseSerializerException: If a nested serializer is not properly 

848 configured to inherit from ModelSerializer. 

849 """ 

850 fields = [] 

851 model = cls.get_model_class() 

852 for field_name in cls.model_fields.keys(): 

853 # Skip computed fields that don't exist in the model 

854 if field_name not in model._meta.fields_map.keys(): 

855 continue 

856 

857 if cls._is_nested_serializer(field_name): 

858 args = get_args(cls.__annotations__[field_name]) 

859 serializers = list( 

860 [ 

861 arg 

862 for arg in args 

863 if ( 

864 isinstance(arg, type) 

865 and issubclass(arg, ModelSerializer) 

866 ) 

867 ] 

868 ) 

869 serializer = serializers[0] 

870 nested_fields = serializer.get_only_fetch_fields( 

871 path=f"{path or ''}{field_name}__" 

872 ) 

873 fields.extend(nested_fields) 

874 else: 

875 fields.append(f"{path or ''}{field_name}") 

876 

877 return fields 

878 

879 @classmethod 

880 async def from_queryset( 

881 cls, 

882 queryset: QuerySet, 

883 *args, 

884 prefetch: bool = False, 

885 select_only: bool = False, 

886 **kwargs, 

887 ) -> list[Self]: 

888 """ 

889 Return a list of Self (ModelSerializer) from the given queryset. 

890 All instances are fetched in concurrency using asyncio. 

891 

892 Parameters: 

893 - `queryset`: The QuerySet instance to serialize from 

894 - `prefetch`: If True, prefetch the related fields 

895 - `select_only`: If True, only fetch the fields that are needed to serialize the model 

896 Note that only the fields defined in the serializer 

897 and its nested serializers are considered, be careful 

898 with the resolvers needs 

899 any *args, *kwargs will be passed to `Serializer.from_queryset` method.""" 

900 assert not ( 

901 prefetch and select_only 

902 ), "prefetch and select_only cannot be true at the same time" 

903 if prefetch: 

904 queryset = queryset.prefetch_related(*cls.get_prefetch_fields()) 

905 elif select_only: 

906 queryset = queryset.only(*cls.get_only_fetch_fields()) 

907 

908 return await super().from_queryset(queryset, *args, **kwargs)