1"""Utilities related to determining the reachability of code (in semantic analysis)."""
2
3from typing import Tuple, TypeVar, Union, Optional
4from typing_extensions import Final
5
6from mypy.nodes import (
7    Expression, IfStmt, Block, AssertStmt, NameExpr, UnaryExpr, MemberExpr, OpExpr, ComparisonExpr,
8    StrExpr, UnicodeExpr, CallExpr, IntExpr, TupleExpr, IndexExpr, SliceExpr, Import, ImportFrom,
9    ImportAll, LITERAL_YES
10)
11from mypy.options import Options
12from mypy.traverser import TraverserVisitor
13from mypy.literals import literal
14
15# Inferred truth value of an expression.
16ALWAYS_TRUE = 1  # type: Final
17MYPY_TRUE = 2  # type: Final  # True in mypy, False at runtime
18ALWAYS_FALSE = 3  # type: Final
19MYPY_FALSE = 4  # type: Final  # False in mypy, True at runtime
20TRUTH_VALUE_UNKNOWN = 5  # type: Final
21
22inverted_truth_mapping = {
23    ALWAYS_TRUE: ALWAYS_FALSE,
24    ALWAYS_FALSE: ALWAYS_TRUE,
25    TRUTH_VALUE_UNKNOWN: TRUTH_VALUE_UNKNOWN,
26    MYPY_TRUE: MYPY_FALSE,
27    MYPY_FALSE: MYPY_TRUE,
28}  # type: Final
29
30reverse_op = {"==": "==",
31              "!=": "!=",
32              "<":  ">",
33              ">":  "<",
34              "<=": ">=",
35              ">=": "<=",
36              }  # type: Final
37
38
39def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None:
40    for i in range(len(s.expr)):
41        result = infer_condition_value(s.expr[i], options)
42        if result in (ALWAYS_FALSE, MYPY_FALSE):
43            # The condition is considered always false, so we skip the if/elif body.
44            mark_block_unreachable(s.body[i])
45        elif result in (ALWAYS_TRUE, MYPY_TRUE):
46            # This condition is considered always true, so all of the remaining
47            # elif/else bodies should not be checked.
48            if result == MYPY_TRUE:
49                # This condition is false at runtime; this will affect
50                # import priorities.
51                mark_block_mypy_only(s.body[i])
52            for body in s.body[i + 1:]:
53                mark_block_unreachable(body)
54
55            # Make sure else body always exists and is marked as
56            # unreachable so the type checker always knows that
57            # all control flow paths will flow through the if
58            # statement body.
59            if not s.else_body:
60                s.else_body = Block([])
61            mark_block_unreachable(s.else_body)
62            break
63
64
65def assert_will_always_fail(s: AssertStmt, options: Options) -> bool:
66    return infer_condition_value(s.expr, options) in (ALWAYS_FALSE, MYPY_FALSE)
67
68
69def infer_condition_value(expr: Expression, options: Options) -> int:
70    """Infer whether the given condition is always true/false.
71
72    Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false,
73    MYPY_TRUE if true under mypy and false at runtime, MYPY_FALSE if
74    false under mypy and true at runtime, else TRUTH_VALUE_UNKNOWN.
75    """
76    pyversion = options.python_version
77    name = ''
78    negated = False
79    alias = expr
80    if isinstance(alias, UnaryExpr):
81        if alias.op == 'not':
82            expr = alias.expr
83            negated = True
84    result = TRUTH_VALUE_UNKNOWN
85    if isinstance(expr, NameExpr):
86        name = expr.name
87    elif isinstance(expr, MemberExpr):
88        name = expr.name
89    elif isinstance(expr, OpExpr) and expr.op in ('and', 'or'):
90        left = infer_condition_value(expr.left, options)
91        if ((left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'and') or
92                (left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'or')):
93            # Either `True and <other>` or `False or <other>`: the result will
94            # always be the right-hand-side.
95            return infer_condition_value(expr.right, options)
96        else:
97            # The result will always be the left-hand-side (e.g. ALWAYS_* or
98            # TRUTH_VALUE_UNKNOWN).
99            return left
100    else:
101        result = consider_sys_version_info(expr, pyversion)
102        if result == TRUTH_VALUE_UNKNOWN:
103            result = consider_sys_platform(expr, options.platform)
104    if result == TRUTH_VALUE_UNKNOWN:
105        if name == 'PY2':
106            result = ALWAYS_TRUE if pyversion[0] == 2 else ALWAYS_FALSE
107        elif name == 'PY3':
108            result = ALWAYS_TRUE if pyversion[0] == 3 else ALWAYS_FALSE
109        elif name == 'MYPY' or name == 'TYPE_CHECKING':
110            result = MYPY_TRUE
111        elif name in options.always_true:
112            result = ALWAYS_TRUE
113        elif name in options.always_false:
114            result = ALWAYS_FALSE
115    if negated:
116        result = inverted_truth_mapping[result]
117    return result
118
119
120def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> int:
121    """Consider whether expr is a comparison involving sys.version_info.
122
123    Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
124    """
125    # Cases supported:
126    # - sys.version_info[<int>] <compare_op> <int>
127    # - sys.version_info[:<int>] <compare_op> <tuple_of_n_ints>
128    # - sys.version_info <compare_op> <tuple_of_1_or_2_ints>
129    #   (in this case <compare_op> must be >, >=, <, <=, but cannot be ==, !=)
130    if not isinstance(expr, ComparisonExpr):
131        return TRUTH_VALUE_UNKNOWN
132    # Let's not yet support chained comparisons.
133    if len(expr.operators) > 1:
134        return TRUTH_VALUE_UNKNOWN
135    op = expr.operators[0]
136    if op not in ('==', '!=', '<=', '>=', '<', '>'):
137        return TRUTH_VALUE_UNKNOWN
138
139    index = contains_sys_version_info(expr.operands[0])
140    thing = contains_int_or_tuple_of_ints(expr.operands[1])
141    if index is None or thing is None:
142        index = contains_sys_version_info(expr.operands[1])
143        thing = contains_int_or_tuple_of_ints(expr.operands[0])
144        op = reverse_op[op]
145    if isinstance(index, int) and isinstance(thing, int):
146        # sys.version_info[i] <compare_op> k
147        if 0 <= index <= 1:
148            return fixed_comparison(pyversion[index], op, thing)
149        else:
150            return TRUTH_VALUE_UNKNOWN
151    elif isinstance(index, tuple) and isinstance(thing, tuple):
152        lo, hi = index
153        if lo is None:
154            lo = 0
155        if hi is None:
156            hi = 2
157        if 0 <= lo < hi <= 2:
158            val = pyversion[lo:hi]
159            if len(val) == len(thing) or len(val) > len(thing) and op not in ('==', '!='):
160                return fixed_comparison(val, op, thing)
161    return TRUTH_VALUE_UNKNOWN
162
163
164def consider_sys_platform(expr: Expression, platform: str) -> int:
165    """Consider whether expr is a comparison involving sys.platform.
166
167    Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
168    """
169    # Cases supported:
170    # - sys.platform == 'posix'
171    # - sys.platform != 'win32'
172    # - sys.platform.startswith('win')
173    if isinstance(expr, ComparisonExpr):
174        # Let's not yet support chained comparisons.
175        if len(expr.operators) > 1:
176            return TRUTH_VALUE_UNKNOWN
177        op = expr.operators[0]
178        if op not in ('==', '!='):
179            return TRUTH_VALUE_UNKNOWN
180        if not is_sys_attr(expr.operands[0], 'platform'):
181            return TRUTH_VALUE_UNKNOWN
182        right = expr.operands[1]
183        if not isinstance(right, (StrExpr, UnicodeExpr)):
184            return TRUTH_VALUE_UNKNOWN
185        return fixed_comparison(platform, op, right.value)
186    elif isinstance(expr, CallExpr):
187        if not isinstance(expr.callee, MemberExpr):
188            return TRUTH_VALUE_UNKNOWN
189        if len(expr.args) != 1 or not isinstance(expr.args[0], (StrExpr, UnicodeExpr)):
190            return TRUTH_VALUE_UNKNOWN
191        if not is_sys_attr(expr.callee.expr, 'platform'):
192            return TRUTH_VALUE_UNKNOWN
193        if expr.callee.name != 'startswith':
194            return TRUTH_VALUE_UNKNOWN
195        if platform.startswith(expr.args[0].value):
196            return ALWAYS_TRUE
197        else:
198            return ALWAYS_FALSE
199    else:
200        return TRUTH_VALUE_UNKNOWN
201
202
203Targ = TypeVar('Targ', int, str, Tuple[int, ...])
204
205
206def fixed_comparison(left: Targ, op: str, right: Targ) -> int:
207    rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE}
208    if op == '==':
209        return rmap[left == right]
210    if op == '!=':
211        return rmap[left != right]
212    if op == '<=':
213        return rmap[left <= right]
214    if op == '>=':
215        return rmap[left >= right]
216    if op == '<':
217        return rmap[left < right]
218    if op == '>':
219        return rmap[left > right]
220    return TRUTH_VALUE_UNKNOWN
221
222
223def contains_int_or_tuple_of_ints(expr: Expression
224                                  ) -> Union[None, int, Tuple[int], Tuple[int, ...]]:
225    if isinstance(expr, IntExpr):
226        return expr.value
227    if isinstance(expr, TupleExpr):
228        if literal(expr) == LITERAL_YES:
229            thing = []
230            for x in expr.items:
231                if not isinstance(x, IntExpr):
232                    return None
233                thing.append(x.value)
234            return tuple(thing)
235    return None
236
237
238def contains_sys_version_info(expr: Expression
239                              ) -> Union[None, int, Tuple[Optional[int], Optional[int]]]:
240    if is_sys_attr(expr, 'version_info'):
241        return (None, None)  # Same as sys.version_info[:]
242    if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, 'version_info'):
243        index = expr.index
244        if isinstance(index, IntExpr):
245            return index.value
246        if isinstance(index, SliceExpr):
247            if index.stride is not None:
248                if not isinstance(index.stride, IntExpr) or index.stride.value != 1:
249                    return None
250            begin = end = None
251            if index.begin_index is not None:
252                if not isinstance(index.begin_index, IntExpr):
253                    return None
254                begin = index.begin_index.value
255            if index.end_index is not None:
256                if not isinstance(index.end_index, IntExpr):
257                    return None
258                end = index.end_index.value
259            return (begin, end)
260    return None
261
262
263def is_sys_attr(expr: Expression, name: str) -> bool:
264    # TODO: This currently doesn't work with code like this:
265    # - import sys as _sys
266    # - from sys import version_info
267    if isinstance(expr, MemberExpr) and expr.name == name:
268        if isinstance(expr.expr, NameExpr) and expr.expr.name == 'sys':
269            # TODO: Guard against a local named sys, etc.
270            # (Though later passes will still do most checking.)
271            return True
272    return False
273
274
275def mark_block_unreachable(block: Block) -> None:
276    block.is_unreachable = True
277    block.accept(MarkImportsUnreachableVisitor())
278
279
280class MarkImportsUnreachableVisitor(TraverserVisitor):
281    """Visitor that flags all imports nested within a node as unreachable."""
282
283    def visit_import(self, node: Import) -> None:
284        node.is_unreachable = True
285
286    def visit_import_from(self, node: ImportFrom) -> None:
287        node.is_unreachable = True
288
289    def visit_import_all(self, node: ImportAll) -> None:
290        node.is_unreachable = True
291
292
293def mark_block_mypy_only(block: Block) -> None:
294    block.accept(MarkImportsMypyOnlyVisitor())
295
296
297class MarkImportsMypyOnlyVisitor(TraverserVisitor):
298    """Visitor that sets is_mypy_only (which affects priority)."""
299
300    def visit_import(self, node: Import) -> None:
301        node.is_mypy_only = True
302
303    def visit_import_from(self, node: ImportFrom) -> None:
304        node.is_mypy_only = True
305
306    def visit_import_all(self, node: ImportAll) -> None:
307        node.is_mypy_only = True
308