Coverage for src/dataknobs_data/query_logic.py: 25%

221 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 14:14 -0600

1"""Boolean logic support for complex queries.""" 

2 

3from __future__ import annotations 

4 

5from abc import ABC, abstractmethod 

6from dataclasses import dataclass, field 

7from enum import Enum 

8from typing import TYPE_CHECKING, Any 

9 

10from .query import Filter, Operator, VectorQuery 

11 

12if TYPE_CHECKING: 

13 import numpy as np 

14 

15 from .query import Query 

16 from .vector.types import DistanceMetric 

17 

18 

19class LogicOperator(Enum): 

20 """Logical operators for combining conditions.""" 

21 AND = "and" 

22 OR = "or" 

23 NOT = "not" 

24 

25 

26class Condition(ABC): 

27 """Abstract base class for query conditions.""" 

28 

29 @abstractmethod 

30 def matches(self, record: Any) -> bool: 

31 """Check if a record matches this condition.""" 

32 pass 

33 

34 @abstractmethod 

35 def to_dict(self) -> dict[str, Any]: 

36 """Convert condition to dictionary representation.""" 

37 pass 

38 

39 @classmethod 

40 @abstractmethod 

41 def from_dict(cls, data: dict[str, Any]) -> Condition: 

42 """Create condition from dictionary representation.""" 

43 pass 

44 

45 

46@dataclass 

47class FilterCondition(Condition): 

48 """A single filter condition.""" 

49 filter: Filter 

50 

51 def matches(self, record: Any) -> bool: 

52 """Check if a record matches this filter.""" 

53 from .records import Record 

54 

55 if isinstance(record, Record): 

56 value = record.get_value(self.filter.field) 

57 elif isinstance(record, dict): 

58 # Support nested field access for dicts 

59 value = record 

60 for part in self.filter.field.split('.'): 

61 if isinstance(value, dict): 

62 value = value.get(part) 

63 else: 

64 value = None 

65 break 

66 else: 

67 value = getattr(record, self.filter.field, None) 

68 

69 return self.filter.matches(value) 

70 

71 def to_dict(self) -> dict[str, Any]: 

72 """Convert to dictionary representation.""" 

73 return { 

74 "type": "filter", 

75 "filter": self.filter.to_dict() 

76 } 

77 

78 @classmethod 

79 def from_dict(cls, data: dict[str, Any]) -> FilterCondition: 

80 """Create from dictionary representation.""" 

81 return cls(filter=Filter.from_dict(data["filter"])) 

82 

83 

84@dataclass 

85class LogicCondition(Condition): 

86 """A logical combination of conditions.""" 

87 operator: LogicOperator 

88 conditions: list[Condition] = field(default_factory=list) 

89 

90 def matches(self, record: Any) -> bool: 

91 """Check if a record matches this logical condition.""" 

92 if self.operator == LogicOperator.AND: 

93 # All conditions must match 

94 return all(cond.matches(record) for cond in self.conditions) 

95 elif self.operator == LogicOperator.OR: 

96 # At least one condition must match 

97 return any(cond.matches(record) for cond in self.conditions) 

98 elif self.operator == LogicOperator.NOT: 

99 # No conditions should match (or negate single condition) 

100 if len(self.conditions) == 1: 

101 return not self.conditions[0].matches(record) 

102 else: 

103 # NOT with multiple conditions = none should match 

104 return not any(cond.matches(record) for cond in self.conditions) 

105 else: 

106 # This should never be reached as all operators are handled above 

107 raise ValueError(f"Unknown logical operator: {self.operator}") 

108 

109 def to_dict(self) -> dict[str, Any]: 

110 """Convert to dictionary representation.""" 

111 return { 

112 "type": "logic", 

113 "operator": self.operator.value, 

114 "conditions": [cond.to_dict() for cond in self.conditions] 

115 } 

116 

117 @classmethod 

118 def from_dict(cls, data: dict[str, Any]) -> LogicCondition: 

119 """Create from dictionary representation.""" 

120 conditions: list[Condition] = [] 

121 for cond_data in data.get("conditions", []): 

122 if cond_data["type"] == "filter": 

123 conditions.append(FilterCondition.from_dict(cond_data)) 

124 elif cond_data["type"] == "logic": 

125 conditions.append(LogicCondition.from_dict(cond_data)) 

126 

127 return cls( 

128 operator=LogicOperator(data["operator"]), 

129 conditions=conditions 

130 ) 

131 

132 

133def condition_from_dict(data: dict[str, Any]) -> Condition: 

134 """Factory function to create condition from dictionary.""" 

135 if data["type"] == "filter": 

136 return FilterCondition.from_dict(data) 

137 elif data["type"] == "logic": 

138 return LogicCondition.from_dict(data) 

139 else: 

140 raise ValueError(f"Unknown condition type: {data['type']}") 

141 

142 

143class QueryBuilder: 

144 """Builder for complex queries with boolean logic.""" 

145 

146 def __init__(self): 

147 """Initialize empty query builder.""" 

148 self.root_condition = None 

149 self.sort_specs = [] 

150 self.limit_value = None 

151 self.offset_value = None 

152 self.fields = None 

153 self.vector_query = None 

154 

155 def where(self, field: str, operator: str | Operator, value: Any = None) -> QueryBuilder: 

156 """Add a filter condition (defaults to AND with existing conditions).""" 

157 op = Operator(operator) if isinstance(operator, str) else operator 

158 filter_cond = FilterCondition(Filter(field, op, value)) 

159 

160 if self.root_condition is None: 

161 self.root_condition = filter_cond 

162 elif isinstance(self.root_condition, LogicCondition) and self.root_condition.operator == LogicOperator.AND: 

163 self.root_condition.conditions.append(filter_cond) 

164 else: 

165 # Wrap existing condition in AND 

166 self.root_condition = LogicCondition( 

167 operator=LogicOperator.AND, 

168 conditions=[self.root_condition, filter_cond] 

169 ) 

170 

171 return self 

172 

173 def and_(self, *conditions: QueryBuilder | Filter | Condition) -> QueryBuilder: 

174 """Add AND conditions.""" 

175 logic_cond = LogicCondition(operator=LogicOperator.AND) 

176 

177 for cond in conditions: 

178 if isinstance(cond, QueryBuilder): 

179 if cond.root_condition: 

180 logic_cond.conditions.append(cond.root_condition) 

181 elif isinstance(cond, Filter): 

182 logic_cond.conditions.append(FilterCondition(cond)) 

183 elif isinstance(cond, Condition): 

184 logic_cond.conditions.append(cond) 

185 

186 if self.root_condition is None: 

187 self.root_condition = logic_cond 

188 elif isinstance(self.root_condition, LogicCondition) and self.root_condition.operator == LogicOperator.AND: 

189 self.root_condition.conditions.extend(logic_cond.conditions) 

190 else: 

191 self.root_condition = LogicCondition( 

192 operator=LogicOperator.AND, 

193 conditions=[self.root_condition, logic_cond] 

194 ) 

195 

196 return self 

197 

198 def or_(self, *conditions: QueryBuilder | Filter | Condition) -> QueryBuilder: 

199 """Add OR conditions.""" 

200 logic_cond = LogicCondition(operator=LogicOperator.OR) 

201 

202 for cond in conditions: 

203 if isinstance(cond, QueryBuilder): 

204 if cond.root_condition: 

205 logic_cond.conditions.append(cond.root_condition) 

206 elif isinstance(cond, Filter): 

207 logic_cond.conditions.append(FilterCondition(cond)) 

208 elif isinstance(cond, Condition): 

209 logic_cond.conditions.append(cond) 

210 

211 if self.root_condition is None: 

212 self.root_condition = logic_cond 

213 else: 

214 # Always wrap in OR at top level 

215 if isinstance(self.root_condition, LogicCondition) and self.root_condition.operator == LogicOperator.OR: 

216 self.root_condition.conditions.extend(logic_cond.conditions) 

217 else: 

218 self.root_condition = LogicCondition( 

219 operator=LogicOperator.OR, 

220 conditions=[self.root_condition] + logic_cond.conditions 

221 ) 

222 

223 return self 

224 

225 def not_(self, condition: QueryBuilder | Filter | Condition) -> QueryBuilder: 

226 """Add NOT condition.""" 

227 if isinstance(condition, QueryBuilder): 

228 not_cond = LogicCondition( 

229 operator=LogicOperator.NOT, 

230 conditions=[condition.root_condition] if condition.root_condition else [] 

231 ) 

232 elif isinstance(condition, Filter): 

233 not_cond = LogicCondition( 

234 operator=LogicOperator.NOT, 

235 conditions=[FilterCondition(condition)] 

236 ) 

237 else: 

238 not_cond = LogicCondition( 

239 operator=LogicOperator.NOT, 

240 conditions=[condition] 

241 ) 

242 

243 if self.root_condition is None: 

244 self.root_condition = not_cond 

245 elif isinstance(self.root_condition, LogicCondition) and self.root_condition.operator == LogicOperator.AND: 

246 self.root_condition.conditions.append(not_cond) 

247 else: 

248 self.root_condition = LogicCondition( 

249 operator=LogicOperator.AND, 

250 conditions=[self.root_condition, not_cond] 

251 ) 

252 

253 return self 

254 

255 def sort_by(self, field: str, order: str = "asc") -> QueryBuilder: 

256 """Add sort specification.""" 

257 from .query import SortOrder, SortSpec 

258 

259 sort_order = SortOrder.ASC if order.lower() == "asc" else SortOrder.DESC 

260 self.sort_specs.append(SortSpec(field=field, order=sort_order)) 

261 return self 

262 

263 def limit(self, value: int) -> QueryBuilder: 

264 """Set result limit.""" 

265 self.limit_value = value 

266 return self 

267 

268 def offset(self, value: int) -> QueryBuilder: 

269 """Set result offset.""" 

270 self.offset_value = value 

271 return self 

272 

273 def select(self, *fields: str) -> QueryBuilder: 

274 """Set field projection.""" 

275 self.fields = list(fields) if fields else None 

276 return self 

277 

278 def similar_to( 

279 self, 

280 vector: np.ndarray | list[float], 

281 field: str = "embedding", 

282 k: int = 10, 

283 metric: DistanceMetric | str = "cosine", 

284 include_source: bool = True, 

285 score_threshold: float | None = None, 

286 ) -> QueryBuilder: 

287 """Add vector similarity search.""" 

288 self.vector_query = VectorQuery( 

289 vector=vector, 

290 field_name=field, 

291 k=k, 

292 metric=metric, 

293 include_source=include_source, 

294 score_threshold=score_threshold, 

295 ) 

296 # If limit is not set, use k as the limit 

297 if self.limit_value is None: 

298 self.limit_value = k 

299 return self 

300 

301 def build(self) -> ComplexQuery: 

302 """Build the final query.""" 

303 return ComplexQuery( 

304 condition=self.root_condition, 

305 sort_specs=self.sort_specs, 

306 limit_value=self.limit_value, 

307 offset_value=self.offset_value, 

308 fields=self.fields, 

309 vector_query=self.vector_query 

310 ) 

311 

312 

313@dataclass 

314class ComplexQuery: 

315 """A query with complex boolean logic support.""" 

316 

317 # All fields have defaults to avoid ordering issues 

318 condition: Condition | None = None 

319 sort_specs: list = field(default_factory=list) 

320 limit_value: int | None = None 

321 offset_value: int | None = None 

322 fields: list[str] | None = None 

323 vector_query: VectorQuery | None = None # Vector similarity search 

324 

325 @classmethod 

326 def AND(cls, queries: list[Query]) -> ComplexQuery: 

327 """Create a complex query with AND logic.""" 

328 from .query import Query 

329 

330 conditions: list[Condition] = [] 

331 for q in queries: 

332 if isinstance(q, Query): 

333 # Convert Query filters to conditions 

334 for f in q.filters: 

335 conditions.append(FilterCondition(filter=f)) 

336 

337 return cls( 

338 condition=LogicCondition(operator=LogicOperator.AND, conditions=conditions) 

339 ) 

340 

341 @classmethod 

342 def OR(cls, queries: list[Query]) -> ComplexQuery: 

343 """Create a complex query with OR logic.""" 

344 from .query import Query 

345 

346 conditions: list[Condition] = [] 

347 for q in queries: 

348 if isinstance(q, Query): 

349 # Convert Query filters to conditions 

350 for f in q.filters: 

351 conditions.append(FilterCondition(filter=f)) 

352 

353 return cls( 

354 condition=LogicCondition(operator=LogicOperator.OR, conditions=conditions) 

355 ) 

356 

357 def matches(self, record: Any) -> bool: 

358 """Check if a record matches this query.""" 

359 if self.condition is None: 

360 return True 

361 return self.condition.matches(record) 

362 

363 def to_simple_query(self) -> Query: 

364 """Convert to simple Query if possible (AND filters only).""" 

365 from .query import Query 

366 

367 filters = [] 

368 

369 # Try to extract simple filters if all are AND conditions 

370 if self.condition is None: 

371 pass 

372 elif isinstance(self.condition, FilterCondition): 

373 filters.append(self.condition.filter) 

374 elif isinstance(self.condition, LogicCondition) and self.condition.operator == LogicOperator.AND: 

375 # Check if all sub-conditions are simple filters 

376 all_filters = True 

377 for cond in self.condition.conditions: 

378 if isinstance(cond, FilterCondition): 

379 filters.append(cond.filter) 

380 else: 

381 all_filters = False 

382 break 

383 

384 if not all_filters: 

385 # Can't convert complex logic to simple query 

386 raise ValueError("Cannot convert complex boolean logic to simple Query") 

387 else: 

388 raise ValueError("Cannot convert complex boolean logic to simple Query") 

389 

390 return Query( 

391 filters=filters, 

392 sort_specs=self.sort_specs, 

393 limit_value=self.limit_value, 

394 offset_value=self.offset_value, 

395 fields=self.fields, 

396 vector_query=self.vector_query 

397 ) 

398 

399 def to_dict(self) -> dict[str, Any]: 

400 """Convert to dictionary representation.""" 

401 result = {} 

402 

403 if self.condition: 

404 result["condition"] = self.condition.to_dict() 

405 

406 if self.sort_specs: 

407 result["sort"] = [s.to_dict() for s in self.sort_specs] 

408 

409 if self.limit_value is not None: 

410 result["limit"] = self.limit_value 

411 

412 if self.offset_value is not None: 

413 result["offset"] = self.offset_value 

414 

415 if self.fields is not None: 

416 result["fields"] = self.fields 

417 

418 if self.vector_query is not None: 

419 result["vector_query"] = self.vector_query.to_dict() 

420 

421 return result 

422 

423 @classmethod 

424 def from_dict(cls, data: dict[str, Any]) -> ComplexQuery: 

425 """Create from dictionary representation.""" 

426 from .query import SortSpec 

427 

428 condition = None 

429 if "condition" in data: 

430 condition = condition_from_dict(data["condition"]) 

431 

432 sort_specs = [] 

433 for sort_data in data.get("sort", []): 

434 sort_specs.append(SortSpec.from_dict(sort_data)) 

435 

436 vector_query = None 

437 if "vector_query" in data: 

438 vector_query = VectorQuery.from_dict(data["vector_query"]) 

439 

440 return cls( 

441 condition=condition, 

442 sort_specs=sort_specs, 

443 limit_value=data.get("limit"), 

444 offset_value=data.get("offset"), 

445 fields=data.get("fields"), 

446 vector_query=vector_query 

447 )