Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_genomedb.py: 90%

462 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-06-12 16:31 -0400

1import collections 

2import dataclasses 

3import functools 

4import itertools 

5import pathlib 

6import re 

7import sqlite3 

8import typing 

9 

10from abc import ABC, abstractmethod 

11from typing import Any, Optional 

12 

13import click 

14import h5py 

15import numpy 

16import typing_extensions 

17 

18from cogent3 import get_moltype, make_seq, make_table 

19from cogent3.app.composable import define_app 

20from cogent3.core.annotation import Feature 

21from cogent3.core.annotation_db import ( 

22 FeatureDataType, 

23 OptionalInt, 

24 OptionalStr, 

25 _select_records_sql, 

26) 

27from cogent3.core.sequence import Sequence 

28from cogent3.parse.gff import GffRecord, gff_parser, is_gff3 

29from cogent3.util.io import iter_splitlines 

30from cogent3.util.table import Table 

31from numpy.typing import NDArray 

32 

33from ensembl_lite._config import Config, InstalledConfig 

34from ensembl_lite._db_base import Hdf5Mixin, SqliteDbMixin 

35from ensembl_lite._faster_fasta import quicka_parser 

36from ensembl_lite._species import Species 

37from ensembl_lite._util import ( 

38 _HDF5_BLOSC2_KWARGS, 

39 PathType, 

40 get_stableid_prefix, 

41) 

42 

43 

44_SEQDB_NAME = "genome_sequence.hdf5_blosc2" 

45_ANNOTDB_NAME = "features.ensembl_gff3db" 

46 

47_typed_id = re.compile( 

48 r"\b[a-z]+:", flags=re.IGNORECASE 

49) # ensembl stableid's prefixed by the type 

50_feature_id = re.compile(r"(?<=\bID=)[^;]+") 

51_exon_id = re.compile(r"(?<=\bexon_id=)[^;]+") 

52_parent_id = re.compile(r"(?<=\bParent=)[^;]+") 

53 

54 

55def _lower_case_match(match) -> str: 

56 return match.group(0).lower() 

57 

58 

59def tidy_gff3_stableids(attrs: str) -> str: 

60 """makes the feature type prefix lowercase in gff3 attribute fields""" 

61 return _typed_id.sub(_lower_case_match, attrs) 

62 

63 

64class EnsemblGffRecord(GffRecord): 

65 __slots__ = GffRecord.__slots__ + ("feature_id",) 

66 

67 def __init__(self, feature_id: Optional[int] = None, **kwargs): 

68 is_canonical = kwargs.pop("is_canonical", None) 

69 super().__init__(**kwargs) 

70 self.feature_id = feature_id 

71 if is_canonical: 

72 self.attrs = "Ensembl_canonical;" + (self.attrs or "") 

73 

74 def __hash__(self) -> int: 

75 return hash(self.name) 

76 

77 def __eq__(self, other): 

78 return self.name == getattr(other, "name", other) 

79 

80 @property 

81 def stableid(self): 

82 return _typed_id.sub("", self.name or "") 

83 

84 @property 

85 def is_canonical(self): 

86 attrs = self.attrs or "" 

87 return "Ensembl_canonical" in attrs 

88 

89 def update_from_attrs(self) -> None: 

90 """updates attributes from the attrs string 

91 

92 Notes 

93 ----- 

94 also updates biotype from the prefix in the name 

95 """ 

96 attrs = self.attrs 

97 id_regex = _feature_id if "ID=" in attrs else _exon_id 

98 attr = tidy_gff3_stableids(attrs) 

99 if feature_id := id_regex.search(attr): 

100 self.name = feature_id.group() 

101 

102 if pid := _parent_id.search(attr): 

103 parents = pid.group().split(",") 

104 # now sure how to handle multiple-parent features 

105 # so taking first ID as the parent for now 

106 self.parent_id = parents[0] 

107 

108 if ":" in (self.name or ""): 

109 biotype = self.name.split(":")[0] 

110 self.biotype = "mrna" if biotype == "transcript" else biotype 

111 

112 @property 

113 def size(self) -> int: 

114 """the sum of span segments""" 

115 return 0 if self.spans is None else sum(abs(s - e) for s, e in self.spans) 

116 

117 

118def custom_gff_parser( 

119 path: PathType, num_fake_ids: int 

120) -> tuple[dict[str, EnsemblGffRecord], int]: 

121 """replacement for cogent3 merged_gff_records""" 

122 reduced = {} 

123 gff3 = is_gff3(path) 

124 for record in gff_parser( 

125 iter_splitlines(path), 

126 gff3=gff3, 

127 make_record=EnsemblGffRecord, 

128 ): 

129 record.update_from_attrs() 

130 if not record.name: 

131 record.name = f"unknown-{num_fake_ids}" 

132 num_fake_ids += 1 

133 

134 if record.name not in reduced: 

135 record.spans = record.spans or [] 

136 reduced[record] = record 

137 

138 reduced[record].spans.append([record.start, record.stop]) 

139 reduced[record].start = min(reduced[record].start, record.start) 

140 reduced[record].stop = max(reduced[record].stop, record.stop) 

141 

142 return reduced, num_fake_ids 

143 

144 

145DbTypes = typing.Union[sqlite3.Connection, "EnsemblGffDb"] 

146 

147 

148class EnsemblGffDb(SqliteDbMixin): 

149 _biotype_schema = { 

150 "type": "TEXT COLLATE NOCASE", 

151 "id": "INTEGER PRIMARY KEY AUTOINCREMENT", 

152 } 

153 _feature_schema = { 

154 "seqid": "TEXT COLLATE NOCASE", 

155 "source": "TEXT COLLATE NOCASE", 

156 "biotype_id": "TEXT", 

157 "start": "INTEGER", 

158 "stop": "INTEGER", 

159 "score": "TEXT", # check defn 

160 "strand": "TEXT", 

161 "phase": "TEXT", 

162 "attributes": "TEXT", 

163 "comments": "TEXT", 

164 "spans": "array", # aggregation of coords across records 

165 "stableid": "TEXT", 

166 "id": "INTEGER PRIMARY KEY AUTOINCREMENT", 

167 "is_canonical": "INTEGER", 

168 } 

169 # relationships are directional, but can span levels, eg. 

170 # gene -> transcript -> CDS / Exon 

171 # gene -> CDS 

172 _related_feature_schema = {"gene_id": "INTEGER", "related_id": "INTEGER"} 

173 

174 _index_columns = { 

175 "feature": ( 

176 "seqid", 

177 "stableid", 

178 "start", 

179 "stop", 

180 "is_canonical", 

181 "biotype_id", 

182 ), 

183 "related_feature": ("gene_id", "related_id"), 

184 } 

185 

186 def __init__( 

187 self, 

188 source: PathType = ":memory:", 

189 db: typing.Optional[DbTypes] = None, 

190 ): 

191 self.source = source 

192 if isinstance(db, self.__class__): 

193 db = db.db 

194 

195 self._db = db 

196 self._init_tables() 

197 self._create_views() 

198 

199 def __hash__(self): 

200 return id(self) 

201 

202 def __eq__(self, other) -> bool: 

203 return id(self) == id(other) 

204 

205 def _create_views(self) -> None: 

206 """define views to simplify queries""" 

207 sql = """ 

208 CREATE VIEW IF NOT EXISTS gff AS 

209 SELECT f.seqid as seqid, 

210 b.type as biotype, 

211 f.start as start, 

212 f.stop as stop, 

213 f.strand as strand, 

214 f.spans as spans, 

215 f.stableid as name, 

216 f.is_canonical as is_canonical, 

217 f.id as feature_id 

218 FROM feature f 

219 JOIN biotype b ON f.biotype_id = b.id 

220 """ 

221 self._execute_sql(sql) 

222 # view to query for child given parent id and vice versa 

223 p2c = """ 

224 CREATE VIEW IF NOT EXISTS parent_to_child AS 

225 SELECT fc.stableid as name, 

226 fp.stableid as parent_stableid, 

227 fc.seqid as seqid, 

228 b.type as biotype, 

229 fc.start as start, 

230 fc.stop as stop, 

231 fc.strand as strand, 

232 fc.spans as spans, 

233 fc.is_canonical as is_canonical 

234 FROM related_feature r  

235 JOIN biotype b ON fc.biotype_id = b.id 

236 JOIN feature fp ON fp.id = r.gene_id 

237 JOIN feature fc ON fc.id = r.related_id 

238 """ 

239 self._execute_sql(p2c) 

240 c2p = """ 

241 CREATE VIEW IF NOT EXISTS child_to_parent AS 

242 SELECT fp.stableid as name, 

243 fc.stableid as child_stableid, 

244 fp.seqid as seqid, 

245 b.type as biotype, 

246 fp.start as start, 

247 fp.stop as stop, 

248 fp.strand as strand, 

249 fp.is_canonical as is_canonical, 

250 fp.spans as spans 

251 FROM related_feature r  

252 JOIN biotype b ON fp.biotype_id = b.id 

253 JOIN feature fp ON fp.id = r.gene_id 

254 JOIN feature fc ON fc.id = r.related_id 

255 """ 

256 self._execute_sql(c2p) 

257 

258 def __len__(self) -> int: 

259 return self.num_records() 

260 

261 @functools.cache 

262 def _get_biotype_id(self, biotype: str) -> int: 

263 sql = "INSERT OR IGNORE INTO biotype(type) VALUES (?) RETURNING id" 

264 result = self.db.execute(sql, (biotype,)).fetchone() 

265 return result["id"] 

266 

267 def _build_feature(self, kwargs) -> EnsemblGffRecord: 

268 # not supporting this at present, which comes from cogent3 

269 # alignment objects 

270 kwargs.pop("on_alignment", None) 

271 return EnsemblGffRecord(**kwargs) 

272 

273 def add_feature( 

274 self, *, feature: typing.Optional[EnsemblGffRecord] = None, **kwargs 

275 ) -> None: 

276 """updates the feature_id attribute""" 

277 if feature is None: 

278 feature = self._build_feature(kwargs) 

279 

280 id_cols = ("biotype_id", "id") 

281 cols = [col for col in self._feature_schema if col not in id_cols] 

282 # do conversion to numpy array after the above statement to avoid issue of 

283 # having a numpy array in a conditional 

284 feature.spans = numpy.array(feature.spans) 

285 feature.start = feature.start or int(feature.spans.min()) 

286 feature.stop = feature.stop or int(feature.spans.max()) 

287 vals = [feature[col] for col in cols] + [self._get_biotype_id(feature.biotype)] 

288 cols += ["biotype_id"] 

289 placeholders = ",".join("?" * len(cols)) 

290 sql = f"INSERT INTO feature({','.join(cols)}) VALUES ({placeholders}) RETURNING id" 

291 result = self.db.execute(sql, tuple(vals)).fetchone() 

292 feature.feature_id = result["id"] 

293 

294 def add_records( 

295 self, 

296 *, 

297 records: typing.Iterable[EnsemblGffRecord], 

298 gene_relations: dict[EnsemblGffRecord, set[EnsemblGffRecord]], 

299 ) -> None: 

300 for record in records: 

301 self.add_feature(feature=record) 

302 

303 # now add the relationships 

304 sql = "INSERT INTO related_feature(gene_id, related_id) VALUES (?,?)" 

305 for gene, children in gene_relations.items(): 

306 if gene.feature_id is None: 

307 raise ValueError(f"gene.feature_id not defined for {gene!r}") 

308 

309 child_ids = [child.feature_id for child in children] 

310 if None in child_ids: 

311 raise ValueError(f"child.feature_id not defined for {children!r}") 

312 

313 comb = [tuple(c) for c in itertools.product([gene.feature_id], child_ids)] 

314 self.db.executemany(sql, comb) 

315 

316 def num_records(self) -> int: 

317 return self._execute_sql("SELECT COUNT(*) as count FROM feature").fetchone()[ 

318 "count" 

319 ] 

320 

321 def _get_records_matching( 

322 self, table_name: str, **kwargs 

323 ) -> typing.Iterator[sqlite3.Row]: 

324 """return all fields""" 

325 columns = kwargs.pop("columns", None) 

326 allow_partial = kwargs.pop("allow_partial", False) 

327 # now 

328 sql, vals = _select_records_sql( 

329 table_name=table_name, 

330 conditions=kwargs, 

331 columns=columns, 

332 allow_partial=allow_partial, 

333 ) 

334 yield from self._execute_sql(sql, values=vals) 

335 

336 def get_features_matching( 

337 self, 

338 *, 

339 seqid: OptionalStr = None, 

340 biotype: OptionalStr = None, 

341 name: OptionalStr = None, 

342 start: OptionalInt = None, 

343 stop: OptionalInt = None, 

344 strand: OptionalStr = None, 

345 attributes: OptionalStr = None, 

346 allow_partial: bool = False, 

347 **kwargs, 

348 ) -> typing.Iterator[FeatureDataType]: 

349 kwargs = { 

350 k: v 

351 for k, v in locals().items() 

352 if k not in ("self", "kwargs") and v is not None 

353 } 

354 # alignment features are created by the user specific 

355 columns = ("seqid", "biotype", "spans", "strand", "name") 

356 query_args = {**kwargs} 

357 

358 for result in self._get_records_matching( 

359 table_name="gff", columns=columns, **query_args 

360 ): 

361 result = dict(zip(columns, result)) 

362 result["spans"] = [tuple(c) for c in result["spans"]] 

363 yield result 

364 

365 def get_feature_children( 

366 self, 

367 *, 

368 name: str, 

369 **kwargs, 

370 ) -> typing.List[FeatureDataType]: 

371 cols = "seqid", "biotype", "spans", "strand", "name" 

372 results = {} 

373 for result in self._get_records_matching( 

374 table_name="parent_to_child", columns=cols, parent_stableid=name, **kwargs 

375 ): 

376 result = dict(zip(cols, result)) 

377 result["spans"] = [tuple(c) for c in result["spans"]] 

378 results[result["name"]] = result 

379 return list(results.values()) 

380 

381 def get_feature_parent( 

382 self, 

383 *, 

384 name: str, 

385 **kwargs, 

386 ) -> typing.List[FeatureDataType]: 

387 cols = "seqid", "biotype", "spans", "strand", "name" 

388 results = {} 

389 for result in self._get_records_matching( 

390 table_name="child_to_parent", columns=cols, child_stableid=name 

391 ): 

392 result = dict(zip(cols, result)) 

393 results[result["name"]] = result 

394 return list(results.values()) 

395 

396 def get_records_matching( 

397 self, 

398 *, 

399 biotype: OptionalStr = None, 

400 seqid: OptionalStr = None, 

401 name: OptionalStr = None, 

402 start: OptionalInt = None, 

403 stop: OptionalInt = None, 

404 strand: OptionalStr = None, 

405 attributes: OptionalStr = None, 

406 allow_partial: bool = False, 

407 ) -> typing.Iterator[FeatureDataType]: 

408 kwargs = { 

409 k: v for k, v in locals().items() if k not in ("self", "allow_partial") 

410 } 

411 sql, vals = _select_records_sql("gff", kwargs, allow_partial=allow_partial) 

412 col_names = None 

413 for result in self._execute_sql(sql, values=vals): 

414 if col_names is None: 

415 col_names = result.keys() 

416 yield {c: result[c] for c in col_names} 

417 

418 def biotype_counts(self) -> dict[str, int]: 

419 sql = "SELECT biotype, COUNT(*) as count FROM gff GROUP BY biotype" 

420 result = self._execute_sql(sql).fetchall() 

421 return {r["biotype"]: r["count"] for r in result} 

422 

423 def subset( 

424 self, 

425 *, 

426 source: PathType = ":memory:", 

427 biotype: OptionalStr = None, 

428 seqid: OptionalStr = None, 

429 name: OptionalStr = None, 

430 start: OptionalInt = None, 

431 stop: OptionalInt = None, 

432 strand: OptionalStr = None, 

433 attributes: OptionalStr = None, 

434 allow_partial: bool = False, 

435 ) -> typing_extensions.Self: 

436 """returns a new db instance with records matching the provided conditions""" 

437 # make sure python, not numpy, integers 

438 start = start if start is None else int(start) 

439 stop = stop if stop is None else int(stop) 

440 

441 kwargs = {k: v for k, v in locals().items() if k not in {"self", "source"}} 

442 

443 newdb = self.__class__(source=source) 

444 if not len(self): 

445 return newdb 

446 

447 # we need to recreate the values that get passed to add_records 

448 # so first identify the feature IDs that match the criteria 

449 cols = None 

450 

451 feature_ids = {} 

452 for r in self._get_records_matching(table_name="gff", **kwargs): 

453 if cols is None: 

454 cols = r.keys() 

455 r = dict(zip(cols, r)) 

456 feature_id = r.pop("feature_id") 

457 feature = EnsemblGffRecord(**r) 

458 feature_ids[feature_id] = feature 

459 

460 # now build the related features by selecting the rows with matches 

461 # in both columns to feature_ids 

462 ids = ",".join(str(i) for i in feature_ids) 

463 sql = f""" 

464 SELECT gene_id, related_id FROM related_feature  

465 WHERE gene_id IN ({ids}) AND related_id IN ({ids}) 

466 """ 

467 related = collections.defaultdict(set) 

468 for record in self._execute_sql(sql): 

469 gene_id, related_id = record["gene_id"], record["related_id"] 

470 gene = feature_ids[gene_id] 

471 related[gene].add(feature_ids[related_id]) 

472 

473 newdb.add_records(records=feature_ids.values(), gene_relations=related) 

474 return newdb 

475 

476 

477def make_gene_relationships( 

478 records: typing.Sequence[EnsemblGffRecord], 

479) -> dict[EnsemblGffRecord, set[EnsemblGffRecord]]: 

480 """returns all feature children of genes""" 

481 related = {} 

482 for record in records: 

483 biotype = related.get(record.biotype.lower(), {}) 

484 biotype[record.name] = record 

485 related[record.biotype.lower()] = biotype 

486 

487 # reduce the related dict into gene_id by child/grandchild ID 

488 genes = {} 

489 for cds_record in related["cds"].values(): 

490 mrna_record = related["mrna"][cds_record.parent_id] 

491 if mrna_record.is_canonical: 

492 # we make the CDS identifiable as being canonical 

493 # this token is used by is_canonical property 

494 cds_record.attrs = f"Ensembl_canonical;{cds_record.attrs}" 

495 

496 gene = related["gene"][mrna_record.parent_id] 

497 gene_relationships = genes.get(gene.name, set()) 

498 gene_relationships.update((cds_record, mrna_record)) 

499 genes[gene] = gene_relationships 

500 

501 return genes 

502 

503 

504def get_stableid_prefixes(records: typing.Sequence[EnsemblGffRecord]) -> set[str]: 

505 """returns the prefixes of the stableids""" 

506 prefixes = set() 

507 for record in records: 

508 record.update_from_attrs() 

509 try: 

510 prefix = get_stableid_prefix(record.stableid) 

511 except ValueError: 

512 continue 

513 prefixes.add(prefix) 

514 return prefixes 

515 

516 

517def make_annotation_db( 

518 src_dest: tuple[pathlib.Path, pathlib.Path] 

519) -> tuple[str, set[str]]: 

520 """convert gff3 file into a EnsemblGffDb 

521 

522 Parameters 

523 ---------- 

524 src_dest 

525 path to gff3 file, path to write AnnotationDb 

526 """ 

527 src, dest = src_dest 

528 db_name = dest.parent.name 

529 if dest.exists(): 

530 return db_name, set() 

531 

532 db = EnsemblGffDb(source=dest) 

533 records, _ = custom_gff_parser(src, 0) 

534 prefixes = get_stableid_prefixes(tuple(records.keys())) 

535 related = make_gene_relationships(records) 

536 db.add_records(records=records.values(), gene_relations=related) 

537 db.make_indexes() 

538 db.close() 

539 del db 

540 return db_name, prefixes 

541 

542 

543def _rename(label: str) -> str: 

544 return label.split()[0] 

545 

546 

547@define_app 

548class fasta_to_hdf5: 

549 def __init__(self, config: Config, label_to_name=_rename): 

550 self.config = config 

551 self.label_to_name = label_to_name 

552 

553 def main(self, db_name: str) -> bool: 

554 src_dir = self.config.staging_genomes / db_name 

555 dest_dir = self.config.install_genomes / db_name 

556 

557 seq_store = SeqsDataHdf5( 

558 source=dest_dir / _SEQDB_NAME, 

559 species=Species.get_species_name(db_name), 

560 mode="w", 

561 ) 

562 

563 src_dir = src_dir / "fasta" 

564 for path in src_dir.glob("*.fa.gz"): 

565 # for label, seq in quicka_parser(path, one_seq=False): 

566 for label, seq in quicka_parser(path): 

567 seqid = self.label_to_name(label) 

568 seq_store.add_record(seq, seqid) 

569 del seq 

570 

571 seq_store.close() 

572 

573 return True 

574 

575 

576T = tuple[PathType, list[tuple[str, str]]] 

577 

578 

579class SeqsDataABC(ABC): 

580 """interface for genome sequence storage""" 

581 

582 # the storage reference, e.g. path to file 

583 source: PathType 

584 species: str 

585 mode: str # as per standard file opening modes, r, w, a 

586 _is_open = False 

587 _file: Optional[Any] = None 

588 

589 @abstractmethod 

590 def __hash__(self): ... 

591 

592 @abstractmethod 

593 def add_record(self, seq: str, seqid: str): ... 

594 

595 @abstractmethod 

596 def add_records(self, *, records: typing.Iterable[list[str, str]]): ... 

597 

598 @abstractmethod 

599 def get_seq_str( 

600 self, *, seqid: str, start: Optional[int] = None, stop: Optional[int] = None 

601 ) -> str: ... 

602 

603 @abstractmethod 

604 def get_seq_arr( 

605 self, *, seqid: str, start: Optional[int] = None, stop: Optional[int] = None 

606 ) -> NDArray[numpy.uint8]: ... 

607 

608 @abstractmethod 

609 def get_coord_names(self) -> tuple[str]: ... 

610 

611 @abstractmethod 

612 def close(self): ... 

613 

614 

615@define_app 

616class str2arr: 

617 """convert string to array of uint8""" 

618 

619 def __init__(self, moltype: str = "dna", max_length=None): 

620 moltype = get_moltype(moltype) 

621 self.canonical = "".join(moltype) 

622 self.max_length = max_length 

623 extended = "".join(list(moltype.alphabets.degen)) 

624 self.translation = b"".maketrans( 

625 extended.encode("utf8"), 

626 "".join(chr(i) for i in range(len(extended))).encode("utf8"), 

627 ) 

628 

629 def main(self, data: str) -> numpy.ndarray: 

630 if self.max_length: 

631 data = data[: self.max_length] 

632 

633 b = data.encode("utf8").translate(self.translation) 

634 return numpy.array(memoryview(b), dtype=numpy.uint8) 

635 

636 

637@define_app 

638class arr2str: 

639 """convert array of uint8 to str""" 

640 

641 def __init__(self, moltype: str = "dna", max_length=None): 

642 moltype = get_moltype(moltype) 

643 self.canonical = "".join(moltype) 

644 self.max_length = max_length 

645 extended = "".join(list(moltype.alphabets.degen)) 

646 self.translation = b"".maketrans( 

647 "".join(chr(i) for i in range(len(extended))).encode("utf8"), 

648 extended.encode("utf8"), 

649 ) 

650 

651 def main(self, data: numpy.ndarray) -> str: 

652 if self.max_length: 

653 data = data[: self.max_length] 

654 

655 b = data.tobytes().translate(self.translation) 

656 return bytearray(b).decode("utf8") 

657 

658 

659@dataclasses.dataclass 

660class SeqsDataHdf5(Hdf5Mixin, SeqsDataABC): 

661 """HDF5 sequence data storage""" 

662 

663 def __init__( 

664 self, 

665 source: PathType, 

666 species: Optional[str] = None, 

667 mode: str = "r", 

668 in_memory: bool = False, 

669 ): 

670 # note that species are converted into the Ensembl db prefix 

671 

672 source = pathlib.Path(source) 

673 self.source = source 

674 

675 if mode == "r" and not source.exists(): 

676 raise OSError(f"{self.source!s} not found") 

677 

678 species = Species.get_ensembl_db_prefix(species) if species else None 

679 self.mode = "w-" if mode == "w" else mode 

680 if in_memory: 

681 h5_kwargs = dict( 

682 driver="core", 

683 backing_store=False, 

684 ) 

685 else: 

686 h5_kwargs = {} 

687 

688 try: 

689 self._file: h5py.File = h5py.File(source, mode=self.mode, **h5_kwargs) 

690 except OSError: 

691 print(source) 

692 raise 

693 self._str2arr = str2arr(moltype="dna") 

694 self._arr2str = arr2str(moltype="dna") 

695 self._is_open = True 

696 if "r" not in self.mode and "species" not in self._file.attrs: 

697 assert species 

698 self._file.attrs["species"] = species 

699 

700 if ( 

701 species 

702 and (file_species := self._file.attrs.get("species", None)) != species 

703 ): 

704 raise ValueError(f"{self.source.name!r} {file_species!r} != {species}") 

705 self.species = self._file.attrs["species"] 

706 

707 def __hash__(self): 

708 return id(self) 

709 

710 @functools.singledispatchmethod 

711 def add_record(self, seq: str, seqid: str): 

712 seq = self._str2arr(seq) 

713 self.add_record(seq, seqid) 

714 

715 @add_record.register 

716 def _(self, seq: numpy.ndarray, seqid: str): 

717 if seqid in self._file: 

718 stored = self._file[seqid] 

719 if (seq == stored).all(): 

720 # already seen this seq 

721 return 

722 # but it's different, which is a problem 

723 raise ValueError(f"{seqid!r} already present but with different seq") 

724 

725 self._file.create_dataset( 

726 name=seqid, data=seq, chunks=True, **_HDF5_BLOSC2_KWARGS 

727 ) 

728 

729 def add_records(self, *, records: typing.Iterable[list[str, str]]): 

730 for seqid, seq in records: 

731 self.add_record(seq, seqid) 

732 

733 def get_seq_str( 

734 self, *, seqid: str, start: Optional[int] = None, stop: Optional[int] = None 

735 ) -> str: 

736 return self._arr2str(self.get_seq_arr(seqid=seqid, start=start, stop=stop)) 

737 

738 def get_seq_arr( 

739 self, *, seqid: str, start: Optional[int] = None, stop: Optional[int] = None 

740 ) -> NDArray[numpy.uint8]: 

741 if not self._is_open: 

742 raise OSError(f"{self.source.name!r} is closed") 

743 

744 return self._file[seqid][start:stop] 

745 

746 def get_coord_names(self) -> tuple[str]: 

747 """names of chromosomes / contig""" 

748 return tuple(self._file) 

749 

750 

751# todo: this wrapping class is required for memory efficiency because 

752# the cogent3 SequenceCollection class is not designed for large sequence 

753# collections, either large sequences or large numbers of sequences. The 

754# longer term solution is improving SequenceCollections, 

755# which is underway 🎉 

756class Genome: 

757 """class to be replaced by cogent3 sequence collection when that 

758 has been modernised""" 

759 

760 def __init__( 

761 self, 

762 *, 

763 species: str, 

764 seqs: SeqsDataABC, 

765 annots: EnsemblGffDb, 

766 ) -> None: 

767 self.species = species 

768 self._seqs = seqs 

769 self.annotation_db = annots 

770 

771 def get_seq( 

772 self, 

773 *, 

774 seqid: str, 

775 start: Optional[int] = None, 

776 stop: Optional[int] = None, 

777 namer: typing.Callable | None = None, 

778 with_annotations: bool = True, 

779 ) -> str: 

780 """returns annotated sequence 

781 

782 Parameters 

783 ---------- 

784 seqid 

785 name of chromosome etc.. 

786 start 

787 starting position of slice in python coordinates, defaults 

788 to 0 

789 stop 

790 ending position of slice in python coordinates, defaults 

791 to length of coordinate 

792 namer 

793 callback for naming the sequence. Callback must take four 

794 arguments: species, seqid,start, stop. Default is 

795 species:seqid:start-stop. 

796 with_annotations 

797 assign annotation_db to seq 

798 

799 Notes 

800 ----- 

801 Full annotations are bound to the instance. 

802 """ 

803 seq = self._seqs.get_seq_str(seqid=seqid, start=start, stop=stop) 

804 if namer: 

805 name = namer(self.species, seqid, start, stop) 

806 else: 

807 name = f"{self.species}:{seqid}:{start}-{stop}" 

808 # we use seqid to make the sequence here because that identifies the 

809 # parent seq identity, required for querying annotations 

810 seq = make_seq(seq, name=seqid, moltype="dna") 

811 seq.name = name 

812 seq.annotation_offset = start or 0 

813 seq.annotation_db = self.annotation_db if with_annotations else None 

814 return seq 

815 

816 def get_features( 

817 self, 

818 *, 

819 biotype: OptionalStr = None, 

820 seqid: OptionalStr = None, 

821 name: OptionalStr = None, 

822 start: OptionalInt = None, 

823 stop: OptionalInt = None, 

824 ) -> typing.Iterable[Feature]: 

825 for ft in self.annotation_db.get_features_matching( 

826 biotype=biotype, seqid=seqid, name=name, start=start, stop=stop 

827 ): 

828 seqid = ft.pop("seqid") 

829 ft["spans"] = numpy.array(ft["spans"]) 

830 start = int(ft["spans"].min()) 

831 stop = int(ft["spans"].max()) 

832 ft["spans"] = ft["spans"] - start 

833 seq = self.get_seq( 

834 seqid=seqid, start=start, stop=stop, with_annotations=True 

835 ) 

836 # because self.get_seq() automatically names seqs differently 

837 seq.name = seqid 

838 yield seq.make_feature(ft) 

839 

840 def get_gene_cds(self, name: str, is_canonical: bool = True): 

841 for cds in self.annotation_db.get_feature_children( 

842 name=name, biotype="cds", is_canonical=is_canonical 

843 ): 

844 seqid = cds.pop("seqid") 

845 cds["spans"] = numpy.array(cds["spans"]) 

846 start = cds["spans"].min() 

847 stop = cds["spans"].max() 

848 seq = self.get_seq(seqid=seqid, start=start, stop=stop) 

849 cds["spans"] = cds["spans"] - start 

850 yield seq.make_feature(feature=cds) 

851 

852 def get_ids_for_biotype(self, biotype: str, limit: OptionalInt = None): 

853 annot_db = self.annotation_db 

854 sql = "SELECT name from gff WHERE biotype=?" 

855 val = (biotype,) 

856 if limit: 

857 sql += " LIMIT ?" 

858 val = val + (limit,) 

859 for result in annot_db._execute_sql(sql, val): 

860 yield result["name"].split(":")[-1] 

861 

862 def close(self): 

863 self._seqs.close() 

864 self.annotation_db.db.close() 

865 

866 

867def load_genome(*, config: InstalledConfig, species: str): 

868 """returns the Genome with bound seqs and features""" 

869 genome_path = config.installed_genome(species) / _SEQDB_NAME 

870 seqs = SeqsDataHdf5(source=genome_path, species=species, mode="r") 

871 ann_path = config.installed_genome(species) / _ANNOTDB_NAME 

872 ann = EnsemblGffDb(source=ann_path) 

873 return Genome(species=species, seqs=seqs, annots=ann) 

874 

875 

876def get_seqs_for_ids( 

877 *, 

878 config: InstalledConfig, 

879 species: str, 

880 names: list[str], 

881 make_seq_name: typing.Optional[typing.Callable] = None, 

882) -> typing.Iterable[Sequence]: 

883 genome = load_genome(config=config, species=species) 

884 # is it possible to do batch query for all names? 

885 for name in names: 

886 cds = list(genome.get_gene_cds(name=name, is_canonical=False)) 

887 if not cds: 

888 continue 

889 

890 feature = cds[0] 

891 seq = feature.get_slice() 

892 if callable(make_seq_name): 

893 seq.name = make_seq_name(feature) 

894 else: 

895 seq.name = f"{species}-{name}" 

896 seq.info["species"] = species 

897 seq.info["name"] = name 

898 # disconnect from annotation so the closure of the genome 

899 # does not cause issues when run in parallel 

900 seq.annotation_db = None 

901 yield seq 

902 

903 genome.close() 

904 del genome 

905 

906 

907def load_annotations_for_species(*, path: pathlib.Path) -> EnsemblGffDb: 

908 """returns the annotation Db for species""" 

909 if not path.exists(): 

910 click.secho(f"{path.name!r} is missing", fg="red") 

911 exit(1) 

912 return EnsemblGffDb(source=path) 

913 

914 

915def get_gene_table_for_species( 

916 *, annot_db: EnsemblGffDb, limit: Optional[int], species: Optional[str] = None 

917) -> Table: 

918 """ 

919 returns gene data from a GffDb 

920 

921 Parameters 

922 ---------- 

923 annot_db 

924 feature db 

925 limit 

926 limit number of records to 

927 species 

928 species name, overrides inference from annot_db.source 

929 """ 

930 species = species or annot_db.source.parent.name 

931 

932 columns = ( 

933 "species", 

934 "name", 

935 "seqid", 

936 "source", 

937 "biotype", 

938 "start", 

939 "stop", 

940 "score", 

941 "strand", 

942 "phase", 

943 ) 

944 rows = [] 

945 for i, record in enumerate(annot_db.get_records_matching(biotype="gene")): 

946 rows.append([species] + [record.get(c, None) for c in columns[1:]]) 

947 if i == limit: 

948 break 

949 

950 return make_table(header=columns, data=rows) 

951 

952 

953def get_species_summary( 

954 *, annot_db: EnsemblGffDb, species: Optional[str] = None 

955) -> Table: 

956 """ 

957 returns the Table summarising data for species_name 

958 

959 Parameters 

960 ---------- 

961 annot_db 

962 feature db 

963 species 

964 species name, overrides inference from annot_db.source 

965 """ 

966 # for now, just biotype 

967 species = species or annot_db.source.parent.name 

968 counts = annot_db.biotype_counts() 

969 try: 

970 common_name = Species.get_common_name(species) 

971 except ValueError: 

972 common_name = species 

973 

974 return Table( 

975 header=("biotype", "count"), 

976 data=list(counts.items()), 

977 title=f"{common_name} features", 

978 column_templates={"count": lambda x: f"{x:,}"}, 

979 )