Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_aligndb.py: 96%
233 statements
« prev ^ index » next coverage.py v7.2.3, created at 2024-03-25 13:40 +1100
« prev ^ index » next coverage.py v7.2.3, created at 2024-03-25 13:40 +1100
1from __future__ import annotations
3import os
4import typing
6from collections import defaultdict
7from dataclasses import dataclass
9import numpy
11from cogent3.core.alignment import Alignment
12from numpy.typing import NDArray
13from rich.progress import track
15from ensembl_lite._db_base import SqliteDbMixin, _compressed_array_proxy
16from ensembl_lite._util import sanitise_stableid
19@dataclass(slots=True)
20class AlignRecord:
21 """a record from an AlignDb
23 Notes
24 -----
25 Can return fields as attributes or like a dict using the field name as
26 a string.
27 """
29 source: str
30 block_id: str
31 species: str
32 seqid: str
33 start: int
34 stop: int
35 strand: str
36 gap_spans: numpy.ndarray
38 def __getitem__(self, item):
39 return getattr(self, item)
41 def __setitem__(self, item, value):
42 setattr(self, item, value)
44 def __eq__(self, other):
45 attrs = "source", "block_id", "species", "seqid", "start", "stop", "strand"
46 for attr in attrs:
47 if getattr(self, attr) != getattr(other, attr):
48 return False
49 return (self.gap_spans == other.gap_spans).all()
52ReturnType = tuple[str, tuple] # the sql statement and corresponding values
55# todo add a table and methods to support storing the species tree used
56# for the alignment and for getting the species tree
57class AlignDb(SqliteDbMixin):
58 table_name = "align"
59 _align_schema = {
60 "source": "TEXT", # the file path
61 "block_id": "TEXT", # <source file path>-<alignment number>
62 "species": "TEXT",
63 "seqid": "TEXT",
64 "start": "INTEGER",
65 "stop": "INTEGER",
66 "strand": "TEXT",
67 "gap_spans": "compressed_array",
68 }
70 def __init__(self, *, source=":memory:"):
71 """
72 Parameters
73 ----------
74 source
75 location to store the db, defaults to in memory only
76 """
77 # note that data is destroyed
78 self.source = source
79 self._db = None
80 self._init_tables()
82 def add_records(self, records: typing.Sequence[AlignRecord]):
83 # bulk insert
84 col_order = [
85 row[1]
86 for row in self.db.execute(
87 f"PRAGMA table_info({self.table_name})"
88 ).fetchall()
89 ]
90 for i in range(len(records)):
91 records[i].gap_spans = _compressed_array_proxy(records[i].gap_spans)
92 records[i] = [records[i][c] for c in col_order]
94 val_placeholder = ", ".join("?" * len(col_order))
95 sql = f"INSERT INTO {self.table_name} ({', '.join(col_order)}) VALUES ({val_placeholder})"
96 self.db.executemany(sql, records)
98 def _get_block_id(
99 self,
100 *,
101 species,
102 seqid: str,
103 start: int | None,
104 stop: int | None,
105 ) -> list[str]:
106 sql = f"SELECT block_id from {self.table_name} WHERE species = ? AND seqid = ?"
107 values = species, seqid
108 if start is not None and stop is not None:
109 # as long as start or stop are within the record start/stop, it's a match
110 sql = f"{sql} AND ((start <= ? AND ? < stop) OR (start <= ? AND ? < stop))"
111 values += (start, start, stop, stop)
112 elif start is not None:
113 # the aligned segment overlaps start
114 sql = f"{sql} AND start <= ? AND ? < stop"
115 values += (start, start)
116 elif stop is not None:
117 # the aligned segment overlaps stop
118 sql = f"{sql} AND start <= ? AND ? < stop"
119 values += (stop, stop)
121 return self.db.execute(sql, values).fetchall()
123 def get_records_matching(
124 self,
125 *,
126 species,
127 seqid: str,
128 start: int | None = None,
129 stop: int | None = None,
130 ) -> typing.Iterable[AlignRecord]:
131 # make sure python, not numpy, integers
132 start = None if start is None else int(start)
133 stop = None if stop is None else int(stop)
135 # We need the block IDs for all records for a species whose coordinates
136 # lie in the range (start, stop). We then search for all records with
137 # each block id. We return full records.
138 # Client code is responsible for creating Aligned sequence instances
139 # and the Alignment.
141 block_ids = [
142 r["block_id"]
143 for r in self._get_block_id(
144 species=species, seqid=seqid, start=start, stop=stop
145 )
146 ]
148 values = ", ".join("?" * len(block_ids))
149 sql = f"SELECT * from {self.table_name} WHERE block_id IN ({values})"
150 results = defaultdict(list)
151 for record in self.db.execute(sql, block_ids).fetchall():
152 record = {k: record[k] for k in record.keys()}
153 results[record["block_id"]].append(AlignRecord(**record))
155 return results.values()
157 def get_species_names(self) -> list[str]:
158 """return the list of species names"""
159 return list(self.get_distinct("species"))
162def get_alignment(
163 align_db: AlignDb,
164 genomes: dict,
165 ref_species: str,
166 seqid: str,
167 ref_start: int | None = None,
168 ref_end: int | None = None,
169 namer: typing.Callable | None = None,
170 mask_features: list[str] | None = None,
171) -> typing.Generator[Alignment]:
172 """yields cogent3 Alignments"""
173 from ensembl_lite._convert import gap_coords_to_seq
175 if ref_species not in genomes:
176 raise ValueError(f"unknown species {ref_species!r}")
178 align_records = align_db.get_records_matching(
179 species=ref_species, seqid=seqid, start=ref_start, stop=ref_end
180 )
182 # sample the sequences
183 for block in align_records:
184 # we get the gaps corresponding to the reference sequence
185 # and convert them to a GapPosition instance. We then convert
186 # the ref_start, ref_end into align_start, align_end. Those values are
187 # used for all other species -- they are converted into sequence
188 # coordinates for each species -- selecting their sequence,
189 # building the Aligned instance, and selecting the annotation subset.
190 for align_record in block:
191 if align_record.species == ref_species and align_record.seqid == seqid:
192 # ref_start, ref_end are genomic positions and the align_record
193 # start / stop are also genomic positions
194 genome_start = align_record.start
195 genome_end = align_record.stop
196 gaps = GapPositions(
197 align_record.gap_spans, seq_length=genome_end - genome_start
198 )
200 # We use the GapPosition object to identify the alignment
201 # positions the ref_start / ref_end correspond to. The alignment
202 # positions are used below for slicing each sequence in the
203 # alignment.
205 # make sure the sequence start and stop are within this
206 # aligned block
207 seq_start = max(ref_start or genome_start, genome_start)
208 seq_end = min(ref_end or genome_end, genome_end)
209 # make these coordinates relative to the aligned segment
210 if align_record.strand == "-":
211 # if record is on minus strand, then genome stop is
212 # the alignment start
213 seq_start, seq_end = genome_end - seq_end, genome_end - seq_start
214 else:
215 seq_start = seq_start - genome_start
216 seq_end = seq_end - genome_start
218 align_start = gaps.from_seq_to_align_index(seq_start)
219 align_end = gaps.from_seq_to_align_index(seq_end)
220 break
221 else:
222 raise ValueError(f"no matching alignment record for {ref_species!r}")
224 seqs = []
225 for align_record in block:
226 record_species = align_record.species
227 genome = genomes[record_species]
228 # We need to convert the alignment coordinates into sequence
229 # coordinates for this species.
230 genome_start = align_record.start
231 genome_end = align_record.stop
232 gaps = GapPositions(
233 align_record.gap_spans, seq_length=genome_end - genome_start
234 )
236 # We use the alignment indices derived for the reference sequence
237 # above
238 seq_start = gaps.from_align_to_seq_index(align_start)
239 seq_end = gaps.from_align_to_seq_index(align_end)
240 seq_length = seq_end - seq_start
241 if align_record.strand == "-":
242 # if it's neg strand, the alignment start is the genome stop
243 seq_start = gaps.seq_length - seq_end
245 s = genome.get_seq(
246 seqid=align_record.seqid,
247 start=genome_start + seq_start,
248 stop=genome_start + seq_start + seq_length,
249 namer=namer,
250 )
251 # we now trim the gaps for this sequence to the sub-alignment
252 gaps = gaps[align_start:align_end]
254 if align_record.strand == "-":
255 s = s.rc()
257 aligned = gap_coords_to_seq(gaps.gaps, s)
258 seqs.append(aligned)
260 aln = Alignment(seqs)
261 if mask_features:
262 aln = aln.with_masked_annotations(biotypes=mask_features)
263 yield aln
266def _gap_spans(
267 gap_pos: NDArray[int], gap_cum_lengths: NDArray[int]
268) -> tuple[NDArray[int], NDArray[int]]:
269 """returns 1D arrays in alignment coordinates of
270 gap start, gap stop"""
271 if not len(gap_pos):
272 r = numpy.array([], dtype=gap_pos.dtype)
273 return r, r
275 sum_to_prev = 0
276 gap_starts = numpy.empty(gap_pos.shape[0], dtype=gap_pos.dtype)
277 gap_ends = numpy.empty(gap_pos.shape[0], dtype=gap_pos.dtype)
278 for i, pos in enumerate(gap_pos):
279 gap_starts[i] = sum_to_prev + pos
280 gap_ends[i] = pos + gap_cum_lengths[i]
281 sum_to_prev = gap_cum_lengths[i]
283 return numpy.array(gap_starts), numpy.array(gap_ends)
286@dataclass(slots=True)
287class GapPositions:
288 """records gap insertion index and length
290 Notes
291 -----
292 This very closely parallels the cogent3.core.location.Map class,
293 but is more memory efficient. When that class has been updated,
294 this can be removed.
295 """
297 # 2D numpy int array,
298 # each row is a gap
299 # column 0 is sequence index of gap **relative to the alignment**
300 # column 1 is gap length
301 gaps: NDArray[NDArray[int]]
302 # length of the underlying sequence
303 seq_length: int
305 def __post_init__(self):
306 if not len(self.gaps):
307 # can get a zero length array with shape != (0, 0)
308 # e.g. by slicing gaps[:0], but since there's no data
309 # we force it to have zero elements on both dimensions
310 self.gaps = self.gaps.reshape((0, 0))
312 # make gap array immutable
313 self.gaps.flags.writeable = False
315 def __getitem__(self, item: slice) -> typing.Self:
316 # we're assuming that this gap object is associated with a sequence
317 # that will also be sliced. Hence, we need to shift the gap insertion
318 # positions relative to this newly sliced sequence.
319 if item.step:
320 raise NotImplementedError(
321 f"{type(self).__name__!r} does not support strides"
322 )
323 start = item.start or 0
324 stop = item.stop or len(self)
325 gaps = self.gaps.copy()
326 if start < 0 or stop < 0:
327 raise NotImplementedError(
328 f"{type(self).__name__!r} does not support negative indexes"
329 )
331 if not len(gaps):
332 cum_lengths = numpy.array([], dtype=gaps.dtype)
333 pos = cum_lengths
334 else:
335 cum_lengths = gaps[:, 1].cumsum()
336 pos = gaps[:, 0]
338 gap_starts, gap_ends = _gap_spans(pos, cum_lengths)
340 if not len(gaps) or stop < gap_starts[0] or start >= gap_ends[-1]:
341 return self.__class__(
342 gaps=numpy.array([], dtype=gaps.dtype), seq_length=stop - start
343 )
345 # second column of spans is gap ends
346 # which gaps does it fall between
347 l = numpy.searchsorted(gap_ends, start, side="left")
348 if gap_starts[l] <= start < gap_ends[l]:
349 # start is within a gap
350 begin = l
351 begin_diff = start - gap_starts[l]
352 gaps[l, 1] -= begin_diff
353 shift = start - cum_lengths[l - 1] - begin_diff if l else gaps[l, 0]
354 elif start == gap_ends[l]:
355 # at gap boundary
356 begin = l + 1
357 shift = start - cum_lengths[l]
358 else:
359 # not within a gap
360 begin = l
361 shift = start - cum_lengths[l - 1] if l else start
363 # start search for stop from l index
364 r = numpy.searchsorted(gap_ends[l:], stop, side="right") + l
365 if r == len(gaps):
366 # stop is after last gap
367 end = r
368 elif gap_starts[r] < stop <= gap_ends[r]:
369 # within gap
370 end = r + 1
371 end_diff = gap_ends[r] - stop
372 gaps[r, 1] -= end_diff
373 else:
374 end = r
376 result = gaps[begin:end]
377 result[:, 0] -= shift
378 if not len(result):
379 # no gaps
380 seq_length = stop - start
381 else:
382 seq_length = self.from_align_to_seq_index(
383 stop
384 ) - self.from_align_to_seq_index(start)
386 return self.__class__(gaps=result, seq_length=seq_length)
388 def __len__(self):
389 total_gaps = self.gaps[:, 1].sum() if len(self.gaps) else 0
390 return total_gaps + self.seq_length
392 def from_seq_to_align_index(self, seq_index: int) -> int:
393 """convert a sequence index into an alignment index"""
394 if seq_index < 0:
395 raise NotImplementedError(f"{seq_index} negative align_index not supported")
397 if not len(self.gaps) or seq_index < self.gaps[0, 0]:
398 return seq_index
400 # this statement replace when we change self.gaps to include [gap pos, cumsum]
401 cum_gap_lengths = self.gaps[:, 1].cumsum()
402 gap_pos = self.gaps[:, 0]
404 if seq_index >= self.gaps[-1, 0]:
405 return seq_index + cum_gap_lengths[-1]
407 # find gap position before seq_index
408 index = numpy.searchsorted(gap_pos, seq_index, side="left")
409 if seq_index < gap_pos[index]:
410 gap_lengths = cum_gap_lengths[index - 1] if index else 0
411 else:
412 gap_lengths = cum_gap_lengths[index]
414 return seq_index + gap_lengths
416 def from_align_to_seq_index(self, align_index: int) -> int:
417 """converts alignment index to sequence index"""
418 if align_index < 0:
419 raise NotImplementedError(
420 f"{align_index} negative align_index not supported"
421 )
423 if not len(self.gaps) or align_index < self.gaps[0, 0]:
424 return align_index
426 # replace the following call when we change self.gaps to include
427 # [gap pos, cumsum]
428 # these are alignment indices for gaps
429 cum_lengths = self.gaps[:, 1].cumsum()
430 gap_starts, gap_ends = _gap_spans(self.gaps[:, 0], cum_lengths)
431 if align_index >= gap_ends[-1]:
432 return align_index - cum_lengths[-1]
434 index = numpy.searchsorted(gap_ends, align_index, side="left")
435 if align_index < gap_starts[index]:
436 # before the gap at index
437 return align_index - cum_lengths[index - 1]
439 if align_index == gap_ends[index]:
440 # after the gap at index
441 return align_index - cum_lengths[index]
443 if gap_starts[index] <= align_index < gap_ends[index]:
444 # within the gap at index
445 # so the gap insertion position is the sequence position
446 return self.gaps[index, 0]
449def write_alignments(
450 *,
451 align_db: AlignDb,
452 genomes: dict,
453 limit: int | None,
454 mask_features: list[str],
455 outdir: os.PathLike,
456 ref_species: str,
457 stableids: list[str],
458 show_progress: bool = True,
459):
460 # then the coordinates for the id's
461 ref_genome = genomes[ref_species]
462 locations = []
463 for stableid in stableids:
464 record = list(ref_genome.annotation_db.get_records_matching(name=stableid))
465 if not record:
466 continue
467 elif len(record) == 1:
468 record = record[0]
469 locations.append(
470 (
471 stableid,
472 ref_species,
473 record["seqid"],
474 record["start"],
475 record["stop"],
476 )
477 )
479 if limit:
480 locations = locations[:limit]
482 for stableid, species, seqid, start, end in track(
483 locations, disable=not show_progress
484 ):
485 alignments = list(
486 get_alignment(
487 align_db,
488 genomes,
489 species,
490 seqid,
491 start,
492 end,
493 mask_features=mask_features,
494 )
495 )
496 stableid = sanitise_stableid(stableid)
497 if len(alignments) == 1:
498 outpath = outdir / f"{stableid}.fa.gz"
499 alignments[0].write(outpath)
500 elif len(alignments) > 1:
501 for i, aln in enumerate(alignments):
502 outpath = outdir / f"{stableid}-{i}.fa.gz"
503 aln.write(outpath)
505 return True