Coverage for crateweb/nlp_classification/database_connection.py: 81%

27 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-27 10:34 -0500

1from typing import Any, Iterable, Optional, Sequence 

2 

3from django.db import connections 

4 

5 

6class DatabaseConnection: 

7 def __init__(self, connection_name: str) -> None: 

8 self.connection_name = connection_name 

9 

10 def fetchone_as_dict( 

11 self, 

12 column_names: Iterable[str], 

13 table_name: str, 

14 where: Optional[str] = None, 

15 params: Optional[Sequence[Any]] = None, 

16 ) -> dict[str, Any]: 

17 

18 out = {} 

19 

20 sql = self.get_sql(column_names, table_name, where, params) 

21 

22 with connections[self.connection_name].cursor() as cursor: 

23 cursor.execute(sql, params) 

24 row = cursor.fetchone() 

25 

26 if row: 

27 for index, column_name in enumerate(column_names): 

28 out[column_name] = row[index] 

29 

30 return out 

31 

32 def fetchall( 

33 self, 

34 column_names: Iterable[str], 

35 table_name: str, 

36 where: Optional[str] = None, 

37 params: Optional[Sequence[Any]] = None, 

38 ): # TODO: Return type 

39 sql = self.get_sql(column_names, table_name, where, params) 

40 

41 with connections[self.connection_name].cursor() as cursor: 

42 cursor.execute(sql, params) 

43 for row in cursor.fetchall(): 

44 yield row 

45 

46 def get_sql( 

47 self, 

48 column_names: Iterable[str], 

49 table_name: str, 

50 where: Optional[str] = None, 

51 params: Optional[Sequence[Any]] = None, 

52 ) -> str: 

53 

54 column_names_str = ", ".join(column_names) 

55 sql = f"SELECT {column_names_str} FROM {table_name}" 

56 if where: 

57 sql += f" WHERE {where}" 

58 

59 return sql