Coverage for crateweb/research/tests/sql_writer_tests.py: 100%

43 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-27 10:34 -0500

1""" 

2crate_anon/crateweb/research/tests/sql_writer_tests.py 

3 

4=============================================================================== 

5 

6 Copyright (C) 2015, University of Cambridge, Department of Psychiatry. 

7 Created by Rudolf Cardinal (rnc1001@cam.ac.uk). 

8 

9 This file is part of CRATE. 

10 

11 CRATE is free software: you can redistribute it and/or modify 

12 it under the terms of the GNU General Public License as published by 

13 the Free Software Foundation, either version 3 of the License, or 

14 (at your option) any later version. 

15 

16 CRATE is distributed in the hope that it will be useful, 

17 but WITHOUT ANY WARRANTY; without even the implied warranty of 

18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

19 GNU General Public License for more details. 

20 

21 You should have received a copy of the GNU General Public License 

22 along with CRATE. If not, see <https://www.gnu.org/licenses/>. 

23 

24=============================================================================== 

25 

26Test sql_writer.py. 

27 

28""" 

29 

30import logging 

31 

32from cardinal_pythonlib.sql.sql_grammar_factory import make_grammar 

33from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName 

34from django.test import TestCase 

35 

36from crate_anon.common.sql import ColumnId, WhereCondition 

37from crate_anon.crateweb.core.constants import ( 

38 DJANGO_DEFAULT_CONNECTION, 

39 RESEARCH_DB_CONNECTION_NAME, 

40) 

41from crate_anon.crateweb.research.errors import DatabaseStructureNotUnderstood 

42from crate_anon.crateweb.research.sql_writer import ( 

43 add_to_select, 

44 SelectElement, 

45) 

46 

47log = logging.getLogger(__name__) 

48 

49 

50class AddToSelectTests(TestCase): 

51 databases = {DJANGO_DEFAULT_CONNECTION, RESEARCH_DB_CONNECTION_NAME} 

52 

53 def setUp(self) -> None: 

54 super().setUp() 

55 self.grammar = make_grammar(SqlaDialectName.MYSQL) 

56 

57 def assert_query_equal(self, actual: str, expected: str) -> None: 

58 # Test a query string matches the expected value, ignoring 

59 # whitespace differences 

60 actual = actual.replace(" ,", ",") 

61 actual = " ".join(actual.split()) 

62 

63 self.assertEqual(actual, expected) 

64 

65 def test_second_table_joined(self) -> None: 

66 sql = add_to_select( 

67 "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5", 

68 grammar=self.grammar, 

69 select_elements=[ 

70 SelectElement(column_id=ColumnId(table="t2", column="c")) 

71 ], 

72 # magic_join requires DB knowledge hence Django 

73 magic_join=False, 

74 ) 

75 self.assert_query_equal( 

76 sql, 

77 ( 

78 "SELECT t1.a, t1.b, t2.c FROM t1 NATURAL JOIN t2 " 

79 "WHERE t1.col1 > 5" 

80 ), 

81 ) 

82 

83 def test_another_column_added(self) -> None: 

84 sql = add_to_select( 

85 "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5", 

86 grammar=self.grammar, 

87 select_elements=[ 

88 SelectElement(column_id=ColumnId(table="t1", column="a")) 

89 ], 

90 ) 

91 self.assert_query_equal( 

92 sql, "SELECT t1.a, t1.b, t1.a FROM t1 WHERE t1.col1 > 5" 

93 ) 

94 

95 def test_select_element_added_to_nothing(self) -> None: 

96 sql = add_to_select( 

97 "", 

98 grammar=self.grammar, 

99 select_elements=[ 

100 SelectElement(column_id=ColumnId(table="t2", column="c")) 

101 ], 

102 ) 

103 self.assert_query_equal(sql, "SELECT t2.c FROM t2") 

104 

105 def test_first_where_condition_added(self) -> None: 

106 sql = add_to_select( 

107 "SELECT t1.a, t1.b FROM t1", 

108 grammar=self.grammar, 

109 where_conditions=[WhereCondition(raw_sql="t1.col1 > 5")], 

110 ) 

111 self.assert_query_equal( 

112 sql, "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5" 

113 ) 

114 

115 def test_second_where_condition_added(self) -> None: 

116 sql = add_to_select( 

117 "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5", 

118 grammar=self.grammar, 

119 where_conditions=[WhereCondition(raw_sql="t1.col2 < 3")], 

120 ) 

121 

122 self.assert_query_equal( 

123 sql, "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5 AND t1.col2 < 3" 

124 ) 

125 

126 def test_third_where_condition_added(self) -> None: 

127 sql = add_to_select( 

128 "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5 AND t3.col99 = 100", 

129 grammar=self.grammar, 

130 where_conditions=[WhereCondition(raw_sql="t1.col2 < 3")], 

131 ) 

132 self.assert_query_equal( 

133 sql, 

134 ( 

135 "SELECT t1.a, t1.b FROM t1 " 

136 "WHERE t1.col1 > 5 AND t3.col99 = 100 AND t1.col2 < 3" 

137 ), 

138 ) 

139 

140 def test_multiple_wheres_added_to_none(self) -> None: 

141 sql = add_to_select( 

142 "SELECT t1.a, t1.b FROM t1", 

143 grammar=self.grammar, 

144 where_conditions=[ 

145 WhereCondition(raw_sql="t1.col1 > 99"), 

146 WhereCondition(raw_sql="t1.col2 < 999"), 

147 ], 

148 ) 

149 self.assert_query_equal( 

150 sql, 

151 "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 99 AND t1.col2 < 999", 

152 ) 

153 

154 def test_raises_when_table_does_not_exist(self) -> None: 

155 column_id = ColumnId( 

156 schema="research", table="blobdoc", column="_src_hash" 

157 ) 

158 with self.assertRaises(DatabaseStructureNotUnderstood): 

159 add_to_select( 

160 "SELECT foo from bar", 

161 grammar=self.grammar, 

162 select_elements=[SelectElement(column_id=column_id)], 

163 )