Coverage for src/usaspending/queries/subawards_search.py: 90%
79 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 17:15 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 17:15 -0700
1"""Subawards search query builder for USASpending data."""
3from __future__ import annotations
5from typing import Any, Dict, TYPE_CHECKING, Optional
7from ..exceptions import ValidationError
8from ..models.subaward import SubAward
9from .awards_search import AwardsSearch
10from ..logging_config import USASpendingLogger
11from ..config import AWARD_TYPE_GROUPS
13if TYPE_CHECKING:
14 from ..client import USASpending
16logger = USASpendingLogger.get_logger(__name__)
19class SubAwardsSearch(AwardsSearch):
20 """
21 Builds and executes a subawards search query, allowing for complex
22 filtering on subaward data. This class extends AwardsSearch to reuse
23 filter logic while specializing for subawards.
24 """
26 def __init__(self, client: "USASpending"):
27 """
28 Initializes the SubAwardsSearch query builder.
30 Args:
31 client: The USASpending client instance.
32 """
33 super().__init__(client)
34 self._award_id: Optional[str] = None
36 def _clone(self) -> SubAwardsSearch:
37 """Creates an immutable copy of the query builder."""
38 clone = SubAwardsSearch(self._client)
39 clone._filters = self._filters.copy()
40 clone._filter_objects = self._filter_objects.copy()
41 clone._page_size = self._page_size
42 clone._total_limit = self._total_limit
43 clone._max_pages = self._max_pages
44 clone._order_by = self._order_by
45 clone._order_direction = self._order_direction
46 clone._award_id = self._award_id
47 return clone
49 def _build_payload(self, page: int) -> Dict[str, Any]:
50 """
51 Constructs the final API request payload for subawards.
53 Overrides parent to always include subawards=true and spending_level=subawards.
54 """
55 payload = super()._build_payload(page)
57 # Always search for subawards
58 payload["subawards"] = True
59 payload["spending_level"] = "subawards"
61 # If filtering by specific award, add to filters
62 if self._award_id:
63 if "filters" not in payload:
64 payload["filters"] = {}
65 payload["filters"]["award_unique_id"] = self._award_id
67 return payload
69 def _transform_result(self, result: Dict[str, Any]) -> SubAward:
70 """Transforms a single API result item into a SubAward model."""
71 return SubAward(result, self._client)
73 def _get_fields(self) -> list[str]:
74 """
75 Determines the list of fields to request based on award type filters.
77 Returns different field sets depending on the award type codes:
78 - Contracts: Contract subaward fields
79 - Grants/Assistance: Grant subaward fields
80 """
81 # Get award type codes from filters
82 award_types = self._get_award_type_codes()
84 # Determine if we're dealing with contracts or grants
85 is_contract = False
86 is_grant = False
88 for category_name, codes in AWARD_TYPE_GROUPS.items():
89 if award_types & frozenset(codes.keys()):
90 if category_name in ["contracts"]:
91 is_contract = True
92 elif category_name in ["grants"]:
93 is_grant = True
95 # Return appropriate field set
96 if is_contract and not is_grant:
97 return SubAward.CONTRACT_SUBAWARD_FIELDS.copy()
98 elif is_grant and not is_contract:
99 return SubAward.GRANT_SUBAWARD_FIELDS.copy()
100 else:
101 # If both or neither, return union of both field sets
102 fields = set(SubAward.CONTRACT_SUBAWARD_FIELDS)
103 fields.update(SubAward.GRANT_SUBAWARD_FIELDS)
104 return list(fields)
106 def count(self) -> int:
107 """
108 Get the total count of subawards.
110 If filtering by a specific award, uses the efficient count endpoint.
111 Otherwise falls back to parent implementation.
112 """
113 logger.debug(f"{self.__class__.__name__}.count() called")
115 # If we have an award_id filter, use the efficient count endpoint
116 if self._award_id:
117 endpoint = f"/awards/count/subaward/{self._award_id}/"
119 from ..logging_config import log_query_execution
120 log_query_execution(logger, "SubAwardsSearch.count", 1, endpoint)
122 # Send the request to the count endpoint
123 response = self._client._make_request("GET", endpoint)
125 # Extract count from response
126 total = response.get("subawards", 0)
128 logger.info(
129 f"{self.__class__.__name__}.count() = {total} subawards for award {self._award_id}"
130 )
131 return total
133 # Fall back to parent implementation for general subaward counting
134 # This is inefficient, but it's the only way to get the count
135 # without a dedicated endpoint for subaward searches.
136 # The parent's count() method will iterate through all pages.
137 # return super().count()
138 # For now, let's just iterate and count
139 count = 0
140 for _ in self:
141 count += 1
142 return count
144 def count_awards_by_type(self) -> Dict[str, int]:
145 """
146 Override parent method to use subawards-specific count endpoint.
148 Returns:
149 A dictionary mapping award type categories to their subaward counts.
150 """
151 endpoint = "/search/spending_by_award_count/"
152 final_filters = self._aggregate_filters()
154 payload = {
155 "filters": final_filters,
156 "subawards": True, # Always count subawards
157 "spending_level": "subawards"
158 }
160 from ..logging_config import log_query_execution
162 log_query_execution(
163 logger, "SubAwardsSearch.count_awards_by_type", len(self._filter_objects), endpoint
164 )
166 # Send the request to the count endpoint
167 response = self._client._make_request("POST", endpoint, json=payload)
169 # Extract and return aggregations
170 return response.get("aggregations", {})
172 def for_award(self, award_id: str) -> SubAwardsSearch:
173 """
174 Filter subawards for a specific prime award.
176 Args:
177 award_id: The unique generated award identifier.
179 Returns:
180 A new SubAwardsSearch instance with the award filter applied.
182 Example:
183 >>> subawards = client.subawards.for_award("CONT_AWD_123...")
184 >>> for sub in subawards:
185 ... print(f"{sub.sub_awardee_name}: ${sub.sub_award_amount:,.2f}")
186 """
187 if not award_id:
188 raise ValidationError("award_id cannot be empty")
190 clone = self._clone()
191 clone._award_id = str(award_id).strip()
192 return clone