Coverage for src / superset_io / api.py: 73%
112 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-02 13:07 +0100
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-02 13:07 +0100
1from __future__ import annotations
3import io
4import json
5import logging
6import shutil
7import tempfile
8import zipfile
9from pathlib import Path
11import requests
13from superset_io.session import SupersetApiSession
14from superset_io.utils import (
15 validate_assets_bundle_structure,
16 zipfile_buffer_from_folder,
17 zipfile_buffer_from_zipfile,
18)
20log = logging.getLogger("superset_io")
23class SuperSetApiClient:
24 session: SupersetApiSession
26 def __init__(self, session: SupersetApiSession):
27 self.session = session
29 # ---------------------------- Public Methods ---------------------------- #
31 def test_connection(self) -> bool:
32 """Test connection is possible.
34 Smoke test that:
35 0) We can connect via /health
36 1) Bearer auth is accepted (GET /api/v1/log/)
37 2) CSRF token header is accepted (POST /api/v1/assets/import/)
39 Returns true if accessible returns false if not
40 """
42 # 0) Server reachable
43 res = self.session.get("/health")
44 try:
45 res.raise_for_status()
46 log.info("✅ Server is reachable.")
47 except requests.HTTPError as e:
48 log.error(f"❌ Server not reachable: {e}")
49 log.debug(f" {e.response.text}" if e.response else "")
50 return False
52 # 1) Access token works
53 res = self.session.get("/api/v1/log/")
54 try:
55 res.raise_for_status()
56 log.info("✅ Access token working, can download assets.")
57 except requests.HTTPError as e:
58 msg = "❌ Could not access API that requires bearer token"
59 if res.status_code == 401:
60 msg += (
61 ". Check credentials and JWT_ALGORITHM in your superset_config.py"
62 )
63 log.error(f"{msg}\n {e}")
64 return False
66 # 2) CSRF works: pick a POST endpoint that requires CSRF,
67 # Send invalid payload so we don't create anything.
68 # Expectation:
69 # - If CSRF is missing/invalid => typically 400/403 (CSRF-related)
70 # - If CSRF is accepted => typically 400/422 (payload validation) or 403
71 if not self.session.headers.get("X-CSRFToken"):
72 log.error(
73 "❌ No X-CSRFToken set on session; cannot validate CSRF handling."
74 )
75 return False
77 res = self.session.post(
78 "/api/v1/assets/import/",
79 json={}, # invalid; we only want to get past CSRF
80 headers={"Referer": self.session.base_url.rstrip("/") + "/"},
81 )
83 try:
84 # Anything else is unexpected
85 res.raise_for_status()
86 log.info("✅ CSRF token working, can upload assets.")
87 except requests.HTTPError as e:
88 try:
89 error = res.json()["errors"][0]["error_type"]
90 if error == "INVALID_PAYLOAD_FORMAT_ERROR":
91 log.info("✅ CSRF token working, can upload assets.")
92 else:
93 raise ValueError("Expected INVALID_PAYLOAD_FORMAT_ERROR error!")
94 except Exception:
95 log.error(f"❌ CSRF validation failed: {e} {res.text}")
96 return False
98 return True
100 def download_assets(self, dst_path: Path):
101 """Download and export all assets to disk.
103 Depending an provided dst_path, we either write as zip file or the extracted
104 folder structure.
105 """
107 if dst_path.suffix.lower() == ".zip":
108 kind = "zip"
109 else:
110 kind = "folder"
111 dst_path.mkdir(parents=True, exist_ok=True)
113 if kind == "folder" and any(dst_path.iterdir()):
114 raise ValueError(f"Destination directory '{dst_path}' is not empty")
116 zip_bytes, zip_file = self._get_assets_zip()
118 if kind == "zip":
119 dst_path.parent.mkdir(parents=True, exist_ok=True)
120 with dst_path.open("wb") as f:
121 f.write(zip_bytes)
122 else:
123 # Extract to temp dir, get the assets and move to dst_path
124 with tempfile.TemporaryDirectory() as tmpdir:
125 tmp_path = Path(tmpdir)
126 zip_file.extractall(tmp_path)
128 src_folders = [
129 f for f in tmp_path.iterdir() if f.name.startswith("assets_export")
130 ]
131 if len(src_folders) != 1:
132 raise ValueError(
133 "Did not find a single `assets_export` folder in zip. "
134 "This should not happen."
135 )
137 for item in src_folders[0].iterdir():
138 shutil.move(item, dst_path / item.name)
140 def upload_assets(self, src_path: Path):
141 """Upload and restore assets from disk.
143 src_path can be zip or directory.
144 If directory, needs to directly contain the metadata.yml.
145 """
147 if src_path.suffix.lower() == ".zip":
148 zipfile_buffer = zipfile_buffer_from_zipfile(src_path)
149 else:
150 zipfile_buffer = zipfile_buffer_from_folder(src_path)
152 # raise for invalid zips, already before trying the endpoint
153 validate_assets_bundle_structure(zipfile_buffer)
155 self._post_assets(zipfile_buffer=zipfile_buffer)
157 # ----------------------------- Internal Use ----------------------------- #
159 def _get_dashboard(self, dashboard_id: int):
160 """Get a single dashboard.
162 For actual downloading, better use more modern assets api."""
163 url = f"{self.session.base_url}/dashboard/{dashboard_id}"
164 response = self.session.get(url)
165 response.raise_for_status()
166 return response.json()
168 def _get_dashboards(self):
169 """Get overview all dashboards.
171 For actual downloading, better use more modern assets api."""
172 url = f"{self.session.base_url}/api/v1/dashboard/"
173 response = self.session.get(url)
174 response.raise_for_status()
175 return response.json()
177 def _get_assets_zip(self):
178 """Get all assets from server as zip."""
179 url = f"{self.session.base_url}/api/v1/assets/export"
180 response = self.session.get(url)
181 response.raise_for_status()
183 # if the zip gets big we might need to consider streaming
184 zip_bytes = response.content
185 return zip_bytes, zipfile.ZipFile(io.BytesIO(zip_bytes), "r")
187 def _post_assets(
188 self,
189 zipfile_buffer: io.BytesIO | Path | bytes,
190 overwrite: bool = False,
191 passwords: dict[str, str] | None = None,
192 ssh_tunnel_passwords: dict[str, str] | None = None,
193 ssh_tunnel_private_key_passwords: dict[str, str] | None = None,
194 ssh_tunnel_private_keys: dict[str, str] | None = None,
195 ):
196 # Get the zip content
197 if isinstance(zipfile_buffer, io.BytesIO):
198 zip_content = zipfile_buffer.getvalue()
199 elif isinstance(zipfile_buffer, Path):
200 zip_content = zipfile_buffer.read_bytes()
201 elif isinstance(zipfile_buffer, bytes):
202 zip_content = zipfile_buffer
203 else:
204 raise ValueError("zipfile must be io.BytesIO, Path, or bytes")
206 # only the file goes into `files=`
207 files = {
208 "bundle": (
209 "name_does_not_matter.zip",
210 zip_content,
211 "application/zip",
212 )
213 }
215 # all non-file fields go into `data=`
216 data = {
217 "passwords": json.dumps(passwords or {}),
218 "ssh_tunnel_passwords": json.dumps(ssh_tunnel_passwords or {}),
219 "ssh_tunnel_private_keys": json.dumps(ssh_tunnel_private_keys or {}),
220 "ssh_tunnel_private_key_passwords": json.dumps(
221 ssh_tunnel_private_key_passwords or {}
222 ),
223 }
224 if overwrite:
225 data["overwrite"] = "true"
227 # ensure content-type is not set, to allow requests.post to set it.
228 # this is needed so the boundary (file length) is also set automatically
229 headers = dict(self.session.headers)
230 headers.pop("Content-Type", None)
231 headers["Referer"] = self.session.base_url.rstrip("/") + "/"
233 res = self.session.post(
234 "/api/v1/assets/import/",
235 files=files,
236 data=data,
237 headers=headers,
238 )
240 res.raise_for_status()
242 return res