lodum.schema

  1# SPDX-FileCopyrightText: 2025-present Michael R. Bernstein <zopemaven@gmail.com>
  2#
  3# SPDX-License-Identifier: Apache-2.0
  4import inspect
  5import collections
  6from typing import (
  7    Any,
  8    Dict,
  9    Optional,
 10    Type,
 11    get_origin,
 12    get_args,
 13)
 14
 15from .field import Field
 16from .core import DEFAULT_MAX_DEPTH, get_context
 17
 18
 19def _sanitize_name(name: str) -> str:
 20    """Sanitizes a string to be a valid Python identifier part."""
 21    if not name:
 22        return "unknown"
 23    return "".join(c if c.isalnum() else "_" for c in name)
 24
 25
 26def generate_schema(
 27    t: Type[Any], depth: int = 0, visited: Optional[set] = None
 28) -> Dict[str, Any]:
 29    """Generates a JSON Schema for a given type."""
 30    if depth > DEFAULT_MAX_DEPTH:
 31        raise ValueError(
 32            f"Max recursion depth ({DEFAULT_MAX_DEPTH}) exceeded during schema generation"
 33        )
 34
 35    if visited is None:
 36        visited = set()
 37
 38    ctx = get_context()
 39
 40    # Direct registry lookup
 41    if t in ctx.registry._handlers:
 42        return ctx.registry._handlers[t].schema_fn(t, depth, visited)
 43
 44    origin = get_origin(t) or t
 45
 46    # Generic lookup (exact match)
 47    if origin in ctx.registry._handlers:
 48        return ctx.registry._handlers[origin].schema_fn(t, depth, visited)
 49
 50    # Inheritance lookup
 51    for super_t, h_obj in ctx.registry._handlers.items():
 52        try:
 53            if inspect.isclass(origin) and issubclass(origin, super_t):
 54                return h_obj.schema_fn(t, depth, visited)
 55        except TypeError:
 56            continue
 57
 58    if inspect.isclass(t) and getattr(t, "_lodum_enabled", False):
 59        if t in visited:
 60            # Recursive reference
 61            return {"$ref": f"#/definitions/{_sanitize_name(t.__name__)}"}
 62
 63        visited.add(t)
 64        fields: Dict[str, Field] = getattr(t, "_lodum_fields", {})
 65        properties = {}
 66        required = []
 67        for field_name, field_info in fields.items():
 68            key = field_info.rename if field_info.rename else field_info.name
 69            properties[key] = generate_schema(field_info.type, depth + 1, visited)
 70            if not field_info.has_default:
 71                required.append(key)
 72
 73        schema = {"type": "object", "properties": properties}
 74
 75        tag_name = getattr(t, "_lodum_tag", None)
 76        if tag_name:
 77            tag_value = getattr(t, "_lodum_tag_value", t.__name__)
 78            properties[tag_name] = {"const": tag_value}
 79            if tag_name not in required:
 80                required.append(tag_name)
 81
 82        if required:
 83            schema["required"] = required
 84
 85        visited.remove(t)
 86        return schema
 87
 88    return {}
 89
 90
 91def _schema_int(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
 92    return {"type": "integer"}
 93
 94
 95def _schema_str(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
 96    return {"type": "string"}
 97
 98
 99def _schema_float(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
100    return {"type": "number"}
101
102
103def _schema_bool(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
104    return {"type": "boolean"}
105
106
107def _schema_none(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
108    return {"type": "null"}
109
110
111def _schema_any(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
112    return {}
113
114
115def _schema_uuid(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
116    return {"type": "string", "format": "uuid"}
117
118
119def _schema_decimal(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
120    return {"type": "string"}
121
122
123def _schema_path(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
124    return {"type": "string"}
125
126
127def _schema_bytes(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
128    return {"type": "string", "contentEncoding": "base64"}
129
130
131def _schema_list(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
132    args = get_args(t)
133    item_schema = generate_schema(args[0], depth + 1, visited) if args else {}
134    return {"type": "array", "items": item_schema}
135
136
137def _schema_dict(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
138    args = get_args(t)
139    origin = get_origin(t) or t
140    if origin is collections.Counter:
141        val_schema = {"type": "integer"}
142    else:
143        val_schema = (
144            generate_schema(args[1], depth + 1, visited) if len(args) == 2 else {}
145        )
146    return {"type": "object", "additionalProperties": val_schema}
147
148
149def _schema_union(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
150    args = get_args(t)
151    schema: Dict[str, Any] = {
152        "anyOf": [generate_schema(arg, depth + 1, visited) for arg in args]
153    }
154
155    tag_names = set()
156    for arg in args:
157        if inspect.isclass(arg) and getattr(arg, "_lodum_enabled", False):
158            tag_names.add(getattr(arg, "_lodum_tag", None))
159        else:
160            tag_names.add(None)
161
162    if len(tag_names) == 1 and None not in tag_names:
163        tag_name = tag_names.pop()
164        schema["discriminator"] = {"propertyName": tag_name}
165
166    return schema
167
168
169def _schema_tuple(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
170    args = get_args(t)
171    return {
172        "type": "array",
173        "prefixItems": [generate_schema(arg, depth + 1, visited) for arg in args],
174    }
175
176
177def _schema_set(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
178    args = get_args(t)
179    item_schema = generate_schema(args[0], depth + 1, visited) if args else {}
180    return {"type": "array", "items": item_schema, "uniqueItems": True}
181
182
183def _schema_datetime(
184    t: Type[Any], depth: int, visited: Optional[set]
185) -> Dict[str, Any]:
186    return {"type": "string", "format": "date-time"}
187
188
189def _schema_enum(t: Type[Any], depth: int, visited: Optional[set]) -> Dict[str, Any]:
190    return {"enum": [m.value for m in t]}
def generate_schema( t: Type[Any], depth: int = 0, visited: Optional[set] = None) -> Dict[str, Any]:
27def generate_schema(
28    t: Type[Any], depth: int = 0, visited: Optional[set] = None
29) -> Dict[str, Any]:
30    """Generates a JSON Schema for a given type."""
31    if depth > DEFAULT_MAX_DEPTH:
32        raise ValueError(
33            f"Max recursion depth ({DEFAULT_MAX_DEPTH}) exceeded during schema generation"
34        )
35
36    if visited is None:
37        visited = set()
38
39    ctx = get_context()
40
41    # Direct registry lookup
42    if t in ctx.registry._handlers:
43        return ctx.registry._handlers[t].schema_fn(t, depth, visited)
44
45    origin = get_origin(t) or t
46
47    # Generic lookup (exact match)
48    if origin in ctx.registry._handlers:
49        return ctx.registry._handlers[origin].schema_fn(t, depth, visited)
50
51    # Inheritance lookup
52    for super_t, h_obj in ctx.registry._handlers.items():
53        try:
54            if inspect.isclass(origin) and issubclass(origin, super_t):
55                return h_obj.schema_fn(t, depth, visited)
56        except TypeError:
57            continue
58
59    if inspect.isclass(t) and getattr(t, "_lodum_enabled", False):
60        if t in visited:
61            # Recursive reference
62            return {"$ref": f"#/definitions/{_sanitize_name(t.__name__)}"}
63
64        visited.add(t)
65        fields: Dict[str, Field] = getattr(t, "_lodum_fields", {})
66        properties = {}
67        required = []
68        for field_name, field_info in fields.items():
69            key = field_info.rename if field_info.rename else field_info.name
70            properties[key] = generate_schema(field_info.type, depth + 1, visited)
71            if not field_info.has_default:
72                required.append(key)
73
74        schema = {"type": "object", "properties": properties}
75
76        tag_name = getattr(t, "_lodum_tag", None)
77        if tag_name:
78            tag_value = getattr(t, "_lodum_tag_value", t.__name__)
79            properties[tag_name] = {"const": tag_value}
80            if tag_name not in required:
81                required.append(tag_name)
82
83        if required:
84            schema["required"] = required
85
86        visited.remove(t)
87        return schema
88
89    return {}

Generates a JSON Schema for a given type.