Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_db_base.py: 96%

80 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2024-03-25 13:40 +1100

1import dataclasses 

2import inspect 

3import sqlite3 

4 

5import numpy 

6 

7from ensembl_lite._util import blosc_compress_it, blosc_decompress_it 

8 

9 

10@dataclasses.dataclass(slots=True) 

11class _compressed_array_proxy: 

12 """this exists only to automate conversion of a customised sqlite type""" 

13 

14 array: numpy.ndarray 

15 

16 

17ReturnType = tuple[str, tuple] # the sql statement and corresponding values 

18 

19_compressor = blosc_compress_it() 

20_decompressor = blosc_decompress_it() 

21 

22 

23def compressed_array_to_sqlite(data): 

24 # todo change to cogent3 approach for cross-platform support 

25 return _compressor(data.array.astype(numpy.int32).tobytes()) 

26 

27 

28def decompressed_sqlite_to_array(data): 

29 # todo change to cogent3 approach for cross-platform support 

30 result = numpy.frombuffer(_decompressor(data), dtype=numpy.int32) 

31 if len(result): 

32 dim = result.shape[0] // 2 

33 result = result.reshape((dim, 2)) 

34 return result 

35 

36 

37# registering the conversion functions with sqlite 

38# since these conversion functions are tied to a type, need to ensure the 

39# type will be unique to this tool, best way is to use <libname_type> and 

40# wrap a fundamental type with a proxy 

41sqlite3.register_adapter(_compressed_array_proxy, compressed_array_to_sqlite) 

42sqlite3.register_converter("compressed_array", decompressed_sqlite_to_array) 

43 

44 

45def _make_table_sql( 

46 table_name: str, 

47 columns: dict, 

48) -> str: 

49 """makes the SQL for creating a table 

50 

51 Parameters 

52 ---------- 

53 table_name : str 

54 name of the table 

55 columns : dict 

56 {<column name>: <column SQL type>, ...} 

57 

58 Returns 

59 ------- 

60 str 

61 """ 

62 columns_types = ", ".join([f"{name} {ctype}" for name, ctype in columns.items()]) 

63 return f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_types});" 

64 

65 

66class SqliteDbMixin: 

67 table_name = None 

68 _db = None 

69 source = None 

70 

71 def __new__(cls, *args, **kwargs): 

72 obj = object.__new__(cls) 

73 init_sig = inspect.signature(cls.__init__) 

74 bargs = init_sig.bind_partial(cls, *args, **kwargs) 

75 bargs.apply_defaults() 

76 init_vals = bargs.arguments 

77 init_vals.pop("self", None) 

78 obj._init_vals = init_vals 

79 return obj 

80 

81 def __getstate__(self): 

82 return {**self._init_vals} 

83 

84 def __setstate__(self, state): 

85 # this will reset connections to read only db's 

86 obj = self.__class__(**state) 

87 self.__dict__.update(obj.__dict__) 

88 

89 def __repr__(self): 

90 name = self.__class__.__name__ 

91 total_records = len(self) 

92 args = ", ".join( 

93 f"{k}={repr(v) if isinstance(v, str) else v}" 

94 for k, v in self._init_vals.items() 

95 if k != "data" 

96 ) 

97 return f"{name}({args}, total_records={total_records})" 

98 

99 def __len__(self): 

100 return self.num_records() 

101 

102 def __eq__(self, other): 

103 return isinstance(other, self.__class__) and other.db is self.db 

104 

105 def _init_tables(self) -> None: 

106 # is source an existing db 

107 self._db = sqlite3.connect( 

108 self.source, 

109 detect_types=sqlite3.PARSE_DECLTYPES, 

110 check_same_thread=False, 

111 ) 

112 self._db.row_factory = sqlite3.Row 

113 

114 # A bit of magic. 

115 # Assumes schema attributes named as `_<table name>_schema` 

116 for attr in dir(self): 

117 if attr.endswith("_schema"): 

118 table_name = "_".join(attr.split("_")[1:-1]) 

119 attr = getattr(self, attr) 

120 sql = _make_table_sql(table_name, attr) 

121 self._execute_sql(sql) 

122 

123 @property 

124 def db(self) -> sqlite3.Connection: 

125 if self._db is None: 

126 self._db = sqlite3.connect( 

127 self.source, 

128 detect_types=sqlite3.PARSE_DECLTYPES, 

129 check_same_thread=False, 

130 ) 

131 self._db.row_factory = sqlite3.Row 

132 

133 return self._db 

134 

135 def _execute_sql(self, cmnd: str, values=None) -> sqlite3.Cursor: 

136 with self.db: 

137 # context manager ensures safe transactions 

138 cursor = self.db.cursor() 

139 cursor.execute(cmnd, values or []) 

140 return cursor 

141 

142 def num_records(self): 

143 sql = f"SELECT COUNT(*) as count FROM {self.table_name}" 

144 return list(self._execute_sql(sql).fetchone())[0] 

145 

146 def close(self): 

147 self.db.commit() 

148 self.db.close() 

149 

150 def get_distinct(self, column: str) -> set[str]: 

151 sql = f"SELECT DISTINCT {column} from {self.table_name}" 

152 return {r[column] for r in self._execute_sql(sql).fetchall()}