1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16import operator 17from operator import attrgetter 18from contextlib import contextmanager, suppress 19from collections import namedtuple 20from functools import total_ordering 21import itertools as it 22from weakref import ref 23import threading 24import types 25from typing import (Any, Callable, ClassVar, Dict, Generator, 26 Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple, 27 Type, Union, cast, Iterable, Hashable) 28 29import numpy as np 30 31from . import dtypes 32from .config import FLAGS, config 33from . import linear_util as lu 34 35from jax._src import source_info_util 36from ._src.util import (safe_zip, safe_map, partial, curry, prod, partialmethod, 37 tuple_insert, tuple_delete, as_hashable_function, 38 HashableFunction) 39from ._src.pprint_util import pp, vcat, PrettyPrint 40 41from ._src import traceback_util 42traceback_util.register_exclusion(__file__) 43 44# TODO(mattjj): move this into debug_state 45skip_checks = not FLAGS.jax_enable_checks 46 47@contextmanager 48def skipping_checks(): 49 """Context manager for temporarily disabling internal checks.""" 50 global skip_checks 51 old_value, skip_checks = skip_checks, True 52 try: 53 yield 54 finally: 55 skip_checks = old_value 56 57@contextmanager 58def checking_leaks(): 59 """Context manager for temporarily enabling tracer leak checks.""" 60 old_value, debug_state.check_leaks = debug_state.check_leaks, True 61 try: 62 yield 63 finally: 64 debug_state.check_leaks = old_value 65 66class DebugState(threading.local): 67 def __init__(self): 68 self.check_leaks = FLAGS.jax_check_tracer_leaks 69debug_state = DebugState() 70 71zip = safe_zip 72map = safe_map 73 74 75# -------------------- jaxprs -------------------- 76 77class Jaxpr: 78 constvars: List['Var'] 79 invars: List['Var'] 80 outvars: List['Atom'] 81 eqns: List['JaxprEqn'] 82 83 def __init__(self, constvars: Sequence['Var'], invars: Sequence['Var'], 84 outvars: Sequence['Atom'], eqns: Sequence['JaxprEqn']): 85 """ 86 Args: 87 constvars: list of variables introduced for constants. Array constants are 88 replaced with such variables while scalar constants are kept inline. 89 invars: list of input variables. Together, `constvars` and `invars` are 90 the inputs to the Jaxpr. 91 outvars: list of output variables. 92 eqns: list of equations. 93 """ 94 self.constvars = list(constvars) 95 self.invars = list(invars) 96 self.outvars = list(outvars) 97 self.eqns = list(eqns) 98 99 def __str__(self): 100 return str(pp_jaxpr(self)) 101 __repr__ = __str__ 102 103 104def jaxprs_in_params(params) -> Iterator[Jaxpr]: 105 for val in params.values(): 106 vals = val if isinstance(val, tuple) else (val,) 107 for v in vals: 108 if isinstance(v, Jaxpr): 109 yield v 110 elif isinstance(v, ClosedJaxpr): 111 yield v.jaxpr 112 113 114def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]: 115 """Generator for all subjaxprs found in the params of jaxpr.eqns. 116 117 Does not descend recursively into the found subjaxprs. 118 """ 119 for eqn in jaxpr.eqns: 120 yield from jaxprs_in_params(eqn.params) 121 122 123class ClosedJaxpr: 124 jaxpr: Jaxpr 125 consts: List['Any'] 126 127 def __init__(self, jaxpr: Jaxpr, consts: Sequence): 128 assert len(consts) == len(jaxpr.constvars) 129 self.jaxpr = jaxpr 130 self.consts = list(consts) 131 132 @property 133 def in_avals(self): 134 return [v.aval for v in self.jaxpr.invars] 135 136 @property 137 def out_avals(self): 138 return [v.aval for v in self.jaxpr.outvars] 139 140 @property 141 def literals(self): 142 return self.consts # backwards compatible alias 143 144 def map_jaxpr(self, f): 145 return ClosedJaxpr(f(self.jaxpr), self.consts) 146 147 def __str__(self): return str(self.jaxpr) 148 def __repr__(self): return repr(self.jaxpr) 149 150@curry 151def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args): 152 return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args) 153 154 155class JaxprEqn(NamedTuple): 156 invars: List['Atom'] 157 outvars: List['Var'] 158 primitive: 'Primitive' 159 params: Dict[str, Any] 160 source_info: Optional[source_info_util.Traceback] 161 162 def __repr__(self): return str(pp_eqn(self)).rstrip() 163 164def new_jaxpr_eqn(invars, outvars, primitive, params, source_info=None): 165 return JaxprEqn(invars, outvars, primitive, params, source_info) 166 167 168@total_ordering 169class Var: 170 # TODO(frostig,mattjj): We don't override __eq__ or __hash__, so comparison is 171 # by object id, but pretty printing might collide. 172 count: int 173 suffix: str 174 aval: 'AbstractValue' 175 176 def __init__(self, count: int, suffix: str, aval: 'AbstractValue'): 177 self.count = count 178 self.suffix = suffix 179 self.aval = raise_to_shaped(aval) 180 181 def __lt__(self, other): 182 if not isinstance(other, Var): 183 return NotImplemented 184 else: 185 return (self.count, self.suffix) < (other.count, other.suffix) 186 187 def __repr__(self): 188 rem = self.count 189 s = '' 190 while True: 191 rem, i = rem // 26, rem % 26 192 s = chr(97 + i % 26) + s 193 if not rem: 194 break 195 return s + self.suffix 196 197def _jaxpr_vars(jaxpr): 198 return it.chain( 199 jaxpr.invars, jaxpr.constvars, 200 (v for eqn in jaxpr.eqns for v in eqn.outvars)) 201 202def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None, 203 suffix: str = '') -> Callable[['AbstractValue'], Var]: 204 """Produce distinct variables, printed with the optional suffix. 205 206 If `jaxprs` is provided, the variables produced will be distinct from those in 207 any of the given jaxprs. 208 """ 209 if jaxprs is None: 210 start = 0 211 else: 212 all_vars = it.chain.from_iterable(_jaxpr_vars(j) for j in jaxprs) 213 start = 1 + max((v.count for v in all_vars), default=-1) 214 counter = it.count(start=start) 215 return lambda aval: Var(next(counter), suffix, aval) 216 217# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that 218# the assignment is dropped, i.e. that an expression's output value will never 219# be read. In that sense, `dropvar` is not a variable, but it is convenient to 220# treat it as a special case of one. Its `aval` is similarly inexact. 221class DropVar(Var): 222 count = -1 223 suffix = '' 224 def __init__(self): pass 225 @property 226 def aval(self): return abstract_unit 227 def __repr__(self): return '_' 228dropvar = DropVar() 229 230class Literal: 231 __slots__ = ["val", "hash"] 232 233 val: Any 234 hash: Optional[int] 235 236 def __init__(self, val): 237 self.val = val 238 try: 239 self.hash = hash(val) 240 except TypeError: 241 if type(val) in literalable_types: 242 try: 243 self.hash = hash((val.item(), val.dtype)) 244 except (TypeError, AttributeError, ValueError): 245 self.hash = None 246 247 @property 248 def aval(self): 249 return raise_to_shaped(get_aval(self.val)) 250 251 def __hash__(self): 252 assert False 253 254 def __repr__(self): 255 if hasattr(self, 'hash'): 256 return '{}'.format(self.val) 257 else: 258 return 'Literal(val={})'.format(self.val) 259 260literalable_types: Set[type] = set() 261 262Atom = Union[Var, Literal] 263 264class Primitive: 265 name: str 266 multiple_results = False # set for multi-output primitives 267 call_primitive = False # set for call primitives processed in final style 268 map_primitive = False # set for map primitives processed in final style 269 270 def __init__(self, name: str): 271 self.name = name 272 273 def __repr__(self): 274 return '{}'.format(self.name) 275 276 277 def bind(self, *args, **params): 278 assert skip_checks or all(isinstance(arg, Tracer) 279 or valid_jaxtype(arg) for arg in args), args 280 top_trace = find_top_trace(args) 281 tracers = map(top_trace.full_raise, args) 282 out = top_trace.process_primitive(self, tracers, params) 283 return map(full_lower, out) if self.multiple_results else full_lower(out) 284 285 def def_impl(self, impl): 286 self.impl = impl 287 return impl 288 289 def def_abstract_eval(self, abstract_eval): 290 self.abstract_eval = abstract_eval 291 return abstract_eval 292 293 def def_custom_bind(self, bind): 294 self.bind = bind 295 return bind 296 297 def impl(self, *args, **params): 298 raise NotImplementedError("Evaluation rule for '{}' not implemented" 299 .format(self.name)) 300 301 def abstract_eval(self, *args, **params): 302 raise NotImplementedError("Abstract evaluation for '{}' not implemented" 303 .format(self.name)) 304 305 306# -------------------- lifting -------------------- 307 308# TODO(necula): this belongs next to pe.new_eqn_recipe, but is needed in 309# core.py. Plan to move all these utilities to jaxpr.py. 310def extract_call_jaxpr( 311 primitive: Primitive, 312 params: Dict[str, Any]) -> Tuple[Optional[Jaxpr], Dict[str, Any]]: 313 """Extract the call primitive subjaxpr from the params. 314 315 Returns the subjaxpr and the params without the "call_jaxpr" value. If this is 316 not a call primitive then returns (None, params). 317 """ 318 if not (primitive.call_primitive or primitive.map_primitive): 319 return (None, params) 320 else: 321 assert "call_jaxpr" in params 322 new_params = dict(params) 323 del new_params["call_jaxpr"] 324 return (params["call_jaxpr"], new_params) 325 326 327def traverse_jaxpr_params(f, params): 328 """Applies f to each jaxpr parameter and returns a tuple of returned values.""" 329 return tuple(f(param if type(param) is Jaxpr else param.jaxpr) 330 for param in params.values() 331 if type(param) in (Jaxpr, ClosedJaxpr)) 332 333 334def eval_jaxpr(jaxpr: Jaxpr, consts, *args): 335 def read(v): 336 if type(v) is Literal: 337 return v.val 338 else: 339 return env[v] 340 341 def write(v, val): 342 env[v] = val 343 344 env: Dict[Var, Any] = {} 345 write(unitvar, unit) 346 map(write, jaxpr.constvars, consts) 347 map(write, jaxpr.invars, args) 348 for eqn in jaxpr.eqns: 349 in_vals = map(read, eqn.invars) 350 call_jaxpr, params = extract_call_jaxpr(eqn.primitive, eqn.params) 351 if call_jaxpr: 352 subfuns = [lu.wrap_init(partial(eval_jaxpr, call_jaxpr, ()))] 353 else: 354 subfuns = [] 355 if eqn.primitive.map_primitive: 356 out_axes_thunk = HashableFunction(lambda: params['out_axes'], 357 closure=params['out_axes']) 358 bind_params = dict(params, out_axes_thunk=out_axes_thunk) 359 del bind_params['out_axes'] 360 else: 361 bind_params = params 362 with source_info_util.user_context(eqn.source_info): 363 ans = eqn.primitive.bind(*(subfuns + in_vals), **bind_params) 364 if eqn.primitive.multiple_results: 365 map(write, eqn.outvars, ans) 366 else: 367 write(eqn.outvars[0], ans) 368 return map(read, jaxpr.outvars) 369 370 371# -------------------- tracing -------------------- 372 373 374class Trace: 375 __slots__ = ['main', 'level', 'sublevel'] 376 377 main: 'MainTrace' 378 level: int 379 sublevel: 'Sublevel' 380 381 def __init__(self, main: 'MainTrace', sublevel: 'Sublevel') -> None: 382 self.main = main 383 self.level = main.level 384 self.sublevel = sublevel 385 386 def full_raise(self, val) -> 'Tracer': 387 if not isinstance(val, Tracer): 388 return self.pure(val) 389 val._assert_live() 390 level = self.level 391 sublevel = self.sublevel 392 if val._trace.main is self.main: 393 if val._trace.sublevel == sublevel: 394 return val 395 elif val._trace.sublevel < sublevel: 396 return self.sublift(val) 397 else: 398 raise escaped_tracer_error( 399 val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}") 400 elif val._trace.level < level: 401 if val._trace.sublevel > sublevel: 402 raise escaped_tracer_error( 403 val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}") 404 return self.lift(val) 405 elif val._trace.level > level: 406 raise escaped_tracer_error( 407 val, f"Can't lift level {val} to {self}") 408 else: # val._trace.level == self.level: 409 raise escaped_tracer_error( 410 val, f"Different traces at same level: {val}, {self}") 411 412 def pure(self, val): 413 raise NotImplementedError("must override") 414 415 def lift(self, tracer): 416 raise NotImplementedError("must override") 417 418 def sublift(self, tracer): 419 raise NotImplementedError("must override") 420 421 def process_primitive(self, primitive, tracers, params): 422 raise NotImplementedError("must override") 423 424 def __repr__(self): 425 return '{}(level={}/{})'.format( 426 self.__class__.__name__, self.level, self.sublevel) 427 428 def process_call(self, call_primitive, f, tracers, params): 429 msg = (f"{type(self)} must override process_call to handle call-like " 430 "primitives") 431 raise NotImplementedError(msg) 432 433 def process_map(self, call_primitive, f, tracers, params): 434 msg = (f"{type(self)} must override process_map to handle map-like " 435 "primitives") 436 raise NotImplementedError(msg) 437 438 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): 439 msg = (f"{type(self)} must override process_custom_jvp_call " 440 "to handle custom_jvp primitives") 441 raise NotImplementedError(msg) 442 443 def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees): 444 msg = (f"{type(self)} must override process_custom_vjp_call " 445 "to handle custom_vjp primitives") 446 raise NotImplementedError(msg) 447 448def escaped_tracer_error(tracer, detail=None): 449 num_frames = FLAGS.jax_tracer_error_num_traceback_frames 450 msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped " 451 "through global state from a previously traced function.\n" 452 "The functions being transformed should not save traced values to " 453 "global state.") 454 if detail: 455 msg += " Detail: {}.".format(detail) 456 try: 457 line_info = tracer._line_info 458 except AttributeError: 459 pass 460 else: 461 msg += ('\nThe tracer that caused this error was created on line ' 462 f'{source_info_util.summarize(line_info)}.\n') 463 if num_frames > 0: 464 msg += (f'When the tracer was created, the final {num_frames} stack ' 465 'frames (most recent last) excluding JAX-internal frames were:\n' 466 f'{source_info_util.summarize(line_info, num_frames=num_frames)}') 467 try: 468 fun_source_info = tracer._trace.main.source_info 469 except AttributeError: 470 pass 471 else: 472 msg += ('\nThe function being traced when the tracer leaked was ' 473 f'{fun_source_info}.') 474 msg += ('\nTo catch the leak earlier, try setting the environment variable ' 475 'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context ' 476 'manager.') 477 return UnexpectedTracerError(msg) 478 479class UnexpectedTracerError(Exception): pass 480 481class Tracer: 482 __array_priority__ = 1000 483 __slots__ = ['_trace', '__weakref__', '_line_info'] 484 485 def __array__(self, *args, **kw): 486 msg = ("The numpy.ndarray conversion method __array__() was called on " 487 f"the JAX Tracer object {self}.\n\n" 488 "This error can occur when a JAX Tracer object is passed to a raw " 489 "numpy function, or a method on a numpy.ndarray object. You might " 490 "want to check that you are using `jnp` together with " 491 "`import jax.numpy as jnp` rather than using `np` via " 492 "`import numpy as np`. If this error arises on a line that involves " 493 "array indexing, like `x[idx]`, it may be that the array being " 494 "indexed `x` is a raw numpy.ndarray while the indices `idx` are a " 495 "JAX Tracer instance; in that case, you can instead write " 496 "`jax.device_put(x)[idx]`.") 497 raise Exception(msg) 498 499 def __init__(self, trace: Trace): 500 self._trace = trace 501 502 def __iter__(self): 503 return iter(self.aval._iter(self)) 504 505 def __len__(self): 506 return self.aval._len(self) 507 508 @property 509 def aval(self): 510 raise NotImplementedError("must override") 511 512 def _assert_live(self) -> None: 513 pass # Override for liveness checking 514 515 # Python looks up special methods only on classes, not instances. This means 516 # these methods needs to be defined explicitly rather than relying on 517 # __getattr__. 518 def __neg__(self): return self.aval._neg(self) 519 def __pos__(self): return self.aval._pos(self) 520 def __eq__(self, other): return self.aval._eq(self, other) 521 def __ne__(self, other): return self.aval._ne(self, other) 522 def __lt__(self, other): return self.aval._lt(self, other) 523 def __le__(self, other): return self.aval._le(self, other) 524 def __gt__(self, other): return self.aval._gt(self, other) 525 def __ge__(self, other): return self.aval._ge(self, other) 526 def __abs__(self): return self.aval._abs(self) 527 def __add__(self, other): return self.aval._add(self, other) 528 def __radd__(self, other): return self.aval._radd(self, other) 529 def __sub__(self, other): return self.aval._sub(self, other) 530 def __rsub__(self, other): return self.aval._rsub(self, other) 531 def __mul__(self, other): return self.aval._mul(self, other) 532 def __rmul__(self, other): return self.aval._rmul(self, other) 533 def __div__(self, other): return self.aval._div(self, other) 534 def __rdiv__(self, other): return self.aval._rdiv(self, other) 535 def __truediv__(self, other): return self.aval._truediv(self, other) 536 def __rtruediv__(self, other): return self.aval._rtruediv(self, other) 537 def __floordiv__(self, other): return self.aval._floordiv(self, other) 538 def __rfloordiv__(self, other): return self.aval._rfloordiv(self, other) 539 def __divmod__(self, other): return self.aval._divmod(self, other) 540 def __rdivmod__(self, other): return self.aval._rdivmod(self, other) 541 def __mod__(self, other): return self.aval._mod(self, other) 542 def __rmod__(self, other): return self.aval._rmod(self, other) 543 def __pow__(self, other): return self.aval._pow(self, other) 544 def __rpow__(self, other): return self.aval._rpow(self, other) 545 def __matmul__(self, other): return self.aval._matmul(self, other) 546 def __rmatmul__(self, other): return self.aval._rmatmul(self, other) 547 def __and__(self, other): return self.aval._and(self, other) 548 def __rand__(self, other): return self.aval._rand(self, other) 549 def __or__(self, other): return self.aval._or(self, other) 550 def __ror__(self, other): return self.aval._ror(self, other) 551 def __xor__(self, other): return self.aval._xor(self, other) 552 def __rxor__(self, other): return self.aval._rxor(self, other) 553 def __invert__(self): return self.aval._invert(self) 554 def __lshift__(self, other): return self.aval._lshift(self, other) 555 def __rlshift__(self, other): return self.aval._rlshift(self, other) 556 def __rshift__(self, other): return self.aval._rshift(self, other) 557 def __rrshift__(self, other): return self.aval._rrshift(self, other) 558 def __getitem__(self, idx): return self.aval._getitem(self, idx) 559 def __nonzero__(self): return self.aval._nonzero(self) 560 def __bool__(self): return self.aval._bool(self) 561 def __int__(self): return self.aval._int(self) 562 def __long__(self): return self.aval._long(self) 563 def __hex__(self): return self.aval._hex(self) 564 def __oct__(self): return self.aval._oct(self) 565 def __float__(self): return self.aval._float(self) 566 def __complex__(self): return self.aval._complex(self) 567 568 def __setitem__(self, idx, val): 569 raise TypeError("JAX 'Tracer' objects do not support item assignment") 570 571 # NumPy also only looks up special methods on classes. 572 def __array_module__(self, types): return self.aval._array_module(self, types) 573 574 def __getattr__(self, name): 575 # if the aval property raises an AttributeError, gets caught here 576 assert skip_checks or name != "aval" 577 578 try: 579 attr = getattr(self.aval, name) 580 except KeyError as err: 581 raise AttributeError( 582 "{} has no attribute {}".format(self.__class__.__name__, name) 583 ) from err 584 else: 585 t = type(attr) 586 if t is aval_property: 587 return attr.fget(self) 588 elif t is aval_method: 589 return types.MethodType(attr.fun, self) 590 else: 591 return attr 592 593 def __repr__(self): 594 base = pp('Traced<{}>with<{}>'.format(self.aval, self._trace)) 595 contents = self._contents() 596 if contents: 597 base += pp(' with ') >> vcat(pp('{} = '.format(name)) >> pp_payload 598 for name, pp_payload in contents) 599 return str(base) 600 601 def _contents(self): 602 try: 603 return [(name, pp(repr(getattr(self, name)))) for name in self.__slots__] 604 except AttributeError: 605 return () 606 607 def __copy__(self): 608 return self 609 610 def __deepcopy__(self, unused_memo): 611 return self 612 613 def _origin_msg(self) -> str: 614 return "" 615 616# these can be used to set up forwarding of properties and instance methods from 617# Tracer instances to the underlying avals 618aval_property = namedtuple("aval_property", ["fget"]) 619aval_method = namedtuple("aval_method", ["fun"]) 620 621 622class EvalTrace(Trace): 623 # See comments in https://github.com/google/jax/pull/3370 624 def pure(self, x): return x 625 lift = sublift = pure 626 627 def process_primitive(self, primitive, tracers, params): 628 return primitive.impl(*tracers, **params) 629 630 def process_call(self, primitive, f, tracers, params): 631 return primitive.impl(f, *tracers, **params) 632 process_map = process_call 633 634 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): 635 del primitive, jvp # Unused. 636 return fun.call_wrapped(*tracers) 637 638 def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees): 639 del primitive, fwd, bwd, out_trees # Unused. 640 return fun.call_wrapped(*tracers) 641 642 643class MainTrace: 644 level: int 645 trace_type: Type[Trace] 646 payload: Dict[str, Any] 647 648 def __init__(self, level, trace_type, **payload) -> None: 649 self.level = level 650 self.trace_type = trace_type 651 self.payload = payload 652 653 def __repr__(self) -> str: 654 return "MainTrace({},{})".format(self.level, self.trace_type.__name__) 655 656 def __hash__(self) -> int: 657 return hash((self.level, self.trace_type)) 658 659 def __eq__(self, other: object) -> bool: 660 return (isinstance(other, MainTrace) and 661 self.level == other.level and 662 self.trace_type == other.trace_type and 663 self.payload == other.payload) 664 665 def with_cur_sublevel(self): 666 return self.trace_type(self, cur_sublevel(), **self.payload) 667 668class TraceStack: 669 # See comments in https://github.com/google/jax/pull/3370 670 stack: List[MainTrace] 671 dynamic: MainTrace 672 673 def __init__(self): 674 eval_trace = MainTrace(0, EvalTrace) 675 self.stack = [eval_trace] 676 self.dynamic = eval_trace 677 678 def next_level(self) -> int: 679 return len(self.stack) 680 681 def push(self, main_trace: MainTrace) -> None: 682 self.stack.append(main_trace) 683 684 def pop(self) -> None: 685 self.stack.pop() 686 687 def __repr__(self) -> str: 688 stack_str = map(' {}\n'.format, self.stack[::-1]) 689 return f'Trace stack\n{stack_str}\n{self.dynamic}' 690 691 def copy(self): 692 new = self.__new__(TraceStack) 693 new.stack = self.stack[:] 694 new.dynamic = self.dynamic 695 return new 696 697class Sublevel(int): pass 698AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace']) 699AxisName = Hashable 700 701class TraceState: 702 trace_stack: TraceStack 703 substack: List[Sublevel] 704 axis_env: List[AxisEnvFrame] 705 706 def __init__(self) -> None: 707 self.trace_stack = TraceStack() 708 self.substack = [Sublevel(0)] 709 self.axis_env = [] 710 711 def copy(self): 712 new = self.__new__(TraceState) 713 new.trace_stack = self.trace_stack.copy() 714 new.substack = self.substack[:] 715 new.axis_env = self.axis_env[:] 716 return new 717 718# The global state of the tracer is accessed by a thread-local object. 719# This allows concurrent tracing in separate threads; passing traced objects 720# between threads is forbidden. 721class ThreadLocalState(threading.local): 722 def __init__(self): 723 self.trace_state = TraceState() 724thread_local_state = ThreadLocalState() 725 726def trace_state_clean() -> bool: 727 trace_state = thread_local_state.trace_state 728 return (trace_state.substack == [Sublevel(0)] and 729 trace_state.axis_env == [] and 730 trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and 731 trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace)) 732 733def reset_trace_state() -> bool: 734 "Reset the global trace state and return True if it was already clean." 735 if not trace_state_clean(): 736 thread_local_state.trace_state.__init__() # type: ignore 737 return False 738 else: 739 return True 740 741def cur_sublevel() -> Sublevel: 742 return thread_local_state.trace_state.substack[-1] 743 744@contextmanager 745def new_main(trace_type: Type[Trace], 746 dynamic: bool = False, 747 **payload) -> Generator[MainTrace, None, None]: 748 # See comments in https://github.com/google/jax/pull/3370 749 stack = thread_local_state.trace_state.trace_stack 750 level = stack.next_level() 751 main = MainTrace(level, trace_type, **payload) 752 stack.push(main) 753 if dynamic: 754 prev_dynamic, stack.dynamic = stack.dynamic, main 755 756 try: 757 yield main 758 finally: 759 stack.pop() 760 if dynamic: 761 stack.dynamic = prev_dynamic 762 763 if debug_state.check_leaks: 764 t = ref(main) 765 del main 766 if t() is not None: 767 raise Exception(f'Leaked trace {t()}') 768 769@contextmanager 770def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]: 771 # See comments in https://github.com/google/jax/pull/3370 772 stack = thread_local_state.trace_state.trace_stack 773 main = MainTrace(0, trace_type) 774 prev_dynamic, stack.dynamic = stack.dynamic, main 775 prev_base, stack.stack[0] = stack.stack[0], main 776 try: 777 yield main 778 finally: 779 stack.dynamic = prev_dynamic 780 stack.stack[0] = prev_base 781 782 if debug_state.check_leaks: 783 t = ref(main) 784 del main 785 if t() is not None: 786 raise Exception('Leaked trace {}'.format(t())) 787 788@contextmanager 789def eval_context(): 790 with new_base_main(EvalTrace): 791 yield 792 793@contextmanager 794def new_sublevel() -> Generator[None, None, None]: 795 sublevel = Sublevel(len(thread_local_state.trace_state.substack)) 796 thread_local_state.trace_state.substack.append(sublevel) 797 try: 798 yield 799 finally: 800 thread_local_state.trace_state.substack.pop() 801 802 # TODO(mattjj): to check sublevel leaks, we need to make Sublevel weakref-able 803 # if debug_state.check_leaks: 804 # t = ref(sublevel) 805 # del sublevel 806 # if t() is not None: 807 # raise Exception('Leaked sublevel {}'.format(t())) 808 809def maybe_new_sublevel(trace): 810 # dynamic traces run the WrappedFun, so we raise the sublevel for them 811 dynamic = thread_local_state.trace_state.trace_stack.dynamic 812 return new_sublevel() if trace.main is dynamic else suppress() 813 814def full_lower(val): 815 if isinstance(val, Tracer): 816 return val.full_lower() 817 else: 818 return val 819 820def find_top_trace(xs) -> Trace: 821 top_tracer = max((x for x in xs if isinstance(x, Tracer)), 822 default=None, key=attrgetter('_trace.level')) 823 if top_tracer is not None: 824 top_tracer._assert_live() 825 top_main = top_tracer._trace.main # type: Optional[MainTrace] 826 else: 827 top_main = None 828 dynamic = thread_local_state.trace_state.trace_stack.dynamic 829 top_main = (dynamic if top_main is None or dynamic.level > top_main.level 830 else top_main) 831 return top_main and top_main.with_cur_sublevel() # type: ignore 832 833 834# -------------------- abstract values -------------------- 835 836 837class AbstractValue: 838 __slots__: List[str] = [] 839 _num_buffers: int = 1 # number of buffers used to represent the value. 840 841 def at_least_vspace(self): 842 raise NotImplementedError("must override") 843 844 def __repr__(self): 845 try: 846 kv_pairs = ('{}={}'.format(k, v) for k, v in self.__dict__.items()) 847 return '{}({})'.format(self.__class__.__name__, ','.join(kv_pairs)) 848 except AttributeError: 849 return self.__class__.__name__ 850 851 def strip_weak_type(self) -> 'AbstractValue': 852 return self 853 854 def join(self, other): 855 raise NotImplementedError("must override") 856 857class Bot(AbstractValue): pass 858 859bot = Bot() 860 861class AbstractUnit(AbstractValue): 862 # TODO(jakevdp): make it possible to set zero buffers 863 # _num_buffers = 0 864 def at_least_vspace(self): return self 865 def join(self, other): 866 if not skip_checks: 867 assert other is abstract_unit, other 868 return self 869 def _eq(self, self_traced, other): return get_aval(other) is self 870 def str_short(self): return '*' 871 872abstract_unit = AbstractUnit() 873 874def lattice_join(x: Optional[AbstractValue], 875 y: Optional[AbstractValue]) -> AbstractValue: 876 if x is None: 877 return cast(AbstractValue, y) 878 elif y is None: 879 return cast(AbstractValue, x) 880 elif isinstance(x, type(y)): 881 return y.join(x) 882 elif isinstance(y, type(x)): 883 return x.join(y) 884 else: 885 raise TypeError((x, y)) 886 887# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. 888Value = Any 889 890def valid_jaxtype(x): 891 try: 892 concrete_aval(x) 893 except TypeError: 894 return False 895 else: 896 return True 897 898def check_valid_jaxtype(x): 899 if not valid_jaxtype(x): 900 raise TypeError(f"{x} of type {type(x)} is not a valid JAX type") 901 902 903def concrete_aval(x): 904 for typ in type(x).mro(): 905 handler = pytype_aval_mappings.get(typ) 906 if handler: return handler(x) 907 raise TypeError(f"{type(x)} is not a valid JAX type") 908 909 910def get_aval(x): 911 if isinstance(x, Tracer): 912 return x.aval 913 else: 914 return concrete_aval(x) 915 916 917pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {} 918 919 920class Unit: 921 def __repr__(self): return '*' 922unit = Unit() 923literalable_types.add(Unit) 924 925class UnitVar(Var): 926 count = -1 927 suffix = '' 928 def __init__(self): pass 929 @property 930 def aval(self): return abstract_unit 931 def __repr__(self): return '*' 932unitvar = UnitVar() 933 934pytype_aval_mappings[Unit] = lambda _: abstract_unit 935 936class ConcretizationTypeError(TypeError): pass 937 938def raise_concretization_error(val: Tracer, context=""): 939 msg = ("Abstract tracer value encountered where concrete value is expected.\n\n" 940 + context + "\n\n" 941 + val._origin_msg() + "\n\n" 942 "See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n" 943 f"Encountered tracer value: {val}") 944 raise ConcretizationTypeError(msg) 945 946 947def concretization_function_error(fun, suggest_astype=False): 948 fname = getattr(fun, "__name__", fun) 949 fname_context = f"The problem arose with the `{fname}` function. " 950 if suggest_astype: 951 fname_context += ("If trying to convert the data type of a value, " 952 f"try using `x.astype({fun.__name__})` " 953 f"or `jnp.array(x, {fun.__name__})` instead.") 954 def error(self, arg): 955 raise_concretization_error(arg, fname_context) 956 return error 957 958 959def concrete_or_error(force: Any, val: Any, context=""): 960 """Like force(val), but gives the context in the error message.""" 961 if force is None: 962 force = lambda x: x 963 if isinstance(val, Tracer): 964 if isinstance(val.aval, ConcreteArray): 965 return force(val.aval.val) 966 else: 967 raise_concretization_error(val, context) 968 else: 969 return force(val) 970 971class UnshapedArray(AbstractValue): 972 __slots__ = ['dtype', 'weak_type'] 973 array_abstraction_level = 2 974 975 def __init__(self, dtype, weak_type=False): 976 self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype)) 977 self.weak_type = weak_type 978 979 def __eq__(self, other): 980 return (type(self) is type(other) and self.dtype == other.dtype and 981 self.weak_type == other.weak_type) 982 983 def __ne__(self, other): 984 return not self == other 985 986 def __hash__(self): 987 # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype 988 # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use 989 # the unique character code via hash(self.dtype.char) 990 return hash((self.dtype, self.weak_type)) 991 992 def __repr__(self): 993 return '{}({}{})'.format(self.__class__.__name__, self.str_short(), 994 ", weak_type=True" if self.weak_type else "") 995 996 _bool = _nonzero = concretization_function_error(bool) 997 _float = concretization_function_error(float, True) 998 _int = concretization_function_error(int, True) 999 _complex = concretization_function_error(complex, True) 1000 _hex = concretization_function_error(hex) 1001 _oct = concretization_function_error(oct) 1002 1003 def at_least_vspace(self) -> AbstractValue: 1004 return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype), 1005 self.weak_type) 1006 1007 def join(self, other): 1008 if self.dtype == other.dtype: 1009 if self.weak_type == other.weak_type: 1010 return self 1011 else: 1012 return UnshapedArray(self.dtype, weak_type=False) 1013 else: 1014 raise TypeError(self, other) 1015 1016 def str_short(self) -> str: 1017 return self.dtype.name 1018 1019 def strip_weak_type(self) -> 'UnshapedArray': 1020 """Returns a copy of the aval with weak_type=False.""" 1021 return UnshapedArray(self.dtype) if self.weak_type else self 1022 1023 @property 1024 def shape(self): 1025 msg = ("UnshapedArray has no shape. Please open an issue at " 1026 "https://github.com/google/jax/issues because it's unexpected for " 1027 "UnshapedArray instances to ever be produced.") 1028 raise TypeError(msg) 1029 1030class ShapedArray(UnshapedArray): 1031 __slots__ = ['shape'] 1032 array_abstraction_level = 1 1033 1034 def __init__(self, shape, dtype, weak_type=False): 1035 super(ShapedArray, self).__init__(dtype, weak_type=weak_type) 1036 self.shape = canonicalize_shape(shape) 1037 1038 ndim = property(lambda self: len(self.shape)) 1039 size = property(lambda self: prod(self.shape)) 1040 1041 broadcast: ClassVar[Optional[aval_method]] = None 1042 transpose: ClassVar[Optional[aval_method]] = None 1043 reshape: ClassVar[Optional[aval_method]] = None 1044 _iter: ClassVar[Optional[staticmethod]] = None 1045 1046 def __eq__(self, other): 1047 return (type(self) is type(other) 1048 and self.dtype == other.dtype and self.shape == other.shape 1049 and self.weak_type == other.weak_type) 1050 1051 def __hash__(self): 1052 # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype 1053 # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use 1054 # the unique character code via hash(self.dtype.char) 1055 return hash((self.shape, self.dtype, self.weak_type)) 1056 1057 def at_least_vspace(self): 1058 return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), 1059 self.weak_type) 1060 1061 def join(self, other): 1062 if self.shape == other.shape and self.dtype == other.dtype: 1063 if self.weak_type == other.weak_type: 1064 return self 1065 else: 1066 return ShapedArray(self.shape, self.dtype, weak_type=False) 1067 elif self.dtype == other.dtype: 1068 return UnshapedArray(self.dtype) 1069 else: 1070 raise TypeError(self, other) 1071 1072 def str_short(self): 1073 shapestr = ','.join(map(str, self.shape)) 1074 return '{}[{}]'.format(self.dtype.name, shapestr) 1075 1076 def __len__(self): 1077 try: 1078 return self.shape[0] 1079 except IndexError as err: 1080 raise TypeError("len() of unsized object") from err # same as numpy error 1081 1082 def _len(self, ignored_tracer): 1083 return len(self) 1084 1085 def strip_weak_type(self): 1086 return ShapedArray(self.shape, self.dtype) if self.weak_type else self 1087 1088 1089def _forward_to_value(self, fun, ignored_tracer, *args): 1090 return fun(self.val, *args) 1091 1092class ConcreteArray(ShapedArray): 1093 __slots__ = ['val'] 1094 array_abstraction_level = 0 1095 1096 def __init__(self, val, weak_type=False): 1097 super(ConcreteArray, self).__init__(np.shape(val), np.result_type(val), 1098 weak_type=weak_type) 1099 # Note: canonicalized self.dtype doesn't necessarily match self.val 1100 self.val = val 1101 assert self.dtype != np.dtype('O'), val 1102 1103 def __eq__(self, other): 1104 if (type(self) is type(other) and self.dtype == other.dtype 1105 and self.shape == other.shape and self.weak_type == other.weak_type): 1106 with eval_context(): # in case self.val is a DeviceArray 1107 return (self.val == other.val).all() 1108 else: 1109 return False 1110 1111 def __hash__(self): 1112 return id(self.val) 1113 1114 def at_least_vspace(self): 1115 return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), 1116 weak_type=self.weak_type) 1117 1118 def join(self, other) -> UnshapedArray: 1119 if self == other: 1120 return self 1121 elif self.shape == other.shape and self.dtype == other.dtype: 1122 return ShapedArray(self.shape, self.dtype, 1123 weak_type=self.weak_type and other.weak_type) 1124 elif self.dtype == other.dtype: 1125 return UnshapedArray(self.dtype, 1126 weak_type=self.weak_type and other.weak_type) 1127 else: 1128 raise TypeError(self, other) 1129 1130 def str_short(self) -> str: 1131 return str(self.val) 1132 1133 def strip_weak_type(self) -> 'ConcreteArray': 1134 return ConcreteArray(self.val) if self.weak_type else self 1135 1136 _bool = _nonzero = partialmethod(_forward_to_value, bool) 1137 _int = partialmethod(_forward_to_value, int) 1138 _hex = partialmethod(_forward_to_value, hex) 1139 _oct = partialmethod(_forward_to_value, oct) 1140 1141 _float = concretization_function_error(float, True) 1142 _complex = concretization_function_error(complex, True) 1143 1144def primal_dtype_to_tangent_dtype(primal_dtype): 1145 if not dtypes.issubdtype(primal_dtype, np.inexact): 1146 return dtypes.float0 1147 else: 1148 return primal_dtype 1149 1150class AbstractToken(AbstractValue): 1151 def join(self, other): 1152 if isinstance(other, AbstractToken): 1153 return self 1154 else: 1155 assert False, f"Cannot join {self} with {other}" 1156 def str_short(self): return 'Tok' 1157 def at_least_vspace(self): return self 1158 1159abstract_token: AbstractToken = AbstractToken() 1160 1161 1162def raise_to_shaped(aval: AbstractValue, weak_type=None): 1163 if weak_type is None: 1164 weak_type = getattr(aval, 'weak_type', False) 1165 for typ in type(aval).mro(): 1166 handler = raise_to_shaped_mappings.get(typ) 1167 if handler: return handler(aval, weak_type) 1168 raise TypeError(type(aval)) 1169 1170raise_to_shaped_mappings : Dict[type, Callable] = { 1171 AbstractUnit: lambda aval, _: aval, 1172 AbstractToken: lambda aval, _: aval, 1173 ShapedArray: lambda aval, weak_type: ShapedArray(aval.shape, aval.dtype, weak_type=weak_type) 1174} 1175 1176# Registry for valid dimension types. This is used by masking.Poly. 1177_DIMENSION_TYPES: Set[type] = {int} 1178 1179def _canonicalize_dimension(dim): 1180 if type(dim) in _DIMENSION_TYPES: 1181 return dim 1182 else: 1183 return operator.index(dim) 1184 1185def canonicalize_shape(shape): 1186 """Canonicalizes and checks for errors in a user-provided shape value. 1187 1188 Args: 1189 shape: a Python value that represents a shape. 1190 1191 Returns: 1192 A tuple of integers. 1193 """ 1194 try: 1195 return tuple(map(_canonicalize_dimension, shape)) 1196 except TypeError: 1197 pass 1198 msg = ("Shapes must be 1D sequences of concrete values of integer type, " 1199 "got {}.") 1200 if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) 1201 and not isinstance(get_aval(x), ConcreteArray) for x in shape): 1202 msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to " 1203 "smaller subfunctions.") 1204 raise TypeError(msg.format(shape)) 1205 1206 1207# ------------------- Call ------------------- 1208 1209def apply_todos(todos, outs): 1210 todos_list = list(todos) 1211 while todos_list: 1212 outs = map(full_lower, todos_list.pop()(outs)) 1213 return outs 1214 1215class _IgnoreElemList(list): 1216 """Compares equal to all other _ignore_elem_lists.""" 1217 def __hash__(self): return 0 1218 def __eq__(self, other): 1219 return type(other) is _IgnoreElemList 1220 1221@lu.transformation_with_aux 1222def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'], 1223 level: int, params_tuple: tuple, out_axes_transforms, *args): 1224 outs = yield args, {} 1225 params = dict(params_tuple) 1226 todo = [] 1227 assert not out_axes_transforms 1228 while True: 1229 tracers = [x for x in outs if isinstance(x, Tracer) 1230 and (level is None or x._trace.level > level)] 1231 if tracers: 1232 ans = max(tracers, key=lambda x: x._trace.level) 1233 else: 1234 break 1235 trace = ans._trace.main.with_cur_sublevel() 1236 outs = map(trace.full_raise, outs) 1237 outs, cur_todo = primitive.post_process(trace, outs, params) 1238 if isinstance(primitive, MapPrimitive): 1239 cur_todo, out_axes_transform = cur_todo 1240 out_axes_transforms.append(out_axes_transform) 1241 todo.append(cur_todo) 1242 yield outs, tuple(todo) # Ensure the aux output is immutable 1243 1244def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'], 1245 fun, *args, **params): 1246 out_axes_transforms = _IgnoreElemList() 1247 if primitive.map_primitive: 1248 out_axes_thunk = params['out_axes_thunk'] 1249 # The new thunk depends deterministically on the old thunk and the wrapped function. 1250 # Any caching already has to include the wrapped function as part of the key, so we 1251 # only use the previous thunk for equality checks. 1252 @as_hashable_function(closure=out_axes_thunk) 1253 def new_out_axes_thunk(): 1254 out_axes = out_axes_thunk() 1255 for t in out_axes_transforms: 1256 out_axes = t(out_axes) 1257 return out_axes 1258 params = dict(params, out_axes_thunk=new_out_axes_thunk) 1259 params_tuple = tuple(params.items()) 1260 top_trace = find_top_trace(args) 1261 fun, env_trace_todo = process_env_traces( 1262 fun, primitive, top_trace and top_trace.level, 1263 params_tuple, out_axes_transforms) 1264 tracers = map(top_trace.full_raise, args) 1265 with maybe_new_sublevel(top_trace): 1266 outs = primitive.process(top_trace, fun, tracers, params) 1267 return map(full_lower, apply_todos(env_trace_todo(), outs)) 1268 1269 1270class CallPrimitive(Primitive): 1271 multiple_results = True 1272 call_primitive = True 1273 1274 def bind(self, fun, *args, **params): 1275 return call_bind(self, fun, *args, **params) 1276 1277 def process(self, trace, fun, tracers, params): 1278 return trace.process_call(self, fun, tracers, params) 1279 1280 def post_process(self, trace, out_tracers, params): 1281 return trace.post_process_call(self, out_tracers, params) 1282 1283def call_impl(f: lu.WrappedFun, *args, **params): 1284 del params # params parameterize the call primitive, not the function 1285 return f.call_wrapped(*args) 1286 1287call_p = CallPrimitive('call') 1288call = call_p.bind 1289call_p.def_impl(call_impl) 1290 1291named_call_p = CallPrimitive('named_call') 1292named_call_p.def_impl(call_impl) 1293 1294# ------------------- Map ------------------- 1295 1296class MapPrimitive(Primitive): 1297 multiple_results = True 1298 map_primitive = True 1299 1300 def bind(self, fun, *args, **params): 1301 assert len(params['in_axes']) == len(args) 1302 return call_bind(self, fun, *args, **params) 1303 1304 def process(self, trace, fun, tracers, params): 1305 return trace.process_map(self, fun, tracers, params) 1306 1307 def post_process(self, trace, out_tracers, params): 1308 return trace.post_process_map(self, out_tracers, params) 1309 1310@contextmanager 1311def extend_axis_env(axis_name: AxisName, size: int, tag: Any): 1312 frame = AxisEnvFrame(axis_name, size, tag) 1313 thread_local_state.trace_state.axis_env.append(frame) 1314 try: 1315 yield 1316 finally: 1317 thread_local_state.trace_state.axis_env.pop() 1318 1319@contextmanager 1320def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]): 1321 frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes] 1322 thread_local_state.trace_state.axis_env.extend(frames) 1323 try: 1324 yield 1325 finally: 1326 for _ in frames: 1327 thread_local_state.trace_state.axis_env.pop() 1328 1329 1330# When a mapped function is given no axis name, we generate a name object based 1331# on the id of the function object. Collisions aren't important because this 1332# name can't be used in collectives, as user code never gets a ref to this 1333# object. We don't want to use the function object itself because that might 1334# persist references to the function object. 1335# TODO(mattjj): revisit this unique axis name strategy 1336class _TempAxisName: 1337 1338 def __init__(self, obj): 1339 self.id = id(obj) 1340 1341 def __repr__(self): 1342 return f'<axis {hex(self.id)}>' 1343 1344 def __hash__(self): 1345 return hash(self.id) 1346 1347 def __eq__(self, other): 1348 return type(other) is _TempAxisName and self.id == other.id 1349 1350 1351def axis_frame(axis_name): 1352 frames = thread_local_state.trace_state.axis_env 1353 for frame in reversed(frames): 1354 if frame.name == axis_name: 1355 return frame 1356 named_axes = [frame.name for frame in reversed(frames) 1357 if not isinstance(frame.name, _TempAxisName)] 1358 raise NameError( 1359 f'unbound axis name: {axis_name}. The following axis names (e.g. defined ' 1360 f'by pmap) are available to collective operations: {named_axes}') 1361 1362 1363# ------------------- Jaxpr checking ------------------- 1364 1365def mapped_aval(size: int, axis: int, aval: AbstractValue) -> AbstractValue: 1366 if aval is abstract_unit: 1367 return aval 1368 elif isinstance(aval, ShapedArray): 1369 # might be raising abstraction level from Concrete here 1370 assert aval.shape[axis] == size 1371 return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype) 1372 else: 1373 raise TypeError(f"Mapped operand {aval}") 1374 1375def unmapped_aval(size: int, axis: int, aval: AbstractValue) -> AbstractValue: 1376 if aval is abstract_unit: 1377 return aval 1378 elif isinstance(aval, ShapedArray): 1379 return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype) 1380 else: 1381 raise TypeError(f"Mapped output {aval}") 1382 1383def typecheck(aval: AbstractValue, x) -> bool: 1384 return typecompat(aval, get_aval(x)) 1385 1386def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool: 1387 """Determine whether `aval` conforms to `aval_ref`""" 1388 aval_ref = raise_to_shaped(aval_ref).strip_weak_type() 1389 try: 1390 return aval_ref == lattice_join(aval_ref, aval).strip_weak_type() 1391 except TypeError: 1392 return False 1393 1394def typematch(aval1: UnshapedArray, aval2: UnshapedArray) -> bool: 1395 return raise_to_shaped(aval1, weak_type=False) == raise_to_shaped(aval2, weak_type=False) 1396 1397class JaxprTypeError(TypeError): pass 1398 1399def typecheck_assert(pred, msg): 1400 if not pred: 1401 raise JaxprTypeError(msg) 1402 1403custom_typechecks: Dict[Primitive, Callable] = {} 1404 1405def check_jaxpr(jaxpr: Jaxpr): 1406 """Checks well-formedness of a jaxpr. 1407 1408 Specifically, check that: 1409 - variables that are read are bound beforehand 1410 - variables are typed equally throughout a jaxpr 1411 - variable type annotations are compatible with their binding expression 1412 1413 Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None` 1414 otherwise. 1415 """ 1416 try: 1417 _check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars]) 1418 except JaxprTypeError as e: 1419 if len(e.args) == 2: 1420 msg, eqn_idx = e.args 1421 jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqn_idx - 10, eqn_idx + 10)) 1422 else: 1423 msg, = e.args 1424 jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20)) 1425 msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str]) 1426 raise JaxprTypeError(msg) from None 1427 1428def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]): 1429 1430 def read(v: Atom) -> AbstractValue: 1431 if isinstance(v, Literal): 1432 return raise_to_shaped(get_aval(v.val)) 1433 else: 1434 typecheck_assert(v in env, f"Variable '{v}' not defined") 1435 return env[v] 1436 1437 def write(v: Var, a: AbstractValue) -> None: 1438 typecheck_assert(v not in env, f"Variable '{v}' already bound") 1439 if v is not dropvar: 1440 typecheck_assert(typecompat(v.aval, a), 1441 f"Variable '{v}' inconsistently typed as {a}, " 1442 f"bound as {v.aval}") 1443 env[v] = a 1444 1445 env : Dict[Var, AbstractValue] = {} 1446 1447 write(unitvar, abstract_unit) 1448 map(write, jaxpr.constvars, [v.aval for v in jaxpr.constvars]) 1449 map(write, jaxpr.invars, in_avals) 1450 1451 for eqn_idx, eqn in enumerate(jaxpr.eqns): 1452 prim = eqn.primitive 1453 try: 1454 in_avals = map(read, eqn.invars) 1455 typecheck_assert(all(not isinstance(ina, ConcreteArray) for ina in in_avals), 1456 "Equation given ConcreteArray type inputs") 1457 if prim in custom_typechecks: 1458 custom_typechecks[prim](*in_avals, **eqn.params) 1459 if prim.call_primitive: 1460 out_avals = check_call(prim, in_avals, eqn.params) 1461 elif prim.map_primitive: 1462 out_avals = check_map(prim, in_avals, eqn.params) 1463 else: 1464 out_avals = check_eqn(prim, in_avals, eqn.params) 1465 map(write, eqn.outvars, out_avals) 1466 except JaxprTypeError as e: 1467 msg, = e.args 1468 src = source_info_util.summarize(eqn.source_info) 1469 msg = "\n\n".join([msg, "in equation:", str(pp_eqn(eqn).indent(2)), 1470 f"from source: {src}"]) 1471 raise JaxprTypeError(msg, eqn_idx) from None 1472 1473 map(read, jaxpr.outvars) 1474 1475def check_eqn(prim, in_avals, params): 1476 for jaxpr in jaxprs_in_params(params): 1477 check_jaxpr(jaxpr) 1478 1479 out_avals = prim.abstract_eval(*in_avals, **params) 1480 if not prim.multiple_results: 1481 out_avals = [out_avals] 1482 return out_avals 1483 1484def check_call(prim, in_avals, params): 1485 typecheck_assert("call_jaxpr" in params, 1486 f"Call primitive {prim} missing 'call_jaxpr' parameter") 1487 call_jaxpr = params["call_jaxpr"] 1488 1489 # These checks also happen in recursive call, but give better errors here. 1490 typecheck_assert(len(in_avals) == len(call_jaxpr.invars), 1491 f"Call primitive {prim} with {len(call_jaxpr.invars)} " 1492 f"operands cannot call jaxpr with {len(call_jaxpr.invars)} " 1493 f"inputs") 1494 binder_avals = [v.aval for v in call_jaxpr.invars] 1495 for binder_aval, in_aval in zip(binder_avals, in_avals): 1496 typecheck_assert(typecompat(binder_aval, in_aval), 1497 f"Call primitive {prim} passes operand {in_aval} " 1498 f"to jaxpr expecting {binder_aval}") 1499 1500 _check_jaxpr(call_jaxpr, in_avals) 1501 1502 out_avals = [v.aval for v in call_jaxpr.outvars] 1503 return out_avals 1504 1505def check_map(prim, in_avals, params): 1506 typecheck_assert("call_jaxpr" in params, 1507 f"Map primitive {prim} missing 'call_jaxpr' parameter") 1508 call_jaxpr = params["call_jaxpr"] 1509 typecheck_assert("axis_size" in params, 1510 f"Map primitive {prim} missing 'axis_size' parameter") 1511 axis_size = params["axis_size"] 1512 typecheck_assert("in_axes" in params, 1513 f"Map primitive {prim} missing 'in_axes' parameter") 1514 in_axes = params["in_axes"] 1515 typecheck_assert("out_axes" in params, 1516 f"Map primitive {prim} missing 'out_axes' parameter") 1517 out_axes = params["out_axes"] 1518 1519 binder_avals = [unmapped_aval(axis_size, in_axis, v.aval) 1520 if in_axis is not None else v.aval 1521 for v, in_axis in zip(call_jaxpr.invars, in_axes)] 1522 for binder_aval, in_aval in zip(binder_avals, in_avals): 1523 typecheck_assert(typecompat(binder_aval, in_aval), 1524 f"Call primitive {prim} passes operand {in_aval} " 1525 f"to jaxpr expecting {binder_aval}") 1526 1527 mapped_avals = [mapped_aval(axis_size, in_axis, aval) 1528 if in_axis is not None else aval 1529 for aval, in_axis in zip(in_avals, in_axes)] 1530 _check_jaxpr(call_jaxpr, mapped_avals) 1531 1532 mapped_out_avals = [v.aval for v in call_jaxpr.outvars] 1533 out_avals = [unmapped_aval(axis_size, out_axis, aval) if out_axis is not None else aval 1534 for aval, out_axis in zip(mapped_out_avals, out_axes)] 1535 return out_avals 1536 1537 1538# ------------------- Jaxpr printed representation ------------------- 1539 1540def pp_vars(vs: Sequence[Any], print_shapes: bool = False) -> str: 1541 if print_shapes: 1542 return ' '.join(f'{v}:{v.aval.str_short()}' for v in vs) 1543 else: 1544 return ' '.join(map(str, vs)) 1545 1546def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint: 1547 filtered_params = {k: v for k, v in params.items() 1548 if (k != 'branches' and 1549 not isinstance(v, (Jaxpr, ClosedJaxpr)))} 1550 return pp(primitive_name) >> pp_kv_pairs(sorted(filtered_params.items())) 1551 1552def pp_eqn(eqn: JaxprEqn, print_shapes: bool = False) -> PrettyPrint: 1553 lhs = pp_vars(eqn.outvars, print_shapes) 1554 pp_lhs = pp(f'{lhs} =') 1555 pp_rhs = (pp(eqn.primitive.name) >> 1556 pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >> 1557 pp(pp_vars(eqn.invars, print_shapes))) 1558 if len(lhs) <= 6 or print_shapes: 1559 return pp_lhs >> pp(' ') >> pp_rhs 1560 else: 1561 return pp_lhs + pp_rhs.indent(2) 1562 1563def pp_eqns(eqns: Sequence[JaxprEqn], 1564 source_info: bool = False) -> Sequence[PrettyPrint]: 1565 pps = map(pp_eqn, eqns) 1566 if source_info: 1567 l = max((i + len(s) for x in pps for i, s in x.lines), default=None) 1568 if l is not None: 1569 return [p.annotate(l, source_info_util.summarize(e.source_info)) 1570 for e, p in zip(eqns, pps)] 1571 return pps 1572 1573def pp_jaxpr(jaxpr: Jaxpr, source_info: bool = False) -> PrettyPrint: 1574 pps = pp_eqns(jaxpr.eqns, source_info=source_info) 1575 str_outvars = str(tuple(jaxpr.outvars)) 1576 return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars), 1577 pp_vars(jaxpr.invars))) + 1578 ((pp('let ') >> vcat(pps)) 1579 + pp('in {} }}'.format(str_outvars))).indent(2)) 1580 1581def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, 1582 source_info: bool = False) -> PrettyPrint: 1583 lo = max(lo, 0) 1584 hi = max(lo, min(hi, len(jaxpr.eqns))) 1585 eqns = jaxpr.eqns[lo:hi] 1586 pps = [] 1587 if len(eqns) == 0 and len(jaxpr.eqns) != 0: 1588 pps.append(pp('...')) 1589 else: 1590 if lo != 0: 1591 pps.append(pp('...')) 1592 pps.extend(pp_eqns(eqns, source_info=source_info)) 1593 if hi != len(jaxpr.eqns): 1594 pps.append(pp('...')) 1595 str_outvars = str(tuple(jaxpr.outvars)) 1596 return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars), 1597 pp_vars(jaxpr.invars))) + 1598 ((pp('let ') >> vcat(pps)) 1599 + pp('in {} }}'.format(str_outvars))).indent(2)) 1600 1601def pp_jaxprs(jaxprs) -> PrettyPrint: 1602 jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs] 1603 return pp('( ') >> vcat(map(pp_jaxpr, jaxprs)) >> pp(' )') 1604 1605def pp_kv_pair(k, v): 1606 if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v): 1607 pp_v = pp_jaxprs(v) 1608 else: 1609 pp_v = pp(v) 1610 return pp(f'{k}=') >> pp_v 1611 1612def pp_kv_pairs(kv_pairs): 1613 if kv_pairs: 1614 return pp('[ ') >> vcat([pp_kv_pair(k, v) for k, v in kv_pairs]) >> pp(' ]') 1615 else: 1616 return pp('') 1617 1618@config.register_omnistaging_disabler 1619def omnistaging_disabler() -> None: 1620 global thread_local_state, call_bind, find_top_trace, initial_style_staging, \ 1621 new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env, \ 1622 eval_context 1623 1624 class TraceStack: 1625 upward: List[MainTrace] 1626 downward: List[MainTrace] 1627 1628 def __init__(self): 1629 self.upward = [] 1630 self.downward = [] 1631 1632 def next_level(self, bottom: bool) -> int: 1633 if bottom: 1634 return - (len(self.downward) + 1) 1635 else: 1636 return len(self.upward) 1637 1638 def push(self, main_trace: MainTrace, bottom: bool) -> None: 1639 if bottom: 1640 self.downward.append(main_trace) 1641 else: 1642 self.upward.append(main_trace) 1643 1644 def pop(self, bottom: bool) -> None: 1645 if bottom: 1646 self.downward.pop() 1647 else: 1648 self.upward.pop() 1649 1650 def __repr__(self) -> str: 1651 return 'Trace stack\n{} ---\n{}'.format( 1652 map(' {}\n'.format, self.upward[::-1]), 1653 map(' {}\n'.format, self.downward)) 1654 1655 def copy(self): 1656 new = TraceStack() 1657 new.upward = self.upward[:] 1658 new.downward = self.downward[:] 1659 return new 1660 1661 class TraceState: 1662 trace_stack: TraceStack 1663 substack: List[Sublevel] 1664 initial_style: bool 1665 1666 def __init__(self) -> None: 1667 self.trace_stack = TraceStack() # type: ignore 1668 self.substack = [Sublevel(0)] 1669 self.initial_style = False 1670 1671 def copy(self): 1672 new = TraceState() 1673 new.trace_stack = self.trace_stack.copy() 1674 new.substack = self.substack[:] 1675 new.initial_style = self.initial_style 1676 return new 1677 1678 thread_local_state = ThreadLocalState() 1679 1680 def reset_trace_state() -> bool: 1681 "Reset the global trace state and return True if it was already clean." 1682 if (thread_local_state.trace_state.substack != [Sublevel(0)] or 1683 thread_local_state.trace_state.trace_stack.downward or 1684 thread_local_state.trace_state.trace_stack.upward): 1685 thread_local_state.trace_state.__init__() # type: ignore 1686 return False 1687 else: 1688 return True 1689 1690 @contextmanager 1691 def new_main(trace_type: Type[Trace], bottom=False, **payload) -> Generator[MainTrace, None, None]: 1692 level = thread_local_state.trace_state.trace_stack.next_level(bottom) 1693 main = MainTrace(level, trace_type, **payload) 1694 thread_local_state.trace_state.trace_stack.push(main, bottom) 1695 1696 try: 1697 yield main 1698 finally: 1699 thread_local_state.trace_state.trace_stack.pop(bottom) 1700 1701 if debug_state.check_leaks: 1702 t = ref(main) 1703 del main 1704 if t() is not None: 1705 print(thread_local_state.trace_state.trace_stack) 1706 raise Exception('Leaked trace {}'.format(t())) 1707 1708 def find_top_trace(xs) -> Optional[Trace]: 1709 top_trace = max((x._trace for x in xs if isinstance(x, Tracer)), 1710 key=attrgetter('level'), default=None) 1711 return top_trace and top_trace.main.with_cur_sublevel() 1712 1713 @contextmanager 1714 def eval_context(): 1715 yield # dummy implementation for forward compatibility 1716 1717 def bind(self, *args, **kwargs): 1718 assert skip_checks or all(isinstance(arg, Tracer) 1719 or valid_jaxtype(arg) for arg in args), args 1720 top_trace = find_top_trace(args) 1721 if top_trace is None: 1722 return self.impl(*args, **kwargs) 1723 1724 tracers = map(top_trace.full_raise, args) 1725 out_tracer = top_trace.process_primitive(self, tracers, kwargs) 1726 if self.multiple_results: 1727 return map(full_lower, out_tracer) 1728 else: 1729 return full_lower(out_tracer) 1730 Primitive.bind = bind # type: ignore 1731 1732 def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'], 1733 fun: lu.WrappedFun, *args, **params): 1734 out_axes_transforms = _IgnoreElemList() 1735 if primitive.map_primitive: 1736 out_axes_thunk = params['out_axes_thunk'] 1737 # The new thunk depends deterministically on the old thunk and the wrapped function. 1738 # Any caching already has to include the wrapped function as part of the key, so we 1739 # only use the previous thunk for equality checks. 1740 @as_hashable_function(closure=out_axes_thunk) 1741 def new_out_axes_thunk(): 1742 out_axes = out_axes_thunk() 1743 for t in out_axes_transforms: 1744 out_axes = t(out_axes) 1745 return out_axes 1746 params = dict(params, out_axes_thunk=new_out_axes_thunk) 1747 params_tuple = tuple(params.items()) 1748 top_trace = find_top_trace(args) 1749 level = (thread_local_state.trace_state.trace_stack.next_level(True) 1750 if top_trace is None else top_trace.level) 1751 params_tuple = tuple(params.items()) 1752 fun, env_trace_todo = process_env_traces( 1753 fun, primitive, level, params_tuple, out_axes_transforms) 1754 if top_trace is None: 1755 with new_sublevel(): 1756 outs = primitive.impl(fun, *args, **params) 1757 else: 1758 tracers = map(top_trace.full_raise, args) 1759 outs = primitive.process(top_trace, fun, tracers, params) 1760 return apply_todos(env_trace_todo(), map(full_lower, outs)) 1761 1762 @contextmanager 1763 def extend_axis_env(axis_name, size: int, tag: Any): 1764 yield 1765 1766 @contextmanager 1767 def initial_style_staging(): 1768 trace_state = thread_local_state.trace_state 1769 prev, trace_state.initial_style = trace_state.initial_style, True 1770 try: 1771 yield 1772 finally: 1773 trace_state.initial_style = prev 1774 1775# Casting float0 array to a float-valued zero array. 1776def zeros_like_float0(array, dtype=None): 1777 if not dtype: 1778 dtype = np.float 1779 return np.zeros(array.shape, dtype) 1780