1"""
2See docstring class Validator below for more details on validation
3"""
4from abc import abstractmethod
5from copy import deepcopy
6
7from moto.dynamodb2.exceptions import (
8    AttributeIsReservedKeyword,
9    ExpressionAttributeValueNotDefined,
10    AttributeDoesNotExist,
11    ExpressionAttributeNameNotDefined,
12    IncorrectOperandType,
13    InvalidUpdateExpressionInvalidDocumentPath,
14    ProvidedKeyDoesNotExist,
15    EmptyKeyAttributeException,
16    UpdateHashRangeKeyException,
17)
18from moto.dynamodb2.models import DynamoType
19from moto.dynamodb2.parsing.ast_nodes import (
20    ExpressionAttribute,
21    UpdateExpressionPath,
22    UpdateExpressionSetAction,
23    UpdateExpressionAddAction,
24    UpdateExpressionDeleteAction,
25    UpdateExpressionRemoveAction,
26    DDBTypedValue,
27    ExpressionAttributeValue,
28    ExpressionAttributeName,
29    DepthFirstTraverser,
30    NoneExistingPath,
31    UpdateExpressionFunction,
32    ExpressionPathDescender,
33    UpdateExpressionValue,
34    ExpressionValueOperator,
35    ExpressionSelector,
36)
37from moto.dynamodb2.parsing.reserved_keywords import ReservedKeywords
38
39
40class ExpressionAttributeValueProcessor(DepthFirstTraverser):
41    def __init__(self, expression_attribute_values):
42        self.expression_attribute_values = expression_attribute_values
43
44    def _processing_map(self):
45        return {
46            ExpressionAttributeValue: self.replace_expression_attribute_value_with_value
47        }
48
49    def replace_expression_attribute_value_with_value(self, node):
50        """A node representing an Expression Attribute Value. Resolve and replace value"""
51        assert isinstance(node, ExpressionAttributeValue)
52        attribute_value_name = node.get_value_name()
53        try:
54            target = self.expression_attribute_values[attribute_value_name]
55        except KeyError:
56            raise ExpressionAttributeValueNotDefined(
57                attribute_value=attribute_value_name
58            )
59        return DDBTypedValue(DynamoType(target))
60
61
62class ExpressionPathResolver(object):
63    def __init__(self, expression_attribute_names):
64        self.expression_attribute_names = expression_attribute_names
65
66    @classmethod
67    def raise_exception_if_keyword(cls, attribute):
68        if attribute.upper() in ReservedKeywords.get_reserved_keywords():
69            raise AttributeIsReservedKeyword(attribute)
70
71    def resolve_expression_path(self, item, update_expression_path):
72        assert isinstance(update_expression_path, UpdateExpressionPath)
73        return self.resolve_expression_path_nodes(item, update_expression_path.children)
74
75    def resolve_expression_path_nodes(self, item, update_expression_path_nodes):
76        target = item.attrs
77
78        for child in update_expression_path_nodes:
79            # First replace placeholder with attribute_name
80            attr_name = None
81            if isinstance(child, ExpressionAttributeName):
82                attr_placeholder = child.get_attribute_name_placeholder()
83                try:
84                    attr_name = self.expression_attribute_names[attr_placeholder]
85                except KeyError:
86                    raise ExpressionAttributeNameNotDefined(attr_placeholder)
87            elif isinstance(child, ExpressionAttribute):
88                attr_name = child.get_attribute_name()
89                self.raise_exception_if_keyword(attr_name)
90            if attr_name is not None:
91                # Resolv attribute_name
92                try:
93                    target = target[attr_name]
94                except (KeyError, TypeError):
95                    if child == update_expression_path_nodes[-1]:
96                        return NoneExistingPath(creatable=True)
97                    return NoneExistingPath()
98            else:
99                if isinstance(child, ExpressionPathDescender):
100                    continue
101                elif isinstance(child, ExpressionSelector):
102                    index = child.get_index()
103                    if target.is_list():
104                        try:
105                            target = target[index]
106                        except IndexError:
107                            # When a list goes out of bounds when assigning that is no problem when at the assignment
108                            # side. It will just append to the list.
109                            if child == update_expression_path_nodes[-1]:
110                                return NoneExistingPath(creatable=True)
111                            return NoneExistingPath()
112                    else:
113                        raise InvalidUpdateExpressionInvalidDocumentPath
114                else:
115                    raise NotImplementedError(
116                        "Path resolution for {t}".format(t=type(child))
117                    )
118        if not isinstance(target, DynamoType):
119            print(target)
120        return DDBTypedValue(target)
121
122    def resolve_expression_path_nodes_to_dynamo_type(
123        self, item, update_expression_path_nodes
124    ):
125        node = self.resolve_expression_path_nodes(item, update_expression_path_nodes)
126        if isinstance(node, NoneExistingPath):
127            raise ProvidedKeyDoesNotExist()
128        assert isinstance(node, DDBTypedValue)
129        return node.get_value()
130
131
132class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
133    def _processing_map(self):
134        return {
135            UpdateExpressionSetAction: self.disable_resolving,
136            UpdateExpressionPath: self.process_expression_path_node,
137        }
138
139    def __init__(self, expression_attribute_names, item):
140        self.expression_attribute_names = expression_attribute_names
141        self.item = item
142        self.resolving = False
143
144    def pre_processing_of_child(self, parent_node, child_id):
145        """
146        We have to enable resolving if we are processing a child of UpdateExpressionSetAction that is not first.
147        Because first argument is path to be set, 2nd argument would be the value.
148        """
149        if isinstance(
150            parent_node,
151            (
152                UpdateExpressionSetAction,
153                UpdateExpressionRemoveAction,
154                UpdateExpressionDeleteAction,
155                UpdateExpressionAddAction,
156            ),
157        ):
158            if child_id == 0:
159                self.resolving = False
160            else:
161                self.resolving = True
162
163    def disable_resolving(self, node=None):
164        self.resolving = False
165        return node
166
167    def process_expression_path_node(self, node):
168        """Resolve ExpressionAttribute if not part of a path and resolving is enabled."""
169        if self.resolving:
170            return self.resolve_expression_path(node)
171        else:
172            # Still resolve but return original note to make sure path is correct Just make sure nodes are creatable.
173            result_node = self.resolve_expression_path(node)
174            if (
175                isinstance(result_node, NoneExistingPath)
176                and not result_node.is_creatable()
177            ):
178                raise InvalidUpdateExpressionInvalidDocumentPath()
179
180            return node
181
182    def resolve_expression_path(self, node):
183        return ExpressionPathResolver(
184            self.expression_attribute_names
185        ).resolve_expression_path(self.item, node)
186
187
188class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
189    """
190    At time of writing there are only 2 functions for DDB UpdateExpressions. They both are specific to the SET
191    expression as per the official AWS docs:
192        https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/
193        Expressions.UpdateExpressions.html#Expressions.UpdateExpressions.SET
194    """
195
196    def _processing_map(self):
197        return {UpdateExpressionFunction: self.process_function}
198
199    def process_function(self, node):
200        assert isinstance(node, UpdateExpressionFunction)
201        function_name = node.get_function_name()
202        first_arg = node.get_nth_argument(1)
203        second_arg = node.get_nth_argument(2)
204
205        if function_name == "if_not_exists":
206            if isinstance(first_arg, NoneExistingPath):
207                result = second_arg
208            else:
209                result = first_arg
210            assert isinstance(result, (DDBTypedValue, NoneExistingPath))
211            return result
212        elif function_name == "list_append":
213            first_arg = deepcopy(
214                self.get_list_from_ddb_typed_value(first_arg, function_name)
215            )
216            second_arg = self.get_list_from_ddb_typed_value(second_arg, function_name)
217            for list_element in second_arg.value:
218                first_arg.value.append(list_element)
219            return DDBTypedValue(first_arg)
220        else:
221            raise NotImplementedError(
222                "Unsupported function for moto {name}".format(name=function_name)
223            )
224
225    @classmethod
226    def get_list_from_ddb_typed_value(cls, node, function_name):
227        assert isinstance(node, DDBTypedValue)
228        dynamo_value = node.get_value()
229        assert isinstance(dynamo_value, DynamoType)
230        if not dynamo_value.is_list():
231            raise IncorrectOperandType(function_name, dynamo_value.type)
232        return dynamo_value
233
234
235class NoneExistingPathChecker(DepthFirstTraverser):
236    """
237    Pass through the AST and make sure there are no none-existing paths.
238    """
239
240    def _processing_map(self):
241        return {NoneExistingPath: self.raise_none_existing_path}
242
243    def raise_none_existing_path(self, node):
244        raise AttributeDoesNotExist
245
246
247class ExecuteOperations(DepthFirstTraverser):
248    def _processing_map(self):
249        return {UpdateExpressionValue: self.process_update_expression_value}
250
251    def process_update_expression_value(self, node):
252        """
253        If an UpdateExpressionValue only has a single child the node will be replaced with the childe.
254        Otherwise it has 3 children and the middle one is an ExpressionValueOperator which details how to combine them
255        Args:
256            node(Node):
257
258        Returns:
259            Node: The resulting node of the operation if present or the child.
260        """
261        assert isinstance(node, UpdateExpressionValue)
262        if len(node.children) == 1:
263            return node.children[0]
264        elif len(node.children) == 3:
265            operator_node = node.children[1]
266            assert isinstance(operator_node, ExpressionValueOperator)
267            operator = operator_node.get_operator()
268            left_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[0])
269            right_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[2])
270            if operator == "+":
271                return self.get_sum(left_operand, right_operand)
272            elif operator == "-":
273                return self.get_subtraction(left_operand, right_operand)
274            else:
275                raise NotImplementedError(
276                    "Moto does not support operator {operator}".format(
277                        operator=operator
278                    )
279                )
280        else:
281            raise NotImplementedError(
282                "UpdateExpressionValue only has implementations for 1 or 3 children."
283            )
284
285    @classmethod
286    def get_dynamo_value_from_ddb_typed_value(cls, node):
287        assert isinstance(node, DDBTypedValue)
288        dynamo_value = node.get_value()
289        assert isinstance(dynamo_value, DynamoType)
290        return dynamo_value
291
292    @classmethod
293    def get_sum(cls, left_operand, right_operand):
294        """
295        Args:
296            left_operand(DynamoType):
297            right_operand(DynamoType):
298
299        Returns:
300            DDBTypedValue:
301        """
302        try:
303            return DDBTypedValue(left_operand + right_operand)
304        except TypeError:
305            raise IncorrectOperandType("+", left_operand.type)
306
307    @classmethod
308    def get_subtraction(cls, left_operand, right_operand):
309        """
310        Args:
311            left_operand(DynamoType):
312            right_operand(DynamoType):
313
314        Returns:
315            DDBTypedValue:
316        """
317        try:
318            return DDBTypedValue(left_operand - right_operand)
319        except TypeError:
320            raise IncorrectOperandType("-", left_operand.type)
321
322
323class EmptyStringKeyValueValidator(DepthFirstTraverser):
324    def __init__(self, key_attributes):
325        self.key_attributes = key_attributes
326
327    def _processing_map(self):
328        return {UpdateExpressionSetAction: self.check_for_empty_string_key_value}
329
330    def check_for_empty_string_key_value(self, node):
331        """A node representing a SET action. Check that keys are not being assigned empty strings"""
332        assert isinstance(node, UpdateExpressionSetAction)
333        assert len(node.children) == 2
334        key = node.children[0].children[0].children[0]
335        val_node = node.children[1].children[0]
336        if (
337            not val_node.value
338            and val_node.type in ["S", "B"]
339            and key in self.key_attributes
340        ):
341            raise EmptyKeyAttributeException(key_in_index=True)
342        return node
343
344
345class UpdateHashRangeKeyValidator(DepthFirstTraverser):
346    def __init__(self, table_key_attributes):
347        self.table_key_attributes = table_key_attributes
348
349    def _processing_map(self):
350        return {UpdateExpressionPath: self.check_for_hash_or_range_key}
351
352    def check_for_hash_or_range_key(self, node):
353        """Check that hash and range keys are not updated"""
354        key_to_update = node.children[0].children[0]
355        if key_to_update in self.table_key_attributes:
356            raise UpdateHashRangeKeyException(key_to_update)
357        return node
358
359
360class Validator(object):
361    """
362    A validator is used to validate expressions which are passed in as an AST.
363    """
364
365    def __init__(
366        self,
367        expression,
368        expression_attribute_names,
369        expression_attribute_values,
370        item,
371        table,
372    ):
373        """
374        Besides validation the Validator should also replace referenced parts of an item which is cheapest upon
375        validation.
376
377        Args:
378            expression(Node): The root node of the AST representing the expression to be validated
379            expression_attribute_names(ExpressionAttributeNames):
380            expression_attribute_values(ExpressionAttributeValues):
381            item(Item): The item which will be updated (pointed to by Key of update_item)
382        """
383        self.expression_attribute_names = expression_attribute_names
384        self.expression_attribute_values = expression_attribute_values
385        self.item = item
386        self.table = table
387        self.processors = self.get_ast_processors()
388        self.node_to_validate = deepcopy(expression)
389
390    @abstractmethod
391    def get_ast_processors(self):
392        """Get the different processors that go through the AST tree and processes the nodes."""
393
394    def validate(self):
395        n = self.node_to_validate
396        for processor in self.processors:
397            n = processor.traverse(n)
398        return n
399
400
401class UpdateExpressionValidator(Validator):
402    def get_ast_processors(self):
403        """Get the different processors that go through the AST tree and processes the nodes."""
404        processors = [
405            UpdateHashRangeKeyValidator(self.table.table_key_attrs),
406            ExpressionAttributeValueProcessor(self.expression_attribute_values),
407            ExpressionAttributeResolvingProcessor(
408                self.expression_attribute_names, self.item
409            ),
410            UpdateExpressionFunctionEvaluator(),
411            NoneExistingPathChecker(),
412            ExecuteOperations(),
413            EmptyStringKeyValueValidator(self.table.key_attributes),
414        ]
415        return processors
416