Coverage for dj/utils.py: 100%

64 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Utility functions. 

3""" 

4import logging 

5import os 

6import re 

7from enum import Enum 

8from functools import lru_cache 

9 

10# pylint: disable=line-too-long 

11from typing import Iterator, List, Optional 

12 

13from dotenv import load_dotenv 

14from rich.logging import RichHandler 

15from sqlalchemy.engine import Engine 

16from sqlmodel import Session, create_engine 

17from yarl import URL 

18 

19from dj.config import Settings 

20from dj.errors import DJException 

21from dj.service_clients import QueryServiceClient 

22 

23 

24def setup_logging(loglevel: str) -> None: 

25 """ 

26 Setup basic logging. 

27 """ 

28 level = getattr(logging, loglevel.upper(), None) 

29 if not isinstance(level, int): 

30 raise ValueError(f"Invalid log level: {loglevel}") 

31 

32 logformat = "[%(asctime)s] %(levelname)s: %(name)s: %(message)s" 

33 logging.basicConfig( 

34 level=level, 

35 format=logformat, 

36 datefmt="[%X]", 

37 handlers=[RichHandler(rich_tracebacks=True)], 

38 force=True, 

39 ) 

40 

41 

42@lru_cache 

43def get_settings() -> Settings: 

44 """ 

45 Return a cached settings object. 

46 """ 

47 dotenv_file = os.environ.get("DOTENV_FILE", ".env") 

48 load_dotenv(dotenv_file) 

49 return Settings() 

50 

51 

52def get_engine() -> Engine: 

53 """ 

54 Create the metadata engine. 

55 """ 

56 settings = get_settings() 

57 engine = create_engine(settings.index) 

58 

59 return engine 

60 

61 

62def get_session() -> Iterator[Session]: 

63 """ 

64 Per-request session. 

65 """ 

66 engine = get_engine() 

67 

68 with Session(engine, autoflush=False) as session: # pragma: no cover 

69 yield session 

70 

71 

72def get_query_service_client() -> Optional[QueryServiceClient]: 

73 """ 

74 Return query service client 

75 """ 

76 settings = get_settings() 

77 if not settings.query_service: # pragma: no cover 

78 return None 

79 return QueryServiceClient(settings.query_service) 

80 

81 

82def get_issue_url( 

83 baseurl: URL = URL("https://github.com/DataJunction/dj/issues/new"), 

84 title: Optional[str] = None, 

85 body: Optional[str] = None, 

86 labels: Optional[List[str]] = None, 

87) -> URL: 

88 """ 

89 Return the URL to file an issue on GitHub. 

90 

91 https://docs.github.com/en/issues/tracking-your-work-with-issues/creating-an-issue#creating-an-issue-from-a-url-query 

92 """ 

93 query_arguments = { 

94 "title": title, 

95 "body": body, 

96 "labels": ",".join(label.strip() for label in labels) if labels else None, 

97 } 

98 query_arguments = {k: v for k, v in query_arguments.items() if v is not None} 

99 

100 return baseurl % query_arguments 

101 

102 

103class VersionUpgrade(str, Enum): 

104 """ 

105 The version upgrade type 

106 """ 

107 

108 MAJOR = "major" 

109 MINOR = "minor" 

110 

111 

112class Version: 

113 """ 

114 Represents a basic semantic version with only major & minor parts. 

115 Used for tracking node versioning. 

116 """ 

117 

118 def __init__(self, major, minor): 

119 self.major = major 

120 self.minor = minor 

121 

122 def __str__(self) -> str: 

123 return f"v{self.major}.{self.minor}" 

124 

125 @classmethod 

126 def parse(cls, version_string) -> "Version": 

127 """ 

128 Parse a version string. 

129 """ 

130 version_regex = re.compile(r"^v(?P<major>[0-9]+)\.(?P<minor>[0-9]+)") 

131 matcher = version_regex.search(version_string) 

132 if not matcher: 

133 raise DJException( 

134 http_status_code=500, 

135 message=f"Unparseable version {version_string}!", 

136 ) 

137 results = matcher.groupdict() 

138 return Version(int(results["major"]), int(results["minor"])) 

139 

140 def next_minor_version(self) -> "Version": 

141 """ 

142 Returns the next minor version 

143 """ 

144 return Version(self.major, self.minor + 1) 

145 

146 def next_major_version(self) -> "Version": 

147 """ 

148 Returns the next major version 

149 """ 

150 return Version(self.major + 1, 0) 

151 

152 

153def get_namespace_from_name(name: str) -> str: 

154 """ 

155 Splits a qualified node name into it's namespace and name parts 

156 """ 

157 if "." in name: 

158 node_namespace, _ = name.rsplit(".", 1) 

159 else: 

160 node_namespace = "default" 

161 return node_namespace