Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_util.py: 71%

228 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-06-12 16:32 -0400

1from __future__ import annotations 

2 

3import contextlib 

4import functools 

5import inspect 

6import os 

7import pathlib 

8import re 

9import shutil 

10import subprocess 

11import sys 

12import typing 

13import uuid 

14 

15from hashlib import md5 

16from tempfile import mkdtemp 

17from typing import IO, Callable, Union 

18 

19import blosc2 

20import hdf5plugin 

21import numba 

22import numpy 

23 

24from cogent3.app.composable import define_app 

25from cogent3.util.parallel import as_completed 

26 

27 

28PathType = Union[str, pathlib.Path, os.PathLike] 

29 

30_HDF5_BLOSC2_KWARGS = hdf5plugin.Blosc2( 

31 cname="blosclz", clevel=9, filters=hdf5plugin.Blosc2.BITSHUFFLE 

32) 

33 

34 

35def md5sum(data: bytes, *args) -> str: 

36 """computes MD5SUM 

37 

38 Notes 

39 ----- 

40 *args is for signature compatability with checksum 

41 """ 

42 return md5(data).hexdigest() 

43 

44 

45# based on https://www.reddit.com/r/learnpython/comments/9bpgjl/implementing_bsd_16bit_checksum/ 

46# and https://www.gnu.org/software/coreutils/manual/html_node/sum-invocation.html#sum-invocation 

47@numba.jit(nopython=True) 

48def checksum(data: bytes, size: int): # pragma: no cover 

49 """computes BSD style checksum""" 

50 # equivalent to command line BSD sum 

51 nb = numpy.ceil(size / 1024) 

52 cksum = 0 

53 for c in data: 

54 cksum = (cksum >> 1) + ((cksum & 1) << 15) 

55 cksum += c 

56 cksum &= 0xFFFF 

57 return cksum, int(nb) 

58 

59 

60def _get_resource_dir() -> PathType: 

61 """returns path to resource directory""" 

62 if "ENSEMBLDBRC" in os.environ: 

63 path = os.environ["ENSEMBLDBRC"] 

64 else: 

65 from ensembl_lite import data 

66 

67 path = pathlib.Path(data.__file__).parent 

68 

69 path = pathlib.Path(path).expanduser().absolute() 

70 if not path.exists(): 

71 raise ValueError(f"ENSEMBLDBRC directory {str(path)!r} does not exist") 

72 

73 return pathlib.Path(path) 

74 

75 

76def get_resource_path(resource: PathType) -> PathType: 

77 path = ENSEMBLDBRC / resource 

78 assert path.exists() 

79 return path 

80 

81 

82# the following is where essential files live, such as 

83# the species/common name map and sample download.cfg 

84ENSEMBLDBRC = _get_resource_dir() 

85 

86 

87def exec_command(cmnd, stdout=subprocess.PIPE, stderr=subprocess.PIPE): 

88 """executes shell command and returns stdout if completes exit code 0 

89 

90 Parameters 

91 ---------- 

92 

93 cmnd : str 

94 shell command to be executed 

95 stdout, stderr : streams 

96 Default value (PIPE) intercepts process output, setting to None 

97 blocks this.""" 

98 proc = subprocess.Popen(cmnd, shell=True, stdout=stdout, stderr=stderr) 

99 out, err = proc.communicate() 

100 if proc.returncode != 0: 

101 msg = err 

102 sys.stderr.writelines(f"FAILED: {cmnd}\n{msg}") 

103 sys.exit(proc.returncode) 

104 return out.decode("utf8") if out is not None else None 

105 

106 

107class CaseInsensitiveString(str): 

108 """A case-insensitive string class. Comparisons are also case-insensitive.""" 

109 

110 def __new__(cls, arg, h=None): 

111 n = str.__new__(cls, str(arg)) 

112 n._lower = "".join(list(n)).lower() 

113 n._hash = hash(n._lower) 

114 return n 

115 

116 def __eq__(self, other): 

117 return self._lower == "".join(list(other)).lower() 

118 

119 def __hash__(self): 

120 # dict hashing done via lower case 

121 return self._hash 

122 

123 def __str__(self): 

124 return "".join(list(self)) 

125 

126 

127def load_ensembl_checksum(path: PathType) -> dict: 

128 """loads the BSD checksums from Ensembl CHECKSUMS file""" 

129 result = {} 

130 for line in path.read_text().splitlines(): 

131 line = line.strip() 

132 if not line: 

133 continue 

134 s, b, p = line.split() 

135 result[p] = int(s), int(b) 

136 result.pop("README", None) 

137 return result 

138 

139 

140def load_ensembl_md5sum(path: PathType) -> dict: 

141 """loads the md5 sum from Ensembl MD5SUM file""" 

142 result = {} 

143 for line in path.read_text().splitlines(): 

144 line = line.strip() 

145 if not line: 

146 continue 

147 s, p = line.split() 

148 result[p] = s 

149 result.pop("README", None) 

150 return result 

151 

152 

153class atomic_write: 

154 """performs atomic write operations, cleans up if fails""" 

155 

156 def __init__(self, path: PathType, tmpdir=None, mode="wb", encoding=None): 

157 """ 

158 

159 Parameters 

160 ---------- 

161 path 

162 path to file 

163 tmpdir 

164 directory where temporary file will be created 

165 mode 

166 file writing mode 

167 encoding 

168 text encoding 

169 """ 

170 path = pathlib.Path(path).expanduser() 

171 

172 self._path = path 

173 self._mode = mode 

174 self._file = None 

175 self._encoding = encoding 

176 self._tmppath = self._make_tmppath(tmpdir) 

177 

178 self.succeeded = None 

179 self._close_func = self._close_rename_standard 

180 

181 def _make_tmppath(self, tmpdir): 

182 """returns path of temporary file 

183 

184 Parameters 

185 ---------- 

186 tmpdir: Path 

187 to directory 

188 

189 Returns 

190 ------- 

191 full path to a temporary file 

192 

193 Notes 

194 ----- 

195 Uses a random uuid as the file name, adds suffixes from path 

196 """ 

197 suffixes = "".join(self._path.suffixes) 

198 parent = self._path.parent 

199 name = f"{uuid.uuid4()}{suffixes}" 

200 tmpdir = ( 

201 pathlib.Path(mkdtemp(dir=parent)) 

202 if tmpdir is None 

203 else pathlib.Path(tmpdir) 

204 ) 

205 

206 if not tmpdir.exists(): 

207 raise FileNotFoundError(f"{tmpdir} directory does not exist") 

208 

209 return tmpdir / name 

210 

211 def _get_fileobj(self): 

212 """returns file to be written to""" 

213 if self._file is None: 

214 self._file = open(self._tmppath, self._mode) 

215 

216 return self._file 

217 

218 def __enter__(self) -> IO: 

219 return self._get_fileobj() 

220 

221 def _close_rename_standard(self, src): 

222 dest = pathlib.Path(self._path) 

223 try: 

224 dest.unlink() 

225 except FileNotFoundError: 

226 pass 

227 finally: 

228 src.rename(dest) 

229 

230 shutil.rmtree(src.parent) 

231 

232 def __exit__(self, exc_type, exc_val, exc_tb): 

233 self._file.close() 

234 if exc_type is None: 

235 self._close_func(self._tmppath) 

236 self.succeeded = True 

237 else: 

238 self.succeeded = False 

239 

240 shutil.rmtree(self._tmppath.parent, ignore_errors=True) 

241 

242 def write(self, text): 

243 """writes text to file""" 

244 fileobj = self._get_fileobj() 

245 fileobj.write(text) 

246 

247 def close(self): 

248 """closes file""" 

249 self.__exit__(None, None, None) 

250 

251 

252_sig_load_funcs = dict(CHECKSUMS=load_ensembl_checksum, MD5SUM=load_ensembl_md5sum) 

253_sig_calc_funcs = dict(CHECKSUMS=checksum, MD5SUM=md5sum) 

254_dont_checksum = re.compile("(CHECKSUMS|MD5SUM|README)") 

255_sig_file = re.compile("(CHECKSUMS|MD5SUM)") 

256 

257 

258def dont_checksum(path: PathType) -> bool: 

259 return _dont_checksum.search(str(path)) is not None 

260 

261 

262@functools.singledispatch 

263def is_signature(path: PathType) -> bool: 

264 return _sig_file.search(path.name) is not None 

265 

266 

267@is_signature.register 

268def _(path: str) -> bool: 

269 return _sig_file.search(path) is not None 

270 

271 

272@functools.singledispatch 

273def get_sig_calc_func(sig_path) -> Callable: 

274 """returns signature calculating function based on Ensembl path name""" 

275 raise NotImplementedError(f"{type(sig_path)} not supported") 

276 

277 

278@get_sig_calc_func.register 

279def _(sig_path: str) -> Callable: 

280 return _sig_calc_funcs[sig_path] 

281 

282 

283def get_signature_data(path: PathType) -> Callable: 

284 return _sig_load_funcs[path.name](path) 

285 

286 

287def rich_display(c3t, title_justify="left"): 

288 """converts a cogent3 Table to a Rich Table and displays it""" 

289 from rich.console import Console 

290 from rich.table import Table 

291 

292 cols = c3t.columns 

293 columns = [] 

294 for c in c3t.header: 

295 if tmplt := c3t._column_templates.get(c, None): 

296 col = [tmplt(v) for v in cols[c]] 

297 else: 

298 col = cols[c] 

299 columns.append(col) 

300 

301 rich_table = Table( 

302 title=c3t.title, 

303 highlight=True, 

304 title_justify=title_justify, 

305 title_style="bold blue", 

306 ) 

307 for col in c3t.header: 

308 numeric_type = any(v in cols[col].dtype.name for v in ("int", "float")) 

309 j = "right" if numeric_type else "left" 

310 rich_table.add_column(col, justify=j, no_wrap=numeric_type) 

311 

312 for row in zip(*columns): 

313 rich_table.add_row(*row) 

314 

315 console = Console() 

316 console.print(rich_table) 

317 

318 

319_seps = re.compile(r"[-._\s]") 

320 

321 

322def _name_parts(path: str) -> list[str]: 

323 return _seps.split(pathlib.Path(path).name.lower()) 

324 

325 

326def _simple_check(align_parts: str, tree_parts: str) -> int: 

327 """evaluates whether the start of the two paths match""" 

328 matches = 0 

329 for a, b in zip(align_parts, tree_parts): 

330 if a != b: 

331 break 

332 matches += 1 

333 

334 return matches 

335 

336 

337def trees_for_aligns(aligns, trees) -> dict[str, str]: 

338 aligns = {p: _name_parts(p) for p in aligns} 

339 trees = {p: _name_parts(p) for p in trees} 

340 result = {} 

341 for align, align_parts in aligns.items(): 

342 dists = [ 

343 (_simple_check(align_parts, tree_parts), tree) 

344 for tree, tree_parts in trees.items() 

345 ] 

346 v, p = max(dists) 

347 if v == 0: 

348 raise ValueError(f"no tree for {align}") 

349 

350 result[align] = p 

351 

352 return result 

353 

354 

355@define_app 

356def _str_to_bytes(data: str) -> bytes: 

357 """converts string to bytes""" 

358 return data.encode("utf8") 

359 

360 

361@define_app 

362def _bytes_to_str(data: bytes) -> str: 

363 """converts bytes into string""" 

364 return data.decode("utf8") 

365 

366 

367@define_app 

368def blosc_compress_it(data: bytes) -> bytes: 

369 return blosc2.compress(data, clevel=9, filter=blosc2.Filter.SHUFFLE) 

370 

371 

372@define_app 

373def blosc_decompress_it(data: bytes, as_bytearray=True) -> bytes: 

374 return bytes(blosc2.decompress(data, as_bytearray=as_bytearray)) 

375 

376 

377elt_compress_it = _str_to_bytes() + blosc_compress_it() 

378elt_decompress_it = blosc_decompress_it() + _bytes_to_str() 

379 

380_biotypes = re.compile(r"(gene|transcript|exon|mRNA|rRNA|protein):") 

381 

382 

383def sanitise_stableid(stableid: str) -> str: 

384 """remove <biotype>:E.. from Ensembl stable ID 

385 

386 Notes 

387 ----- 

388 The GFF3 files from Ensembl store identifiers as <biotype>:<identifier>, 

389 this function removes redundant biotype component. 

390 """ 

391 return _biotypes.sub("", stableid) 

392 

393 

394@contextlib.contextmanager 

395def fake_wake(*args, **kwargs): 

396 yield 

397 

398 

399class SerialisableMixin: 

400 """mixin class, adds a self._init_vals dict attribute which 

401 contains the keyword/arg mapping of arguments provided to the 

402 constructor""" 

403 

404 def __new__(cls, *args, **kwargs): 

405 obj = object.__new__(cls) 

406 init_sig = inspect.signature(cls.__init__) 

407 bargs = init_sig.bind_partial(cls, *args, **kwargs) 

408 bargs.apply_defaults() 

409 init_vals = bargs.arguments 

410 init_vals.pop("self", None) 

411 obj._init_vals = init_vals 

412 return obj 

413 

414 

415def get_iterable_tasks( 

416 *, 

417 func: typing.Callable, 

418 series: typing.Sequence, 

419 max_workers: typing.Optional[int], 

420 **kwargs, 

421) -> typing.Iterator: 

422 if max_workers == 1: 

423 return map(func, series) 

424 else: 

425 return as_completed(func, series, max_workers=max_workers, **kwargs) 

426 

427 

428# From http://mart.ensembl.org/info/genome/stable_ids/prefixes.html 

429# The Ensembl stable id structure is 

430# [species prefix][feature type prefix][a unique eleven digit number] 

431# feature type prefixes are 

432# E exon 

433# FM Ensembl protein family 

434# G gene 

435# GT gene tree 

436# P protein 

437# R regulatory feature 

438# T transcript 

439_feature_type_1 = {"E", "G", "P", "R", "T"} 

440_feature_type_2 = {"FM", "GT"} 

441 

442 

443def get_stableid_prefix(stableid: str) -> str: 

444 """returns the prefix component of a stableid""" 

445 if len(stableid) < 15: 

446 raise ValueError(f"{stableid!r} too short") 

447 

448 if stableid[-13:-11] in _feature_type_2: 

449 return stableid[:-13] 

450 if stableid[-12] not in _feature_type_1: 

451 raise ValueError(f"{stableid!r} has unknown feature type {stableid[-13]!r}") 

452 return stableid[:-12]