Coverage for crateweb/research/tests/sql_writer_tests.py: 100%
43 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-27 10:34 -0500
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-27 10:34 -0500
1"""
2crate_anon/crateweb/research/tests/sql_writer_tests.py
4===============================================================================
6 Copyright (C) 2015, University of Cambridge, Department of Psychiatry.
7 Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
9 This file is part of CRATE.
11 CRATE is free software: you can redistribute it and/or modify
12 it under the terms of the GNU General Public License as published by
13 the Free Software Foundation, either version 3 of the License, or
14 (at your option) any later version.
16 CRATE is distributed in the hope that it will be useful,
17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 GNU General Public License for more details.
21 You should have received a copy of the GNU General Public License
22 along with CRATE. If not, see <https://www.gnu.org/licenses/>.
24===============================================================================
26Test sql_writer.py.
28"""
30import logging
32from cardinal_pythonlib.sql.sql_grammar_factory import make_grammar
33from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName
34from django.test import TestCase
36from crate_anon.common.sql import ColumnId, WhereCondition
37from crate_anon.crateweb.core.constants import (
38 DJANGO_DEFAULT_CONNECTION,
39 RESEARCH_DB_CONNECTION_NAME,
40)
41from crate_anon.crateweb.research.errors import DatabaseStructureNotUnderstood
42from crate_anon.crateweb.research.sql_writer import (
43 add_to_select,
44 SelectElement,
45)
47log = logging.getLogger(__name__)
50class AddToSelectTests(TestCase):
51 databases = {DJANGO_DEFAULT_CONNECTION, RESEARCH_DB_CONNECTION_NAME}
53 def setUp(self) -> None:
54 super().setUp()
55 self.grammar = make_grammar(SqlaDialectName.MYSQL)
57 def assert_query_equal(self, actual: str, expected: str) -> None:
58 # Test a query string matches the expected value, ignoring
59 # whitespace differences
60 actual = actual.replace(" ,", ",")
61 actual = " ".join(actual.split())
63 self.assertEqual(actual, expected)
65 def test_second_table_joined(self) -> None:
66 sql = add_to_select(
67 "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5",
68 grammar=self.grammar,
69 select_elements=[
70 SelectElement(column_id=ColumnId(table="t2", column="c"))
71 ],
72 # magic_join requires DB knowledge hence Django
73 magic_join=False,
74 )
75 self.assert_query_equal(
76 sql,
77 (
78 "SELECT t1.a, t1.b, t2.c FROM t1 NATURAL JOIN t2 "
79 "WHERE t1.col1 > 5"
80 ),
81 )
83 def test_another_column_added(self) -> None:
84 sql = add_to_select(
85 "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5",
86 grammar=self.grammar,
87 select_elements=[
88 SelectElement(column_id=ColumnId(table="t1", column="a"))
89 ],
90 )
91 self.assert_query_equal(
92 sql, "SELECT t1.a, t1.b, t1.a FROM t1 WHERE t1.col1 > 5"
93 )
95 def test_select_element_added_to_nothing(self) -> None:
96 sql = add_to_select(
97 "",
98 grammar=self.grammar,
99 select_elements=[
100 SelectElement(column_id=ColumnId(table="t2", column="c"))
101 ],
102 )
103 self.assert_query_equal(sql, "SELECT t2.c FROM t2")
105 def test_first_where_condition_added(self) -> None:
106 sql = add_to_select(
107 "SELECT t1.a, t1.b FROM t1",
108 grammar=self.grammar,
109 where_conditions=[WhereCondition(raw_sql="t1.col1 > 5")],
110 )
111 self.assert_query_equal(
112 sql, "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5"
113 )
115 def test_second_where_condition_added(self) -> None:
116 sql = add_to_select(
117 "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5",
118 grammar=self.grammar,
119 where_conditions=[WhereCondition(raw_sql="t1.col2 < 3")],
120 )
122 self.assert_query_equal(
123 sql, "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5 AND t1.col2 < 3"
124 )
126 def test_third_where_condition_added(self) -> None:
127 sql = add_to_select(
128 "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5 AND t3.col99 = 100",
129 grammar=self.grammar,
130 where_conditions=[WhereCondition(raw_sql="t1.col2 < 3")],
131 )
132 self.assert_query_equal(
133 sql,
134 (
135 "SELECT t1.a, t1.b FROM t1 "
136 "WHERE t1.col1 > 5 AND t3.col99 = 100 AND t1.col2 < 3"
137 ),
138 )
140 def test_multiple_wheres_added_to_none(self) -> None:
141 sql = add_to_select(
142 "SELECT t1.a, t1.b FROM t1",
143 grammar=self.grammar,
144 where_conditions=[
145 WhereCondition(raw_sql="t1.col1 > 99"),
146 WhereCondition(raw_sql="t1.col2 < 999"),
147 ],
148 )
149 self.assert_query_equal(
150 sql,
151 "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 99 AND t1.col2 < 999",
152 )
154 def test_raises_when_table_does_not_exist(self) -> None:
155 column_id = ColumnId(
156 schema="research", table="blobdoc", column="_src_hash"
157 )
158 with self.assertRaises(DatabaseStructureNotUnderstood):
159 add_to_select(
160 "SELECT foo from bar",
161 grammar=self.grammar,
162 select_elements=[SelectElement(column_id=column_id)],
163 )