Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_aligndb.py: 96%

233 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2024-03-25 13:40 +1100

1from __future__ import annotations 

2 

3import os 

4import typing 

5 

6from collections import defaultdict 

7from dataclasses import dataclass 

8 

9import numpy 

10 

11from cogent3.core.alignment import Alignment 

12from numpy.typing import NDArray 

13from rich.progress import track 

14 

15from ensembl_lite._db_base import SqliteDbMixin, _compressed_array_proxy 

16from ensembl_lite._util import sanitise_stableid 

17 

18 

19@dataclass(slots=True) 

20class AlignRecord: 

21 """a record from an AlignDb 

22 

23 Notes 

24 ----- 

25 Can return fields as attributes or like a dict using the field name as 

26 a string. 

27 """ 

28 

29 source: str 

30 block_id: str 

31 species: str 

32 seqid: str 

33 start: int 

34 stop: int 

35 strand: str 

36 gap_spans: numpy.ndarray 

37 

38 def __getitem__(self, item): 

39 return getattr(self, item) 

40 

41 def __setitem__(self, item, value): 

42 setattr(self, item, value) 

43 

44 def __eq__(self, other): 

45 attrs = "source", "block_id", "species", "seqid", "start", "stop", "strand" 

46 for attr in attrs: 

47 if getattr(self, attr) != getattr(other, attr): 

48 return False 

49 return (self.gap_spans == other.gap_spans).all() 

50 

51 

52ReturnType = tuple[str, tuple] # the sql statement and corresponding values 

53 

54 

55# todo add a table and methods to support storing the species tree used 

56# for the alignment and for getting the species tree 

57class AlignDb(SqliteDbMixin): 

58 table_name = "align" 

59 _align_schema = { 

60 "source": "TEXT", # the file path 

61 "block_id": "TEXT", # <source file path>-<alignment number> 

62 "species": "TEXT", 

63 "seqid": "TEXT", 

64 "start": "INTEGER", 

65 "stop": "INTEGER", 

66 "strand": "TEXT", 

67 "gap_spans": "compressed_array", 

68 } 

69 

70 def __init__(self, *, source=":memory:"): 

71 """ 

72 Parameters 

73 ---------- 

74 source 

75 location to store the db, defaults to in memory only 

76 """ 

77 # note that data is destroyed 

78 self.source = source 

79 self._db = None 

80 self._init_tables() 

81 

82 def add_records(self, records: typing.Sequence[AlignRecord]): 

83 # bulk insert 

84 col_order = [ 

85 row[1] 

86 for row in self.db.execute( 

87 f"PRAGMA table_info({self.table_name})" 

88 ).fetchall() 

89 ] 

90 for i in range(len(records)): 

91 records[i].gap_spans = _compressed_array_proxy(records[i].gap_spans) 

92 records[i] = [records[i][c] for c in col_order] 

93 

94 val_placeholder = ", ".join("?" * len(col_order)) 

95 sql = f"INSERT INTO {self.table_name} ({', '.join(col_order)}) VALUES ({val_placeholder})" 

96 self.db.executemany(sql, records) 

97 

98 def _get_block_id( 

99 self, 

100 *, 

101 species, 

102 seqid: str, 

103 start: int | None, 

104 stop: int | None, 

105 ) -> list[str]: 

106 sql = f"SELECT block_id from {self.table_name} WHERE species = ? AND seqid = ?" 

107 values = species, seqid 

108 if start is not None and stop is not None: 

109 # as long as start or stop are within the record start/stop, it's a match 

110 sql = f"{sql} AND ((start <= ? AND ? < stop) OR (start <= ? AND ? < stop))" 

111 values += (start, start, stop, stop) 

112 elif start is not None: 

113 # the aligned segment overlaps start 

114 sql = f"{sql} AND start <= ? AND ? < stop" 

115 values += (start, start) 

116 elif stop is not None: 

117 # the aligned segment overlaps stop 

118 sql = f"{sql} AND start <= ? AND ? < stop" 

119 values += (stop, stop) 

120 

121 return self.db.execute(sql, values).fetchall() 

122 

123 def get_records_matching( 

124 self, 

125 *, 

126 species, 

127 seqid: str, 

128 start: int | None = None, 

129 stop: int | None = None, 

130 ) -> typing.Iterable[AlignRecord]: 

131 # make sure python, not numpy, integers 

132 start = None if start is None else int(start) 

133 stop = None if stop is None else int(stop) 

134 

135 # We need the block IDs for all records for a species whose coordinates 

136 # lie in the range (start, stop). We then search for all records with 

137 # each block id. We return full records. 

138 # Client code is responsible for creating Aligned sequence instances 

139 # and the Alignment. 

140 

141 block_ids = [ 

142 r["block_id"] 

143 for r in self._get_block_id( 

144 species=species, seqid=seqid, start=start, stop=stop 

145 ) 

146 ] 

147 

148 values = ", ".join("?" * len(block_ids)) 

149 sql = f"SELECT * from {self.table_name} WHERE block_id IN ({values})" 

150 results = defaultdict(list) 

151 for record in self.db.execute(sql, block_ids).fetchall(): 

152 record = {k: record[k] for k in record.keys()} 

153 results[record["block_id"]].append(AlignRecord(**record)) 

154 

155 return results.values() 

156 

157 def get_species_names(self) -> list[str]: 

158 """return the list of species names""" 

159 return list(self.get_distinct("species")) 

160 

161 

162def get_alignment( 

163 align_db: AlignDb, 

164 genomes: dict, 

165 ref_species: str, 

166 seqid: str, 

167 ref_start: int | None = None, 

168 ref_end: int | None = None, 

169 namer: typing.Callable | None = None, 

170 mask_features: list[str] | None = None, 

171) -> typing.Generator[Alignment]: 

172 """yields cogent3 Alignments""" 

173 from ensembl_lite._convert import gap_coords_to_seq 

174 

175 if ref_species not in genomes: 

176 raise ValueError(f"unknown species {ref_species!r}") 

177 

178 align_records = align_db.get_records_matching( 

179 species=ref_species, seqid=seqid, start=ref_start, stop=ref_end 

180 ) 

181 

182 # sample the sequences 

183 for block in align_records: 

184 # we get the gaps corresponding to the reference sequence 

185 # and convert them to a GapPosition instance. We then convert 

186 # the ref_start, ref_end into align_start, align_end. Those values are 

187 # used for all other species -- they are converted into sequence 

188 # coordinates for each species -- selecting their sequence, 

189 # building the Aligned instance, and selecting the annotation subset. 

190 for align_record in block: 

191 if align_record.species == ref_species and align_record.seqid == seqid: 

192 # ref_start, ref_end are genomic positions and the align_record 

193 # start / stop are also genomic positions 

194 genome_start = align_record.start 

195 genome_end = align_record.stop 

196 gaps = GapPositions( 

197 align_record.gap_spans, seq_length=genome_end - genome_start 

198 ) 

199 

200 # We use the GapPosition object to identify the alignment 

201 # positions the ref_start / ref_end correspond to. The alignment 

202 # positions are used below for slicing each sequence in the 

203 # alignment. 

204 

205 # make sure the sequence start and stop are within this 

206 # aligned block 

207 seq_start = max(ref_start or genome_start, genome_start) 

208 seq_end = min(ref_end or genome_end, genome_end) 

209 # make these coordinates relative to the aligned segment 

210 if align_record.strand == "-": 

211 # if record is on minus strand, then genome stop is 

212 # the alignment start 

213 seq_start, seq_end = genome_end - seq_end, genome_end - seq_start 

214 else: 

215 seq_start = seq_start - genome_start 

216 seq_end = seq_end - genome_start 

217 

218 align_start = gaps.from_seq_to_align_index(seq_start) 

219 align_end = gaps.from_seq_to_align_index(seq_end) 

220 break 

221 else: 

222 raise ValueError(f"no matching alignment record for {ref_species!r}") 

223 

224 seqs = [] 

225 for align_record in block: 

226 record_species = align_record.species 

227 genome = genomes[record_species] 

228 # We need to convert the alignment coordinates into sequence 

229 # coordinates for this species. 

230 genome_start = align_record.start 

231 genome_end = align_record.stop 

232 gaps = GapPositions( 

233 align_record.gap_spans, seq_length=genome_end - genome_start 

234 ) 

235 

236 # We use the alignment indices derived for the reference sequence 

237 # above 

238 seq_start = gaps.from_align_to_seq_index(align_start) 

239 seq_end = gaps.from_align_to_seq_index(align_end) 

240 seq_length = seq_end - seq_start 

241 if align_record.strand == "-": 

242 # if it's neg strand, the alignment start is the genome stop 

243 seq_start = gaps.seq_length - seq_end 

244 

245 s = genome.get_seq( 

246 seqid=align_record.seqid, 

247 start=genome_start + seq_start, 

248 stop=genome_start + seq_start + seq_length, 

249 namer=namer, 

250 ) 

251 # we now trim the gaps for this sequence to the sub-alignment 

252 gaps = gaps[align_start:align_end] 

253 

254 if align_record.strand == "-": 

255 s = s.rc() 

256 

257 aligned = gap_coords_to_seq(gaps.gaps, s) 

258 seqs.append(aligned) 

259 

260 aln = Alignment(seqs) 

261 if mask_features: 

262 aln = aln.with_masked_annotations(biotypes=mask_features) 

263 yield aln 

264 

265 

266def _gap_spans( 

267 gap_pos: NDArray[int], gap_cum_lengths: NDArray[int] 

268) -> tuple[NDArray[int], NDArray[int]]: 

269 """returns 1D arrays in alignment coordinates of 

270 gap start, gap stop""" 

271 if not len(gap_pos): 

272 r = numpy.array([], dtype=gap_pos.dtype) 

273 return r, r 

274 

275 sum_to_prev = 0 

276 gap_starts = numpy.empty(gap_pos.shape[0], dtype=gap_pos.dtype) 

277 gap_ends = numpy.empty(gap_pos.shape[0], dtype=gap_pos.dtype) 

278 for i, pos in enumerate(gap_pos): 

279 gap_starts[i] = sum_to_prev + pos 

280 gap_ends[i] = pos + gap_cum_lengths[i] 

281 sum_to_prev = gap_cum_lengths[i] 

282 

283 return numpy.array(gap_starts), numpy.array(gap_ends) 

284 

285 

286@dataclass(slots=True) 

287class GapPositions: 

288 """records gap insertion index and length 

289 

290 Notes 

291 ----- 

292 This very closely parallels the cogent3.core.location.Map class, 

293 but is more memory efficient. When that class has been updated, 

294 this can be removed. 

295 """ 

296 

297 # 2D numpy int array, 

298 # each row is a gap 

299 # column 0 is sequence index of gap **relative to the alignment** 

300 # column 1 is gap length 

301 gaps: NDArray[NDArray[int]] 

302 # length of the underlying sequence 

303 seq_length: int 

304 

305 def __post_init__(self): 

306 if not len(self.gaps): 

307 # can get a zero length array with shape != (0, 0) 

308 # e.g. by slicing gaps[:0], but since there's no data 

309 # we force it to have zero elements on both dimensions 

310 self.gaps = self.gaps.reshape((0, 0)) 

311 

312 # make gap array immutable 

313 self.gaps.flags.writeable = False 

314 

315 def __getitem__(self, item: slice) -> typing.Self: 

316 # we're assuming that this gap object is associated with a sequence 

317 # that will also be sliced. Hence, we need to shift the gap insertion 

318 # positions relative to this newly sliced sequence. 

319 if item.step: 

320 raise NotImplementedError( 

321 f"{type(self).__name__!r} does not support strides" 

322 ) 

323 start = item.start or 0 

324 stop = item.stop or len(self) 

325 gaps = self.gaps.copy() 

326 if start < 0 or stop < 0: 

327 raise NotImplementedError( 

328 f"{type(self).__name__!r} does not support negative indexes" 

329 ) 

330 

331 if not len(gaps): 

332 cum_lengths = numpy.array([], dtype=gaps.dtype) 

333 pos = cum_lengths 

334 else: 

335 cum_lengths = gaps[:, 1].cumsum() 

336 pos = gaps[:, 0] 

337 

338 gap_starts, gap_ends = _gap_spans(pos, cum_lengths) 

339 

340 if not len(gaps) or stop < gap_starts[0] or start >= gap_ends[-1]: 

341 return self.__class__( 

342 gaps=numpy.array([], dtype=gaps.dtype), seq_length=stop - start 

343 ) 

344 

345 # second column of spans is gap ends 

346 # which gaps does it fall between 

347 l = numpy.searchsorted(gap_ends, start, side="left") 

348 if gap_starts[l] <= start < gap_ends[l]: 

349 # start is within a gap 

350 begin = l 

351 begin_diff = start - gap_starts[l] 

352 gaps[l, 1] -= begin_diff 

353 shift = start - cum_lengths[l - 1] - begin_diff if l else gaps[l, 0] 

354 elif start == gap_ends[l]: 

355 # at gap boundary 

356 begin = l + 1 

357 shift = start - cum_lengths[l] 

358 else: 

359 # not within a gap 

360 begin = l 

361 shift = start - cum_lengths[l - 1] if l else start 

362 

363 # start search for stop from l index 

364 r = numpy.searchsorted(gap_ends[l:], stop, side="right") + l 

365 if r == len(gaps): 

366 # stop is after last gap 

367 end = r 

368 elif gap_starts[r] < stop <= gap_ends[r]: 

369 # within gap 

370 end = r + 1 

371 end_diff = gap_ends[r] - stop 

372 gaps[r, 1] -= end_diff 

373 else: 

374 end = r 

375 

376 result = gaps[begin:end] 

377 result[:, 0] -= shift 

378 if not len(result): 

379 # no gaps 

380 seq_length = stop - start 

381 else: 

382 seq_length = self.from_align_to_seq_index( 

383 stop 

384 ) - self.from_align_to_seq_index(start) 

385 

386 return self.__class__(gaps=result, seq_length=seq_length) 

387 

388 def __len__(self): 

389 total_gaps = self.gaps[:, 1].sum() if len(self.gaps) else 0 

390 return total_gaps + self.seq_length 

391 

392 def from_seq_to_align_index(self, seq_index: int) -> int: 

393 """convert a sequence index into an alignment index""" 

394 if seq_index < 0: 

395 raise NotImplementedError(f"{seq_index} negative align_index not supported") 

396 

397 if not len(self.gaps) or seq_index < self.gaps[0, 0]: 

398 return seq_index 

399 

400 # this statement replace when we change self.gaps to include [gap pos, cumsum] 

401 cum_gap_lengths = self.gaps[:, 1].cumsum() 

402 gap_pos = self.gaps[:, 0] 

403 

404 if seq_index >= self.gaps[-1, 0]: 

405 return seq_index + cum_gap_lengths[-1] 

406 

407 # find gap position before seq_index 

408 index = numpy.searchsorted(gap_pos, seq_index, side="left") 

409 if seq_index < gap_pos[index]: 

410 gap_lengths = cum_gap_lengths[index - 1] if index else 0 

411 else: 

412 gap_lengths = cum_gap_lengths[index] 

413 

414 return seq_index + gap_lengths 

415 

416 def from_align_to_seq_index(self, align_index: int) -> int: 

417 """converts alignment index to sequence index""" 

418 if align_index < 0: 

419 raise NotImplementedError( 

420 f"{align_index} negative align_index not supported" 

421 ) 

422 

423 if not len(self.gaps) or align_index < self.gaps[0, 0]: 

424 return align_index 

425 

426 # replace the following call when we change self.gaps to include 

427 # [gap pos, cumsum] 

428 # these are alignment indices for gaps 

429 cum_lengths = self.gaps[:, 1].cumsum() 

430 gap_starts, gap_ends = _gap_spans(self.gaps[:, 0], cum_lengths) 

431 if align_index >= gap_ends[-1]: 

432 return align_index - cum_lengths[-1] 

433 

434 index = numpy.searchsorted(gap_ends, align_index, side="left") 

435 if align_index < gap_starts[index]: 

436 # before the gap at index 

437 return align_index - cum_lengths[index - 1] 

438 

439 if align_index == gap_ends[index]: 

440 # after the gap at index 

441 return align_index - cum_lengths[index] 

442 

443 if gap_starts[index] <= align_index < gap_ends[index]: 

444 # within the gap at index 

445 # so the gap insertion position is the sequence position 

446 return self.gaps[index, 0] 

447 

448 

449def write_alignments( 

450 *, 

451 align_db: AlignDb, 

452 genomes: dict, 

453 limit: int | None, 

454 mask_features: list[str], 

455 outdir: os.PathLike, 

456 ref_species: str, 

457 stableids: list[str], 

458 show_progress: bool = True, 

459): 

460 # then the coordinates for the id's 

461 ref_genome = genomes[ref_species] 

462 locations = [] 

463 for stableid in stableids: 

464 record = list(ref_genome.annotation_db.get_records_matching(name=stableid)) 

465 if not record: 

466 continue 

467 elif len(record) == 1: 

468 record = record[0] 

469 locations.append( 

470 ( 

471 stableid, 

472 ref_species, 

473 record["seqid"], 

474 record["start"], 

475 record["stop"], 

476 ) 

477 ) 

478 

479 if limit: 

480 locations = locations[:limit] 

481 

482 for stableid, species, seqid, start, end in track( 

483 locations, disable=not show_progress 

484 ): 

485 alignments = list( 

486 get_alignment( 

487 align_db, 

488 genomes, 

489 species, 

490 seqid, 

491 start, 

492 end, 

493 mask_features=mask_features, 

494 ) 

495 ) 

496 stableid = sanitise_stableid(stableid) 

497 if len(alignments) == 1: 

498 outpath = outdir / f"{stableid}.fa.gz" 

499 alignments[0].write(outpath) 

500 elif len(alignments) > 1: 

501 for i, aln in enumerate(alignments): 

502 outpath = outdir / f"{stableid}-{i}.fa.gz" 

503 aln.write(outpath) 

504 

505 return True