1"""Utilities related to determining the reachability of code (in semantic analysis).""" 2 3from typing import Tuple, TypeVar, Union, Optional 4from typing_extensions import Final 5 6from mypy.nodes import ( 7 Expression, IfStmt, Block, AssertStmt, NameExpr, UnaryExpr, MemberExpr, OpExpr, ComparisonExpr, 8 StrExpr, UnicodeExpr, CallExpr, IntExpr, TupleExpr, IndexExpr, SliceExpr, Import, ImportFrom, 9 ImportAll, LITERAL_YES 10) 11from mypy.options import Options 12from mypy.traverser import TraverserVisitor 13from mypy.literals import literal 14 15# Inferred truth value of an expression. 16ALWAYS_TRUE = 1 # type: Final 17MYPY_TRUE = 2 # type: Final # True in mypy, False at runtime 18ALWAYS_FALSE = 3 # type: Final 19MYPY_FALSE = 4 # type: Final # False in mypy, True at runtime 20TRUTH_VALUE_UNKNOWN = 5 # type: Final 21 22inverted_truth_mapping = { 23 ALWAYS_TRUE: ALWAYS_FALSE, 24 ALWAYS_FALSE: ALWAYS_TRUE, 25 TRUTH_VALUE_UNKNOWN: TRUTH_VALUE_UNKNOWN, 26 MYPY_TRUE: MYPY_FALSE, 27 MYPY_FALSE: MYPY_TRUE, 28} # type: Final 29 30reverse_op = {"==": "==", 31 "!=": "!=", 32 "<": ">", 33 ">": "<", 34 "<=": ">=", 35 ">=": "<=", 36 } # type: Final 37 38 39def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None: 40 for i in range(len(s.expr)): 41 result = infer_condition_value(s.expr[i], options) 42 if result in (ALWAYS_FALSE, MYPY_FALSE): 43 # The condition is considered always false, so we skip the if/elif body. 44 mark_block_unreachable(s.body[i]) 45 elif result in (ALWAYS_TRUE, MYPY_TRUE): 46 # This condition is considered always true, so all of the remaining 47 # elif/else bodies should not be checked. 48 if result == MYPY_TRUE: 49 # This condition is false at runtime; this will affect 50 # import priorities. 51 mark_block_mypy_only(s.body[i]) 52 for body in s.body[i + 1:]: 53 mark_block_unreachable(body) 54 55 # Make sure else body always exists and is marked as 56 # unreachable so the type checker always knows that 57 # all control flow paths will flow through the if 58 # statement body. 59 if not s.else_body: 60 s.else_body = Block([]) 61 mark_block_unreachable(s.else_body) 62 break 63 64 65def assert_will_always_fail(s: AssertStmt, options: Options) -> bool: 66 return infer_condition_value(s.expr, options) in (ALWAYS_FALSE, MYPY_FALSE) 67 68 69def infer_condition_value(expr: Expression, options: Options) -> int: 70 """Infer whether the given condition is always true/false. 71 72 Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false, 73 MYPY_TRUE if true under mypy and false at runtime, MYPY_FALSE if 74 false under mypy and true at runtime, else TRUTH_VALUE_UNKNOWN. 75 """ 76 pyversion = options.python_version 77 name = '' 78 negated = False 79 alias = expr 80 if isinstance(alias, UnaryExpr): 81 if alias.op == 'not': 82 expr = alias.expr 83 negated = True 84 result = TRUTH_VALUE_UNKNOWN 85 if isinstance(expr, NameExpr): 86 name = expr.name 87 elif isinstance(expr, MemberExpr): 88 name = expr.name 89 elif isinstance(expr, OpExpr) and expr.op in ('and', 'or'): 90 left = infer_condition_value(expr.left, options) 91 if ((left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'and') or 92 (left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'or')): 93 # Either `True and <other>` or `False or <other>`: the result will 94 # always be the right-hand-side. 95 return infer_condition_value(expr.right, options) 96 else: 97 # The result will always be the left-hand-side (e.g. ALWAYS_* or 98 # TRUTH_VALUE_UNKNOWN). 99 return left 100 else: 101 result = consider_sys_version_info(expr, pyversion) 102 if result == TRUTH_VALUE_UNKNOWN: 103 result = consider_sys_platform(expr, options.platform) 104 if result == TRUTH_VALUE_UNKNOWN: 105 if name == 'PY2': 106 result = ALWAYS_TRUE if pyversion[0] == 2 else ALWAYS_FALSE 107 elif name == 'PY3': 108 result = ALWAYS_TRUE if pyversion[0] == 3 else ALWAYS_FALSE 109 elif name == 'MYPY' or name == 'TYPE_CHECKING': 110 result = MYPY_TRUE 111 elif name in options.always_true: 112 result = ALWAYS_TRUE 113 elif name in options.always_false: 114 result = ALWAYS_FALSE 115 if negated: 116 result = inverted_truth_mapping[result] 117 return result 118 119 120def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> int: 121 """Consider whether expr is a comparison involving sys.version_info. 122 123 Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN. 124 """ 125 # Cases supported: 126 # - sys.version_info[<int>] <compare_op> <int> 127 # - sys.version_info[:<int>] <compare_op> <tuple_of_n_ints> 128 # - sys.version_info <compare_op> <tuple_of_1_or_2_ints> 129 # (in this case <compare_op> must be >, >=, <, <=, but cannot be ==, !=) 130 if not isinstance(expr, ComparisonExpr): 131 return TRUTH_VALUE_UNKNOWN 132 # Let's not yet support chained comparisons. 133 if len(expr.operators) > 1: 134 return TRUTH_VALUE_UNKNOWN 135 op = expr.operators[0] 136 if op not in ('==', '!=', '<=', '>=', '<', '>'): 137 return TRUTH_VALUE_UNKNOWN 138 139 index = contains_sys_version_info(expr.operands[0]) 140 thing = contains_int_or_tuple_of_ints(expr.operands[1]) 141 if index is None or thing is None: 142 index = contains_sys_version_info(expr.operands[1]) 143 thing = contains_int_or_tuple_of_ints(expr.operands[0]) 144 op = reverse_op[op] 145 if isinstance(index, int) and isinstance(thing, int): 146 # sys.version_info[i] <compare_op> k 147 if 0 <= index <= 1: 148 return fixed_comparison(pyversion[index], op, thing) 149 else: 150 return TRUTH_VALUE_UNKNOWN 151 elif isinstance(index, tuple) and isinstance(thing, tuple): 152 lo, hi = index 153 if lo is None: 154 lo = 0 155 if hi is None: 156 hi = 2 157 if 0 <= lo < hi <= 2: 158 val = pyversion[lo:hi] 159 if len(val) == len(thing) or len(val) > len(thing) and op not in ('==', '!='): 160 return fixed_comparison(val, op, thing) 161 return TRUTH_VALUE_UNKNOWN 162 163 164def consider_sys_platform(expr: Expression, platform: str) -> int: 165 """Consider whether expr is a comparison involving sys.platform. 166 167 Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN. 168 """ 169 # Cases supported: 170 # - sys.platform == 'posix' 171 # - sys.platform != 'win32' 172 # - sys.platform.startswith('win') 173 if isinstance(expr, ComparisonExpr): 174 # Let's not yet support chained comparisons. 175 if len(expr.operators) > 1: 176 return TRUTH_VALUE_UNKNOWN 177 op = expr.operators[0] 178 if op not in ('==', '!='): 179 return TRUTH_VALUE_UNKNOWN 180 if not is_sys_attr(expr.operands[0], 'platform'): 181 return TRUTH_VALUE_UNKNOWN 182 right = expr.operands[1] 183 if not isinstance(right, (StrExpr, UnicodeExpr)): 184 return TRUTH_VALUE_UNKNOWN 185 return fixed_comparison(platform, op, right.value) 186 elif isinstance(expr, CallExpr): 187 if not isinstance(expr.callee, MemberExpr): 188 return TRUTH_VALUE_UNKNOWN 189 if len(expr.args) != 1 or not isinstance(expr.args[0], (StrExpr, UnicodeExpr)): 190 return TRUTH_VALUE_UNKNOWN 191 if not is_sys_attr(expr.callee.expr, 'platform'): 192 return TRUTH_VALUE_UNKNOWN 193 if expr.callee.name != 'startswith': 194 return TRUTH_VALUE_UNKNOWN 195 if platform.startswith(expr.args[0].value): 196 return ALWAYS_TRUE 197 else: 198 return ALWAYS_FALSE 199 else: 200 return TRUTH_VALUE_UNKNOWN 201 202 203Targ = TypeVar('Targ', int, str, Tuple[int, ...]) 204 205 206def fixed_comparison(left: Targ, op: str, right: Targ) -> int: 207 rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE} 208 if op == '==': 209 return rmap[left == right] 210 if op == '!=': 211 return rmap[left != right] 212 if op == '<=': 213 return rmap[left <= right] 214 if op == '>=': 215 return rmap[left >= right] 216 if op == '<': 217 return rmap[left < right] 218 if op == '>': 219 return rmap[left > right] 220 return TRUTH_VALUE_UNKNOWN 221 222 223def contains_int_or_tuple_of_ints(expr: Expression 224 ) -> Union[None, int, Tuple[int], Tuple[int, ...]]: 225 if isinstance(expr, IntExpr): 226 return expr.value 227 if isinstance(expr, TupleExpr): 228 if literal(expr) == LITERAL_YES: 229 thing = [] 230 for x in expr.items: 231 if not isinstance(x, IntExpr): 232 return None 233 thing.append(x.value) 234 return tuple(thing) 235 return None 236 237 238def contains_sys_version_info(expr: Expression 239 ) -> Union[None, int, Tuple[Optional[int], Optional[int]]]: 240 if is_sys_attr(expr, 'version_info'): 241 return (None, None) # Same as sys.version_info[:] 242 if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, 'version_info'): 243 index = expr.index 244 if isinstance(index, IntExpr): 245 return index.value 246 if isinstance(index, SliceExpr): 247 if index.stride is not None: 248 if not isinstance(index.stride, IntExpr) or index.stride.value != 1: 249 return None 250 begin = end = None 251 if index.begin_index is not None: 252 if not isinstance(index.begin_index, IntExpr): 253 return None 254 begin = index.begin_index.value 255 if index.end_index is not None: 256 if not isinstance(index.end_index, IntExpr): 257 return None 258 end = index.end_index.value 259 return (begin, end) 260 return None 261 262 263def is_sys_attr(expr: Expression, name: str) -> bool: 264 # TODO: This currently doesn't work with code like this: 265 # - import sys as _sys 266 # - from sys import version_info 267 if isinstance(expr, MemberExpr) and expr.name == name: 268 if isinstance(expr.expr, NameExpr) and expr.expr.name == 'sys': 269 # TODO: Guard against a local named sys, etc. 270 # (Though later passes will still do most checking.) 271 return True 272 return False 273 274 275def mark_block_unreachable(block: Block) -> None: 276 block.is_unreachable = True 277 block.accept(MarkImportsUnreachableVisitor()) 278 279 280class MarkImportsUnreachableVisitor(TraverserVisitor): 281 """Visitor that flags all imports nested within a node as unreachable.""" 282 283 def visit_import(self, node: Import) -> None: 284 node.is_unreachable = True 285 286 def visit_import_from(self, node: ImportFrom) -> None: 287 node.is_unreachable = True 288 289 def visit_import_all(self, node: ImportAll) -> None: 290 node.is_unreachable = True 291 292 293def mark_block_mypy_only(block: Block) -> None: 294 block.accept(MarkImportsMypyOnlyVisitor()) 295 296 297class MarkImportsMypyOnlyVisitor(TraverserVisitor): 298 """Visitor that sets is_mypy_only (which affects priority).""" 299 300 def visit_import(self, node: Import) -> None: 301 node.is_mypy_only = True 302 303 def visit_import_from(self, node: ImportFrom) -> None: 304 node.is_mypy_only = True 305 306 def visit_import_all(self, node: ImportAll) -> None: 307 node.is_mypy_only = True 308