Coverage for jumpstarter_driver_flashers/driver.py: 97%

116 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-05 20:29 +0000

1from dataclasses import dataclass, field 

2from pathlib import Path 

3 

4import anyio.to_thread 

5from jumpstarter_driver_http.driver import HttpServer 

6from jumpstarter_driver_tftp.driver import Tftp 

7from jumpstarter_driver_uboot.driver import UbootConsole 

8from oras.provider import Registry 

9 

10from .bundle import FlasherBundleManifestV1Alpha1 

11from jumpstarter.common.exceptions import ConfigurationError 

12from jumpstarter.driver import Driver, export 

13 

14 

15@dataclass(kw_only=True) 

16class BaseFlasher(Driver): 

17 """driver for Jumpstarter""" 

18 

19 flasher_bundle: str = field(default="quay.io/jumpstarter-dev/jumpstarter-flasher-test:latest") 

20 cache_dir: str = field(default="/var/lib/jumpstarter/flasher") 

21 tftp_dir: str = field(default="/var/lib/tftpboot") 

22 http_dir: str = field(default="/var/www/html") 

23 

24 def __post_init__(self): 

25 if hasattr(super(), "__post_init__"): 

26 super().__post_init__() 

27 

28 # Ensure required children are present if not already instantiated 

29 # in configuration 

30 if "tftp" not in self.children: 

31 self.children["tftp"] = Tftp(root_dir=self.tftp_dir) 

32 self.tftp = self.children["tftp"] 

33 

34 if "http" not in self.children: 

35 self.children["http"] = HttpServer(root_dir=self.http_dir) 

36 self.http = self.children["http"] 

37 

38 # Ensure required children are present, the following are not auto-created 

39 if "serial" not in self.children: 

40 raise ConfigurationError( 

41 "'serial' instance is required for BaseFlasher either via a ref ir a direct child instance" 

42 ) 

43 

44 if "power" not in self.children: 

45 raise ConfigurationError( 

46 "'power' instance is required for BaseFlasher either via a ref ir a direct child instance" 

47 ) 

48 

49 if "uboot" not in self.children: 

50 self.children["uboot"] = UbootConsole( 

51 children={ 

52 "power": self.children["power"], 

53 "serial": self.children["serial"], 

54 } 

55 ) 

56 

57 # bundles that have already been downloaded in the current session 

58 self._downloaded = {} 

59 self._use_dtb = None # use default dtb unless set by client 

60 

61 @classmethod 

62 def client(cls) -> str: 

63 return "jumpstarter_driver_flashers.client.BaseFlasherClient" 

64 

65 @export 

66 async def setup_flasher_bundle(self, force_flash_bundle: str | None = None): 

67 """Setup flasher bundle 

68 

69 This method sets all the files in place in the tftp server 

70 so that the target can download from bootloader. 

71 """ 

72 

73 # the client is requesting a different flasher bundle 

74 if force_flash_bundle: 

75 self.flasher_bundle = force_flash_bundle 

76 

77 manifest = await self.get_flasher_manifest() 

78 kernel_path = await self._get_file_path(manifest.spec.kernel.file) 

79 self.logger.info(f"Setting up kernel in tftp: {kernel_path}") 

80 await self.tftp.storage.copy_exporter_file(kernel_path, kernel_path.name) 

81 

82 initram_path = await self._get_file_path(manifest.spec.initram.file) if manifest.spec.initram else None 

83 if initram_path: 

84 self.logger.info(f"Setting up initram in tftp: {initram_path}") 

85 await self.tftp.storage.copy_exporter_file(initram_path, initram_path.name) 

86 

87 dtb_path = await self._get_file_path(manifest.get_dtb_file(self._use_dtb)) 

88 self.logger.info(f"Setting up dtb in tftp: {dtb_path}") 

89 await self.tftp.storage.copy_exporter_file(dtb_path, dtb_path.name) 

90 

91 @export 

92 def set_dtb(self, handle): 

93 """Provide a different dtb from client""" 

94 raise NotImplementedError 

95 

96 @export 

97 async def use_dtb_variant(self, variant): 

98 """Provide a different dtb reference from the flasher bundle""" 

99 manifest = await self.get_flasher_manifest() 

100 if manifest.get_dtb_file(variant) is None: 

101 raise ValueError( 

102 f"DTB variant {variant} not found in the flasher bundle, " 

103 f"available variants are: {list(manifest.spec.dtb.variants.keys())}" 

104 ) 

105 self._use_dtb = variant 

106 

107 def set_kernel(self, handle): 

108 """Provide a different kernel from client""" 

109 raise NotImplementedError 

110 

111 def set_initram(self, handle): 

112 """Provide a different initram from client""" 

113 raise NotImplementedError 

114 

115 def _download_to_cache(self) -> str: 

116 """Download the bundle to the cache 

117 

118 This function downloads the bundle contents to the cache directory, 

119 if it was already downloaded during this session we don't download it again. 

120 """ 

121 if self._downloaded.get(self.flasher_bundle): 

122 self.logger.debug(f"Bundled already downloaded: {self.flasher_bundle}") 

123 return self._downloaded[self.flasher_bundle] 

124 

125 oras_client = Registry() 

126 self.logger.info(f"Downloading bundle: {self.flasher_bundle}") 

127 

128 # make a filesystem valid name for the cache directory 

129 bundle_subdir = self.flasher_bundle.replace(":", "_").replace("/", "_") 

130 bundle_dir = Path(self.cache_dir) / bundle_subdir 

131 

132 # ensure the bundle dir exists 

133 bundle_dir.mkdir(parents=True, exist_ok=True) 

134 oras_client.pull(self.flasher_bundle, outdir=bundle_dir) 

135 

136 self.logger.info(f"Bundle downloaded to {bundle_dir}") 

137 

138 # mark this bundle as downloaded for the current object lifetime 

139 self._downloaded[self.flasher_bundle] = bundle_dir 

140 return bundle_dir 

141 

142 async def _get_file_path(self, filename) -> Path: 

143 """Get the bundle contents path. 

144 

145 This function will ensure that the bundle is downloaded into cache, and 

146 then return the path to the requested file in the cache directory. 

147 """ 

148 bundle_dir = await anyio.to_thread.run_sync(self._download_to_cache) 

149 return Path(bundle_dir) / filename 

150 

151 @export 

152 async def get_flasher_manifest_yaml(self) -> str: 

153 """Return the manifest yaml as a string for client side consumption""" 

154 with open(await self._get_file_path("manifest.yaml")) as f: 

155 return f.read() 

156 

157 async def get_flasher_manifest(self) -> FlasherBundleManifestV1Alpha1: 

158 filename = await self._get_file_path("manifest.yaml") 

159 return FlasherBundleManifestV1Alpha1.from_file(filename) 

160 

161 @export 

162 async def get_kernel_filename(self) -> str: 

163 """Return the kernel filename""" 

164 manifest = await self.get_flasher_manifest() 

165 return Path(manifest.get_kernel_file()).name 

166 

167 @export 

168 async def get_initram_filename(self) -> str: 

169 """Return the initram filename""" 

170 manifest = await self.get_flasher_manifest() 

171 filename = manifest.get_initram_file() 

172 if filename: 

173 return Path(filename).name 

174 

175 @export 

176 async def get_dtb_filename(self) -> str: 

177 """Return the dtb filename""" 

178 manifest = await self.get_flasher_manifest() 

179 return Path(manifest.get_dtb_file(self._use_dtb)).name 

180 

181 @export 

182 async def get_dtb_address(self) -> str: 

183 """Return the dtb address""" 

184 manifest = await self.get_flasher_manifest() 

185 return manifest.get_dtb_address() 

186 

187 @export 

188 async def get_kernel_address(self) -> str: 

189 """Return the kernel address""" 

190 manifest = await self.get_flasher_manifest() 

191 return manifest.get_kernel_address() 

192 

193 @export 

194 async def get_initram_address(self) -> str: 

195 """Return the initram address""" 

196 manifest = await self.get_flasher_manifest() 

197 return manifest.get_initram_address() 

198 

199 

200@dataclass(kw_only=True) 

201class TIJ784S4Flasher(BaseFlasher): 

202 """driver for Jumpstarter""" 

203 

204 flasher_bundle: str = "quay.io/jumpstarter-dev/jumpstarter-flasher-ti-j784s4:latest"