Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_db_base.py: 97%
90 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-06-12 16:31 -0400
« prev ^ index » next coverage.py v7.5.1, created at 2024-06-12 16:31 -0400
1import dataclasses
2import sqlite3
4from ensembl_lite._util import PathType, SerialisableMixin
7ReturnType = tuple[str, tuple] # the sql statement and corresponding values
10def _make_table_sql(
11 table_name: str,
12 columns: dict,
13) -> str:
14 """makes the SQL for creating a table
16 Parameters
17 ----------
18 table_name : str
19 name of the table
20 columns : dict
21 {<column name>: <column SQL type>, ...}
23 Returns
24 -------
25 str
26 """
27 primary_key = columns.pop("PRIMARY KEY", None)
28 columns_types = ", ".join([f"{name} {ctype}" for name, ctype in columns.items()])
29 if primary_key:
30 columns_types = f"{columns_types}, PRIMARY KEY ({','.join(primary_key)})"
31 sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_types})"
32 return sql
35class SqliteDbMixin(SerialisableMixin):
36 table_name = None
37 _db = None
38 source: PathType = ":memory:"
40 def __getstate__(self):
41 return {**self._init_vals}
43 def __setstate__(self, state):
44 # this will reset connections to read only db's
45 obj = self.__class__(**state)
46 self.__dict__.update(obj.__dict__)
48 def __repr__(self):
49 name = self.__class__.__name__
50 total_records = len(self)
51 args = ", ".join(
52 f"{k}={repr(v) if isinstance(v, str) else v}"
53 for k, v in self._init_vals.items()
54 if k != "data"
55 )
56 return f"{name}({args}, total_records={total_records})"
58 def __len__(self):
59 return self.num_records()
61 def __eq__(self, other):
62 return isinstance(other, self.__class__) and other.db is self.db
64 def _init_tables(self) -> None:
65 # is source an existing db
66 self._db = self._db or sqlite3.connect(
67 self.source,
68 detect_types=sqlite3.PARSE_DECLTYPES,
69 check_same_thread=False,
70 )
71 self._db.row_factory = sqlite3.Row
73 # try and reduce memory usage
74 cursor = self._db.cursor()
75 cursor.execute("PRAGMA cache_size = -2048;")
77 # A bit of magic.
78 # Assumes schema attributes named as `_<table name>_schema`
79 for attr in dir(self):
80 if attr.endswith("_schema"):
81 table_name = "_".join(attr.split("_")[1:-1])
82 attr = getattr(self, attr)
83 sql = _make_table_sql(table_name, attr)
84 cursor.execute(sql)
86 @property
87 def db(self) -> sqlite3.Connection:
88 if self._db is None:
89 self._db = sqlite3.connect(
90 self.source,
91 detect_types=sqlite3.PARSE_DECLTYPES,
92 check_same_thread=False,
93 )
94 self._db.row_factory = sqlite3.Row
96 return self._db
98 def _execute_sql(self, cmnd: str, values=None) -> sqlite3.Cursor:
99 with self.db:
100 # context manager ensures safe transactions
101 cursor = self.db.cursor()
102 cursor.execute(cmnd, values or [])
103 return cursor
105 def num_records(self):
106 sql = f"SELECT COUNT(*) as count FROM {self.table_name}"
107 return list(self._execute_sql(sql).fetchone())[0]
109 def close(self):
110 self.db.commit()
111 self.db.close()
113 def get_distinct(self, column: str) -> set[str]:
114 sql = f"SELECT DISTINCT {column} from {self.table_name}"
115 return {r[column] for r in self._execute_sql(sql).fetchall()}
117 def make_indexes(self):
118 """adds db indexes for core attributes"""
119 sql = "CREATE INDEX IF NOT EXISTS %(index)s on %(table)s(%(col)s)"
120 for table_name, columns in self._index_columns.items():
121 for col in columns:
122 index = f"{col}_index"
123 self._execute_sql(sql % dict(table=table_name, index=index, col=col))
126# HDF5 base class
127@dataclasses.dataclass
128class Hdf5Mixin(SerialisableMixin):
129 """HDF5 sequence data storage"""
131 _file = None
132 _is_open = False
134 def __getstate__(self):
135 if set(self.mode) & {"w", "a"}:
136 raise NotImplementedError(f"pickling not supported for mode={self.mode!r}")
137 return self._init_vals.copy()
139 def __setstate__(self, state):
140 obj = self.__class__(**state)
141 self.__dict__.update(obj.__dict__)
142 # because we have a __del__ method, and self attributes point to
143 # attributes on obj, we need to modify obj state so that garbage
144 # collection does not screw up self
145 obj._is_open = False
146 obj._file = None
148 def __del__(self):
149 if self._is_open and self._file is not None:
150 self._file.flush()
151 if self._file is not None:
152 self._file.close()
153 self._is_open = False
155 def close(self):
156 if self._is_open:
157 self._file.flush()
158 self._file.close()
159 self._is_open = False