Coverage for dynamodx / transact_writer.py: 68%
110 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-23 16:15 -0300
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-23 16:15 -0300
1from typing import TYPE_CHECKING, Any, Self, Type, TypedDict
3import jmespath
5from .types import deserialize, serialize, to_dict
7if TYPE_CHECKING:
8 from mypy_boto3_dynamodb.client import DynamoDBClient
9 from mypy_boto3_dynamodb.literals import ReturnValuesOnConditionCheckFailureType
10 from mypy_boto3_dynamodb.type_defs import TransactWriteItemTypeDef
11else:
12 DynamoDBClient = Any
13 ReturnValuesOnConditionCheckFailureType = Any
14 TransactWriteItemTypeDef = Any
17class TransactionCanceledReason(TypedDict):
18 code: str
19 message: str
20 operation: dict
21 old_item: dict
24class TransactionOperationFailed(Exception):
25 msg: str
26 reason: TransactionCanceledReason
28 def __init__(
29 self,
30 msg: str = '',
31 *,
32 reason: TransactionCanceledReason,
33 ) -> None:
34 super().__init__(msg)
35 self.msg = msg
36 self.reason = reason
39class TransactionCanceledException(Exception):
40 def __init__(
41 self,
42 msg: str = '',
43 *,
44 reasons: list[TransactionCanceledReason] = [],
45 ) -> None:
46 super().__init__(msg)
47 self.msg = msg
48 self.reasons = reasons
51class TransactOperation:
52 def __init__(
53 self,
54 operation: dict,
55 exc_cls: type[Exception] | None = None,
56 ) -> None:
57 self.operation = operation
58 self.exc_cls = exc_cls
61class TransactWriter:
62 def __init__(
63 self,
64 table_name: str,
65 *,
66 flush_amount: int = 50,
67 client: DynamoDBClient,
68 fail_fast: bool = True,
69 ) -> None:
70 self._table_name = table_name
71 self._items_buffer: list[TransactOperation] = []
72 self._flush_amount = flush_amount
73 self._client = client
74 self._fail_fast = fail_fast
76 def __enter__(self) -> Self:
77 return self
79 def __exit__(self, *exc_details) -> None:
80 # When we exit, we need to keep flushing whatever's left
81 # until there's nothing left in our items buffer.
82 while self._items_buffer:
83 self._flush()
85 def condition(
86 self,
87 key: dict,
88 cond_expr: str,
89 *,
90 table_name: str | None = None,
91 expr_attr_names: dict | None = None,
92 expr_attr_values: dict | None = None,
93 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None,
94 exc_cls: Type[Exception] | None = None,
95 ) -> None:
96 attrs: dict = {}
98 if expr_attr_names:
99 attrs['ExpressionAttributeNames'] = expr_attr_names
101 if expr_attr_values:
102 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
104 if return_on_cond_fail:
105 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail
107 self._add_op_and_process(
108 TransactOperation(
109 {
110 'ConditionCheck': dict(
111 TableName=table_name or self._table_name,
112 Key=serialize(key),
113 ConditionExpression=cond_expr,
114 **attrs,
115 )
116 },
117 exc_cls,
118 )
119 )
121 def put(
122 self,
123 item: dict,
124 *,
125 table_name: str | None = None,
126 expr_attr_names: dict | None = None,
127 expr_attr_values: dict | None = None,
128 cond_expr: str | None = None,
129 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None,
130 exc_cls: Type[Exception] | None = None,
131 ) -> None:
132 is_dynamodb_mapped = getattr(item.__class__, '_is_dynamodb_mapped', False)
133 serialized = serialize(to_dict(item) if is_dynamodb_mapped else item) # type: ignore
134 attrs: dict = {}
136 if cond_expr:
137 attrs['ConditionExpression'] = cond_expr
139 if expr_attr_names:
140 attrs['ExpressionAttributeNames'] = expr_attr_names
142 if expr_attr_values:
143 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
145 if return_on_cond_fail:
146 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail
148 self._add_op_and_process(
149 TransactOperation(
150 {
151 'Put': dict(
152 TableName=table_name or self._table_name,
153 Item=serialized,
154 **attrs,
155 )
156 },
157 exc_cls,
158 ),
159 )
161 def delete(
162 self,
163 key: dict,
164 *,
165 table_name: str | None = None,
166 cond_expr: str | None = None,
167 expr_attr_names: dict | None = None,
168 expr_attr_values: dict | None = None,
169 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None,
170 exc_cls: Type[Exception] | None = None,
171 ) -> None:
172 attrs: dict = {}
174 if cond_expr:
175 attrs['ConditionExpression'] = cond_expr
177 if expr_attr_names:
178 attrs['ExpressionAttributeNames'] = expr_attr_names
180 if expr_attr_values:
181 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
183 if return_on_cond_fail:
184 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail
186 self._add_op_and_process(
187 TransactOperation(
188 {
189 'Delete': dict(
190 TableName=table_name or self._table_name,
191 Key=serialize(key),
192 **attrs,
193 )
194 },
195 exc_cls,
196 ),
197 )
199 def update(
200 self,
201 key: dict,
202 update_expr: str,
203 *,
204 cond_expr: str | None = None,
205 table_name: str | None = None,
206 expr_attr_names: dict | None = None,
207 expr_attr_values: dict | None = None,
208 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None,
209 exc_cls: Type[Exception] | None = None,
210 ) -> None:
211 attrs: dict = {}
213 if cond_expr:
214 attrs['ConditionExpression'] = cond_expr
216 if expr_attr_names:
217 attrs['ExpressionAttributeNames'] = expr_attr_names
219 if expr_attr_values:
220 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
222 if return_on_cond_fail:
223 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail
225 self._add_op_and_process(
226 TransactOperation(
227 {
228 'Update': dict(
229 TableName=table_name or self._table_name,
230 Key=serialize(key),
231 UpdateExpression=update_expr,
232 **attrs,
233 )
234 },
235 exc_cls,
236 )
237 )
239 def _add_op_and_process(self, op: TransactOperation) -> None:
240 self._items_buffer.append(op)
241 self._flush_if_needed()
243 def _flush_if_needed(self) -> None:
244 if len(self._items_buffer) >= self._flush_amount:
245 self._flush()
247 def _flush(self) -> bool:
248 items_to_send = self._items_buffer[: self._flush_amount]
249 self._items_buffer = self._items_buffer[self._flush_amount :]
251 transact_items: list[TransactWriteItemTypeDef] = [
252 item.operation # type: ignore
253 for item in items_to_send
254 ]
256 try:
257 self._client.transact_write_items(TransactItems=transact_items)
258 except self._client.exceptions.TransactionCanceledException as err:
259 error_msg = jmespath.search("Error.Message || 'Unknown'", err.response)
260 cancellations = err.response.get('CancellationReasons', [])
261 reasons = []
263 for idx, reason in enumerate(cancellations):
264 if 'Message' not in reason:
265 continue
267 item = items_to_send[idx]
268 cancellation_reason = TransactionCanceledReason(
269 code=reason['Code'], # type: ignore
270 message=reason['Message'],
271 operation=item.operation,
272 old_item=deserialize(reason.get('Item', {})),
273 )
275 if self._fail_fast:
276 exc_cls = item.exc_cls or TransactionOperationFailed
277 raise _exc_for_reason(
278 exc_cls, error_msg, cancellation_reason
279 ) from err
281 reasons.append(cancellation_reason)
283 raise TransactionCanceledException(error_msg, reasons=reasons) from err
284 else:
285 return True
288def _exc_for_reason(
289 exc_cls: Type[Exception],
290 msg: str,
291 reason: TransactionCanceledReason,
292) -> Exception:
293 if issubclass(exc_cls, TransactionOperationFailed):
294 return exc_cls(msg, reason=reason)
296 exc = exc_cls(msg)
297 setattr(exc, '__reason__', reason)
298 return exc