Coverage for jumpstarter_driver_flashers/driver.py: 95%

139 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-30 18:45 +0200

1from dataclasses import dataclass, field 

2from pathlib import Path 

3 

4import anyio.to_thread 

5from jumpstarter_driver_http.driver import HttpServer 

6from jumpstarter_driver_tftp.driver import Tftp 

7from jumpstarter_driver_uboot.driver import UbootConsole 

8from oras.provider import Registry 

9 

10from .bundle import FlasherBundleManifestV1Alpha1 

11from jumpstarter.common.exceptions import ConfigurationError 

12from jumpstarter.driver import Driver, export 

13 

14 

15@dataclass(kw_only=True) 

16class BaseFlasher(Driver): 

17 """driver for Jumpstarter""" 

18 

19 flasher_bundle: str = field(default="quay.io/jumpstarter-dev/jumpstarter-flasher-test:latest") 

20 variant: None | str = field(default=None) 

21 manifest: str = field(default="manifest.yaml") 

22 default_target: None | str = field(default=None) 

23 cache_dir: str = field(default="/var/lib/jumpstarter/flasher") 

24 tftp_dir: str = field(default="/var/lib/tftpboot") 

25 http_dir: str = field(default="/var/www/html") 

26 

27 def __post_init__(self): 

28 if hasattr(super(), "__post_init__"): 

29 super().__post_init__() 

30 

31 # Ensure required children are present if not already instantiated 

32 # in configuration 

33 if "tftp" not in self.children: 

34 self.children["tftp"] = Tftp(root_dir=self.tftp_dir) 

35 self.tftp = self.children["tftp"] 

36 

37 if "http" not in self.children: 

38 self.children["http"] = HttpServer(root_dir=self.http_dir) 

39 self.http = self.children["http"] 

40 

41 # Ensure required children are present, the following are not auto-created 

42 if "serial" not in self.children: 

43 raise ConfigurationError( 

44 "'serial' instance is required for BaseFlasher either via a ref ir a direct child instance" 

45 ) 

46 

47 if "power" not in self.children: 

48 raise ConfigurationError( 

49 "'power' instance is required for BaseFlasher either via a ref ir a direct child instance" 

50 ) 

51 

52 if "uboot" not in self.children: 

53 self.children["uboot"] = UbootConsole( 

54 children={ 

55 "power": self.children["power"], 

56 "serial": self.children["serial"], 

57 } 

58 ) 

59 

60 # bundles that have already been downloaded in the current session 

61 self._downloaded = {} 

62 

63 @classmethod 

64 def client(cls) -> str: 

65 return "jumpstarter_driver_flashers.client.BaseFlasherClient" 

66 

67 @export 

68 async def get_default_target(self): 

69 """Return the default target""" 

70 return self.default_target 

71 

72 @export 

73 async def setup_flasher_bundle(self, force_flash_bundle: str | None = None): 

74 """Setup flasher bundle 

75 

76 This method sets all the files in place in the tftp server 

77 so that the target can download from bootloader. 

78 """ 

79 

80 # the client is requesting a different flasher bundle 

81 if force_flash_bundle: 

82 self.flasher_bundle = force_flash_bundle 

83 

84 manifest = await self.get_flasher_manifest() 

85 kernel_path = await self._get_file_path(manifest.spec.kernel.file) 

86 self.logger.info(f"Setting up kernel in tftp: {kernel_path}") 

87 await self.tftp.storage.copy_exporter_file(kernel_path, kernel_path.name) 

88 

89 initram_file = manifest.get_initram_file() 

90 if initram_file: 

91 initram_path = await self._get_file_path(initram_file) 

92 self.logger.info(f"Setting up initram in tftp: {initram_path}") 

93 await self.tftp.storage.copy_exporter_file(initram_path, initram_path.name) 

94 

95 dtb_file = manifest.get_dtb_file(self.variant) if manifest.spec.dtb else None 

96 if dtb_file: 

97 dtb_path = await self._get_file_path(dtb_file) 

98 self.logger.info(f"Setting up dtb in tftp: {dtb_path}") 

99 await self.tftp.storage.copy_exporter_file(dtb_path, dtb_path.name) 

100 

101 @export 

102 def set_dtb(self, handle): 

103 """Provide a different dtb from client""" 

104 raise NotImplementedError 

105 

106 @export 

107 async def use_dtb_variant(self, variant): 

108 """Provide a different dtb reference from the flasher bundle""" 

109 manifest = await self.get_flasher_manifest() 

110 # Check if the variant exists in the manifest 

111 if not manifest.spec.dtb or variant not in manifest.spec.dtb.variants: 

112 variant_list = [] 

113 if manifest.spec.dtb: 

114 variant_list = list(manifest.spec.dtb.variants.keys()) 

115 raise ValueError( 

116 f"DTB variant {variant} not found in the flasher bundle, " 

117 f"available variants are: {variant_list}." 

118 ) 

119 self.variant = variant 

120 

121 def set_kernel(self, handle): 

122 """Provide a different kernel from client""" 

123 raise NotImplementedError 

124 

125 def set_initram(self, handle): 

126 """Provide a different initram from client""" 

127 raise NotImplementedError 

128 

129 def _download_to_cache(self) -> str: 

130 """Download the bundle to the cache 

131 

132 This function downloads the bundle contents to the cache directory, 

133 if it was already downloaded during this session we don't download it again. 

134 """ 

135 if self._downloaded.get(self.flasher_bundle): 

136 self.logger.debug(f"Bundled already downloaded: {self.flasher_bundle}") 

137 return self._downloaded[self.flasher_bundle] 

138 

139 oras_client = Registry() 

140 self.logger.info(f"Downloading bundle: {self.flasher_bundle}") 

141 

142 # make a filesystem valid name for the cache directory 

143 bundle_subdir = self.flasher_bundle.replace(":", "_").replace("/", "_") 

144 bundle_dir = Path(self.cache_dir) / bundle_subdir 

145 

146 # ensure the bundle dir exists 

147 bundle_dir.mkdir(parents=True, exist_ok=True) 

148 oras_client.pull(self.flasher_bundle, outdir=bundle_dir) 

149 

150 self.logger.info(f"Bundle downloaded to {bundle_dir}") 

151 

152 # mark this bundle as downloaded for the current object lifetime 

153 self._downloaded[self.flasher_bundle] = bundle_dir 

154 return bundle_dir 

155 

156 async def _get_file_path(self, filename) -> Path: 

157 """Get the bundle contents path. 

158 

159 This function will ensure that the bundle is downloaded into cache, and 

160 then return the path to the requested file in the cache directory. 

161 """ 

162 if filename is None: 

163 raise ValueError("filename cannot be None") 

164 bundle_dir = await anyio.to_thread.run_sync(self._download_to_cache) 

165 return Path(bundle_dir) / filename 

166 

167 @export 

168 async def get_flasher_manifest_yaml(self) -> str: 

169 """Return the manifest yaml as a string for client side consumption""" 

170 with open(await self._get_file_path(self.manifest)) as f: 

171 return f.read() 

172 

173 async def get_flasher_manifest(self) -> FlasherBundleManifestV1Alpha1: 

174 filename = await self._get_file_path(self.manifest) 

175 return FlasherBundleManifestV1Alpha1.from_file(filename) 

176 

177 @export 

178 async def get_kernel_filename(self) -> str: 

179 """Return the kernel filename""" 

180 manifest = await self.get_flasher_manifest() 

181 return Path(manifest.get_kernel_file()).name 

182 

183 @export 

184 async def get_initram_filename(self) -> str | None: 

185 """Return the initram filename""" 

186 manifest = await self.get_flasher_manifest() 

187 filename = manifest.get_initram_file() 

188 if filename: 

189 return Path(filename).name 

190 

191 @export 

192 async def get_dtb_filename(self) -> str: 

193 """Return the dtb filename""" 

194 manifest = await self.get_flasher_manifest() 

195 dtb_file = manifest.get_dtb_file(self.variant) 

196 if dtb_file: 

197 return Path(dtb_file).name 

198 else: 

199 return "" 

200 

201 @export 

202 async def get_dtb_address(self) -> str: 

203 """Return the dtb address""" 

204 manifest = await self.get_flasher_manifest() 

205 return manifest.get_dtb_address() 

206 

207 @export 

208 async def get_kernel_address(self) -> str: 

209 """Return the kernel address""" 

210 manifest = await self.get_flasher_manifest() 

211 return manifest.get_kernel_address() 

212 

213 @export 

214 async def get_initram_address(self) -> str: 

215 """Return the initram address""" 

216 manifest = await self.get_flasher_manifest() 

217 return manifest.get_initram_address() 

218 

219 @export 

220 async def get_bootcmd(self) -> str: 

221 """Return the bootcmd""" 

222 manifest = await self.get_flasher_manifest() 

223 return manifest.get_boot_cmd(self.variant) 

224 

225@dataclass(kw_only=True) 

226class TIJ784S4Flasher(BaseFlasher): 

227 """driver for Jumpstarter""" 

228 

229 flasher_bundle: str = "quay.io/jumpstarter-dev/jumpstarter-flasher-ti-j784s4:latest" 

230 

231 

232@dataclass(kw_only=True) 

233class RCarS4Flasher(BaseFlasher): 

234 """RCarS4 driver for Jumpstarter""" 

235 

236 flasher_bundle: str = "quay.io/jumpstarter-dev/jumpstarter-flasher-rcar-s4:latest"