Coverage for /home/benjarobin/Bootlin/projects/Schneider-Electric-Senux/sbom-cve-check/src/sbom_cve_check/cve_db/db_git.py: 69%

188 statements  

« prev     ^ index     » next       coverage.py v7.11.1, created at 2025-11-28 15:37 +0100

1# -*- coding: utf-8 -*- 

2# SPDX-License-Identifier: GPL-2.0-only 

3 

4import abc 

5import contextlib 

6import fcntl 

7import hashlib 

8import json 

9import logging 

10import os 

11import pathlib 

12from collections.abc import Callable, Generator 

13from datetime import UTC, datetime, timedelta 

14from typing import Any 

15 

16from ..utils import parsing 

17from ..utils.git import GitRepo 

18from .annot_base import AnnotDatabase 

19from .db_base import CveDatabase, CveDatabaseT 

20 

21_logger = logging.getLogger(__name__) 

22 

23 

24class GitDatabase: 

25 @staticmethod 

26 def parse_init_params(args: dict[str, Any]) -> None: 

27 parsing.update_timedelta_param( 

28 args, "auto_update_max_age", special_vals={"off": None} 

29 ) 

30 parsing.update_boolean_param(args, "max_age_since_last_commit") 

31 parsing.update_integer_param(args, "git_fetch_depth") 

32 

33 def __init__( 

34 self, 

35 git_dir: pathlib.Path, 

36 git_url: str, 

37 is_annotation: bool, 

38 *, 

39 auto_update_max_age: timedelta | None = timedelta(hours=20), 

40 max_age_since_last_commit: bool | None = None, 

41 cache_index_path: str | None = ".sbom-cve-check-cache-index.json", 

42 git_branch: str | None = None, 

43 git_ref: str | None = None, 

44 git_fetch_depth: int | None = None, 

45 ) -> None: 

46 self._git_repo = GitRepo(git_dir) 

47 self._git_clone_url: str = git_url 

48 

49 self.auto_update_max_age: timedelta | None = auto_update_max_age 

50 if max_age_since_last_commit is None: 

51 max_age_since_last_commit = not is_annotation 

52 self.max_age_since_last_commit: bool = max_age_since_last_commit 

53 

54 self._cache_index_path: pathlib.Path | None = None 

55 if cache_index_path: 

56 cache_index_path = os.path.expandvars(cache_index_path) 

57 self._cache_index_path = pathlib.Path(cache_index_path).expanduser() 

58 if not self._cache_index_path.is_absolute(): 

59 self._cache_index_path = self._git_repo.path.joinpath(cache_index_path) 

60 self._cache_index_path = self._cache_index_path.resolve() 

61 

62 self._git_branch = git_branch 

63 self._git_ref = git_ref 

64 self._git_fetch_depth = git_fetch_depth 

65 self._lock_fd: int = -1 

66 

67 self._git_repo.path.mkdir(parents=True, exist_ok=True) 

68 self._take_lock(False) 

69 

70 def __del__(self) -> None: 

71 self._release_lock() 

72 

73 def _take_lock(self, exclusive: bool) -> None: 

74 if self._lock_fd < 0: 

75 self._lock_fd = os.open(self._git_repo.path, os.O_RDONLY | os.O_NOCTTY) 

76 fcntl.flock(self._lock_fd, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH) 

77 

78 def _release_lock(self) -> None: 

79 if self._lock_fd >= 0: 

80 os.close(self._lock_fd) 

81 self._lock_fd = -1 

82 

83 @contextlib.contextmanager 

84 def _lock_exclusive_taken(self) -> Generator[None, Any, None]: 

85 try: 

86 self._take_lock(True) 

87 yield None 

88 finally: 

89 self._take_lock(False) 

90 

91 def initialize(self) -> None: 

92 """ 

93 Initialize this database. 

94 When updating the git database, take the lock exclusively. 

95 """ 

96 self.update(force_update=False) 

97 

98 def update(self, force_update: bool = True) -> None: 

99 """Update the database to the latest version""" 

100 with self._lock_exclusive_taken(): 

101 self._update(force_update) 

102 

103 def _auto_update_needed(self) -> bool: 

104 if self.auto_update_max_age is None: 

105 return False 

106 

107 if self.max_age_since_last_commit: 

108 d = self.get_date_last_commit() 

109 else: 

110 d = self.get_date_last_update() 

111 

112 if d is not None: 

113 return self._compute_elapsed_time(d) >= self.auto_update_max_age 

114 return True 

115 

116 def _update(self, force_update: bool) -> None: 

117 repo_dir = self._git_repo.path 

118 

119 # If the repository does not exist yet, we need to clone it 

120 if not self._git_repo.is_git_repo(): 

121 _logger.info("Downloading to %s from %s", repo_dir, self._git_clone_url) 

122 

123 # If the branch name is missing consider that the ref is a tag 

124 branch = self._git_branch 

125 if self._git_ref and (not branch): 

126 branch = self._git_ref 

127 

128 if self._git_fetch_depth is not None: 

129 depth = self._git_fetch_depth 

130 else: 

131 depth = 0 if self._git_ref else 1 

132 self._git_repo.clone(self._git_clone_url, branch, depth) 

133 return 

134 

135 # Shortcut: If asked for a specific reference, and the reference already 

136 # exist in the git repository, switch to this revision 

137 if self._git_ref and self._git_repo.is_valid_object(self._git_ref): 

138 self._git_repo.switch_detach_head(self._git_ref, True) 

139 return 

140 

141 # Get default remote branch if none was provided 

142 if not self._git_branch: 

143 self._git_branch = self._git_repo.get_default_remote_branch() 

144 if not self._git_branch: 

145 self._git_branch = self._git_repo.update_default_remote_branch() 

146 

147 # We need to check if we are on the correct branch if the user did not ask 

148 # for a specific reference. 

149 curr_branch: str | None = None 

150 if not force_update: 

151 if not self._git_ref: 

152 curr_branch = self._git_repo.get_current_branch() 

153 else: 

154 force_update = True 

155 

156 # If we are not on the correct branch, or if an update was forced, or if the 

157 # last update is too old, trigger a fetch 

158 if ( 

159 force_update 

160 or (self._git_branch != curr_branch) 

161 or self._auto_update_needed() 

162 ): 

163 fetch_url = self._git_repo.get_fetch_url() 

164 _logger.info("Updating %s from %s", repo_dir, fetch_url) 

165 

166 fetch_depths: list[int] = [ 

167 0 if self._git_fetch_depth is None else self._git_fetch_depth, 

168 0, 

169 ] 

170 for fetch_depth in fetch_depths: 

171 unshallow = False 

172 if (fetch_depth <= 0) and self._git_ref: 

173 unshallow = self._git_repo.is_shallow_repository() 

174 

175 self._git_repo.fetch(self._git_branch, unshallow, fetch_depth) 

176 

177 if ( 

178 (fetch_depth <= 0) 

179 or (not self._git_ref) 

180 or (self._git_repo.is_valid_object(self._git_ref)) 

181 ): 

182 break 

183 

184 # And finally update the repository 

185 if self._git_ref: 

186 self._git_repo.switch_detach_head(self._git_ref, True) 

187 else: 

188 self._git_repo.switch_force_create_branch(self._git_branch, True) 

189 

190 def get_date_last_commit(self) -> datetime | None: 

191 """ 

192 Get the date time of last commit from the git repository. 

193 

194 Get commiter date of the last commit (HEAD). 

195 Return None if the database does not exist yet. 

196 """ 

197 if not self._git_repo.is_git_repo(): 

198 return None 

199 return self._git_repo.get_date_last_commit() 

200 

201 def get_date_last_update(self) -> datetime | None: 

202 return self._git_repo.get_date_last_update() 

203 

204 @staticmethod 

205 def _compute_elapsed_time(start_date: datetime) -> timedelta: 

206 """ 

207 Get the elapsed time since a start date (should be in UTC) 

208 """ 

209 date_now = datetime.now(UTC) 

210 delta = date_now - start_date 

211 zero_delta = timedelta(0) 

212 return max(zero_delta, delta) 

213 

214 def create_index( 

215 self, 

216 create_index: Callable[[], dict[str, set[str]]], 

217 cfg_hash: str, 

218 ) -> dict[str, set[str]]: 

219 """ 

220 Create the database index if the index previously computed is not associated 

221 with the last git commit. Otherwise, load the previous index from disk. 

222 """ 

223 

224 def set_default(obj: object) -> list[str]: 

225 if isinstance(obj, set): 

226 return list(obj) 

227 raise TypeError 

228 

229 if not self._git_repo.is_git_repo(): 

230 raise RuntimeError(f"Database {self._git_repo.path} was not initialized") 

231 

232 commit_id = self._git_repo.resolve_sha_ref("HEAD") 

233 

234 if self._cache_index_path and self._cache_index_path.is_file(): 

235 with self._cache_index_path.open(encoding="utf-8") as f: 

236 cache_obj = json.load(f) 

237 if isinstance(cache_obj, dict): 

238 index_comp_cve = cache_obj.get("index") 

239 cache_commit = cache_obj.get("commit") 

240 cache_hash = cache_obj.get("cfg_hash") 

241 if ( 

242 (cache_commit == commit_id) 

243 and (cache_hash == cfg_hash) 

244 and isinstance(index_comp_cve, dict) 

245 ): 

246 return index_comp_cve 

247 

248 index_comp_cve = create_index() 

249 assert index_comp_cve 

250 

251 if self._cache_index_path: 

252 with ( 

253 self._lock_exclusive_taken(), 

254 self._cache_index_path.open("w", encoding="utf-8") as f, 

255 ): 

256 cache_obj = { 

257 "commit": commit_id, 

258 "cfg_hash": cfg_hash, 

259 "index": index_comp_cve, 

260 } 

261 f.write(json.dumps(cache_obj, default=set_default)) 

262 

263 return index_comp_cve 

264 

265 

266class GitCveDatabase(CveDatabase, abc.ABC): 

267 @classmethod 

268 def create_from_config(cls: type[CveDatabaseT], **kwargs: Any) -> CveDatabaseT: 

269 GitDatabase.parse_init_params(kwargs) 

270 return super().create_from_config(**kwargs) 

271 

272 def __init__( 

273 self, path: pathlib.Path, name: str, git_url: str, **kwargs: Any 

274 ) -> None: 

275 super().__init__(name) 

276 self._git_dir: pathlib.Path = path.resolve() 

277 self._git_db = GitDatabase(self._git_dir, git_url, False, **kwargs) 

278 

279 @property 

280 def git_database(self) -> GitDatabase: 

281 return self._git_db 

282 

283 def _initialize(self) -> None: 

284 self._git_db.initialize() 

285 

286 def create_index(self) -> dict[str, set[str]]: 

287 return self._git_db.create_index(super().create_index, "") 

288 

289 

290class GitAnnotDatabase(AnnotDatabase, abc.ABC): 

291 @classmethod 

292 def create_from_config(cls: type[CveDatabaseT], **kwargs: Any) -> CveDatabaseT: 

293 GitDatabase.parse_init_params(kwargs) 

294 return super().create_from_config(**kwargs) 

295 

296 def __init__( 

297 self, 

298 path: pathlib.Path, 

299 name: str, 

300 obsolete_assessment_check: bool | None = None, 

301 git_url: str | None = None, 

302 **kwargs: Any, 

303 ) -> None: 

304 super().__init__(name, obsolete_assessment_check) 

305 self._db_dir: pathlib.Path = path.resolve(strict=not git_url) 

306 self._git_db: GitDatabase | None = None 

307 if git_url: 

308 self._git_db = GitDatabase(self._db_dir, git_url, True, **kwargs) 

309 

310 @property 

311 def git_database(self) -> GitDatabase | None: 

312 return self._git_db 

313 

314 def _initialize(self) -> None: 

315 if self._git_db is not None: 

316 self._git_db.initialize() 

317 

318 @abc.abstractmethod 

319 def _get_index_cache_invalidation_data(self) -> bytes: 

320 """ 

321 Raw data, used to build an invalidation key for cache_index 

322 

323 Use to invalidate the index cache if the annotation database changed. 

324 """ 

325 

326 def create_index(self) -> dict[str, set[str]]: 

327 if self._git_db is not None: 

328 h = hashlib.sha256(self._get_index_cache_invalidation_data()) 

329 h.update(b"||") 

330 h.update(self._db_dir.as_posix().encode()) 

331 return self._git_db.create_index(super().create_index, h.hexdigest()) 

332 return super().create_index()