Coverage for dj/service_clients.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1"""Clients for various configurable services.""" 

2from typing import TYPE_CHECKING, List, Optional 

3from urllib.parse import urljoin 

4from uuid import UUID 

5 

6import requests 

7from requests.adapters import HTTPAdapter 

8from urllib3 import Retry 

9 

10from dj.errors import DJQueryServiceClientException 

11from dj.models.column import Column 

12from dj.models.query import QueryCreate, QueryWithResults 

13from dj.sql.parsing.types import ColumnType 

14 

15if TYPE_CHECKING: 

16 from dj.models.engine import Engine 

17 

18 

19class RequestsSessionWithEndpoint(requests.Session): 

20 """ 

21 Creates a requests session that comes with an endpoint that all 

22 subsequent requests will use as a prefix. 

23 """ 

24 

25 def __init__(self, endpoint: str = None, retry_strategy: Retry = None): 

26 super().__init__() 

27 self.endpoint = endpoint 

28 self.mount("http://", HTTPAdapter(max_retries=retry_strategy)) 

29 self.mount("https://", HTTPAdapter(max_retries=retry_strategy)) 

30 

31 def request(self, method, url, *args, **kwargs): 

32 """ 

33 Make the request with the full URL. 

34 """ 

35 url = self.construct_url(url) 

36 return super().request(method, url, *args, **kwargs) 

37 

38 def prepare_request(self, request, *args, **kwargs): 

39 """ 

40 Prepare the request with the full URL. 

41 """ 

42 request.url = self.construct_url(request.url) 

43 return super().prepare_request( 

44 request, 

45 *args, 

46 **kwargs, 

47 ) 

48 

49 def construct_url(self, url): 

50 """ 

51 Construct full URL based off the endpoint. 

52 """ 

53 return urljoin(self.endpoint, url) 

54 

55 

56class QueryServiceClient: # pylint: disable=too-few-public-methods 

57 """ 

58 Client for the query service. 

59 """ 

60 

61 def __init__(self, uri: str, retries: int = 2): 

62 self.uri = uri 

63 retry_strategy = Retry( 

64 total=retries, 

65 backoff_factor=1.5, 

66 status_forcelist=[429, 500, 502, 503, 504], 

67 allowed_methods=["GET", "POST", "PUT", "PATCH"], 

68 ) 

69 self.requests_session = RequestsSessionWithEndpoint( 

70 endpoint=self.uri, 

71 retry_strategy=retry_strategy, 

72 ) 

73 

74 def get_columns_for_table( 

75 self, 

76 catalog: str, 

77 schema: str, 

78 table: str, 

79 engine: Optional["Engine"] = None, 

80 ) -> List[Column]: 

81 """ 

82 Retrieves columns for a table. 

83 """ 

84 response = self.requests_session.get( 

85 f"/table/{catalog}.{schema}.{table}/columns/", 

86 params={ 

87 "engine": engine.name, 

88 "engine_version": engine.version, 

89 } 

90 if engine 

91 else {}, 

92 ) 

93 table_columns = response.json()["columns"] 

94 return [ 

95 Column(name=column["name"], type=ColumnType(column["type"])) 

96 for column in table_columns 

97 ] 

98 

99 def submit_query( # pylint: disable=too-many-arguments 

100 self, 

101 query_create: QueryCreate, 

102 ) -> QueryWithResults: 

103 """ 

104 Submit a query to the query service 

105 """ 

106 response = self.requests_session.post( 

107 "/queries/", 

108 json=query_create.dict(), 

109 ) 

110 response_data = response.json() 

111 if not response.ok: 

112 raise DJQueryServiceClientException( 

113 message=f"Error response from query service: {response_data['message']}", 

114 ) 

115 query_info = response.json() 

116 query_info["id"] = UUID(query_info["id"]) 

117 return QueryWithResults(**query_info) 

118 

119 def get_query( 

120 self, 

121 query_id: str, 

122 ) -> QueryWithResults: 

123 """ 

124 Get a previously submitted query 

125 """ 

126 response = self.requests_session.get(f"/queries/{query_id}/") 

127 if not response.ok: 

128 raise DJQueryServiceClientException( 

129 message=f"Error response from query service: {response.text}", 

130 ) 

131 query_info = response.json() 

132 return QueryWithResults(**query_info)