Coverage for dynamodx / transact_writer.py: 68%

108 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-17 01:32 -0300

1from typing import TYPE_CHECKING, Any, Self, Type, TypedDict 

2 

3import jmespath 

4 

5from .types import deserialize, serialize 

6 

7if TYPE_CHECKING: 

8 from mypy_boto3_dynamodb.client import DynamoDBClient 

9 from mypy_boto3_dynamodb.literals import ReturnValuesOnConditionCheckFailureType 

10 from mypy_boto3_dynamodb.type_defs import TransactWriteItemTypeDef 

11else: 

12 DynamoDBClient = Any 

13 TransactWriteItemTypeDef = Any 

14 ReturnValuesOnConditionCheckFailureType = Any 

15 

16 

17class TransactionCanceledReason(TypedDict): 

18 code: str 

19 message: str 

20 operation: dict 

21 old_item: dict 

22 

23 

24class TransactionOperationFailed(Exception): 

25 msg: str 

26 reason: TransactionCanceledReason 

27 

28 def __init__( 

29 self, 

30 msg: str = '', 

31 *, 

32 reason: TransactionCanceledReason, 

33 ) -> None: 

34 super().__init__(msg) 

35 self.msg = msg 

36 self.reason = reason 

37 

38 

39class TransactionCanceledException(Exception): 

40 def __init__( 

41 self, 

42 msg: str = '', 

43 *, 

44 reasons: list[TransactionCanceledReason] = [], 

45 ) -> None: 

46 super().__init__(msg) 

47 self.msg = msg 

48 self.reasons = reasons 

49 

50 

51class TransactOperation: 

52 def __init__( 

53 self, 

54 operation: dict, 

55 exc_cls: type[Exception] | None = None, 

56 ) -> None: 

57 self.operation = operation 

58 self.exc_cls = exc_cls 

59 

60 

61class TransactWriter: 

62 def __init__( 

63 self, 

64 table_name: str, 

65 *, 

66 flush_amount: int = 50, 

67 client: DynamoDBClient, 

68 fail_fast: bool = True, 

69 ) -> None: 

70 self._table_name = table_name 

71 self._items_buffer: list[TransactOperation] = [] 

72 self._flush_amount = flush_amount 

73 self._client = client 

74 self._fail_fast = fail_fast 

75 

76 def __enter__(self) -> Self: 

77 return self 

78 

79 def __exit__(self, *exc_details) -> None: 

80 # When we exit, we need to keep flushing whatever's left 

81 # until there's nothing left in our items buffer. 

82 while self._items_buffer: 

83 self._flush() 

84 

85 def condition( 

86 self, 

87 key: dict, 

88 cond_expr: str, 

89 *, 

90 table_name: str | None = None, 

91 expr_attr_names: dict | None = None, 

92 expr_attr_values: dict | None = None, 

93 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None, 

94 exc_cls: Type[Exception] | None = None, 

95 ) -> None: 

96 attrs: dict = {} 

97 

98 if expr_attr_names: 

99 attrs['ExpressionAttributeNames'] = expr_attr_names 

100 

101 if expr_attr_values: 

102 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values) 

103 

104 if return_on_cond_fail: 

105 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail 

106 

107 self._add_op_and_process( 

108 TransactOperation( 

109 { 

110 'ConditionCheck': dict( 

111 TableName=table_name or self._table_name, 

112 Key=serialize(key), 

113 ConditionExpression=cond_expr, 

114 **attrs, 

115 ) 

116 }, 

117 exc_cls, 

118 ) 

119 ) 

120 

121 def put( 

122 self, 

123 item: dict, 

124 *, 

125 table_name: str | None = None, 

126 expr_attr_names: dict | None = None, 

127 expr_attr_values: dict | None = None, 

128 cond_expr: str | None = None, 

129 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None, 

130 exc_cls: Type[Exception] | None = None, 

131 ) -> None: 

132 attrs: dict = {} 

133 

134 if cond_expr: 

135 attrs['ConditionExpression'] = cond_expr 

136 

137 if expr_attr_names: 

138 attrs['ExpressionAttributeNames'] = expr_attr_names 

139 

140 if expr_attr_values: 

141 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values) 

142 

143 if return_on_cond_fail: 

144 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail 

145 

146 self._add_op_and_process( 

147 TransactOperation( 

148 { 

149 'Put': dict( 

150 TableName=table_name or self._table_name, 

151 Item=serialize(item), 

152 **attrs, 

153 ) 

154 }, 

155 exc_cls, 

156 ), 

157 ) 

158 

159 def delete( 

160 self, 

161 key: dict, 

162 *, 

163 table_name: str | None = None, 

164 cond_expr: str | None = None, 

165 expr_attr_names: dict | None = None, 

166 expr_attr_values: dict | None = None, 

167 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None, 

168 exc_cls: Type[Exception] | None = None, 

169 ) -> None: 

170 attrs: dict = {} 

171 

172 if cond_expr: 

173 attrs['ConditionExpression'] = cond_expr 

174 

175 if expr_attr_names: 

176 attrs['ExpressionAttributeNames'] = expr_attr_names 

177 

178 if expr_attr_values: 

179 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values) 

180 

181 if return_on_cond_fail: 

182 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail 

183 

184 self._add_op_and_process( 

185 TransactOperation( 

186 { 

187 'Delete': dict( 

188 TableName=table_name or self._table_name, 

189 Key=serialize(key), 

190 **attrs, 

191 ) 

192 }, 

193 exc_cls, 

194 ), 

195 ) 

196 

197 def update( 

198 self, 

199 key: dict, 

200 update_expr: str, 

201 *, 

202 cond_expr: str | None = None, 

203 table_name: str | None = None, 

204 expr_attr_names: dict | None = None, 

205 expr_attr_values: dict | None = None, 

206 return_on_cond_fail: ReturnValuesOnConditionCheckFailureType | None = None, 

207 exc_cls: Type[Exception] | None = None, 

208 ) -> None: 

209 attrs: dict = {} 

210 

211 if cond_expr: 

212 attrs['ConditionExpression'] = cond_expr 

213 

214 if expr_attr_names: 

215 attrs['ExpressionAttributeNames'] = expr_attr_names 

216 

217 if expr_attr_values: 

218 attrs['ExpressionAttributeValues'] = serialize(expr_attr_values) 

219 

220 if return_on_cond_fail: 

221 attrs['ReturnValuesOnConditionCheckFailure'] = return_on_cond_fail 

222 

223 self._add_op_and_process( 

224 TransactOperation( 

225 { 

226 'Update': dict( 

227 TableName=table_name or self._table_name, 

228 Key=serialize(key), 

229 UpdateExpression=update_expr, 

230 **attrs, 

231 ) 

232 }, 

233 exc_cls, 

234 ) 

235 ) 

236 

237 def _add_op_and_process(self, op: TransactOperation) -> None: 

238 self._items_buffer.append(op) 

239 self._flush_if_needed() 

240 

241 def _flush_if_needed(self) -> None: 

242 if len(self._items_buffer) >= self._flush_amount: 

243 self._flush() 

244 

245 def _flush(self) -> bool: 

246 items_to_send = self._items_buffer[: self._flush_amount] 

247 self._items_buffer = self._items_buffer[self._flush_amount :] 

248 

249 transact_items: list[TransactWriteItemTypeDef] = [ 

250 item.operation # type: ignore 

251 for item in items_to_send 

252 ] 

253 

254 try: 

255 self._client.transact_write_items(TransactItems=transact_items) 

256 except self._client.exceptions.TransactionCanceledException as err: 

257 error_msg = jmespath.search("Error.Message || 'Unknown'", err.response) 

258 cancellations = err.response.get('CancellationReasons', []) 

259 reasons = [] 

260 

261 for idx, reason in enumerate(cancellations): 

262 if 'Message' not in reason: 

263 continue 

264 

265 item = items_to_send[idx] 

266 cancellation_reason = TransactionCanceledReason( 

267 code=reason['Code'], # type: ignore 

268 message=reason['Message'], 

269 operation=item.operation, 

270 old_item=deserialize(reason.get('Item', {})), 

271 ) 

272 

273 if self._fail_fast: 

274 exc_cls = item.exc_cls or TransactionOperationFailed 

275 raise _exc_for_reason( 

276 exc_cls, error_msg, cancellation_reason 

277 ) from err 

278 

279 reasons.append(cancellation_reason) 

280 

281 raise TransactionCanceledException(error_msg, reasons=reasons) from err 

282 else: 

283 return True 

284 

285 

286def _exc_for_reason( 

287 exc_cls: Type[Exception], 

288 msg: str, 

289 reason: TransactionCanceledReason, 

290) -> Exception: 

291 if issubclass(exc_cls, TransactionOperationFailed): 

292 return exc_cls(msg, reason=reason) 

293 

294 exc = exc_cls(msg) 

295 setattr(exc, '__reason__', reason) 

296 return exc