Coverage for src/usaspending/queries/transactions_search.py: 60%
83 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 17:15 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 17:15 -0700
1"""Transactions search query builder for USASpending data."""
3from __future__ import annotations
5from typing import Any, Dict, TYPE_CHECKING, Iterator
6from datetime import datetime
8from ..exceptions import ValidationError
9from ..models.transaction import Transaction
10from .query_builder import QueryBuilder
11from ..logging_config import USASpendingLogger
13if TYPE_CHECKING:
14 from ..client import USASpending
16logger = USASpendingLogger.get_logger(__name__)
19class TransactionsSearch(QueryBuilder["Transaction"]):
20 """
21 Builds and executes a transactions search query, allowing for filtering
22 on transaction data. This class follows a fluent interface pattern.
23 """
25 def __init__(self, client: "USASpending"):
26 """
27 Initializes the TransactionsSearch query builder.
29 Args:
30 client: The USASpending client instance.
31 """
32 super().__init__(client)
33 self._award_id: str = None
34 # Client-side filters (not supported by API)
35 self._client_filters = {}
37 @property
38 def _endpoint(self) -> str:
39 """The API endpoint for this query."""
40 return "/transactions/"
42 def _clone(self) -> TransactionsSearch:
43 """Creates an immutable copy of the query builder."""
44 clone = super()._clone()
45 clone._filter_objects = self._filter_objects.copy()
46 clone._award_id = self._award_id
47 clone._client_filters = self._client_filters.copy()
48 return clone
50 def _build_payload(self, page: int) -> Dict[str, Any]:
51 """Constructs the final API request payload from the filter objects."""
53 if not self._award_id:
54 raise ValidationError(
55 "An award_id is required. Use the .for_award() method."
56 )
58 payload = {
59 "award_id": self._award_id,
60 "limit": self._get_effective_page_size(),
61 "page": page,
62 }
64 # Add any additional filters if they exist
65 final_filters = self._aggregate_filters()
66 if final_filters:
67 payload.update(final_filters)
69 return payload
71 def _transform_result(self, result: Dict[str, Any]) -> Transaction:
72 """Transforms a single API result item into a Transaction model."""
73 return Transaction(result)
75 def count(self) -> int:
76 """Counts the number of transactions per a given award id."""
77 logger.debug(f"{self.__class__.__name__}.count() called")
79 # If we have client-side filters, we need to fetch all results and count
80 if self._client_filters:
81 logger.debug(
82 "Client-side filters present, counting by iterating all results"
83 )
84 count = 0
85 for _ in self:
86 count += 1
87 return count
89 # No client-side filters, use the efficient API count endpoint
90 endpoint = f"/awards/count/transaction/{self._award_id}/"
92 from ..logging_config import log_query_execution
94 log_query_execution(logger, "TransactionsSearch.count", 1, endpoint)
96 # Send the request to the count endpoint
97 response = self._client._make_request("GET", endpoint)
99 # Extract count from the appropriate category
100 total = response.get("transactions", 0)
102 logger.info(
103 f"{self.__class__.__name__}.count() = {total} transactions for award {self._award_id}"
104 )
105 return total
107 # ==========================================================================
108 # Filter Methods
109 # ==========================================================================
111 def for_award(self, award_id: str) -> TransactionsSearch:
112 """
113 Filter transactions for a specific award.
115 Args:
116 award_id: The unique award identifier.
118 Returns:
119 A new `TransactionsSearch` instance with the award filter applied.
120 """
121 if not award_id:
122 raise ValidationError("award_id cannot be empty")
124 clone = self._clone()
125 clone._award_id = str(award_id).strip()
126 return clone
128 def since(self, date: str) -> "TransactionsSearch":
129 """
130 Filter transactions to those on or after the specified date.
132 Note: This filter is applied client-side as the API endpoint
133 doesn't support date filtering for transactions.
135 Args:
136 date: Date string in YYYY-MM-DD format
138 Returns:
139 A new TransactionsSearch instance with the date filter applied
141 Example:
142 >>> transactions = award.transactions.since("2024-01-01").all()
143 """
144 # Validate date format
145 try:
146 datetime.strptime(date, "%Y-%m-%d")
147 except ValueError:
148 raise ValidationError("Date must be in YYYY-MM-DD format")
150 clone = self._clone()
151 clone._client_filters["since_date"] = date
152 return clone
154 def until(self, date: str) -> "TransactionsSearch":
155 """
156 Filter transactions to those on or before the specified date.
158 Note: This filter is applied client-side as the API endpoint
159 doesn't support date filtering for transactions.
161 Args:
162 date: Date string in YYYY-MM-DD format
164 Returns:
165 A new TransactionsSearch instance with the date filter applied
167 Example:
168 >>> transactions = award.transactions.until("2024-12-31").all()
169 """
170 # Validate date format
171 try:
172 datetime.strptime(date, "%Y-%m-%d")
173 except ValueError:
174 raise ValidationError("Date must be in YYYY-MM-DD format")
176 clone = self._clone()
177 clone._client_filters["until_date"] = date
178 return clone
180 def _apply_client_filters(self, transaction: Transaction) -> bool:
181 """
182 Apply client-side filters to a transaction.
184 Args:
185 transaction: The transaction to filter
187 Returns:
188 True if transaction passes all filters, False otherwise
189 """
190 # Apply date filters
191 if "since_date" in self._client_filters:
192 since_date = datetime.strptime(
193 self._client_filters["since_date"], "%Y-%m-%d"
194 ).date()
195 if transaction.action_date and transaction.action_date.date() < since_date:
196 return False
198 if "until_date" in self._client_filters:
199 until_date = datetime.strptime(
200 self._client_filters["until_date"], "%Y-%m-%d"
201 ).date()
202 if transaction.action_date and transaction.action_date.date() > until_date:
203 return False
205 return True
207 def __iter__(self) -> Iterator[Transaction]:
208 """
209 Override iteration to apply client-side filters.
210 """
211 for transaction in super().__iter__():
212 if self._apply_client_filters(transaction):
213 yield transaction