Coverage for common/sql.py: 32%

743 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2026-02-05 06:46 -0600

1""" 

2crate_anon/common/sql.py 

3 

4=============================================================================== 

5 

6 Copyright (C) 2015, University of Cambridge, Department of Psychiatry. 

7 Created by Rudolf Cardinal (rnc1001@cam.ac.uk). 

8 

9 This file is part of CRATE. 

10 

11 CRATE is free software: you can redistribute it and/or modify 

12 it under the terms of the GNU General Public License as published by 

13 the Free Software Foundation, either version 3 of the License, or 

14 (at your option) any later version. 

15 

16 CRATE is distributed in the hope that it will be useful, 

17 but WITHOUT ANY WARRANTY; without even the implied warranty of 

18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

19 GNU General Public License for more details. 

20 

21 You should have received a copy of the GNU General Public License 

22 along with CRATE. If not, see <https://www.gnu.org/licenses/>. 

23 

24=============================================================================== 

25 

26**Low-level SQL manipulation functions.** 

27 

28These are about the manipulation of SQL as text (e.g. for query building 

29assistance for researchers, or for interpreting SQL data types in data 

30dictionaries), not about a higher-level approach like SQLAlchemy. 

31 

32""" 

33 

34from collections import OrderedDict 

35from dataclasses import dataclass 

36import functools 

37import logging 

38import re 

39from typing import Any, Dict, Iterable, List, Tuple, Union, Optional 

40 

41from cardinal_pythonlib.json_utils.serialize import ( 

42 METHOD_PROVIDES_INIT_KWARGS, 

43 METHOD_STRIP_UNDERSCORE, 

44 register_for_json, 

45) 

46from cardinal_pythonlib.lists import unique_list 

47from cardinal_pythonlib.reprfunc import mapped_repr_stripping_underscores 

48from cardinal_pythonlib.sizeformatter import sizeof_fmt 

49from cardinal_pythonlib.sql.literals import ( 

50 sql_date_literal, 

51 sql_string_literal, 

52) 

53from cardinal_pythonlib.sql.sql_grammar import SqlGrammar, text_from_parsed 

54from cardinal_pythonlib.sql.sql_grammar_factory import ( 

55 make_grammar, 

56 mysql_grammar, 

57) 

58from cardinal_pythonlib.sql.validation import ( 

59 SQLTYPES_INTEGER, 

60 SQLTYPES_BIT, 

61 SQLTYPES_FLOAT, 

62 SQLTYPES_TEXT, 

63 SQLTYPES_OTHER_NUMERIC, 

64) 

65from cardinal_pythonlib.sqlalchemy.core_query import count_star 

66from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName 

67from cardinal_pythonlib.sqlalchemy.schema import ( 

68 column_creation_ddl, 

69 execute_ddl, 

70) 

71from cardinal_pythonlib.timing import MultiTimerContext, timer 

72from pyparsing import ParseResults 

73from sqlalchemy import inspect 

74from sqlalchemy.dialects.mssql.base import MS_2012_VERSION 

75from sqlalchemy.engine.base import Engine 

76from sqlalchemy.engine.interfaces import Dialect 

77from sqlalchemy.exc import CompileError 

78from sqlalchemy.orm.session import Session 

79from sqlalchemy.schema import Column, Table 

80from sqlalchemy.sql.sqltypes import TypeEngine 

81 

82from crate_anon.common.stringfunc import get_spec_match_regex 

83 

84log = logging.getLogger(__name__) 

85 

86 

87# ============================================================================= 

88# Types 

89# ============================================================================= 

90 

91SqlArgsTupleType = Tuple[str, List[Any]] 

92 

93 

94# ============================================================================= 

95# Constants 

96# ============================================================================= 

97 

98# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 

99# Generic 

100# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 

101 

102TIMING_COMMIT = "commit" 

103 

104SQL_OPS_VALUE_UNNECESSARY = ["IS NULL", "IS NOT NULL"] 

105SQL_OPS_MULTIPLE_VALUES = ["IN", "NOT IN"] 

106 

107SQLTYPES_INTEGER_OR_BIT = SQLTYPES_INTEGER + SQLTYPES_BIT 

108SQLTYPES_FLOAT_OR_OTHER_NUMERIC = SQLTYPES_FLOAT + SQLTYPES_OTHER_NUMERIC 

109 

110# Must match querybuilder.js: 

111QB_DATATYPE_INTEGER = "int" 

112QB_DATATYPE_FLOAT = "float" 

113QB_DATATYPE_DATE = "date" 

114QB_DATATYPE_STRING = "string" 

115QB_DATATYPE_STRING_FULLTEXT = "string_fulltext" 

116QB_DATATYPE_UNKNOWN = "unknown" 

117QB_STRING_TYPES = [QB_DATATYPE_STRING, QB_DATATYPE_STRING_FULLTEXT] 

118 

119COLTYPE_WITH_ONE_INTEGER_REGEX = re.compile(r"^([A-z]+)\((-?\d+)\)$") 

120# ... start, group(alphabetical), literal (, group(optional_minus_sign digits), 

121# literal ), end 

122 

123# Dictionaries for the different dialects mapping text column type to length 

124# or default length. 

125# Doesn't include things like VARCHAR which require the user to specify length 

126 

127# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 

128# SQLAlchemy dialects 

129# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 

130 

131DATABRICKS_COLTYPE_TO_LEN = { 

132 # https://docs.databricks.com/en/sql/language-manual/data-types/string-type.html # noqa: E501 

133 "STRING": None # There is no maximum. 

134} 

135MSSQL_COLTYPE_TO_LEN = { 

136 # The "N" prefix means Unicode. 

137 # https://docs.microsoft.com/en-us/sql/t-sql/data-types/char-and-varchar-transact-sql?view=sql-server-ver15 # noqa: E501 

138 # https://docs.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver15 # noqa: E501 

139 # https://docs.microsoft.com/en-us/sql/t-sql/data-types/ntext-text-and-image-transact-sql?view=sql-server-ver15 # noqa: E501 

140 "NVARCHAR_MAX": 2**30 - 1, 

141 # Can specify NVARCHAR(1) to NVARCHAR(4000), or NVARCHAR(MAX) for 2^30 - 1. 

142 "VARCHAR_MAX": 2**31 - 1, 

143 # Can specify VARCHAR(1) to VARCHAR(8000), or VARCHAR(MAX) for 2^31 - 1. 

144 "TEXT": 2**31 - 1, 

145 "NTEXT": 2**30 - 1, 

146} 

147MYSQL_COLTYPE_TO_LEN = { 

148 # https://dev.mysql.com/doc/refman/8.0/en/string-type-overview.html 

149 "CHAR": 1, # can specify CHAR(0) to CHAR(255), but if omitted, length is 1 

150 "TINYTEXT": 255, # 2^8 - 1 

151 "TEXT": 65535, # 2^16 - 1 

152 "MEDIUMTEXT": 16777215, # 2^24 - 1 

153 "LONGTEXT": 4294967295, # 2^32 - 1 

154} 

155 

156DIALECT_TO_STRING_LEN_LOOKUP = { 

157 SqlaDialectName.DATABRICKS: DATABRICKS_COLTYPE_TO_LEN, 

158 SqlaDialectName.MSSQL: MSSQL_COLTYPE_TO_LEN, 

159 SqlaDialectName.MYSQL: MYSQL_COLTYPE_TO_LEN, 

160} 

161 

162 

163# ============================================================================= 

164# Helper classes 

165# ============================================================================= 

166 

167 

168@dataclass 

169class IndexCreationInfo: 

170 index_name: str #: Name of the index 

171 column: Union[str, List[str]] #: Column name(s) to index 

172 unique: bool = False #: Make a unique index? 

173 

174 @property 

175 def column_names(self) -> str: 

176 if isinstance(self.column, str): 

177 # Single column 

178 return self.column 

179 else: 

180 # Multiple columns 

181 return ", ".join(self.column) 

182 

183 

184# ============================================================================= 

185# SQL elements: identifiers 

186# ============================================================================= 

187 

188 

189@register_for_json(method=METHOD_STRIP_UNDERSCORE) 

190@functools.total_ordering 

191class SchemaId: 

192 """ 

193 Represents a database schema. This is a bit complex: 

194 

195 - In SQL Server, schemas live within databases. Tables can be referred to 

196 as ``table``, ``schema.table``, or ``database.schema.table``. 

197 

198 - https://docs.microsoft.com/en-us/dotnet/framework/data/adonet/sql/ownership-and-user-schema-separation-in-sql-server 

199 - The default schema is named ``dbo``. 

200 

201 - In PostgreSQL, schemas live within databases. Tables can be referred to 

202 as ``table``, ``schema.table``, or ``database.schema.table``. 

203 

204 - https://www.postgresql.org/docs/current/static/ddl-schemas.html 

205 - The default schema is named ``public``. 

206 

207 - In MySQL, "database" and "schema" are synonymous. Tables can be referred 

208 to as ``table`` or ``database.table`` (= ``schema.table``). 

209 

210 - https://stackoverflow.com/questions/11618277/difference-between-schema-database-in-mysql 

211 

212 """ # noqa: E501 

213 

214 def __init__(self, db: str = "", schema: str = "") -> None: 

215 """ 

216 Args: 

217 db: database name 

218 schema: schema name 

219 """ 

220 assert "." not in db, f"Bad database name ({db!r}); can't include '.'" 

221 assert ( 

222 "." not in schema 

223 ), f"Bad schema name ({schema!r}); can't include '.'" 

224 self._db = db 

225 self._schema = schema 

226 

227 @property 

228 def schema_tag(self) -> str: 

229 """ 

230 String suitable for encoding the SchemaId e.g. in a single HTML form. 

231 Takes the format ``database.schema``. 

232 

233 The :func:`__init__` function has already checked the assumption of no 

234 ``'.'`` characters in either part. 

235 """ 

236 return f"{self._db}.{self._schema}" 

237 

238 @classmethod 

239 def from_schema_tag(cls, tag: str) -> "SchemaId": 

240 """ 

241 Returns a :class:`SchemaId` from a tag of the form ``db.schema``. 

242 """ 

243 parts = tag.split(".") 

244 assert len(parts) == 2, f"Bad schema tag {tag!r}" 

245 db, schema = parts 

246 return SchemaId(db, schema) 

247 

248 def __bool__(self) -> bool: 

249 """ 

250 Returns: 

251 is there a named schema? 

252 """ 

253 return bool(self._schema) 

254 

255 def __eq__(self, other: "SchemaId") -> bool: 

256 return ( # ordering is for speed 

257 self._schema == other._schema and self._db == other._db 

258 ) 

259 

260 def __lt__(self, other: "SchemaId") -> bool: 

261 return (self._db, self._schema) < (other._db, other._schema) 

262 

263 def __hash__(self) -> int: 

264 return hash(str(self)) 

265 

266 def identifier(self, grammar: SqlGrammar) -> str: 

267 """ 

268 Returns an SQL identifier for this schema using the specified SQL 

269 grammar, quoting it if need be. 

270 

271 Args: 

272 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

273 """ 

274 return make_identifier(grammar, database=self._db, schema=self._schema) 

275 

276 def table_id(self, table: str) -> "TableId": 

277 """ 

278 Returns a :class:`TableId` combining this schema and the specified 

279 table. 

280 

281 Args: 

282 table: name of the table 

283 """ 

284 return TableId(db=self._db, schema=self._schema, table=table) 

285 

286 def column_id(self, table: str, column: str) -> "ColumnId": 

287 """ 

288 Returns a :class:`ColumnId` combining this schema and the specified 

289 table/column. 

290 

291 Args: 

292 table: name of the table 

293 column: name of the column 

294 """ 

295 return ColumnId( 

296 db=self._db, schema=self._schema, table=table, column=column 

297 ) 

298 

299 @property 

300 def db(self) -> str: 

301 """ 

302 Returns the database part. 

303 """ 

304 return self._db 

305 

306 @property 

307 def schema(self) -> str: 

308 """ 

309 Returns the schema part. 

310 """ 

311 return self._schema 

312 

313 def __str__(self) -> str: 

314 return self.identifier(mysql_grammar) # specific one unimportant 

315 

316 def __repr__(self) -> str: 

317 return mapped_repr_stripping_underscores(self, ["_db", "_schema"]) 

318 

319 def is_present(self) -> bool: 

320 """ 

321 Is this a blank/nonfunctional schema, with no ``database`` or 

322 ``schema`` part? 

323 """ 

324 return bool(self._db or self._schema) 

325 

326 def is_blank(self) -> bool: 

327 """ 

328 Is this a blank/nonfunctional schema, with no ``database`` or 

329 ``schema`` part? 

330 """ 

331 return not self.is_present() 

332 

333 

334@register_for_json(method=METHOD_STRIP_UNDERSCORE) 

335@functools.total_ordering 

336class TableId: 

337 """ 

338 Represents a database table. 

339 """ 

340 

341 def __init__( 

342 self, db: str = "", schema: str = "", table: str = "" 

343 ) -> None: 

344 """ 

345 Args: 

346 db: database name 

347 schema: schema name 

348 table: table name 

349 """ 

350 self._db = db 

351 self._schema = schema 

352 self._table = table 

353 

354 def __bool__(self) -> bool: 

355 return bool(self._table) 

356 

357 def __eq__(self, other: "TableId") -> bool: 

358 return ( # ordering is for speed 

359 self._table == other._table 

360 and self._schema == other._schema 

361 and self._db == other._db 

362 ) 

363 

364 def __lt__(self, other: "TableId") -> bool: 

365 return (self._db, self._schema, self._table) < ( 

366 other._db, 

367 other._schema, 

368 other._table, 

369 ) 

370 

371 def __hash__(self) -> int: 

372 return hash(str(self)) 

373 

374 def identifier(self, grammar: SqlGrammar) -> str: 

375 """ 

376 Returns an SQL identifier for this table using the specified SQL 

377 grammar, quoting it if need be. 

378 

379 Args: 

380 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

381 """ 

382 return make_identifier( 

383 grammar, database=self._db, schema=self._schema, table=self._table 

384 ) 

385 

386 @property 

387 def schema_id(self) -> SchemaId: 

388 """ 

389 Returns a :class:`SchemaId` for the schema of our table. 

390 """ 

391 return SchemaId(db=self._db, schema=self._schema) 

392 

393 def column_id(self, column: str) -> "ColumnId": 

394 """ 

395 Returns a :class:`ColumnId` combining this table and the specified 

396 column. 

397 

398 Args: 

399 column: name of the column 

400 """ 

401 return ColumnId( 

402 db=self._db, schema=self._schema, table=self._table, column=column 

403 ) 

404 

405 def database_schema_part(self, grammar: SqlGrammar) -> str: 

406 """ 

407 Returns an SQL identifier for this table's database/schema (without the 

408 table part) using the specified SQL grammar, quoting it if need be. 

409 

410 Args: 

411 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

412 """ 

413 return make_identifier(grammar, database=self._db, schema=self._schema) 

414 

415 def table_part(self, grammar: SqlGrammar) -> str: 

416 """ 

417 Returns an SQL identifier for this table's table name (only) using the 

418 specified SQL grammar, quoting it if need be. 

419 

420 Args: 

421 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

422 """ 

423 return make_identifier(grammar, table=self._table) 

424 

425 @property 

426 def db(self) -> str: 

427 """ 

428 Returns the database part. 

429 """ 

430 return self._db 

431 

432 @property 

433 def schema(self) -> str: 

434 """ 

435 Returns the schema part. 

436 """ 

437 return self._schema 

438 

439 @property 

440 def table(self) -> str: 

441 """ 

442 Returns the table part. 

443 """ 

444 return self._table 

445 

446 def __str__(self) -> str: 

447 return self.identifier(mysql_grammar) # specific one unimportant 

448 

449 def __repr__(self) -> str: 

450 return mapped_repr_stripping_underscores( 

451 self, ["_db", "_schema", "_table"] 

452 ) 

453 

454 

455@register_for_json(method=METHOD_STRIP_UNDERSCORE) 

456@functools.total_ordering 

457class ColumnId: 

458 """ 

459 Represents a database column. 

460 """ 

461 

462 def __init__( 

463 self, db: str = "", schema: str = "", table: str = "", column: str = "" 

464 ) -> None: 

465 """ 

466 Args: 

467 db: database name 

468 schema: schema name 

469 table: table name 

470 column: column name 

471 """ 

472 self._db = db 

473 self._schema = schema 

474 self._table = table 

475 self._column = column 

476 

477 def __bool__(self) -> bool: 

478 return bool(self._column) 

479 

480 def __eq__(self, other: "ColumnId") -> bool: 

481 return ( 

482 self._column == other._column 

483 and self._table == other._table 

484 and self._schema == other._schema 

485 and self._db == other._db 

486 ) 

487 

488 def __lt__(self, other: "ColumnId") -> bool: 

489 return (self._db, self._schema, self._table, self._column) < ( 

490 other._db, 

491 other._schema, 

492 other._table, 

493 other._column, 

494 ) 

495 

496 @property 

497 def is_valid(self) -> bool: 

498 """ 

499 Do we know about a table and a column, at least? 

500 """ 

501 return bool(self._table and self._column) # the minimum 

502 

503 def identifier(self, grammar: SqlGrammar) -> str: 

504 return make_identifier( 

505 grammar, 

506 database=self._db, 

507 schema=self._schema, 

508 table=self._table, 

509 column=self._column, 

510 ) 

511 

512 @property 

513 def db(self) -> str: 

514 """ 

515 Returns the database part. 

516 """ 

517 return self._db 

518 

519 @property 

520 def schema(self) -> str: 

521 """ 

522 Returns the schema part. 

523 """ 

524 return self._schema 

525 

526 @property 

527 def table(self) -> str: 

528 """ 

529 Returns the table part. 

530 """ 

531 return self._table 

532 

533 @property 

534 def column(self) -> str: 

535 """ 

536 Returns the column part. 

537 """ 

538 return self._column 

539 

540 @property 

541 def schema_id(self) -> SchemaId: 

542 """ 

543 Returns a :class:`SchemaId` for the schema of our column. 

544 """ 

545 return SchemaId(db=self._db, schema=self._schema) 

546 

547 @property 

548 def table_id(self) -> TableId: 

549 """ 

550 Returns a :class:`TableId` for our table. 

551 """ 

552 return TableId(db=self._db, schema=self._schema, table=self._table) 

553 

554 @property 

555 def has_table_and_column(self) -> bool: 

556 """ 

557 Do we know about a table and a column? 

558 """ 

559 return bool(self._table and self._column) 

560 

561 def __str__(self) -> str: 

562 return self.identifier(mysql_grammar) # specific one unimportant 

563 

564 def __repr__(self) -> str: 

565 return mapped_repr_stripping_underscores( 

566 self, ["_db", "_schema", "_table", "_column"] 

567 ) 

568 

569 # def html(self, grammar: SqlGrammar, bold_column: bool = True) -> str: 

570 # components = [ 

571 # html.escape(grammar.quote_identifier_if_required(x)) 

572 # for x in [self._db, self._schema, self._table, self._column] 

573 # if x] 

574 # if not components: 

575 # return '' 

576 # if bold_column: 

577 # components[-1] = f"<b>{components[-1]}</b>" 

578 # return ".".join(components) 

579 

580 

581def split_db_schema_table(db_schema_table: str) -> TableId: 

582 """ 

583 Converts a simple SQL-style identifier string into a :class:`TableId`. 

584 

585 Args: 

586 db_schema_table: 

587 one of: ``database.schema.table``, ``schema.table``, ``table`` 

588 

589 Returns: 

590 a :class:`TableId` 

591 

592 Raises: 

593 :exc:`ValueError` if the input is bad 

594 

595 """ 

596 components = db_schema_table.split(".") 

597 if len(components) == 3: # db.schema.table 

598 d, s, t = components[0], components[1], components[2] 

599 elif len(components) == 2: # schema.table 

600 d, s, t = "", components[0], components[1] 

601 elif len(components) == 1: # table 

602 d, s, t = "", "", components[0] 

603 else: 

604 raise ValueError(f"Bad db_schema_table: {db_schema_table}") 

605 return TableId(db=d, schema=s, table=t) 

606 

607 

608def split_db_schema_table_column(db_schema_table_col: str) -> ColumnId: 

609 """ 

610 Converts a simple SQL-style identifier string into a :class:`ColumnId`. 

611 

612 Args: 

613 db_schema_table_col: 

614 one of: ``database.schema.table.column``, ``schema.table.column``, 

615 ``table.column``, ``column`` 

616 

617 Returns: 

618 a :class:`ColumnId` 

619 

620 Raises: 

621 :exc:`ValueError` if the input is bad 

622 

623 """ 

624 components = db_schema_table_col.split(".") 

625 if len(components) == 4: # db.schema.table.column 

626 d, s, t, c = components[0], components[1], components[2], components[3] 

627 elif len(components) == 3: # schema.table.column 

628 d, s, t, c = "", components[0], components[1], components[2] 

629 elif len(components) == 2: # table.column 

630 d, s, t, c = "", "", components[0], components[1] 

631 elif len(components) == 1: # column 

632 d, s, t, c = "", "", "", components[0] 

633 else: 

634 raise ValueError(f"Bad db_schema_table_col: {db_schema_table_col}") 

635 return ColumnId(db=d, schema=s, table=t, column=c) 

636 

637 

638def columns_to_table_column_hierarchy( 

639 columns: List[ColumnId], sort: bool = True 

640) -> List[Tuple[TableId, List[ColumnId]]]: 

641 """ 

642 Converts a list of column IDs 

643 Args: 

644 columns: list of :class:`ColumnId` objects 

645 sort: sort by table, and column within table? 

646 

647 Returns: 

648 a list of tuples, each ``table, columns``, where ``table`` is a 

649 :class:`TableId` and ``columns`` is a list of :class:`ColumnId` 

650 

651 """ 

652 tables = unique_list(c.table_id for c in columns) 

653 if sort: 

654 tables.sort() 

655 table_column_map = [] # type: List[Tuple[TableId, List[ColumnId]]] 

656 for t in tables: 

657 t_columns = [c for c in columns if c.table_id == t] 

658 if sort: 

659 t_columns.sort() 

660 table_column_map.append((t, t_columns)) 

661 return table_column_map 

662 

663 

664# ============================================================================= 

665# Using SQL grammars (but without reference to Django models, for testing) 

666# ============================================================================= 

667 

668 

669def make_identifier( 

670 grammar: SqlGrammar, 

671 database: str = None, 

672 schema: str = None, 

673 table: str = None, 

674 column: str = None, 

675) -> str: 

676 """ 

677 Makes an SQL identifier by quoting its elements according to the style of 

678 the specific SQL grammar, and then joining them with ``.``. 

679 

680 Args: 

681 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

682 database: database name 

683 schema: schema name 

684 table: table name 

685 column: column name 

686 

687 Returns: 

688 a string as above in the order "database, schema, table, column", but 

689 omitting any missing parts 

690 

691 """ 

692 elements = [ 

693 grammar.quote_identifier_if_required(x) 

694 for x in (database, schema, table, column) 

695 if x 

696 ] 

697 assert elements, "make_identifier(): No elements passed!" 

698 return ".".join(elements) 

699 

700 

701def dumb_make_identifier( 

702 database: str = None, 

703 schema: str = None, 

704 table: str = None, 

705 column: str = None, 

706) -> str: 

707 """ 

708 Makes an SQL-style identifier by joining all the parts with ``.``, without 

709 bothering to quote them. 

710 

711 Args: 

712 database: database name 

713 schema: schema name 

714 table: table name 

715 column: column name 

716 

717 Returns: 

718 a string as above in the order "database, schema, table, column", but 

719 omitting any missing parts 

720 

721 """ 

722 elements = filter(None, [database, schema, table, column]) 

723 assert elements, "make_identifier(): No elements passed!" 

724 return ".".join(elements) 

725 

726 

727def parser_add_result_column( 

728 parsed: ParseResults, column: str, grammar: SqlGrammar 

729) -> ParseResults: 

730 """ 

731 Takes a parsed SQL statement of the form 

732 

733 .. code-block:: sql 

734 

735 SELECT a, b, c 

736 FROM sometable 

737 WHERE conditions; 

738 

739 and adds a result column, e.g. ``d``, to give 

740 

741 .. code-block:: sql 

742 

743 SELECT a, b, c, d 

744 FROM sometable 

745 WHERE conditions; 

746 

747 Presupposes that there is at least one column already in the SELECT 

748 statement. 

749 

750 Args: 

751 parsed: a `pyparsing.ParseResults` result 

752 column: column name 

753 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

754 

755 Returns: 

756 a `pyparsing.ParseResults` result 

757 

758 """ 

759 existing_columns = parsed.select_expression.select_columns.asList() 

760 if column not in existing_columns: 

761 # doesn't exist; add it 

762 newcol = grammar.get_result_column().parseString(column, parseAll=True) 

763 parsed.select_expression.extend([",", newcol]) 

764 return parsed 

765 

766 

767class JoinInfo: 

768 """ 

769 Object to represent a SQL join condition in a simple way. 

770 """ 

771 

772 def __init__( 

773 self, 

774 table: str, 

775 join_type: str = "INNER JOIN", 

776 join_condition: str = "", 

777 ) -> None: # e.g. "ON x = y" 

778 """ 

779 Args: 

780 table: table to be joined in 

781 join_type: join method, e.g. ``"INNER JOIN"`` 

782 join_condition: join condition, e.g. ``"ON x = y"`` 

783 """ 

784 self.join_type = join_type 

785 self.table = table 

786 self.join_condition = join_condition 

787 

788 

789def parser_add_from_tables( 

790 parsed: ParseResults, join_info_list: List[JoinInfo], grammar: SqlGrammar 

791) -> ParseResults: 

792 """ 

793 Takes a parsed SQL statement of the form 

794 

795 .. code-block:: sql 

796 

797 SELECT a, b, c 

798 FROM sometable 

799 WHERE conditions; 

800 

801 and adds one or more join columns, e.g. ``JoinInfo("othertable", "INNER 

802 JOIN", "ON table.key = othertable.key")``, to give 

803 

804 .. code-block:: sql 

805 

806 SELECT a, b, c 

807 FROM sometable 

808 INNER JOIN othertable ON table.key = othertable.key 

809 WHERE conditions; 

810 

811 Presupposes that there at least one table already in the FROM clause. 

812 

813 Args: 

814 parsed: a `pyparsing.ParseResults` result 

815 join_info_list: list of :class:`JoinInfo` objects 

816 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

817 

818 Returns: 

819 a `pyparsing.ParseResults` result 

820 

821 """ 

822 # log.critical(parsed.dump()) 

823 existing_tables = parsed.join_source.from_tables.asList() 

824 # log.critical(f"existing tables: {existing_tables}") 

825 # log.critical(f"adding table: {table}") 

826 for ji in join_info_list: 

827 if ji.table in existing_tables: # already there 

828 # log.critical("field already present") 

829 continue 

830 parsed_join = grammar.get_join_op().parseString( 

831 ji.join_type, parseAll=True 

832 )[ 

833 0 

834 ] # e.g. INNER JOIN 

835 parsed_table = grammar.get_table_spec().parseString( 

836 ji.table, parseAll=True 

837 )[0] 

838 extrabits = [parsed_join, parsed_table] 

839 if ji.join_condition: # e.g. ON x = y 

840 extrabits.append( 

841 grammar.get_join_constraint().parseString( 

842 ji.join_condition, parseAll=True 

843 )[0] 

844 ) 

845 parsed.join_source.extend(extrabits) 

846 # log.critical(parsed.dump()) 

847 return parsed 

848 

849 

850def get_first_from_table( 

851 parsed: ParseResults, 

852 match_db: str = "", 

853 match_schema: str = "", 

854 match_table: str = "", 

855) -> TableId: 

856 """ 

857 Given a set of parsed results from a SELECT statement, returns the ``db, 

858 schema, table`` tuple representing the first table in the FROM clause. 

859 

860 Optionally, the match may be constrained with the ``match*`` parameters. 

861 

862 Args: 

863 parsed: a `pyparsing.ParseResults` result 

864 match_db: optional database name to constrain the result to 

865 match_schema: optional schema name to constrain the result to 

866 match_table: optional table name to constrain the result to 

867 

868 Returns: 

869 a :class:`TableId`, which will be empty in case of failure 

870 """ 

871 existing_tables = parsed.join_source.from_tables.asList() 

872 for t in existing_tables: 

873 if isinstance(t, list): 

874 assert len(t) == 1 

875 t = t[0] 

876 table_id = split_db_schema_table(t) 

877 if match_db and table_id.db != match_db: 

878 continue 

879 if match_schema and table_id.schema != match_schema: 

880 continue 

881 if match_table and table_id.table != match_table: 

882 continue 

883 return table_id 

884 return TableId() 

885 

886 

887def set_distinct_within_parsed(p: ParseResults, action: str = "set") -> None: 

888 """ 

889 Modifies (in place) the DISTINCT status of a parsed SQL statement. 

890 

891 Args: 

892 p: a `pyparsing.ParseResults` result 

893 action: ``"set"`` to turn DISTINCT on; ``"clear"`` to turn it off; 

894 or ``"toggle"`` to toggle it. 

895 """ 

896 ss = p.select_specifier # type: ParseResults 

897 if action == "set": 

898 if "DISTINCT" not in ss.asList(): 

899 ss.append("DISTINCT") 

900 elif action == "clear": 

901 if "DISTINCT" in ss.asList(): 

902 del ss[:] 

903 elif action == "toggle": 

904 if "DISTINCT" in ss.asList(): 

905 del ss[:] 

906 else: 

907 ss.append("DISTINCT") 

908 else: 

909 raise ValueError("action must be one of set/clear/toggle") 

910 

911 

912def set_distinct( 

913 sql: str, 

914 grammar: SqlGrammar, 

915 action: str = "set", 

916 formatted: bool = True, 

917 debug: bool = False, 

918 debug_verbose: bool = False, 

919) -> str: 

920 """ 

921 Takes an SQL statement (as a string) and modifies its DISTINCT status. 

922 

923 Args: 

924 sql: SQL statment as text 

925 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

926 action: one of ``"set"``, ``"clear"``, ``"toggle"``; see 

927 :func:`set_distinct_within_parsed` 

928 formatted: pretty-format the result? 

929 debug: show debugging information to the Python log 

930 debug_verbose: be verbose when debugging 

931 

932 Returns: 

933 the modified SQL statment, as a string 

934 

935 """ 

936 p = grammar.get_select_statement().parseString(sql, parseAll=True) 

937 if debug: 

938 log.info(f"START: {sql}") 

939 if debug_verbose: 

940 log.debug("start dump:\n" + p.dump()) 

941 set_distinct_within_parsed(p, action=action) 

942 result = text_from_parsed(p, formatted=formatted) 

943 if debug: 

944 log.info(f"END: {result}") 

945 if debug_verbose: 

946 log.debug("end dump:\n" + p.dump()) 

947 return result 

948 

949 

950def toggle_distinct( 

951 sql: str, 

952 grammar: SqlGrammar, 

953 formatted: bool = True, 

954 debug: bool = False, 

955 debug_verbose: bool = False, 

956) -> str: 

957 """ 

958 Takes an SQL statement and toggles its DISTINCT status. 

959 

960 Args: 

961 sql: SQL statment as text 

962 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

963 formatted: pretty-format the result? 

964 debug: show debugging information to the Python log 

965 debug_verbose: be verbose when debugging 

966 

967 Returns: 

968 the modified SQL statment, as a string 

969 

970 """ 

971 return set_distinct( 

972 sql=sql, 

973 grammar=grammar, 

974 action="toggle", 

975 formatted=formatted, 

976 debug=debug, 

977 debug_verbose=debug_verbose, 

978 ) 

979 

980 

981# ============================================================================= 

982# SQLAlchemy reflection and DDL 

983# ============================================================================= 

984 

985_global_print_not_execute_sql = False 

986 

987 

988def set_print_not_execute(print_not_execute: bool) -> None: 

989 """ 

990 Sets a nasty global flag: should we print DDL, rather than executing it, 

991 when we issue DDL commands from this module? 

992 

993 Args: 

994 print_not_execute: print (not execute)? 

995 """ 

996 global _global_print_not_execute_sql 

997 _global_print_not_execute_sql = print_not_execute 

998 

999 

1000def _exec_ddl(engine: Engine, sql: str) -> None: 

1001 """ 

1002 Executes SQL as DDL. 

1003 

1004 Whether we act or just print is conditional on previous calls to 

1005 :func:`set_print_not_execute`. 

1006 

1007 Args: 

1008 engine: SQLAlchemy database Engine 

1009 sql: raw SQL to execute (or print) 

1010 """ 

1011 log.debug(sql) 

1012 if _global_print_not_execute_sql: 

1013 print(format_sql_for_print(sql) + "\n;") 

1014 # extra \n in case the SQL ends in a comment 

1015 else: 

1016 execute_ddl(engine, sql=sql) 

1017 

1018 

1019def execute(engine: Engine, sql: str) -> None: 

1020 """ 

1021 Executes plain SQL in a transaction. 

1022 

1023 Whether we act or just print is conditional on previous calls to 

1024 :func:`set_print_not_execute`. 

1025 

1026 Args: 

1027 engine: SQLAlchemy database Engine 

1028 sql: raw SQL to execute (or print) 

1029 """ 

1030 log.debug(sql) 

1031 if _global_print_not_execute_sql: 

1032 print(format_sql_for_print(sql) + "\n;") 

1033 # extra \n in case the SQL ends in a comment 

1034 else: 

1035 with engine.begin() as connection: 

1036 connection.execute(sql) 

1037 

1038 

1039def add_columns(engine: Engine, table: Table, columns: List[Column]) -> None: 

1040 """ 

1041 Adds columns to a table. 

1042 

1043 Whether we act or just print is conditional on previous calls to 

1044 :func:`set_print_not_execute`. 

1045 

1046 Args: 

1047 engine: SQLAlchemy database Engine 

1048 table: SQLAlchemy Table object 

1049 columns: SQLAlchemy Column objects to add to the table 

1050 

1051 Behaviour of different database systems: 

1052 

1053 - ANSI SQL: add one column at a time: ``ALTER TABLE ADD [COLUMN] coldef`` 

1054 

1055 - i.e. "COLUMN" optional, one at a time, no parentheses 

1056 - https://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt 

1057 

1058 - MySQL: ``ALTER TABLE ADD [COLUMN] (a INT, b VARCHAR(32));`` 

1059 

1060 - i.e. "COLUMN" optional, parentheses required for >1, multiple OK 

1061 - https://dev.mysql.com/doc/refman/5.7/en/alter-table.html 

1062 

1063 - MS SQL Server: ``ALTER TABLE ADD COLUMN a INT, B VARCHAR(32);`` 

1064 

1065 - i.e. no "COLUMN", no parentheses, multiple OK 

1066 - https://msdn.microsoft.com/en-us/library/ms190238.aspx 

1067 - https://msdn.microsoft.com/en-us/library/ms190273.aspx 

1068 - https://stackoverflow.com/questions/2523676 

1069 

1070 This function therefore operates one at a time. 

1071 

1072 SQLAlchemy doesn't provide a shortcut for this. 

1073 

1074 """ 

1075 existing_column_names = get_column_names( 

1076 engine, tablename=table.name, to_lower=True 

1077 ) 

1078 column_defs = [] # type: List[str] 

1079 for column in columns: 

1080 if column.name.lower() not in existing_column_names: 

1081 column_defs.append(column_creation_ddl(column, engine.dialect)) 

1082 else: 

1083 log.debug( 

1084 f"Table {table.name!r}: column {column.name!r} " 

1085 f"already exists; not adding" 

1086 ) 

1087 for column_def in column_defs: 

1088 log.info(f"Table {table.name!r}: adding column {column_def!r}") 

1089 sql = f"ALTER TABLE {table.name} ADD {column_def}" 

1090 _exec_ddl(engine, sql) 

1091 

1092 

1093def drop_columns( 

1094 engine: Engine, table: Table, column_names: Iterable[str] 

1095) -> None: 

1096 """ 

1097 Drops columns from a table. 

1098 

1099 Whether we act or just print is conditional on previous calls to 

1100 :func:`set_print_not_execute`. 

1101 

1102 Args: 

1103 engine: SQLAlchemy database Engine 

1104 table: SQLAlchemy Table object 

1105 column_names: names of columns to drop 

1106 

1107 Columns are dropped one by one. 

1108 

1109 """ 

1110 existing_column_names = get_column_names( 

1111 engine, tablename=table.name, to_lower=True 

1112 ) 

1113 for name in column_names: 

1114 if name.lower() not in existing_column_names: 

1115 log.debug( 

1116 f"Table {table.name!r}: column {name!r} " 

1117 f"does not exist; not dropping" 

1118 ) 

1119 else: 

1120 log.info(f"Table {table.name!r}: dropping column {name!r}") 

1121 # SQL Server: 

1122 # http://www.techonthenet.com/sql_server/tables/alter_table.php 

1123 # MySQL: 

1124 # http://dev.mysql.com/doc/refman/5.7/en/alter-table.html 

1125 _exec_ddl(engine, f"ALTER TABLE {table.name} DROP COLUMN {name}") 

1126 

1127 

1128def add_indexes( 

1129 engine: Engine, table: Table, index_info_list: Iterable[IndexCreationInfo] 

1130) -> None: 

1131 """ 

1132 Adds indexes to a table. 

1133 

1134 Whether we act or just print is conditional on previous calls to 

1135 :func:`set_print_not_execute`. 

1136 

1137 Args: 

1138 engine: 

1139 SQLAlchemy database Engine 

1140 table: 

1141 SQLAlchemy Table object 

1142 index_info_list: 

1143 Index(es) to create: list of :class:`IndexCreationInfo` objects. 

1144 """ 

1145 existing_index_names = get_index_names( 

1146 engine, tablename=table.name, to_lower=True 

1147 ) 

1148 for i in index_info_list: 

1149 index_name = i.index_name 

1150 column = i.column_names 

1151 if index_name.lower() not in existing_index_names: 

1152 log.info( 

1153 f"Table {table.name!r}: adding index {index_name!r} on " 

1154 f"column {column!r}" 

1155 ) 

1156 _exec_ddl( 

1157 engine, 

1158 f""" 

1159 CREATE{" UNIQUE" if i.unique else ""} INDEX {index_name} 

1160 ON {table.name} ({column}) 

1161 """, 

1162 ) 

1163 else: 

1164 log.debug( 

1165 f"Table {table.name!r}: index {index_name!r} " 

1166 f"already exists; not adding" 

1167 ) 

1168 

1169 

1170def drop_indexes( 

1171 engine: Engine, table: Table, index_names: Iterable[str] 

1172) -> None: 

1173 """ 

1174 Drops indexes from a table. 

1175 

1176 Whether we act or just print is conditional on previous calls to 

1177 :func:`set_print_not_execute`. 

1178 

1179 Args: 

1180 engine: SQLAlchemy database Engine 

1181 table: SQLAlchemy Table object 

1182 index_names: names of indexes to drop 

1183 """ 

1184 existing_index_names = get_index_names( 

1185 engine, tablename=table.name, to_lower=True 

1186 ) 

1187 for index_name in index_names: 

1188 if index_name.lower() not in existing_index_names: 

1189 log.debug( 

1190 f"Table {table.name!r}: index {index_name!r} " 

1191 f"does not exist; not dropping" 

1192 ) 

1193 else: 

1194 log.info(f"Table {table.name!r}: dropping index {index_name!r}") 

1195 if engine.dialect.name == "mysql": 

1196 sql = f"ALTER TABLE {table.name} DROP INDEX {index_name}" 

1197 elif engine.dialect.name == "mssql": 

1198 sql = f"DROP INDEX {table.name}.{index_name}" 

1199 else: 

1200 assert False, f"Unknown dialect: {engine.dialect.name}" 

1201 _exec_ddl(engine, sql) 

1202 

1203 

1204def get_table_names( 

1205 engine: Engine, to_lower: bool = False, sort: bool = False 

1206) -> List[str]: 

1207 """ 

1208 Returns all table names for the database. 

1209 

1210 Args: 

1211 engine: SQLAlchemy database Engine 

1212 to_lower: convert table names to lower case? 

1213 sort: sort table names? 

1214 

1215 Returns: 

1216 list of table names 

1217 

1218 """ 

1219 inspector = inspect(engine) 

1220 table_names = inspector.get_table_names() 

1221 if to_lower: 

1222 table_names = [x.lower() for x in table_names] 

1223 if sort: 

1224 table_names = sorted(table_names, key=lambda x: x.lower()) 

1225 return table_names 

1226 

1227 

1228def get_view_names( 

1229 engine: Engine, to_lower: bool = False, sort: bool = False 

1230) -> List[str]: 

1231 """ 

1232 Returns all view names for the database. 

1233 

1234 Args: 

1235 engine: SQLAlchemy database Engine 

1236 to_lower: convert view names to lower case? 

1237 sort: sort view names? 

1238 

1239 Returns: 

1240 list of view names 

1241 

1242 """ 

1243 inspector = inspect(engine) 

1244 view_names = inspector.get_view_names() 

1245 if to_lower: 

1246 view_names = [x.lower() for x in view_names] 

1247 if sort: 

1248 view_names = sorted(view_names, key=lambda x: x.lower()) 

1249 return view_names 

1250 

1251 

1252def get_column_names( 

1253 engine: Engine, tablename: str, to_lower: bool = False, sort: bool = False 

1254) -> List[str]: 

1255 """ 

1256 Reads columns names afresh from the database, for a specific table (in case 

1257 metadata is out of date). 

1258 

1259 Args: 

1260 engine: SQLAlchemy database Engine 

1261 tablename: name of the table 

1262 to_lower: convert view names to lower case? 

1263 sort: sort view names? 

1264 

1265 Returns: 

1266 list of column names 

1267 

1268 """ 

1269 inspector = inspect(engine) 

1270 columns = inspector.get_columns(tablename) 

1271 column_names = [x["name"] for x in columns] 

1272 if to_lower: 

1273 column_names = [x.lower() for x in column_names] 

1274 if sort: 

1275 column_names = sorted(column_names, key=lambda x: x.lower()) 

1276 return column_names 

1277 

1278 

1279def get_index_names( 

1280 engine: Engine, tablename: str, to_lower: bool = False, sort: bool = False 

1281) -> List[str]: 

1282 """ 

1283 Reads index names from the database, for a specific table. 

1284 

1285 Args: 

1286 engine: SQLAlchemy database Engine 

1287 tablename: name of the table 

1288 to_lower: convert index names to lower case? 

1289 sort: sort index names? 

1290 

1291 Returns: 

1292 list of index names 

1293 

1294 """ 

1295 # http://docs.sqlalchemy.org/en/latest/core/reflection.html 

1296 inspector = inspect(engine) 

1297 indexes = inspector.get_indexes(tablename) 

1298 index_names = [x["name"] for x in indexes if x["name"]] 

1299 # ... at least for SQL Server, there always seems to be a blank one 

1300 # with {'name': None, ...}. 

1301 if to_lower: 

1302 index_names = [x.lower() for x in index_names] 

1303 if sort: 

1304 index_names = sorted(index_names, key=lambda x: x.lower()) 

1305 return index_names 

1306 

1307 

1308def ensure_columns_present( 

1309 engine: Engine, tablename: str, column_names: Iterable[str] 

1310) -> None: 

1311 """ 

1312 Ensure all these columns are present in a table, or raise an exception. 

1313 

1314 Operates in case-insensitive fashion. 

1315 

1316 Args: 

1317 engine: SQLAlchemy database Engine 

1318 tablename: name of the table 

1319 column_names: names of required columns 

1320 

1321 Raises: 

1322 :exc:`ValueError` if any are missing 

1323 

1324 """ 

1325 existing_column_names = get_column_names( 

1326 engine, tablename=tablename, to_lower=True 

1327 ) 

1328 if not column_names: 

1329 return 

1330 for col in column_names: 

1331 if col.lower() not in existing_column_names: 

1332 raise ValueError( 

1333 f"Column {col!r} missing from table {tablename!r}, " 

1334 f"whose columns are {existing_column_names!r}" 

1335 ) 

1336 

1337 

1338def create_view(engine: Engine, viewname: str, select_sql: str) -> None: 

1339 """ 

1340 Creates a view. 

1341 

1342 Whether we act or just print is conditional on previous calls to 

1343 :func:`set_print_not_execute`. 

1344 

1345 Args: 

1346 engine: SQLAlchemy database Engine 

1347 viewname: view name 

1348 select_sql: SQL SELECT statement for this view 

1349 """ 

1350 if engine.dialect.name == "mysql": 

1351 # MySQL has CREATE OR REPLACE VIEW. 

1352 sql = f"CREATE OR REPLACE VIEW {viewname} AS {select_sql}" 

1353 else: 

1354 # SQL Server doesn't: https://stackoverflow.com/questions/18534919 

1355 drop_view(engine, viewname, quiet=True) 

1356 sql = f"CREATE VIEW {viewname} AS {select_sql}" 

1357 log.info(f"Creating view: {viewname!r}") 

1358 _exec_ddl(engine, sql) 

1359 

1360 

1361def assert_view_has_same_num_rows( 

1362 engine: Engine, basetable: str, viewname: str 

1363) -> None: 

1364 """ 

1365 Ensures that a view gives the same number of rows as a table. (For use in 

1366 situations where this should hold; views don't have to do this in general!) 

1367 

1368 Args: 

1369 engine: SQLAlchemy database Engine 

1370 basetable: name of the table that this view should have a 1:1 

1371 relationship to 

1372 viewname: view name 

1373 

1374 Raises: 

1375 :exc:`AssertionError` if they don't have the same number of rows 

1376 

1377 """ 

1378 # Note that this relies on the data, i.e. design failures MAY cause this 

1379 # assertion to fail, but won't necessarily (e.g. if the table is empty). 

1380 n_base = count_star(engine, basetable) 

1381 n_view = count_star(engine, viewname) 

1382 assert n_view == n_base, ( 

1383 f"View bug: view {viewname} has {n_view} records but its base table " 

1384 f"{basetable} has {n_base}; they should be equal" 

1385 ) 

1386 

1387 

1388def drop_view(engine: Engine, viewname: str, quiet: bool = False) -> None: 

1389 """ 

1390 Drops a view. 

1391 

1392 Whether we act or just print is conditional on previous calls to 

1393 :func:`set_print_not_execute`. 

1394 

1395 Args: 

1396 engine: SQLAlchemy database Engine 

1397 viewname: view name 

1398 quiet: don't announce this to the Python log 

1399 

1400 """ 

1401 # MySQL has DROP VIEW IF EXISTS, but SQL Server only has that from 

1402 # SQL Server 2016 onwards. 

1403 # - https://msdn.microsoft.com/en-us/library/ms173492.aspx 

1404 # - http://dev.mysql.com/doc/refman/5.7/en/drop-view.html 

1405 view_names = get_view_names(engine, to_lower=True) 

1406 if viewname.lower() not in view_names: 

1407 log.debug(f"View {viewname} does not exist; not dropping") 

1408 else: 

1409 if not quiet: 

1410 log.info(f"Dropping view: {viewname!r}") 

1411 _exec_ddl(engine, f"DROP VIEW {viewname}") 

1412 

1413 

1414def get_column_fk_description(c: Column) -> str: 

1415 """ 

1416 Standardized description of a column's foreign keys. 

1417 

1418 Args: 

1419 c: 

1420 SQLAlchemy Column 

1421 """ 

1422 fkeys = sorted( 

1423 c.foreign_keys, key=lambda x: (x.column.table.name, x.column.name) 

1424 ) 

1425 if not fkeys: 

1426 return "" 

1427 fk_strings = [f"{fk.column.table.name}.{fk.column.name}" for fk in fkeys] 

1428 return "FK to " + ", ".join(fk_strings) 

1429 

1430 

1431@dataclass 

1432class ReflectedColumnInfo: 

1433 """ 

1434 Provides information about a column reflected from a database, with 

1435 optional additional information from a CRATE data dictionary, +/- a 

1436 description of values in that column (for researcher reports). 

1437 """ 

1438 

1439 column: Column 

1440 override_comment: str = None # can override SQLAlchemy-level comment 

1441 crate_annotation: str = None 

1442 values_info: str = None 

1443 

1444 @property 

1445 def name(self) -> str: 

1446 return self.columnname 

1447 

1448 @property 

1449 def columnname(self) -> str: 

1450 return self.column.name 

1451 # Do not manipulate the case of SOURCE tables/columns. 

1452 # If you do, they can fail to match the SQLAlchemy 

1453 # introspection and cause a crash. 

1454 

1455 @property 

1456 def tablename(self) -> str: 

1457 return self.column.table.name 

1458 

1459 @property 

1460 def tablename_columname(self) -> str: 

1461 return f"{self.column.table.name}.{self.column.name}" 

1462 

1463 @property 

1464 def sqla_coltype(self) -> TypeEngine: 

1465 return self.column.type 

1466 

1467 @property 

1468 def sql_type(self) -> str: 

1469 try: 

1470 return str(self.column.type) 

1471 except CompileError: 

1472 log.critical(f"Column that failed was: {self.column!r}") 

1473 raise 

1474 

1475 @property 

1476 def datatype_sqltext(self) -> str: 

1477 return self.sql_type 

1478 

1479 @property 

1480 def pk(self) -> bool: 

1481 return self.column.primary_key 

1482 

1483 @property 

1484 def nullable(self) -> bool: 

1485 return self.column.nullable 

1486 

1487 @property 

1488 def comment(self) -> str: 

1489 """ 

1490 The database comment, if present, or another that has been supplied. 

1491 """ 

1492 db_comment = getattr(self.column, "comment", "") 

1493 # ... not all dialects support reflecting comments; 

1494 # https://docs.sqlalchemy.org/en/14/core/reflection.html 

1495 return self.override_comment or db_comment or "" 

1496 

1497 @property 

1498 def nullable_str(self) -> str: 

1499 return "✓" if self.nullable else "NOT NULL" 

1500 

1501 @property 

1502 def pk_str(self) -> str: 

1503 return "PK" if self.pk else "" 

1504 

1505 @property 

1506 def fk_str(self) -> str: 

1507 return get_column_fk_description(self.column) 

1508 

1509 def get_column_source_description(self, with_fk: bool = True) -> str: 

1510 """ 

1511 Returns a description of where the column is from, used as a suffix for 

1512 data dictionary comment generation. 

1513 

1514 Args: 

1515 with_fk: 

1516 Include foreign key descriptions (helpful because CRATE doesn't 

1517 reproduce FK relationships in the destination DDL). 

1518 """ 

1519 if with_fk: 

1520 fk_str = self.fk_str 

1521 if fk_str: 

1522 fk_str = "; " + fk_str 

1523 else: 

1524 fk_str = "" 

1525 return f" [from {self.tablename_columname}{fk_str}]" 

1526 

1527 @property 

1528 def crate_annotation_str(self) -> str: 

1529 """ 

1530 Human-oriented version for report. 

1531 """ 

1532 return self.crate_annotation or "?" 

1533 

1534 @property 

1535 def values_info_str(self) -> str: 

1536 """ 

1537 Human-oriented version for report. 

1538 """ 

1539 return self.values_info or "?" 

1540 

1541 

1542# ============================================================================= 

1543# ViewMaker 

1544# ============================================================================= 

1545 

1546 

1547class ViewMaker: 

1548 """ 

1549 View-building assistance class. 

1550 """ 

1551 

1552 def __init__( 

1553 self, 

1554 viewname: str, 

1555 engine: Engine, 

1556 basetable: str, 

1557 existing_to_lower: bool = False, 

1558 rename: Dict[str, str] = None, 

1559 userobj: Any = None, 

1560 enforce_same_n_rows_as_base: bool = True, 

1561 insert_basetable_columns: bool = True, 

1562 ) -> None: 

1563 """ 

1564 Args: 

1565 viewname: name of the view 

1566 engine: SQLAlchemy database Engine 

1567 basetable: name of the single base table that this view draws from 

1568 existing_to_lower: translate column names to lower case in the 

1569 view? 

1570 rename: optional dictionary mapping ``from_name: to_name`` to 

1571 translate column names in the view 

1572 userobj: optional object (e.g. `argparse.Namespace`, 

1573 dictionary...), not used by this class, and purely to store 

1574 information for others' benefit 

1575 enforce_same_n_rows_as_base: ensure that the view produces the 

1576 same number of rows as its base table? 

1577 insert_basetable_columns: start drafting the view by including all 

1578 columns from the base table? 

1579 """ 

1580 rename = rename or {} 

1581 assert basetable, "ViewMaker: basetable missing!" 

1582 self.viewname = viewname 

1583 self.engine = engine 

1584 self.basetable = basetable 

1585 self.userobj = userobj # only for others' benefit 

1586 self.enforce_same_n_rows_as_base = enforce_same_n_rows_as_base 

1587 self.select_elements = [] # type: List[str] 

1588 self.from_elements = [basetable] # type: List[str] 

1589 self.where_elements = [] # type: List[str] 

1590 self.lookup_tables = [] # type: List[str] 

1591 self.index_requests = OrderedDict() # type: Dict[str, List[str]] 

1592 

1593 if insert_basetable_columns: 

1594 grammar = make_grammar(engine.dialect.name) 

1595 

1596 def q(identifier: str) -> str: 

1597 return grammar.quote_identifier_if_required(identifier) 

1598 

1599 for colname in get_column_names( 

1600 engine, tablename=basetable, to_lower=existing_to_lower 

1601 ): 

1602 if colname in rename: 

1603 rename_to = rename[colname] 

1604 if not rename_to: 

1605 continue 

1606 as_clause = f" AS {q(rename_to)}" 

1607 else: 

1608 as_clause = "" 

1609 self.select_elements.append( 

1610 f"{q(basetable)}.{q(colname)}{as_clause}" 

1611 ) 

1612 assert self.select_elements, ( 

1613 "Must have some active SELECT " "elements from base table" 

1614 ) 

1615 

1616 def add_select(self, element: str) -> None: 

1617 """ 

1618 Add an element to the SELECT clause of the the draft view's SQL 

1619 (meaning: add e.g. a result column). 

1620 """ 

1621 self.select_elements.append(element) 

1622 

1623 def add_from(self, element: str) -> None: 

1624 """ 

1625 Add an element to the FROM clause of the draft view's SQL statement. 

1626 """ 

1627 self.from_elements.append(element) 

1628 

1629 def add_where(self, element: str) -> None: 

1630 """ 

1631 Add an element to the WHERE clause of the draft view's SQL statement. 

1632 """ 

1633 self.where_elements.append(element) 

1634 

1635 def get_sql(self) -> str: 

1636 """ 

1637 Returns the view-creation SQL. 

1638 """ 

1639 assert self.select_elements, "ViewMaker: no SELECT elements!" 

1640 if self.where_elements: 

1641 where = "\n WHERE {}".format( 

1642 "\n AND ".join(self.where_elements) 

1643 ) 

1644 else: 

1645 where = "" 

1646 return ( 

1647 "\n SELECT {select_elements}" 

1648 "\n FROM {from_elements}{where}".format( 

1649 select_elements=",\n ".join(self.select_elements), 

1650 from_elements="\n ".join(self.from_elements), 

1651 where=where, 

1652 ) 

1653 ) 

1654 

1655 def create_view(self, engine: Engine) -> None: 

1656 """ 

1657 Creates the view. 

1658 

1659 Whether we act or just print is conditional on previous calls to 

1660 :func:`set_print_not_execute`. 

1661 

1662 If ``enforce_same_n_rows_as_base`` is set, check the number of rows 

1663 returned matches the base table. 

1664 

1665 Args: 

1666 engine: SQLAlchemy database Engine 

1667 """ 

1668 create_view(engine, self.viewname, self.get_sql()) 

1669 if self.enforce_same_n_rows_as_base: 

1670 assert_view_has_same_num_rows( 

1671 engine, self.basetable, self.viewname 

1672 ) 

1673 

1674 def drop_view(self, engine: Engine) -> None: 

1675 """ 

1676 Drops the view. 

1677 

1678 Whether we act or just print is conditional on previous calls to 

1679 :func:`set_print_not_execute`. 

1680 

1681 Args: 

1682 engine: SQLAlchemy database Engine 

1683 

1684 """ 

1685 drop_view(engine, self.viewname) 

1686 

1687 def record_lookup_table(self, table: str) -> None: 

1688 """ 

1689 Keep a record of a lookup table. The framework may wish to suppress 

1690 these from a data dictionary later (e.g. create a view, suppress the 

1691 messier raw data). See :func:`get_lookup_tables`. 

1692 

1693 Args: 

1694 table: table name 

1695 """ 

1696 if table not in self.lookup_tables: 

1697 self.lookup_tables.append(table) 

1698 

1699 def get_lookup_tables(self) -> List[str]: 

1700 """ 

1701 Returns all lookup tables that we have recorded. See 

1702 :func:`record_lookup_table`. 

1703 """ 

1704 return self.lookup_tables 

1705 

1706 def request_index(self, table: str, column: str) -> None: 

1707 """ 

1708 Note a request that a specific column be indexed. The framework can use 

1709 the ViewMaker to keep a note of these requests, and then add index 

1710 hints to a data dictionary if it wishes. See 

1711 :func:`get_index_request_dict`. 

1712 

1713 Args: 

1714 table: table name 

1715 column: column name 

1716 """ 

1717 if table not in self.index_requests: 

1718 self.index_requests[table] = [] # type: List[str] 

1719 if column not in self.index_requests[table]: 

1720 self.index_requests[table].append(column) 

1721 

1722 def get_index_request_dict(self) -> Dict[str, List[str]]: 

1723 """ 

1724 Returns all our recorded index requests, as a dictionary mapping each 

1725 table name to a list of column names to be indexed. See 

1726 :func:`request_index`. 

1727 """ 

1728 return self.index_requests 

1729 

1730 def record_lookup_table_keyfield( 

1731 self, table: str, keyfield: Union[str, Iterable[str]] 

1732 ) -> None: 

1733 """ 

1734 Makes a note that a table is a lookup table, and its key field(s) 

1735 should be indexed. See :func:`get_lookup_tables`, 

1736 :func:`get_index_request_dict`. 

1737 

1738 Args: 

1739 table: table name 

1740 keyfield: field name, or iterable (e.g. list) of them 

1741 """ 

1742 if isinstance(keyfield, str): 

1743 keyfield = [keyfield] 

1744 self.record_lookup_table(table) 

1745 for kf in keyfield: 

1746 self.request_index(table, kf) 

1747 

1748 def record_lookup_table_keyfields( 

1749 self, 

1750 table_keyfield_tuples: Iterable[Tuple[str, Union[str, Iterable[str]]]], 

1751 ) -> None: 

1752 """ 

1753 Make a note of a whole set of lookup table / key field groups. See 

1754 :func:`record_lookup_table_keyfield`. 

1755 

1756 Args: 

1757 table_keyfield_tuples: 

1758 iterable (e.g. list) of tuples of the format ``tablename, 

1759 keyfield``. Each will be passed to 

1760 :func:`record_lookup_table_keyfield`. 

1761 """ 

1762 for t, k in table_keyfield_tuples: 

1763 self.record_lookup_table_keyfield(t, k) 

1764 

1765 

1766# ============================================================================= 

1767# TransactionSizeLimiter 

1768# ============================================================================= 

1769 

1770 

1771class TransactionSizeLimiter: 

1772 """ 

1773 Class to allow us to limit the size of database transactions. 

1774 """ 

1775 

1776 def __init__( 

1777 self, 

1778 session: Session, 

1779 max_rows_before_commit: int = None, 

1780 max_bytes_before_commit: int = None, 

1781 ) -> None: 

1782 """ 

1783 Args: 

1784 session: SQLAlchemy database Session 

1785 max_rows_before_commit: how many rows should we insert before 

1786 triggering a COMMIT? ``None`` for no limit. 

1787 max_bytes_before_commit: how many bytes should we insert before 

1788 triggering a COMMIT? ``None`` for no limit. 

1789 """ 

1790 self._session = session 

1791 self._max_rows_before_commit = max_rows_before_commit 

1792 self._max_bytes_before_commit = max_bytes_before_commit 

1793 self._bytes_in_transaction = 0 

1794 self._rows_in_transaction = 0 

1795 

1796 def commit(self) -> None: 

1797 """ 

1798 Performs a database COMMIT and resets our counters. 

1799 

1800 (Measures some timing information, too.) 

1801 """ 

1802 with MultiTimerContext(timer, TIMING_COMMIT): 

1803 self._session.commit() 

1804 self._bytes_in_transaction = 0 

1805 self._rows_in_transaction = 0 

1806 

1807 def notify( 

1808 self, n_rows: int, n_bytes: int, force_commit: bool = False 

1809 ) -> None: 

1810 """ 

1811 Use this function to notify the limiter of data that you've inserted 

1812 into the database. If the total number of rows or bytes exceeds a limit 

1813 that we've set, this will trigger a COMMIT. 

1814 

1815 Args: 

1816 n_rows: number of rows inserted 

1817 n_bytes: number of bytes inserted 

1818 force_commit: force a COMMIT? 

1819 """ 

1820 if force_commit: 

1821 self.commit() 

1822 return 

1823 self._bytes_in_transaction += n_bytes 

1824 self._rows_in_transaction += n_rows 

1825 if ( 

1826 self._max_bytes_before_commit is not None 

1827 and self._bytes_in_transaction >= self._max_bytes_before_commit 

1828 ): 

1829 log.debug( 

1830 f"Triggering early commit based on byte count " 

1831 f"(reached {sizeof_fmt(self._bytes_in_transaction)}, " 

1832 f"limit is {sizeof_fmt(self._max_bytes_before_commit)})" 

1833 ) 

1834 self.commit() 

1835 elif ( 

1836 self._max_rows_before_commit is not None 

1837 and self._rows_in_transaction >= self._max_rows_before_commit 

1838 ): 

1839 log.debug( 

1840 f"Triggering early commit based on row count " 

1841 f"(reached {self._rows_in_transaction} rows, " 

1842 f"limit is {self._max_rows_before_commit})" 

1843 ) 

1844 self.commit() 

1845 

1846 

1847# ============================================================================= 

1848# Specification matching 

1849# ============================================================================= 

1850 

1851 

1852def _matches_tabledef(table: str, tabledef: str) -> bool: 

1853 """ 

1854 Does the table name match the wildcard-based table definition? 

1855 

1856 Args: 

1857 table: tablename 

1858 tabledef: ``fnmatch``-style pattern (e.g. 

1859 ``"patient_address_table_*"``) 

1860 """ 

1861 tr = get_spec_match_regex(tabledef) 

1862 return bool(tr.match(table)) 

1863 

1864 

1865def matches_tabledef(table: str, tabledef: Union[str, List[str]]) -> bool: 

1866 """ 

1867 Does the table name match the wildcard-based table definition? 

1868 

1869 Args: 

1870 table: table name 

1871 tabledef: ``fnmatch``-style pattern (e.g. 

1872 ``"patient_address_table_*"``), or list of them 

1873 """ 

1874 if isinstance(tabledef, str): 

1875 return _matches_tabledef(table, tabledef) 

1876 elif not tabledef: 

1877 return False 

1878 else: # list 

1879 return any(_matches_tabledef(table, td) for td in tabledef) 

1880 

1881 

1882def _matches_fielddef(table: str, field: str, fielddef: str) -> bool: 

1883 """ 

1884 Does the table/field name match the wildcard-based field definition? 

1885 

1886 Args: 

1887 table: tablename 

1888 field: fieldname 

1889 fielddef: ``fnmatch``-style pattern (e.g. ``"system_table.*"``, 

1890 ``"*.nhs_number"``) 

1891 """ 

1892 column_id = split_db_schema_table_column(fielddef) 

1893 cr = get_spec_match_regex(column_id.column) 

1894 if not column_id.table: 

1895 # Table not specified in the wildcard. 

1896 # It's a match if the field matches. 

1897 return bool(cr.match(field)) 

1898 # Table specified in the wildcard. 

1899 # Both the table and the field parts have to match. 

1900 tr = get_spec_match_regex(column_id.table) 

1901 return bool(tr.match(table)) and bool(cr.match(field)) 

1902 

1903 

1904def matches_fielddef( 

1905 table: str, field: str, fielddef: Union[str, List[str]] 

1906) -> bool: 

1907 """ 

1908 Does the table/field name match the wildcard-based field definition? 

1909 

1910 Args: 

1911 table: table name 

1912 field: fieldname 

1913 fielddef: ``fnmatch``-style pattern (e.g. ``"system_table.*"`` or 

1914 ``"*.nhs_number"``), or list of them 

1915 """ 

1916 if isinstance(fielddef, str): 

1917 return _matches_fielddef(table, field, fielddef) 

1918 elif not fielddef: 

1919 return False 

1920 else: # list 

1921 return any(_matches_fielddef(table, field, fd) for fd in fielddef) 

1922 

1923 

1924# ============================================================================= 

1925# More SQL 

1926# ============================================================================= 

1927 

1928 

1929def sql_fragment_cast_to_int( 

1930 expr: str, 

1931 big: bool = True, 

1932 dialect: Dialect = None, 

1933 viewmaker: ViewMaker = None, 

1934) -> str: 

1935 """ 

1936 Takes an SQL expression and coerces it to an integer. For Microsoft SQL 

1937 Server. 

1938 

1939 Args: 

1940 expr: starting SQL expression 

1941 big: use BIGINT, not INTEGER? 

1942 dialect: optional :class:`sqlalchemy.engine.interfaces.Dialect`. If 

1943 ``None`` and we have a ``viewmaker``, use the viewmaker's dialect. 

1944 Otherwise, assume SQL Server. 

1945 viewmaker: optional :class:`ViewMaker` 

1946 

1947 Returns: 

1948 modified SQL expression 

1949 

1950 *Notes* 

1951 

1952 Conversion to INT: 

1953 

1954 - https://stackoverflow.com/questions/2000045 

1955 - https://stackoverflow.com/questions/14719760 (this one in particular!) 

1956 - https://stackoverflow.com/questions/14692131 

1957 

1958 - see LIKE example. 

1959 - see ISNUMERIC(); 

1960 https://msdn.microsoft.com/en-us/library/ms186272.aspx; 

1961 but that includes non-integer numerics 

1962 

1963 - https://msdn.microsoft.com/en-us/library/ms174214(v=sql.120).aspx; 

1964 relates to the SQL Server Management Studio "Find and Replace" 

1965 dialogue box, not to SQL itself! 

1966 

1967 - https://stackoverflow.com/questions/29206404/mssql-regular-expression 

1968 

1969 Note that the regex-like expression supported by LIKE is extremely limited. 

1970 

1971 - https://msdn.microsoft.com/en-us/library/ms179859.aspx 

1972 

1973 - The only things supported are: 

1974 

1975 .. code-block:: none 

1976 

1977 % any characters 

1978 _ any single character 

1979 [] single character in range or set, e.g. [a-f], [abcdef] 

1980 [^] single character NOT in range or set, e.g. [^a-f], [abcdef] 

1981 

1982 SQL Server does not support a REGEXP command directly. 

1983 

1984 So the best bet is to have the LIKE clause check for a non-integer: 

1985 

1986 .. code-block:: sql 

1987 

1988 CASE 

1989 WHEN something LIKE '%[^0-9]%' THEN NULL 

1990 ELSE CAST(something AS BIGINT) 

1991 END 

1992 

1993 ... which doesn't deal with spaces properly, but there you go. 

1994 Could also strip whitespace left/right: 

1995 

1996 .. code-block:: sql 

1997 

1998 CASE 

1999 WHEN LTRIM(RTRIM(something)) LIKE '%[^0-9]%' THEN NULL 

2000 ELSE CAST(something AS BIGINT) 

2001 END 

2002 

2003 That only works for positive integers. 

2004 

2005 LTRIM/RTRIM are not ANSI SQL. 

2006 Nor are unusual LIKE clauses; see 

2007 https://stackoverflow.com/questions/712580/list-of-special-characters-for-sql-like-clause 

2008 

2009 The other, for SQL Server 2012 or higher, is TRY_CAST: 

2010 

2011 .. code-block:: sql 

2012 

2013 TRY_CAST(something AS BIGINT) 

2014 

2015 ... which returns NULL upon failure; see 

2016 https://msdn.microsoft.com/en-us/library/hh974669.aspx 

2017 

2018 Therefore, our **method** is as follows: 

2019 

2020 - If the database supports TRY_CAST, use that. 

2021 - Otherwise if we're using SQL Server, use a CASE/CAST construct. 

2022 - Otherwise, raise :exc:`ValueError` as we don't know what to do. 

2023 

2024 """ 

2025 inttype = "BIGINT" if big else "INTEGER" 

2026 if dialect is None and viewmaker is not None: 

2027 dialect = viewmaker.engine.dialect 

2028 if dialect is None: 

2029 sql_server = True 

2030 supports_try_cast = False 

2031 else: 

2032 # noinspection PyUnresolvedReferences 

2033 sql_server = dialect.name == "mssql" 

2034 # noinspection PyUnresolvedReferences 

2035 supports_try_cast = ( 

2036 sql_server and dialect.server_version_info >= MS_2012_VERSION 

2037 ) 

2038 if supports_try_cast: 

2039 return f"TRY_CAST({expr} AS {inttype})" 

2040 elif sql_server: 

2041 return ( 

2042 f"CASE WHEN LTRIM(RTRIM({expr})) LIKE '%[^0-9]%' " 

2043 f"THEN NULL ELSE CAST({expr} AS {inttype}) END" 

2044 ) 

2045 # Doesn't support negative integers. 

2046 else: 

2047 # noinspection PyUnresolvedReferences 

2048 raise ValueError( 

2049 f"Code not yet written for convert-to-int for " 

2050 f"dialect {dialect.name}" 

2051 ) 

2052 

2053 

2054# ============================================================================= 

2055# Abstracted SQL WHERE condition 

2056# ============================================================================= 

2057 

2058 

2059@register_for_json(method=METHOD_PROVIDES_INIT_KWARGS) 

2060@functools.total_ordering 

2061class WhereCondition: 

2062 """ 

2063 Ancillary class for building SQL WHERE expressions from our web forms. 

2064 

2065 The essence of it is ``WHERE column op value_or_values``. 

2066 """ 

2067 

2068 def __init__( 

2069 self, 

2070 column_id: ColumnId = None, 

2071 op: str = "", 

2072 datatype: str = "", 

2073 value_or_values: Any = None, 

2074 raw_sql: str = "", 

2075 from_table_for_raw_sql: TableId = None, 

2076 ) -> None: 

2077 """ 

2078 Args: 

2079 column_id: 

2080 :class:`ColumnId` for the column 

2081 op: 

2082 operation (e.g. ``=``, ``<``, ``<=``, etc.) 

2083 datatype: 

2084 data type string that must match values in our 

2085 ``querybuilder.js``; see source code. We use this to know how 

2086 to build SQL literal values. (Not terribly elegant, but it 

2087 works; SQL injection isn't a particular concern because we 

2088 let our users run any SQL they want and ensure the connection 

2089 is made read-only.) 

2090 value_or_values: 

2091 ``None``, single value, or list of values. Which is appropriate 

2092 depends on the operation. For example, ``IS NULL`` takes no 

2093 value; ``=`` takes one; ``IN`` takes many. 

2094 raw_sql: 

2095 override any thinking we might wish to do, and just return this 

2096 raw SQL 

2097 from_table_for_raw_sql: 

2098 if we are using raw SQL, provide a :class:`TableId` for the 

2099 relevant table here 

2100 """ 

2101 self._column_id = column_id 

2102 self._op = op.upper() 

2103 self._datatype = datatype 

2104 self._value = value_or_values 

2105 self._no_value = False 

2106 self._multivalue = False 

2107 self._raw_sql = raw_sql 

2108 self._from_table_for_raw_sql = from_table_for_raw_sql 

2109 

2110 if not self._raw_sql: 

2111 if self._op in SQL_OPS_VALUE_UNNECESSARY: 

2112 self._no_value = True 

2113 assert value_or_values is None, "Superfluous value passed" 

2114 elif self._op in SQL_OPS_MULTIPLE_VALUES: 

2115 self._multivalue = True 

2116 assert isinstance(value_or_values, list), "Need list" 

2117 else: 

2118 assert not isinstance( 

2119 value_or_values, list 

2120 ), "Need single value" 

2121 

2122 def init_kwargs(self) -> Dict: 

2123 return { 

2124 "column_id": self._column_id, 

2125 "op": self._op, 

2126 "datatype": self._datatype, 

2127 "value_or_values": self._value, 

2128 "raw_sql": self._raw_sql, 

2129 "from_table_for_raw_sql": self._from_table_for_raw_sql, 

2130 } 

2131 

2132 def __repr__(self) -> str: 

2133 return ( 

2134 "{qualname}(" 

2135 "column_id={column_id}, " 

2136 "op={op}, " 

2137 "datatype={datatype}, " 

2138 "value_or_values={value_or_values}, " 

2139 "raw_sql={raw_sql}, " 

2140 "from_table_for_raw_sql={from_table_for_raw_sql}" 

2141 ")".format( 

2142 qualname=self.__class__.__qualname__, 

2143 column_id=repr(self._column_id), 

2144 op=repr(self._op), 

2145 datatype=repr(self._datatype), 

2146 value_or_values=repr(self._value), 

2147 raw_sql=repr(self._raw_sql), 

2148 from_table_for_raw_sql=repr(self._from_table_for_raw_sql), 

2149 ) 

2150 ) 

2151 

2152 def __eq__(self, other: "WhereCondition") -> bool: 

2153 return ( 

2154 self._raw_sql == other._raw_sql 

2155 and self._column_id == other._column_id 

2156 and self._op == other._op 

2157 and self._value == other._value 

2158 ) 

2159 

2160 def __lt__(self, other: "WhereCondition") -> bool: 

2161 return (self._raw_sql, self._column_id, self._op, self._value) < ( 

2162 other._raw_sql, 

2163 other._column_id, 

2164 other._op, 

2165 other._value, 

2166 ) 

2167 

2168 @property 

2169 def column_id(self) -> ColumnId: 

2170 """ 

2171 Returns the :class:`ColumnId` provided at creation. 

2172 """ 

2173 return self._column_id 

2174 

2175 @property 

2176 def table_id(self) -> TableId: 

2177 """ 

2178 Returns a :class:`TableId`: 

2179 

2180 - for raw SQL, our ``from_table_for_raw_sql`` attribute 

2181 - otherwise, the table ID extracted from our ``column_id`` attribute 

2182 """ 

2183 if self._raw_sql: 

2184 return self._from_table_for_raw_sql 

2185 return self.column_id.table_id 

2186 

2187 def table_str(self, grammar: SqlGrammar) -> str: 

2188 """ 

2189 Returns the table identifier in the specified SQL grammar. 

2190 

2191 Args: 

2192 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

2193 """ 

2194 return self.table_id.identifier(grammar) 

2195 

2196 def sql(self, grammar: SqlGrammar) -> str: 

2197 """ 

2198 Returns the WHERE clause (without ``WHERE`` itself!) for our condition, 

2199 in the specified SQL grammar. Some examples might be: 

2200 

2201 - ``somecol = 3`` 

2202 - ``othercol IN (6, 7, 8)`` 

2203 - ``thirdcol IS NOT NULL`` 

2204 - ``textcol LIKE '%paracetamol%'`` 

2205 - ``MATCH (fulltextcol AGAINST 'paracetamol')`` (MySQL) 

2206 - ``CONTAINS(fulltextcol, 'paracetamol')`` (SQL Server) 

2207 

2208 Args: 

2209 grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` 

2210 """ 

2211 if self._raw_sql: 

2212 return self._raw_sql 

2213 

2214 col = self._column_id.identifier(grammar) 

2215 op = self._op 

2216 

2217 if self._no_value: 

2218 return f"{col} {op}" 

2219 

2220 if self._datatype in QB_STRING_TYPES: 

2221 element_converter = sql_string_literal 

2222 elif self._datatype == QB_DATATYPE_DATE: 

2223 element_converter = sql_date_literal 

2224 elif self._datatype == QB_DATATYPE_INTEGER: 

2225 element_converter = str 

2226 elif self._datatype == QB_DATATYPE_FLOAT: 

2227 element_converter = str 

2228 else: 

2229 # Safe default 

2230 element_converter = sql_string_literal 

2231 

2232 if self._multivalue: 

2233 literal = "({})".format( 

2234 ", ".join(element_converter(v) for v in self._value) 

2235 ) 

2236 else: 

2237 literal = element_converter(self._value) 

2238 

2239 if self._op == "MATCH": # MySQL 

2240 return f"MATCH ({col}) AGAINST ({literal})" 

2241 elif self._op == "CONTAINS": # SQL Server 

2242 return f"CONTAINS({col}, {literal})" 

2243 else: 

2244 return f"{col} {op} {literal}" 

2245 

2246 

2247# ============================================================================= 

2248# SQL formatting 

2249# ============================================================================= 

2250 

2251 

2252def format_sql_for_print(sql: str) -> str: 

2253 """ 

2254 Very simple SQL formatting. 

2255 

2256 Remove blank lines and trailing spaces from an SQL statement. 

2257 Converts tabs to spaces. 

2258 """ 

2259 lines = list( 

2260 filter( 

2261 None, [x.replace("\t", " ").rstrip() for x in sql.splitlines()] 

2262 ) 

2263 ) 

2264 # Shift all lines left if they're left-padded 

2265 firstleftpos = float("inf") 

2266 for line in lines: 

2267 leftpos = len(line) - len(line.lstrip()) 

2268 firstleftpos = min(firstleftpos, leftpos) 

2269 if firstleftpos > 0: 

2270 lines = [x[firstleftpos:] for x in lines] 

2271 return "\n".join(lines) 

2272 

2273 

2274# ============================================================================= 

2275# Plain SQL types 

2276# ============================================================================= 

2277 

2278 

2279def is_sql_column_type_textual(column_type: str, min_length: int = 1) -> bool: 

2280 """ 

2281 Does an SQL column type look textual? 

2282 

2283 Args: 

2284 column_type: SQL column type as a string, e.g. ``"VARCHAR(50)"`` 

2285 min_length: what's the minimum string length we'll say "yes" to? 

2286 

2287 Returns: 

2288 is it a textual column (of the minimum length or more)? 

2289 

2290 Note: 

2291 

2292 - For SQL Server's NVARCHAR(MAX), 

2293 :meth:`crate_anon.crateweb.research.research_db_info._schema_query_microsoft` 

2294 returns "NVARCHAR(-1)" 

2295 """ 

2296 if not column_type: 

2297 return False 

2298 column_type = column_type.upper().split()[0] 

2299 if column_type in SQLTYPES_TEXT: 

2300 # A text type without a specific length 

2301 return True 

2302 try: 

2303 m = COLTYPE_WITH_ONE_INTEGER_REGEX.match(column_type) 

2304 basetype = m.group(1) 

2305 length = int(m.group(2)) 

2306 except (AttributeError, ValueError): 

2307 return False 

2308 return (length >= min_length or length < 0) and basetype in SQLTYPES_TEXT 

2309 

2310 

2311def coltype_length_if_text(column_type: str, dialect: str) -> Optional[int]: 

2312 """ 

2313 Find the length of an sql text column type. 

2314 

2315 Args: 

2316 column_type: SQL column type as a string, e.g. ``"VARCHAR(50)"`` 

2317 dialect: the SQL dialect the column type is from 

2318 

2319 Returns: 

2320 length of the column or ``None`` if it's not a text column. 

2321 

2322 """ 

2323 column_type = column_type.upper() 

2324 if column_type in SQLTYPES_TEXT: 

2325 # No length specified - get the default 

2326 try: 

2327 lookup = DIALECT_TO_STRING_LEN_LOOKUP[dialect] 

2328 except KeyError: 

2329 possible = list(DIALECT_TO_STRING_LEN_LOOKUP.keys()) 

2330 raise ValueError( 

2331 f"CRATE doesn't properly understand SQL dialect {dialect!r}. " 

2332 f"Supported: {possible}" 

2333 ) 

2334 try: 

2335 return lookup[column_type] 

2336 except KeyError: 

2337 raise ValueError( 

2338 f"For SQL dialect {dialect!r}, CRATE doesn't know the length " 

2339 f"for string data type {column_type!r}" 

2340 ) 

2341 else: 

2342 # Length specified - get it from the column type 

2343 try: 

2344 m = COLTYPE_WITH_ONE_INTEGER_REGEX.match(column_type) 

2345 basetype = m.group(1) 

2346 length = m.group(2) 

2347 if length == "MAX" or length == "-1": 

2348 if dialect == SqlaDialectName.MSSQL: 

2349 if basetype == "VARCHAR": 

2350 return MSSQL_COLTYPE_TO_LEN["VARCHAR_MAX"] 

2351 elif basetype == "NVARCHAR": 

2352 return MSSQL_COLTYPE_TO_LEN["NVARCHAR_MAX"] 

2353 return None 

2354 except AttributeError: 

2355 # Not the correct type of column 

2356 return None 

2357 try: 

2358 return int(length) 

2359 except ValueError: 

2360 # Not the correct type of column 

2361 return None 

2362 

2363 

2364def escape_quote_in_literal(s: str) -> str: 

2365 r""" 

2366 Escape ``'``. We could use ``''`` or ``\'``. 

2367 Let's use ``\.`` for consistency with percent escaping. 

2368 """ 

2369 return s.replace("'", r"\'") 

2370 

2371 

2372def escape_percent_in_literal(sql: str) -> str: 

2373 r""" 

2374 Escapes ``%`` by converting it to ``\%``. 

2375 Use this for LIKE clauses. 

2376 

2377 - https://dev.mysql.com/doc/refman/5.7/en/string-literals.html 

2378 """ 

2379 return sql.replace("%", r"\%") 

2380 

2381 

2382def escape_percent_for_python_dbapi(sql: str) -> str: 

2383 """ 

2384 Escapes ``%`` by converting it to ``%%``. 

2385 Use this for SQL within Python where ``%`` characters are used for argument 

2386 placeholders. 

2387 """ 

2388 return sql.replace("%", "%%") 

2389 

2390 

2391def escape_sql_string_literal(s: str) -> str: 

2392 """ 

2393 Escapes SQL string literal fragments against quotes and parameter 

2394 substitution. 

2395 """ 

2396 return escape_percent_in_literal(escape_quote_in_literal(s)) 

2397 

2398 

2399def make_string_literal(s: str) -> str: 

2400 """ 

2401 Converts a Python string into an SQL single-quoted (and escaped) string 

2402 literal. 

2403 """ 

2404 return f"'{escape_sql_string_literal(s)}'" 

2405 

2406 

2407def escape_sql_string_or_int_literal(s: Union[str, int]) -> str: 

2408 """ 

2409 Converts an integer or a string into an SQL literal (with single quotes and 

2410 escaping in the case of a string). 

2411 """ 

2412 if isinstance(s, int): 

2413 return str(s) 

2414 else: 

2415 return make_string_literal(s) 

2416 

2417 

2418def translate_sql_qmark_to_percent(sql: str) -> str: 

2419 """ 

2420 This function translates SQL using ``?`` placeholders to SQL using ``%s`` 

2421 placeholders, without breaking literal ``'?'`` or ``'%'``, e.g. inside 

2422 string literals. 

2423 

2424 *Notes* 

2425 

2426 - MySQL likes ``?`` as a placeholder. 

2427 

2428 - https://dev.mysql.com/doc/refman/5.7/en/sql-syntax-prepared-statements.html 

2429 

2430 - Python DBAPI allows several: ``%s``, ``?``, ``:1``, ``:name``, 

2431 ``%(name)s``. 

2432 

2433 - https://www.python.org/dev/peps/pep-0249/#paramstyle 

2434 

2435 - Django uses ``%s``. 

2436 

2437 - https://docs.djangoproject.com/en/1.8/topics/db/sql/ 

2438 

2439 - Microsoft like ``?``, ``@paramname``, and ``:paramname``. 

2440 

2441 - https://msdn.microsoft.com/en-us/library/yy6y35y8(v=vs.110).aspx 

2442 

2443 - We need to parse SQL with argument placeholders. 

2444 

2445 - See :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` classes, 

2446 particularly: ``bind_parameter`` 

2447 

2448 I prefer ``?``, because ``%`` is used in LIKE clauses, and the databases 

2449 we're using like it. 

2450 

2451 So: 

2452 

2453 - We use ``%s`` when using ``cursor.execute()`` directly, via Django. 

2454 - We use ``?`` when talking to users, and 

2455 :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` objects, so that 

2456 the visual appearance matches what they expect from their database. 

2457 

2458 """ # noqa: E501 

2459 # 1. Escape % characters 

2460 sql = escape_percent_for_python_dbapi(sql) 

2461 # 2. Replace ? characters that are not within quotes with %s. 

2462 newsql = "" 

2463 in_quotes = False 

2464 for c in sql: 

2465 if c == "'": 

2466 in_quotes = not in_quotes 

2467 if c == "?" and not in_quotes: 

2468 newsql += "%s" 

2469 else: 

2470 newsql += c 

2471 return newsql 

2472 

2473 

2474def decorate_index_name( 

2475 idxname: str, tablename: str = None, engine: Engine = None 

2476) -> str: 

2477 """ 

2478 Amend the name of a database index. Specifically, this is because SQLite 

2479 (which we won't use much, but do use for testing!) won't accept two indexes 

2480 with the same names applying to different tables. 

2481 

2482 Args: 

2483 idxname: 

2484 The original index name. 

2485 tablename: 

2486 The name of the table. 

2487 engine: 

2488 The SQLAlchemy engine, from which we obtain the dialect. 

2489 

2490 Returns: 

2491 The index name, amended if necessary. 

2492 """ 

2493 if not tablename or not engine: 

2494 return idxname 

2495 if engine.dialect.name == "sqlite": 

2496 return f"{idxname}_{tablename}" 

2497 return idxname