Coverage for src/usaspending/queries/transactions_search.py: 60%

83 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-03 17:15 -0700

1"""Transactions search query builder for USASpending data.""" 

2 

3from __future__ import annotations 

4 

5from typing import Any, Dict, TYPE_CHECKING, Iterator 

6from datetime import datetime 

7 

8from ..exceptions import ValidationError 

9from ..models.transaction import Transaction 

10from .query_builder import QueryBuilder 

11from ..logging_config import USASpendingLogger 

12 

13if TYPE_CHECKING: 

14 from ..client import USASpending 

15 

16logger = USASpendingLogger.get_logger(__name__) 

17 

18 

19class TransactionsSearch(QueryBuilder["Transaction"]): 

20 """ 

21 Builds and executes a transactions search query, allowing for filtering 

22 on transaction data. This class follows a fluent interface pattern. 

23 """ 

24 

25 def __init__(self, client: "USASpending"): 

26 """ 

27 Initializes the TransactionsSearch query builder. 

28 

29 Args: 

30 client: The USASpending client instance. 

31 """ 

32 super().__init__(client) 

33 self._award_id: str = None 

34 # Client-side filters (not supported by API) 

35 self._client_filters = {} 

36 

37 @property 

38 def _endpoint(self) -> str: 

39 """The API endpoint for this query.""" 

40 return "/transactions/" 

41 

42 def _clone(self) -> TransactionsSearch: 

43 """Creates an immutable copy of the query builder.""" 

44 clone = super()._clone() 

45 clone._filter_objects = self._filter_objects.copy() 

46 clone._award_id = self._award_id 

47 clone._client_filters = self._client_filters.copy() 

48 return clone 

49 

50 def _build_payload(self, page: int) -> Dict[str, Any]: 

51 """Constructs the final API request payload from the filter objects.""" 

52 

53 if not self._award_id: 

54 raise ValidationError( 

55 "An award_id is required. Use the .for_award() method." 

56 ) 

57 

58 payload = { 

59 "award_id": self._award_id, 

60 "limit": self._get_effective_page_size(), 

61 "page": page, 

62 } 

63 

64 # Add any additional filters if they exist 

65 final_filters = self._aggregate_filters() 

66 if final_filters: 

67 payload.update(final_filters) 

68 

69 return payload 

70 

71 def _transform_result(self, result: Dict[str, Any]) -> Transaction: 

72 """Transforms a single API result item into a Transaction model.""" 

73 return Transaction(result) 

74 

75 def count(self) -> int: 

76 """Counts the number of transactions per a given award id.""" 

77 logger.debug(f"{self.__class__.__name__}.count() called") 

78 

79 # If we have client-side filters, we need to fetch all results and count 

80 if self._client_filters: 

81 logger.debug( 

82 "Client-side filters present, counting by iterating all results" 

83 ) 

84 count = 0 

85 for _ in self: 

86 count += 1 

87 return count 

88 

89 # No client-side filters, use the efficient API count endpoint 

90 endpoint = f"/awards/count/transaction/{self._award_id}/" 

91 

92 from ..logging_config import log_query_execution 

93 

94 log_query_execution(logger, "TransactionsSearch.count", 1, endpoint) 

95 

96 # Send the request to the count endpoint 

97 response = self._client._make_request("GET", endpoint) 

98 

99 # Extract count from the appropriate category 

100 total = response.get("transactions", 0) 

101 

102 logger.info( 

103 f"{self.__class__.__name__}.count() = {total} transactions for award {self._award_id}" 

104 ) 

105 return total 

106 

107 # ========================================================================== 

108 # Filter Methods 

109 # ========================================================================== 

110 

111 def for_award(self, award_id: str) -> TransactionsSearch: 

112 """ 

113 Filter transactions for a specific award. 

114 

115 Args: 

116 award_id: The unique award identifier. 

117 

118 Returns: 

119 A new `TransactionsSearch` instance with the award filter applied. 

120 """ 

121 if not award_id: 

122 raise ValidationError("award_id cannot be empty") 

123 

124 clone = self._clone() 

125 clone._award_id = str(award_id).strip() 

126 return clone 

127 

128 def since(self, date: str) -> "TransactionsSearch": 

129 """ 

130 Filter transactions to those on or after the specified date. 

131 

132 Note: This filter is applied client-side as the API endpoint 

133 doesn't support date filtering for transactions. 

134 

135 Args: 

136 date: Date string in YYYY-MM-DD format 

137 

138 Returns: 

139 A new TransactionsSearch instance with the date filter applied 

140 

141 Example: 

142 >>> transactions = award.transactions.since("2024-01-01").all() 

143 """ 

144 # Validate date format 

145 try: 

146 datetime.strptime(date, "%Y-%m-%d") 

147 except ValueError: 

148 raise ValidationError("Date must be in YYYY-MM-DD format") 

149 

150 clone = self._clone() 

151 clone._client_filters["since_date"] = date 

152 return clone 

153 

154 def until(self, date: str) -> "TransactionsSearch": 

155 """ 

156 Filter transactions to those on or before the specified date. 

157 

158 Note: This filter is applied client-side as the API endpoint 

159 doesn't support date filtering for transactions. 

160 

161 Args: 

162 date: Date string in YYYY-MM-DD format 

163 

164 Returns: 

165 A new TransactionsSearch instance with the date filter applied 

166 

167 Example: 

168 >>> transactions = award.transactions.until("2024-12-31").all() 

169 """ 

170 # Validate date format 

171 try: 

172 datetime.strptime(date, "%Y-%m-%d") 

173 except ValueError: 

174 raise ValidationError("Date must be in YYYY-MM-DD format") 

175 

176 clone = self._clone() 

177 clone._client_filters["until_date"] = date 

178 return clone 

179 

180 def _apply_client_filters(self, transaction: Transaction) -> bool: 

181 """ 

182 Apply client-side filters to a transaction. 

183 

184 Args: 

185 transaction: The transaction to filter 

186 

187 Returns: 

188 True if transaction passes all filters, False otherwise 

189 """ 

190 # Apply date filters 

191 if "since_date" in self._client_filters: 

192 since_date = datetime.strptime( 

193 self._client_filters["since_date"], "%Y-%m-%d" 

194 ).date() 

195 if transaction.action_date and transaction.action_date.date() < since_date: 

196 return False 

197 

198 if "until_date" in self._client_filters: 

199 until_date = datetime.strptime( 

200 self._client_filters["until_date"], "%Y-%m-%d" 

201 ).date() 

202 if transaction.action_date and transaction.action_date.date() > until_date: 

203 return False 

204 

205 return True 

206 

207 def __iter__(self) -> Iterator[Transaction]: 

208 """ 

209 Override iteration to apply client-side filters. 

210 """ 

211 for transaction in super().__iter__(): 

212 if self._apply_client_filters(transaction): 

213 yield transaction