Coverage for src/dataknobs_data/query.py: 29%
307 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1from __future__ import annotations
3from dataclasses import dataclass, field
4from enum import Enum
5from typing import TYPE_CHECKING, Any
7if TYPE_CHECKING:
8 from collections.abc import Callable
10 import numpy as np
12 from .query_logic import ComplexQuery
13 from .vector.types import DistanceMetric
16class Operator(Enum):
17 """Query operators for filtering."""
19 EQ = "=" # Equal
20 NEQ = "!=" # Not equal
21 GT = ">" # Greater than
22 GTE = ">=" # Greater than or equal
23 LT = "<" # Less than
24 LTE = "<=" # Less than or equal
25 IN = "in" # In list
26 NOT_IN = "not_in" # Not in list
27 LIKE = "like" # String pattern matching (SQL LIKE)
28 REGEX = "regex" # Regular expression matching
29 EXISTS = "exists" # Field exists
30 NOT_EXISTS = "not_exists" # Field does not exist
31 BETWEEN = "between" # Value between two bounds (inclusive)
32 NOT_BETWEEN = "not_between" # Value not between two bounds
35class SortOrder(Enum):
36 """Sort order for query results."""
38 ASC = "asc"
39 DESC = "desc"
42@dataclass
43class Filter:
44 """Represents a filter condition."""
46 field: str
47 operator: Operator
48 value: Any = None
50 def matches(self, record_value: Any) -> bool:
51 """Check if a record value matches this filter.
53 Supports type-aware comparisons for ranges and special handling
54 for datetime/date objects.
55 """
56 if self.operator == Operator.EXISTS:
57 return record_value is not None
58 elif self.operator == Operator.NOT_EXISTS:
59 return record_value is None
60 elif record_value is None:
61 return False
63 if self.operator == Operator.EQ:
64 return record_value == self.value
65 elif self.operator == Operator.NEQ:
66 return record_value != self.value
67 elif self.operator == Operator.GT:
68 return self._compare_values(record_value, self.value, lambda a, b: a > b)
69 elif self.operator == Operator.GTE:
70 return self._compare_values(record_value, self.value, lambda a, b: a >= b)
71 elif self.operator == Operator.LT:
72 return self._compare_values(record_value, self.value, lambda a, b: a < b)
73 elif self.operator == Operator.LTE:
74 return self._compare_values(record_value, self.value, lambda a, b: a <= b)
75 elif self.operator == Operator.IN:
76 return record_value in self.value
77 elif self.operator == Operator.NOT_IN:
78 return record_value not in self.value
79 elif self.operator == Operator.BETWEEN:
80 if not isinstance(self.value, (list, tuple)) or len(self.value) != 2:
81 return False
82 lower, upper = self.value
83 return self._compare_values(record_value, lower, lambda a, b: a >= b) and \
84 self._compare_values(record_value, upper, lambda a, b: a <= b)
85 elif self.operator == Operator.NOT_BETWEEN:
86 if not isinstance(self.value, (list, tuple)) or len(self.value) != 2:
87 return True
88 lower, upper = self.value
89 return not (self._compare_values(record_value, lower, lambda a, b: a >= b) and \
90 self._compare_values(record_value, upper, lambda a, b: a <= b))
91 elif self.operator == Operator.LIKE:
92 if not isinstance(record_value, str):
93 return False
94 import re
96 pattern = self.value.replace("%", ".*").replace("_", ".")
97 return bool(re.match(f"^{pattern}$", record_value))
98 elif self.operator == Operator.REGEX:
99 if not isinstance(record_value, str):
100 return False
101 import re
103 return bool(re.search(self.value, record_value))
104 else:
105 # This should never be reached as all operators are handled above
106 raise ValueError(f"Unknown operator: {self.operator}")
108 def _compare_values(self, a: Any, b: Any, comparator) -> bool:
109 """Compare two values with type awareness.
111 Handles special cases:
112 - Datetime strings are parsed for comparison
113 - Mixed numeric types are converted appropriately
114 - String comparisons are case-sensitive
115 """
116 from datetime import date, datetime
118 # Handle datetime/date comparisons
119 if isinstance(a, str) and isinstance(b, (datetime, date)):
120 try:
121 a = datetime.fromisoformat(a.replace("Z", "+00:00"))
122 except (ValueError, AttributeError):
123 return False
124 elif isinstance(b, str) and isinstance(a, (datetime, date)):
125 try:
126 b = datetime.fromisoformat(b.replace("Z", "+00:00"))
127 except (ValueError, AttributeError):
128 return False
129 elif isinstance(a, str) and isinstance(b, str):
130 # Check if both look like dates
131 if "T" in a or "-" in a:
132 try:
133 a = datetime.fromisoformat(a.replace("Z", "+00:00"))
134 b = datetime.fromisoformat(b.replace("Z", "+00:00"))
135 except (ValueError, AttributeError):
136 pass # Keep as strings
138 # Handle numeric comparisons
139 if isinstance(a, (int, float)) and isinstance(b, (int, float)):
140 return comparator(a, b)
142 # Try direct comparison
143 try:
144 return comparator(a, b)
145 except TypeError:
146 # Types not comparable
147 return False
149 def to_dict(self) -> dict[str, Any]:
150 """Convert filter to dictionary representation."""
151 return {"field": self.field, "operator": self.operator.value, "value": self.value}
153 @classmethod
154 def from_dict(cls, data: dict[str, Any]) -> Filter:
155 """Create filter from dictionary representation."""
156 return cls(
157 field=data["field"], operator=Operator(data["operator"]), value=data.get("value")
158 )
161@dataclass
162class SortSpec:
163 """Represents a sort specification."""
165 field: str
166 order: SortOrder = SortOrder.ASC
168 def to_dict(self) -> dict[str, str]:
169 """Convert sort spec to dictionary representation."""
170 return {"field": self.field, "order": self.order.value}
172 @classmethod
173 def from_dict(cls, data: dict[str, str]) -> SortSpec:
174 """Create sort spec from dictionary representation."""
175 return cls(field=data["field"], order=SortOrder(data.get("order", "asc")))
178@dataclass
179class VectorQuery:
180 """Represents a vector similarity search query.
182 This dataclass encapsulates all parameters needed for vector similarity search,
183 including the query vector, distance metric, and various search options.
184 """
186 vector: np.ndarray | list[float] # Query vector or embeddings
187 field_name: str = "embedding" # Vector field name to search
188 k: int = 10 # Number of results (top-k)
189 metric: DistanceMetric | str = "cosine" # Distance metric
190 include_source: bool = True # Include source text in results
191 score_threshold: float | None = None # Minimum similarity score
192 rerank: bool = False # Whether to rerank results
193 rerank_k: int | None = None # Number of results to rerank (default: 2*k)
194 metadata: dict[str, Any] = field(default_factory=dict) # Additional metadata
196 def to_dict(self) -> dict[str, Any]:
197 """Convert vector query to dictionary representation."""
198 import numpy as np
200 # Handle vector serialization
201 vector_data = self.vector
202 if isinstance(vector_data, np.ndarray):
203 vector_data = vector_data.tolist()
205 # Handle metric serialization
206 metric_value = self.metric
207 if hasattr(metric_value, 'value'): # DistanceMetric enum
208 metric_value = metric_value.value
210 result = {
211 "vector": vector_data,
212 "field": self.field_name,
213 "k": self.k,
214 "metric": metric_value,
215 "include_source": self.include_source,
216 }
218 if self.score_threshold is not None:
219 result["score_threshold"] = self.score_threshold
220 if self.rerank:
221 result["rerank"] = self.rerank
222 if self.rerank_k is not None:
223 result["rerank_k"] = self.rerank_k
224 if self.metadata:
225 result["metadata"] = self.metadata
227 return result
229 @classmethod
230 def from_dict(cls, data: dict[str, Any]) -> VectorQuery:
231 """Create vector query from dictionary representation."""
232 import numpy as np
234 from .vector.types import DistanceMetric
236 # Handle vector deserialization
237 vector_data = data["vector"]
238 if not isinstance(vector_data, np.ndarray):
239 vector_data = np.array(vector_data, dtype=np.float32)
241 # Handle metric deserialization
242 metric_value = data.get("metric", "cosine")
243 if isinstance(metric_value, str):
244 try:
245 metric_value = DistanceMetric(metric_value)
246 except ValueError:
247 # Keep as string if not a valid enum value
248 pass
250 return cls(
251 vector=vector_data,
252 field_name=data.get("field", "embedding"),
253 k=data.get("k", 10),
254 metric=metric_value,
255 include_source=data.get("include_source", True),
256 score_threshold=data.get("score_threshold"),
257 rerank=data.get("rerank", False),
258 rerank_k=data.get("rerank_k"),
259 metadata=data.get("metadata", {}),
260 )
263@dataclass
264class Query:
265 """Represents a database query with filters, sorting, pagination, and vector search."""
267 filters: list[Filter] = field(default_factory=list)
268 sort_specs: list[SortSpec] = field(default_factory=list)
269 limit_value: int | None = None
270 offset_value: int | None = None
271 fields: list[str] | None = None # Field projection
272 vector_query: VectorQuery | None = None # Vector similarity search
274 @property
275 def sort_property(self) -> list[SortSpec]:
276 """Get sort specifications (backward compatibility)."""
277 return self.sort_specs
279 @property
280 def limit_property(self) -> int | None:
281 """Get limit value (backward compatibility)."""
282 return self.limit_value
284 @property
285 def offset_property(self) -> int | None:
286 """Get offset value (backward compatibility)."""
287 return self.offset_value
289 def filter(self, field: str, operator: str | Operator, value: Any = None) -> Query:
290 """Add a filter to the query (fluent interface).
292 Args:
293 field: The field name to filter on
294 operator: The operator (string or Operator enum)
295 value: The value to compare against
297 Returns:
298 Self for method chaining
299 """
300 if isinstance(operator, str):
301 op_map = {
302 "=": Operator.EQ,
303 "==": Operator.EQ,
304 "!=": Operator.NEQ,
305 ">": Operator.GT,
306 ">=": Operator.GTE,
307 "<": Operator.LT,
308 "<=": Operator.LTE,
309 "in": Operator.IN,
310 "IN": Operator.IN,
311 "not_in": Operator.NOT_IN,
312 "NOT IN": Operator.NOT_IN,
313 "like": Operator.LIKE,
314 "LIKE": Operator.LIKE,
315 "regex": Operator.REGEX,
316 "exists": Operator.EXISTS,
317 "not_exists": Operator.NOT_EXISTS,
318 "between": Operator.BETWEEN,
319 "BETWEEN": Operator.BETWEEN,
320 "not_between": Operator.NOT_BETWEEN,
321 "NOT BETWEEN": Operator.NOT_BETWEEN,
322 }
323 operator = op_map.get(operator, Operator.EQ)
325 self.filters.append(Filter(field=field, operator=operator, value=value))
326 return self
328 def sort_by(self, field: str, order: str | SortOrder = "asc") -> Query:
329 """Add a sort specification to the query (fluent interface).
331 Args:
332 field: The field name to sort by
333 order: The sort order ("asc", "desc", or SortOrder enum)
335 Returns:
336 Self for method chaining
337 """
338 if isinstance(order, str):
339 order = SortOrder.ASC if order.lower() == "asc" else SortOrder.DESC
341 self.sort_specs.append(SortSpec(field=field, order=order))
342 return self
344 def sort(self, field: str, order: str | SortOrder = "asc") -> Query:
345 """Add sorting (fluent interface)."""
346 return self.sort_by(field, order)
348 def set_limit(self, limit: int) -> Query:
349 """Set the result limit (fluent interface).
351 Args:
352 limit: Maximum number of results
354 Returns:
355 Self for method chaining
356 """
357 self.limit_value = limit
358 return self
360 def limit(self, value: int) -> Query:
361 """Set limit (fluent interface)."""
362 return self.set_limit(value)
364 def set_offset(self, offset: int) -> Query:
365 """Set the result offset (fluent interface).
367 Args:
368 offset: Number of results to skip
370 Returns:
371 Self for method chaining
372 """
373 self.offset_value = offset
374 return self
376 def offset(self, value: int) -> Query:
377 """Set offset (fluent interface)."""
378 return self.set_offset(value)
380 def select(self, *fields: str) -> Query:
381 """Set field projection (fluent interface).
383 Args:
384 fields: Field names to include in results
386 Returns:
387 Self for method chaining
388 """
389 self.fields = list(fields) if fields else None
390 return self
392 def clear_filters(self) -> Query:
393 """Clear all filters (fluent interface)."""
394 self.filters = []
395 return self
397 def clear_sort(self) -> Query:
398 """Clear all sort specifications (fluent interface)."""
399 self.sort_specs = []
400 return self
402 def similar_to(
403 self,
404 vector: np.ndarray | list[float],
405 field: str = "embedding",
406 k: int = 10,
407 metric: DistanceMetric | str = "cosine",
408 include_source: bool = True,
409 score_threshold: float | None = None,
410 ) -> Query:
411 """Add vector similarity search to the query.
413 This method sets up a vector similarity search that will find the k most
414 similar vectors to the provided query vector.
416 Args:
417 vector: Query vector to search for similar vectors
418 field: Vector field name to search (default: "embedding")
419 k: Number of results to return (default: 10)
420 metric: Distance metric to use (default: "cosine")
421 include_source: Whether to include source text in results (default: True)
422 score_threshold: Minimum similarity score threshold (optional)
424 Returns:
425 Self for method chaining
426 """
427 self.vector_query = VectorQuery(
428 vector=vector,
429 field_name=field,
430 k=k,
431 metric=metric,
432 include_source=include_source,
433 score_threshold=score_threshold,
434 )
435 # Always update limit to match k
436 self.limit_value = k
437 return self
439 def near_text(
440 self,
441 text: str,
442 embedding_fn: Callable[[str], np.ndarray],
443 field: str = "embedding",
444 k: int = 10,
445 metric: DistanceMetric | str = "cosine",
446 include_source: bool = True,
447 score_threshold: float | None = None,
448 ) -> Query:
449 """Add text-based vector similarity search to the query.
451 This is a convenience method that converts text to a vector using the
452 provided embedding function, then performs vector similarity search.
454 Args:
455 text: Text to convert to vector for similarity search
456 embedding_fn: Function to convert text to vector
457 field: Vector field name to search (default: "embedding")
458 k: Number of results to return (default: 10)
459 metric: Distance metric to use (default: "cosine")
460 include_source: Whether to include source text in results (default: True)
461 score_threshold: Minimum similarity score threshold (optional)
463 Returns:
464 Self for method chaining
465 """
466 # Convert text to vector using provided embedding function
467 vector = embedding_fn(text)
468 return self.similar_to(
469 vector=vector,
470 field=field,
471 k=k,
472 metric=metric,
473 include_source=include_source,
474 score_threshold=score_threshold,
475 )
477 def hybrid(
478 self,
479 text_query: str | None = None,
480 vector: np.ndarray | list[float] | None = None,
481 text_field: str = "content",
482 vector_field: str = "embedding",
483 alpha: float = 0.5,
484 k: int = 10,
485 metric: DistanceMetric | str = "cosine",
486 ) -> Query:
487 """Create a hybrid query combining text and vector search.
489 This method combines traditional text search with vector similarity search,
490 allowing for more nuanced queries that leverage both exact text matching
491 and semantic similarity.
493 Args:
494 text_query: Text to search for (optional)
495 vector: Vector for similarity search (optional)
496 text_field: Field for text search (default: "content")
497 vector_field: Field for vector search (default: "embedding")
498 alpha: Weight balance between text (0.0) and vector (1.0) search (default: 0.5)
499 k: Number of results to return (default: 10)
500 metric: Distance metric for vector search (default: "cosine")
502 Returns:
503 Self for method chaining
505 Note:
506 - alpha=0.0 gives full weight to text search
507 - alpha=1.0 gives full weight to vector search
508 - alpha=0.5 gives equal weight to both
509 """
510 # Add text filter if provided
511 if text_query:
512 self.filter(text_field, Operator.LIKE, f"%{text_query}%")
514 # Add vector search if provided
515 if vector is not None:
516 self.vector_query = VectorQuery(
517 vector=vector,
518 field_name=vector_field,
519 k=k,
520 metric=metric,
521 include_source=True,
522 )
523 # Store alpha in vector query metadata for backend to use
524 self.vector_query.metadata = {"hybrid_alpha": alpha}
526 # Set limit if not already set
527 if self.limit_value is None:
528 self.limit_value = k
530 return self
532 def with_reranking(self, rerank_k: int | None = None) -> Query:
533 """Enable result reranking for vector queries.
535 Args:
536 rerank_k: Number of results to rerank (default: 2*k from vector query)
538 Returns:
539 Self for method chaining
540 """
541 if self.vector_query:
542 self.vector_query.rerank = True
543 self.vector_query.rerank_k = rerank_k or (self.vector_query.k * 2)
544 return self
546 def clear_vector(self) -> Query:
547 """Clear vector search from the query (fluent interface)."""
548 self.vector_query = None
549 return self
551 def to_dict(self) -> dict[str, Any]:
552 """Convert query to dictionary representation."""
553 result = {
554 "filters": [f.to_dict() for f in self.filters],
555 "sort": [s.to_dict() for s in self.sort_specs],
556 }
557 if self.limit_value is not None:
558 result["limit"] = self.limit_value
559 if self.offset_value is not None:
560 result["offset"] = self.offset_value
561 if self.fields is not None:
562 result["fields"] = self.fields
563 if self.vector_query is not None:
564 result["vector_query"] = self.vector_query.to_dict()
565 return result
567 @classmethod
568 def from_dict(cls, data: dict[str, Any]) -> Query:
569 """Create query from dictionary representation."""
570 query = cls()
572 for filter_data in data.get("filters", []):
573 query.filters.append(Filter.from_dict(filter_data))
575 for sort_data in data.get("sort", []):
576 query.sort_specs.append(SortSpec.from_dict(sort_data))
578 query.limit_value = data.get("limit")
579 query.offset_value = data.get("offset")
580 query.fields = data.get("fields")
582 if "vector_query" in data:
583 query.vector_query = VectorQuery.from_dict(data["vector_query"])
585 return query
587 def copy(self) -> Query:
588 """Create a copy of the query."""
589 import copy
591 return Query(
592 filters=copy.deepcopy(self.filters),
593 sort_specs=copy.deepcopy(self.sort_specs),
594 limit_value=self.limit_value,
595 offset_value=self.offset_value,
596 fields=self.fields.copy() if self.fields else None,
597 vector_query=copy.deepcopy(self.vector_query) if self.vector_query else None,
598 )
600 def or_(self, *filters: Filter | Query) -> ComplexQuery:
601 """Create a ComplexQuery with OR logic.
603 The current query's filters become an AND group, combined with OR conditions.
604 Example: Query with filters [A, B] calling or_(C, D) creates: (A AND B) AND (C OR D)
606 Args:
607 filters: Filter objects or Query objects to OR together
609 Returns:
610 ComplexQuery with OR logic
611 """
612 from .query_logic import (
613 ComplexQuery,
614 Condition,
615 FilterCondition,
616 LogicCondition,
617 LogicOperator,
618 )
620 # Build OR conditions from the arguments
621 or_conditions: list[Condition] = []
622 for item in filters:
623 if isinstance(item, Filter):
624 or_conditions.append(FilterCondition(item))
625 elif isinstance(item, Query):
626 if len(item.filters) == 1:
627 or_conditions.append(FilterCondition(item.filters[0]))
628 elif item.filters:
629 and_cond = LogicCondition(operator=LogicOperator.AND)
630 for f in item.filters:
631 and_cond.conditions.append(FilterCondition(f))
632 or_conditions.append(and_cond)
634 # Create the OR condition group
635 or_group = None
636 if or_conditions:
637 if len(or_conditions) == 1:
638 or_group = or_conditions[0]
639 else:
640 or_group = LogicCondition(
641 operator=LogicOperator.OR,
642 conditions=or_conditions
643 )
645 # Combine with existing filters (if any) using AND
646 if self.filters:
647 # Create AND condition for existing filters
648 if len(self.filters) == 1:
649 existing = FilterCondition(self.filters[0])
650 else:
651 existing = LogicCondition(operator=LogicOperator.AND)
652 for f in self.filters:
653 existing.conditions.append(FilterCondition(f))
655 # Combine existing AND new OR group with AND
656 if or_group:
657 root_condition = LogicCondition(
658 operator=LogicOperator.AND,
659 conditions=[existing, or_group]
660 )
661 else:
662 root_condition = existing
663 else:
664 # No existing filters, just use OR group
665 root_condition = or_group
667 return ComplexQuery(
668 condition=root_condition,
669 sort_specs=self.sort_specs.copy(),
670 limit_value=self.limit_value,
671 offset_value=self.offset_value,
672 fields=self.fields.copy() if self.fields else None
673 )
675 def and_(self, *filters: Filter | Query) -> Query:
676 """Add more filters with AND logic (convenience method).
678 Args:
679 filters: Filter objects or Query objects to AND together
681 Returns:
682 Self for chaining
683 """
684 for item in filters:
685 if isinstance(item, Filter):
686 self.filters.append(item)
687 elif isinstance(item, Query):
688 self.filters.extend(item.filters)
689 return self
691 def not_(self, filter: Filter) -> ComplexQuery:
692 """Create a ComplexQuery with NOT logic.
694 Args:
695 filter: Filter to negate
697 Returns:
698 ComplexQuery with NOT logic
699 """
700 from .query_logic import (
701 ComplexQuery,
702 Condition,
703 FilterCondition,
704 LogicCondition,
705 LogicOperator,
706 )
708 # Current filters as AND
709 conditions: list[Condition] = []
710 if self.filters:
711 if len(self.filters) == 1:
712 conditions.append(FilterCondition(self.filters[0]))
713 else:
714 and_cond = LogicCondition(operator=LogicOperator.AND)
715 for f in self.filters:
716 and_cond.conditions.append(FilterCondition(f))
717 conditions.append(and_cond)
719 # Add NOT condition
720 not_cond = LogicCondition(
721 operator=LogicOperator.NOT,
722 conditions=[FilterCondition(filter)]
723 )
724 conditions.append(not_cond)
726 # Create root condition
727 if len(conditions) == 1:
728 root_condition = conditions[0]
729 else:
730 root_condition = LogicCondition(
731 operator=LogicOperator.AND,
732 conditions=conditions
733 )
735 return ComplexQuery(
736 condition=root_condition,
737 sort_specs=self.sort_specs.copy(),
738 limit_value=self.limit_value,
739 offset_value=self.offset_value,
740 fields=self.fields.copy() if self.fields else None
741 )