Coverage for /Users/gavin/repos/EnsemblLite/src/ensembl_lite/_db_base.py: 97%

90 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-06-12 16:31 -0400

1import dataclasses 

2import sqlite3 

3 

4from ensembl_lite._util import PathType, SerialisableMixin 

5 

6 

7ReturnType = tuple[str, tuple] # the sql statement and corresponding values 

8 

9 

10def _make_table_sql( 

11 table_name: str, 

12 columns: dict, 

13) -> str: 

14 """makes the SQL for creating a table 

15 

16 Parameters 

17 ---------- 

18 table_name : str 

19 name of the table 

20 columns : dict 

21 {<column name>: <column SQL type>, ...} 

22 

23 Returns 

24 ------- 

25 str 

26 """ 

27 primary_key = columns.pop("PRIMARY KEY", None) 

28 columns_types = ", ".join([f"{name} {ctype}" for name, ctype in columns.items()]) 

29 if primary_key: 

30 columns_types = f"{columns_types}, PRIMARY KEY ({','.join(primary_key)})" 

31 sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_types})" 

32 return sql 

33 

34 

35class SqliteDbMixin(SerialisableMixin): 

36 table_name = None 

37 _db = None 

38 source: PathType = ":memory:" 

39 

40 def __getstate__(self): 

41 return {**self._init_vals} 

42 

43 def __setstate__(self, state): 

44 # this will reset connections to read only db's 

45 obj = self.__class__(**state) 

46 self.__dict__.update(obj.__dict__) 

47 

48 def __repr__(self): 

49 name = self.__class__.__name__ 

50 total_records = len(self) 

51 args = ", ".join( 

52 f"{k}={repr(v) if isinstance(v, str) else v}" 

53 for k, v in self._init_vals.items() 

54 if k != "data" 

55 ) 

56 return f"{name}({args}, total_records={total_records})" 

57 

58 def __len__(self): 

59 return self.num_records() 

60 

61 def __eq__(self, other): 

62 return isinstance(other, self.__class__) and other.db is self.db 

63 

64 def _init_tables(self) -> None: 

65 # is source an existing db 

66 self._db = self._db or sqlite3.connect( 

67 self.source, 

68 detect_types=sqlite3.PARSE_DECLTYPES, 

69 check_same_thread=False, 

70 ) 

71 self._db.row_factory = sqlite3.Row 

72 

73 # try and reduce memory usage 

74 cursor = self._db.cursor() 

75 cursor.execute("PRAGMA cache_size = -2048;") 

76 

77 # A bit of magic. 

78 # Assumes schema attributes named as `_<table name>_schema` 

79 for attr in dir(self): 

80 if attr.endswith("_schema"): 

81 table_name = "_".join(attr.split("_")[1:-1]) 

82 attr = getattr(self, attr) 

83 sql = _make_table_sql(table_name, attr) 

84 cursor.execute(sql) 

85 

86 @property 

87 def db(self) -> sqlite3.Connection: 

88 if self._db is None: 

89 self._db = sqlite3.connect( 

90 self.source, 

91 detect_types=sqlite3.PARSE_DECLTYPES, 

92 check_same_thread=False, 

93 ) 

94 self._db.row_factory = sqlite3.Row 

95 

96 return self._db 

97 

98 def _execute_sql(self, cmnd: str, values=None) -> sqlite3.Cursor: 

99 with self.db: 

100 # context manager ensures safe transactions 

101 cursor = self.db.cursor() 

102 cursor.execute(cmnd, values or []) 

103 return cursor 

104 

105 def num_records(self): 

106 sql = f"SELECT COUNT(*) as count FROM {self.table_name}" 

107 return list(self._execute_sql(sql).fetchone())[0] 

108 

109 def close(self): 

110 self.db.commit() 

111 self.db.close() 

112 

113 def get_distinct(self, column: str) -> set[str]: 

114 sql = f"SELECT DISTINCT {column} from {self.table_name}" 

115 return {r[column] for r in self._execute_sql(sql).fetchall()} 

116 

117 def make_indexes(self): 

118 """adds db indexes for core attributes""" 

119 sql = "CREATE INDEX IF NOT EXISTS %(index)s on %(table)s(%(col)s)" 

120 for table_name, columns in self._index_columns.items(): 

121 for col in columns: 

122 index = f"{col}_index" 

123 self._execute_sql(sql % dict(table=table_name, index=index, col=col)) 

124 

125 

126# HDF5 base class 

127@dataclasses.dataclass 

128class Hdf5Mixin(SerialisableMixin): 

129 """HDF5 sequence data storage""" 

130 

131 _file = None 

132 _is_open = False 

133 

134 def __getstate__(self): 

135 if set(self.mode) & {"w", "a"}: 

136 raise NotImplementedError(f"pickling not supported for mode={self.mode!r}") 

137 return self._init_vals.copy() 

138 

139 def __setstate__(self, state): 

140 obj = self.__class__(**state) 

141 self.__dict__.update(obj.__dict__) 

142 # because we have a __del__ method, and self attributes point to 

143 # attributes on obj, we need to modify obj state so that garbage 

144 # collection does not screw up self 

145 obj._is_open = False 

146 obj._file = None 

147 

148 def __del__(self): 

149 if self._is_open and self._file is not None: 

150 self._file.flush() 

151 if self._file is not None: 

152 self._file.close() 

153 self._is_open = False 

154 

155 def close(self): 

156 if self._is_open: 

157 self._file.flush() 

158 self._file.close() 

159 self._is_open = False