Coverage for src/dataknobs_data/backends/postgres_vector.py: 0%

100 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""PostgreSQL vector support utilities.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from typing import TYPE_CHECKING, Any 

7 

8if TYPE_CHECKING: 

9 import asyncpg 

10 import numpy as np 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15def check_pgvector_extension_sync(db: Any) -> bool: 

16 """Check if pgvector extension is installed (sync version). 

17  

18 Args: 

19 db: PostgresDB connection object 

20  

21 Returns: 

22 True if pgvector is installed, False otherwise 

23 """ 

24 try: 

25 result = db.query(""" 

26 SELECT EXISTS ( 

27 SELECT 1 FROM pg_extension WHERE extname = 'vector' 

28 ) as exists 

29 """) 

30 return bool(result.iloc[0]["exists"]) if not result.empty else False 

31 except Exception as e: 

32 logger.debug(f"Could not check pgvector extension: {e}") 

33 return False 

34 

35 

36def install_pgvector_extension_sync(db: Any) -> bool: 

37 """Install pgvector extension if not already installed (sync version). 

38  

39 Args: 

40 db: PostgresDB connection object 

41  

42 Returns: 

43 True if installation successful or already installed 

44 """ 

45 try: 

46 # Check if already installed 

47 if check_pgvector_extension_sync(db): 

48 logger.debug("pgvector extension already installed") 

49 return True 

50 

51 # Try to install 

52 db.execute("CREATE EXTENSION IF NOT EXISTS vector") 

53 logger.info("Successfully installed pgvector extension") 

54 return True 

55 except Exception as e: 

56 logger.warning(f"Could not install pgvector extension: {e}") 

57 return False 

58 

59 

60async def check_pgvector_extension(conn: asyncpg.Connection) -> bool: 

61 """Check if pgvector extension is installed. 

62  

63 Args: 

64 conn: AsyncPG connection 

65  

66 Returns: 

67 True if pgvector is installed, False otherwise 

68 """ 

69 result = await conn.fetchval(""" 

70 SELECT EXISTS ( 

71 SELECT 1 FROM pg_extension WHERE extname = 'vector' 

72 ) 

73 """) 

74 return bool(result) 

75 

76 

77async def install_pgvector_extension(conn: asyncpg.Connection) -> bool: 

78 """Install pgvector extension if not already installed. 

79  

80 Args: 

81 conn: AsyncPG connection 

82  

83 Returns: 

84 True if installation successful or already installed 

85 """ 

86 try: 

87 # Check if already installed 

88 if await check_pgvector_extension(conn): 

89 logger.debug("pgvector extension already installed") 

90 return True 

91 

92 # Try to install 

93 await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") 

94 logger.info("Successfully installed pgvector extension") 

95 return True 

96 except Exception as e: 

97 logger.warning(f"Could not install pgvector extension: {e}") 

98 return False 

99 

100 

101def get_vector_operator(metric: str) -> str: 

102 """Get PostgreSQL vector operator for distance metric. 

103  

104 Args: 

105 metric: Distance metric (cosine, euclidean, inner_product) 

106  

107 Returns: 

108 PostgreSQL operator string 

109 """ 

110 operators = { 

111 "cosine": "<=>", # Cosine distance 

112 "euclidean": "<->", # L2 distance 

113 "inner_product": "<#>", # Negative inner product 

114 "l2": "<->", # Alias for euclidean 

115 "ip": "<#>", # Alias for inner product 

116 } 

117 return operators.get(metric.lower(), "<=>") # Default to cosine 

118 

119 

120def get_optimal_index_type(num_vectors: int) -> tuple[str, dict[str, Any]]: 

121 """Determine optimal index type based on dataset size. 

122  

123 Args: 

124 num_vectors: Number of vectors in dataset 

125  

126 Returns: 

127 Tuple of (index_type, index_parameters) 

128 """ 

129 if num_vectors < 10000: 

130 # For small datasets, use IVFFlat with fewer lists 

131 return "ivfflat", {"lists": min(100, num_vectors // 10)} 

132 elif num_vectors < 1000000: 

133 # For medium datasets, use IVFFlat with standard parameters 

134 lists = int(num_vectors ** 0.5) # Square root heuristic 

135 return "ivfflat", {"lists": min(lists, 5000)} 

136 else: 

137 # For large datasets, consider HNSW (if available in pgvector version) 

138 # Note: HNSW requires pgvector 0.5.0+ 

139 return "hnsw", {"m": 16, "ef_construction": 200} 

140 

141 

142def build_vector_index_sql( 

143 table_name: str, 

144 schema_name: str, 

145 column_name: str, 

146 dimensions: int, 

147 metric: str = "cosine", 

148 index_type: str = "ivfflat", 

149 index_params: dict[str, Any] | None = None, 

150 field_name: str | None = None 

151) -> str: 

152 """Build SQL for creating a vector index. 

153  

154 Args: 

155 table_name: Name of table 

156 schema_name: Schema name 

157 column_name: SQL expression for vector column 

158 dimensions: Vector dimensions 

159 metric: Distance metric 

160 index_type: Type of index (ivfflat, hnsw) 

161 index_params: Index-specific parameters 

162 field_name: Original field name for index naming 

163  

164 Returns: 

165 SQL CREATE INDEX statement 

166 """ 

167 index_params = index_params or {} 

168 

169 # Determine field name for index naming 

170 if not field_name: 

171 field_name = extract_field_name(column_name) 

172 

173 index_name = get_vector_index_name(table_name, field_name, metric) 

174 

175 # Determine operator class based on metric 

176 op_class = { 

177 "cosine": "vector_cosine_ops", 

178 "euclidean": "vector_l2_ops", 

179 "l2": "vector_l2_ops", 

180 "inner_product": "vector_ip_ops", 

181 "ip": "vector_ip_ops", 

182 "dot_product": "vector_ip_ops", 

183 }.get(metric.lower(), "vector_cosine_ops") 

184 

185 if index_type == "ivfflat": 

186 lists = index_params.get("lists", 100) 

187 # IVFFlat requires proper parentheses for functional indexes with operator class 

188 # The column_name should already include the dimension cast 

189 return f""" 

190 CREATE INDEX IF NOT EXISTS {index_name} 

191 ON {schema_name}.{table_name} 

192 USING ivfflat (({column_name}) {op_class}) 

193 WITH (lists = {lists}) 

194 """ 

195 elif index_type == "hnsw": 

196 m = index_params.get("m", 16) 

197 ef_construction = index_params.get("ef_construction", 200) 

198 # HNSW index (requires pgvector 0.5.0+) 

199 # The column_name should already include the dimension cast 

200 return f""" 

201 CREATE INDEX IF NOT EXISTS {index_name} 

202 ON {schema_name}.{table_name}  

203 USING hnsw (({column_name}) {op_class}) 

204 WITH (m = {m}, ef_construction = {ef_construction}) 

205 """ 

206 else: 

207 # Default to basic index 

208 return f""" 

209 CREATE INDEX IF NOT EXISTS {index_name} 

210 ON {schema_name}.{table_name} 

211 USING btree ({column_name}) 

212 """ 

213 

214 

215def sanitize_identifier(name: str) -> str: 

216 """Sanitize a string to be used as a database identifier. 

217  

218 Removes or replaces special characters that are not valid in identifiers. 

219  

220 Args: 

221 name: Raw string that may contain special characters 

222  

223 Returns: 

224 Sanitized string safe for use as identifier 

225 """ 

226 import re 

227 # Remove SQL operators and special chars 

228 name = re.sub(r"[->()'\[\]:,\s]+", "_", name) 

229 # Remove multiple underscores 

230 name = re.sub(r"_+", "_", name) 

231 # Remove leading/trailing underscores 

232 name = name.strip("_") 

233 return name 

234 

235 

236def extract_field_name(column_expression: str) -> str: 

237 """Extract field name from a column expression. 

238  

239 Args: 

240 column_expression: SQL expression like "(data->'field'->>'value')::vector" 

241  

242 Returns: 

243 Extracted field name or 'vector' as fallback 

244 """ 

245 import re 

246 # Try to extract from JSON path expressions 

247 patterns = [ 

248 r"data->'([^']+)'", # data->'field' 

249 r"data->>'([^']+)'", # data->>'field' 

250 r"\$\.([^'\"]+)", # $.field (JSONPath) 

251 r"'([^']+)'", # Any quoted string 

252 ] 

253 

254 for pattern in patterns: 

255 match = re.search(pattern, column_expression) 

256 if match: 

257 return match.group(1) 

258 

259 # Fallback: try to use the whole expression after basic cleanup 

260 cleaned = sanitize_identifier(column_expression) 

261 return cleaned if cleaned else "vector" 

262 

263 

264def get_vector_index_name(table_name: str, field_name: str, metric: str = "cosine") -> str: 

265 """Generate consistent index name for vector field. 

266  

267 Args: 

268 table_name: Name of the table 

269 field_name: Name of the vector field (or column expression) 

270 metric: Distance metric 

271  

272 Returns: 

273 Index name string 

274 """ 

275 # Sanitize all parts 

276 clean_table = sanitize_identifier(table_name) 

277 clean_field = sanitize_identifier(field_name) 

278 clean_metric = sanitize_identifier(metric) 

279 

280 return f"idx_{clean_table}_{clean_field}_{clean_metric}" 

281 

282 

283def build_vector_column_expression(field_name: str, dimensions: int | None = None, for_index: bool = False) -> str: 

284 """Build SQL expression for vector column from JSON field. 

285  

286 Args: 

287 field_name: Name of the vector field in JSON 

288 dimensions: Optional dimensions for casting 

289 for_index: Whether this is for index creation (needs special handling) 

290  

291 Returns: 

292 SQL expression for vector column 

293 """ 

294 dim_cast = f"({dimensions})" if dimensions else "" 

295 

296 if for_index: 

297 # For indexes, we need a simpler expression 

298 # Since we're storing VectorFields as objects with 'value' key, index on that 

299 return f"(data->'{field_name}'->>'value')::vector{dim_cast}" 

300 else: 

301 # For queries, we can use the same expression 

302 return f"(data->'{field_name}'->>'value')::vector{dim_cast}" 

303 

304 

305def get_vector_count_sql(schema_name: str, table_name: str, field_name: str) -> str: 

306 """Get SQL to count vectors in a field. 

307  

308 Args: 

309 schema_name: Database schema 

310 table_name: Table name 

311 field_name: Vector field name 

312  

313 Returns: 

314 SQL query string 

315 """ 

316 return f""" 

317 SELECT COUNT(*) as count  

318 FROM {schema_name}.{table_name} 

319 WHERE data ? '{field_name}' 

320 """ 

321 

322 

323def get_index_check_sql(schema_name: str, table_name: str, field_name: str) -> tuple[str, list[Any]]: 

324 """Get SQL to check if vector index exists. 

325  

326 Args: 

327 schema_name: Database schema 

328 table_name: Table name  

329 field_name: Vector field name 

330  

331 Returns: 

332 Tuple of (SQL query, parameters) 

333 """ 

334 sql = """ 

335 SELECT COUNT(*) > 0 as has_index 

336 FROM pg_indexes 

337 WHERE schemaname = $1 

338 AND tablename = $2 

339 AND indexname LIKE $3 

340 """ 

341 index_pattern = f"%{field_name}%" 

342 return sql, [schema_name, table_name, index_pattern] 

343 

344 

345def format_vector_for_postgres(vector: np.ndarray | list[float]) -> str: 

346 """Format vector for PostgreSQL vector column. 

347  

348 Args: 

349 vector: Numpy array or list of floats 

350  

351 Returns: 

352 PostgreSQL vector string format 

353 """ 

354 if hasattr(vector, 'tolist'): 

355 vector = vector.tolist() 

356 

357 # Format as PostgreSQL vector literal 

358 return f"[{','.join(str(float(v)) for v in vector)}]" 

359 

360 

361def parse_postgres_vector(vector_str: str) -> list[float]: 

362 """Parse PostgreSQL vector string to list of floats. 

363  

364 Args: 

365 vector_str: PostgreSQL vector string like '[0.1,0.2,0.3]' 

366  

367 Returns: 

368 List of floats 

369 """ 

370 if not vector_str or vector_str == "[]": 

371 return [] 

372 

373 # Remove brackets and split by comma 

374 vector_str = vector_str.strip("[]") 

375 return [float(v.strip()) for v in vector_str.split(",")]