Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_aligndb.py: 93%
187 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-06-12 16:31 -0400
« prev ^ index » next coverage.py v7.5.1, created at 2024-06-12 16:31 -0400
1from __future__ import annotations
3import pathlib
4import typing
6from collections import defaultdict
7from dataclasses import dataclass
9import h5py
10import numpy
12from cogent3.core.alignment import Aligned, Alignment
13from cogent3.core.location import _DEFAULT_GAP_DTYPE, IndelMap
14from rich.progress import track
16from ensembl_lite._db_base import Hdf5Mixin, SqliteDbMixin
17from ensembl_lite._util import _HDF5_BLOSC2_KWARGS, PathType, sanitise_stableid
20_no_gaps = numpy.array([], dtype=_DEFAULT_GAP_DTYPE)
22_GAP_STORE_SUFFIX = "hdf5_blosc2"
25@dataclass(slots=True)
26class AlignRecord:
27 """a record from an AlignDb
29 Notes
30 -----
31 Can return fields as attributes or like a dict using the field name as
32 a string.
33 """
35 source: str
36 block_id: str
37 species: str
38 seqid: str
39 start: int
40 stop: int
41 strand: str
42 gap_spans: numpy.ndarray
44 def __getitem__(self, item):
45 return getattr(self, item)
47 def __setitem__(self, item, value):
48 setattr(self, item, value)
50 def __eq__(self, other):
51 attrs = "source", "block_id", "species", "seqid", "start", "stop", "strand"
52 for attr in attrs:
53 if getattr(self, attr) != getattr(other, attr):
54 return False
55 return (self.gap_spans == other.gap_spans).all()
57 @property
58 def gap_data(self):
59 if len(self.gap_spans):
60 gap_pos, gap_lengths = self.gap_spans.T
61 else:
62 gap_pos, gap_lengths = _no_gaps.copy(), _no_gaps.copy()
64 return gap_pos, gap_lengths
67ReturnType = tuple[str, tuple] # the sql statement and corresponding values
70class GapStore(Hdf5Mixin):
71 # store gap data from aligned sequences
72 def __init__(
73 self,
74 source: PathType,
75 align_name: typing.Optional[str] = None,
76 mode: str = "r",
77 in_memory: bool = False,
78 ):
79 self.source = pathlib.Path(source)
80 self.mode = "w-" if mode == "w" else mode
81 h5_kwargs = (
82 dict(
83 driver="core",
84 backing_store=False,
85 )
86 if in_memory
87 else {}
88 )
89 try:
90 self._file = h5py.File(source, mode=self.mode, **h5_kwargs)
91 except OSError:
92 print(source)
93 raise
95 if "r" not in self.mode and "align_name" not in self._file.attrs:
96 assert align_name
97 self._file.attrs["align_name"] = align_name
98 if (
99 align_name
100 and (file_species := self._file.attrs.get("align_name", None)) != align_name
101 ):
102 raise ValueError(f"{self.source.name!r} {file_species!r} != {align_name}")
103 self.align_name = self._file.attrs["align_name"]
105 def add_record(self, *, index: int, gaps: numpy.ndarray):
106 # dataset names must be strings
107 index = str(index)
108 if index in self._file:
109 stored = self._file[index]
110 if (gaps == stored).all():
111 # already seen this index
112 return
113 # but it's different, which is a problem
114 raise ValueError(f"{index!r} already present but with different gaps")
115 self._file.create_dataset(
116 name=index, data=gaps, chunks=True, **_HDF5_BLOSC2_KWARGS
117 )
118 self._file.flush()
120 def get_record(self, *, index: int) -> numpy.ndarray:
121 return self._file[str(index)][:]
124# todo add a table and methods to support storing the species tree used
125# for the alignment and for getting the species tree
126class AlignDb(SqliteDbMixin):
127 table_name = "align"
128 _align_schema = {
129 "id": "INTEGER PRIMARY KEY", # used to uniquely identify gap_spans in bound GapStore
130 "source": "TEXT", # the file path
131 "block_id": "TEXT", # <source file path>-<alignment number>
132 "species": "TEXT",
133 "seqid": "TEXT",
134 "start": "INTEGER",
135 "stop": "INTEGER",
136 "strand": "TEXT",
137 }
139 _index_columns = {"align": ("id", "block_id", "seqid", "start", "stop")}
141 def __init__(self, *, source=":memory:", mode="a"):
142 """
143 Parameters
144 ----------
145 source
146 location to store the db, defaults to in memory only
147 """
148 # note that data is destroyed
149 source = pathlib.Path(source)
150 self.source = source
151 if source.name == ":memory:":
152 gap_path = "memory"
153 kwargs = dict(in_memory=True)
154 else:
155 gap_path = source.parent / f"{source.stem}.{_GAP_STORE_SUFFIX}"
156 kwargs = dict(in_memory=False)
158 self.gap_store = GapStore(
159 source=gap_path, align_name=source.stem, mode=mode, **kwargs
160 )
161 self._db = None
162 self._init_tables()
164 def add_records(self, records: typing.Sequence[AlignRecord]):
165 # bulk insert
166 col_order = [
167 row[1]
168 for row in self.db.execute(
169 f"PRAGMA table_info({self.table_name})"
170 ).fetchall()
171 if row[1] != "id"
172 ]
173 val_placeholder = ", ".join("?" * len(col_order))
174 sql = f"INSERT INTO {self.table_name} ({', '.join(col_order)}) VALUES ({val_placeholder}) RETURNING id"
176 for i in range(len(records)):
177 index = self.db.execute(sql, [records[i][c] for c in col_order]).fetchone()
178 index = index["id"]
179 self.gap_store.add_record(index=index, gaps=records[i].gap_spans)
181 def _get_block_id(
182 self,
183 *,
184 species,
185 seqid: str,
186 start: int | None,
187 stop: int | None,
188 ) -> list[str]:
189 sql = f"SELECT block_id from {self.table_name} WHERE species = ? AND seqid = ?"
190 values = species, seqid
191 if start is not None and stop is not None:
192 # as long as start or stop are within the record start/stop, it's a match
193 sql = f"{sql} AND ((start <= ? AND ? < stop) OR (start <= ? AND ? < stop))"
194 values += (start, start, stop, stop)
195 elif start is not None:
196 # the aligned segment overlaps start
197 sql = f"{sql} AND start <= ? AND ? < stop"
198 values += (start, start)
199 elif stop is not None:
200 # the aligned segment overlaps stop
201 sql = f"{sql} AND start <= ? AND ? < stop"
202 values += (stop, stop)
204 return self.db.execute(sql, values).fetchall()
206 def get_records_matching(
207 self,
208 *,
209 species,
210 seqid: str,
211 start: int | None = None,
212 stop: int | None = None,
213 ) -> typing.Iterable[AlignRecord]:
214 # make sure python, not numpy, integers
215 start = None if start is None else int(start)
216 stop = None if stop is None else int(stop)
218 # We need the block IDs for all records for a species whose coordinates
219 # lie in the range (start, stop). We then search for all records with
220 # each block id. We return full records.
221 # Client code is responsible for creating Aligned sequence instances
222 # and the Alignment.
224 block_ids = [
225 r["block_id"]
226 for r in self._get_block_id(
227 species=species, seqid=seqid, start=start, stop=stop
228 )
229 ]
231 values = ", ".join("?" * len(block_ids))
232 sql = f"SELECT * from {self.table_name} WHERE block_id IN ({values})"
233 results = defaultdict(list)
234 for record in self.db.execute(sql, block_ids).fetchall():
235 record = {k: record[k] for k in record.keys()}
236 index = record.pop("id")
237 record["gap_spans"] = self.gap_store.get_record(index=index)
238 results[record["block_id"]].append(AlignRecord(**record))
240 return results.values()
242 def get_species_names(self) -> list[str]:
243 """return the list of species names"""
244 return list(self.get_distinct("species"))
247def get_alignment(
248 align_db: AlignDb,
249 genomes: dict,
250 ref_species: str,
251 seqid: str,
252 ref_start: int | None = None,
253 ref_end: int | None = None,
254 namer: typing.Callable | None = None,
255 mask_features: list[str] | None = None,
256) -> typing.Generator[Alignment]:
257 """yields cogent3 Alignments"""
259 if ref_species not in genomes:
260 raise ValueError(f"unknown species {ref_species!r}")
262 align_records = align_db.get_records_matching(
263 species=ref_species, seqid=seqid, start=ref_start, stop=ref_end
264 )
266 # sample the sequences
267 for block in align_records:
268 # we get the gaps corresponding to the reference sequence
269 # and convert them to a IndelMap instance. We then convert
270 # the ref_start, ref_end into align_start, align_end. Those values are
271 # used for all other species -- they are converted into sequence
272 # coordinates for each species -- selecting their sequence,
273 # building the Aligned instance, and selecting the annotation subset.
274 for align_record in block:
275 if align_record.species == ref_species and align_record.seqid == seqid:
276 # ref_start, ref_end are genomic positions and the align_record
277 # start / stop are also genomic positions
278 genome_start = align_record.start
279 genome_end = align_record.stop
280 gap_pos, gap_lengths = align_record.gap_data
281 gaps = IndelMap(
282 gap_pos=gap_pos,
283 gap_lengths=gap_lengths,
284 parent_length=genome_end - genome_start,
285 )
287 # We use the IndelMap object to identify the alignment
288 # positions the ref_start / ref_end correspond to. The alignment
289 # positions are used below for slicing each sequence in the
290 # alignment.
292 # make sure the sequence start and stop are within this
293 # aligned block
294 seq_start = max(ref_start or genome_start, genome_start)
295 seq_end = min(ref_end or genome_end, genome_end)
296 # make these coordinates relative to the aligned segment
297 if align_record.strand == "-":
298 # if record is on minus strand, then genome stop is
299 # the alignment start
300 seq_start, seq_end = genome_end - seq_end, genome_end - seq_start
301 else:
302 seq_start = seq_start - genome_start
303 seq_end = seq_end - genome_start
305 align_start = gaps.get_align_index(seq_start)
306 align_end = gaps.get_align_index(seq_end)
307 break
308 else:
309 raise ValueError(f"no matching alignment record for {ref_species!r}")
311 seqs = []
312 for align_record in block:
313 record_species = align_record.species
314 genome = genomes[record_species]
315 # We need to convert the alignment coordinates into sequence
316 # coordinates for this species.
317 genome_start = align_record.start
318 genome_end = align_record.stop
319 gap_pos, gap_lengths = align_record.gap_data
320 gaps = IndelMap(
321 gap_pos=gap_pos,
322 gap_lengths=gap_lengths,
323 parent_length=genome_end - genome_start,
324 )
326 # We use the alignment indices derived for the reference sequence
327 # above
328 seq_start = gaps.get_seq_index(align_start)
329 seq_end = gaps.get_seq_index(align_end)
330 seq_length = seq_end - seq_start
331 if align_record.strand == "-":
332 # if it's neg strand, the alignment start is the genome stop
333 seq_start = gaps.parent_length - seq_end
335 s = genome.get_seq(
336 seqid=align_record.seqid,
337 start=genome_start + seq_start,
338 stop=genome_start + seq_start + seq_length,
339 namer=namer,
340 with_annotations=False,
341 )
342 # we now trim the gaps for this sequence to the sub-alignment
343 gaps = gaps[align_start:align_end]
345 if align_record.strand == "-":
346 s = s.rc()
348 aligned = Aligned(gaps, s)
349 seqs.append(aligned)
351 aln = Alignment(seqs)
352 aln.annotation_db = genome.annotation_db
353 if mask_features:
354 aln = aln.with_masked_annotations(biotypes=mask_features)
355 yield aln
358def write_alignments(
359 *,
360 align_db: AlignDb,
361 genomes: dict,
362 limit: int | None,
363 mask_features: list[str],
364 outdir: PathType,
365 ref_species: str,
366 stableids: list[str],
367 show_progress: bool = True,
368):
369 # then the coordinates for the id's
370 ref_genome = genomes[ref_species]
371 locations = []
372 for stableid in stableids:
373 record = list(ref_genome.annotation_db.get_records_matching(name=stableid))
374 if not record:
375 continue
376 elif len(record) == 1:
377 record = record[0]
378 locations.append(
379 (
380 stableid,
381 ref_species,
382 record["seqid"],
383 record["start"],
384 record["stop"],
385 )
386 )
388 if limit:
389 locations = locations[:limit]
391 for stableid, species, seqid, start, end in track(
392 locations, disable=not show_progress
393 ):
394 alignments = list(
395 get_alignment(
396 align_db,
397 genomes,
398 species,
399 seqid,
400 start,
401 end,
402 mask_features=mask_features,
403 )
404 )
405 stableid = sanitise_stableid(stableid)
406 if len(alignments) == 1:
407 outpath = outdir / f"{stableid}.fa.gz"
408 alignments[0].write(outpath)
409 elif len(alignments) > 1:
410 for i, aln in enumerate(alignments):
411 outpath = outdir / f"{stableid}-{i}.fa.gz"
412 aln.write(outpath)
414 return True