1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6from typing import Callable, Optional
7
8import libcst as cst
9from libcst import parse_expression
10from libcst._nodes.tests.base import CSTNodeTest
11from libcst.metadata import CodeRange
12from libcst.testing.utils import data_provider
13
14
15class NumberTest(CSTNodeTest):
16    @data_provider(
17        (
18            # Simple number
19            (cst.Integer("5"), "5", parse_expression),
20            # Negted number
21            (
22                cst.UnaryOperation(operator=cst.Minus(), expression=cst.Integer("5")),
23                "-5",
24                parse_expression,
25                CodeRange((1, 0), (1, 2)),
26            ),
27            # In parenthesis
28            (
29                cst.UnaryOperation(
30                    lpar=(cst.LeftParen(),),
31                    operator=cst.Minus(),
32                    expression=cst.Integer("5"),
33                    rpar=(cst.RightParen(),),
34                ),
35                "(-5)",
36                parse_expression,
37                CodeRange((1, 1), (1, 3)),
38            ),
39            (
40                cst.UnaryOperation(
41                    lpar=(cst.LeftParen(),),
42                    operator=cst.Minus(),
43                    expression=cst.Integer(
44                        "5", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
45                    ),
46                    rpar=(cst.RightParen(),),
47                ),
48                "(-(5))",
49                parse_expression,
50                CodeRange((1, 1), (1, 5)),
51            ),
52            (
53                cst.UnaryOperation(
54                    operator=cst.Minus(),
55                    expression=cst.UnaryOperation(
56                        operator=cst.Minus(), expression=cst.Integer("5")
57                    ),
58                ),
59                "--5",
60                parse_expression,
61                CodeRange((1, 0), (1, 3)),
62            ),
63            # multiple nested parenthesis
64            (
65                cst.Integer(
66                    "5",
67                    lpar=(cst.LeftParen(), cst.LeftParen()),
68                    rpar=(cst.RightParen(), cst.RightParen()),
69                ),
70                "((5))",
71                parse_expression,
72                CodeRange((1, 2), (1, 3)),
73            ),
74            (
75                cst.UnaryOperation(
76                    lpar=(cst.LeftParen(),),
77                    operator=cst.Plus(),
78                    expression=cst.Integer(
79                        "5",
80                        lpar=(cst.LeftParen(), cst.LeftParen()),
81                        rpar=(cst.RightParen(), cst.RightParen()),
82                    ),
83                    rpar=(cst.RightParen(),),
84                ),
85                "(+((5)))",
86                parse_expression,
87                CodeRange((1, 1), (1, 7)),
88            ),
89        )
90    )
91    def test_valid(
92        self,
93        node: cst.CSTNode,
94        code: str,
95        parser: Optional[Callable[[str], cst.CSTNode]],
96        position: Optional[CodeRange] = None,
97    ) -> None:
98        self.validate_node(node, code, parser, expected_position=position)
99
100    @data_provider(
101        (
102            (
103                lambda: cst.Integer("5", lpar=(cst.LeftParen(),)),
104                "left paren without right paren",
105            ),
106            (
107                lambda: cst.Integer("5", rpar=(cst.RightParen(),)),
108                "right paren without left paren",
109            ),
110            (
111                lambda: cst.Float("5.5", lpar=(cst.LeftParen(),)),
112                "left paren without right paren",
113            ),
114            (
115                lambda: cst.Float("5.5", rpar=(cst.RightParen(),)),
116                "right paren without left paren",
117            ),
118            (
119                lambda: cst.Imaginary("5i", lpar=(cst.LeftParen(),)),
120                "left paren without right paren",
121            ),
122            (
123                lambda: cst.Imaginary("5i", rpar=(cst.RightParen(),)),
124                "right paren without left paren",
125            ),
126        )
127    )
128    def test_invalid(
129        self, get_node: Callable[[], cst.CSTNode], expected_re: str
130    ) -> None:
131        self.assert_invalid(get_node, expected_re)
132