Coverage for src/pydantic_typer/main.py: 99%
116 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-08-04 21:00 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-08-04 21:00 +0200
1from __future__ import annotations
3import inspect
4from functools import wraps
5from typing import Any, Callable
7import click
8import pydantic
9from typer import BadParameter, Option
10from typer import Typer as TyperBase
11from typer.main import CommandFunctionType, get_click_param, get_params_from_function, lenient_issubclass
12from typer.models import OptionInfo, ParameterInfo
13from typer.utils import (
14 AnnotatedParamWithDefaultValueError,
15 DefaultFactoryAndDefaultValueError,
16 MixedAnnotatedAndDefaultStyleError,
17 MultipleTyperAnnotationsError,
18 _split_annotation_from_typer_annotations,
19)
20from typing_extensions import Annotated
22from pydantic_typer.utils import copy_type, deep_update, inspect_signature
24PYDANTIC_FIELD_SEPARATOR = "."
27def _flatten_pydantic_model(
28 model: pydantic.BaseModel, ancestors: list[str], ancestor_typer_param=None
29) -> dict[str, inspect.Parameter]:
30 pydantic_parameters = {}
31 for field_name, field in model.model_fields.items():
32 qualifier = [*ancestors, field_name]
33 sub_name = f"_pydantic_{'_'.join(qualifier)}"
34 if lenient_issubclass(field.annotation, pydantic.BaseModel):
35 # TODO: pass ancestor_typer_param
36 params = _flatten_pydantic_model(field.annotation, qualifier) # type: ignore
37 pydantic_parameters.update(params)
38 else:
39 default = (
40 field.default if field.default is not pydantic.fields._Unset else ... # noqa: SLF001
41 )
42 # Pydantic stores annotations in field.metadata.
43 # If the field is already annotated with a typer.Option or typer.Argument, use that.
44 existing_typer_params = [meta for meta in field.metadata if isinstance(meta, ParameterInfo)]
45 if existing_typer_params:
46 typer_param = existing_typer_params[0]
47 if isinstance(typer_param, OptionInfo) and not typer_param.param_decls:
48 # If the the option was not named manually, use the default naming scheme
49 typer_param.param_decls = (f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}",)
50 elif ancestor_typer_param:
51 typer_param = ancestor_typer_param
52 else:
53 typer_param = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}")
54 pydantic_parameters[sub_name] = inspect.Parameter(
55 sub_name,
56 inspect.Parameter.KEYWORD_ONLY,
57 annotation=Annotated[field.annotation, typer_param, qualifier],
58 default=default,
59 )
60 return pydantic_parameters
63def enable_pydantic(callback: CommandFunctionType) -> CommandFunctionType:
64 original_signature = inspect_signature(callback)
66 pydantic_parameters = {}
67 pydantic_roots = {}
68 other_parameters = {}
69 for name, parameter in original_signature.parameters.items():
70 base_annotation, typer_annotations = _split_annotation_from_typer_annotations(parameter.annotation)
71 typer_param = typer_annotations[0] if typer_annotations else None
72 if lenient_issubclass(base_annotation, pydantic.BaseModel):
73 params = _flatten_pydantic_model(parameter.annotation, [name], typer_param)
74 pydantic_parameters.update(params)
75 pydantic_roots[name] = base_annotation
76 else:
77 other_parameters[name] = parameter
79 extended_signature = inspect.Signature(
80 [*other_parameters.values(), *pydantic_parameters.values()],
81 return_annotation=original_signature.return_annotation,
82 )
84 @copy_type(callback)
85 @wraps(callback)
86 def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
87 converted_kwargs = kwargs.copy()
88 raw_pydantic_objects: dict[str, Any] = {}
89 for kwarg_name in kwargs:
90 if kwarg_name in pydantic_parameters: 90 ↛ 89line 90 didn't jump to line 89 because the condition on line 90 was always true
91 kwarg_value = kwargs[kwarg_name]
92 converted_kwargs.pop(kwarg_name)
93 annotation = pydantic_parameters[kwarg_name].annotation
94 _, qualifier = annotation.__metadata__
95 for part in reversed(qualifier):
96 kwarg_value = {part: kwarg_value}
97 raw_pydantic_objects = deep_update(raw_pydantic_objects, kwarg_value)
98 for root_name, value in raw_pydantic_objects.items():
99 converted_kwargs[root_name] = pydantic_roots[root_name](**value)
100 return callback(*args, **converted_kwargs)
102 wrapper.__signature__ = extended_signature # type: ignore
103 # Copy annotations to make forward references work in Python <= 3.9
104 wrapper.__annotations__ = {k: v.annotation for k, v in extended_signature.parameters.items()}
105 return wrapper
108def enable_pydantic_type_validation(callback: CommandFunctionType) -> CommandFunctionType:
109 original_signature = inspect_signature(callback)
111 # Change the annotation of unsupported types to str to be parsed by pydantic.
112 # Adapted from https://github.com/tiangolo/typer/blob/95b767e38a98ee287a7a0e28176284836e1188c2/typer/main.py#L543
113 # TODO: it's not ideal to call get_params_from_function and get_click_param here,
114 # because it will be called in typer again, but the annotations supported by typer are quite dynamic.
115 try:
116 parameters = get_params_from_function(callback)
117 except (
118 AnnotatedParamWithDefaultValueError,
119 DefaultFactoryAndDefaultValueError,
120 MixedAnnotatedAndDefaultStyleError,
121 MultipleTyperAnnotationsError,
122 ):
123 # We can't raise now. Typer will raise in the right moment.
124 parameters = {}
126 updated_parameters = dict(original_signature.parameters)
127 for param_name, param in parameters.items():
128 original_parameter = original_signature.parameters[param_name]
129 if lenient_issubclass(param.annotation, click.Context):
130 # click.Context should not be modified
131 continue
132 # We don't know wheter to use pydantic or typer to parse a param without checking if typer supports it.
133 try:
134 get_click_param(param)
135 except click.ClickException:
136 # We can't raise now. Typer will raise in the right moment.
137 pass
138 except RuntimeError:
139 # TODO: don't use raw str, but copy other annotations
140 updated_parameter = inspect.Parameter(
141 param_name, kind=original_parameter.kind, default=original_parameter.default, annotation=str
142 )
143 updated_parameters[param_name] = updated_parameter
145 new_signature = inspect.Signature(
146 parameters=list(updated_parameters.values()), return_annotation=original_signature.return_annotation
147 )
149 @copy_type(callback)
150 @wraps(callback)
151 def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
152 bound_params = original_signature.bind(*args, **kwargs)
153 for name, value in bound_params.arguments.items():
154 try:
155 type_adapter = pydantic.TypeAdapter(original_signature.parameters[name].annotation)
156 except pydantic.PydanticSchemaGenerationError:
157 continue
158 try:
159 bound_params.arguments[name] = type_adapter.validate_python(value)
160 except pydantic.ValidationError as e:
161 raise BadParameter(message=e.errors()[0]["msg"], param_hint=name) from e
162 callback(*bound_params.args, **bound_params.kwargs)
164 wrapper.__signature__ = new_signature # type: ignore
165 # Copy annotations to make forward references work in Python <= 3.9
166 wrapper.__annotations__ = {k: v.annotation for k, v in new_signature.parameters.items()}
167 return wrapper
170class Typer(TyperBase):
171 @copy_type(TyperBase.command)
172 def command(self, *args, **kwargs):
173 original_decorator = super().command(*args, **kwargs)
175 def decorator_override(f: CommandFunctionType) -> CommandFunctionType:
176 f = enable_pydantic(f)
177 f = enable_pydantic_type_validation(f)
178 return original_decorator(f)
180 return decorator_override
183def run(function: Callable[..., Any]) -> None:
184 app = Typer(add_completion=False)
185 app.command()(function)
186 app()