Coverage for src/usaspending/queries/agencies_search.py: 94%

87 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-03 17:15 -0700

1"""Agencies search query implementation for funding agency/office autocomplete.""" 

2 

3from __future__ import annotations 

4from typing import Dict, Any, TYPE_CHECKING 

5from ..exceptions import ValidationError 

6from ..models.agency import Agency 

7from ..models.subtier_agency import SubTierAgency 

8from .query_builder import QueryBuilder 

9from ..logging_config import USASpendingLogger 

10 

11if TYPE_CHECKING: 

12 from ..client import USASpending 

13 

14logger = USASpendingLogger.get_logger(__name__) 

15 

16 

17class AgenciesSearch(QueryBuilder[Agency]): 

18 """Search for funding agencies and offices by name using autocomplete. 

19  

20 This query builder uses the /v2/autocomplete/funding_agency_office/ endpoint 

21 to search for agencies by name. Results can be filtered by type (toptier,  

22 subtier, or office). 

23 """ 

24 

25 def __init__(self, client: USASpending): 

26 """Initialize AgenciesSearch with client.""" 

27 super().__init__(client) 

28 self._search_text = "" 

29 self._limit = 100 # Default limit 

30 self._result_type = None # Filter: None, 'toptier', 'subtier', 'office' 

31 

32 @property 

33 def _endpoint(self) -> str: 

34 """API endpoint for agency autocomplete.""" 

35 raise NotImplementedError("Subclasses must implement _endpoint") 

36 

37 def _build_payload(self, page: int) -> Dict[str, Any]: 

38 """Build request payload.""" 

39 if not self._search_text: 

40 raise ValidationError("search_text is required. Use with_search_text() method.") 

41 

42 return { 

43 "search_text": self._search_text, 

44 "limit": self._limit 

45 } 

46 

47 def _execute_query(self, page: int) -> Dict[str, Any]: 

48 """Execute the autocomplete query. 

49  

50 This endpoint doesn't support pagination, so only return results on page 1. 

51 """ 

52 # No pagination - return empty after first page 

53 if page > 1: 

54 return {"results": [], "page_metadata": {"hasNext": False}} 

55 

56 payload = self._build_payload(1) 

57 response = self._client._make_request("POST", self._endpoint, json=payload) 

58 

59 # The response has results as an object with three arrays 

60 # We need to flatten this into a single results array for QueryBuilder 

61 results_obj = response.get("results", {}) 

62 flat_results = [] 

63 

64 # Add results based on filter type 

65 if self._result_type is None: 

66 # No filter - return all types 

67 # Add toptier agencies 

68 for agency in results_obj.get("toptier_agency", []): 

69 flat_results.append({"type": "toptier", "data": agency}) 

70 

71 # Add subtier agencies  

72 for subtier in results_obj.get("subtier_agency", []): 

73 flat_results.append({"type": "subtier", "data": subtier}) 

74 

75 # Add offices 

76 for office in results_obj.get("office", []): 

77 flat_results.append({"type": "office", "data": office}) 

78 

79 elif self._result_type == "toptier": 

80 for agency in results_obj.get("toptier_agency", []): 

81 flat_results.append({"type": "toptier", "data": agency}) 

82 

83 elif self._result_type == "subtier": 

84 for subtier in results_obj.get("subtier_agency", []): 

85 flat_results.append({"type": "subtier", "data": subtier}) 

86 

87 elif self._result_type == "office": 

88 for office in results_obj.get("office", []): 

89 flat_results.append({"type": "office", "data": office}) 

90 

91 # Return flattened structure for QueryBuilder 

92 return { 

93 "results": flat_results, 

94 "page_metadata": {"hasNext": False}, 

95 "messages": response.get("messages", []) 

96 } 

97 

98 def _transform_result(self, result: Dict[str, Any]) -> Agency: 

99 """Transform result into Agency object based on type.""" 

100 if not result: 

101 return None 

102 

103 result_type = result.get("type") 

104 data = result.get("data", {}) 

105 

106 if result_type == "toptier": 

107 # Direct toptier agency result 

108 agency_data = { 

109 "code": data.get("code"), 

110 "toptier_code": data.get("code"), 

111 "name": data.get("name"), 

112 "abbreviation": data.get("abbreviation") 

113 } 

114 return Agency(agency_data, self._client) 

115 

116 elif result_type == "subtier": 

117 # Include subtier data 

118 return SubTierAgency(data, self._client) 

119 

120 elif result_type == "office": 

121 return SubTierAgency(data, self._client) 

122 

123 return None 

124 

125 def count(self) -> int: 

126 """Get total count of matching agencies/offices. 

127  

128 Returns: 

129 Total number of matching results 

130 """ 

131 logger.debug(f"{self.__class__.__name__}.count() called") 

132 

133 if not self._search_text: 

134 raise ValidationError("search_text is required. Use with_search_text() method.") 

135 

136 # Execute query to get all results 

137 response = self._execute_query(1) 

138 results = response.get("results", []) 

139 

140 count = len(results) 

141 

142 logger.info( 

143 f"{self.__class__.__name__}.count() = {count} results " 

144 f"for search text '{self._search_text}'" 

145 ) 

146 return count 

147 

148 def with_search_text(self, search_text: str) -> AgenciesSearch: 

149 """Set the search text for the query. 

150  

151 Args: 

152 search_text: Text to search for in agency names 

153  

154 Returns: 

155 New AgenciesSearch instance with search text set 

156 """ 

157 clone = self._clone() 

158 clone._search_text = search_text 

159 return clone 

160 

161 def toptier(self) -> AgenciesSearch: 

162 """Filter to only return toptier agency matches. 

163  

164 Returns: 

165 New AgenciesSearch instance filtered to toptier agencies 

166 """ 

167 clone = self._clone() 

168 clone._result_type = "toptier" 

169 return clone 

170 

171 def subtier(self) -> AgenciesSearch: 

172 """Filter to only return subtier agency matches. 

173  

174 Returns: 

175 New AgenciesSearch instance filtered to subtier agencies 

176 """ 

177 clone = self._clone() 

178 clone._result_type = "subtier" 

179 return clone 

180 

181 def office(self) -> AgenciesSearch: 

182 """Filter to only return office matches. 

183  

184 Returns: 

185 New AgenciesSearch instance filtered to offices 

186 """ 

187 clone = self._clone() 

188 clone._result_type = "office" 

189 return clone 

190 

191 def _clone(self) -> AgenciesSearch: 

192 """Create immutable copy for chaining.""" 

193 clone = super()._clone() 

194 clone._search_text = self._search_text 

195 return clone