Coverage for src / superset_io / api.py: 73%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-02 13:07 +0100

1from __future__ import annotations 

2 

3import io 

4import json 

5import logging 

6import shutil 

7import tempfile 

8import zipfile 

9from pathlib import Path 

10 

11import requests 

12 

13from superset_io.session import SupersetApiSession 

14from superset_io.utils import ( 

15 validate_assets_bundle_structure, 

16 zipfile_buffer_from_folder, 

17 zipfile_buffer_from_zipfile, 

18) 

19 

20log = logging.getLogger("superset_io") 

21 

22 

23class SuperSetApiClient: 

24 session: SupersetApiSession 

25 

26 def __init__(self, session: SupersetApiSession): 

27 self.session = session 

28 

29 # ---------------------------- Public Methods ---------------------------- # 

30 

31 def test_connection(self) -> bool: 

32 """Test connection is possible. 

33 

34 Smoke test that: 

35 0) We can connect via /health 

36 1) Bearer auth is accepted (GET /api/v1/log/) 

37 2) CSRF token header is accepted (POST /api/v1/assets/import/) 

38 

39 Returns true if accessible returns false if not 

40 """ 

41 

42 # 0) Server reachable 

43 res = self.session.get("/health") 

44 try: 

45 res.raise_for_status() 

46 log.info("✅ Server is reachable.") 

47 except requests.HTTPError as e: 

48 log.error(f"❌ Server not reachable: {e}") 

49 log.debug(f" {e.response.text}" if e.response else "") 

50 return False 

51 

52 # 1) Access token works 

53 res = self.session.get("/api/v1/log/") 

54 try: 

55 res.raise_for_status() 

56 log.info("✅ Access token working, can download assets.") 

57 except requests.HTTPError as e: 

58 msg = "❌ Could not access API that requires bearer token" 

59 if res.status_code == 401: 

60 msg += ( 

61 ". Check credentials and JWT_ALGORITHM in your superset_config.py" 

62 ) 

63 log.error(f"{msg}\n {e}") 

64 return False 

65 

66 # 2) CSRF works: pick a POST endpoint that requires CSRF, 

67 # Send invalid payload so we don't create anything. 

68 # Expectation: 

69 # - If CSRF is missing/invalid => typically 400/403 (CSRF-related) 

70 # - If CSRF is accepted => typically 400/422 (payload validation) or 403 

71 if not self.session.headers.get("X-CSRFToken"): 

72 log.error( 

73 "❌ No X-CSRFToken set on session; cannot validate CSRF handling." 

74 ) 

75 return False 

76 

77 res = self.session.post( 

78 "/api/v1/assets/import/", 

79 json={}, # invalid; we only want to get past CSRF 

80 headers={"Referer": self.session.base_url.rstrip("/") + "/"}, 

81 ) 

82 

83 try: 

84 # Anything else is unexpected 

85 res.raise_for_status() 

86 log.info("✅ CSRF token working, can upload assets.") 

87 except requests.HTTPError as e: 

88 try: 

89 error = res.json()["errors"][0]["error_type"] 

90 if error == "INVALID_PAYLOAD_FORMAT_ERROR": 

91 log.info("✅ CSRF token working, can upload assets.") 

92 else: 

93 raise ValueError("Expected INVALID_PAYLOAD_FORMAT_ERROR error!") 

94 except Exception: 

95 log.error(f"❌ CSRF validation failed: {e} {res.text}") 

96 return False 

97 

98 return True 

99 

100 def download_assets(self, dst_path: Path): 

101 """Download and export all assets to disk. 

102 

103 Depending an provided dst_path, we either write as zip file or the extracted 

104 folder structure. 

105 """ 

106 

107 if dst_path.suffix.lower() == ".zip": 

108 kind = "zip" 

109 else: 

110 kind = "folder" 

111 dst_path.mkdir(parents=True, exist_ok=True) 

112 

113 if kind == "folder" and any(dst_path.iterdir()): 

114 raise ValueError(f"Destination directory '{dst_path}' is not empty") 

115 

116 zip_bytes, zip_file = self._get_assets_zip() 

117 

118 if kind == "zip": 

119 dst_path.parent.mkdir(parents=True, exist_ok=True) 

120 with dst_path.open("wb") as f: 

121 f.write(zip_bytes) 

122 else: 

123 # Extract to temp dir, get the assets and move to dst_path 

124 with tempfile.TemporaryDirectory() as tmpdir: 

125 tmp_path = Path(tmpdir) 

126 zip_file.extractall(tmp_path) 

127 

128 src_folders = [ 

129 f for f in tmp_path.iterdir() if f.name.startswith("assets_export") 

130 ] 

131 if len(src_folders) != 1: 

132 raise ValueError( 

133 "Did not find a single `assets_export` folder in zip. " 

134 "This should not happen." 

135 ) 

136 

137 for item in src_folders[0].iterdir(): 

138 shutil.move(item, dst_path / item.name) 

139 

140 def upload_assets(self, src_path: Path): 

141 """Upload and restore assets from disk. 

142 

143 src_path can be zip or directory. 

144 If directory, needs to directly contain the metadata.yml. 

145 """ 

146 

147 if src_path.suffix.lower() == ".zip": 

148 zipfile_buffer = zipfile_buffer_from_zipfile(src_path) 

149 else: 

150 zipfile_buffer = zipfile_buffer_from_folder(src_path) 

151 

152 # raise for invalid zips, already before trying the endpoint 

153 validate_assets_bundle_structure(zipfile_buffer) 

154 

155 self._post_assets(zipfile_buffer=zipfile_buffer) 

156 

157 # ----------------------------- Internal Use ----------------------------- # 

158 

159 def _get_dashboard(self, dashboard_id: int): 

160 """Get a single dashboard. 

161 

162 For actual downloading, better use more modern assets api.""" 

163 url = f"{self.session.base_url}/dashboard/{dashboard_id}" 

164 response = self.session.get(url) 

165 response.raise_for_status() 

166 return response.json() 

167 

168 def _get_dashboards(self): 

169 """Get overview all dashboards. 

170 

171 For actual downloading, better use more modern assets api.""" 

172 url = f"{self.session.base_url}/api/v1/dashboard/" 

173 response = self.session.get(url) 

174 response.raise_for_status() 

175 return response.json() 

176 

177 def _get_assets_zip(self): 

178 """Get all assets from server as zip.""" 

179 url = f"{self.session.base_url}/api/v1/assets/export" 

180 response = self.session.get(url) 

181 response.raise_for_status() 

182 

183 # if the zip gets big we might need to consider streaming 

184 zip_bytes = response.content 

185 return zip_bytes, zipfile.ZipFile(io.BytesIO(zip_bytes), "r") 

186 

187 def _post_assets( 

188 self, 

189 zipfile_buffer: io.BytesIO | Path | bytes, 

190 overwrite: bool = False, 

191 passwords: dict[str, str] | None = None, 

192 ssh_tunnel_passwords: dict[str, str] | None = None, 

193 ssh_tunnel_private_key_passwords: dict[str, str] | None = None, 

194 ssh_tunnel_private_keys: dict[str, str] | None = None, 

195 ): 

196 # Get the zip content 

197 if isinstance(zipfile_buffer, io.BytesIO): 

198 zip_content = zipfile_buffer.getvalue() 

199 elif isinstance(zipfile_buffer, Path): 

200 zip_content = zipfile_buffer.read_bytes() 

201 elif isinstance(zipfile_buffer, bytes): 

202 zip_content = zipfile_buffer 

203 else: 

204 raise ValueError("zipfile must be io.BytesIO, Path, or bytes") 

205 

206 # only the file goes into `files=` 

207 files = { 

208 "bundle": ( 

209 "name_does_not_matter.zip", 

210 zip_content, 

211 "application/zip", 

212 ) 

213 } 

214 

215 # all non-file fields go into `data=` 

216 data = { 

217 "passwords": json.dumps(passwords or {}), 

218 "ssh_tunnel_passwords": json.dumps(ssh_tunnel_passwords or {}), 

219 "ssh_tunnel_private_keys": json.dumps(ssh_tunnel_private_keys or {}), 

220 "ssh_tunnel_private_key_passwords": json.dumps( 

221 ssh_tunnel_private_key_passwords or {} 

222 ), 

223 } 

224 if overwrite: 

225 data["overwrite"] = "true" 

226 

227 # ensure content-type is not set, to allow requests.post to set it. 

228 # this is needed so the boundary (file length) is also set automatically 

229 headers = dict(self.session.headers) 

230 headers.pop("Content-Type", None) 

231 headers["Referer"] = self.session.base_url.rstrip("/") + "/" 

232 

233 res = self.session.post( 

234 "/api/v1/assets/import/", 

235 files=files, 

236 data=data, 

237 headers=headers, 

238 ) 

239 

240 res.raise_for_status() 

241 

242 return res