Coverage for src/usaspending/queries/query_builder.py: 94%
175 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 17:15 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 17:15 -0700
1from abc import ABC, abstractmethod
2from typing import (
3 Iterator,
4 List,
5 Dict,
6 Any,
7 Optional,
8 TypeVar,
9 Generic,
10 TYPE_CHECKING,
11 Union,
12)
14# Import exceptions for use by all query builders
16from .filters import BaseFilter
18from ..logging_config import USASpendingLogger, log_query_execution
20T = TypeVar("T")
22if TYPE_CHECKING:
23 from ..client import USASpending
25logger = USASpendingLogger.get_logger(__name__)
28class QueryBuilder(ABC, Generic[T]):
29 """Base query builder with automatic pagination support.
31 Provides transparent pagination handling for USASpending API queries.
32 - Use limit() to set the total number of items to retrieve across all pages
33 - Use page_size() to control how many items are fetched per API request
34 - Use max_pages() to limit the number of API requests made
35 """
37 def __init__(self, client: "USASpending"):
38 self._client = client
39 self._filters: Dict[str, Any] = {}
40 self._filter_objects: list[BaseFilter] = []
41 self._page_size = 100 # Items per page (max 100 per USASpending API)
42 self._total_limit = None # Total items to return (across all pages)
43 self._max_pages = None # Limit total pages fetched
44 self._order_by = None
45 self._order_direction = "desc"
47 def limit(self, num: int) -> "QueryBuilder[T]":
48 """Set the total number of items to return across all pages."""
49 clone = self._clone()
50 clone._total_limit = num
51 return clone
53 def page_size(self, num: int) -> "QueryBuilder[T]":
54 """Set page size (max 100 per USASpending API)."""
55 clone = self._clone()
56 clone._page_size = min(num, 100)
57 return clone
59 def max_pages(self, num: int) -> "QueryBuilder[T]":
60 """Limit total number of pages fetched."""
61 clone = self._clone()
62 clone._max_pages = num
63 return clone
65 def order_by(self, field: str, direction: str = "desc") -> "QueryBuilder[T]":
66 """Set sort order."""
67 clone = self._clone()
68 clone._order_by = field
69 clone._order_direction = direction
70 return clone
72 def __iter__(self) -> Iterator[T]:
73 """Iterate over all results, handling pagination automatically."""
74 page = 1
75 pages_fetched = 0
76 items_yielded = 0
78 query_type = self.__class__.__name__
79 effective_page_size = self._get_effective_page_size()
80 logger.info(
81 f"Starting {query_type} iteration with page_size={effective_page_size}, "
82 f"total_limit={self._total_limit}, max_pages={self._max_pages}"
83 )
85 while True:
86 # Check if we've reached the total limit
87 if self._total_limit is not None and items_yielded >= self._total_limit:
88 logger.debug(f"Total limit of {self._total_limit} items reached")
89 break
91 # Check if we've reached the max pages limit
92 if self._max_pages and pages_fetched >= self._max_pages:
93 logger.debug(f"Max pages limit ({self._max_pages}) reached")
94 break
96 response = self._execute_query(page)
97 results = response.get("results", [])
98 has_next = response.get("page_metadata", {}).get("hasNext", False)
100 logger.debug(f"Page {page}: {len(results)} results, hasNext={has_next}")
102 # Empty page means no more data
103 if not results:
104 logger.debug("Empty page returned")
105 break
107 for item in results:
108 # Check limit before each yield to handle mid-page limits
109 if self._total_limit is not None and items_yielded >= self._total_limit:
110 logger.debug(f"Stopping mid-page at item {items_yielded}")
111 return
113 yield self._transform_result(item)
114 items_yielded += 1
116 # API indicates no more pages
117 if not has_next:
118 logger.debug("Last page reached (hasNext=false)")
119 break
121 page += 1
122 pages_fetched += 1
124 def first(self) -> Optional[T]:
125 """Get first result only."""
126 logger.debug(f"{self.__class__.__name__}.first() called")
127 for result in self.limit(1):
128 return result
129 return None
131 def all(self) -> List[T]:
132 """Get all results as a list."""
133 logger.debug(f"{self.__class__.__name__}.all() called")
134 results = list(self)
135 logger.info(f"{self.__class__.__name__}.all() returned {len(results)} results")
136 return results
138 def __len__(self) -> int:
139 """Return the total number of items (delegates to count())."""
140 return self.count()
142 def __getitem__(self, key: Union[int, slice]) -> Union[T, List[T]]:
143 """Support list-like indexing and slicing.
145 Args:
146 key: Integer index or slice object
148 Returns:
149 Single item for integer index, list of items for slice
151 Raises:
152 IndexError: If index is out of bounds
153 TypeError: If key is not int or slice
154 """
155 if isinstance(key, int):
156 # Handle single index
157 total_count = self.count()
159 # Convert negative index to positive
160 if key < 0:
161 key = total_count + key
163 # Check bounds
164 if key < 0 or key >= total_count:
165 raise IndexError(
166 f"Index {key} out of range for query with {total_count} items"
167 )
169 # Calculate which page contains this item
170 page_num = (key // self._page_size) + 1
171 offset_in_page = key % self._page_size
173 # Fetch just the page we need
174 logger.debug(f"Fetching page {page_num} to get item at index {key}")
175 response = self._execute_query(page_num)
176 results = response.get("results", [])
178 if offset_in_page < len(results):
179 return self._transform_result(results[offset_in_page])
180 else:
181 raise IndexError(f"Index {key} not found in results")
183 elif isinstance(key, slice):
184 # Handle slice
185 total_count = self.count()
187 # Convert slice indices
188 start, stop, step = key.indices(total_count)
190 # If step is not 1, we need to fetch more data
191 if step != 1:
192 # For non-unit steps, fetch all items in range and then slice
193 items = []
194 for i in range(start, stop):
195 if (i - start) % step == 0:
196 items.append(self[i]) # Recursive call
197 return items
199 # For contiguous slices, optimize by fetching only needed pages
200 if start >= stop:
201 return []
203 # Calculate page range
204 start_page = (start // self._page_size) + 1
205 end_page = ((stop - 1) // self._page_size) + 1
207 items = []
208 items_collected = 0
210 logger.debug(
211 f"Fetching pages {start_page} to {end_page} for slice [{start}:{stop}]"
212 )
214 for page in range(start_page, end_page + 1):
215 response = self._execute_query(page)
216 results = response.get("results", [])
218 # Calculate which items to take from this page
219 page_start_idx = (page - 1) * self._page_size
221 # Determine overlap with requested slice
222 take_start = max(0, start - page_start_idx)
223 take_end = min(len(results), stop - page_start_idx)
225 if take_start < take_end:
226 for i in range(take_start, take_end):
227 items.append(self._transform_result(results[i]))
228 items_collected += 1
230 # Stop if we've collected all requested items
231 if items_collected >= (stop - start):
232 break
234 return items
236 else:
237 raise TypeError(
238 f"indices must be integers or slices, not {type(key).__name__}"
239 )
241 @abstractmethod
242 def count(self) -> int:
243 """Get total count without fetching all results."""
244 pass
246 @property
247 @abstractmethod
248 def _endpoint(self) -> str:
249 """API endpoint for this query."""
250 pass
252 @abstractmethod
253 def _build_payload(self, page: int) -> Dict[str, Any]:
254 """Build request payload."""
255 pass
257 def _get_effective_page_size(self) -> int:
258 """Get the effective page size based on limit and configured page size."""
259 if self._total_limit is not None:
260 return min(self._page_size, self._total_limit)
261 return self._page_size
263 @abstractmethod
264 def _transform_result(self, data: Dict[str, Any]) -> T:
265 """Transform raw result to model instance."""
266 pass
268 def _aggregate_filters(self) -> dict[str, Any]:
269 """Aggregates all filter objects into a single dictionary payload."""
270 final_filters: dict[str, Any] = {}
272 # Aggregate filters
273 for f in self._filter_objects:
274 f_dict = f.to_dict()
275 for key, value in f_dict.items():
276 if key in final_filters and isinstance(final_filters[key], list):
277 final_filters[key].extend(value)
278 # Skip keys with empty values to keep payload clean
279 elif value:
280 final_filters[key] = value
282 logger.debug(f"Applied {len(self._filter_objects)} filters to query")
284 return final_filters
286 def _fetch_page(self, page: int) -> List[Dict[str, Any]]:
287 """Fetch a single page of results."""
288 response = self._execute_query(page)
289 return response.get("results", [])
291 def _execute_query(self, page: int) -> Dict[str, Any]:
292 """Execute the query and return raw response."""
293 query_type = self.__class__.__name__
294 filters_count = len(self._filters)
295 endpoint = self._endpoint
297 log_query_execution(logger, query_type, filters_count, endpoint, page)
299 payload = self._build_payload(page)
300 logger.debug(f"Query payload: {payload}")
302 response = self._client._make_request("POST", endpoint, json=payload)
304 if "page_metadata" in response:
305 metadata = response["page_metadata"]
306 logger.debug(
307 f"Page metadata: page={metadata.get('page')}, "
308 f"total={metadata.get('total')}, hasNext={metadata.get('hasNext')}"
309 )
311 return response
313 def _clone(self) -> "QueryBuilder[T]":
314 """Create a copy for method chaining."""
315 clone = self.__class__(self._client)
316 clone._filters = self._filters.copy()
317 clone._filter_objects = self._filter_objects.copy()
318 clone._page_size = self._page_size
319 clone._total_limit = self._total_limit
320 clone._max_pages = self._max_pages
321 clone._order_by = self._order_by
322 clone._order_direction = self._order_direction
323 return clone