Coverage for dynamodx / transact_writer.py: 68%
108 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-17 01:32 -0300
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-17 01:32 -0300
1from typing import TYPE_CHECKING, Any, Self, Type, TypedDict
3import jmespath
5from .types import deserialize, serialize
7if TYPE_CHECKING:
8 from mypy_boto3_dynamodb.client import DynamoDBClient
9 from mypy_boto3_dynamodb.literals import ReturnValuesOnConditionCheckFailureType
10 from mypy_boto3_dynamodb.type_defs import TransactWriteItemTypeDef
11else:
12 DynamoDBClient = Any
13 TransactWriteItemTypeDef = Any
14 ReturnValuesOnConditionCheckFailureType = Any
17class TransactionCanceledReason(TypedDict):
18 code: str
19 message: str
20 operation: dict
21 old_item: dict
24class TransactionOperationFailed(Exception):
25 msg: str
26 reason: TransactionCanceledReason
28 def __init__(
29 self,
30 msg: str = '',
31 *,
32 reason: TransactionCanceledReason,
33 ) -> None:
34 super().__init__(msg)
35 self.msg = msg
36 self.reason = reason
39class TransactionCanceledException(Exception):
40 def __init__(
41 self,
42 msg: str = '',
43 *,
44 reasons: list[TransactionCanceledReason] = [],
45 ) -> None:
46 super().__init__(msg)
47 self.msg = msg
48 self.reasons = reasons
51class TransactOperation:
52 def __init__(
53 self,
54 operation: dict,
55 exc_cls: type[Exception] | None = None,
56 ) -> None:
57 self.operation = operation
58 self.exc_cls = exc_cls
61class TransactWriter:
62 def __init__(
63 self,
64 table_name: str,
65 *,
66 flush_amount: int = 50,
67 client: DynamoDBClient,
68 fail_fast: bool = True,
69 ) -> None:
70 self._table_name = table_name
71 self._items_buffer: list[TransactOperation] = []
72 self._flush_amount = flush_amount
73 self._client = client
74 self._fail_fast = fail_fast
76 def __enter__(self) -> Self:
77 return self
79 def __exit__(self, *exc_details) -> None:
80 # When we exit, we need to keep flushing whatever's left
81 # until there's nothing left in our items buffer.
82 while self._items_buffer:
83 self._flush()
85 def condition(
86 self,
87 key: dict,
88 cond_expr: str,
89 *,
90 table_name: str | None = None,
91 expr_attr_names: dict | None = None,
92 expr_attr_values: dict | None = None,
93 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None,
94 exc_cls: Type[Exception] | None = None,
95 ) -> None:
96 attrs: dict = {}
98 if expr_attr_names:
99 attrs['ExpressionAttributeNames'] = expr_attr_names
101 if expr_attr_values:
102 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
104 if return_on_cond_fail:
105 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail
107 self._add_op_and_process(
108 TransactOperation(
109 {
110 'ConditionCheck': dict(
111 TableName=table_name or self._table_name,
112 Key=serialize(key),
113 ConditionExpression=cond_expr,
114 **attrs,
115 )
116 },
117 exc_cls,
118 )
119 )
121 def put(
122 self,
123 item: dict,
124 *,
125 table_name: str | None = None,
126 expr_attr_names: dict | None = None,
127 expr_attr_values: dict | None = None,
128 cond_expr: str | None = None,
129 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None,
130 exc_cls: Type[Exception] | None = None,
131 ) -> None:
132 attrs: dict = {}
134 if cond_expr:
135 attrs['ConditionExpression'] = cond_expr
137 if expr_attr_names:
138 attrs['ExpressionAttributeNames'] = expr_attr_names
140 if expr_attr_values:
141 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
143 if return_on_cond_fail:
144 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail
146 self._add_op_and_process(
147 TransactOperation(
148 {
149 'Put': dict(
150 TableName=table_name or self._table_name,
151 Item=serialize(item),
152 **attrs,
153 )
154 },
155 exc_cls,
156 ),
157 )
159 def delete(
160 self,
161 key: dict,
162 *,
163 table_name: str | None = None,
164 cond_expr: str | None = None,
165 expr_attr_names: dict | None = None,
166 expr_attr_values: dict | None = None,
167 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None,
168 exc_cls: Type[Exception] | None = None,
169 ) -> None:
170 attrs: dict = {}
172 if cond_expr:
173 attrs['ConditionExpression'] = cond_expr
175 if expr_attr_names:
176 attrs['ExpressionAttributeNames'] = expr_attr_names
178 if expr_attr_values:
179 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
181 if return_on_cond_fail:
182 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail
184 self._add_op_and_process(
185 TransactOperation(
186 {
187 'Delete': dict(
188 TableName=table_name or self._table_name,
189 Key=serialize(key),
190 **attrs,
191 )
192 },
193 exc_cls,
194 ),
195 )
197 def update(
198 self,
199 key: dict,
200 update_expr: str,
201 *,
202 cond_expr: str | None = None,
203 table_name: str | None = None,
204 expr_attr_names: dict | None = None,
205 expr_attr_values: dict | None = None,
206 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None,
207 exc_cls: Type[Exception] | None = None,
208 ) -> None:
209 attrs: dict = {}
211 if cond_expr:
212 attrs['ConditionExpression'] = cond_expr
214 if expr_attr_names:
215 attrs['ExpressionAttributeNames'] = expr_attr_names
217 if expr_attr_values:
218 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
220 if return_on_cond_fail:
221 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail
223 self._add_op_and_process(
224 TransactOperation(
225 {
226 'Update': dict(
227 TableName=table_name or self._table_name,
228 Key=serialize(key),
229 UpdateExpression=update_expr,
230 **attrs,
231 )
232 },
233 exc_cls,
234 )
235 )
237 def _add_op_and_process(self, op: TransactOperation) -> None:
238 self._items_buffer.append(op)
239 self._flush_if_needed()
241 def _flush_if_needed(self) -> None:
242 if len(self._items_buffer) >= self._flush_amount:
243 self._flush()
245 def _flush(self) -> bool:
246 items_to_send = self._items_buffer[: self._flush_amount]
247 self._items_buffer = self._items_buffer[self._flush_amount :]
249 transact_items: list[TransactWriteItemTypeDef] = [
250 item.operation # type: ignore
251 for item in items_to_send
252 ]
254 try:
255 self._client.transact_write_items(TransactItems=transact_items)
256 except self._client.exceptions.TransactionCanceledException as err:
257 error_msg = jmespath.search("Error.Message || 'Unknown'", err.response)
258 cancellations = err.response.get('CancellationReasons', [])
259 reasons = []
261 for idx, reason in enumerate(cancellations):
262 if 'Message' not in reason:
263 continue
265 item = items_to_send[idx]
266 cancellation_reason = TransactionCanceledReason(
267 code=reason['Code'], # type: ignore
268 message=reason['Message'],
269 operation=item.operation,
270 old_item=deserialize(reason.get('Item', {})),
271 )
273 if self._fail_fast:
274 exc_cls = item.exc_cls or TransactionOperationFailed
275 raise _exc_for_reason(
276 exc_cls, error_msg, cancellation_reason
277 ) from err
279 reasons.append(cancellation_reason)
281 raise TransactionCanceledException(error_msg, reasons=reasons) from err
282 else:
283 return True
286def _exc_for_reason(
287 exc_cls: Type[Exception],
288 msg: str,
289 reason: TransactionCanceledReason,
290) -> Exception:
291 if issubclass(exc_cls, TransactionOperationFailed):
292 return exc_cls(msg, reason=reason)
294 exc = exc_cls(msg)
295 setattr(exc, '__reason__', reason)
296 return exc