1""" 2See docstring class Validator below for more details on validation 3""" 4from abc import abstractmethod 5from copy import deepcopy 6 7from moto.dynamodb2.exceptions import ( 8 AttributeIsReservedKeyword, 9 ExpressionAttributeValueNotDefined, 10 AttributeDoesNotExist, 11 ExpressionAttributeNameNotDefined, 12 IncorrectOperandType, 13 InvalidUpdateExpressionInvalidDocumentPath, 14 ProvidedKeyDoesNotExist, 15 EmptyKeyAttributeException, 16 UpdateHashRangeKeyException, 17) 18from moto.dynamodb2.models import DynamoType 19from moto.dynamodb2.parsing.ast_nodes import ( 20 ExpressionAttribute, 21 UpdateExpressionPath, 22 UpdateExpressionSetAction, 23 UpdateExpressionAddAction, 24 UpdateExpressionDeleteAction, 25 UpdateExpressionRemoveAction, 26 DDBTypedValue, 27 ExpressionAttributeValue, 28 ExpressionAttributeName, 29 DepthFirstTraverser, 30 NoneExistingPath, 31 UpdateExpressionFunction, 32 ExpressionPathDescender, 33 UpdateExpressionValue, 34 ExpressionValueOperator, 35 ExpressionSelector, 36) 37from moto.dynamodb2.parsing.reserved_keywords import ReservedKeywords 38 39 40class ExpressionAttributeValueProcessor(DepthFirstTraverser): 41 def __init__(self, expression_attribute_values): 42 self.expression_attribute_values = expression_attribute_values 43 44 def _processing_map(self): 45 return { 46 ExpressionAttributeValue: self.replace_expression_attribute_value_with_value 47 } 48 49 def replace_expression_attribute_value_with_value(self, node): 50 """A node representing an Expression Attribute Value. Resolve and replace value""" 51 assert isinstance(node, ExpressionAttributeValue) 52 attribute_value_name = node.get_value_name() 53 try: 54 target = self.expression_attribute_values[attribute_value_name] 55 except KeyError: 56 raise ExpressionAttributeValueNotDefined( 57 attribute_value=attribute_value_name 58 ) 59 return DDBTypedValue(DynamoType(target)) 60 61 62class ExpressionPathResolver(object): 63 def __init__(self, expression_attribute_names): 64 self.expression_attribute_names = expression_attribute_names 65 66 @classmethod 67 def raise_exception_if_keyword(cls, attribute): 68 if attribute.upper() in ReservedKeywords.get_reserved_keywords(): 69 raise AttributeIsReservedKeyword(attribute) 70 71 def resolve_expression_path(self, item, update_expression_path): 72 assert isinstance(update_expression_path, UpdateExpressionPath) 73 return self.resolve_expression_path_nodes(item, update_expression_path.children) 74 75 def resolve_expression_path_nodes(self, item, update_expression_path_nodes): 76 target = item.attrs 77 78 for child in update_expression_path_nodes: 79 # First replace placeholder with attribute_name 80 attr_name = None 81 if isinstance(child, ExpressionAttributeName): 82 attr_placeholder = child.get_attribute_name_placeholder() 83 try: 84 attr_name = self.expression_attribute_names[attr_placeholder] 85 except KeyError: 86 raise ExpressionAttributeNameNotDefined(attr_placeholder) 87 elif isinstance(child, ExpressionAttribute): 88 attr_name = child.get_attribute_name() 89 self.raise_exception_if_keyword(attr_name) 90 if attr_name is not None: 91 # Resolv attribute_name 92 try: 93 target = target[attr_name] 94 except (KeyError, TypeError): 95 if child == update_expression_path_nodes[-1]: 96 return NoneExistingPath(creatable=True) 97 return NoneExistingPath() 98 else: 99 if isinstance(child, ExpressionPathDescender): 100 continue 101 elif isinstance(child, ExpressionSelector): 102 index = child.get_index() 103 if target.is_list(): 104 try: 105 target = target[index] 106 except IndexError: 107 # When a list goes out of bounds when assigning that is no problem when at the assignment 108 # side. It will just append to the list. 109 if child == update_expression_path_nodes[-1]: 110 return NoneExistingPath(creatable=True) 111 return NoneExistingPath() 112 else: 113 raise InvalidUpdateExpressionInvalidDocumentPath 114 else: 115 raise NotImplementedError( 116 "Path resolution for {t}".format(t=type(child)) 117 ) 118 if not isinstance(target, DynamoType): 119 print(target) 120 return DDBTypedValue(target) 121 122 def resolve_expression_path_nodes_to_dynamo_type( 123 self, item, update_expression_path_nodes 124 ): 125 node = self.resolve_expression_path_nodes(item, update_expression_path_nodes) 126 if isinstance(node, NoneExistingPath): 127 raise ProvidedKeyDoesNotExist() 128 assert isinstance(node, DDBTypedValue) 129 return node.get_value() 130 131 132class ExpressionAttributeResolvingProcessor(DepthFirstTraverser): 133 def _processing_map(self): 134 return { 135 UpdateExpressionSetAction: self.disable_resolving, 136 UpdateExpressionPath: self.process_expression_path_node, 137 } 138 139 def __init__(self, expression_attribute_names, item): 140 self.expression_attribute_names = expression_attribute_names 141 self.item = item 142 self.resolving = False 143 144 def pre_processing_of_child(self, parent_node, child_id): 145 """ 146 We have to enable resolving if we are processing a child of UpdateExpressionSetAction that is not first. 147 Because first argument is path to be set, 2nd argument would be the value. 148 """ 149 if isinstance( 150 parent_node, 151 ( 152 UpdateExpressionSetAction, 153 UpdateExpressionRemoveAction, 154 UpdateExpressionDeleteAction, 155 UpdateExpressionAddAction, 156 ), 157 ): 158 if child_id == 0: 159 self.resolving = False 160 else: 161 self.resolving = True 162 163 def disable_resolving(self, node=None): 164 self.resolving = False 165 return node 166 167 def process_expression_path_node(self, node): 168 """Resolve ExpressionAttribute if not part of a path and resolving is enabled.""" 169 if self.resolving: 170 return self.resolve_expression_path(node) 171 else: 172 # Still resolve but return original note to make sure path is correct Just make sure nodes are creatable. 173 result_node = self.resolve_expression_path(node) 174 if ( 175 isinstance(result_node, NoneExistingPath) 176 and not result_node.is_creatable() 177 ): 178 raise InvalidUpdateExpressionInvalidDocumentPath() 179 180 return node 181 182 def resolve_expression_path(self, node): 183 return ExpressionPathResolver( 184 self.expression_attribute_names 185 ).resolve_expression_path(self.item, node) 186 187 188class UpdateExpressionFunctionEvaluator(DepthFirstTraverser): 189 """ 190 At time of writing there are only 2 functions for DDB UpdateExpressions. They both are specific to the SET 191 expression as per the official AWS docs: 192 https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/ 193 Expressions.UpdateExpressions.html#Expressions.UpdateExpressions.SET 194 """ 195 196 def _processing_map(self): 197 return {UpdateExpressionFunction: self.process_function} 198 199 def process_function(self, node): 200 assert isinstance(node, UpdateExpressionFunction) 201 function_name = node.get_function_name() 202 first_arg = node.get_nth_argument(1) 203 second_arg = node.get_nth_argument(2) 204 205 if function_name == "if_not_exists": 206 if isinstance(first_arg, NoneExistingPath): 207 result = second_arg 208 else: 209 result = first_arg 210 assert isinstance(result, (DDBTypedValue, NoneExistingPath)) 211 return result 212 elif function_name == "list_append": 213 first_arg = deepcopy( 214 self.get_list_from_ddb_typed_value(first_arg, function_name) 215 ) 216 second_arg = self.get_list_from_ddb_typed_value(second_arg, function_name) 217 for list_element in second_arg.value: 218 first_arg.value.append(list_element) 219 return DDBTypedValue(first_arg) 220 else: 221 raise NotImplementedError( 222 "Unsupported function for moto {name}".format(name=function_name) 223 ) 224 225 @classmethod 226 def get_list_from_ddb_typed_value(cls, node, function_name): 227 assert isinstance(node, DDBTypedValue) 228 dynamo_value = node.get_value() 229 assert isinstance(dynamo_value, DynamoType) 230 if not dynamo_value.is_list(): 231 raise IncorrectOperandType(function_name, dynamo_value.type) 232 return dynamo_value 233 234 235class NoneExistingPathChecker(DepthFirstTraverser): 236 """ 237 Pass through the AST and make sure there are no none-existing paths. 238 """ 239 240 def _processing_map(self): 241 return {NoneExistingPath: self.raise_none_existing_path} 242 243 def raise_none_existing_path(self, node): 244 raise AttributeDoesNotExist 245 246 247class ExecuteOperations(DepthFirstTraverser): 248 def _processing_map(self): 249 return {UpdateExpressionValue: self.process_update_expression_value} 250 251 def process_update_expression_value(self, node): 252 """ 253 If an UpdateExpressionValue only has a single child the node will be replaced with the childe. 254 Otherwise it has 3 children and the middle one is an ExpressionValueOperator which details how to combine them 255 Args: 256 node(Node): 257 258 Returns: 259 Node: The resulting node of the operation if present or the child. 260 """ 261 assert isinstance(node, UpdateExpressionValue) 262 if len(node.children) == 1: 263 return node.children[0] 264 elif len(node.children) == 3: 265 operator_node = node.children[1] 266 assert isinstance(operator_node, ExpressionValueOperator) 267 operator = operator_node.get_operator() 268 left_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[0]) 269 right_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[2]) 270 if operator == "+": 271 return self.get_sum(left_operand, right_operand) 272 elif operator == "-": 273 return self.get_subtraction(left_operand, right_operand) 274 else: 275 raise NotImplementedError( 276 "Moto does not support operator {operator}".format( 277 operator=operator 278 ) 279 ) 280 else: 281 raise NotImplementedError( 282 "UpdateExpressionValue only has implementations for 1 or 3 children." 283 ) 284 285 @classmethod 286 def get_dynamo_value_from_ddb_typed_value(cls, node): 287 assert isinstance(node, DDBTypedValue) 288 dynamo_value = node.get_value() 289 assert isinstance(dynamo_value, DynamoType) 290 return dynamo_value 291 292 @classmethod 293 def get_sum(cls, left_operand, right_operand): 294 """ 295 Args: 296 left_operand(DynamoType): 297 right_operand(DynamoType): 298 299 Returns: 300 DDBTypedValue: 301 """ 302 try: 303 return DDBTypedValue(left_operand + right_operand) 304 except TypeError: 305 raise IncorrectOperandType("+", left_operand.type) 306 307 @classmethod 308 def get_subtraction(cls, left_operand, right_operand): 309 """ 310 Args: 311 left_operand(DynamoType): 312 right_operand(DynamoType): 313 314 Returns: 315 DDBTypedValue: 316 """ 317 try: 318 return DDBTypedValue(left_operand - right_operand) 319 except TypeError: 320 raise IncorrectOperandType("-", left_operand.type) 321 322 323class EmptyStringKeyValueValidator(DepthFirstTraverser): 324 def __init__(self, key_attributes): 325 self.key_attributes = key_attributes 326 327 def _processing_map(self): 328 return {UpdateExpressionSetAction: self.check_for_empty_string_key_value} 329 330 def check_for_empty_string_key_value(self, node): 331 """A node representing a SET action. Check that keys are not being assigned empty strings""" 332 assert isinstance(node, UpdateExpressionSetAction) 333 assert len(node.children) == 2 334 key = node.children[0].children[0].children[0] 335 val_node = node.children[1].children[0] 336 if ( 337 not val_node.value 338 and val_node.type in ["S", "B"] 339 and key in self.key_attributes 340 ): 341 raise EmptyKeyAttributeException(key_in_index=True) 342 return node 343 344 345class UpdateHashRangeKeyValidator(DepthFirstTraverser): 346 def __init__(self, table_key_attributes): 347 self.table_key_attributes = table_key_attributes 348 349 def _processing_map(self): 350 return {UpdateExpressionPath: self.check_for_hash_or_range_key} 351 352 def check_for_hash_or_range_key(self, node): 353 """Check that hash and range keys are not updated""" 354 key_to_update = node.children[0].children[0] 355 if key_to_update in self.table_key_attributes: 356 raise UpdateHashRangeKeyException(key_to_update) 357 return node 358 359 360class Validator(object): 361 """ 362 A validator is used to validate expressions which are passed in as an AST. 363 """ 364 365 def __init__( 366 self, 367 expression, 368 expression_attribute_names, 369 expression_attribute_values, 370 item, 371 table, 372 ): 373 """ 374 Besides validation the Validator should also replace referenced parts of an item which is cheapest upon 375 validation. 376 377 Args: 378 expression(Node): The root node of the AST representing the expression to be validated 379 expression_attribute_names(ExpressionAttributeNames): 380 expression_attribute_values(ExpressionAttributeValues): 381 item(Item): The item which will be updated (pointed to by Key of update_item) 382 """ 383 self.expression_attribute_names = expression_attribute_names 384 self.expression_attribute_values = expression_attribute_values 385 self.item = item 386 self.table = table 387 self.processors = self.get_ast_processors() 388 self.node_to_validate = deepcopy(expression) 389 390 @abstractmethod 391 def get_ast_processors(self): 392 """Get the different processors that go through the AST tree and processes the nodes.""" 393 394 def validate(self): 395 n = self.node_to_validate 396 for processor in self.processors: 397 n = processor.traverse(n) 398 return n 399 400 401class UpdateExpressionValidator(Validator): 402 def get_ast_processors(self): 403 """Get the different processors that go through the AST tree and processes the nodes.""" 404 processors = [ 405 UpdateHashRangeKeyValidator(self.table.table_key_attrs), 406 ExpressionAttributeValueProcessor(self.expression_attribute_values), 407 ExpressionAttributeResolvingProcessor( 408 self.expression_attribute_names, self.item 409 ), 410 UpdateExpressionFunctionEvaluator(), 411 NoneExistingPathChecker(), 412 ExecuteOperations(), 413 EmptyStringKeyValueValidator(self.table.key_attributes), 414 ] 415 return processors 416