Coverage for common/sql.py: 32%
743 statements
« prev ^ index » next coverage.py v7.8.0, created at 2026-02-05 06:46 -0600
« prev ^ index » next coverage.py v7.8.0, created at 2026-02-05 06:46 -0600
1"""
2crate_anon/common/sql.py
4===============================================================================
6 Copyright (C) 2015, University of Cambridge, Department of Psychiatry.
7 Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
9 This file is part of CRATE.
11 CRATE is free software: you can redistribute it and/or modify
12 it under the terms of the GNU General Public License as published by
13 the Free Software Foundation, either version 3 of the License, or
14 (at your option) any later version.
16 CRATE is distributed in the hope that it will be useful,
17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 GNU General Public License for more details.
21 You should have received a copy of the GNU General Public License
22 along with CRATE. If not, see <https://www.gnu.org/licenses/>.
24===============================================================================
26**Low-level SQL manipulation functions.**
28These are about the manipulation of SQL as text (e.g. for query building
29assistance for researchers, or for interpreting SQL data types in data
30dictionaries), not about a higher-level approach like SQLAlchemy.
32"""
34from collections import OrderedDict
35from dataclasses import dataclass
36import functools
37import logging
38import re
39from typing import Any, Dict, Iterable, List, Tuple, Union, Optional
41from cardinal_pythonlib.json_utils.serialize import (
42 METHOD_PROVIDES_INIT_KWARGS,
43 METHOD_STRIP_UNDERSCORE,
44 register_for_json,
45)
46from cardinal_pythonlib.lists import unique_list
47from cardinal_pythonlib.reprfunc import mapped_repr_stripping_underscores
48from cardinal_pythonlib.sizeformatter import sizeof_fmt
49from cardinal_pythonlib.sql.literals import (
50 sql_date_literal,
51 sql_string_literal,
52)
53from cardinal_pythonlib.sql.sql_grammar import SqlGrammar, text_from_parsed
54from cardinal_pythonlib.sql.sql_grammar_factory import (
55 make_grammar,
56 mysql_grammar,
57)
58from cardinal_pythonlib.sql.validation import (
59 SQLTYPES_INTEGER,
60 SQLTYPES_BIT,
61 SQLTYPES_FLOAT,
62 SQLTYPES_TEXT,
63 SQLTYPES_OTHER_NUMERIC,
64)
65from cardinal_pythonlib.sqlalchemy.core_query import count_star
66from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName
67from cardinal_pythonlib.sqlalchemy.schema import (
68 column_creation_ddl,
69 execute_ddl,
70)
71from cardinal_pythonlib.timing import MultiTimerContext, timer
72from pyparsing import ParseResults
73from sqlalchemy import inspect
74from sqlalchemy.dialects.mssql.base import MS_2012_VERSION
75from sqlalchemy.engine.base import Engine
76from sqlalchemy.engine.interfaces import Dialect
77from sqlalchemy.exc import CompileError
78from sqlalchemy.orm.session import Session
79from sqlalchemy.schema import Column, Table
80from sqlalchemy.sql.sqltypes import TypeEngine
82from crate_anon.common.stringfunc import get_spec_match_regex
84log = logging.getLogger(__name__)
87# =============================================================================
88# Types
89# =============================================================================
91SqlArgsTupleType = Tuple[str, List[Any]]
94# =============================================================================
95# Constants
96# =============================================================================
98# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
99# Generic
100# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102TIMING_COMMIT = "commit"
104SQL_OPS_VALUE_UNNECESSARY = ["IS NULL", "IS NOT NULL"]
105SQL_OPS_MULTIPLE_VALUES = ["IN", "NOT IN"]
107SQLTYPES_INTEGER_OR_BIT = SQLTYPES_INTEGER + SQLTYPES_BIT
108SQLTYPES_FLOAT_OR_OTHER_NUMERIC = SQLTYPES_FLOAT + SQLTYPES_OTHER_NUMERIC
110# Must match querybuilder.js:
111QB_DATATYPE_INTEGER = "int"
112QB_DATATYPE_FLOAT = "float"
113QB_DATATYPE_DATE = "date"
114QB_DATATYPE_STRING = "string"
115QB_DATATYPE_STRING_FULLTEXT = "string_fulltext"
116QB_DATATYPE_UNKNOWN = "unknown"
117QB_STRING_TYPES = [QB_DATATYPE_STRING, QB_DATATYPE_STRING_FULLTEXT]
119COLTYPE_WITH_ONE_INTEGER_REGEX = re.compile(r"^([A-z]+)\((-?\d+)\)$")
120# ... start, group(alphabetical), literal (, group(optional_minus_sign digits),
121# literal ), end
123# Dictionaries for the different dialects mapping text column type to length
124# or default length.
125# Doesn't include things like VARCHAR which require the user to specify length
127# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128# SQLAlchemy dialects
129# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
131DATABRICKS_COLTYPE_TO_LEN = {
132 # https://docs.databricks.com/en/sql/language-manual/data-types/string-type.html # noqa: E501
133 "STRING": None # There is no maximum.
134}
135MSSQL_COLTYPE_TO_LEN = {
136 # The "N" prefix means Unicode.
137 # https://docs.microsoft.com/en-us/sql/t-sql/data-types/char-and-varchar-transact-sql?view=sql-server-ver15 # noqa: E501
138 # https://docs.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver15 # noqa: E501
139 # https://docs.microsoft.com/en-us/sql/t-sql/data-types/ntext-text-and-image-transact-sql?view=sql-server-ver15 # noqa: E501
140 "NVARCHAR_MAX": 2**30 - 1,
141 # Can specify NVARCHAR(1) to NVARCHAR(4000), or NVARCHAR(MAX) for 2^30 - 1.
142 "VARCHAR_MAX": 2**31 - 1,
143 # Can specify VARCHAR(1) to VARCHAR(8000), or VARCHAR(MAX) for 2^31 - 1.
144 "TEXT": 2**31 - 1,
145 "NTEXT": 2**30 - 1,
146}
147MYSQL_COLTYPE_TO_LEN = {
148 # https://dev.mysql.com/doc/refman/8.0/en/string-type-overview.html
149 "CHAR": 1, # can specify CHAR(0) to CHAR(255), but if omitted, length is 1
150 "TINYTEXT": 255, # 2^8 - 1
151 "TEXT": 65535, # 2^16 - 1
152 "MEDIUMTEXT": 16777215, # 2^24 - 1
153 "LONGTEXT": 4294967295, # 2^32 - 1
154}
156DIALECT_TO_STRING_LEN_LOOKUP = {
157 SqlaDialectName.DATABRICKS: DATABRICKS_COLTYPE_TO_LEN,
158 SqlaDialectName.MSSQL: MSSQL_COLTYPE_TO_LEN,
159 SqlaDialectName.MYSQL: MYSQL_COLTYPE_TO_LEN,
160}
163# =============================================================================
164# Helper classes
165# =============================================================================
168@dataclass
169class IndexCreationInfo:
170 index_name: str #: Name of the index
171 column: Union[str, List[str]] #: Column name(s) to index
172 unique: bool = False #: Make a unique index?
174 @property
175 def column_names(self) -> str:
176 if isinstance(self.column, str):
177 # Single column
178 return self.column
179 else:
180 # Multiple columns
181 return ", ".join(self.column)
184# =============================================================================
185# SQL elements: identifiers
186# =============================================================================
189@register_for_json(method=METHOD_STRIP_UNDERSCORE)
190@functools.total_ordering
191class SchemaId:
192 """
193 Represents a database schema. This is a bit complex:
195 - In SQL Server, schemas live within databases. Tables can be referred to
196 as ``table``, ``schema.table``, or ``database.schema.table``.
198 - https://docs.microsoft.com/en-us/dotnet/framework/data/adonet/sql/ownership-and-user-schema-separation-in-sql-server
199 - The default schema is named ``dbo``.
201 - In PostgreSQL, schemas live within databases. Tables can be referred to
202 as ``table``, ``schema.table``, or ``database.schema.table``.
204 - https://www.postgresql.org/docs/current/static/ddl-schemas.html
205 - The default schema is named ``public``.
207 - In MySQL, "database" and "schema" are synonymous. Tables can be referred
208 to as ``table`` or ``database.table`` (= ``schema.table``).
210 - https://stackoverflow.com/questions/11618277/difference-between-schema-database-in-mysql
212 """ # noqa: E501
214 def __init__(self, db: str = "", schema: str = "") -> None:
215 """
216 Args:
217 db: database name
218 schema: schema name
219 """
220 assert "." not in db, f"Bad database name ({db!r}); can't include '.'"
221 assert (
222 "." not in schema
223 ), f"Bad schema name ({schema!r}); can't include '.'"
224 self._db = db
225 self._schema = schema
227 @property
228 def schema_tag(self) -> str:
229 """
230 String suitable for encoding the SchemaId e.g. in a single HTML form.
231 Takes the format ``database.schema``.
233 The :func:`__init__` function has already checked the assumption of no
234 ``'.'`` characters in either part.
235 """
236 return f"{self._db}.{self._schema}"
238 @classmethod
239 def from_schema_tag(cls, tag: str) -> "SchemaId":
240 """
241 Returns a :class:`SchemaId` from a tag of the form ``db.schema``.
242 """
243 parts = tag.split(".")
244 assert len(parts) == 2, f"Bad schema tag {tag!r}"
245 db, schema = parts
246 return SchemaId(db, schema)
248 def __bool__(self) -> bool:
249 """
250 Returns:
251 is there a named schema?
252 """
253 return bool(self._schema)
255 def __eq__(self, other: "SchemaId") -> bool:
256 return ( # ordering is for speed
257 self._schema == other._schema and self._db == other._db
258 )
260 def __lt__(self, other: "SchemaId") -> bool:
261 return (self._db, self._schema) < (other._db, other._schema)
263 def __hash__(self) -> int:
264 return hash(str(self))
266 def identifier(self, grammar: SqlGrammar) -> str:
267 """
268 Returns an SQL identifier for this schema using the specified SQL
269 grammar, quoting it if need be.
271 Args:
272 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
273 """
274 return make_identifier(grammar, database=self._db, schema=self._schema)
276 def table_id(self, table: str) -> "TableId":
277 """
278 Returns a :class:`TableId` combining this schema and the specified
279 table.
281 Args:
282 table: name of the table
283 """
284 return TableId(db=self._db, schema=self._schema, table=table)
286 def column_id(self, table: str, column: str) -> "ColumnId":
287 """
288 Returns a :class:`ColumnId` combining this schema and the specified
289 table/column.
291 Args:
292 table: name of the table
293 column: name of the column
294 """
295 return ColumnId(
296 db=self._db, schema=self._schema, table=table, column=column
297 )
299 @property
300 def db(self) -> str:
301 """
302 Returns the database part.
303 """
304 return self._db
306 @property
307 def schema(self) -> str:
308 """
309 Returns the schema part.
310 """
311 return self._schema
313 def __str__(self) -> str:
314 return self.identifier(mysql_grammar) # specific one unimportant
316 def __repr__(self) -> str:
317 return mapped_repr_stripping_underscores(self, ["_db", "_schema"])
319 def is_present(self) -> bool:
320 """
321 Is this a blank/nonfunctional schema, with no ``database`` or
322 ``schema`` part?
323 """
324 return bool(self._db or self._schema)
326 def is_blank(self) -> bool:
327 """
328 Is this a blank/nonfunctional schema, with no ``database`` or
329 ``schema`` part?
330 """
331 return not self.is_present()
334@register_for_json(method=METHOD_STRIP_UNDERSCORE)
335@functools.total_ordering
336class TableId:
337 """
338 Represents a database table.
339 """
341 def __init__(
342 self, db: str = "", schema: str = "", table: str = ""
343 ) -> None:
344 """
345 Args:
346 db: database name
347 schema: schema name
348 table: table name
349 """
350 self._db = db
351 self._schema = schema
352 self._table = table
354 def __bool__(self) -> bool:
355 return bool(self._table)
357 def __eq__(self, other: "TableId") -> bool:
358 return ( # ordering is for speed
359 self._table == other._table
360 and self._schema == other._schema
361 and self._db == other._db
362 )
364 def __lt__(self, other: "TableId") -> bool:
365 return (self._db, self._schema, self._table) < (
366 other._db,
367 other._schema,
368 other._table,
369 )
371 def __hash__(self) -> int:
372 return hash(str(self))
374 def identifier(self, grammar: SqlGrammar) -> str:
375 """
376 Returns an SQL identifier for this table using the specified SQL
377 grammar, quoting it if need be.
379 Args:
380 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
381 """
382 return make_identifier(
383 grammar, database=self._db, schema=self._schema, table=self._table
384 )
386 @property
387 def schema_id(self) -> SchemaId:
388 """
389 Returns a :class:`SchemaId` for the schema of our table.
390 """
391 return SchemaId(db=self._db, schema=self._schema)
393 def column_id(self, column: str) -> "ColumnId":
394 """
395 Returns a :class:`ColumnId` combining this table and the specified
396 column.
398 Args:
399 column: name of the column
400 """
401 return ColumnId(
402 db=self._db, schema=self._schema, table=self._table, column=column
403 )
405 def database_schema_part(self, grammar: SqlGrammar) -> str:
406 """
407 Returns an SQL identifier for this table's database/schema (without the
408 table part) using the specified SQL grammar, quoting it if need be.
410 Args:
411 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
412 """
413 return make_identifier(grammar, database=self._db, schema=self._schema)
415 def table_part(self, grammar: SqlGrammar) -> str:
416 """
417 Returns an SQL identifier for this table's table name (only) using the
418 specified SQL grammar, quoting it if need be.
420 Args:
421 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
422 """
423 return make_identifier(grammar, table=self._table)
425 @property
426 def db(self) -> str:
427 """
428 Returns the database part.
429 """
430 return self._db
432 @property
433 def schema(self) -> str:
434 """
435 Returns the schema part.
436 """
437 return self._schema
439 @property
440 def table(self) -> str:
441 """
442 Returns the table part.
443 """
444 return self._table
446 def __str__(self) -> str:
447 return self.identifier(mysql_grammar) # specific one unimportant
449 def __repr__(self) -> str:
450 return mapped_repr_stripping_underscores(
451 self, ["_db", "_schema", "_table"]
452 )
455@register_for_json(method=METHOD_STRIP_UNDERSCORE)
456@functools.total_ordering
457class ColumnId:
458 """
459 Represents a database column.
460 """
462 def __init__(
463 self, db: str = "", schema: str = "", table: str = "", column: str = ""
464 ) -> None:
465 """
466 Args:
467 db: database name
468 schema: schema name
469 table: table name
470 column: column name
471 """
472 self._db = db
473 self._schema = schema
474 self._table = table
475 self._column = column
477 def __bool__(self) -> bool:
478 return bool(self._column)
480 def __eq__(self, other: "ColumnId") -> bool:
481 return (
482 self._column == other._column
483 and self._table == other._table
484 and self._schema == other._schema
485 and self._db == other._db
486 )
488 def __lt__(self, other: "ColumnId") -> bool:
489 return (self._db, self._schema, self._table, self._column) < (
490 other._db,
491 other._schema,
492 other._table,
493 other._column,
494 )
496 @property
497 def is_valid(self) -> bool:
498 """
499 Do we know about a table and a column, at least?
500 """
501 return bool(self._table and self._column) # the minimum
503 def identifier(self, grammar: SqlGrammar) -> str:
504 return make_identifier(
505 grammar,
506 database=self._db,
507 schema=self._schema,
508 table=self._table,
509 column=self._column,
510 )
512 @property
513 def db(self) -> str:
514 """
515 Returns the database part.
516 """
517 return self._db
519 @property
520 def schema(self) -> str:
521 """
522 Returns the schema part.
523 """
524 return self._schema
526 @property
527 def table(self) -> str:
528 """
529 Returns the table part.
530 """
531 return self._table
533 @property
534 def column(self) -> str:
535 """
536 Returns the column part.
537 """
538 return self._column
540 @property
541 def schema_id(self) -> SchemaId:
542 """
543 Returns a :class:`SchemaId` for the schema of our column.
544 """
545 return SchemaId(db=self._db, schema=self._schema)
547 @property
548 def table_id(self) -> TableId:
549 """
550 Returns a :class:`TableId` for our table.
551 """
552 return TableId(db=self._db, schema=self._schema, table=self._table)
554 @property
555 def has_table_and_column(self) -> bool:
556 """
557 Do we know about a table and a column?
558 """
559 return bool(self._table and self._column)
561 def __str__(self) -> str:
562 return self.identifier(mysql_grammar) # specific one unimportant
564 def __repr__(self) -> str:
565 return mapped_repr_stripping_underscores(
566 self, ["_db", "_schema", "_table", "_column"]
567 )
569 # def html(self, grammar: SqlGrammar, bold_column: bool = True) -> str:
570 # components = [
571 # html.escape(grammar.quote_identifier_if_required(x))
572 # for x in [self._db, self._schema, self._table, self._column]
573 # if x]
574 # if not components:
575 # return ''
576 # if bold_column:
577 # components[-1] = f"<b>{components[-1]}</b>"
578 # return ".".join(components)
581def split_db_schema_table(db_schema_table: str) -> TableId:
582 """
583 Converts a simple SQL-style identifier string into a :class:`TableId`.
585 Args:
586 db_schema_table:
587 one of: ``database.schema.table``, ``schema.table``, ``table``
589 Returns:
590 a :class:`TableId`
592 Raises:
593 :exc:`ValueError` if the input is bad
595 """
596 components = db_schema_table.split(".")
597 if len(components) == 3: # db.schema.table
598 d, s, t = components[0], components[1], components[2]
599 elif len(components) == 2: # schema.table
600 d, s, t = "", components[0], components[1]
601 elif len(components) == 1: # table
602 d, s, t = "", "", components[0]
603 else:
604 raise ValueError(f"Bad db_schema_table: {db_schema_table}")
605 return TableId(db=d, schema=s, table=t)
608def split_db_schema_table_column(db_schema_table_col: str) -> ColumnId:
609 """
610 Converts a simple SQL-style identifier string into a :class:`ColumnId`.
612 Args:
613 db_schema_table_col:
614 one of: ``database.schema.table.column``, ``schema.table.column``,
615 ``table.column``, ``column``
617 Returns:
618 a :class:`ColumnId`
620 Raises:
621 :exc:`ValueError` if the input is bad
623 """
624 components = db_schema_table_col.split(".")
625 if len(components) == 4: # db.schema.table.column
626 d, s, t, c = components[0], components[1], components[2], components[3]
627 elif len(components) == 3: # schema.table.column
628 d, s, t, c = "", components[0], components[1], components[2]
629 elif len(components) == 2: # table.column
630 d, s, t, c = "", "", components[0], components[1]
631 elif len(components) == 1: # column
632 d, s, t, c = "", "", "", components[0]
633 else:
634 raise ValueError(f"Bad db_schema_table_col: {db_schema_table_col}")
635 return ColumnId(db=d, schema=s, table=t, column=c)
638def columns_to_table_column_hierarchy(
639 columns: List[ColumnId], sort: bool = True
640) -> List[Tuple[TableId, List[ColumnId]]]:
641 """
642 Converts a list of column IDs
643 Args:
644 columns: list of :class:`ColumnId` objects
645 sort: sort by table, and column within table?
647 Returns:
648 a list of tuples, each ``table, columns``, where ``table`` is a
649 :class:`TableId` and ``columns`` is a list of :class:`ColumnId`
651 """
652 tables = unique_list(c.table_id for c in columns)
653 if sort:
654 tables.sort()
655 table_column_map = [] # type: List[Tuple[TableId, List[ColumnId]]]
656 for t in tables:
657 t_columns = [c for c in columns if c.table_id == t]
658 if sort:
659 t_columns.sort()
660 table_column_map.append((t, t_columns))
661 return table_column_map
664# =============================================================================
665# Using SQL grammars (but without reference to Django models, for testing)
666# =============================================================================
669def make_identifier(
670 grammar: SqlGrammar,
671 database: str = None,
672 schema: str = None,
673 table: str = None,
674 column: str = None,
675) -> str:
676 """
677 Makes an SQL identifier by quoting its elements according to the style of
678 the specific SQL grammar, and then joining them with ``.``.
680 Args:
681 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
682 database: database name
683 schema: schema name
684 table: table name
685 column: column name
687 Returns:
688 a string as above in the order "database, schema, table, column", but
689 omitting any missing parts
691 """
692 elements = [
693 grammar.quote_identifier_if_required(x)
694 for x in (database, schema, table, column)
695 if x
696 ]
697 assert elements, "make_identifier(): No elements passed!"
698 return ".".join(elements)
701def dumb_make_identifier(
702 database: str = None,
703 schema: str = None,
704 table: str = None,
705 column: str = None,
706) -> str:
707 """
708 Makes an SQL-style identifier by joining all the parts with ``.``, without
709 bothering to quote them.
711 Args:
712 database: database name
713 schema: schema name
714 table: table name
715 column: column name
717 Returns:
718 a string as above in the order "database, schema, table, column", but
719 omitting any missing parts
721 """
722 elements = filter(None, [database, schema, table, column])
723 assert elements, "make_identifier(): No elements passed!"
724 return ".".join(elements)
727def parser_add_result_column(
728 parsed: ParseResults, column: str, grammar: SqlGrammar
729) -> ParseResults:
730 """
731 Takes a parsed SQL statement of the form
733 .. code-block:: sql
735 SELECT a, b, c
736 FROM sometable
737 WHERE conditions;
739 and adds a result column, e.g. ``d``, to give
741 .. code-block:: sql
743 SELECT a, b, c, d
744 FROM sometable
745 WHERE conditions;
747 Presupposes that there is at least one column already in the SELECT
748 statement.
750 Args:
751 parsed: a `pyparsing.ParseResults` result
752 column: column name
753 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
755 Returns:
756 a `pyparsing.ParseResults` result
758 """
759 existing_columns = parsed.select_expression.select_columns.asList()
760 if column not in existing_columns:
761 # doesn't exist; add it
762 newcol = grammar.get_result_column().parseString(column, parseAll=True)
763 parsed.select_expression.extend([",", newcol])
764 return parsed
767class JoinInfo:
768 """
769 Object to represent a SQL join condition in a simple way.
770 """
772 def __init__(
773 self,
774 table: str,
775 join_type: str = "INNER JOIN",
776 join_condition: str = "",
777 ) -> None: # e.g. "ON x = y"
778 """
779 Args:
780 table: table to be joined in
781 join_type: join method, e.g. ``"INNER JOIN"``
782 join_condition: join condition, e.g. ``"ON x = y"``
783 """
784 self.join_type = join_type
785 self.table = table
786 self.join_condition = join_condition
789def parser_add_from_tables(
790 parsed: ParseResults, join_info_list: List[JoinInfo], grammar: SqlGrammar
791) -> ParseResults:
792 """
793 Takes a parsed SQL statement of the form
795 .. code-block:: sql
797 SELECT a, b, c
798 FROM sometable
799 WHERE conditions;
801 and adds one or more join columns, e.g. ``JoinInfo("othertable", "INNER
802 JOIN", "ON table.key = othertable.key")``, to give
804 .. code-block:: sql
806 SELECT a, b, c
807 FROM sometable
808 INNER JOIN othertable ON table.key = othertable.key
809 WHERE conditions;
811 Presupposes that there at least one table already in the FROM clause.
813 Args:
814 parsed: a `pyparsing.ParseResults` result
815 join_info_list: list of :class:`JoinInfo` objects
816 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
818 Returns:
819 a `pyparsing.ParseResults` result
821 """
822 # log.critical(parsed.dump())
823 existing_tables = parsed.join_source.from_tables.asList()
824 # log.critical(f"existing tables: {existing_tables}")
825 # log.critical(f"adding table: {table}")
826 for ji in join_info_list:
827 if ji.table in existing_tables: # already there
828 # log.critical("field already present")
829 continue
830 parsed_join = grammar.get_join_op().parseString(
831 ji.join_type, parseAll=True
832 )[
833 0
834 ] # e.g. INNER JOIN
835 parsed_table = grammar.get_table_spec().parseString(
836 ji.table, parseAll=True
837 )[0]
838 extrabits = [parsed_join, parsed_table]
839 if ji.join_condition: # e.g. ON x = y
840 extrabits.append(
841 grammar.get_join_constraint().parseString(
842 ji.join_condition, parseAll=True
843 )[0]
844 )
845 parsed.join_source.extend(extrabits)
846 # log.critical(parsed.dump())
847 return parsed
850def get_first_from_table(
851 parsed: ParseResults,
852 match_db: str = "",
853 match_schema: str = "",
854 match_table: str = "",
855) -> TableId:
856 """
857 Given a set of parsed results from a SELECT statement, returns the ``db,
858 schema, table`` tuple representing the first table in the FROM clause.
860 Optionally, the match may be constrained with the ``match*`` parameters.
862 Args:
863 parsed: a `pyparsing.ParseResults` result
864 match_db: optional database name to constrain the result to
865 match_schema: optional schema name to constrain the result to
866 match_table: optional table name to constrain the result to
868 Returns:
869 a :class:`TableId`, which will be empty in case of failure
870 """
871 existing_tables = parsed.join_source.from_tables.asList()
872 for t in existing_tables:
873 if isinstance(t, list):
874 assert len(t) == 1
875 t = t[0]
876 table_id = split_db_schema_table(t)
877 if match_db and table_id.db != match_db:
878 continue
879 if match_schema and table_id.schema != match_schema:
880 continue
881 if match_table and table_id.table != match_table:
882 continue
883 return table_id
884 return TableId()
887def set_distinct_within_parsed(p: ParseResults, action: str = "set") -> None:
888 """
889 Modifies (in place) the DISTINCT status of a parsed SQL statement.
891 Args:
892 p: a `pyparsing.ParseResults` result
893 action: ``"set"`` to turn DISTINCT on; ``"clear"`` to turn it off;
894 or ``"toggle"`` to toggle it.
895 """
896 ss = p.select_specifier # type: ParseResults
897 if action == "set":
898 if "DISTINCT" not in ss.asList():
899 ss.append("DISTINCT")
900 elif action == "clear":
901 if "DISTINCT" in ss.asList():
902 del ss[:]
903 elif action == "toggle":
904 if "DISTINCT" in ss.asList():
905 del ss[:]
906 else:
907 ss.append("DISTINCT")
908 else:
909 raise ValueError("action must be one of set/clear/toggle")
912def set_distinct(
913 sql: str,
914 grammar: SqlGrammar,
915 action: str = "set",
916 formatted: bool = True,
917 debug: bool = False,
918 debug_verbose: bool = False,
919) -> str:
920 """
921 Takes an SQL statement (as a string) and modifies its DISTINCT status.
923 Args:
924 sql: SQL statment as text
925 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
926 action: one of ``"set"``, ``"clear"``, ``"toggle"``; see
927 :func:`set_distinct_within_parsed`
928 formatted: pretty-format the result?
929 debug: show debugging information to the Python log
930 debug_verbose: be verbose when debugging
932 Returns:
933 the modified SQL statment, as a string
935 """
936 p = grammar.get_select_statement().parseString(sql, parseAll=True)
937 if debug:
938 log.info(f"START: {sql}")
939 if debug_verbose:
940 log.debug("start dump:\n" + p.dump())
941 set_distinct_within_parsed(p, action=action)
942 result = text_from_parsed(p, formatted=formatted)
943 if debug:
944 log.info(f"END: {result}")
945 if debug_verbose:
946 log.debug("end dump:\n" + p.dump())
947 return result
950def toggle_distinct(
951 sql: str,
952 grammar: SqlGrammar,
953 formatted: bool = True,
954 debug: bool = False,
955 debug_verbose: bool = False,
956) -> str:
957 """
958 Takes an SQL statement and toggles its DISTINCT status.
960 Args:
961 sql: SQL statment as text
962 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
963 formatted: pretty-format the result?
964 debug: show debugging information to the Python log
965 debug_verbose: be verbose when debugging
967 Returns:
968 the modified SQL statment, as a string
970 """
971 return set_distinct(
972 sql=sql,
973 grammar=grammar,
974 action="toggle",
975 formatted=formatted,
976 debug=debug,
977 debug_verbose=debug_verbose,
978 )
981# =============================================================================
982# SQLAlchemy reflection and DDL
983# =============================================================================
985_global_print_not_execute_sql = False
988def set_print_not_execute(print_not_execute: bool) -> None:
989 """
990 Sets a nasty global flag: should we print DDL, rather than executing it,
991 when we issue DDL commands from this module?
993 Args:
994 print_not_execute: print (not execute)?
995 """
996 global _global_print_not_execute_sql
997 _global_print_not_execute_sql = print_not_execute
1000def _exec_ddl(engine: Engine, sql: str) -> None:
1001 """
1002 Executes SQL as DDL.
1004 Whether we act or just print is conditional on previous calls to
1005 :func:`set_print_not_execute`.
1007 Args:
1008 engine: SQLAlchemy database Engine
1009 sql: raw SQL to execute (or print)
1010 """
1011 log.debug(sql)
1012 if _global_print_not_execute_sql:
1013 print(format_sql_for_print(sql) + "\n;")
1014 # extra \n in case the SQL ends in a comment
1015 else:
1016 execute_ddl(engine, sql=sql)
1019def execute(engine: Engine, sql: str) -> None:
1020 """
1021 Executes plain SQL in a transaction.
1023 Whether we act or just print is conditional on previous calls to
1024 :func:`set_print_not_execute`.
1026 Args:
1027 engine: SQLAlchemy database Engine
1028 sql: raw SQL to execute (or print)
1029 """
1030 log.debug(sql)
1031 if _global_print_not_execute_sql:
1032 print(format_sql_for_print(sql) + "\n;")
1033 # extra \n in case the SQL ends in a comment
1034 else:
1035 with engine.begin() as connection:
1036 connection.execute(sql)
1039def add_columns(engine: Engine, table: Table, columns: List[Column]) -> None:
1040 """
1041 Adds columns to a table.
1043 Whether we act or just print is conditional on previous calls to
1044 :func:`set_print_not_execute`.
1046 Args:
1047 engine: SQLAlchemy database Engine
1048 table: SQLAlchemy Table object
1049 columns: SQLAlchemy Column objects to add to the table
1051 Behaviour of different database systems:
1053 - ANSI SQL: add one column at a time: ``ALTER TABLE ADD [COLUMN] coldef``
1055 - i.e. "COLUMN" optional, one at a time, no parentheses
1056 - https://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt
1058 - MySQL: ``ALTER TABLE ADD [COLUMN] (a INT, b VARCHAR(32));``
1060 - i.e. "COLUMN" optional, parentheses required for >1, multiple OK
1061 - https://dev.mysql.com/doc/refman/5.7/en/alter-table.html
1063 - MS SQL Server: ``ALTER TABLE ADD COLUMN a INT, B VARCHAR(32);``
1065 - i.e. no "COLUMN", no parentheses, multiple OK
1066 - https://msdn.microsoft.com/en-us/library/ms190238.aspx
1067 - https://msdn.microsoft.com/en-us/library/ms190273.aspx
1068 - https://stackoverflow.com/questions/2523676
1070 This function therefore operates one at a time.
1072 SQLAlchemy doesn't provide a shortcut for this.
1074 """
1075 existing_column_names = get_column_names(
1076 engine, tablename=table.name, to_lower=True
1077 )
1078 column_defs = [] # type: List[str]
1079 for column in columns:
1080 if column.name.lower() not in existing_column_names:
1081 column_defs.append(column_creation_ddl(column, engine.dialect))
1082 else:
1083 log.debug(
1084 f"Table {table.name!r}: column {column.name!r} "
1085 f"already exists; not adding"
1086 )
1087 for column_def in column_defs:
1088 log.info(f"Table {table.name!r}: adding column {column_def!r}")
1089 sql = f"ALTER TABLE {table.name} ADD {column_def}"
1090 _exec_ddl(engine, sql)
1093def drop_columns(
1094 engine: Engine, table: Table, column_names: Iterable[str]
1095) -> None:
1096 """
1097 Drops columns from a table.
1099 Whether we act or just print is conditional on previous calls to
1100 :func:`set_print_not_execute`.
1102 Args:
1103 engine: SQLAlchemy database Engine
1104 table: SQLAlchemy Table object
1105 column_names: names of columns to drop
1107 Columns are dropped one by one.
1109 """
1110 existing_column_names = get_column_names(
1111 engine, tablename=table.name, to_lower=True
1112 )
1113 for name in column_names:
1114 if name.lower() not in existing_column_names:
1115 log.debug(
1116 f"Table {table.name!r}: column {name!r} "
1117 f"does not exist; not dropping"
1118 )
1119 else:
1120 log.info(f"Table {table.name!r}: dropping column {name!r}")
1121 # SQL Server:
1122 # http://www.techonthenet.com/sql_server/tables/alter_table.php
1123 # MySQL:
1124 # http://dev.mysql.com/doc/refman/5.7/en/alter-table.html
1125 _exec_ddl(engine, f"ALTER TABLE {table.name} DROP COLUMN {name}")
1128def add_indexes(
1129 engine: Engine, table: Table, index_info_list: Iterable[IndexCreationInfo]
1130) -> None:
1131 """
1132 Adds indexes to a table.
1134 Whether we act or just print is conditional on previous calls to
1135 :func:`set_print_not_execute`.
1137 Args:
1138 engine:
1139 SQLAlchemy database Engine
1140 table:
1141 SQLAlchemy Table object
1142 index_info_list:
1143 Index(es) to create: list of :class:`IndexCreationInfo` objects.
1144 """
1145 existing_index_names = get_index_names(
1146 engine, tablename=table.name, to_lower=True
1147 )
1148 for i in index_info_list:
1149 index_name = i.index_name
1150 column = i.column_names
1151 if index_name.lower() not in existing_index_names:
1152 log.info(
1153 f"Table {table.name!r}: adding index {index_name!r} on "
1154 f"column {column!r}"
1155 )
1156 _exec_ddl(
1157 engine,
1158 f"""
1159 CREATE{" UNIQUE" if i.unique else ""} INDEX {index_name}
1160 ON {table.name} ({column})
1161 """,
1162 )
1163 else:
1164 log.debug(
1165 f"Table {table.name!r}: index {index_name!r} "
1166 f"already exists; not adding"
1167 )
1170def drop_indexes(
1171 engine: Engine, table: Table, index_names: Iterable[str]
1172) -> None:
1173 """
1174 Drops indexes from a table.
1176 Whether we act or just print is conditional on previous calls to
1177 :func:`set_print_not_execute`.
1179 Args:
1180 engine: SQLAlchemy database Engine
1181 table: SQLAlchemy Table object
1182 index_names: names of indexes to drop
1183 """
1184 existing_index_names = get_index_names(
1185 engine, tablename=table.name, to_lower=True
1186 )
1187 for index_name in index_names:
1188 if index_name.lower() not in existing_index_names:
1189 log.debug(
1190 f"Table {table.name!r}: index {index_name!r} "
1191 f"does not exist; not dropping"
1192 )
1193 else:
1194 log.info(f"Table {table.name!r}: dropping index {index_name!r}")
1195 if engine.dialect.name == "mysql":
1196 sql = f"ALTER TABLE {table.name} DROP INDEX {index_name}"
1197 elif engine.dialect.name == "mssql":
1198 sql = f"DROP INDEX {table.name}.{index_name}"
1199 else:
1200 assert False, f"Unknown dialect: {engine.dialect.name}"
1201 _exec_ddl(engine, sql)
1204def get_table_names(
1205 engine: Engine, to_lower: bool = False, sort: bool = False
1206) -> List[str]:
1207 """
1208 Returns all table names for the database.
1210 Args:
1211 engine: SQLAlchemy database Engine
1212 to_lower: convert table names to lower case?
1213 sort: sort table names?
1215 Returns:
1216 list of table names
1218 """
1219 inspector = inspect(engine)
1220 table_names = inspector.get_table_names()
1221 if to_lower:
1222 table_names = [x.lower() for x in table_names]
1223 if sort:
1224 table_names = sorted(table_names, key=lambda x: x.lower())
1225 return table_names
1228def get_view_names(
1229 engine: Engine, to_lower: bool = False, sort: bool = False
1230) -> List[str]:
1231 """
1232 Returns all view names for the database.
1234 Args:
1235 engine: SQLAlchemy database Engine
1236 to_lower: convert view names to lower case?
1237 sort: sort view names?
1239 Returns:
1240 list of view names
1242 """
1243 inspector = inspect(engine)
1244 view_names = inspector.get_view_names()
1245 if to_lower:
1246 view_names = [x.lower() for x in view_names]
1247 if sort:
1248 view_names = sorted(view_names, key=lambda x: x.lower())
1249 return view_names
1252def get_column_names(
1253 engine: Engine, tablename: str, to_lower: bool = False, sort: bool = False
1254) -> List[str]:
1255 """
1256 Reads columns names afresh from the database, for a specific table (in case
1257 metadata is out of date).
1259 Args:
1260 engine: SQLAlchemy database Engine
1261 tablename: name of the table
1262 to_lower: convert view names to lower case?
1263 sort: sort view names?
1265 Returns:
1266 list of column names
1268 """
1269 inspector = inspect(engine)
1270 columns = inspector.get_columns(tablename)
1271 column_names = [x["name"] for x in columns]
1272 if to_lower:
1273 column_names = [x.lower() for x in column_names]
1274 if sort:
1275 column_names = sorted(column_names, key=lambda x: x.lower())
1276 return column_names
1279def get_index_names(
1280 engine: Engine, tablename: str, to_lower: bool = False, sort: bool = False
1281) -> List[str]:
1282 """
1283 Reads index names from the database, for a specific table.
1285 Args:
1286 engine: SQLAlchemy database Engine
1287 tablename: name of the table
1288 to_lower: convert index names to lower case?
1289 sort: sort index names?
1291 Returns:
1292 list of index names
1294 """
1295 # http://docs.sqlalchemy.org/en/latest/core/reflection.html
1296 inspector = inspect(engine)
1297 indexes = inspector.get_indexes(tablename)
1298 index_names = [x["name"] for x in indexes if x["name"]]
1299 # ... at least for SQL Server, there always seems to be a blank one
1300 # with {'name': None, ...}.
1301 if to_lower:
1302 index_names = [x.lower() for x in index_names]
1303 if sort:
1304 index_names = sorted(index_names, key=lambda x: x.lower())
1305 return index_names
1308def ensure_columns_present(
1309 engine: Engine, tablename: str, column_names: Iterable[str]
1310) -> None:
1311 """
1312 Ensure all these columns are present in a table, or raise an exception.
1314 Operates in case-insensitive fashion.
1316 Args:
1317 engine: SQLAlchemy database Engine
1318 tablename: name of the table
1319 column_names: names of required columns
1321 Raises:
1322 :exc:`ValueError` if any are missing
1324 """
1325 existing_column_names = get_column_names(
1326 engine, tablename=tablename, to_lower=True
1327 )
1328 if not column_names:
1329 return
1330 for col in column_names:
1331 if col.lower() not in existing_column_names:
1332 raise ValueError(
1333 f"Column {col!r} missing from table {tablename!r}, "
1334 f"whose columns are {existing_column_names!r}"
1335 )
1338def create_view(engine: Engine, viewname: str, select_sql: str) -> None:
1339 """
1340 Creates a view.
1342 Whether we act or just print is conditional on previous calls to
1343 :func:`set_print_not_execute`.
1345 Args:
1346 engine: SQLAlchemy database Engine
1347 viewname: view name
1348 select_sql: SQL SELECT statement for this view
1349 """
1350 if engine.dialect.name == "mysql":
1351 # MySQL has CREATE OR REPLACE VIEW.
1352 sql = f"CREATE OR REPLACE VIEW {viewname} AS {select_sql}"
1353 else:
1354 # SQL Server doesn't: https://stackoverflow.com/questions/18534919
1355 drop_view(engine, viewname, quiet=True)
1356 sql = f"CREATE VIEW {viewname} AS {select_sql}"
1357 log.info(f"Creating view: {viewname!r}")
1358 _exec_ddl(engine, sql)
1361def assert_view_has_same_num_rows(
1362 engine: Engine, basetable: str, viewname: str
1363) -> None:
1364 """
1365 Ensures that a view gives the same number of rows as a table. (For use in
1366 situations where this should hold; views don't have to do this in general!)
1368 Args:
1369 engine: SQLAlchemy database Engine
1370 basetable: name of the table that this view should have a 1:1
1371 relationship to
1372 viewname: view name
1374 Raises:
1375 :exc:`AssertionError` if they don't have the same number of rows
1377 """
1378 # Note that this relies on the data, i.e. design failures MAY cause this
1379 # assertion to fail, but won't necessarily (e.g. if the table is empty).
1380 n_base = count_star(engine, basetable)
1381 n_view = count_star(engine, viewname)
1382 assert n_view == n_base, (
1383 f"View bug: view {viewname} has {n_view} records but its base table "
1384 f"{basetable} has {n_base}; they should be equal"
1385 )
1388def drop_view(engine: Engine, viewname: str, quiet: bool = False) -> None:
1389 """
1390 Drops a view.
1392 Whether we act or just print is conditional on previous calls to
1393 :func:`set_print_not_execute`.
1395 Args:
1396 engine: SQLAlchemy database Engine
1397 viewname: view name
1398 quiet: don't announce this to the Python log
1400 """
1401 # MySQL has DROP VIEW IF EXISTS, but SQL Server only has that from
1402 # SQL Server 2016 onwards.
1403 # - https://msdn.microsoft.com/en-us/library/ms173492.aspx
1404 # - http://dev.mysql.com/doc/refman/5.7/en/drop-view.html
1405 view_names = get_view_names(engine, to_lower=True)
1406 if viewname.lower() not in view_names:
1407 log.debug(f"View {viewname} does not exist; not dropping")
1408 else:
1409 if not quiet:
1410 log.info(f"Dropping view: {viewname!r}")
1411 _exec_ddl(engine, f"DROP VIEW {viewname}")
1414def get_column_fk_description(c: Column) -> str:
1415 """
1416 Standardized description of a column's foreign keys.
1418 Args:
1419 c:
1420 SQLAlchemy Column
1421 """
1422 fkeys = sorted(
1423 c.foreign_keys, key=lambda x: (x.column.table.name, x.column.name)
1424 )
1425 if not fkeys:
1426 return ""
1427 fk_strings = [f"{fk.column.table.name}.{fk.column.name}" for fk in fkeys]
1428 return "FK to " + ", ".join(fk_strings)
1431@dataclass
1432class ReflectedColumnInfo:
1433 """
1434 Provides information about a column reflected from a database, with
1435 optional additional information from a CRATE data dictionary, +/- a
1436 description of values in that column (for researcher reports).
1437 """
1439 column: Column
1440 override_comment: str = None # can override SQLAlchemy-level comment
1441 crate_annotation: str = None
1442 values_info: str = None
1444 @property
1445 def name(self) -> str:
1446 return self.columnname
1448 @property
1449 def columnname(self) -> str:
1450 return self.column.name
1451 # Do not manipulate the case of SOURCE tables/columns.
1452 # If you do, they can fail to match the SQLAlchemy
1453 # introspection and cause a crash.
1455 @property
1456 def tablename(self) -> str:
1457 return self.column.table.name
1459 @property
1460 def tablename_columname(self) -> str:
1461 return f"{self.column.table.name}.{self.column.name}"
1463 @property
1464 def sqla_coltype(self) -> TypeEngine:
1465 return self.column.type
1467 @property
1468 def sql_type(self) -> str:
1469 try:
1470 return str(self.column.type)
1471 except CompileError:
1472 log.critical(f"Column that failed was: {self.column!r}")
1473 raise
1475 @property
1476 def datatype_sqltext(self) -> str:
1477 return self.sql_type
1479 @property
1480 def pk(self) -> bool:
1481 return self.column.primary_key
1483 @property
1484 def nullable(self) -> bool:
1485 return self.column.nullable
1487 @property
1488 def comment(self) -> str:
1489 """
1490 The database comment, if present, or another that has been supplied.
1491 """
1492 db_comment = getattr(self.column, "comment", "")
1493 # ... not all dialects support reflecting comments;
1494 # https://docs.sqlalchemy.org/en/14/core/reflection.html
1495 return self.override_comment or db_comment or ""
1497 @property
1498 def nullable_str(self) -> str:
1499 return "✓" if self.nullable else "NOT NULL"
1501 @property
1502 def pk_str(self) -> str:
1503 return "PK" if self.pk else ""
1505 @property
1506 def fk_str(self) -> str:
1507 return get_column_fk_description(self.column)
1509 def get_column_source_description(self, with_fk: bool = True) -> str:
1510 """
1511 Returns a description of where the column is from, used as a suffix for
1512 data dictionary comment generation.
1514 Args:
1515 with_fk:
1516 Include foreign key descriptions (helpful because CRATE doesn't
1517 reproduce FK relationships in the destination DDL).
1518 """
1519 if with_fk:
1520 fk_str = self.fk_str
1521 if fk_str:
1522 fk_str = "; " + fk_str
1523 else:
1524 fk_str = ""
1525 return f" [from {self.tablename_columname}{fk_str}]"
1527 @property
1528 def crate_annotation_str(self) -> str:
1529 """
1530 Human-oriented version for report.
1531 """
1532 return self.crate_annotation or "?"
1534 @property
1535 def values_info_str(self) -> str:
1536 """
1537 Human-oriented version for report.
1538 """
1539 return self.values_info or "?"
1542# =============================================================================
1543# ViewMaker
1544# =============================================================================
1547class ViewMaker:
1548 """
1549 View-building assistance class.
1550 """
1552 def __init__(
1553 self,
1554 viewname: str,
1555 engine: Engine,
1556 basetable: str,
1557 existing_to_lower: bool = False,
1558 rename: Dict[str, str] = None,
1559 userobj: Any = None,
1560 enforce_same_n_rows_as_base: bool = True,
1561 insert_basetable_columns: bool = True,
1562 ) -> None:
1563 """
1564 Args:
1565 viewname: name of the view
1566 engine: SQLAlchemy database Engine
1567 basetable: name of the single base table that this view draws from
1568 existing_to_lower: translate column names to lower case in the
1569 view?
1570 rename: optional dictionary mapping ``from_name: to_name`` to
1571 translate column names in the view
1572 userobj: optional object (e.g. `argparse.Namespace`,
1573 dictionary...), not used by this class, and purely to store
1574 information for others' benefit
1575 enforce_same_n_rows_as_base: ensure that the view produces the
1576 same number of rows as its base table?
1577 insert_basetable_columns: start drafting the view by including all
1578 columns from the base table?
1579 """
1580 rename = rename or {}
1581 assert basetable, "ViewMaker: basetable missing!"
1582 self.viewname = viewname
1583 self.engine = engine
1584 self.basetable = basetable
1585 self.userobj = userobj # only for others' benefit
1586 self.enforce_same_n_rows_as_base = enforce_same_n_rows_as_base
1587 self.select_elements = [] # type: List[str]
1588 self.from_elements = [basetable] # type: List[str]
1589 self.where_elements = [] # type: List[str]
1590 self.lookup_tables = [] # type: List[str]
1591 self.index_requests = OrderedDict() # type: Dict[str, List[str]]
1593 if insert_basetable_columns:
1594 grammar = make_grammar(engine.dialect.name)
1596 def q(identifier: str) -> str:
1597 return grammar.quote_identifier_if_required(identifier)
1599 for colname in get_column_names(
1600 engine, tablename=basetable, to_lower=existing_to_lower
1601 ):
1602 if colname in rename:
1603 rename_to = rename[colname]
1604 if not rename_to:
1605 continue
1606 as_clause = f" AS {q(rename_to)}"
1607 else:
1608 as_clause = ""
1609 self.select_elements.append(
1610 f"{q(basetable)}.{q(colname)}{as_clause}"
1611 )
1612 assert self.select_elements, (
1613 "Must have some active SELECT " "elements from base table"
1614 )
1616 def add_select(self, element: str) -> None:
1617 """
1618 Add an element to the SELECT clause of the the draft view's SQL
1619 (meaning: add e.g. a result column).
1620 """
1621 self.select_elements.append(element)
1623 def add_from(self, element: str) -> None:
1624 """
1625 Add an element to the FROM clause of the draft view's SQL statement.
1626 """
1627 self.from_elements.append(element)
1629 def add_where(self, element: str) -> None:
1630 """
1631 Add an element to the WHERE clause of the draft view's SQL statement.
1632 """
1633 self.where_elements.append(element)
1635 def get_sql(self) -> str:
1636 """
1637 Returns the view-creation SQL.
1638 """
1639 assert self.select_elements, "ViewMaker: no SELECT elements!"
1640 if self.where_elements:
1641 where = "\n WHERE {}".format(
1642 "\n AND ".join(self.where_elements)
1643 )
1644 else:
1645 where = ""
1646 return (
1647 "\n SELECT {select_elements}"
1648 "\n FROM {from_elements}{where}".format(
1649 select_elements=",\n ".join(self.select_elements),
1650 from_elements="\n ".join(self.from_elements),
1651 where=where,
1652 )
1653 )
1655 def create_view(self, engine: Engine) -> None:
1656 """
1657 Creates the view.
1659 Whether we act or just print is conditional on previous calls to
1660 :func:`set_print_not_execute`.
1662 If ``enforce_same_n_rows_as_base`` is set, check the number of rows
1663 returned matches the base table.
1665 Args:
1666 engine: SQLAlchemy database Engine
1667 """
1668 create_view(engine, self.viewname, self.get_sql())
1669 if self.enforce_same_n_rows_as_base:
1670 assert_view_has_same_num_rows(
1671 engine, self.basetable, self.viewname
1672 )
1674 def drop_view(self, engine: Engine) -> None:
1675 """
1676 Drops the view.
1678 Whether we act or just print is conditional on previous calls to
1679 :func:`set_print_not_execute`.
1681 Args:
1682 engine: SQLAlchemy database Engine
1684 """
1685 drop_view(engine, self.viewname)
1687 def record_lookup_table(self, table: str) -> None:
1688 """
1689 Keep a record of a lookup table. The framework may wish to suppress
1690 these from a data dictionary later (e.g. create a view, suppress the
1691 messier raw data). See :func:`get_lookup_tables`.
1693 Args:
1694 table: table name
1695 """
1696 if table not in self.lookup_tables:
1697 self.lookup_tables.append(table)
1699 def get_lookup_tables(self) -> List[str]:
1700 """
1701 Returns all lookup tables that we have recorded. See
1702 :func:`record_lookup_table`.
1703 """
1704 return self.lookup_tables
1706 def request_index(self, table: str, column: str) -> None:
1707 """
1708 Note a request that a specific column be indexed. The framework can use
1709 the ViewMaker to keep a note of these requests, and then add index
1710 hints to a data dictionary if it wishes. See
1711 :func:`get_index_request_dict`.
1713 Args:
1714 table: table name
1715 column: column name
1716 """
1717 if table not in self.index_requests:
1718 self.index_requests[table] = [] # type: List[str]
1719 if column not in self.index_requests[table]:
1720 self.index_requests[table].append(column)
1722 def get_index_request_dict(self) -> Dict[str, List[str]]:
1723 """
1724 Returns all our recorded index requests, as a dictionary mapping each
1725 table name to a list of column names to be indexed. See
1726 :func:`request_index`.
1727 """
1728 return self.index_requests
1730 def record_lookup_table_keyfield(
1731 self, table: str, keyfield: Union[str, Iterable[str]]
1732 ) -> None:
1733 """
1734 Makes a note that a table is a lookup table, and its key field(s)
1735 should be indexed. See :func:`get_lookup_tables`,
1736 :func:`get_index_request_dict`.
1738 Args:
1739 table: table name
1740 keyfield: field name, or iterable (e.g. list) of them
1741 """
1742 if isinstance(keyfield, str):
1743 keyfield = [keyfield]
1744 self.record_lookup_table(table)
1745 for kf in keyfield:
1746 self.request_index(table, kf)
1748 def record_lookup_table_keyfields(
1749 self,
1750 table_keyfield_tuples: Iterable[Tuple[str, Union[str, Iterable[str]]]],
1751 ) -> None:
1752 """
1753 Make a note of a whole set of lookup table / key field groups. See
1754 :func:`record_lookup_table_keyfield`.
1756 Args:
1757 table_keyfield_tuples:
1758 iterable (e.g. list) of tuples of the format ``tablename,
1759 keyfield``. Each will be passed to
1760 :func:`record_lookup_table_keyfield`.
1761 """
1762 for t, k in table_keyfield_tuples:
1763 self.record_lookup_table_keyfield(t, k)
1766# =============================================================================
1767# TransactionSizeLimiter
1768# =============================================================================
1771class TransactionSizeLimiter:
1772 """
1773 Class to allow us to limit the size of database transactions.
1774 """
1776 def __init__(
1777 self,
1778 session: Session,
1779 max_rows_before_commit: int = None,
1780 max_bytes_before_commit: int = None,
1781 ) -> None:
1782 """
1783 Args:
1784 session: SQLAlchemy database Session
1785 max_rows_before_commit: how many rows should we insert before
1786 triggering a COMMIT? ``None`` for no limit.
1787 max_bytes_before_commit: how many bytes should we insert before
1788 triggering a COMMIT? ``None`` for no limit.
1789 """
1790 self._session = session
1791 self._max_rows_before_commit = max_rows_before_commit
1792 self._max_bytes_before_commit = max_bytes_before_commit
1793 self._bytes_in_transaction = 0
1794 self._rows_in_transaction = 0
1796 def commit(self) -> None:
1797 """
1798 Performs a database COMMIT and resets our counters.
1800 (Measures some timing information, too.)
1801 """
1802 with MultiTimerContext(timer, TIMING_COMMIT):
1803 self._session.commit()
1804 self._bytes_in_transaction = 0
1805 self._rows_in_transaction = 0
1807 def notify(
1808 self, n_rows: int, n_bytes: int, force_commit: bool = False
1809 ) -> None:
1810 """
1811 Use this function to notify the limiter of data that you've inserted
1812 into the database. If the total number of rows or bytes exceeds a limit
1813 that we've set, this will trigger a COMMIT.
1815 Args:
1816 n_rows: number of rows inserted
1817 n_bytes: number of bytes inserted
1818 force_commit: force a COMMIT?
1819 """
1820 if force_commit:
1821 self.commit()
1822 return
1823 self._bytes_in_transaction += n_bytes
1824 self._rows_in_transaction += n_rows
1825 if (
1826 self._max_bytes_before_commit is not None
1827 and self._bytes_in_transaction >= self._max_bytes_before_commit
1828 ):
1829 log.debug(
1830 f"Triggering early commit based on byte count "
1831 f"(reached {sizeof_fmt(self._bytes_in_transaction)}, "
1832 f"limit is {sizeof_fmt(self._max_bytes_before_commit)})"
1833 )
1834 self.commit()
1835 elif (
1836 self._max_rows_before_commit is not None
1837 and self._rows_in_transaction >= self._max_rows_before_commit
1838 ):
1839 log.debug(
1840 f"Triggering early commit based on row count "
1841 f"(reached {self._rows_in_transaction} rows, "
1842 f"limit is {self._max_rows_before_commit})"
1843 )
1844 self.commit()
1847# =============================================================================
1848# Specification matching
1849# =============================================================================
1852def _matches_tabledef(table: str, tabledef: str) -> bool:
1853 """
1854 Does the table name match the wildcard-based table definition?
1856 Args:
1857 table: tablename
1858 tabledef: ``fnmatch``-style pattern (e.g.
1859 ``"patient_address_table_*"``)
1860 """
1861 tr = get_spec_match_regex(tabledef)
1862 return bool(tr.match(table))
1865def matches_tabledef(table: str, tabledef: Union[str, List[str]]) -> bool:
1866 """
1867 Does the table name match the wildcard-based table definition?
1869 Args:
1870 table: table name
1871 tabledef: ``fnmatch``-style pattern (e.g.
1872 ``"patient_address_table_*"``), or list of them
1873 """
1874 if isinstance(tabledef, str):
1875 return _matches_tabledef(table, tabledef)
1876 elif not tabledef:
1877 return False
1878 else: # list
1879 return any(_matches_tabledef(table, td) for td in tabledef)
1882def _matches_fielddef(table: str, field: str, fielddef: str) -> bool:
1883 """
1884 Does the table/field name match the wildcard-based field definition?
1886 Args:
1887 table: tablename
1888 field: fieldname
1889 fielddef: ``fnmatch``-style pattern (e.g. ``"system_table.*"``,
1890 ``"*.nhs_number"``)
1891 """
1892 column_id = split_db_schema_table_column(fielddef)
1893 cr = get_spec_match_regex(column_id.column)
1894 if not column_id.table:
1895 # Table not specified in the wildcard.
1896 # It's a match if the field matches.
1897 return bool(cr.match(field))
1898 # Table specified in the wildcard.
1899 # Both the table and the field parts have to match.
1900 tr = get_spec_match_regex(column_id.table)
1901 return bool(tr.match(table)) and bool(cr.match(field))
1904def matches_fielddef(
1905 table: str, field: str, fielddef: Union[str, List[str]]
1906) -> bool:
1907 """
1908 Does the table/field name match the wildcard-based field definition?
1910 Args:
1911 table: table name
1912 field: fieldname
1913 fielddef: ``fnmatch``-style pattern (e.g. ``"system_table.*"`` or
1914 ``"*.nhs_number"``), or list of them
1915 """
1916 if isinstance(fielddef, str):
1917 return _matches_fielddef(table, field, fielddef)
1918 elif not fielddef:
1919 return False
1920 else: # list
1921 return any(_matches_fielddef(table, field, fd) for fd in fielddef)
1924# =============================================================================
1925# More SQL
1926# =============================================================================
1929def sql_fragment_cast_to_int(
1930 expr: str,
1931 big: bool = True,
1932 dialect: Dialect = None,
1933 viewmaker: ViewMaker = None,
1934) -> str:
1935 """
1936 Takes an SQL expression and coerces it to an integer. For Microsoft SQL
1937 Server.
1939 Args:
1940 expr: starting SQL expression
1941 big: use BIGINT, not INTEGER?
1942 dialect: optional :class:`sqlalchemy.engine.interfaces.Dialect`. If
1943 ``None`` and we have a ``viewmaker``, use the viewmaker's dialect.
1944 Otherwise, assume SQL Server.
1945 viewmaker: optional :class:`ViewMaker`
1947 Returns:
1948 modified SQL expression
1950 *Notes*
1952 Conversion to INT:
1954 - https://stackoverflow.com/questions/2000045
1955 - https://stackoverflow.com/questions/14719760 (this one in particular!)
1956 - https://stackoverflow.com/questions/14692131
1958 - see LIKE example.
1959 - see ISNUMERIC();
1960 https://msdn.microsoft.com/en-us/library/ms186272.aspx;
1961 but that includes non-integer numerics
1963 - https://msdn.microsoft.com/en-us/library/ms174214(v=sql.120).aspx;
1964 relates to the SQL Server Management Studio "Find and Replace"
1965 dialogue box, not to SQL itself!
1967 - https://stackoverflow.com/questions/29206404/mssql-regular-expression
1969 Note that the regex-like expression supported by LIKE is extremely limited.
1971 - https://msdn.microsoft.com/en-us/library/ms179859.aspx
1973 - The only things supported are:
1975 .. code-block:: none
1977 % any characters
1978 _ any single character
1979 [] single character in range or set, e.g. [a-f], [abcdef]
1980 [^] single character NOT in range or set, e.g. [^a-f], [abcdef]
1982 SQL Server does not support a REGEXP command directly.
1984 So the best bet is to have the LIKE clause check for a non-integer:
1986 .. code-block:: sql
1988 CASE
1989 WHEN something LIKE '%[^0-9]%' THEN NULL
1990 ELSE CAST(something AS BIGINT)
1991 END
1993 ... which doesn't deal with spaces properly, but there you go.
1994 Could also strip whitespace left/right:
1996 .. code-block:: sql
1998 CASE
1999 WHEN LTRIM(RTRIM(something)) LIKE '%[^0-9]%' THEN NULL
2000 ELSE CAST(something AS BIGINT)
2001 END
2003 That only works for positive integers.
2005 LTRIM/RTRIM are not ANSI SQL.
2006 Nor are unusual LIKE clauses; see
2007 https://stackoverflow.com/questions/712580/list-of-special-characters-for-sql-like-clause
2009 The other, for SQL Server 2012 or higher, is TRY_CAST:
2011 .. code-block:: sql
2013 TRY_CAST(something AS BIGINT)
2015 ... which returns NULL upon failure; see
2016 https://msdn.microsoft.com/en-us/library/hh974669.aspx
2018 Therefore, our **method** is as follows:
2020 - If the database supports TRY_CAST, use that.
2021 - Otherwise if we're using SQL Server, use a CASE/CAST construct.
2022 - Otherwise, raise :exc:`ValueError` as we don't know what to do.
2024 """
2025 inttype = "BIGINT" if big else "INTEGER"
2026 if dialect is None and viewmaker is not None:
2027 dialect = viewmaker.engine.dialect
2028 if dialect is None:
2029 sql_server = True
2030 supports_try_cast = False
2031 else:
2032 # noinspection PyUnresolvedReferences
2033 sql_server = dialect.name == "mssql"
2034 # noinspection PyUnresolvedReferences
2035 supports_try_cast = (
2036 sql_server and dialect.server_version_info >= MS_2012_VERSION
2037 )
2038 if supports_try_cast:
2039 return f"TRY_CAST({expr} AS {inttype})"
2040 elif sql_server:
2041 return (
2042 f"CASE WHEN LTRIM(RTRIM({expr})) LIKE '%[^0-9]%' "
2043 f"THEN NULL ELSE CAST({expr} AS {inttype}) END"
2044 )
2045 # Doesn't support negative integers.
2046 else:
2047 # noinspection PyUnresolvedReferences
2048 raise ValueError(
2049 f"Code not yet written for convert-to-int for "
2050 f"dialect {dialect.name}"
2051 )
2054# =============================================================================
2055# Abstracted SQL WHERE condition
2056# =============================================================================
2059@register_for_json(method=METHOD_PROVIDES_INIT_KWARGS)
2060@functools.total_ordering
2061class WhereCondition:
2062 """
2063 Ancillary class for building SQL WHERE expressions from our web forms.
2065 The essence of it is ``WHERE column op value_or_values``.
2066 """
2068 def __init__(
2069 self,
2070 column_id: ColumnId = None,
2071 op: str = "",
2072 datatype: str = "",
2073 value_or_values: Any = None,
2074 raw_sql: str = "",
2075 from_table_for_raw_sql: TableId = None,
2076 ) -> None:
2077 """
2078 Args:
2079 column_id:
2080 :class:`ColumnId` for the column
2081 op:
2082 operation (e.g. ``=``, ``<``, ``<=``, etc.)
2083 datatype:
2084 data type string that must match values in our
2085 ``querybuilder.js``; see source code. We use this to know how
2086 to build SQL literal values. (Not terribly elegant, but it
2087 works; SQL injection isn't a particular concern because we
2088 let our users run any SQL they want and ensure the connection
2089 is made read-only.)
2090 value_or_values:
2091 ``None``, single value, or list of values. Which is appropriate
2092 depends on the operation. For example, ``IS NULL`` takes no
2093 value; ``=`` takes one; ``IN`` takes many.
2094 raw_sql:
2095 override any thinking we might wish to do, and just return this
2096 raw SQL
2097 from_table_for_raw_sql:
2098 if we are using raw SQL, provide a :class:`TableId` for the
2099 relevant table here
2100 """
2101 self._column_id = column_id
2102 self._op = op.upper()
2103 self._datatype = datatype
2104 self._value = value_or_values
2105 self._no_value = False
2106 self._multivalue = False
2107 self._raw_sql = raw_sql
2108 self._from_table_for_raw_sql = from_table_for_raw_sql
2110 if not self._raw_sql:
2111 if self._op in SQL_OPS_VALUE_UNNECESSARY:
2112 self._no_value = True
2113 assert value_or_values is None, "Superfluous value passed"
2114 elif self._op in SQL_OPS_MULTIPLE_VALUES:
2115 self._multivalue = True
2116 assert isinstance(value_or_values, list), "Need list"
2117 else:
2118 assert not isinstance(
2119 value_or_values, list
2120 ), "Need single value"
2122 def init_kwargs(self) -> Dict:
2123 return {
2124 "column_id": self._column_id,
2125 "op": self._op,
2126 "datatype": self._datatype,
2127 "value_or_values": self._value,
2128 "raw_sql": self._raw_sql,
2129 "from_table_for_raw_sql": self._from_table_for_raw_sql,
2130 }
2132 def __repr__(self) -> str:
2133 return (
2134 "{qualname}("
2135 "column_id={column_id}, "
2136 "op={op}, "
2137 "datatype={datatype}, "
2138 "value_or_values={value_or_values}, "
2139 "raw_sql={raw_sql}, "
2140 "from_table_for_raw_sql={from_table_for_raw_sql}"
2141 ")".format(
2142 qualname=self.__class__.__qualname__,
2143 column_id=repr(self._column_id),
2144 op=repr(self._op),
2145 datatype=repr(self._datatype),
2146 value_or_values=repr(self._value),
2147 raw_sql=repr(self._raw_sql),
2148 from_table_for_raw_sql=repr(self._from_table_for_raw_sql),
2149 )
2150 )
2152 def __eq__(self, other: "WhereCondition") -> bool:
2153 return (
2154 self._raw_sql == other._raw_sql
2155 and self._column_id == other._column_id
2156 and self._op == other._op
2157 and self._value == other._value
2158 )
2160 def __lt__(self, other: "WhereCondition") -> bool:
2161 return (self._raw_sql, self._column_id, self._op, self._value) < (
2162 other._raw_sql,
2163 other._column_id,
2164 other._op,
2165 other._value,
2166 )
2168 @property
2169 def column_id(self) -> ColumnId:
2170 """
2171 Returns the :class:`ColumnId` provided at creation.
2172 """
2173 return self._column_id
2175 @property
2176 def table_id(self) -> TableId:
2177 """
2178 Returns a :class:`TableId`:
2180 - for raw SQL, our ``from_table_for_raw_sql`` attribute
2181 - otherwise, the table ID extracted from our ``column_id`` attribute
2182 """
2183 if self._raw_sql:
2184 return self._from_table_for_raw_sql
2185 return self.column_id.table_id
2187 def table_str(self, grammar: SqlGrammar) -> str:
2188 """
2189 Returns the table identifier in the specified SQL grammar.
2191 Args:
2192 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
2193 """
2194 return self.table_id.identifier(grammar)
2196 def sql(self, grammar: SqlGrammar) -> str:
2197 """
2198 Returns the WHERE clause (without ``WHERE`` itself!) for our condition,
2199 in the specified SQL grammar. Some examples might be:
2201 - ``somecol = 3``
2202 - ``othercol IN (6, 7, 8)``
2203 - ``thirdcol IS NOT NULL``
2204 - ``textcol LIKE '%paracetamol%'``
2205 - ``MATCH (fulltextcol AGAINST 'paracetamol')`` (MySQL)
2206 - ``CONTAINS(fulltextcol, 'paracetamol')`` (SQL Server)
2208 Args:
2209 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
2210 """
2211 if self._raw_sql:
2212 return self._raw_sql
2214 col = self._column_id.identifier(grammar)
2215 op = self._op
2217 if self._no_value:
2218 return f"{col} {op}"
2220 if self._datatype in QB_STRING_TYPES:
2221 element_converter = sql_string_literal
2222 elif self._datatype == QB_DATATYPE_DATE:
2223 element_converter = sql_date_literal
2224 elif self._datatype == QB_DATATYPE_INTEGER:
2225 element_converter = str
2226 elif self._datatype == QB_DATATYPE_FLOAT:
2227 element_converter = str
2228 else:
2229 # Safe default
2230 element_converter = sql_string_literal
2232 if self._multivalue:
2233 literal = "({})".format(
2234 ", ".join(element_converter(v) for v in self._value)
2235 )
2236 else:
2237 literal = element_converter(self._value)
2239 if self._op == "MATCH": # MySQL
2240 return f"MATCH ({col}) AGAINST ({literal})"
2241 elif self._op == "CONTAINS": # SQL Server
2242 return f"CONTAINS({col}, {literal})"
2243 else:
2244 return f"{col} {op} {literal}"
2247# =============================================================================
2248# SQL formatting
2249# =============================================================================
2252def format_sql_for_print(sql: str) -> str:
2253 """
2254 Very simple SQL formatting.
2256 Remove blank lines and trailing spaces from an SQL statement.
2257 Converts tabs to spaces.
2258 """
2259 lines = list(
2260 filter(
2261 None, [x.replace("\t", " ").rstrip() for x in sql.splitlines()]
2262 )
2263 )
2264 # Shift all lines left if they're left-padded
2265 firstleftpos = float("inf")
2266 for line in lines:
2267 leftpos = len(line) - len(line.lstrip())
2268 firstleftpos = min(firstleftpos, leftpos)
2269 if firstleftpos > 0:
2270 lines = [x[firstleftpos:] for x in lines]
2271 return "\n".join(lines)
2274# =============================================================================
2275# Plain SQL types
2276# =============================================================================
2279def is_sql_column_type_textual(column_type: str, min_length: int = 1) -> bool:
2280 """
2281 Does an SQL column type look textual?
2283 Args:
2284 column_type: SQL column type as a string, e.g. ``"VARCHAR(50)"``
2285 min_length: what's the minimum string length we'll say "yes" to?
2287 Returns:
2288 is it a textual column (of the minimum length or more)?
2290 Note:
2292 - For SQL Server's NVARCHAR(MAX),
2293 :meth:`crate_anon.crateweb.research.research_db_info._schema_query_microsoft`
2294 returns "NVARCHAR(-1)"
2295 """
2296 if not column_type:
2297 return False
2298 column_type = column_type.upper().split()[0]
2299 if column_type in SQLTYPES_TEXT:
2300 # A text type without a specific length
2301 return True
2302 try:
2303 m = COLTYPE_WITH_ONE_INTEGER_REGEX.match(column_type)
2304 basetype = m.group(1)
2305 length = int(m.group(2))
2306 except (AttributeError, ValueError):
2307 return False
2308 return (length >= min_length or length < 0) and basetype in SQLTYPES_TEXT
2311def coltype_length_if_text(column_type: str, dialect: str) -> Optional[int]:
2312 """
2313 Find the length of an sql text column type.
2315 Args:
2316 column_type: SQL column type as a string, e.g. ``"VARCHAR(50)"``
2317 dialect: the SQL dialect the column type is from
2319 Returns:
2320 length of the column or ``None`` if it's not a text column.
2322 """
2323 column_type = column_type.upper()
2324 if column_type in SQLTYPES_TEXT:
2325 # No length specified - get the default
2326 try:
2327 lookup = DIALECT_TO_STRING_LEN_LOOKUP[dialect]
2328 except KeyError:
2329 possible = list(DIALECT_TO_STRING_LEN_LOOKUP.keys())
2330 raise ValueError(
2331 f"CRATE doesn't properly understand SQL dialect {dialect!r}. "
2332 f"Supported: {possible}"
2333 )
2334 try:
2335 return lookup[column_type]
2336 except KeyError:
2337 raise ValueError(
2338 f"For SQL dialect {dialect!r}, CRATE doesn't know the length "
2339 f"for string data type {column_type!r}"
2340 )
2341 else:
2342 # Length specified - get it from the column type
2343 try:
2344 m = COLTYPE_WITH_ONE_INTEGER_REGEX.match(column_type)
2345 basetype = m.group(1)
2346 length = m.group(2)
2347 if length == "MAX" or length == "-1":
2348 if dialect == SqlaDialectName.MSSQL:
2349 if basetype == "VARCHAR":
2350 return MSSQL_COLTYPE_TO_LEN["VARCHAR_MAX"]
2351 elif basetype == "NVARCHAR":
2352 return MSSQL_COLTYPE_TO_LEN["NVARCHAR_MAX"]
2353 return None
2354 except AttributeError:
2355 # Not the correct type of column
2356 return None
2357 try:
2358 return int(length)
2359 except ValueError:
2360 # Not the correct type of column
2361 return None
2364def escape_quote_in_literal(s: str) -> str:
2365 r"""
2366 Escape ``'``. We could use ``''`` or ``\'``.
2367 Let's use ``\.`` for consistency with percent escaping.
2368 """
2369 return s.replace("'", r"\'")
2372def escape_percent_in_literal(sql: str) -> str:
2373 r"""
2374 Escapes ``%`` by converting it to ``\%``.
2375 Use this for LIKE clauses.
2377 - https://dev.mysql.com/doc/refman/5.7/en/string-literals.html
2378 """
2379 return sql.replace("%", r"\%")
2382def escape_percent_for_python_dbapi(sql: str) -> str:
2383 """
2384 Escapes ``%`` by converting it to ``%%``.
2385 Use this for SQL within Python where ``%`` characters are used for argument
2386 placeholders.
2387 """
2388 return sql.replace("%", "%%")
2391def escape_sql_string_literal(s: str) -> str:
2392 """
2393 Escapes SQL string literal fragments against quotes and parameter
2394 substitution.
2395 """
2396 return escape_percent_in_literal(escape_quote_in_literal(s))
2399def make_string_literal(s: str) -> str:
2400 """
2401 Converts a Python string into an SQL single-quoted (and escaped) string
2402 literal.
2403 """
2404 return f"'{escape_sql_string_literal(s)}'"
2407def escape_sql_string_or_int_literal(s: Union[str, int]) -> str:
2408 """
2409 Converts an integer or a string into an SQL literal (with single quotes and
2410 escaping in the case of a string).
2411 """
2412 if isinstance(s, int):
2413 return str(s)
2414 else:
2415 return make_string_literal(s)
2418def translate_sql_qmark_to_percent(sql: str) -> str:
2419 """
2420 This function translates SQL using ``?`` placeholders to SQL using ``%s``
2421 placeholders, without breaking literal ``'?'`` or ``'%'``, e.g. inside
2422 string literals.
2424 *Notes*
2426 - MySQL likes ``?`` as a placeholder.
2428 - https://dev.mysql.com/doc/refman/5.7/en/sql-syntax-prepared-statements.html
2430 - Python DBAPI allows several: ``%s``, ``?``, ``:1``, ``:name``,
2431 ``%(name)s``.
2433 - https://www.python.org/dev/peps/pep-0249/#paramstyle
2435 - Django uses ``%s``.
2437 - https://docs.djangoproject.com/en/1.8/topics/db/sql/
2439 - Microsoft like ``?``, ``@paramname``, and ``:paramname``.
2441 - https://msdn.microsoft.com/en-us/library/yy6y35y8(v=vs.110).aspx
2443 - We need to parse SQL with argument placeholders.
2445 - See :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` classes,
2446 particularly: ``bind_parameter``
2448 I prefer ``?``, because ``%`` is used in LIKE clauses, and the databases
2449 we're using like it.
2451 So:
2453 - We use ``%s`` when using ``cursor.execute()`` directly, via Django.
2454 - We use ``?`` when talking to users, and
2455 :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` objects, so that
2456 the visual appearance matches what they expect from their database.
2458 """ # noqa: E501
2459 # 1. Escape % characters
2460 sql = escape_percent_for_python_dbapi(sql)
2461 # 2. Replace ? characters that are not within quotes with %s.
2462 newsql = ""
2463 in_quotes = False
2464 for c in sql:
2465 if c == "'":
2466 in_quotes = not in_quotes
2467 if c == "?" and not in_quotes:
2468 newsql += "%s"
2469 else:
2470 newsql += c
2471 return newsql
2474def decorate_index_name(
2475 idxname: str, tablename: str = None, engine: Engine = None
2476) -> str:
2477 """
2478 Amend the name of a database index. Specifically, this is because SQLite
2479 (which we won't use much, but do use for testing!) won't accept two indexes
2480 with the same names applying to different tables.
2482 Args:
2483 idxname:
2484 The original index name.
2485 tablename:
2486 The name of the table.
2487 engine:
2488 The SQLAlchemy engine, from which we obtain the dialect.
2490 Returns:
2491 The index name, amended if necessary.
2492 """
2493 if not tablename or not engine:
2494 return idxname
2495 if engine.dialect.name == "sqlite":
2496 return f"{idxname}_{tablename}"
2497 return idxname