Coverage for src/dataknobs_data/backends/postgres_vector.py: 0%
100 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
1"""PostgreSQL vector support utilities."""
3from __future__ import annotations
5import logging
6from typing import TYPE_CHECKING, Any
8if TYPE_CHECKING:
9 import asyncpg
10 import numpy as np
12logger = logging.getLogger(__name__)
15def check_pgvector_extension_sync(db: Any) -> bool:
16 """Check if pgvector extension is installed (sync version).
18 Args:
19 db: PostgresDB connection object
21 Returns:
22 True if pgvector is installed, False otherwise
23 """
24 try:
25 result = db.query("""
26 SELECT EXISTS (
27 SELECT 1 FROM pg_extension WHERE extname = 'vector'
28 ) as exists
29 """)
30 return bool(result.iloc[0]["exists"]) if not result.empty else False
31 except Exception as e:
32 logger.debug(f"Could not check pgvector extension: {e}")
33 return False
36def install_pgvector_extension_sync(db: Any) -> bool:
37 """Install pgvector extension if not already installed (sync version).
39 Args:
40 db: PostgresDB connection object
42 Returns:
43 True if installation successful or already installed
44 """
45 try:
46 # Check if already installed
47 if check_pgvector_extension_sync(db):
48 logger.debug("pgvector extension already installed")
49 return True
51 # Try to install
52 db.execute("CREATE EXTENSION IF NOT EXISTS vector")
53 logger.info("Successfully installed pgvector extension")
54 return True
55 except Exception as e:
56 logger.warning(f"Could not install pgvector extension: {e}")
57 return False
60async def check_pgvector_extension(conn: asyncpg.Connection) -> bool:
61 """Check if pgvector extension is installed.
63 Args:
64 conn: AsyncPG connection
66 Returns:
67 True if pgvector is installed, False otherwise
68 """
69 result = await conn.fetchval("""
70 SELECT EXISTS (
71 SELECT 1 FROM pg_extension WHERE extname = 'vector'
72 )
73 """)
74 return bool(result)
77async def install_pgvector_extension(conn: asyncpg.Connection) -> bool:
78 """Install pgvector extension if not already installed.
80 Args:
81 conn: AsyncPG connection
83 Returns:
84 True if installation successful or already installed
85 """
86 try:
87 # Check if already installed
88 if await check_pgvector_extension(conn):
89 logger.debug("pgvector extension already installed")
90 return True
92 # Try to install
93 await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
94 logger.info("Successfully installed pgvector extension")
95 return True
96 except Exception as e:
97 logger.warning(f"Could not install pgvector extension: {e}")
98 return False
101def get_vector_operator(metric: str) -> str:
102 """Get PostgreSQL vector operator for distance metric.
104 Args:
105 metric: Distance metric (cosine, euclidean, inner_product)
107 Returns:
108 PostgreSQL operator string
109 """
110 operators = {
111 "cosine": "<=>", # Cosine distance
112 "euclidean": "<->", # L2 distance
113 "inner_product": "<#>", # Negative inner product
114 "l2": "<->", # Alias for euclidean
115 "ip": "<#>", # Alias for inner product
116 }
117 return operators.get(metric.lower(), "<=>") # Default to cosine
120def get_optimal_index_type(num_vectors: int) -> tuple[str, dict[str, Any]]:
121 """Determine optimal index type based on dataset size.
123 Args:
124 num_vectors: Number of vectors in dataset
126 Returns:
127 Tuple of (index_type, index_parameters)
128 """
129 if num_vectors < 10000:
130 # For small datasets, use IVFFlat with fewer lists
131 return "ivfflat", {"lists": min(100, num_vectors // 10)}
132 elif num_vectors < 1000000:
133 # For medium datasets, use IVFFlat with standard parameters
134 lists = int(num_vectors ** 0.5) # Square root heuristic
135 return "ivfflat", {"lists": min(lists, 5000)}
136 else:
137 # For large datasets, consider HNSW (if available in pgvector version)
138 # Note: HNSW requires pgvector 0.5.0+
139 return "hnsw", {"m": 16, "ef_construction": 200}
142def build_vector_index_sql(
143 table_name: str,
144 schema_name: str,
145 column_name: str,
146 dimensions: int,
147 metric: str = "cosine",
148 index_type: str = "ivfflat",
149 index_params: dict[str, Any] | None = None,
150 field_name: str | None = None
151) -> str:
152 """Build SQL for creating a vector index.
154 Args:
155 table_name: Name of table
156 schema_name: Schema name
157 column_name: SQL expression for vector column
158 dimensions: Vector dimensions
159 metric: Distance metric
160 index_type: Type of index (ivfflat, hnsw)
161 index_params: Index-specific parameters
162 field_name: Original field name for index naming
164 Returns:
165 SQL CREATE INDEX statement
166 """
167 index_params = index_params or {}
169 # Determine field name for index naming
170 if not field_name:
171 field_name = extract_field_name(column_name)
173 index_name = get_vector_index_name(table_name, field_name, metric)
175 # Determine operator class based on metric
176 op_class = {
177 "cosine": "vector_cosine_ops",
178 "euclidean": "vector_l2_ops",
179 "l2": "vector_l2_ops",
180 "inner_product": "vector_ip_ops",
181 "ip": "vector_ip_ops",
182 "dot_product": "vector_ip_ops",
183 }.get(metric.lower(), "vector_cosine_ops")
185 if index_type == "ivfflat":
186 lists = index_params.get("lists", 100)
187 # IVFFlat requires proper parentheses for functional indexes with operator class
188 # The column_name should already include the dimension cast
189 return f"""
190 CREATE INDEX IF NOT EXISTS {index_name}
191 ON {schema_name}.{table_name}
192 USING ivfflat (({column_name}) {op_class})
193 WITH (lists = {lists})
194 """
195 elif index_type == "hnsw":
196 m = index_params.get("m", 16)
197 ef_construction = index_params.get("ef_construction", 200)
198 # HNSW index (requires pgvector 0.5.0+)
199 # The column_name should already include the dimension cast
200 return f"""
201 CREATE INDEX IF NOT EXISTS {index_name}
202 ON {schema_name}.{table_name}
203 USING hnsw (({column_name}) {op_class})
204 WITH (m = {m}, ef_construction = {ef_construction})
205 """
206 else:
207 # Default to basic index
208 return f"""
209 CREATE INDEX IF NOT EXISTS {index_name}
210 ON {schema_name}.{table_name}
211 USING btree ({column_name})
212 """
215def sanitize_identifier(name: str) -> str:
216 """Sanitize a string to be used as a database identifier.
218 Removes or replaces special characters that are not valid in identifiers.
220 Args:
221 name: Raw string that may contain special characters
223 Returns:
224 Sanitized string safe for use as identifier
225 """
226 import re
227 # Remove SQL operators and special chars
228 name = re.sub(r"[->()'\[\]:,\s]+", "_", name)
229 # Remove multiple underscores
230 name = re.sub(r"_+", "_", name)
231 # Remove leading/trailing underscores
232 name = name.strip("_")
233 return name
236def extract_field_name(column_expression: str) -> str:
237 """Extract field name from a column expression.
239 Args:
240 column_expression: SQL expression like "(data->'field'->>'value')::vector"
242 Returns:
243 Extracted field name or 'vector' as fallback
244 """
245 import re
246 # Try to extract from JSON path expressions
247 patterns = [
248 r"data->'([^']+)'", # data->'field'
249 r"data->>'([^']+)'", # data->>'field'
250 r"\$\.([^'\"]+)", # $.field (JSONPath)
251 r"'([^']+)'", # Any quoted string
252 ]
254 for pattern in patterns:
255 match = re.search(pattern, column_expression)
256 if match:
257 return match.group(1)
259 # Fallback: try to use the whole expression after basic cleanup
260 cleaned = sanitize_identifier(column_expression)
261 return cleaned if cleaned else "vector"
264def get_vector_index_name(table_name: str, field_name: str, metric: str = "cosine") -> str:
265 """Generate consistent index name for vector field.
267 Args:
268 table_name: Name of the table
269 field_name: Name of the vector field (or column expression)
270 metric: Distance metric
272 Returns:
273 Index name string
274 """
275 # Sanitize all parts
276 clean_table = sanitize_identifier(table_name)
277 clean_field = sanitize_identifier(field_name)
278 clean_metric = sanitize_identifier(metric)
280 return f"idx_{clean_table}_{clean_field}_{clean_metric}"
283def build_vector_column_expression(field_name: str, dimensions: int | None = None, for_index: bool = False) -> str:
284 """Build SQL expression for vector column from JSON field.
286 Args:
287 field_name: Name of the vector field in JSON
288 dimensions: Optional dimensions for casting
289 for_index: Whether this is for index creation (needs special handling)
291 Returns:
292 SQL expression for vector column
293 """
294 dim_cast = f"({dimensions})" if dimensions else ""
296 if for_index:
297 # For indexes, we need a simpler expression
298 # Since we're storing VectorFields as objects with 'value' key, index on that
299 return f"(data->'{field_name}'->>'value')::vector{dim_cast}"
300 else:
301 # For queries, we can use the same expression
302 return f"(data->'{field_name}'->>'value')::vector{dim_cast}"
305def get_vector_count_sql(schema_name: str, table_name: str, field_name: str) -> str:
306 """Get SQL to count vectors in a field.
308 Args:
309 schema_name: Database schema
310 table_name: Table name
311 field_name: Vector field name
313 Returns:
314 SQL query string
315 """
316 return f"""
317 SELECT COUNT(*) as count
318 FROM {schema_name}.{table_name}
319 WHERE data ? '{field_name}'
320 """
323def get_index_check_sql(schema_name: str, table_name: str, field_name: str) -> tuple[str, list[Any]]:
324 """Get SQL to check if vector index exists.
326 Args:
327 schema_name: Database schema
328 table_name: Table name
329 field_name: Vector field name
331 Returns:
332 Tuple of (SQL query, parameters)
333 """
334 sql = """
335 SELECT COUNT(*) > 0 as has_index
336 FROM pg_indexes
337 WHERE schemaname = $1
338 AND tablename = $2
339 AND indexname LIKE $3
340 """
341 index_pattern = f"%{field_name}%"
342 return sql, [schema_name, table_name, index_pattern]
345def format_vector_for_postgres(vector: np.ndarray | list[float]) -> str:
346 """Format vector for PostgreSQL vector column.
348 Args:
349 vector: Numpy array or list of floats
351 Returns:
352 PostgreSQL vector string format
353 """
354 if hasattr(vector, 'tolist'):
355 vector = vector.tolist()
357 # Format as PostgreSQL vector literal
358 return f"[{','.join(str(float(v)) for v in vector)}]"
361def parse_postgres_vector(vector_str: str) -> list[float]:
362 """Parse PostgreSQL vector string to list of floats.
364 Args:
365 vector_str: PostgreSQL vector string like '[0.1,0.2,0.3]'
367 Returns:
368 List of floats
369 """
370 if not vector_str or vector_str == "[]":
371 return []
373 # Remove brackets and split by comma
374 vector_str = vector_str.strip("[]")
375 return [float(v.strip()) for v in vector_str.split(",")]