Coverage for src/dataknobs_data/schema.py: 35%
91 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1"""Database schema definitions for field structures."""
3from __future__ import annotations
5from dataclasses import dataclass, field
6from typing import Any
8from .fields import FieldType
11@dataclass
12class FieldSchema:
13 """Schema definition for a field without actual data."""
15 name: str
16 type: FieldType
17 metadata: dict[str, Any] = field(default_factory=dict)
18 required: bool = False
19 default: Any = None
21 def is_vector_field(self) -> bool:
22 """Check if this is a vector field."""
23 return self.type in (FieldType.VECTOR, FieldType.SPARSE_VECTOR)
25 def get_dimensions(self) -> int | None:
26 """Get vector dimensions if this is a vector field."""
27 if self.is_vector_field():
28 return self.metadata.get("dimensions")
29 return None
31 def get_source_field(self) -> str | None:
32 """Get source field if this is a derived vector field."""
33 if self.is_vector_field():
34 return self.metadata.get("source_field")
35 return None
37 def to_dict(self) -> dict[str, Any]:
38 """Convert to dictionary representation."""
39 return {
40 "name": self.name,
41 "type": self.type.value,
42 "metadata": self.metadata,
43 "required": self.required,
44 "default": self.default,
45 }
47 @classmethod
48 def from_dict(cls, data: dict[str, Any]) -> FieldSchema:
49 """Create from dictionary representation."""
50 return cls(
51 name=data["name"],
52 type=FieldType(data["type"]),
53 metadata=data.get("metadata", {}),
54 required=data.get("required", False),
55 default=data.get("default"),
56 )
59@dataclass
60class DatabaseSchema:
61 """Schema definition for a database."""
63 fields: dict[str, FieldSchema] = field(default_factory=dict)
64 metadata: dict[str, Any] = field(default_factory=dict)
66 @classmethod
67 def create(cls, **field_definitions) -> DatabaseSchema:
68 """Create a schema from keyword arguments.
70 Examples:
71 schema = DatabaseSchema.create(
72 content=FieldType.TEXT,
73 embedding=(FieldType.VECTOR, {"dimensions": 384, "source_field": "content"}),
74 title=FieldType.TEXT,
75 score=(FieldType.FLOAT, {"required": True})
76 )
77 """
78 schema = cls()
79 for name, definition in field_definitions.items():
80 if isinstance(definition, FieldType):
81 # Simple field type
82 schema.add_field(FieldSchema(name=name, type=definition))
83 elif isinstance(definition, tuple):
84 # Field type with metadata/options
85 field_type, options = definition
86 field_metadata = options.get("metadata", {})
87 if "dimensions" in options:
88 field_metadata["dimensions"] = options["dimensions"]
89 if "source_field" in options:
90 field_metadata["source_field"] = options["source_field"]
92 schema.add_field(FieldSchema(
93 name=name,
94 type=field_type,
95 metadata=field_metadata,
96 required=options.get("required", False),
97 default=options.get("default")
98 ))
99 else:
100 raise ValueError(f"Invalid field definition for {name}: {definition}")
101 return schema
103 def add_field(self, field_schema: FieldSchema) -> DatabaseSchema:
104 """Add a field to the schema.
106 Returns self for chaining.
107 """
108 self.fields[field_schema.name] = field_schema
109 return self
111 def add_text_field(self, name: str, required: bool = False) -> DatabaseSchema:
112 """Add a text field to the schema."""
113 return self.add_field(FieldSchema(name=name, type=FieldType.TEXT, required=required))
115 def add_vector_field(
116 self,
117 name: str,
118 dimensions: int,
119 source_field: str | None = None,
120 required: bool = False
121 ) -> DatabaseSchema:
122 """Add a vector field to the schema."""
123 return self.add_field(FieldSchema(
124 name=name,
125 type=FieldType.VECTOR,
126 metadata={"dimensions": dimensions, "source_field": source_field},
127 required=required
128 ))
130 def remove_field(self, name: str) -> bool:
131 """Remove a field from the schema."""
132 if name in self.fields:
133 del self.fields[name]
134 return True
135 return False
137 def get_vector_fields(self) -> dict[str, FieldSchema]:
138 """Get all vector fields in the schema."""
139 return {
140 name: field
141 for name, field in self.fields.items()
142 if field.is_vector_field()
143 }
145 def get_source_fields(self) -> dict[str, list[str]]:
146 """Get mapping of source fields to their dependent vector fields."""
147 source_map = {}
148 for name, field_obj in self.fields.items():
149 if field_obj.is_vector_field():
150 source = field_obj.get_source_field()
151 if source:
152 if source not in source_map:
153 source_map[source] = []
154 source_map[source].append(name)
155 return source_map
157 def to_dict(self) -> dict[str, Any]:
158 """Convert to dictionary representation."""
159 return {
160 "fields": {name: f.to_dict() for name, f in self.fields.items()},
161 "metadata": self.metadata,
162 }
164 @classmethod
165 def from_dict(cls, data: dict[str, Any]) -> DatabaseSchema:
166 """Create from dictionary representation.
168 Supports multiple formats:
169 1. Full format with FieldSchema dicts
170 2. Simple format with just field types
171 3. Mixed format
173 Examples:
174 # Simple format
175 {"fields": {"content": "text", "score": "float"}}
177 # Full format
178 {"fields": {"content": {"type": "text", "required": true}}}
180 # Vector fields
181 {"fields": {"embedding": {"type": "vector", "dimensions": 384}}}
182 """
183 schema = cls(metadata=data.get("metadata", {}))
185 for name, field_data in data.get("fields", {}).items():
186 if isinstance(field_data, str):
187 # Simple string type
188 schema.fields[name] = FieldSchema(
189 name=name,
190 type=FieldType(field_data)
191 )
192 elif isinstance(field_data, dict):
193 if "type" in field_data:
194 # Full field schema dict
195 field_type = FieldType(field_data["type"])
196 metadata = {}
198 # Handle vector-specific fields
199 if "dimensions" in field_data:
200 metadata["dimensions"] = field_data["dimensions"]
201 if "source_field" in field_data:
202 metadata["source_field"] = field_data["source_field"]
204 # Merge with explicit metadata
205 if "metadata" in field_data:
206 metadata.update(field_data["metadata"])
208 schema.fields[name] = FieldSchema(
209 name=name,
210 type=field_type,
211 metadata=metadata,
212 required=field_data.get("required", False),
213 default=field_data.get("default")
214 )
215 else:
216 # Try to parse as FieldSchema dict
217 schema.fields[name] = FieldSchema.from_dict(field_data)
219 return schema