Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_homologydb.py: 85%

221 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-06-12 16:31 -0400

1from __future__ import annotations 

2 

3import dataclasses 

4import typing 

5 

6import blosc2 

7 

8from cogent3 import make_unaligned_seqs 

9from cogent3.app.composable import LOADER, NotCompleted, define_app 

10from cogent3.app.io import compress, decompress, pickle_it, unpickle_it 

11from cogent3.app.typing import ( 

12 IdentifierType, 

13 SerialisableType, 

14 UnalignedSeqsType, 

15) 

16from cogent3.parse.table import FilteringParser 

17from cogent3.util.io import PathType, iter_splitlines 

18 

19from ensembl_lite._config import InstalledConfig 

20from ensembl_lite._db_base import SqliteDbMixin 

21from ensembl_lite._genomedb import load_genome 

22from ensembl_lite._species import Species 

23 

24 

25_HOMOLOGYDB_NAME = "homologies.sqlitedb" 

26 

27compressor = compress(compressor=blosc2.compress2) 

28decompressor = decompress(decompressor=blosc2.decompress2) 

29pickler = pickle_it() 

30unpickler = unpickle_it() 

31inflate = decompressor + unpickler 

32 

33 

34@dataclasses.dataclass(slots=True) 

35class species_genes: 

36 """contains gene IDs for species""" 

37 

38 species: str 

39 gene_ids: typing.Optional[list[str]] = None 

40 

41 def __hash__(self): 

42 return hash(self.species) 

43 

44 def __eq__(self, other): 

45 return self.species == other.species and self.gene_ids == other.gene_ids 

46 

47 def __post_init__(self): 

48 self.gene_ids = [] if not self.gene_ids else list(self.gene_ids) 

49 

50 def __getstate__(self) -> tuple[str, tuple[str, ...]]: 

51 return self.species, tuple(self.gene_ids) 

52 

53 def __setstate__(self, args): 

54 species, gene_ids = args 

55 self.species = species 

56 self.gene_ids = list(gene_ids) 

57 

58 

59@dataclasses.dataclass 

60class homolog_group: 

61 """has species_genes instances belonging to the same ortholog group""" 

62 

63 relationship: str 

64 gene_ids: typing.Optional[set[str, ...]] = None 

65 source: str | None = None 

66 

67 def __post_init__(self): 

68 self.gene_ids = self.gene_ids if self.gene_ids else set() 

69 if self.source is None: 

70 self.source = next(iter(self.gene_ids), None) 

71 

72 def __hash__(self): 

73 # allow hashing, but bearing in mind we are updating 

74 # gene values 

75 return hash((hash(self.relationship), id(self.gene_ids))) 

76 

77 def __eq__(self, other): 

78 return ( 

79 self.relationship == other.relationship and self.gene_ids == other.gene_ids 

80 ) 

81 

82 def __getstate__(self) -> tuple[str, set[str] | None, str | None]: 

83 return self.relationship, self.gene_ids, self.source 

84 

85 def __setstate__(self, state: tuple[str, set[str] | None, str | None]): 

86 relationship, gene_ids, source = state 

87 self.relationship = relationship 

88 self.gene_ids = gene_ids 

89 self.source = source 

90 

91 def __len__(self): 

92 return len(self.gene_ids or ()) 

93 

94 def __or__(self, other): 

95 if other.relationship != self.relationship: 

96 raise ValueError( 

97 f"relationship type {self.relationship!r} != {other.relationship!r}" 

98 ) 

99 return self.__class__( 

100 relationship=self.relationship, gene_ids=self.gene_ids | other.gene_ids 

101 ) 

102 

103 def species_ids(self) -> dict[str, tuple[str, ...]]: 

104 """returns {species: gene_ids, ...}""" 

105 result = {} 

106 for gene_id in self.gene_ids: 

107 sp = Species.get_db_prefix_from_stableid(gene_id) 

108 ids = result.get(sp, []) 

109 ids.append(gene_id) 

110 result[sp] = ids 

111 return result 

112 

113 

114T = dict[str, tuple[homolog_group, ...]] 

115 

116 

117def grouped_related( 

118 data: typing.Iterable[tuple[str, str, str]], 

119) -> T: 

120 """determines related groups of genes 

121 

122 Parameters 

123 ---------- 

124 data 

125 list of full records from the HomologyDb 

126 

127 Returns 

128 ------- 

129 a data structure that can be json serialised 

130 

131 Notes 

132 ----- 

133 I assume that for a specific relationship type, a gene can only belong 

134 to one group. 

135 """ 

136 # grouped is {<relationship type>: {gene id: homolog_group}. So gene's 

137 # that belong to the same group have the same value 

138 grouped = {} 

139 for rel_type, gene_id_1, gene_id_2 in data: 

140 relationship = grouped.get(rel_type, {}) 

141 pair = {gene_id_1, gene_id_2} 

142 

143 if gene_id_1 in relationship: 

144 val = relationship[gene_id_1] 

145 elif gene_id_2 in relationship: 

146 val = relationship[gene_id_2] 

147 else: 

148 val = homolog_group(relationship=rel_type) 

149 val.gene_ids |= pair 

150 

151 relationship[gene_id_1] = relationship[gene_id_2] = val 

152 grouped[rel_type] = relationship 

153 

154 reduced = {} 

155 for rel_type, groups in grouped.items(): 

156 reduced[rel_type] = tuple(set(groups.values())) 

157 

158 return reduced 

159 

160 

161def _gene_id_to_group(series: tuple[homolog_group, ...]) -> dict[str:homolog_group]: 

162 """converts series of homolog_group instances to {geneid: groupl, ..}""" 

163 result = {} 

164 for group in series: 

165 result.update({gene_id: group for gene_id in group.gene_ids}) 

166 return result 

167 

168 

169def _add_unique( 

170 a: dict[str, homolog_group], 

171 b: dict[str, homolog_group], 

172 combined: dict[str, homolog_group], 

173) -> dict[str, homolog_group]: 

174 unique = a.keys() - b.keys() 

175 combined.update(**{gene_id: a[gene_id] for gene_id in unique}) 

176 return combined 

177 

178 

179def merge_grouped(group1: T, group2: T) -> T: 

180 """merges homolog_group with overlapping members""" 

181 joint = {} 

182 groups = group1, group2 

183 rel_types = group1.keys() | group2.keys() 

184 for rel_type in rel_types: 

185 if any(rel_type not in grp for grp in groups): 

186 joint[rel_type] = group1.get(rel_type, group2.get(rel_type)) 

187 continue 

188 

189 # expand values to dicts 

190 grp1 = _gene_id_to_group(group1[rel_type]) 

191 grp2 = _gene_id_to_group(group2[rel_type]) 

192 

193 # if a group is unique for a relationship type, not one member 

194 # will be present in the other group 

195 # add groups that are truly unique to each 

196 rel_type_group = {} 

197 # unique to grp 1 

198 rel_type_group = _add_unique(grp1, grp2, rel_type_group) 

199 # unique to grp 2 

200 rel_type_group = _add_unique(grp2, grp1, rel_type_group) 

201 

202 shared_ids = grp1.keys() & grp2.keys() 

203 skip = set() # id's for groups already processed 

204 for gene_id in shared_ids: 

205 if gene_id in skip: 

206 continue 

207 merged = grp1[gene_id] | grp2[gene_id] 

208 rel_type_group.update({gene_id: merged for gene_id in merged.gene_ids}) 

209 skip.update(merged.gene_ids) 

210 

211 joint[rel_type] = tuple(set(rel_type_group.values())) 

212 

213 return joint 

214 

215 

216# the homology db stores pairwise relationship information 

217class HomologyDb(SqliteDbMixin): 

218 table_names = "homology", "relationship", "member" 

219 

220 _relationship_schema = { 

221 "homology_type": "TEXT", 

222 "id": "INTEGER PRIMARY KEY", 

223 } 

224 _homology_schema = { 

225 "id": "INTEGER PRIMARY KEY AUTOINCREMENT", 

226 "relationship_id": "INTEGER", 

227 } 

228 _member_schema = { 

229 "gene_id": "TEXT", # stableid of gene, defined by Ensembl 

230 "homology_id": "INTEGER", 

231 "PRIMARY KEY": ("gene_id", "homology_id"), 

232 } 

233 

234 _index_columns = { 

235 "homology": ("relationship_id",), 

236 "relationship": ("homology_type",), 

237 "member": ("gene_id", "homology_id"), 

238 } 

239 

240 def __init__(self, source=":memory:"): 

241 self.source = source 

242 self._relationship_types = {} 

243 self._init_tables() 

244 self._create_views() 

245 

246 def _create_views(self): 

247 """define views to simplify queries""" 

248 # we want to be able to query for all ortholog groups of a 

249 # particular type. For example, get all groups of IDs of 

250 # type one-to-one orthologs 

251 sql = """ 

252 CREATE VIEW IF NOT EXISTS related_groups AS 

253 SELECT r.homology_type as homology_type, 

254 r.id as relationship_id, 

255 h.id as homology_id , m.gene_id as gene_id 

256 FROM homology h JOIN relationship r ON h.relationship_id = r.id 

257 JOIN member as m ON m.homology_id = h.id 

258 """ 

259 self._execute_sql(sql) 

260 

261 def _make_relationship_type_id(self, rel_type: str) -> int: 

262 """returns the relationship.id value for relationship_type""" 

263 if rel_type not in self._relationship_types: 

264 sql = "INSERT INTO relationship(homology_type) VALUES (?) RETURNING id" 

265 result = self.db.execute(sql, (rel_type,)).fetchone()[0] 

266 self._relationship_types[rel_type] = result 

267 return self._relationship_types[rel_type] 

268 

269 def _get_homology_group_id( 

270 self, *, relationship_id: int, gene_ids: typing.Optional[tuple[str]] = None 

271 ) -> int: 

272 """creates a new homolog table entry for this relationship id""" 

273 if gene_ids is None: 

274 sql = "INSERT INTO homology(relationship_id) VALUES (?) RETURNING id" 

275 return self.db.execute(sql, (relationship_id,)).fetchone()[0] 

276 

277 # check if gene_ids exist 

278 id_placeholders = ",".join("?" * len(gene_ids)) 

279 sql = f""" 

280 SELECT r.homology_id as homology_id 

281 FROM related_groups r 

282 WHERE r.relationship_id = ? AND r.gene_id IN ({id_placeholders})  

283 LIMIT 1 

284 """ 

285 result = self.db.execute(sql, (relationship_id,) + gene_ids).fetchone() 

286 if result is None: 

287 return self._get_homology_group_id(relationship_id=relationship_id) 

288 

289 return result[0] 

290 

291 def add_records( 

292 self, 

293 *, 

294 records: typing.Sequence[homolog_group], 

295 relationship_type: str, 

296 ) -> None: 

297 """inserts homology data from records 

298 

299 Parameters 

300 ---------- 

301 records 

302 a sequence of homolog group instances, all with the same 

303 relationship type 

304 relationship_type 

305 the relationship type 

306 """ 

307 assert relationship_type is not None 

308 rel_type_id = self._make_relationship_type_id(relationship_type) 

309 # we now iterate over the homology groups 

310 # we get a new homology id, then add all genes for that group 

311 # using the IGNORE to skip duplicates 

312 sql = "INSERT OR IGNORE INTO member(gene_id,homology_id) VALUES (?, ?)" 

313 for group in records: 

314 if group.relationship != relationship_type: 

315 raise ValueError(f"{group.relationship=} != {relationship_type=}") 

316 

317 homology_id = self._get_homology_group_id( 

318 relationship_id=rel_type_id, gene_ids=tuple(group.gene_ids) 

319 ) 

320 values = [(gene_id, homology_id) for gene_id in group.gene_ids] 

321 self.db.executemany(sql, values) 

322 self.db.commit() 

323 

324 def get_related_to(self, *, gene_id: str, relationship_type: str) -> homolog_group: 

325 """return genes with relationship type to gene_id""" 

326 sql = """ 

327 SELECT r.homology_id as homology_id 

328 FROM related_groups r 

329 WHERE r.homology_type = ? AND r.gene_id = ? 

330 """ 

331 homology_id = self._execute_sql(sql, (relationship_type, gene_id)).fetchone() 

332 if not homology_id: 

333 return () 

334 homology_id = homology_id["homology_id"] 

335 sql = """ 

336 SELECT GROUP_CONCAT(r.gene_id) as gene_ids 

337 FROM related_groups r 

338 WHERE r.homology_id = ? 

339 """ 

340 result = self._execute_sql(sql, (homology_id,)).fetchone() 

341 return homolog_group( 

342 relationship=relationship_type, 

343 gene_ids=set(result["gene_ids"].split(",")), 

344 source=gene_id, 

345 ) 

346 

347 def get_related_groups( 

348 self, relationship_type: str 

349 ) -> typing.Sequence[homolog_group]: 

350 """returns all groups of relationship type""" 

351 sql = """ 

352 SELECT GROUP_CONCAT(r.gene_id) as gene_ids 

353 FROM related_groups r 

354 WHERE r.homology_type = ? 

355 GROUP BY r.homology_id 

356 """ 

357 return [ 

358 homolog_group( 

359 relationship=relationship_type, 

360 gene_ids=set(group["gene_ids"].split(",")), 

361 ) 

362 for group in self._execute_sql(sql, (relationship_type,)).fetchall() 

363 ] 

364 

365 def num_records(self): 

366 return list( 

367 self._execute_sql("SELECT COUNT(*) as count FROM member").fetchone() 

368 )[0] 

369 

370 

371def load_homology_db( 

372 *, 

373 path: PathType, 

374) -> HomologyDb: 

375 return HomologyDb(source=path) 

376 

377 

378@define_app(app_type=LOADER) 

379class load_homologies: 

380 def __init__(self, allowed_species: set): 

381 self._allowed_species = allowed_species 

382 # map the Ensembl columns to HomologyDb columns 

383 

384 self.src_cols = ( 

385 "homology_type", 

386 "species", 

387 "gene_stable_id", 

388 "homology_species", 

389 "homology_gene_stable_id", 

390 ) 

391 self.dest_col = ( 

392 "relationship", 

393 "species_1", 

394 "gene_id_1", 

395 "species_2", 

396 "gene_id_2", 

397 ) 

398 self._reader = FilteringParser( 

399 row_condition=self._matching_species, columns=self.src_cols, sep="\t" 

400 ) 

401 

402 def _matching_species(self, row): 

403 return {row[1], row[3]} <= self._allowed_species 

404 

405 def main(self, path: IdentifierType) -> SerialisableType: 

406 parser = self._reader(iter_splitlines(path, chunk_size=500_000)) 

407 header = next(parser) 

408 assert list(header) == list(self.src_cols), (header, self.src_cols) 

409 return grouped_related((row[0], row[2], row[4]) for row in parser) 

410 

411 

412@define_app 

413class collect_seqs: 

414 """given a config and homolog group, loads genome instances on demand 

415 and extracts sequences""" 

416 

417 def __init__( 

418 self, 

419 config: InstalledConfig, 

420 make_seq_name: typing.Optional[typing.Callable] = None, 

421 verbose: bool = False, 

422 ): 

423 self._config = config 

424 self._genomes = {} 

425 self._namer = make_seq_name 

426 self._verbose = verbose 

427 

428 def main(self, homologs: homolog_group) -> UnalignedSeqsType: 

429 namer = self._namer 

430 seqs = [] 

431 for species, sp_genes in homologs.species_ids().items(): 

432 if species not in self._genomes: 

433 self._genomes[species] = load_genome( 

434 config=self._config, species=species 

435 ) 

436 genome = self._genomes[species] 

437 for name in sp_genes: 

438 cds = list(genome.get_gene_cds(name=name, is_canonical=True)) 

439 if not cds: 

440 if self._verbose: 

441 print(f"no cds for {name}") 

442 continue 

443 

444 feature = cds[0] 

445 seq = feature.get_slice() 

446 seq.name = f"{species}-{name}" if namer is None else namer(feature) 

447 seq.info["species"] = species 

448 seq.info["name"] = name 

449 # disconnect from annotation so the closure of the genome 

450 # does not cause issues when run in parallel 

451 seq.annotation_db = None 

452 seqs.append(seq) 

453 

454 if not seqs: 

455 return NotCompleted( 

456 type="FAIL", origin=self, message=f"no CDS for {homologs}" 

457 ) 

458 

459 return make_unaligned_seqs(data=seqs, moltype="dna")