Coverage for src/dataknobs_data/query_logic.py: 25%
221 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
1"""Boolean logic support for complex queries."""
3from __future__ import annotations
5from abc import ABC, abstractmethod
6from dataclasses import dataclass, field
7from enum import Enum
8from typing import TYPE_CHECKING, Any
10from .query import Filter, Operator, VectorQuery
12if TYPE_CHECKING:
13 import numpy as np
15 from .query import Query
16 from .vector.types import DistanceMetric
19class LogicOperator(Enum):
20 """Logical operators for combining conditions."""
21 AND = "and"
22 OR = "or"
23 NOT = "not"
26class Condition(ABC):
27 """Abstract base class for query conditions."""
29 @abstractmethod
30 def matches(self, record: Any) -> bool:
31 """Check if a record matches this condition."""
32 pass
34 @abstractmethod
35 def to_dict(self) -> dict[str, Any]:
36 """Convert condition to dictionary representation."""
37 pass
39 @classmethod
40 @abstractmethod
41 def from_dict(cls, data: dict[str, Any]) -> Condition:
42 """Create condition from dictionary representation."""
43 pass
46@dataclass
47class FilterCondition(Condition):
48 """A single filter condition."""
49 filter: Filter
51 def matches(self, record: Any) -> bool:
52 """Check if a record matches this filter."""
53 from .records import Record
55 if isinstance(record, Record):
56 value = record.get_value(self.filter.field)
57 elif isinstance(record, dict):
58 # Support nested field access for dicts
59 value = record
60 for part in self.filter.field.split('.'):
61 if isinstance(value, dict):
62 value = value.get(part)
63 else:
64 value = None
65 break
66 else:
67 value = getattr(record, self.filter.field, None)
69 return self.filter.matches(value)
71 def to_dict(self) -> dict[str, Any]:
72 """Convert to dictionary representation."""
73 return {
74 "type": "filter",
75 "filter": self.filter.to_dict()
76 }
78 @classmethod
79 def from_dict(cls, data: dict[str, Any]) -> FilterCondition:
80 """Create from dictionary representation."""
81 return cls(filter=Filter.from_dict(data["filter"]))
84@dataclass
85class LogicCondition(Condition):
86 """A logical combination of conditions."""
87 operator: LogicOperator
88 conditions: list[Condition] = field(default_factory=list)
90 def matches(self, record: Any) -> bool:
91 """Check if a record matches this logical condition."""
92 if self.operator == LogicOperator.AND:
93 # All conditions must match
94 return all(cond.matches(record) for cond in self.conditions)
95 elif self.operator == LogicOperator.OR:
96 # At least one condition must match
97 return any(cond.matches(record) for cond in self.conditions)
98 elif self.operator == LogicOperator.NOT:
99 # No conditions should match (or negate single condition)
100 if len(self.conditions) == 1:
101 return not self.conditions[0].matches(record)
102 else:
103 # NOT with multiple conditions = none should match
104 return not any(cond.matches(record) for cond in self.conditions)
105 else:
106 # This should never be reached as all operators are handled above
107 raise ValueError(f"Unknown logical operator: {self.operator}")
109 def to_dict(self) -> dict[str, Any]:
110 """Convert to dictionary representation."""
111 return {
112 "type": "logic",
113 "operator": self.operator.value,
114 "conditions": [cond.to_dict() for cond in self.conditions]
115 }
117 @classmethod
118 def from_dict(cls, data: dict[str, Any]) -> LogicCondition:
119 """Create from dictionary representation."""
120 conditions: list[Condition] = []
121 for cond_data in data.get("conditions", []):
122 if cond_data["type"] == "filter":
123 conditions.append(FilterCondition.from_dict(cond_data))
124 elif cond_data["type"] == "logic":
125 conditions.append(LogicCondition.from_dict(cond_data))
127 return cls(
128 operator=LogicOperator(data["operator"]),
129 conditions=conditions
130 )
133def condition_from_dict(data: dict[str, Any]) -> Condition:
134 """Factory function to create condition from dictionary."""
135 if data["type"] == "filter":
136 return FilterCondition.from_dict(data)
137 elif data["type"] == "logic":
138 return LogicCondition.from_dict(data)
139 else:
140 raise ValueError(f"Unknown condition type: {data['type']}")
143class QueryBuilder:
144 """Builder for complex queries with boolean logic."""
146 def __init__(self):
147 """Initialize empty query builder."""
148 self.root_condition = None
149 self.sort_specs = []
150 self.limit_value = None
151 self.offset_value = None
152 self.fields = None
153 self.vector_query = None
155 def where(self, field: str, operator: str | Operator, value: Any = None) -> QueryBuilder:
156 """Add a filter condition (defaults to AND with existing conditions)."""
157 op = Operator(operator) if isinstance(operator, str) else operator
158 filter_cond = FilterCondition(Filter(field, op, value))
160 if self.root_condition is None:
161 self.root_condition = filter_cond
162 elif isinstance(self.root_condition, LogicCondition) and self.root_condition.operator == LogicOperator.AND:
163 self.root_condition.conditions.append(filter_cond)
164 else:
165 # Wrap existing condition in AND
166 self.root_condition = LogicCondition(
167 operator=LogicOperator.AND,
168 conditions=[self.root_condition, filter_cond]
169 )
171 return self
173 def and_(self, *conditions: QueryBuilder | Filter | Condition) -> QueryBuilder:
174 """Add AND conditions."""
175 logic_cond = LogicCondition(operator=LogicOperator.AND)
177 for cond in conditions:
178 if isinstance(cond, QueryBuilder):
179 if cond.root_condition:
180 logic_cond.conditions.append(cond.root_condition)
181 elif isinstance(cond, Filter):
182 logic_cond.conditions.append(FilterCondition(cond))
183 elif isinstance(cond, Condition):
184 logic_cond.conditions.append(cond)
186 if self.root_condition is None:
187 self.root_condition = logic_cond
188 elif isinstance(self.root_condition, LogicCondition) and self.root_condition.operator == LogicOperator.AND:
189 self.root_condition.conditions.extend(logic_cond.conditions)
190 else:
191 self.root_condition = LogicCondition(
192 operator=LogicOperator.AND,
193 conditions=[self.root_condition, logic_cond]
194 )
196 return self
198 def or_(self, *conditions: QueryBuilder | Filter | Condition) -> QueryBuilder:
199 """Add OR conditions."""
200 logic_cond = LogicCondition(operator=LogicOperator.OR)
202 for cond in conditions:
203 if isinstance(cond, QueryBuilder):
204 if cond.root_condition:
205 logic_cond.conditions.append(cond.root_condition)
206 elif isinstance(cond, Filter):
207 logic_cond.conditions.append(FilterCondition(cond))
208 elif isinstance(cond, Condition):
209 logic_cond.conditions.append(cond)
211 if self.root_condition is None:
212 self.root_condition = logic_cond
213 else:
214 # Always wrap in OR at top level
215 if isinstance(self.root_condition, LogicCondition) and self.root_condition.operator == LogicOperator.OR:
216 self.root_condition.conditions.extend(logic_cond.conditions)
217 else:
218 self.root_condition = LogicCondition(
219 operator=LogicOperator.OR,
220 conditions=[self.root_condition] + logic_cond.conditions
221 )
223 return self
225 def not_(self, condition: QueryBuilder | Filter | Condition) -> QueryBuilder:
226 """Add NOT condition."""
227 if isinstance(condition, QueryBuilder):
228 not_cond = LogicCondition(
229 operator=LogicOperator.NOT,
230 conditions=[condition.root_condition] if condition.root_condition else []
231 )
232 elif isinstance(condition, Filter):
233 not_cond = LogicCondition(
234 operator=LogicOperator.NOT,
235 conditions=[FilterCondition(condition)]
236 )
237 else:
238 not_cond = LogicCondition(
239 operator=LogicOperator.NOT,
240 conditions=[condition]
241 )
243 if self.root_condition is None:
244 self.root_condition = not_cond
245 elif isinstance(self.root_condition, LogicCondition) and self.root_condition.operator == LogicOperator.AND:
246 self.root_condition.conditions.append(not_cond)
247 else:
248 self.root_condition = LogicCondition(
249 operator=LogicOperator.AND,
250 conditions=[self.root_condition, not_cond]
251 )
253 return self
255 def sort_by(self, field: str, order: str = "asc") -> QueryBuilder:
256 """Add sort specification."""
257 from .query import SortOrder, SortSpec
259 sort_order = SortOrder.ASC if order.lower() == "asc" else SortOrder.DESC
260 self.sort_specs.append(SortSpec(field=field, order=sort_order))
261 return self
263 def limit(self, value: int) -> QueryBuilder:
264 """Set result limit."""
265 self.limit_value = value
266 return self
268 def offset(self, value: int) -> QueryBuilder:
269 """Set result offset."""
270 self.offset_value = value
271 return self
273 def select(self, *fields: str) -> QueryBuilder:
274 """Set field projection."""
275 self.fields = list(fields) if fields else None
276 return self
278 def similar_to(
279 self,
280 vector: np.ndarray | list[float],
281 field: str = "embedding",
282 k: int = 10,
283 metric: DistanceMetric | str = "cosine",
284 include_source: bool = True,
285 score_threshold: float | None = None,
286 ) -> QueryBuilder:
287 """Add vector similarity search."""
288 self.vector_query = VectorQuery(
289 vector=vector,
290 field_name=field,
291 k=k,
292 metric=metric,
293 include_source=include_source,
294 score_threshold=score_threshold,
295 )
296 # If limit is not set, use k as the limit
297 if self.limit_value is None:
298 self.limit_value = k
299 return self
301 def build(self) -> ComplexQuery:
302 """Build the final query."""
303 return ComplexQuery(
304 condition=self.root_condition,
305 sort_specs=self.sort_specs,
306 limit_value=self.limit_value,
307 offset_value=self.offset_value,
308 fields=self.fields,
309 vector_query=self.vector_query
310 )
313@dataclass
314class ComplexQuery:
315 """A query with complex boolean logic support."""
317 # All fields have defaults to avoid ordering issues
318 condition: Condition | None = None
319 sort_specs: list = field(default_factory=list)
320 limit_value: int | None = None
321 offset_value: int | None = None
322 fields: list[str] | None = None
323 vector_query: VectorQuery | None = None # Vector similarity search
325 @classmethod
326 def AND(cls, queries: list[Query]) -> ComplexQuery:
327 """Create a complex query with AND logic."""
328 from .query import Query
330 conditions: list[Condition] = []
331 for q in queries:
332 if isinstance(q, Query):
333 # Convert Query filters to conditions
334 for f in q.filters:
335 conditions.append(FilterCondition(filter=f))
337 return cls(
338 condition=LogicCondition(operator=LogicOperator.AND, conditions=conditions)
339 )
341 @classmethod
342 def OR(cls, queries: list[Query]) -> ComplexQuery:
343 """Create a complex query with OR logic."""
344 from .query import Query
346 conditions: list[Condition] = []
347 for q in queries:
348 if isinstance(q, Query):
349 # Convert Query filters to conditions
350 for f in q.filters:
351 conditions.append(FilterCondition(filter=f))
353 return cls(
354 condition=LogicCondition(operator=LogicOperator.OR, conditions=conditions)
355 )
357 def matches(self, record: Any) -> bool:
358 """Check if a record matches this query."""
359 if self.condition is None:
360 return True
361 return self.condition.matches(record)
363 def to_simple_query(self) -> Query:
364 """Convert to simple Query if possible (AND filters only)."""
365 from .query import Query
367 filters = []
369 # Try to extract simple filters if all are AND conditions
370 if self.condition is None:
371 pass
372 elif isinstance(self.condition, FilterCondition):
373 filters.append(self.condition.filter)
374 elif isinstance(self.condition, LogicCondition) and self.condition.operator == LogicOperator.AND:
375 # Check if all sub-conditions are simple filters
376 all_filters = True
377 for cond in self.condition.conditions:
378 if isinstance(cond, FilterCondition):
379 filters.append(cond.filter)
380 else:
381 all_filters = False
382 break
384 if not all_filters:
385 # Can't convert complex logic to simple query
386 raise ValueError("Cannot convert complex boolean logic to simple Query")
387 else:
388 raise ValueError("Cannot convert complex boolean logic to simple Query")
390 return Query(
391 filters=filters,
392 sort_specs=self.sort_specs,
393 limit_value=self.limit_value,
394 offset_value=self.offset_value,
395 fields=self.fields,
396 vector_query=self.vector_query
397 )
399 def to_dict(self) -> dict[str, Any]:
400 """Convert to dictionary representation."""
401 result = {}
403 if self.condition:
404 result["condition"] = self.condition.to_dict()
406 if self.sort_specs:
407 result["sort"] = [s.to_dict() for s in self.sort_specs]
409 if self.limit_value is not None:
410 result["limit"] = self.limit_value
412 if self.offset_value is not None:
413 result["offset"] = self.offset_value
415 if self.fields is not None:
416 result["fields"] = self.fields
418 if self.vector_query is not None:
419 result["vector_query"] = self.vector_query.to_dict()
421 return result
423 @classmethod
424 def from_dict(cls, data: dict[str, Any]) -> ComplexQuery:
425 """Create from dictionary representation."""
426 from .query import SortSpec
428 condition = None
429 if "condition" in data:
430 condition = condition_from_dict(data["condition"])
432 sort_specs = []
433 for sort_data in data.get("sort", []):
434 sort_specs.append(SortSpec.from_dict(sort_data))
436 vector_query = None
437 if "vector_query" in data:
438 vector_query = VectorQuery.from_dict(data["vector_query"])
440 return cls(
441 condition=condition,
442 sort_specs=sort_specs,
443 limit_value=data.get("limit"),
444 offset_value=data.get("offset"),
445 fields=data.get("fields"),
446 vector_query=vector_query
447 )