Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py: 72%

230 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 17:10 +0200

1import hashlib 

2from abc import ABCMeta, abstractmethod 

3from collections.abc import AsyncGenerator 

4from dataclasses import dataclass, field 

5from pathlib import Path 

6from tempfile import NamedTemporaryFile, TemporaryDirectory, _TemporaryFileWrapper 

7from typing import Any 

8from uuid import UUID, uuid4 

9 

10from anyio.streams.file import FileReadStream, FileWriteStream 

11from opendal import AsyncFile, AsyncOperator 

12from opendal import Metadata as OpendalMetadata 

13from pydantic import validate_call 

14 

15from .adapter import AsyncFileStream 

16from .common import Capability, HashAlgo, Metadata, Mode, PresignedRequest 

17from jumpstarter.driver import Driver, export 

18 

19 

20@dataclass(kw_only=True) 

21class Opendal(Driver): 

22 scheme: str 

23 kwargs: dict[str, str] 

24 

25 _operator: AsyncOperator = field(init=False) 

26 _fds: dict[UUID, AsyncFile] = field(init=False, default_factory=dict) 

27 _metadata: dict[UUID, OpendalMetadata] = field(init=False, default_factory=dict) 

28 

29 @classmethod 

30 def client(cls) -> str: 

31 return "jumpstarter_driver_opendal.client.OpendalClient" 

32 

33 def __post_init__(self): 

34 if hasattr(super(), "__post_init__"): 

35 super().__post_init__() 

36 

37 self._operator = AsyncOperator(self.scheme, **self.kwargs) 

38 

39 @export 

40 @validate_call(validate_return=True) 

41 async def open(self, /, path: str, mode: Mode) -> UUID: 

42 try: 

43 metadata = await self._operator.stat(path) 

44 except Exception: 

45 metadata = None 

46 file = await self._operator.open(path, mode) 

47 uuid = uuid4() 

48 

49 self._metadata[uuid] = metadata 

50 self._fds[uuid] = file 

51 

52 return uuid 

53 

54 @export 

55 @validate_call(validate_return=True) 

56 async def file_read(self, /, fd: UUID, dst: Any) -> None: 

57 async with self.resource(dst) as res: 

58 stream = AsyncFileStream(file=self._fds[fd], metadata=self._metadata[fd]) 

59 async for chunk in stream: 

60 await res.send(chunk) 

61 

62 @export 

63 @validate_call(validate_return=True) 

64 async def file_write(self, /, fd: UUID, src: Any) -> None: 

65 async with self.resource(src) as res: 

66 stream = AsyncFileStream(file=self._fds[fd], metadata=self._metadata[fd]) 

67 async for chunk in res: 

68 await stream.send(chunk) 

69 

70 @export 

71 @validate_call(validate_return=True) 

72 async def file_seek(self, /, fd: UUID, pos: int, whence: int = 0) -> int: 

73 return await self._fds[fd].seek(pos, whence) 

74 

75 @export 

76 @validate_call(validate_return=True) 

77 async def file_tell(self, /, fd: UUID) -> int: 

78 return await self._fds[fd].tell() 

79 

80 @export 

81 @validate_call(validate_return=True) 

82 async def file_close(self, /, fd: UUID) -> None: 

83 await self._fds[fd].close() 

84 

85 @export 

86 @validate_call(validate_return=True) 

87 async def file_closed(self, /, fd: UUID) -> bool: 

88 return await self._fds[fd].closed 

89 

90 @export 

91 @validate_call(validate_return=True) 

92 async def file_readable(self, /, fd: UUID) -> bool: 

93 return await self._fds[fd].readable() 

94 

95 @export 

96 @validate_call(validate_return=True) 

97 async def file_seekable(self, /, fd: UUID) -> bool: 

98 return await self._fds[fd].seekable() 

99 

100 @export 

101 @validate_call(validate_return=True) 

102 async def file_writable(self, /, fd: UUID) -> bool: 

103 return await self._fds[fd].writable() 

104 

105 @export 

106 @validate_call(validate_return=True) 

107 async def stat(self, /, path: str) -> Metadata: 

108 return Metadata.model_validate(await self._operator.stat(path), from_attributes=True) 

109 

110 @export 

111 @validate_call(validate_return=True) 

112 async def hash(self, /, path: str, algo: HashAlgo = "sha256") -> str: 

113 match algo: 

114 case "md5": 

115 m = hashlib.md5() 

116 case "sha256": 

117 m = hashlib.sha256() 

118 async with await self._operator.open(path, "rb") as f: 

119 while True: 

120 data = await f.read(size=65536) 

121 if len(data) == 0: 

122 break 

123 m.update(data) 

124 

125 return m.hexdigest() 

126 

127 @export 

128 @validate_call(validate_return=True) 

129 async def copy(self, /, source: str, target: str): 

130 await self._operator.copy(source, target) 

131 

132 @export 

133 @validate_call(validate_return=True) 

134 async def rename(self, /, source: str, target: str): 

135 await self._operator.rename(source, target) 

136 

137 @export 

138 @validate_call(validate_return=True) 

139 async def remove_all(self, /, path: str): 

140 await self._operator.remove_all(path) 

141 

142 @export 

143 @validate_call(validate_return=True) 

144 async def create_dir(self, /, path: str): 

145 await self._operator.create_dir(path) 

146 

147 @export 

148 @validate_call(validate_return=True) 

149 async def delete(self, /, path: str): 

150 await self._operator.delete(path) 

151 

152 @export 

153 @validate_call(validate_return=True) 

154 async def exists(self, /, path: str) -> bool: 

155 return await self._operator.exists(path) 

156 

157 @export 

158 async def list(self, /, path: str) -> AsyncGenerator[str, None]: 

159 async for entry in await self._operator.list(path): 

160 yield entry.path 

161 

162 @export 

163 async def scan(self, /, path: str) -> AsyncGenerator[str, None]: 

164 async for entry in await self._operator.scan(path): 

165 yield entry.path 

166 

167 @export 

168 @validate_call(validate_return=True) 

169 async def presign_stat(self, /, path: str, expire_second: int) -> PresignedRequest: 

170 return PresignedRequest.model_validate( 

171 await self._operator.presign_stat(path, expire_second), from_attributes=True 

172 ) 

173 

174 @export 

175 @validate_call(validate_return=True) 

176 async def presign_read(self, /, path: str, expire_second: int) -> PresignedRequest: 

177 return PresignedRequest.model_validate( 

178 await self._operator.presign_read(path, expire_second), from_attributes=True 

179 ) 

180 

181 @export 

182 @validate_call(validate_return=True) 

183 async def presign_write(self, /, path: str, expire_second: int) -> PresignedRequest: 

184 return PresignedRequest.model_validate( 

185 await self._operator.presign_write(path, expire_second), from_attributes=True 

186 ) 

187 

188 @export 

189 @validate_call(validate_return=True) 

190 async def capability(self, /) -> Capability: 

191 return Capability.model_validate(self._operator.capability(), from_attributes=True) 

192 

193 async def copy_exporter_file(self, /, source: Path, target: str): 

194 """Copy a file from the exporter to the target path. 

195 This function is intended to be used on the exporter side to copy files to the target path. 

196 """ 

197 async with await AsyncOperator("fs", root=source.parent.as_posix()).open(source.name, "rb") as src: 

198 async with await self._operator.open(target, "wb") as dst: 

199 while True: 

200 data = await src.read(size=65536) 

201 if len(data) == 0: 

202 break 

203 await dst.write(bs=data) 

204 

205 

206class FlasherInterface(metaclass=ABCMeta): 

207 @classmethod 

208 def client(cls) -> str: 

209 return "jumpstarter_driver_opendal.client.FlasherClient" 

210 

211 @abstractmethod 

212 def flash(self, source, partition: str | None = None): ... 

213 

214 @abstractmethod 

215 def dump(self, target, partition: str | None = None): ... 

216 

217 

218@dataclass 

219class MockFlasher(FlasherInterface, Driver): 

220 _tempdir: TemporaryDirectory = field(default_factory=TemporaryDirectory) 

221 

222 def __path(self, partition: str | None = None) -> str: 

223 if partition is None: 

224 partition = "default" 

225 return str(Path(self._tempdir.name) / partition) 

226 

227 @export 

228 async def flash(self, source, partition: str | None = None): 

229 async with await FileWriteStream.from_path(self.__path(partition)) as stream: 

230 async with self.resource(source) as res: 

231 async for chunk in res: 

232 await stream.send(chunk) 

233 

234 @export 

235 async def dump(self, target, partition: str | None = None): 

236 async with await FileReadStream.from_path(self.__path(partition)) as stream: 

237 async with self.resource(target) as res: 

238 async for chunk in stream: 

239 await res.send(chunk) 

240 

241 

242class StorageMuxInterface(metaclass=ABCMeta): 

243 @classmethod 

244 def client(cls) -> str: 

245 return "jumpstarter_driver_opendal.client.StorageMuxClient" 

246 

247 @abstractmethod 

248 async def host(self): ... 

249 

250 @abstractmethod 

251 async def dut(self): ... 

252 

253 @abstractmethod 

254 async def off(self): ... 

255 

256 @abstractmethod 

257 async def write(self, src: str): ... 

258 

259 @abstractmethod 

260 async def read(self, dst: str): ... 

261 

262 

263class StorageMuxFlasherInterface(StorageMuxInterface): 

264 @classmethod 

265 def client(cls) -> str: 

266 return "jumpstarter_driver_opendal.client.StorageMuxFlasherClient" 

267 

268 

269@dataclass 

270class MockStorageMux(StorageMuxInterface, Driver): 

271 file: _TemporaryFileWrapper = field(default_factory=NamedTemporaryFile) 

272 

273 @export 

274 async def host(self): 

275 pass 

276 

277 @export 

278 async def dut(self): 

279 pass 

280 

281 @export 

282 async def off(self): 

283 pass 

284 

285 @export 

286 async def write(self, src: str): 

287 async with await FileWriteStream.from_path(self.file.name) as stream: 

288 async with self.resource(src) as res: 

289 async for chunk in res: 

290 await stream.send(chunk) 

291 

292 @export 

293 async def read(self, dst: str): 

294 async with await FileReadStream.from_path(self.file.name) as stream: 

295 async with self.resource(dst) as res: 

296 async for chunk in stream: 

297 await res.send(chunk) 

298 

299 

300@dataclass 

301class MockStorageMuxFlasher(StorageMuxFlasherInterface, MockStorageMux): 

302 pass