Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_aligndb.py: 93%

187 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-06-12 16:31 -0400

1from __future__ import annotations 

2 

3import pathlib 

4import typing 

5 

6from collections import defaultdict 

7from dataclasses import dataclass 

8 

9import h5py 

10import numpy 

11 

12from cogent3.core.alignment import Aligned, Alignment 

13from cogent3.core.location import _DEFAULT_GAP_DTYPE, IndelMap 

14from rich.progress import track 

15 

16from ensembl_lite._db_base import Hdf5Mixin, SqliteDbMixin 

17from ensembl_lite._util import _HDF5_BLOSC2_KWARGS, PathType, sanitise_stableid 

18 

19 

20_no_gaps = numpy.array([], dtype=_DEFAULT_GAP_DTYPE) 

21 

22_GAP_STORE_SUFFIX = "hdf5_blosc2" 

23 

24 

25@dataclass(slots=True) 

26class AlignRecord: 

27 """a record from an AlignDb 

28 

29 Notes 

30 ----- 

31 Can return fields as attributes or like a dict using the field name as 

32 a string. 

33 """ 

34 

35 source: str 

36 block_id: str 

37 species: str 

38 seqid: str 

39 start: int 

40 stop: int 

41 strand: str 

42 gap_spans: numpy.ndarray 

43 

44 def __getitem__(self, item): 

45 return getattr(self, item) 

46 

47 def __setitem__(self, item, value): 

48 setattr(self, item, value) 

49 

50 def __eq__(self, other): 

51 attrs = "source", "block_id", "species", "seqid", "start", "stop", "strand" 

52 for attr in attrs: 

53 if getattr(self, attr) != getattr(other, attr): 

54 return False 

55 return (self.gap_spans == other.gap_spans).all() 

56 

57 @property 

58 def gap_data(self): 

59 if len(self.gap_spans): 

60 gap_pos, gap_lengths = self.gap_spans.T 

61 else: 

62 gap_pos, gap_lengths = _no_gaps.copy(), _no_gaps.copy() 

63 

64 return gap_pos, gap_lengths 

65 

66 

67ReturnType = tuple[str, tuple] # the sql statement and corresponding values 

68 

69 

70class GapStore(Hdf5Mixin): 

71 # store gap data from aligned sequences 

72 def __init__( 

73 self, 

74 source: PathType, 

75 align_name: typing.Optional[str] = None, 

76 mode: str = "r", 

77 in_memory: bool = False, 

78 ): 

79 self.source = pathlib.Path(source) 

80 self.mode = "w-" if mode == "w" else mode 

81 h5_kwargs = ( 

82 dict( 

83 driver="core", 

84 backing_store=False, 

85 ) 

86 if in_memory 

87 else {} 

88 ) 

89 try: 

90 self._file = h5py.File(source, mode=self.mode, **h5_kwargs) 

91 except OSError: 

92 print(source) 

93 raise 

94 

95 if "r" not in self.mode and "align_name" not in self._file.attrs: 

96 assert align_name 

97 self._file.attrs["align_name"] = align_name 

98 if ( 

99 align_name 

100 and (file_species := self._file.attrs.get("align_name", None)) != align_name 

101 ): 

102 raise ValueError(f"{self.source.name!r} {file_species!r} != {align_name}") 

103 self.align_name = self._file.attrs["align_name"] 

104 

105 def add_record(self, *, index: int, gaps: numpy.ndarray): 

106 # dataset names must be strings 

107 index = str(index) 

108 if index in self._file: 

109 stored = self._file[index] 

110 if (gaps == stored).all(): 

111 # already seen this index 

112 return 

113 # but it's different, which is a problem 

114 raise ValueError(f"{index!r} already present but with different gaps") 

115 self._file.create_dataset( 

116 name=index, data=gaps, chunks=True, **_HDF5_BLOSC2_KWARGS 

117 ) 

118 self._file.flush() 

119 

120 def get_record(self, *, index: int) -> numpy.ndarray: 

121 return self._file[str(index)][:] 

122 

123 

124# todo add a table and methods to support storing the species tree used 

125# for the alignment and for getting the species tree 

126class AlignDb(SqliteDbMixin): 

127 table_name = "align" 

128 _align_schema = { 

129 "id": "INTEGER PRIMARY KEY", # used to uniquely identify gap_spans in bound GapStore 

130 "source": "TEXT", # the file path 

131 "block_id": "TEXT", # <source file path>-<alignment number> 

132 "species": "TEXT", 

133 "seqid": "TEXT", 

134 "start": "INTEGER", 

135 "stop": "INTEGER", 

136 "strand": "TEXT", 

137 } 

138 

139 _index_columns = {"align": ("id", "block_id", "seqid", "start", "stop")} 

140 

141 def __init__(self, *, source=":memory:", mode="a"): 

142 """ 

143 Parameters 

144 ---------- 

145 source 

146 location to store the db, defaults to in memory only 

147 """ 

148 # note that data is destroyed 

149 source = pathlib.Path(source) 

150 self.source = source 

151 if source.name == ":memory:": 

152 gap_path = "memory" 

153 kwargs = dict(in_memory=True) 

154 else: 

155 gap_path = source.parent / f"{source.stem}.{_GAP_STORE_SUFFIX}" 

156 kwargs = dict(in_memory=False) 

157 

158 self.gap_store = GapStore( 

159 source=gap_path, align_name=source.stem, mode=mode, **kwargs 

160 ) 

161 self._db = None 

162 self._init_tables() 

163 

164 def add_records(self, records: typing.Sequence[AlignRecord]): 

165 # bulk insert 

166 col_order = [ 

167 row[1] 

168 for row in self.db.execute( 

169 f"PRAGMA table_info({self.table_name})" 

170 ).fetchall() 

171 if row[1] != "id" 

172 ] 

173 val_placeholder = ", ".join("?" * len(col_order)) 

174 sql = f"INSERT INTO {self.table_name} ({', '.join(col_order)}) VALUES ({val_placeholder}) RETURNING id" 

175 

176 for i in range(len(records)): 

177 index = self.db.execute(sql, [records[i][c] for c in col_order]).fetchone() 

178 index = index["id"] 

179 self.gap_store.add_record(index=index, gaps=records[i].gap_spans) 

180 

181 def _get_block_id( 

182 self, 

183 *, 

184 species, 

185 seqid: str, 

186 start: int | None, 

187 stop: int | None, 

188 ) -> list[str]: 

189 sql = f"SELECT block_id from {self.table_name} WHERE species = ? AND seqid = ?" 

190 values = species, seqid 

191 if start is not None and stop is not None: 

192 # as long as start or stop are within the record start/stop, it's a match 

193 sql = f"{sql} AND ((start <= ? AND ? < stop) OR (start <= ? AND ? < stop))" 

194 values += (start, start, stop, stop) 

195 elif start is not None: 

196 # the aligned segment overlaps start 

197 sql = f"{sql} AND start <= ? AND ? < stop" 

198 values += (start, start) 

199 elif stop is not None: 

200 # the aligned segment overlaps stop 

201 sql = f"{sql} AND start <= ? AND ? < stop" 

202 values += (stop, stop) 

203 

204 return self.db.execute(sql, values).fetchall() 

205 

206 def get_records_matching( 

207 self, 

208 *, 

209 species, 

210 seqid: str, 

211 start: int | None = None, 

212 stop: int | None = None, 

213 ) -> typing.Iterable[AlignRecord]: 

214 # make sure python, not numpy, integers 

215 start = None if start is None else int(start) 

216 stop = None if stop is None else int(stop) 

217 

218 # We need the block IDs for all records for a species whose coordinates 

219 # lie in the range (start, stop). We then search for all records with 

220 # each block id. We return full records. 

221 # Client code is responsible for creating Aligned sequence instances 

222 # and the Alignment. 

223 

224 block_ids = [ 

225 r["block_id"] 

226 for r in self._get_block_id( 

227 species=species, seqid=seqid, start=start, stop=stop 

228 ) 

229 ] 

230 

231 values = ", ".join("?" * len(block_ids)) 

232 sql = f"SELECT * from {self.table_name} WHERE block_id IN ({values})" 

233 results = defaultdict(list) 

234 for record in self.db.execute(sql, block_ids).fetchall(): 

235 record = {k: record[k] for k in record.keys()} 

236 index = record.pop("id") 

237 record["gap_spans"] = self.gap_store.get_record(index=index) 

238 results[record["block_id"]].append(AlignRecord(**record)) 

239 

240 return results.values() 

241 

242 def get_species_names(self) -> list[str]: 

243 """return the list of species names""" 

244 return list(self.get_distinct("species")) 

245 

246 

247def get_alignment( 

248 align_db: AlignDb, 

249 genomes: dict, 

250 ref_species: str, 

251 seqid: str, 

252 ref_start: int | None = None, 

253 ref_end: int | None = None, 

254 namer: typing.Callable | None = None, 

255 mask_features: list[str] | None = None, 

256) -> typing.Generator[Alignment]: 

257 """yields cogent3 Alignments""" 

258 

259 if ref_species not in genomes: 

260 raise ValueError(f"unknown species {ref_species!r}") 

261 

262 align_records = align_db.get_records_matching( 

263 species=ref_species, seqid=seqid, start=ref_start, stop=ref_end 

264 ) 

265 

266 # sample the sequences 

267 for block in align_records: 

268 # we get the gaps corresponding to the reference sequence 

269 # and convert them to a IndelMap instance. We then convert 

270 # the ref_start, ref_end into align_start, align_end. Those values are 

271 # used for all other species -- they are converted into sequence 

272 # coordinates for each species -- selecting their sequence, 

273 # building the Aligned instance, and selecting the annotation subset. 

274 for align_record in block: 

275 if align_record.species == ref_species and align_record.seqid == seqid: 

276 # ref_start, ref_end are genomic positions and the align_record 

277 # start / stop are also genomic positions 

278 genome_start = align_record.start 

279 genome_end = align_record.stop 

280 gap_pos, gap_lengths = align_record.gap_data 

281 gaps = IndelMap( 

282 gap_pos=gap_pos, 

283 gap_lengths=gap_lengths, 

284 parent_length=genome_end - genome_start, 

285 ) 

286 

287 # We use the IndelMap object to identify the alignment 

288 # positions the ref_start / ref_end correspond to. The alignment 

289 # positions are used below for slicing each sequence in the 

290 # alignment. 

291 

292 # make sure the sequence start and stop are within this 

293 # aligned block 

294 seq_start = max(ref_start or genome_start, genome_start) 

295 seq_end = min(ref_end or genome_end, genome_end) 

296 # make these coordinates relative to the aligned segment 

297 if align_record.strand == "-": 

298 # if record is on minus strand, then genome stop is 

299 # the alignment start 

300 seq_start, seq_end = genome_end - seq_end, genome_end - seq_start 

301 else: 

302 seq_start = seq_start - genome_start 

303 seq_end = seq_end - genome_start 

304 

305 align_start = gaps.get_align_index(seq_start) 

306 align_end = gaps.get_align_index(seq_end) 

307 break 

308 else: 

309 raise ValueError(f"no matching alignment record for {ref_species!r}") 

310 

311 seqs = [] 

312 for align_record in block: 

313 record_species = align_record.species 

314 genome = genomes[record_species] 

315 # We need to convert the alignment coordinates into sequence 

316 # coordinates for this species. 

317 genome_start = align_record.start 

318 genome_end = align_record.stop 

319 gap_pos, gap_lengths = align_record.gap_data 

320 gaps = IndelMap( 

321 gap_pos=gap_pos, 

322 gap_lengths=gap_lengths, 

323 parent_length=genome_end - genome_start, 

324 ) 

325 

326 # We use the alignment indices derived for the reference sequence 

327 # above 

328 seq_start = gaps.get_seq_index(align_start) 

329 seq_end = gaps.get_seq_index(align_end) 

330 seq_length = seq_end - seq_start 

331 if align_record.strand == "-": 

332 # if it's neg strand, the alignment start is the genome stop 

333 seq_start = gaps.parent_length - seq_end 

334 

335 s = genome.get_seq( 

336 seqid=align_record.seqid, 

337 start=genome_start + seq_start, 

338 stop=genome_start + seq_start + seq_length, 

339 namer=namer, 

340 with_annotations=False, 

341 ) 

342 # we now trim the gaps for this sequence to the sub-alignment 

343 gaps = gaps[align_start:align_end] 

344 

345 if align_record.strand == "-": 

346 s = s.rc() 

347 

348 aligned = Aligned(gaps, s) 

349 seqs.append(aligned) 

350 

351 aln = Alignment(seqs) 

352 aln.annotation_db = genome.annotation_db 

353 if mask_features: 

354 aln = aln.with_masked_annotations(biotypes=mask_features) 

355 yield aln 

356 

357 

358def write_alignments( 

359 *, 

360 align_db: AlignDb, 

361 genomes: dict, 

362 limit: int | None, 

363 mask_features: list[str], 

364 outdir: PathType, 

365 ref_species: str, 

366 stableids: list[str], 

367 show_progress: bool = True, 

368): 

369 # then the coordinates for the id's 

370 ref_genome = genomes[ref_species] 

371 locations = [] 

372 for stableid in stableids: 

373 record = list(ref_genome.annotation_db.get_records_matching(name=stableid)) 

374 if not record: 

375 continue 

376 elif len(record) == 1: 

377 record = record[0] 

378 locations.append( 

379 ( 

380 stableid, 

381 ref_species, 

382 record["seqid"], 

383 record["start"], 

384 record["stop"], 

385 ) 

386 ) 

387 

388 if limit: 

389 locations = locations[:limit] 

390 

391 for stableid, species, seqid, start, end in track( 

392 locations, disable=not show_progress 

393 ): 

394 alignments = list( 

395 get_alignment( 

396 align_db, 

397 genomes, 

398 species, 

399 seqid, 

400 start, 

401 end, 

402 mask_features=mask_features, 

403 ) 

404 ) 

405 stableid = sanitise_stableid(stableid) 

406 if len(alignments) == 1: 

407 outpath = outdir / f"{stableid}.fa.gz" 

408 alignments[0].write(outpath) 

409 elif len(alignments) > 1: 

410 for i, aln in enumerate(alignments): 

411 outpath = outdir / f"{stableid}-{i}.fa.gz" 

412 aln.write(outpath) 

413 

414 return True