Coverage for /home/benjarobin/Bootlin/projects/Schneider-Electric-Senux/sbom-cve-check/src/sbom_cve_check/vuln/cvss.py: 94%

126 statements  

« prev     ^ index     » next       coverage.py v7.11.1, created at 2025-11-28 15:37 +0100

1# -*- coding: utf-8 -*- 

2# SPDX-License-Identifier: GPL-2.0-only 

3 

4import dataclasses 

5from collections.abc import Callable, Iterable 

6from enum import Enum 

7from typing import Any, Optional, TypeVar 

8 

9from ..utils.class_utils import Singleton 

10 

11_CvssEnumT = TypeVar("_CvssEnumT", bound=Enum) 

12GroupByT = TypeVar("GroupByT") 

13 

14 

15class CvssMetricsRegistry(metaclass=Singleton): 

16 def __init__(self) -> None: 

17 self._metrics: dict[tuple[int, str], tuple[type[Enum], str]] = {} 

18 

19 def register_type(self, t: type[Enum], vers: int, key: str, desc: str) -> None: 

20 metric_key = (vers, key) 

21 assert metric_key not in self._metrics 

22 self._metrics[metric_key] = (t, desc) 

23 

24 def desc(self, vers: int, key: str) -> str | None: 

25 info = self._metrics.get((vers, key)) 

26 return info[1] if info else None 

27 

28 def value(self, vers: int, key: str, val: str) -> Enum | None: 

29 info = self._metrics.get((vers, key)) 

30 if not info: 

31 return None 

32 return info[0](val) 

33 

34 

35def register_cvss_metric( 

36 vers: int, key: str, desc: str 

37) -> Callable[[type[_CvssEnumT]], type[_CvssEnumT]]: 

38 def decorator(cls: type[_CvssEnumT]) -> type[_CvssEnumT]: 

39 CvssMetricsRegistry().register_type(cls, vers, key, desc) 

40 return cls 

41 

42 return decorator 

43 

44 

45class CvssVersion(Enum): 

46 UNKNOWN = (0, 0) 

47 V2_0 = (2, 0) 

48 V3_0 = (3, 0) 

49 V3_1 = (3, 1) 

50 V4_0 = (4, 0) 

51 

52 

53class CvssSeverity(Enum): 

54 NONE = "none" 

55 LOW = "low" 

56 MEDIUM = "medium" 

57 HIGH = "high" 

58 CRITICAL = "critical" 

59 

60 

61@register_cvss_metric(2, "AV", "Access Vector") 

62class Cvss2AccessVector(Enum): 

63 LOCAL = "L" 

64 ADJACENT_NETWORK = "A" 

65 NETWORK = "N" 

66 

67 

68@register_cvss_metric(3, "AV", "Attack Vector") 

69class Cvss3AttackVector(Enum): 

70 LOCAL = "L" 

71 ADJACENT_NETWORK = "A" 

72 NETWORK = "N" 

73 PHYSICAL = "P" 

74 

75 

76@register_cvss_metric(4, "AV", "Attack Vector") 

77class Cvss4AttackVector(Enum): 

78 LOCAL = "L" 

79 ADJACENT = "A" 

80 NETWORK = "N" 

81 PHYSICAL = "P" 

82 

83 

84@dataclasses.dataclass(frozen=True) 

85class CvssMetric: 

86 cvss_ver: CvssVersion 

87 score: float 

88 vector_str: str 

89 severity: CvssSeverity | None = None 

90 source: str | None = None 

91 

92 def cmp_key(self) -> tuple[Any, ...]: 

93 return ( 

94 self.cvss_ver.value, 

95 self.score, 

96 tuple(sorted(self.vector_str.split("/"))), 

97 ) 

98 

99 def decode_vector(self) -> dict[str, Enum]: 

100 metrics: dict[str, Enum] = {} 

101 vers = self.cvss_ver.value[0] 

102 registry = CvssMetricsRegistry() 

103 vector_parts = self.vector_str.split("/") 

104 if vers > 2: 

105 vector_parts = vector_parts[1:] 

106 for metric in vector_parts: 

107 key, v = metric.split(":", 1) 

108 e = registry.value(vers, key, v) 

109 if e is not None: 

110 metrics[key] = e 

111 return metrics 

112 

113 @staticmethod 

114 def compute_severity_from_score(score: float) -> CvssSeverity: 

115 if score < 0.1: 

116 return CvssSeverity.NONE 

117 if score < 4.0: 

118 return CvssSeverity.LOW 

119 if score < 7.0: 

120 return CvssSeverity.MEDIUM 

121 if score < 9.0: 

122 return CvssSeverity.HIGH 

123 return CvssSeverity.CRITICAL 

124 

125 @staticmethod 

126 def parse_cve_db_metric( 

127 json_obj: dict[str, Any] | None, 

128 *, 

129 source: str | None = None, 

130 version: CvssVersion | None = None, 

131 ) -> Optional["CvssMetric"]: 

132 if json_obj is None: 

133 return None 

134 

135 v = json_obj.get("version") 

136 cvss_ver: CvssVersion | None = version 

137 if v == "2.0": 

138 cvss_ver = CvssVersion.V2_0 

139 elif v == "3.0": 

140 cvss_ver = CvssVersion.V3_0 

141 elif v == "3.1": 

142 cvss_ver = CvssVersion.V3_1 

143 elif v == "4.0": 

144 cvss_ver = CvssVersion.V4_0 

145 

146 if cvss_ver is None: 

147 return None 

148 

149 if (version is not None) and (version != cvss_ver): 

150 raise ValueError(f"Unexpected CvssMetric version {version} != {cvss_ver}") 

151 

152 score = json_obj.get("baseScore") 

153 vector_str = json_obj.get("vectorString") 

154 

155 if (score is None) or (vector_str is None): 

156 return None 

157 

158 severity = None 

159 if cvss_ver != CvssVersion.V2_0: 

160 sev: str | None = json_obj.get("baseSeverity") 

161 if sev is None: 

162 return None 

163 try: 

164 severity = CvssSeverity(sev.lower()) 

165 except ValueError: 

166 return None 

167 

168 return CvssMetric( 

169 cvss_ver=cvss_ver, 

170 score=float(score), 

171 vector_str=vector_str, 

172 severity=severity, 

173 source=source, 

174 ) 

175 

176 

177def group_cvss_metrics( 

178 metrics: Iterable[CvssMetric], key: Callable[[CvssMetric], GroupByT] 

179) -> dict[GroupByT, tuple[CvssMetric, ...]]: 

180 """ 

181 Returns CVSS metrics, associated with this CVE, grouped by any custom key 

182 derived from CVSS metric, then sort them by order of importance. 

183 """ 

184 metrics_grouped: dict[GroupByT, list[CvssMetric]] = {} 

185 for metric in metrics: 

186 metrics_grouped.setdefault(key(metric), []).append(metric) 

187 return { 

188 k: tuple( 

189 sorted( 

190 m, 

191 key=lambda x: ( 

192 x.score, 

193 x.cvss_ver.value, 

194 len(x.vector_str), 

195 x.source or "", 

196 x.vector_str, 

197 ), 

198 reverse=True, 

199 ) 

200 ) 

201 for k, m in metrics_grouped.items() 

202 }