Coverage for src/dataknobs_data/query.py: 30%
314 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:15 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:15 -0600
1from __future__ import annotations
3from dataclasses import dataclass, field
4from enum import Enum
5from typing import TYPE_CHECKING, Any
7if TYPE_CHECKING:
8 from collections.abc import Callable
10 import numpy as np
12 from .query_logic import ComplexQuery
13 from .vector.types import DistanceMetric
16class Operator(Enum):
17 """Query operators for filtering."""
19 EQ = "=" # Equal
20 NEQ = "!=" # Not equal
21 GT = ">" # Greater than
22 GTE = ">=" # Greater than or equal
23 LT = "<" # Less than
24 LTE = "<=" # Less than or equal
25 IN = "in" # In list
26 NOT_IN = "not_in" # Not in list
27 LIKE = "like" # String pattern matching (SQL LIKE)
28 NOT_LIKE = "not_like" # String pattern not matching (SQL NOT LIKE)
29 REGEX = "regex" # Regular expression matching
30 EXISTS = "exists" # Field exists
31 NOT_EXISTS = "not_exists" # Field does not exist
32 BETWEEN = "between" # Value between two bounds (inclusive)
33 NOT_BETWEEN = "not_between" # Value not between two bounds
36class SortOrder(Enum):
37 """Sort order for query results."""
39 ASC = "asc"
40 DESC = "desc"
43@dataclass
44class Filter:
45 """Represents a filter condition."""
47 field: str
48 operator: Operator
49 value: Any = None
51 def matches(self, record_value: Any) -> bool:
52 """Check if a record value matches this filter.
54 Supports type-aware comparisons for ranges and special handling
55 for datetime/date objects.
56 """
57 if self.operator == Operator.EXISTS:
58 return record_value is not None
59 elif self.operator == Operator.NOT_EXISTS:
60 return record_value is None
61 elif record_value is None:
62 return False
64 if self.operator == Operator.EQ:
65 return record_value == self.value
66 elif self.operator == Operator.NEQ:
67 return record_value != self.value
68 elif self.operator == Operator.GT:
69 return self._compare_values(record_value, self.value, lambda a, b: a > b)
70 elif self.operator == Operator.GTE:
71 return self._compare_values(record_value, self.value, lambda a, b: a >= b)
72 elif self.operator == Operator.LT:
73 return self._compare_values(record_value, self.value, lambda a, b: a < b)
74 elif self.operator == Operator.LTE:
75 return self._compare_values(record_value, self.value, lambda a, b: a <= b)
76 elif self.operator == Operator.IN:
77 return record_value in self.value
78 elif self.operator == Operator.NOT_IN:
79 return record_value not in self.value
80 elif self.operator == Operator.BETWEEN:
81 if not isinstance(self.value, (list, tuple)) or len(self.value) != 2:
82 return False
83 lower, upper = self.value
84 return self._compare_values(record_value, lower, lambda a, b: a >= b) and \
85 self._compare_values(record_value, upper, lambda a, b: a <= b)
86 elif self.operator == Operator.NOT_BETWEEN:
87 if not isinstance(self.value, (list, tuple)) or len(self.value) != 2:
88 return True
89 lower, upper = self.value
90 return not (self._compare_values(record_value, lower, lambda a, b: a >= b) and \
91 self._compare_values(record_value, upper, lambda a, b: a <= b))
92 elif self.operator == Operator.LIKE:
93 if not isinstance(record_value, str):
94 return False
95 import re
97 pattern = self.value.replace("%", ".*").replace("_", ".")
98 return bool(re.match(f"^{pattern}$", record_value))
99 elif self.operator == Operator.NOT_LIKE:
100 if not isinstance(record_value, str):
101 return False
102 import re
104 pattern = self.value.replace("%", ".*").replace("_", ".")
105 return not bool(re.match(f"^{pattern}$", record_value))
106 elif self.operator == Operator.REGEX:
107 if not isinstance(record_value, str):
108 return False
109 import re
111 return bool(re.search(self.value, record_value))
112 else:
113 # This should never be reached as all operators are handled above
114 raise ValueError(f"Unknown operator: {self.operator}")
116 def _compare_values(self, a: Any, b: Any, comparator) -> bool:
117 """Compare two values with type awareness.
119 Handles special cases:
120 - Datetime strings are parsed for comparison
121 - Mixed numeric types are converted appropriately
122 - String comparisons are case-sensitive
123 """
124 from datetime import date, datetime
126 # Handle datetime/date comparisons
127 if isinstance(a, str) and isinstance(b, (datetime, date)):
128 try:
129 a = datetime.fromisoformat(a.replace("Z", "+00:00"))
130 except (ValueError, AttributeError):
131 return False
132 elif isinstance(b, str) and isinstance(a, (datetime, date)):
133 try:
134 b = datetime.fromisoformat(b.replace("Z", "+00:00"))
135 except (ValueError, AttributeError):
136 return False
137 elif isinstance(a, str) and isinstance(b, str):
138 # Check if both look like dates
139 if "T" in a or "-" in a:
140 try:
141 a = datetime.fromisoformat(a.replace("Z", "+00:00"))
142 b = datetime.fromisoformat(b.replace("Z", "+00:00"))
143 except (ValueError, AttributeError):
144 pass # Keep as strings
146 # Handle numeric comparisons
147 if isinstance(a, (int, float)) and isinstance(b, (int, float)):
148 return comparator(a, b)
150 # Try direct comparison
151 try:
152 return comparator(a, b)
153 except TypeError:
154 # Types not comparable
155 return False
157 def to_dict(self) -> dict[str, Any]:
158 """Convert filter to dictionary representation."""
159 return {"field": self.field, "operator": self.operator.value, "value": self.value}
161 @classmethod
162 def from_dict(cls, data: dict[str, Any]) -> Filter:
163 """Create filter from dictionary representation."""
164 return cls(
165 field=data["field"], operator=Operator(data["operator"]), value=data.get("value")
166 )
169@dataclass
170class SortSpec:
171 """Represents a sort specification."""
173 field: str
174 order: SortOrder = SortOrder.ASC
176 def to_dict(self) -> dict[str, str]:
177 """Convert sort spec to dictionary representation."""
178 return {"field": self.field, "order": self.order.value}
180 @classmethod
181 def from_dict(cls, data: dict[str, str]) -> SortSpec:
182 """Create sort spec from dictionary representation."""
183 return cls(field=data["field"], order=SortOrder(data.get("order", "asc")))
186@dataclass
187class VectorQuery:
188 """Represents a vector similarity search query.
190 This dataclass encapsulates all parameters needed for vector similarity search,
191 including the query vector, distance metric, and various search options.
192 """
194 vector: np.ndarray | list[float] # Query vector or embeddings
195 field_name: str = "embedding" # Vector field name to search
196 k: int = 10 # Number of results (top-k)
197 metric: DistanceMetric | str = "cosine" # Distance metric
198 include_source: bool = True # Include source text in results
199 score_threshold: float | None = None # Minimum similarity score
200 rerank: bool = False # Whether to rerank results
201 rerank_k: int | None = None # Number of results to rerank (default: 2*k)
202 metadata: dict[str, Any] = field(default_factory=dict) # Additional metadata
204 def to_dict(self) -> dict[str, Any]:
205 """Convert vector query to dictionary representation."""
206 import numpy as np
208 # Handle vector serialization
209 vector_data = self.vector
210 if isinstance(vector_data, np.ndarray):
211 vector_data = vector_data.tolist()
213 # Handle metric serialization
214 metric_value = self.metric
215 if hasattr(metric_value, 'value'): # DistanceMetric enum
216 metric_value = metric_value.value
218 result = {
219 "vector": vector_data,
220 "field": self.field_name,
221 "k": self.k,
222 "metric": metric_value,
223 "include_source": self.include_source,
224 }
226 if self.score_threshold is not None:
227 result["score_threshold"] = self.score_threshold
228 if self.rerank:
229 result["rerank"] = self.rerank
230 if self.rerank_k is not None:
231 result["rerank_k"] = self.rerank_k
232 if self.metadata:
233 result["metadata"] = self.metadata
235 return result
237 @classmethod
238 def from_dict(cls, data: dict[str, Any]) -> VectorQuery:
239 """Create vector query from dictionary representation."""
240 import numpy as np
242 from .vector.types import DistanceMetric
244 # Handle vector deserialization
245 vector_data = data["vector"]
246 if not isinstance(vector_data, np.ndarray):
247 vector_data = np.array(vector_data, dtype=np.float32)
249 # Handle metric deserialization
250 metric_value = data.get("metric", "cosine")
251 if isinstance(metric_value, str):
252 try:
253 metric_value = DistanceMetric(metric_value)
254 except ValueError:
255 # Keep as string if not a valid enum value
256 pass
258 return cls(
259 vector=vector_data,
260 field_name=data.get("field", "embedding"),
261 k=data.get("k", 10),
262 metric=metric_value,
263 include_source=data.get("include_source", True),
264 score_threshold=data.get("score_threshold"),
265 rerank=data.get("rerank", False),
266 rerank_k=data.get("rerank_k"),
267 metadata=data.get("metadata", {}),
268 )
271@dataclass
272class Query:
273 """Represents a database query with filters, sorting, pagination, and vector search."""
275 filters: list[Filter] = field(default_factory=list)
276 sort_specs: list[SortSpec] = field(default_factory=list)
277 limit_value: int | None = None
278 offset_value: int | None = None
279 fields: list[str] | None = None # Field projection
280 vector_query: VectorQuery | None = None # Vector similarity search
282 @property
283 def sort_property(self) -> list[SortSpec]:
284 """Get sort specifications (backward compatibility)."""
285 return self.sort_specs
287 @property
288 def limit_property(self) -> int | None:
289 """Get limit value (backward compatibility)."""
290 return self.limit_value
292 @property
293 def offset_property(self) -> int | None:
294 """Get offset value (backward compatibility)."""
295 return self.offset_value
297 def filter(self, field: str, operator: str | Operator, value: Any = None) -> Query:
298 """Add a filter to the query (fluent interface).
300 Args:
301 field: The field name to filter on
302 operator: The operator (string or Operator enum)
303 value: The value to compare against
305 Returns:
306 Self for method chaining
307 """
308 if isinstance(operator, str):
309 op_map = {
310 "=": Operator.EQ,
311 "==": Operator.EQ,
312 "!=": Operator.NEQ,
313 ">": Operator.GT,
314 ">=": Operator.GTE,
315 "<": Operator.LT,
316 "<=": Operator.LTE,
317 "in": Operator.IN,
318 "IN": Operator.IN,
319 "not_in": Operator.NOT_IN,
320 "NOT IN": Operator.NOT_IN,
321 "like": Operator.LIKE,
322 "LIKE": Operator.LIKE,
323 "regex": Operator.REGEX,
324 "exists": Operator.EXISTS,
325 "not_exists": Operator.NOT_EXISTS,
326 "between": Operator.BETWEEN,
327 "BETWEEN": Operator.BETWEEN,
328 "not_between": Operator.NOT_BETWEEN,
329 "NOT BETWEEN": Operator.NOT_BETWEEN,
330 }
331 operator = op_map.get(operator, Operator.EQ)
333 self.filters.append(Filter(field=field, operator=operator, value=value))
334 return self
336 def sort_by(self, field: str, order: str | SortOrder = "asc") -> Query:
337 """Add a sort specification to the query (fluent interface).
339 Args:
340 field: The field name to sort by
341 order: The sort order ("asc", "desc", or SortOrder enum)
343 Returns:
344 Self for method chaining
345 """
346 if isinstance(order, str):
347 order = SortOrder.ASC if order.lower() == "asc" else SortOrder.DESC
349 self.sort_specs.append(SortSpec(field=field, order=order))
350 return self
352 def sort(self, field: str, order: str | SortOrder = "asc") -> Query:
353 """Add sorting (fluent interface)."""
354 return self.sort_by(field, order)
356 def set_limit(self, limit: int) -> Query:
357 """Set the result limit (fluent interface).
359 Args:
360 limit: Maximum number of results
362 Returns:
363 Self for method chaining
364 """
365 self.limit_value = limit
366 return self
368 def limit(self, value: int) -> Query:
369 """Set limit (fluent interface)."""
370 return self.set_limit(value)
372 def set_offset(self, offset: int) -> Query:
373 """Set the result offset (fluent interface).
375 Args:
376 offset: Number of results to skip
378 Returns:
379 Self for method chaining
380 """
381 self.offset_value = offset
382 return self
384 def offset(self, value: int) -> Query:
385 """Set offset (fluent interface)."""
386 return self.set_offset(value)
388 def select(self, *fields: str) -> Query:
389 """Set field projection (fluent interface).
391 Args:
392 fields: Field names to include in results
394 Returns:
395 Self for method chaining
396 """
397 self.fields = list(fields) if fields else None
398 return self
400 def clear_filters(self) -> Query:
401 """Clear all filters (fluent interface)."""
402 self.filters = []
403 return self
405 def clear_sort(self) -> Query:
406 """Clear all sort specifications (fluent interface)."""
407 self.sort_specs = []
408 return self
410 def similar_to(
411 self,
412 vector: np.ndarray | list[float],
413 field: str = "embedding",
414 k: int = 10,
415 metric: DistanceMetric | str = "cosine",
416 include_source: bool = True,
417 score_threshold: float | None = None,
418 ) -> Query:
419 """Add vector similarity search to the query.
421 This method sets up a vector similarity search that will find the k most
422 similar vectors to the provided query vector.
424 Args:
425 vector: Query vector to search for similar vectors
426 field: Vector field name to search (default: "embedding")
427 k: Number of results to return (default: 10)
428 metric: Distance metric to use (default: "cosine")
429 include_source: Whether to include source text in results (default: True)
430 score_threshold: Minimum similarity score threshold (optional)
432 Returns:
433 Self for method chaining
434 """
435 self.vector_query = VectorQuery(
436 vector=vector,
437 field_name=field,
438 k=k,
439 metric=metric,
440 include_source=include_source,
441 score_threshold=score_threshold,
442 )
443 # Always update limit to match k
444 self.limit_value = k
445 return self
447 def near_text(
448 self,
449 text: str,
450 embedding_fn: Callable[[str], np.ndarray],
451 field: str = "embedding",
452 k: int = 10,
453 metric: DistanceMetric | str = "cosine",
454 include_source: bool = True,
455 score_threshold: float | None = None,
456 ) -> Query:
457 """Add text-based vector similarity search to the query.
459 This is a convenience method that converts text to a vector using the
460 provided embedding function, then performs vector similarity search.
462 Args:
463 text: Text to convert to vector for similarity search
464 embedding_fn: Function to convert text to vector
465 field: Vector field name to search (default: "embedding")
466 k: Number of results to return (default: 10)
467 metric: Distance metric to use (default: "cosine")
468 include_source: Whether to include source text in results (default: True)
469 score_threshold: Minimum similarity score threshold (optional)
471 Returns:
472 Self for method chaining
473 """
474 # Convert text to vector using provided embedding function
475 vector = embedding_fn(text)
476 return self.similar_to(
477 vector=vector,
478 field=field,
479 k=k,
480 metric=metric,
481 include_source=include_source,
482 score_threshold=score_threshold,
483 )
485 def hybrid(
486 self,
487 text_query: str | None = None,
488 vector: np.ndarray | list[float] | None = None,
489 text_field: str = "content",
490 vector_field: str = "embedding",
491 alpha: float = 0.5,
492 k: int = 10,
493 metric: DistanceMetric | str = "cosine",
494 ) -> Query:
495 """Create a hybrid query combining text and vector search.
497 This method combines traditional text search with vector similarity search,
498 allowing for more nuanced queries that leverage both exact text matching
499 and semantic similarity.
501 Args:
502 text_query: Text to search for (optional)
503 vector: Vector for similarity search (optional)
504 text_field: Field for text search (default: "content")
505 vector_field: Field for vector search (default: "embedding")
506 alpha: Weight balance between text (0.0) and vector (1.0) search (default: 0.5)
507 k: Number of results to return (default: 10)
508 metric: Distance metric for vector search (default: "cosine")
510 Returns:
511 Self for method chaining
513 Note:
514 - alpha=0.0 gives full weight to text search
515 - alpha=1.0 gives full weight to vector search
516 - alpha=0.5 gives equal weight to both
517 """
518 # Add text filter if provided
519 if text_query:
520 self.filter(text_field, Operator.LIKE, f"%{text_query}%")
522 # Add vector search if provided
523 if vector is not None:
524 self.vector_query = VectorQuery(
525 vector=vector,
526 field_name=vector_field,
527 k=k,
528 metric=metric,
529 include_source=True,
530 )
531 # Store alpha in vector query metadata for backend to use
532 self.vector_query.metadata = {"hybrid_alpha": alpha}
534 # Set limit if not already set
535 if self.limit_value is None:
536 self.limit_value = k
538 return self
540 def with_reranking(self, rerank_k: int | None = None) -> Query:
541 """Enable result reranking for vector queries.
543 Args:
544 rerank_k: Number of results to rerank (default: 2*k from vector query)
546 Returns:
547 Self for method chaining
548 """
549 if self.vector_query:
550 self.vector_query.rerank = True
551 self.vector_query.rerank_k = rerank_k or (self.vector_query.k * 2)
552 return self
554 def clear_vector(self) -> Query:
555 """Clear vector search from the query (fluent interface)."""
556 self.vector_query = None
557 return self
559 def to_dict(self) -> dict[str, Any]:
560 """Convert query to dictionary representation."""
561 result = {
562 "filters": [f.to_dict() for f in self.filters],
563 "sort": [s.to_dict() for s in self.sort_specs],
564 }
565 if self.limit_value is not None:
566 result["limit"] = self.limit_value
567 if self.offset_value is not None:
568 result["offset"] = self.offset_value
569 if self.fields is not None:
570 result["fields"] = self.fields
571 if self.vector_query is not None:
572 result["vector_query"] = self.vector_query.to_dict()
573 return result
575 @classmethod
576 def from_dict(cls, data: dict[str, Any]) -> Query:
577 """Create query from dictionary representation."""
578 query = cls()
580 for filter_data in data.get("filters", []):
581 query.filters.append(Filter.from_dict(filter_data))
583 for sort_data in data.get("sort", []):
584 query.sort_specs.append(SortSpec.from_dict(sort_data))
586 query.limit_value = data.get("limit")
587 query.offset_value = data.get("offset")
588 query.fields = data.get("fields")
590 if "vector_query" in data:
591 query.vector_query = VectorQuery.from_dict(data["vector_query"])
593 return query
595 def copy(self) -> Query:
596 """Create a copy of the query."""
597 import copy
599 return Query(
600 filters=copy.deepcopy(self.filters),
601 sort_specs=copy.deepcopy(self.sort_specs),
602 limit_value=self.limit_value,
603 offset_value=self.offset_value,
604 fields=self.fields.copy() if self.fields else None,
605 vector_query=copy.deepcopy(self.vector_query) if self.vector_query else None,
606 )
608 def or_(self, *filters: Filter | Query) -> ComplexQuery:
609 """Create a ComplexQuery with OR logic.
611 The current query's filters become an AND group, combined with OR conditions.
612 Example: Query with filters [A, B] calling or_(C, D) creates: (A AND B) AND (C OR D)
614 Args:
615 filters: Filter objects or Query objects to OR together
617 Returns:
618 ComplexQuery with OR logic
619 """
620 from .query_logic import (
621 ComplexQuery,
622 Condition,
623 FilterCondition,
624 LogicCondition,
625 LogicOperator,
626 )
628 # Build OR conditions from the arguments
629 or_conditions: list[Condition] = []
630 for item in filters:
631 if isinstance(item, Filter):
632 or_conditions.append(FilterCondition(item))
633 elif isinstance(item, Query):
634 if len(item.filters) == 1:
635 or_conditions.append(FilterCondition(item.filters[0]))
636 elif item.filters:
637 and_cond = LogicCondition(operator=LogicOperator.AND)
638 for f in item.filters:
639 and_cond.conditions.append(FilterCondition(f))
640 or_conditions.append(and_cond)
642 # Create the OR condition group
643 or_group = None
644 if or_conditions:
645 if len(or_conditions) == 1:
646 or_group = or_conditions[0]
647 else:
648 or_group = LogicCondition(
649 operator=LogicOperator.OR,
650 conditions=or_conditions
651 )
653 # Combine with existing filters (if any) using AND
654 if self.filters:
655 # Create AND condition for existing filters
656 if len(self.filters) == 1:
657 existing = FilterCondition(self.filters[0])
658 else:
659 existing = LogicCondition(operator=LogicOperator.AND)
660 for f in self.filters:
661 existing.conditions.append(FilterCondition(f))
663 # Combine existing AND new OR group with AND
664 if or_group:
665 root_condition = LogicCondition(
666 operator=LogicOperator.AND,
667 conditions=[existing, or_group]
668 )
669 else:
670 root_condition = existing
671 else:
672 # No existing filters, just use OR group
673 root_condition = or_group
675 return ComplexQuery(
676 condition=root_condition,
677 sort_specs=self.sort_specs.copy(),
678 limit_value=self.limit_value,
679 offset_value=self.offset_value,
680 fields=self.fields.copy() if self.fields else None
681 )
683 def and_(self, *filters: Filter | Query) -> Query:
684 """Add more filters with AND logic (convenience method).
686 Args:
687 filters: Filter objects or Query objects to AND together
689 Returns:
690 Self for chaining
691 """
692 for item in filters:
693 if isinstance(item, Filter):
694 self.filters.append(item)
695 elif isinstance(item, Query):
696 self.filters.extend(item.filters)
697 return self
699 def not_(self, filter: Filter) -> ComplexQuery:
700 """Create a ComplexQuery with NOT logic.
702 Args:
703 filter: Filter to negate
705 Returns:
706 ComplexQuery with NOT logic
707 """
708 from .query_logic import (
709 ComplexQuery,
710 Condition,
711 FilterCondition,
712 LogicCondition,
713 LogicOperator,
714 )
716 # Current filters as AND
717 conditions: list[Condition] = []
718 if self.filters:
719 if len(self.filters) == 1:
720 conditions.append(FilterCondition(self.filters[0]))
721 else:
722 and_cond = LogicCondition(operator=LogicOperator.AND)
723 for f in self.filters:
724 and_cond.conditions.append(FilterCondition(f))
725 conditions.append(and_cond)
727 # Add NOT condition
728 not_cond = LogicCondition(
729 operator=LogicOperator.NOT,
730 conditions=[FilterCondition(filter)]
731 )
732 conditions.append(not_cond)
734 # Create root condition
735 if len(conditions) == 1:
736 root_condition = conditions[0]
737 else:
738 root_condition = LogicCondition(
739 operator=LogicOperator.AND,
740 conditions=conditions
741 )
743 return ComplexQuery(
744 condition=root_condition,
745 sort_specs=self.sort_specs.copy(),
746 limit_value=self.limit_value,
747 offset_value=self.offset_value,
748 fields=self.fields.copy() if self.fields else None
749 )