Coverage for dynamodx / transact_writer.py: 68%

110 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-23 16:15 -0300

1from typing import TYPE_CHECKING, Any, Self, Type, TypedDict 

2 

3import jmespath 

4 

5from .types import deserialize, serialize, to_dict 

6 

7if TYPE_CHECKING: 

8 from mypy_boto3_dynamodb.client import DynamoDBClient 

9 from mypy_boto3_dynamodb.literals import ReturnValuesOnConditionCheckFailureType 

10 from mypy_boto3_dynamodb.type_defs import TransactWriteItemTypeDef 

11else: 

12 DynamoDBClient = Any 

13 ReturnValuesOnConditionCheckFailureType = Any 

14 TransactWriteItemTypeDef = Any 

15 

16 

17class TransactionCanceledReason(TypedDict): 

18 code: str 

19 message: str 

20 operation: dict 

21 old_item: dict 

22 

23 

24class TransactionOperationFailed(Exception): 

25 msg: str 

26 reason: TransactionCanceledReason 

27 

28 def __init__( 

29 self, 

30 msg: str = '', 

31 *, 

32 reason: TransactionCanceledReason, 

33 ) -> None: 

34 super().__init__(msg) 

35 self.msg = msg 

36 self.reason = reason 

37 

38 

39class TransactionCanceledException(Exception): 

40 def __init__( 

41 self, 

42 msg: str = '', 

43 *, 

44 reasons: list[TransactionCanceledReason] = [], 

45 ) -> None: 

46 super().__init__(msg) 

47 self.msg = msg 

48 self.reasons = reasons 

49 

50 

51class TransactOperation: 

52 def __init__( 

53 self, 

54 operation: dict, 

55 exc_cls: type[Exception] | None = None, 

56 ) -> None: 

57 self.operation = operation 

58 self.exc_cls = exc_cls 

59 

60 

61class TransactWriter: 

62 def __init__( 

63 self, 

64 table_name: str, 

65 *, 

66 flush_amount: int = 50, 

67 client: DynamoDBClient, 

68 fail_fast: bool = True, 

69 ) -> None: 

70 self._table_name = table_name 

71 self._items_buffer: list[TransactOperation] = [] 

72 self._flush_amount = flush_amount 

73 self._client = client 

74 self._fail_fast = fail_fast 

75 

76 def __enter__(self) -> Self: 

77 return self 

78 

79 def __exit__(self, *exc_details) -> None: 

80 # When we exit, we need to keep flushing whatever's left 

81 # until there's nothing left in our items buffer. 

82 while self._items_buffer: 

83 self._flush() 

84 

85 def condition( 

86 self, 

87 key: dict, 

88 cond_expr: str, 

89 *, 

90 table_name: str | None = None, 

91 expr_attr_names: dict | None = None, 

92 expr_attr_values: dict | None = None, 

93 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None, 

94 exc_cls: Type[Exception] | None = None, 

95 ) -> None: 

96 attrs: dict = {} 

97 

98 if expr_attr_names: 

99 attrs['ExpressionAttributeNames'] = expr_attr_names 

100 

101 if expr_attr_values: 

102 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values) 

103 

104 if return_on_cond_fail: 

105 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail 

106 

107 self._add_op_and_process( 

108 TransactOperation( 

109 { 

110 'ConditionCheck': dict( 

111 TableName=table_name or self._table_name, 

112 Key=serialize(key), 

113 ConditionExpression=cond_expr, 

114 **attrs, 

115 ) 

116 }, 

117 exc_cls, 

118 ) 

119 ) 

120 

121 def put( 

122 self, 

123 item: dict, 

124 *, 

125 table_name: str | None = None, 

126 expr_attr_names: dict | None = None, 

127 expr_attr_values: dict | None = None, 

128 cond_expr: str | None = None, 

129 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None, 

130 exc_cls: Type[Exception] | None = None, 

131 ) -> None: 

132 is_dynamodb_mapped = getattr(item.__class__, '_is_dynamodb_mapped', False) 

133 serialized = serialize(to_dict(item) if is_dynamodb_mapped else item) # type: ignore 

134 attrs: dict = {} 

135 

136 if cond_expr: 

137 attrs['ConditionExpression'] = cond_expr 

138 

139 if expr_attr_names: 

140 attrs['ExpressionAttributeNames'] = expr_attr_names 

141 

142 if expr_attr_values: 

143 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values) 

144 

145 if return_on_cond_fail: 

146 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail 

147 

148 self._add_op_and_process( 

149 TransactOperation( 

150 { 

151 'Put': dict( 

152 TableName=table_name or self._table_name, 

153 Item=serialized, 

154 **attrs, 

155 ) 

156 }, 

157 exc_cls, 

158 ), 

159 ) 

160 

161 def delete( 

162 self, 

163 key: dict, 

164 *, 

165 table_name: str | None = None, 

166 cond_expr: str | None = None, 

167 expr_attr_names: dict | None = None, 

168 expr_attr_values: dict | None = None, 

169 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None, 

170 exc_cls: Type[Exception] | None = None, 

171 ) -> None: 

172 attrs: dict = {} 

173 

174 if cond_expr: 

175 attrs['ConditionExpression'] = cond_expr 

176 

177 if expr_attr_names: 

178 attrs['ExpressionAttributeNames'] = expr_attr_names 

179 

180 if expr_attr_values: 

181 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values) 

182 

183 if return_on_cond_fail: 

184 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail 

185 

186 self._add_op_and_process( 

187 TransactOperation( 

188 { 

189 'Delete': dict( 

190 TableName=table_name or self._table_name, 

191 Key=serialize(key), 

192 **attrs, 

193 ) 

194 }, 

195 exc_cls, 

196 ), 

197 ) 

198 

199 def update( 

200 self, 

201 key: dict, 

202 update_expr: str, 

203 *, 

204 cond_expr: str | None = None, 

205 table_name: str | None = None, 

206 expr_attr_names: dict | None = None, 

207 expr_attr_values: dict | None = None, 

208 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None, 

209 exc_cls: Type[Exception] | None = None, 

210 ) -> None: 

211 attrs: dict = {} 

212 

213 if cond_expr: 

214 attrs['ConditionExpression'] = cond_expr 

215 

216 if expr_attr_names: 

217 attrs['ExpressionAttributeNames'] = expr_attr_names 

218 

219 if expr_attr_values: 

220 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values) 

221 

222 if return_on_cond_fail: 

223 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail 

224 

225 self._add_op_and_process( 

226 TransactOperation( 

227 { 

228 'Update': dict( 

229 TableName=table_name or self._table_name, 

230 Key=serialize(key), 

231 UpdateExpression=update_expr, 

232 **attrs, 

233 ) 

234 }, 

235 exc_cls, 

236 ) 

237 ) 

238 

239 def _add_op_and_process(self, op: TransactOperation) -> None: 

240 self._items_buffer.append(op) 

241 self._flush_if_needed() 

242 

243 def _flush_if_needed(self) -> None: 

244 if len(self._items_buffer) >= self._flush_amount: 

245 self._flush() 

246 

247 def _flush(self) -> bool: 

248 items_to_send = self._items_buffer[: self._flush_amount] 

249 self._items_buffer = self._items_buffer[self._flush_amount :] 

250 

251 transact_items: list[TransactWriteItemTypeDef] = [ 

252 item.operation # type: ignore 

253 for item in items_to_send 

254 ] 

255 

256 try: 

257 self._client.transact_write_items(TransactItems=transact_items) 

258 except self._client.exceptions.TransactionCanceledException as err: 

259 error_msg = jmespath.search("Error.Message || 'Unknown'", err.response) 

260 cancellations = err.response.get('CancellationReasons', []) 

261 reasons = [] 

262 

263 for idx, reason in enumerate(cancellations): 

264 if 'Message' not in reason: 

265 continue 

266 

267 item = items_to_send[idx] 

268 cancellation_reason = TransactionCanceledReason( 

269 code=reason['Code'], # type: ignore 

270 message=reason['Message'], 

271 operation=item.operation, 

272 old_item=deserialize(reason.get('Item', {})), 

273 ) 

274 

275 if self._fail_fast: 

276 exc_cls = item.exc_cls or TransactionOperationFailed 

277 raise _exc_for_reason( 

278 exc_cls, error_msg, cancellation_reason 

279 ) from err 

280 

281 reasons.append(cancellation_reason) 

282 

283 raise TransactionCanceledException(error_msg, reasons=reasons) from err 

284 else: 

285 return True 

286 

287 

288def _exc_for_reason( 

289 exc_cls: Type[Exception], 

290 msg: str, 

291 reason: TransactionCanceledReason, 

292) -> Exception: 

293 if issubclass(exc_cls, TransactionOperationFailed): 

294 return exc_cls(msg, reason=reason) 

295 

296 exc = exc_cls(msg) 

297 setattr(exc, '__reason__', reason) 

298 return exc