Coverage for dj/service_clients.py: 100%
47 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
1"""Clients for various configurable services."""
2from typing import TYPE_CHECKING, List, Optional
3from urllib.parse import urljoin
4from uuid import UUID
6import requests
7from requests.adapters import HTTPAdapter
8from urllib3 import Retry
10from dj.errors import DJQueryServiceClientException
11from dj.models.column import Column
12from dj.models.query import QueryCreate, QueryWithResults
13from dj.sql.parsing.types import ColumnType
15if TYPE_CHECKING:
16 from dj.models.engine import Engine
19class RequestsSessionWithEndpoint(requests.Session):
20 """
21 Creates a requests session that comes with an endpoint that all
22 subsequent requests will use as a prefix.
23 """
25 def __init__(self, endpoint: str = None, retry_strategy: Retry = None):
26 super().__init__()
27 self.endpoint = endpoint
28 self.mount("http://", HTTPAdapter(max_retries=retry_strategy))
29 self.mount("https://", HTTPAdapter(max_retries=retry_strategy))
31 def request(self, method, url, *args, **kwargs):
32 """
33 Make the request with the full URL.
34 """
35 url = self.construct_url(url)
36 return super().request(method, url, *args, **kwargs)
38 def prepare_request(self, request, *args, **kwargs):
39 """
40 Prepare the request with the full URL.
41 """
42 request.url = self.construct_url(request.url)
43 return super().prepare_request(
44 request,
45 *args,
46 **kwargs,
47 )
49 def construct_url(self, url):
50 """
51 Construct full URL based off the endpoint.
52 """
53 return urljoin(self.endpoint, url)
56class QueryServiceClient: # pylint: disable=too-few-public-methods
57 """
58 Client for the query service.
59 """
61 def __init__(self, uri: str, retries: int = 2):
62 self.uri = uri
63 retry_strategy = Retry(
64 total=retries,
65 backoff_factor=1.5,
66 status_forcelist=[429, 500, 502, 503, 504],
67 allowed_methods=["GET", "POST", "PUT", "PATCH"],
68 )
69 self.requests_session = RequestsSessionWithEndpoint(
70 endpoint=self.uri,
71 retry_strategy=retry_strategy,
72 )
74 def get_columns_for_table(
75 self,
76 catalog: str,
77 schema: str,
78 table: str,
79 engine: Optional["Engine"] = None,
80 ) -> List[Column]:
81 """
82 Retrieves columns for a table.
83 """
84 response = self.requests_session.get(
85 f"/table/{catalog}.{schema}.{table}/columns/",
86 params={
87 "engine": engine.name,
88 "engine_version": engine.version,
89 }
90 if engine
91 else {},
92 )
93 table_columns = response.json()["columns"]
94 return [
95 Column(name=column["name"], type=ColumnType(column["type"]))
96 for column in table_columns
97 ]
99 def submit_query( # pylint: disable=too-many-arguments
100 self,
101 query_create: QueryCreate,
102 ) -> QueryWithResults:
103 """
104 Submit a query to the query service
105 """
106 response = self.requests_session.post(
107 "/queries/",
108 json=query_create.dict(),
109 )
110 response_data = response.json()
111 if not response.ok:
112 raise DJQueryServiceClientException(
113 message=f"Error response from query service: {response_data['message']}",
114 )
115 query_info = response.json()
116 query_info["id"] = UUID(query_info["id"])
117 return QueryWithResults(**query_info)
119 def get_query(
120 self,
121 query_id: str,
122 ) -> QueryWithResults:
123 """
124 Get a previously submitted query
125 """
126 response = self.requests_session.get(f"/queries/{query_id}/")
127 if not response.ok:
128 raise DJQueryServiceClientException(
129 message=f"Error response from query service: {response.text}",
130 )
131 query_info = response.json()
132 return QueryWithResults(**query_info)