Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/install.py: 38%

162 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-12-25 11:36 +1100

1from __future__ import annotations 

2 

3import os 

4import shutil 

5import typing 

6 

7from collections import Counter 

8 

9from cogent3 import load_annotations, make_seq, open_ 

10from cogent3.parse.fasta import MinimalFastaParser 

11from cogent3.parse.table import FilteringParser 

12from cogent3.util import parallel as PAR 

13from rich.progress import Progress, track 

14 

15from ensembl_lite import maf 

16from ensembl_lite._aligndb import AlignDb 

17from ensembl_lite._config import _COMPARA_NAME, Config 

18from ensembl_lite._genomedb import ( 

19 _ANNOTDB_NAME, 

20 _SEQDB_NAME, 

21 CompressedGenomeSeqsDb, 

22) 

23from ensembl_lite._homologydb import HomologyDb 

24from ensembl_lite.convert import seq_to_gap_coords 

25from ensembl_lite.species import Species 

26from ensembl_lite.util import elt_compress_it 

27 

28 

29def _rename(label: str) -> str: 

30 return label.split()[0] 

31 

32 

33def _get_seqs(src: os.PathLike) -> typing.List[typing.Tuple[str, bytes]]: 

34 with open_(src) as infile: 

35 data = infile.read().splitlines() 

36 name_seqs = list(MinimalFastaParser(data)) 

37 labels = Counter(n for n, _ in name_seqs) 

38 if max(labels.values()) != 1: 

39 multiples = {k: c for k, c in labels.items() if c > 1} 

40 msg = f"Some seqid's not unique for {str(src.parent.name)!r} : {multiples}" 

41 raise RuntimeError(msg) 

42 return [(_rename(name), elt_compress_it(seq)) for name, seq in name_seqs] 

43 

44 

45def _load_one_annotations(src_dest: tuple[os.PathLike, os.PathLike]) -> bool: 

46 src, dest = src_dest 

47 if dest.exists(): 

48 return True 

49 

50 _ = load_annotations(path=src, write_path=dest) 

51 return True 

52 

53 

54def _make_src_dest_annotation_paths( 

55 src_dir: os.PathLike, dest_dir: os.PathLike 

56) -> list[tuple[os.PathLike, os.PathLike]]: 

57 src_dir = src_dir / "gff3" 

58 dest = dest_dir / _ANNOTDB_NAME 

59 paths = list(src_dir.glob("*.gff3.gz")) 

60 return [(path, dest) for path in paths] 

61 

62 

63T = typing.Tuple[os.PathLike, typing.List[typing.Tuple[str, bytes]]] 

64 

65 

66def _prepped_seqs( 

67 src_dir: os.PathLike, dest_dir: os.PathLike, progress: Progress, max_workers: int 

68) -> T: 

69 src_dir = src_dir / "fasta" 

70 paths = list(src_dir.glob("*.fa.gz")) 

71 dest = dest_dir / _SEQDB_NAME 

72 all_seqs = [] 

73 

74 common_name = Species.get_common_name(src_dir.parent.name) 

75 msg = f"📚🗜️ {common_name} seqs" 

76 load = progress.add_task(msg, total=len(paths)) 

77 for result in PAR.as_completed(_get_seqs, paths, max_workers=max_workers): 

78 all_seqs.extend(result) 

79 progress.update(load, advance=1, description=msg) 

80 

81 progress.update(load, visible=False) 

82 return dest, all_seqs 

83 

84 

85def local_install_genomes( 

86 config: Config, force_overwrite: bool, max_workers: int | None 

87): 

88 if force_overwrite: 

89 shutil.rmtree(config.install_genomes, ignore_errors=True) 

90 

91 # we create the local installation 

92 config.install_genomes.mkdir(parents=True, exist_ok=True) 

93 # we create subdirectories for each species 

94 for db_name in list(config.db_names): 

95 sp_dir = config.install_genomes / db_name 

96 sp_dir.mkdir(parents=True, exist_ok=True) 

97 

98 # for each species, we identify the download and dest paths for annotations 

99 db_names = list(config.db_names) 

100 if max_workers: 

101 max_workers = min(len(db_names) + 1, max_workers) 

102 

103 # we load the individual gff3 files and write to annotation db's 

104 src_dest_paths = [] 

105 for db_name in config.db_names: 

106 src_dir = config.staging_genomes / db_name 

107 dest_dir = config.install_genomes / db_name 

108 src_dest_paths.extend(_make_src_dest_annotation_paths(src_dir, dest_dir)) 

109 

110 with Progress(transient=True) as progress: 

111 msg = "Installing 🧬 features" 

112 writing = progress.add_task(total=len(src_dest_paths), description=msg) 

113 for _ in PAR.as_completed( 

114 _load_one_annotations, src_dest_paths, max_workers=max_workers 

115 ): 

116 progress.update(writing, description=msg, advance=1) 

117 

118 with Progress(transient=True) as progress: 

119 writing = progress.add_task( 

120 total=len(db_names), description="Installing 🧬", advance=0 

121 ) 

122 for db_name in db_names: 

123 src_dir = config.staging_genomes / db_name 

124 dest_dir = config.install_genomes / db_name 

125 dest, records = _prepped_seqs(src_dir, dest_dir, progress, max_workers) 

126 db = CompressedGenomeSeqsDb(source=dest, species=dest.parent.name) 

127 db.add_compressed_records(records=records) 

128 db.close() 

129 progress.update(writing, description="Installing 🧬", advance=1) 

130 

131 return 

132 

133 

134def seq2gaps(record: dict): 

135 seq = make_seq(record.pop("seq")) 

136 record["gap_spans"], _ = seq_to_gap_coords(seq) 

137 return record 

138 

139 

140def _load_one_align(path: os.PathLike) -> typing.Iterable[dict]: 

141 records = [] 

142 for block_id, align in enumerate(maf.parse(path)): 

143 converted = [] 

144 for maf_name, seq in align.items(): 

145 record = maf_name.to_dict() 

146 record["block_id"] = block_id 

147 record["source"] = path.name 

148 record["seq"] = seq 

149 converted.append(seq2gaps(record)) 

150 records.extend(converted) 

151 return records 

152 

153 

154def local_install_compara( 

155 config: Config, force_overwrite: bool, max_workers: int | None 

156): 

157 if force_overwrite: 

158 shutil.rmtree(config.install_path / _COMPARA_NAME, ignore_errors=True) 

159 

160 for align_name in config.align_names: 

161 src_dir = config.staging_aligns / align_name 

162 dest_dir = config.install_aligns 

163 dest_dir.mkdir(parents=True, exist_ok=True) 

164 # write out to a db with align_name 

165 db = AlignDb(source=(dest_dir / f"{align_name}.sqlitedb")) 

166 records = [] 

167 paths = list(src_dir.glob(f"{align_name}*maf*")) 

168 if max_workers: 

169 max_workers = min(len(paths) + 1, max_workers) 

170 

171 for result in track( 

172 PAR.as_completed(_load_one_align, paths, max_workers=max_workers), 

173 transient=True, 

174 description="Installing alignments", 

175 total=len(paths), 

176 ): 

177 records.extend(result) 

178 

179 db.add_records(records=records) 

180 db.close() 

181 

182 return 

183 

184 

185class LoadHomologies: 

186 def __init__(self, allowed_species: set): 

187 self._allowed_species = allowed_species 

188 # map the Ensembl columns to HomologyDb columns 

189 

190 self.src_cols = ( 

191 "homology_type", 

192 "species", 

193 "gene_stable_id", 

194 "protein_stable_id", 

195 "homology_species", 

196 "homology_gene_stable_id", 

197 "homology_protein_stable_id", 

198 ) 

199 self.dest_col = ( 

200 "relationship", 

201 "species_1", 

202 "gene_id_1", 

203 "prot_id_1", 

204 "species_2", 

205 "gene_id_2", 

206 "prot_id_2", 

207 "source", 

208 ) 

209 self._reader = FilteringParser( 

210 row_condition=self._matching_species, columns=self.src_cols, sep="\t" 

211 ) 

212 

213 def _matching_species(self, row): 

214 return {row[1], row[4]} <= self._allowed_species 

215 

216 def __call__(self, paths: typing.Iterable[os.PathLike]) -> list: 

217 final = [] 

218 for path in paths: 

219 with open_(path) as infile: 

220 # we bulk load because it's faster than the default line-by-line 

221 # iteration on a file 

222 data = infile.read().splitlines() 

223 

224 rows = list(self._reader(data)) 

225 header = rows.pop(0) 

226 assert list(header) == list(self.src_cols), (header, self.src_cols) 

227 rows = [r + [path.name] for r in rows] 

228 final.extend(rows) 

229 

230 return final 

231 

232 

233def local_install_homology( 

234 config: Config, force_overwrite: bool, max_workers: int | None 

235): 

236 if force_overwrite: 

237 shutil.rmtree(config.install_homologies, ignore_errors=True) 

238 

239 config.install_homologies.mkdir(parents=True, exist_ok=True) 

240 

241 outpath = config.install_homologies / "homologies.sqlitedb" 

242 db = HomologyDb(source=outpath) 

243 

244 dirnames = [] 

245 for sp in config.db_names: 

246 path = config.staging_homologies / sp 

247 dirnames.append(list(path.glob("*.tsv.gz"))) 

248 

249 loader = LoadHomologies(allowed_species=set(config.db_names)) 

250 # On test cases, only 30% speedup from running in parallel due to overhead 

251 # of pickling the data, but considerable increase in memory. So, run 

252 # in serial to avoid memory issues since it's reasonably fast anyway. 

253 if max_workers: 

254 max_workers = min(len(dirnames) + 1, max_workers) 

255 

256 with Progress(transient=True) as progress: 

257 msg = "Installing homologies" 

258 writing = progress.add_task(total=len(dirnames), description=msg, advance=0) 

259 for rows in PAR.as_completed(loader, dirnames, max_workers=max_workers): 

260 db.add_records(records=rows, col_order=loader.dest_col) 

261 del rows 

262 progress.update(writing, description=msg, advance=1) 

263 

264 no_records = len(db) == 0 

265 db.close() 

266 if no_records: 

267 outpath.unlink()