Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_db_base.py: 96%
80 statements
« prev ^ index » next coverage.py v7.2.3, created at 2024-03-25 13:40 +1100
« prev ^ index » next coverage.py v7.2.3, created at 2024-03-25 13:40 +1100
1import dataclasses
2import inspect
3import sqlite3
5import numpy
7from ensembl_lite._util import blosc_compress_it, blosc_decompress_it
10@dataclasses.dataclass(slots=True)
11class _compressed_array_proxy:
12 """this exists only to automate conversion of a customised sqlite type"""
14 array: numpy.ndarray
17ReturnType = tuple[str, tuple] # the sql statement and corresponding values
19_compressor = blosc_compress_it()
20_decompressor = blosc_decompress_it()
23def compressed_array_to_sqlite(data):
24 # todo change to cogent3 approach for cross-platform support
25 return _compressor(data.array.astype(numpy.int32).tobytes())
28def decompressed_sqlite_to_array(data):
29 # todo change to cogent3 approach for cross-platform support
30 result = numpy.frombuffer(_decompressor(data), dtype=numpy.int32)
31 if len(result):
32 dim = result.shape[0] // 2
33 result = result.reshape((dim, 2))
34 return result
37# registering the conversion functions with sqlite
38# since these conversion functions are tied to a type, need to ensure the
39# type will be unique to this tool, best way is to use <libname_type> and
40# wrap a fundamental type with a proxy
41sqlite3.register_adapter(_compressed_array_proxy, compressed_array_to_sqlite)
42sqlite3.register_converter("compressed_array", decompressed_sqlite_to_array)
45def _make_table_sql(
46 table_name: str,
47 columns: dict,
48) -> str:
49 """makes the SQL for creating a table
51 Parameters
52 ----------
53 table_name : str
54 name of the table
55 columns : dict
56 {<column name>: <column SQL type>, ...}
58 Returns
59 -------
60 str
61 """
62 columns_types = ", ".join([f"{name} {ctype}" for name, ctype in columns.items()])
63 return f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_types});"
66class SqliteDbMixin:
67 table_name = None
68 _db = None
69 source = None
71 def __new__(cls, *args, **kwargs):
72 obj = object.__new__(cls)
73 init_sig = inspect.signature(cls.__init__)
74 bargs = init_sig.bind_partial(cls, *args, **kwargs)
75 bargs.apply_defaults()
76 init_vals = bargs.arguments
77 init_vals.pop("self", None)
78 obj._init_vals = init_vals
79 return obj
81 def __getstate__(self):
82 return {**self._init_vals}
84 def __setstate__(self, state):
85 # this will reset connections to read only db's
86 obj = self.__class__(**state)
87 self.__dict__.update(obj.__dict__)
89 def __repr__(self):
90 name = self.__class__.__name__
91 total_records = len(self)
92 args = ", ".join(
93 f"{k}={repr(v) if isinstance(v, str) else v}"
94 for k, v in self._init_vals.items()
95 if k != "data"
96 )
97 return f"{name}({args}, total_records={total_records})"
99 def __len__(self):
100 return self.num_records()
102 def __eq__(self, other):
103 return isinstance(other, self.__class__) and other.db is self.db
105 def _init_tables(self) -> None:
106 # is source an existing db
107 self._db = sqlite3.connect(
108 self.source,
109 detect_types=sqlite3.PARSE_DECLTYPES,
110 check_same_thread=False,
111 )
112 self._db.row_factory = sqlite3.Row
114 # A bit of magic.
115 # Assumes schema attributes named as `_<table name>_schema`
116 for attr in dir(self):
117 if attr.endswith("_schema"):
118 table_name = "_".join(attr.split("_")[1:-1])
119 attr = getattr(self, attr)
120 sql = _make_table_sql(table_name, attr)
121 self._execute_sql(sql)
123 @property
124 def db(self) -> sqlite3.Connection:
125 if self._db is None:
126 self._db = sqlite3.connect(
127 self.source,
128 detect_types=sqlite3.PARSE_DECLTYPES,
129 check_same_thread=False,
130 )
131 self._db.row_factory = sqlite3.Row
133 return self._db
135 def _execute_sql(self, cmnd: str, values=None) -> sqlite3.Cursor:
136 with self.db:
137 # context manager ensures safe transactions
138 cursor = self.db.cursor()
139 cursor.execute(cmnd, values or [])
140 return cursor
142 def num_records(self):
143 sql = f"SELECT COUNT(*) as count FROM {self.table_name}"
144 return list(self._execute_sql(sql).fetchone())[0]
146 def close(self):
147 self.db.commit()
148 self.db.close()
150 def get_distinct(self, column: str) -> set[str]:
151 sql = f"SELECT DISTINCT {column} from {self.table_name}"
152 return {r[column] for r in self._execute_sql(sql).fetchall()}