Coverage for cc_modules/cc_taskfactory.py: 49%
43 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-15 15:51 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-15 15:51 +0100
1"""
2camcops_server/cc_modules/cc_taskfactory.py
4===============================================================================
6 Copyright (C) 2012, University of Cambridge, Department of Psychiatry.
7 Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
9 This file is part of CamCOPS.
11 CamCOPS is free software: you can redistribute it and/or modify
12 it under the terms of the GNU General Public License as published by
13 the Free Software Foundation, either version 3 of the License, or
14 (at your option) any later version.
16 CamCOPS is distributed in the hope that it will be useful,
17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 GNU General Public License for more details.
21 You should have received a copy of the GNU General Public License
22 along with CamCOPS. If not, see <https://www.gnu.org/licenses/>.
24===============================================================================
26**Functions to fetch tasks from the database.**
28"""
30import logging
31from typing import Optional, Type, TYPE_CHECKING, Union
33from cardinal_pythonlib.logs import BraceStyleAdapter
34import pyramid.httpexceptions as exc
35from sqlalchemy.orm import Query, Session as SqlASession
37from camcops_server.cc_modules.cc_task import (
38 tablename_to_task_class_dict,
39 Task,
40)
41from camcops_server.cc_modules.cc_taskindex import TaskIndexEntry
43if TYPE_CHECKING:
44 from camcops_server.cc_modules.cc_request import CamcopsRequest
46log = BraceStyleAdapter(logging.getLogger(__name__))
49# =============================================================================
50# Task query helpers
51# =============================================================================
54def task_query_restricted_to_permitted_users(
55 req: "CamcopsRequest",
56 q: Query,
57 cls: Union[Type[Task], Type[TaskIndexEntry]],
58 as_dump: bool,
59) -> Optional[Query]:
60 """
61 Restricts an SQLAlchemy ORM query to permitted users, for a given
62 task class. THIS IS A KEY SECURITY FUNCTION.
64 Args:
65 req:
66 the :class:`camcops_server.cc_modules.cc_request.CamcopsRequest`
67 q:
68 the SQLAlchemy ORM query
69 cls:
70 the class of the task type, or the
71 :class:`camcops_server.cc_modules.cc_taskindex.TaskIndexEntry`
72 class
73 as_dump:
74 use the "dump" permissions rather than the "view" permissions?
76 Returns:
77 a filtered query (or the original query, if no filtering was required)
79 """
80 user = req.user
82 if user.superuser:
83 return q # anything goes
85 # Implement group security. Simple:
86 if as_dump:
87 group_ids = user.ids_of_groups_user_may_dump
88 else:
89 group_ids = user.ids_of_groups_user_may_see
91 if not group_ids:
92 return None
94 if cls is TaskIndexEntry:
95 # noinspection PyUnresolvedReferences
96 q = q.filter(cls.group_id.in_(group_ids)) # type: ignore[union-attr]
97 else: # a kind of Task
98 q = q.filter(cls._group_id.in_(group_ids)) # type: ignore[union-attr]
100 return q
103# =============================================================================
104# Make a single task given its base table name and server PK
105# =============================================================================
108def task_factory(
109 req: "CamcopsRequest", basetable: str, serverpk: int
110) -> Optional[Task]:
111 """
112 Load a task from the database and return it.
113 Filters to tasks permitted to the current user.
115 Args:
116 req: the :class:`camcops_server.cc_modules.cc_request.CamcopsRequest`
117 basetable: name of the task's base table
118 serverpk: server PK of the task
120 Returns:
121 the task, or ``None`` if the PK doesn't exist
123 Raises:
124 :exc:`HTTPBadRequest` if the table doesn't exist
126 """
127 d = tablename_to_task_class_dict()
128 try:
129 cls = d[basetable] # may raise KeyError
130 except KeyError:
131 raise exc.HTTPBadRequest(f"No such task table: {basetable!r}")
132 dbsession = req.dbsession
133 # noinspection PyProtectedMember
134 q = dbsession.query(cls).filter(cls._pk == serverpk)
135 q = task_query_restricted_to_permitted_users(req, q, cls, as_dump=False)
136 return q.first()
139def task_factory_no_security_checks(
140 dbsession: SqlASession, basetable: str, serverpk: int
141) -> Optional[Task]:
142 """
143 Load a task from the database and return it.
144 Filters to tasks permitted to the current user.
146 Args:
147 dbsession: a :class:`sqlalchemy.orm.session.Session`
148 basetable: name of the task's base table
149 serverpk: server PK of the task
151 Returns:
152 the task, or ``None`` if the PK doesn't exist
154 Raises:
155 :exc:`KeyError` if the table doesn't exist
156 """
157 d = tablename_to_task_class_dict()
158 cls = d[basetable] # may raise KeyError
159 # noinspection PyProtectedMember
160 q = dbsession.query(cls).filter(cls._pk == serverpk)
161 return q.first()
164# =============================================================================
165# Make a single task given its base table name and server PK
166# =============================================================================
169def task_factory_clientkeys_no_security_checks(
170 dbsession: SqlASession,
171 basetable: str,
172 client_id: int,
173 device_id: int,
174 era: str,
175) -> Optional[Task]:
176 """
177 Load a task from the database and return it.
178 Filters to tasks permitted to the current user.
180 Args:
181 dbsession: a :class:`sqlalchemy.orm.session.Session`
182 basetable: name of the task's base table
183 client_id: task's ``_id`` value
184 device_id: task's ``_device_id`` value
185 era: task's ``_era`` value
187 Returns:
188 the task, or ``None`` if it doesn't exist
190 Raises:
191 :exc:`KeyError` if the table doesn't exist
192 """
193 d = tablename_to_task_class_dict()
194 cls = d[basetable] # may raise KeyError
195 # noinspection PyProtectedMember
196 q = (
197 dbsession.query(cls)
198 .filter(cls.id == client_id)
199 .filter(cls._device_id == device_id)
200 .filter(cls._era == era)
201 )
202 return q.first()