Coverage for src/usaspending/queries/agencies_search.py: 94%
87 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 17:15 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 17:15 -0700
1"""Agencies search query implementation for funding agency/office autocomplete."""
3from __future__ import annotations
4from typing import Dict, Any, TYPE_CHECKING
5from ..exceptions import ValidationError
6from ..models.agency import Agency
7from ..models.subtier_agency import SubTierAgency
8from .query_builder import QueryBuilder
9from ..logging_config import USASpendingLogger
11if TYPE_CHECKING:
12 from ..client import USASpending
14logger = USASpendingLogger.get_logger(__name__)
17class AgenciesSearch(QueryBuilder[Agency]):
18 """Search for funding agencies and offices by name using autocomplete.
20 This query builder uses the /v2/autocomplete/funding_agency_office/ endpoint
21 to search for agencies by name. Results can be filtered by type (toptier,
22 subtier, or office).
23 """
25 def __init__(self, client: USASpending):
26 """Initialize AgenciesSearch with client."""
27 super().__init__(client)
28 self._search_text = ""
29 self._limit = 100 # Default limit
30 self._result_type = None # Filter: None, 'toptier', 'subtier', 'office'
32 @property
33 def _endpoint(self) -> str:
34 """API endpoint for agency autocomplete."""
35 raise NotImplementedError("Subclasses must implement _endpoint")
37 def _build_payload(self, page: int) -> Dict[str, Any]:
38 """Build request payload."""
39 if not self._search_text:
40 raise ValidationError("search_text is required. Use with_search_text() method.")
42 return {
43 "search_text": self._search_text,
44 "limit": self._limit
45 }
47 def _execute_query(self, page: int) -> Dict[str, Any]:
48 """Execute the autocomplete query.
50 This endpoint doesn't support pagination, so only return results on page 1.
51 """
52 # No pagination - return empty after first page
53 if page > 1:
54 return {"results": [], "page_metadata": {"hasNext": False}}
56 payload = self._build_payload(1)
57 response = self._client._make_request("POST", self._endpoint, json=payload)
59 # The response has results as an object with three arrays
60 # We need to flatten this into a single results array for QueryBuilder
61 results_obj = response.get("results", {})
62 flat_results = []
64 # Add results based on filter type
65 if self._result_type is None:
66 # No filter - return all types
67 # Add toptier agencies
68 for agency in results_obj.get("toptier_agency", []):
69 flat_results.append({"type": "toptier", "data": agency})
71 # Add subtier agencies
72 for subtier in results_obj.get("subtier_agency", []):
73 flat_results.append({"type": "subtier", "data": subtier})
75 # Add offices
76 for office in results_obj.get("office", []):
77 flat_results.append({"type": "office", "data": office})
79 elif self._result_type == "toptier":
80 for agency in results_obj.get("toptier_agency", []):
81 flat_results.append({"type": "toptier", "data": agency})
83 elif self._result_type == "subtier":
84 for subtier in results_obj.get("subtier_agency", []):
85 flat_results.append({"type": "subtier", "data": subtier})
87 elif self._result_type == "office":
88 for office in results_obj.get("office", []):
89 flat_results.append({"type": "office", "data": office})
91 # Return flattened structure for QueryBuilder
92 return {
93 "results": flat_results,
94 "page_metadata": {"hasNext": False},
95 "messages": response.get("messages", [])
96 }
98 def _transform_result(self, result: Dict[str, Any]) -> Agency:
99 """Transform result into Agency object based on type."""
100 if not result:
101 return None
103 result_type = result.get("type")
104 data = result.get("data", {})
106 if result_type == "toptier":
107 # Direct toptier agency result
108 agency_data = {
109 "code": data.get("code"),
110 "toptier_code": data.get("code"),
111 "name": data.get("name"),
112 "abbreviation": data.get("abbreviation")
113 }
114 return Agency(agency_data, self._client)
116 elif result_type == "subtier":
117 # Include subtier data
118 return SubTierAgency(data, self._client)
120 elif result_type == "office":
121 return SubTierAgency(data, self._client)
123 return None
125 def count(self) -> int:
126 """Get total count of matching agencies/offices.
128 Returns:
129 Total number of matching results
130 """
131 logger.debug(f"{self.__class__.__name__}.count() called")
133 if not self._search_text:
134 raise ValidationError("search_text is required. Use with_search_text() method.")
136 # Execute query to get all results
137 response = self._execute_query(1)
138 results = response.get("results", [])
140 count = len(results)
142 logger.info(
143 f"{self.__class__.__name__}.count() = {count} results "
144 f"for search text '{self._search_text}'"
145 )
146 return count
148 def with_search_text(self, search_text: str) -> AgenciesSearch:
149 """Set the search text for the query.
151 Args:
152 search_text: Text to search for in agency names
154 Returns:
155 New AgenciesSearch instance with search text set
156 """
157 clone = self._clone()
158 clone._search_text = search_text
159 return clone
161 def toptier(self) -> AgenciesSearch:
162 """Filter to only return toptier agency matches.
164 Returns:
165 New AgenciesSearch instance filtered to toptier agencies
166 """
167 clone = self._clone()
168 clone._result_type = "toptier"
169 return clone
171 def subtier(self) -> AgenciesSearch:
172 """Filter to only return subtier agency matches.
174 Returns:
175 New AgenciesSearch instance filtered to subtier agencies
176 """
177 clone = self._clone()
178 clone._result_type = "subtier"
179 return clone
181 def office(self) -> AgenciesSearch:
182 """Filter to only return office matches.
184 Returns:
185 New AgenciesSearch instance filtered to offices
186 """
187 clone = self._clone()
188 clone._result_type = "office"
189 return clone
191 def _clone(self) -> AgenciesSearch:
192 """Create immutable copy for chaining."""
193 clone = super()._clone()
194 clone._search_text = self._search_text
195 return clone