1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16import operator
17from operator import attrgetter
18from contextlib import contextmanager, suppress
19from collections import namedtuple
20from functools import total_ordering
21import itertools as it
22from weakref import ref
23import threading
24import types
25from typing import (Any, Callable, ClassVar, Dict, Generator,
26                    Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
27                    Type, Union, cast, Iterable, Hashable)
28
29import numpy as np
30
31from . import dtypes
32from .config import FLAGS, config
33from . import linear_util as lu
34
35from jax._src import source_info_util
36from ._src.util import (safe_zip, safe_map, partial, curry, prod, partialmethod,
37                   tuple_insert, tuple_delete, as_hashable_function,
38                   HashableFunction)
39from ._src.pprint_util import pp, vcat, PrettyPrint
40
41from ._src import traceback_util
42traceback_util.register_exclusion(__file__)
43
44# TODO(mattjj): move this into debug_state
45skip_checks = not FLAGS.jax_enable_checks
46
47@contextmanager
48def skipping_checks():
49  """Context manager for temporarily disabling internal checks."""
50  global skip_checks
51  old_value, skip_checks = skip_checks, True
52  try:
53    yield
54  finally:
55    skip_checks = old_value
56
57@contextmanager
58def checking_leaks():
59  """Context manager for temporarily enabling tracer leak checks."""
60  old_value, debug_state.check_leaks = debug_state.check_leaks, True
61  try:
62    yield
63  finally:
64    debug_state.check_leaks = old_value
65
66class DebugState(threading.local):
67  def __init__(self):
68    self.check_leaks = FLAGS.jax_check_tracer_leaks
69debug_state = DebugState()
70
71zip = safe_zip
72map = safe_map
73
74
75# -------------------- jaxprs --------------------
76
77class Jaxpr:
78  constvars: List['Var']
79  invars: List['Var']
80  outvars: List['Atom']
81  eqns: List['JaxprEqn']
82
83  def __init__(self, constvars: Sequence['Var'], invars: Sequence['Var'],
84               outvars: Sequence['Atom'], eqns: Sequence['JaxprEqn']):
85    """
86    Args:
87      constvars: list of variables introduced for constants. Array constants are
88        replaced with such variables while scalar constants are kept inline.
89      invars: list of input variables. Together, `constvars` and `invars` are
90        the inputs to the Jaxpr.
91      outvars: list of output variables.
92      eqns: list of equations.
93    """
94    self.constvars = list(constvars)
95    self.invars = list(invars)
96    self.outvars = list(outvars)
97    self.eqns = list(eqns)
98
99  def __str__(self):
100    return str(pp_jaxpr(self))
101  __repr__ = __str__
102
103
104def jaxprs_in_params(params) -> Iterator[Jaxpr]:
105  for val in params.values():
106    vals = val if isinstance(val, tuple) else (val,)
107    for v in vals:
108      if isinstance(v, Jaxpr):
109        yield v
110      elif isinstance(v, ClosedJaxpr):
111        yield v.jaxpr
112
113
114def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
115  """Generator for all subjaxprs found in the params of jaxpr.eqns.
116
117  Does not descend recursively into the found subjaxprs.
118  """
119  for eqn in jaxpr.eqns:
120    yield from jaxprs_in_params(eqn.params)
121
122
123class ClosedJaxpr:
124  jaxpr: Jaxpr
125  consts: List['Any']
126
127  def __init__(self, jaxpr: Jaxpr, consts: Sequence):
128    assert len(consts) == len(jaxpr.constvars)
129    self.jaxpr = jaxpr
130    self.consts = list(consts)
131
132  @property
133  def in_avals(self):
134    return [v.aval for v in self.jaxpr.invars]
135
136  @property
137  def out_avals(self):
138    return [v.aval for v in self.jaxpr.outvars]
139
140  @property
141  def literals(self):
142    return self.consts  # backwards compatible alias
143
144  def map_jaxpr(self, f):
145    return ClosedJaxpr(f(self.jaxpr), self.consts)
146
147  def __str__(self): return str(self.jaxpr)
148  def __repr__(self): return repr(self.jaxpr)
149
150@curry
151def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
152  return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
153
154
155class JaxprEqn(NamedTuple):
156  invars: List['Atom']
157  outvars: List['Var']
158  primitive: 'Primitive'
159  params: Dict[str, Any]
160  source_info: Optional[source_info_util.Traceback]
161
162  def __repr__(self): return str(pp_eqn(self)).rstrip()
163
164def new_jaxpr_eqn(invars, outvars, primitive, params, source_info=None):
165  return JaxprEqn(invars, outvars, primitive, params, source_info)
166
167
168@total_ordering
169class Var:
170  # TODO(frostig,mattjj): We don't override __eq__ or __hash__, so comparison is
171  # by object id, but pretty printing might collide.
172  count: int
173  suffix: str
174  aval: 'AbstractValue'
175
176  def __init__(self, count: int, suffix: str, aval: 'AbstractValue'):
177    self.count = count
178    self.suffix = suffix
179    self.aval = raise_to_shaped(aval)
180
181  def __lt__(self, other):
182    if not isinstance(other, Var):
183      return NotImplemented
184    else:
185      return (self.count, self.suffix) < (other.count, other.suffix)
186
187  def __repr__(self):
188    rem = self.count
189    s = ''
190    while True:
191      rem, i = rem // 26, rem % 26
192      s = chr(97 + i % 26) + s
193      if not rem:
194        break
195    return s + self.suffix
196
197def _jaxpr_vars(jaxpr):
198  return it.chain(
199      jaxpr.invars, jaxpr.constvars,
200      (v for eqn in jaxpr.eqns for v in eqn.outvars))
201
202def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
203           suffix: str = '') -> Callable[['AbstractValue'], Var]:
204  """Produce distinct variables, printed with the optional suffix.
205
206  If `jaxprs` is provided, the variables produced will be distinct from those in
207  any of the given jaxprs.
208  """
209  if jaxprs is None:
210    start = 0
211  else:
212    all_vars = it.chain.from_iterable(_jaxpr_vars(j) for j in jaxprs)
213    start = 1 + max((v.count for v in all_vars), default=-1)
214  counter = it.count(start=start)
215  return lambda aval: Var(next(counter), suffix, aval)
216
217# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that
218# the assignment is dropped, i.e. that an expression's output value will never
219# be read. In that sense, `dropvar` is not a variable, but it is convenient to
220# treat it as a special case of one. Its `aval` is similarly inexact.
221class DropVar(Var):
222  count = -1
223  suffix = ''
224  def __init__(self): pass
225  @property
226  def aval(self): return abstract_unit
227  def __repr__(self): return '_'
228dropvar = DropVar()
229
230class Literal:
231  __slots__ = ["val", "hash"]
232
233  val: Any
234  hash: Optional[int]
235
236  def __init__(self, val):
237    self.val = val
238    try:
239      self.hash = hash(val)
240    except TypeError:
241      if type(val) in literalable_types:
242        try:
243          self.hash = hash((val.item(), val.dtype))
244        except (TypeError, AttributeError, ValueError):
245          self.hash = None
246
247  @property
248  def aval(self):
249    return raise_to_shaped(get_aval(self.val))
250
251  def __hash__(self):
252    assert False
253
254  def __repr__(self):
255    if hasattr(self, 'hash'):
256      return '{}'.format(self.val)
257    else:
258      return 'Literal(val={})'.format(self.val)
259
260literalable_types: Set[type] = set()
261
262Atom = Union[Var, Literal]
263
264class Primitive:
265  name: str
266  multiple_results = False  # set for multi-output primitives
267  call_primitive = False    # set for call primitives processed in final style
268  map_primitive = False     # set for map primitives processed in final style
269
270  def __init__(self, name: str):
271    self.name = name
272
273  def __repr__(self):
274    return '{}'.format(self.name)
275
276
277  def bind(self, *args, **params):
278    assert skip_checks or all(isinstance(arg, Tracer)
279                              or valid_jaxtype(arg) for arg in args), args
280    top_trace = find_top_trace(args)
281    tracers = map(top_trace.full_raise, args)
282    out = top_trace.process_primitive(self, tracers, params)
283    return map(full_lower, out) if self.multiple_results else full_lower(out)
284
285  def def_impl(self, impl):
286    self.impl = impl
287    return impl
288
289  def def_abstract_eval(self, abstract_eval):
290    self.abstract_eval = abstract_eval
291    return abstract_eval
292
293  def def_custom_bind(self, bind):
294    self.bind = bind
295    return bind
296
297  def impl(self, *args, **params):
298    raise NotImplementedError("Evaluation rule for '{}' not implemented"
299                              .format(self.name))
300
301  def abstract_eval(self, *args, **params):
302    raise NotImplementedError("Abstract evaluation for '{}' not implemented"
303                              .format(self.name))
304
305
306# -------------------- lifting --------------------
307
308# TODO(necula): this belongs next to pe.new_eqn_recipe, but is needed in
309# core.py. Plan to move all these utilities to jaxpr.py.
310def extract_call_jaxpr(
311  primitive: Primitive,
312  params: Dict[str, Any]) -> Tuple[Optional[Jaxpr], Dict[str, Any]]:
313  """Extract the call primitive subjaxpr from the params.
314
315  Returns the subjaxpr and the params without the "call_jaxpr" value. If this is
316  not a call primitive then returns (None, params).
317  """
318  if not (primitive.call_primitive or primitive.map_primitive):
319    return (None, params)
320  else:
321    assert "call_jaxpr" in params
322    new_params = dict(params)
323    del new_params["call_jaxpr"]
324    return (params["call_jaxpr"], new_params)
325
326
327def traverse_jaxpr_params(f, params):
328  """Applies f to each jaxpr parameter and returns a tuple of returned values."""
329  return tuple(f(param if type(param) is Jaxpr else param.jaxpr)
330               for param in params.values()
331               if type(param) in (Jaxpr, ClosedJaxpr))
332
333
334def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
335  def read(v):
336    if type(v) is Literal:
337      return v.val
338    else:
339      return env[v]
340
341  def write(v, val):
342    env[v] = val
343
344  env: Dict[Var, Any] = {}
345  write(unitvar, unit)
346  map(write, jaxpr.constvars, consts)
347  map(write, jaxpr.invars, args)
348  for eqn in jaxpr.eqns:
349    in_vals = map(read, eqn.invars)
350    call_jaxpr, params = extract_call_jaxpr(eqn.primitive, eqn.params)
351    if call_jaxpr:
352      subfuns = [lu.wrap_init(partial(eval_jaxpr, call_jaxpr, ()))]
353    else:
354      subfuns = []
355    if eqn.primitive.map_primitive:
356      out_axes_thunk = HashableFunction(lambda: params['out_axes'],
357                                        closure=params['out_axes'])
358      bind_params = dict(params, out_axes_thunk=out_axes_thunk)
359      del bind_params['out_axes']
360    else:
361      bind_params = params
362    with source_info_util.user_context(eqn.source_info):
363      ans = eqn.primitive.bind(*(subfuns + in_vals), **bind_params)
364    if eqn.primitive.multiple_results:
365      map(write, eqn.outvars, ans)
366    else:
367      write(eqn.outvars[0], ans)
368  return map(read, jaxpr.outvars)
369
370
371# -------------------- tracing --------------------
372
373
374class Trace:
375  __slots__ = ['main', 'level', 'sublevel']
376
377  main: 'MainTrace'
378  level: int
379  sublevel: 'Sublevel'
380
381  def __init__(self, main: 'MainTrace', sublevel: 'Sublevel') -> None:
382    self.main = main
383    self.level = main.level
384    self.sublevel = sublevel
385
386  def full_raise(self, val) -> 'Tracer':
387    if not isinstance(val, Tracer):
388      return self.pure(val)
389    val._assert_live()
390    level = self.level
391    sublevel = self.sublevel
392    if val._trace.main is self.main:
393      if val._trace.sublevel == sublevel:
394        return val
395      elif val._trace.sublevel < sublevel:
396        return self.sublift(val)
397      else:
398        raise escaped_tracer_error(
399            val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}")
400    elif val._trace.level < level:
401      if val._trace.sublevel > sublevel:
402        raise escaped_tracer_error(
403            val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}")
404      return self.lift(val)
405    elif val._trace.level > level:
406      raise escaped_tracer_error(
407          val, f"Can't lift level {val} to {self}")
408    else:  # val._trace.level == self.level:
409      raise escaped_tracer_error(
410          val, f"Different traces at same level: {val}, {self}")
411
412  def pure(self, val):
413    raise NotImplementedError("must override")
414
415  def lift(self, tracer):
416    raise NotImplementedError("must override")
417
418  def sublift(self, tracer):
419    raise NotImplementedError("must override")
420
421  def process_primitive(self, primitive, tracers, params):
422    raise NotImplementedError("must override")
423
424  def __repr__(self):
425    return '{}(level={}/{})'.format(
426        self.__class__.__name__, self.level, self.sublevel)
427
428  def process_call(self, call_primitive, f, tracers, params):
429    msg = (f"{type(self)} must override process_call to handle call-like "
430           "primitives")
431    raise NotImplementedError(msg)
432
433  def process_map(self, call_primitive, f, tracers, params):
434    msg = (f"{type(self)} must override process_map to handle map-like "
435           "primitives")
436    raise NotImplementedError(msg)
437
438  def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
439    msg = (f"{type(self)} must override process_custom_jvp_call "
440           "to handle custom_jvp primitives")
441    raise NotImplementedError(msg)
442
443  def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
444    msg = (f"{type(self)} must override process_custom_vjp_call "
445           "to handle custom_vjp primitives")
446    raise NotImplementedError(msg)
447
448def escaped_tracer_error(tracer, detail=None):
449  num_frames = FLAGS.jax_tracer_error_num_traceback_frames
450  msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
451         "through global state from a previously traced function.\n"
452         "The functions being transformed should not save traced values to "
453         "global state.")
454  if detail:
455    msg += " Detail: {}.".format(detail)
456  try:
457    line_info = tracer._line_info
458  except AttributeError:
459    pass
460  else:
461    msg += ('\nThe tracer that caused this error was created on line '
462            f'{source_info_util.summarize(line_info)}.\n')
463    if num_frames > 0:
464      msg += (f'When the tracer was created, the final {num_frames} stack '
465              'frames (most recent last) excluding JAX-internal frames were:\n'
466              f'{source_info_util.summarize(line_info, num_frames=num_frames)}')
467  try:
468    fun_source_info = tracer._trace.main.source_info
469  except AttributeError:
470    pass
471  else:
472    msg += ('\nThe function being traced when the tracer leaked was '
473            f'{fun_source_info}.')
474  msg += ('\nTo catch the leak earlier, try setting the environment variable '
475          'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context '
476          'manager.')
477  return UnexpectedTracerError(msg)
478
479class UnexpectedTracerError(Exception): pass
480
481class Tracer:
482  __array_priority__ = 1000
483  __slots__ = ['_trace', '__weakref__', '_line_info']
484
485  def __array__(self, *args, **kw):
486    msg = ("The numpy.ndarray conversion method __array__() was called on "
487           f"the JAX Tracer object {self}.\n\n"
488           "This error can occur when a JAX Tracer object is passed to a raw "
489           "numpy function, or a method on a numpy.ndarray object. You might "
490           "want to check that you are using `jnp` together with "
491           "`import jax.numpy as jnp` rather than using `np` via "
492           "`import numpy as np`. If this error arises on a line that involves "
493           "array indexing, like `x[idx]`, it may be that the array being "
494           "indexed `x` is a raw numpy.ndarray while the indices `idx` are a "
495           "JAX Tracer instance; in that case, you can instead write "
496           "`jax.device_put(x)[idx]`.")
497    raise Exception(msg)
498
499  def __init__(self, trace: Trace):
500    self._trace = trace
501
502  def __iter__(self):
503    return iter(self.aval._iter(self))
504
505  def __len__(self):
506    return self.aval._len(self)
507
508  @property
509  def aval(self):
510    raise NotImplementedError("must override")
511
512  def _assert_live(self) -> None:
513    pass  # Override for liveness checking
514
515  # Python looks up special methods only on classes, not instances. This means
516  # these methods needs to be defined explicitly rather than relying on
517  # __getattr__.
518  def __neg__(self): return self.aval._neg(self)
519  def __pos__(self): return self.aval._pos(self)
520  def __eq__(self, other): return self.aval._eq(self, other)
521  def __ne__(self, other): return self.aval._ne(self, other)
522  def __lt__(self, other): return self.aval._lt(self, other)
523  def __le__(self, other): return self.aval._le(self, other)
524  def __gt__(self, other): return self.aval._gt(self, other)
525  def __ge__(self, other): return self.aval._ge(self, other)
526  def __abs__(self): return self.aval._abs(self)
527  def __add__(self, other): return self.aval._add(self, other)
528  def __radd__(self, other): return self.aval._radd(self, other)
529  def __sub__(self, other): return self.aval._sub(self, other)
530  def __rsub__(self, other): return self.aval._rsub(self, other)
531  def __mul__(self, other): return self.aval._mul(self, other)
532  def __rmul__(self, other): return self.aval._rmul(self, other)
533  def __div__(self, other): return self.aval._div(self, other)
534  def __rdiv__(self, other): return self.aval._rdiv(self, other)
535  def __truediv__(self, other): return self.aval._truediv(self, other)
536  def __rtruediv__(self, other): return self.aval._rtruediv(self, other)
537  def __floordiv__(self, other): return self.aval._floordiv(self, other)
538  def __rfloordiv__(self, other): return self.aval._rfloordiv(self, other)
539  def __divmod__(self, other): return self.aval._divmod(self, other)
540  def __rdivmod__(self, other): return self.aval._rdivmod(self, other)
541  def __mod__(self, other): return self.aval._mod(self, other)
542  def __rmod__(self, other): return self.aval._rmod(self, other)
543  def __pow__(self, other): return self.aval._pow(self, other)
544  def __rpow__(self, other): return self.aval._rpow(self, other)
545  def __matmul__(self, other): return self.aval._matmul(self, other)
546  def __rmatmul__(self, other): return self.aval._rmatmul(self, other)
547  def __and__(self, other): return self.aval._and(self, other)
548  def __rand__(self, other): return self.aval._rand(self, other)
549  def __or__(self, other): return self.aval._or(self, other)
550  def __ror__(self, other): return self.aval._ror(self, other)
551  def __xor__(self, other): return self.aval._xor(self, other)
552  def __rxor__(self, other): return self.aval._rxor(self, other)
553  def __invert__(self): return self.aval._invert(self)
554  def __lshift__(self, other): return self.aval._lshift(self, other)
555  def __rlshift__(self, other): return self.aval._rlshift(self, other)
556  def __rshift__(self, other): return self.aval._rshift(self, other)
557  def __rrshift__(self, other): return self.aval._rrshift(self, other)
558  def __getitem__(self, idx): return self.aval._getitem(self, idx)
559  def __nonzero__(self): return self.aval._nonzero(self)
560  def __bool__(self): return self.aval._bool(self)
561  def __int__(self): return self.aval._int(self)
562  def __long__(self): return self.aval._long(self)
563  def __hex__(self): return self.aval._hex(self)
564  def __oct__(self): return self.aval._oct(self)
565  def __float__(self): return self.aval._float(self)
566  def __complex__(self): return self.aval._complex(self)
567
568  def __setitem__(self, idx, val):
569    raise TypeError("JAX 'Tracer' objects do not support item assignment")
570
571  # NumPy also only looks up special methods on classes.
572  def __array_module__(self, types): return self.aval._array_module(self, types)
573
574  def __getattr__(self, name):
575    # if the aval property raises an AttributeError, gets caught here
576    assert skip_checks or name != "aval"
577
578    try:
579      attr = getattr(self.aval, name)
580    except KeyError as err:
581      raise AttributeError(
582          "{} has no attribute {}".format(self.__class__.__name__, name)
583      ) from err
584    else:
585      t = type(attr)
586      if t is aval_property:
587        return attr.fget(self)
588      elif t is aval_method:
589        return types.MethodType(attr.fun, self)
590      else:
591        return attr
592
593  def __repr__(self):
594    base = pp('Traced<{}>with<{}>'.format(self.aval, self._trace))
595    contents = self._contents()
596    if contents:
597      base += pp('  with ') >> vcat(pp('{} = '.format(name)) >> pp_payload
598                                    for name, pp_payload in contents)
599    return str(base)
600
601  def _contents(self):
602    try:
603      return [(name, pp(repr(getattr(self, name)))) for name in self.__slots__]
604    except AttributeError:
605      return ()
606
607  def __copy__(self):
608    return self
609
610  def __deepcopy__(self, unused_memo):
611    return self
612
613  def _origin_msg(self) -> str:
614    return ""
615
616# these can be used to set up forwarding of properties and instance methods from
617# Tracer instances to the underlying avals
618aval_property = namedtuple("aval_property", ["fget"])
619aval_method = namedtuple("aval_method", ["fun"])
620
621
622class EvalTrace(Trace):
623  # See comments in https://github.com/google/jax/pull/3370
624  def pure(self, x): return x
625  lift = sublift = pure
626
627  def process_primitive(self, primitive, tracers, params):
628    return primitive.impl(*tracers, **params)
629
630  def process_call(self, primitive, f, tracers, params):
631    return primitive.impl(f, *tracers, **params)
632  process_map = process_call
633
634  def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
635    del primitive, jvp  # Unused.
636    return fun.call_wrapped(*tracers)
637
638  def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
639    del primitive, fwd, bwd, out_trees  # Unused.
640    return fun.call_wrapped(*tracers)
641
642
643class MainTrace:
644  level: int
645  trace_type: Type[Trace]
646  payload: Dict[str, Any]
647
648  def __init__(self, level, trace_type, **payload) -> None:
649    self.level = level
650    self.trace_type = trace_type
651    self.payload = payload
652
653  def __repr__(self) -> str:
654    return "MainTrace({},{})".format(self.level, self.trace_type.__name__)
655
656  def __hash__(self) -> int:
657    return hash((self.level, self.trace_type))
658
659  def __eq__(self, other: object) -> bool:
660    return (isinstance(other, MainTrace) and
661            self.level == other.level and
662            self.trace_type == other.trace_type and
663            self.payload == other.payload)
664
665  def with_cur_sublevel(self):
666    return self.trace_type(self, cur_sublevel(), **self.payload)
667
668class TraceStack:
669  # See comments in https://github.com/google/jax/pull/3370
670  stack: List[MainTrace]
671  dynamic: MainTrace
672
673  def __init__(self):
674    eval_trace = MainTrace(0, EvalTrace)
675    self.stack = [eval_trace]
676    self.dynamic = eval_trace
677
678  def next_level(self) -> int:
679    return len(self.stack)
680
681  def push(self, main_trace: MainTrace) -> None:
682    self.stack.append(main_trace)
683
684  def pop(self) -> None:
685    self.stack.pop()
686
687  def __repr__(self) -> str:
688    stack_str = map('  {}\n'.format, self.stack[::-1])
689    return f'Trace stack\n{stack_str}\n{self.dynamic}'
690
691  def copy(self):
692    new = self.__new__(TraceStack)
693    new.stack = self.stack[:]
694    new.dynamic = self.dynamic
695    return new
696
697class Sublevel(int): pass
698AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
699AxisName = Hashable
700
701class TraceState:
702  trace_stack: TraceStack
703  substack: List[Sublevel]
704  axis_env: List[AxisEnvFrame]
705
706  def __init__(self) -> None:
707    self.trace_stack = TraceStack()
708    self.substack = [Sublevel(0)]
709    self.axis_env = []
710
711  def copy(self):
712    new = self.__new__(TraceState)
713    new.trace_stack = self.trace_stack.copy()
714    new.substack = self.substack[:]
715    new.axis_env = self.axis_env[:]
716    return new
717
718# The global state of the tracer is accessed by a thread-local object.
719# This allows concurrent tracing in separate threads; passing traced objects
720# between threads is forbidden.
721class ThreadLocalState(threading.local):
722  def __init__(self):
723    self.trace_state = TraceState()
724thread_local_state = ThreadLocalState()
725
726def trace_state_clean() -> bool:
727  trace_state = thread_local_state.trace_state
728  return (trace_state.substack == [Sublevel(0)] and
729          trace_state.axis_env == [] and
730          trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and
731          trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace))
732
733def reset_trace_state() -> bool:
734  "Reset the global trace state and return True if it was already clean."
735  if not trace_state_clean():
736    thread_local_state.trace_state.__init__()  # type: ignore
737    return False
738  else:
739    return True
740
741def cur_sublevel() -> Sublevel:
742  return thread_local_state.trace_state.substack[-1]
743
744@contextmanager
745def new_main(trace_type: Type[Trace],
746             dynamic: bool = False,
747             **payload) -> Generator[MainTrace, None, None]:
748  # See comments in https://github.com/google/jax/pull/3370
749  stack = thread_local_state.trace_state.trace_stack
750  level = stack.next_level()
751  main = MainTrace(level, trace_type, **payload)
752  stack.push(main)
753  if dynamic:
754    prev_dynamic, stack.dynamic = stack.dynamic, main
755
756  try:
757    yield main
758  finally:
759    stack.pop()
760    if dynamic:
761      stack.dynamic = prev_dynamic
762
763  if debug_state.check_leaks:
764    t = ref(main)
765    del main
766    if t() is not None:
767      raise Exception(f'Leaked trace {t()}')
768
769@contextmanager
770def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
771  # See comments in https://github.com/google/jax/pull/3370
772  stack = thread_local_state.trace_state.trace_stack
773  main = MainTrace(0, trace_type)
774  prev_dynamic, stack.dynamic = stack.dynamic, main
775  prev_base, stack.stack[0] = stack.stack[0], main
776  try:
777    yield main
778  finally:
779    stack.dynamic = prev_dynamic
780    stack.stack[0] = prev_base
781
782  if debug_state.check_leaks:
783    t = ref(main)
784    del main
785    if t() is not None:
786      raise Exception('Leaked trace {}'.format(t()))
787
788@contextmanager
789def eval_context():
790  with new_base_main(EvalTrace):
791    yield
792
793@contextmanager
794def new_sublevel() -> Generator[None, None, None]:
795  sublevel = Sublevel(len(thread_local_state.trace_state.substack))
796  thread_local_state.trace_state.substack.append(sublevel)
797  try:
798    yield
799  finally:
800    thread_local_state.trace_state.substack.pop()
801
802  # TODO(mattjj): to check sublevel leaks, we need to make Sublevel weakref-able
803  # if debug_state.check_leaks:
804  #   t = ref(sublevel)
805  #   del sublevel
806  #   if t() is not None:
807  #     raise Exception('Leaked sublevel {}'.format(t()))
808
809def maybe_new_sublevel(trace):
810  # dynamic traces run the WrappedFun, so we raise the sublevel for them
811  dynamic = thread_local_state.trace_state.trace_stack.dynamic
812  return new_sublevel() if trace.main is dynamic else suppress()
813
814def full_lower(val):
815  if isinstance(val, Tracer):
816    return val.full_lower()
817  else:
818    return val
819
820def find_top_trace(xs) -> Trace:
821  top_tracer = max((x for x in xs if isinstance(x, Tracer)),
822                    default=None, key=attrgetter('_trace.level'))
823  if top_tracer is not None:
824    top_tracer._assert_live()
825    top_main = top_tracer._trace.main  # type: Optional[MainTrace]
826  else:
827    top_main = None
828  dynamic = thread_local_state.trace_state.trace_stack.dynamic
829  top_main = (dynamic if top_main is None or dynamic.level > top_main.level
830              else top_main)
831  return top_main and top_main.with_cur_sublevel()  # type: ignore
832
833
834# -------------------- abstract values --------------------
835
836
837class AbstractValue:
838  __slots__: List[str] = []
839  _num_buffers: int = 1  # number of buffers used to represent the value.
840
841  def at_least_vspace(self):
842    raise NotImplementedError("must override")
843
844  def __repr__(self):
845    try:
846      kv_pairs = ('{}={}'.format(k, v) for k, v in self.__dict__.items())
847      return '{}({})'.format(self.__class__.__name__, ','.join(kv_pairs))
848    except AttributeError:
849      return self.__class__.__name__
850
851  def strip_weak_type(self) -> 'AbstractValue':
852    return self
853
854  def join(self, other):
855    raise NotImplementedError("must override")
856
857class Bot(AbstractValue): pass
858
859bot = Bot()
860
861class AbstractUnit(AbstractValue):
862  # TODO(jakevdp): make it possible to set zero buffers
863  # _num_buffers = 0
864  def at_least_vspace(self): return self
865  def join(self, other):
866    if not skip_checks:
867      assert other is abstract_unit, other
868    return self
869  def _eq(self, self_traced, other): return get_aval(other) is self
870  def str_short(self): return '*'
871
872abstract_unit = AbstractUnit()
873
874def lattice_join(x: Optional[AbstractValue],
875                 y: Optional[AbstractValue]) -> AbstractValue:
876  if x is None:
877    return cast(AbstractValue, y)
878  elif y is None:
879    return cast(AbstractValue, x)
880  elif isinstance(x, type(y)):
881    return y.join(x)
882  elif isinstance(y, type(x)):
883    return x.join(y)
884  else:
885    raise TypeError((x, y))
886
887# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
888Value = Any
889
890def valid_jaxtype(x):
891  try:
892    concrete_aval(x)
893  except TypeError:
894    return False
895  else:
896    return True
897
898def check_valid_jaxtype(x):
899  if not valid_jaxtype(x):
900    raise TypeError(f"{x} of type {type(x)} is not a valid JAX type")
901
902
903def concrete_aval(x):
904  for typ in type(x).mro():
905    handler = pytype_aval_mappings.get(typ)
906    if handler: return handler(x)
907  raise TypeError(f"{type(x)} is not a valid JAX type")
908
909
910def get_aval(x):
911  if isinstance(x, Tracer):
912    return x.aval
913  else:
914    return concrete_aval(x)
915
916
917pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
918
919
920class Unit:
921  def __repr__(self): return '*'
922unit = Unit()
923literalable_types.add(Unit)
924
925class UnitVar(Var):
926  count = -1
927  suffix = ''
928  def __init__(self): pass
929  @property
930  def aval(self): return abstract_unit
931  def __repr__(self): return '*'
932unitvar = UnitVar()
933
934pytype_aval_mappings[Unit] = lambda _: abstract_unit
935
936class ConcretizationTypeError(TypeError): pass
937
938def raise_concretization_error(val: Tracer, context=""):
939  msg = ("Abstract tracer value encountered where concrete value is expected.\n\n"
940         + context + "\n\n"
941         + val._origin_msg() + "\n\n"
942         "See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
943          f"Encountered tracer value: {val}")
944  raise ConcretizationTypeError(msg)
945
946
947def concretization_function_error(fun, suggest_astype=False):
948  fname = getattr(fun, "__name__", fun)
949  fname_context = f"The problem arose with the `{fname}` function. "
950  if suggest_astype:
951    fname_context += ("If trying to convert the data type of a value, "
952                      f"try using `x.astype({fun.__name__})` "
953                      f"or `jnp.array(x, {fun.__name__})` instead.")
954  def error(self, arg):
955    raise_concretization_error(arg, fname_context)
956  return error
957
958
959def concrete_or_error(force: Any, val: Any, context=""):
960  """Like force(val), but gives the context in the error message."""
961  if force is None:
962    force = lambda x: x
963  if isinstance(val, Tracer):
964    if isinstance(val.aval, ConcreteArray):
965      return force(val.aval.val)
966    else:
967      raise_concretization_error(val, context)
968  else:
969    return force(val)
970
971class UnshapedArray(AbstractValue):
972  __slots__ = ['dtype', 'weak_type']
973  array_abstraction_level = 2
974
975  def __init__(self, dtype, weak_type=False):
976    self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
977    self.weak_type = weak_type
978
979  def __eq__(self, other):
980    return (type(self) is type(other) and self.dtype == other.dtype and
981            self.weak_type == other.weak_type)
982
983  def __ne__(self, other):
984    return not self == other
985
986  def __hash__(self):
987    # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
988    # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
989    # the unique character code via hash(self.dtype.char)
990    return hash((self.dtype, self.weak_type))
991
992  def __repr__(self):
993    return '{}({}{})'.format(self.__class__.__name__, self.str_short(),
994                             ", weak_type=True" if self.weak_type else "")
995
996  _bool = _nonzero = concretization_function_error(bool)
997  _float   = concretization_function_error(float, True)
998  _int     = concretization_function_error(int, True)
999  _complex = concretization_function_error(complex, True)
1000  _hex     = concretization_function_error(hex)
1001  _oct     = concretization_function_error(oct)
1002
1003  def at_least_vspace(self) -> AbstractValue:
1004    return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
1005                         self.weak_type)
1006
1007  def join(self, other):
1008    if self.dtype == other.dtype:
1009      if self.weak_type == other.weak_type:
1010        return self
1011      else:
1012        return UnshapedArray(self.dtype, weak_type=False)
1013    else:
1014      raise TypeError(self, other)
1015
1016  def str_short(self) -> str:
1017    return self.dtype.name
1018
1019  def strip_weak_type(self) -> 'UnshapedArray':
1020    """Returns a copy of the aval with weak_type=False."""
1021    return UnshapedArray(self.dtype) if self.weak_type else self
1022
1023  @property
1024  def shape(self):
1025    msg = ("UnshapedArray has no shape. Please open an issue at "
1026           "https://github.com/google/jax/issues because it's unexpected for "
1027           "UnshapedArray instances to ever be produced.")
1028    raise TypeError(msg)
1029
1030class ShapedArray(UnshapedArray):
1031  __slots__ = ['shape']
1032  array_abstraction_level = 1
1033
1034  def __init__(self, shape, dtype, weak_type=False):
1035    super(ShapedArray, self).__init__(dtype, weak_type=weak_type)
1036    self.shape = canonicalize_shape(shape)
1037
1038  ndim = property(lambda self: len(self.shape))
1039  size = property(lambda self: prod(self.shape))
1040
1041  broadcast: ClassVar[Optional[aval_method]] = None
1042  transpose: ClassVar[Optional[aval_method]] = None
1043  reshape: ClassVar[Optional[aval_method]] = None
1044  _iter: ClassVar[Optional[staticmethod]] = None
1045
1046  def __eq__(self, other):
1047    return (type(self) is type(other)
1048            and self.dtype == other.dtype and self.shape == other.shape
1049            and self.weak_type == other.weak_type)
1050
1051  def __hash__(self):
1052    # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
1053    # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
1054    # the unique character code via hash(self.dtype.char)
1055    return hash((self.shape, self.dtype, self.weak_type))
1056
1057  def at_least_vspace(self):
1058    return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
1059                       self.weak_type)
1060
1061  def join(self, other):
1062    if self.shape == other.shape and self.dtype == other.dtype:
1063      if self.weak_type == other.weak_type:
1064        return self
1065      else:
1066        return ShapedArray(self.shape, self.dtype, weak_type=False)
1067    elif self.dtype == other.dtype:
1068      return UnshapedArray(self.dtype)
1069    else:
1070      raise TypeError(self, other)
1071
1072  def str_short(self):
1073    shapestr = ','.join(map(str, self.shape))
1074    return '{}[{}]'.format(self.dtype.name, shapestr)
1075
1076  def __len__(self):
1077    try:
1078      return self.shape[0]
1079    except IndexError as err:
1080      raise TypeError("len() of unsized object") from err  # same as numpy error
1081
1082  def _len(self, ignored_tracer):
1083    return len(self)
1084
1085  def strip_weak_type(self):
1086    return ShapedArray(self.shape, self.dtype) if self.weak_type else self
1087
1088
1089def _forward_to_value(self, fun, ignored_tracer, *args):
1090  return fun(self.val, *args)
1091
1092class ConcreteArray(ShapedArray):
1093  __slots__ = ['val']
1094  array_abstraction_level = 0
1095
1096  def __init__(self, val, weak_type=False):
1097    super(ConcreteArray, self).__init__(np.shape(val), np.result_type(val),
1098                                        weak_type=weak_type)
1099    # Note: canonicalized self.dtype doesn't necessarily match self.val
1100    self.val = val
1101    assert self.dtype != np.dtype('O'), val
1102
1103  def __eq__(self, other):
1104    if (type(self) is type(other) and self.dtype == other.dtype
1105        and self.shape == other.shape and self.weak_type == other.weak_type):
1106      with eval_context():  # in case self.val is a DeviceArray
1107        return (self.val == other.val).all()
1108    else:
1109      return False
1110
1111  def __hash__(self):
1112    return id(self.val)
1113
1114  def at_least_vspace(self):
1115    return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
1116                       weak_type=self.weak_type)
1117
1118  def join(self, other) -> UnshapedArray:
1119    if self == other:
1120      return self
1121    elif self.shape == other.shape and self.dtype == other.dtype:
1122      return ShapedArray(self.shape, self.dtype,
1123                         weak_type=self.weak_type and other.weak_type)
1124    elif self.dtype == other.dtype:
1125      return UnshapedArray(self.dtype,
1126                           weak_type=self.weak_type and other.weak_type)
1127    else:
1128      raise TypeError(self, other)
1129
1130  def str_short(self) -> str:
1131    return str(self.val)
1132
1133  def strip_weak_type(self) -> 'ConcreteArray':
1134    return ConcreteArray(self.val) if self.weak_type else self
1135
1136  _bool = _nonzero = partialmethod(_forward_to_value, bool)
1137  _int             = partialmethod(_forward_to_value, int)
1138  _hex             = partialmethod(_forward_to_value, hex)
1139  _oct             = partialmethod(_forward_to_value, oct)
1140
1141  _float           = concretization_function_error(float, True)
1142  _complex         = concretization_function_error(complex, True)
1143
1144def primal_dtype_to_tangent_dtype(primal_dtype):
1145  if not dtypes.issubdtype(primal_dtype, np.inexact):
1146    return dtypes.float0
1147  else:
1148    return primal_dtype
1149
1150class AbstractToken(AbstractValue):
1151  def join(self, other):
1152    if isinstance(other, AbstractToken):
1153      return self
1154    else:
1155      assert False, f"Cannot join {self} with {other}"
1156  def str_short(self): return 'Tok'
1157  def at_least_vspace(self): return self
1158
1159abstract_token: AbstractToken = AbstractToken()
1160
1161
1162def raise_to_shaped(aval: AbstractValue, weak_type=None):
1163  if weak_type is None:
1164    weak_type = getattr(aval, 'weak_type', False)
1165  for typ in type(aval).mro():
1166    handler = raise_to_shaped_mappings.get(typ)
1167    if handler: return handler(aval, weak_type)
1168  raise TypeError(type(aval))
1169
1170raise_to_shaped_mappings : Dict[type, Callable] = {
1171  AbstractUnit: lambda aval, _: aval,
1172  AbstractToken: lambda aval, _: aval,
1173  ShapedArray: lambda aval, weak_type: ShapedArray(aval.shape, aval.dtype, weak_type=weak_type)
1174}
1175
1176# Registry for valid dimension types. This is used by masking.Poly.
1177_DIMENSION_TYPES: Set[type] = {int}
1178
1179def _canonicalize_dimension(dim):
1180  if type(dim) in _DIMENSION_TYPES:
1181    return dim
1182  else:
1183    return operator.index(dim)
1184
1185def canonicalize_shape(shape):
1186  """Canonicalizes and checks for errors in a user-provided shape value.
1187
1188  Args:
1189    shape: a Python value that represents a shape.
1190
1191  Returns:
1192    A tuple of integers.
1193  """
1194  try:
1195    return tuple(map(_canonicalize_dimension, shape))
1196  except TypeError:
1197    pass
1198  msg = ("Shapes must be 1D sequences of concrete values of integer type, "
1199         "got {}.")
1200  if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
1201         and not isinstance(get_aval(x), ConcreteArray) for x in shape):
1202    msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
1203            "smaller subfunctions.")
1204  raise TypeError(msg.format(shape))
1205
1206
1207# ------------------- Call -------------------
1208
1209def apply_todos(todos, outs):
1210  todos_list = list(todos)
1211  while todos_list:
1212    outs = map(full_lower, todos_list.pop()(outs))
1213  return outs
1214
1215class _IgnoreElemList(list):
1216  """Compares equal to all other _ignore_elem_lists."""
1217  def __hash__(self): return 0
1218  def __eq__(self, other):
1219    return type(other) is _IgnoreElemList
1220
1221@lu.transformation_with_aux
1222def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
1223                       level: int, params_tuple: tuple, out_axes_transforms, *args):
1224  outs = yield args, {}
1225  params = dict(params_tuple)
1226  todo = []
1227  assert not out_axes_transforms
1228  while True:
1229    tracers = [x for x in outs if isinstance(x, Tracer)
1230               and (level is None or x._trace.level > level)]
1231    if tracers:
1232      ans = max(tracers, key=lambda x: x._trace.level)
1233    else:
1234      break
1235    trace = ans._trace.main.with_cur_sublevel()
1236    outs = map(trace.full_raise, outs)
1237    outs, cur_todo = primitive.post_process(trace, outs, params)
1238    if isinstance(primitive, MapPrimitive):
1239      cur_todo, out_axes_transform = cur_todo
1240      out_axes_transforms.append(out_axes_transform)
1241    todo.append(cur_todo)
1242  yield outs, tuple(todo)  # Ensure the aux output is immutable
1243
1244def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
1245              fun, *args, **params):
1246  out_axes_transforms = _IgnoreElemList()
1247  if primitive.map_primitive:
1248    out_axes_thunk = params['out_axes_thunk']
1249    # The new thunk depends deterministically on the old thunk and the wrapped function.
1250    # Any caching already has to include the wrapped function as part of the key, so we
1251    # only use the previous thunk for equality checks.
1252    @as_hashable_function(closure=out_axes_thunk)
1253    def new_out_axes_thunk():
1254      out_axes = out_axes_thunk()
1255      for t in out_axes_transforms:
1256        out_axes = t(out_axes)
1257      return out_axes
1258    params = dict(params, out_axes_thunk=new_out_axes_thunk)
1259  params_tuple = tuple(params.items())
1260  top_trace = find_top_trace(args)
1261  fun, env_trace_todo = process_env_traces(
1262      fun, primitive, top_trace and top_trace.level,
1263      params_tuple, out_axes_transforms)
1264  tracers = map(top_trace.full_raise, args)
1265  with maybe_new_sublevel(top_trace):
1266    outs = primitive.process(top_trace, fun, tracers, params)
1267  return map(full_lower, apply_todos(env_trace_todo(), outs))
1268
1269
1270class CallPrimitive(Primitive):
1271  multiple_results = True
1272  call_primitive = True
1273
1274  def bind(self, fun, *args, **params):
1275    return call_bind(self, fun, *args, **params)
1276
1277  def process(self, trace, fun, tracers, params):
1278    return trace.process_call(self, fun, tracers, params)
1279
1280  def post_process(self, trace, out_tracers, params):
1281    return trace.post_process_call(self, out_tracers, params)
1282
1283def call_impl(f: lu.WrappedFun, *args, **params):
1284  del params  # params parameterize the call primitive, not the function
1285  return f.call_wrapped(*args)
1286
1287call_p = CallPrimitive('call')
1288call = call_p.bind
1289call_p.def_impl(call_impl)
1290
1291named_call_p = CallPrimitive('named_call')
1292named_call_p.def_impl(call_impl)
1293
1294# ------------------- Map -------------------
1295
1296class MapPrimitive(Primitive):
1297  multiple_results = True
1298  map_primitive = True
1299
1300  def bind(self, fun, *args, **params):
1301    assert len(params['in_axes']) == len(args)
1302    return call_bind(self, fun, *args, **params)
1303
1304  def process(self, trace, fun, tracers, params):
1305    return trace.process_map(self, fun, tracers, params)
1306
1307  def post_process(self, trace, out_tracers, params):
1308    return trace.post_process_map(self, out_tracers, params)
1309
1310@contextmanager
1311def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
1312  frame = AxisEnvFrame(axis_name, size, tag)
1313  thread_local_state.trace_state.axis_env.append(frame)
1314  try:
1315    yield
1316  finally:
1317    thread_local_state.trace_state.axis_env.pop()
1318
1319@contextmanager
1320def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]):
1321  frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes]
1322  thread_local_state.trace_state.axis_env.extend(frames)
1323  try:
1324    yield
1325  finally:
1326    for _ in frames:
1327      thread_local_state.trace_state.axis_env.pop()
1328
1329
1330# When a mapped function is given no axis name, we generate a name object based
1331# on the id of the function object. Collisions aren't important because this
1332# name can't be used in collectives, as user code never gets a ref to this
1333# object. We don't want to use the function object itself because that might
1334# persist references to the function object.
1335# TODO(mattjj): revisit this unique axis name strategy
1336class _TempAxisName:
1337
1338  def __init__(self, obj):
1339    self.id = id(obj)
1340
1341  def __repr__(self):
1342    return f'<axis {hex(self.id)}>'
1343
1344  def __hash__(self):
1345    return hash(self.id)
1346
1347  def __eq__(self, other):
1348    return type(other) is _TempAxisName and self.id == other.id
1349
1350
1351def axis_frame(axis_name):
1352  frames = thread_local_state.trace_state.axis_env
1353  for frame in reversed(frames):
1354    if frame.name == axis_name:
1355      return frame
1356  named_axes = [frame.name for frame in reversed(frames)
1357                if not isinstance(frame.name, _TempAxisName)]
1358  raise NameError(
1359      f'unbound axis name: {axis_name}. The following axis names (e.g. defined '
1360      f'by pmap) are available to collective operations: {named_axes}')
1361
1362
1363# ------------------- Jaxpr checking -------------------
1364
1365def mapped_aval(size: int, axis: int, aval: AbstractValue) -> AbstractValue:
1366  if aval is abstract_unit:
1367    return aval
1368  elif isinstance(aval, ShapedArray):
1369    # might be raising abstraction level from Concrete here
1370    assert aval.shape[axis] == size
1371    return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype)
1372  else:
1373    raise TypeError(f"Mapped operand {aval}")
1374
1375def unmapped_aval(size: int, axis: int, aval: AbstractValue) -> AbstractValue:
1376  if aval is abstract_unit:
1377    return aval
1378  elif isinstance(aval, ShapedArray):
1379    return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype)
1380  else:
1381    raise TypeError(f"Mapped output {aval}")
1382
1383def typecheck(aval: AbstractValue, x) -> bool:
1384  return typecompat(aval, get_aval(x))
1385
1386def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
1387  """Determine whether `aval` conforms to `aval_ref`"""
1388  aval_ref = raise_to_shaped(aval_ref).strip_weak_type()
1389  try:
1390    return aval_ref == lattice_join(aval_ref, aval).strip_weak_type()
1391  except TypeError:
1392    return False
1393
1394def typematch(aval1: UnshapedArray, aval2: UnshapedArray) -> bool:
1395  return raise_to_shaped(aval1, weak_type=False) == raise_to_shaped(aval2, weak_type=False)
1396
1397class JaxprTypeError(TypeError): pass
1398
1399def typecheck_assert(pred, msg):
1400  if not pred:
1401    raise JaxprTypeError(msg)
1402
1403custom_typechecks: Dict[Primitive, Callable] = {}
1404
1405def check_jaxpr(jaxpr: Jaxpr):
1406  """Checks well-formedness of a jaxpr.
1407
1408  Specifically, check that:
1409  - variables that are read are bound beforehand
1410  - variables are typed equally throughout a jaxpr
1411  - variable type annotations are compatible with their binding expression
1412
1413  Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
1414  otherwise.
1415  """
1416  try:
1417    _check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
1418  except JaxprTypeError as e:
1419    if len(e.args) == 2:
1420      msg, eqn_idx = e.args
1421      jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqn_idx - 10, eqn_idx + 10))
1422    else:
1423      msg, = e.args
1424      jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20))
1425    msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
1426    raise JaxprTypeError(msg) from None
1427
1428def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
1429
1430  def read(v: Atom) -> AbstractValue:
1431    if isinstance(v, Literal):
1432      return raise_to_shaped(get_aval(v.val))
1433    else:
1434      typecheck_assert(v in env, f"Variable '{v}' not defined")
1435      return env[v]
1436
1437  def write(v: Var, a: AbstractValue) -> None:
1438    typecheck_assert(v not in env, f"Variable '{v}' already bound")
1439    if v is not dropvar:
1440      typecheck_assert(typecompat(v.aval, a),
1441                       f"Variable '{v}' inconsistently typed as {a}, "
1442                       f"bound as {v.aval}")
1443      env[v] = a
1444
1445  env : Dict[Var, AbstractValue] = {}
1446
1447  write(unitvar, abstract_unit)
1448  map(write, jaxpr.constvars, [v.aval for v in jaxpr.constvars])
1449  map(write, jaxpr.invars, in_avals)
1450
1451  for eqn_idx, eqn in enumerate(jaxpr.eqns):
1452    prim = eqn.primitive
1453    try:
1454      in_avals = map(read, eqn.invars)
1455      typecheck_assert(all(not isinstance(ina, ConcreteArray) for ina in in_avals),
1456                       "Equation given ConcreteArray type inputs")
1457      if prim in custom_typechecks:
1458        custom_typechecks[prim](*in_avals, **eqn.params)
1459      if prim.call_primitive:
1460        out_avals = check_call(prim, in_avals, eqn.params)
1461      elif prim.map_primitive:
1462        out_avals = check_map(prim, in_avals, eqn.params)
1463      else:
1464        out_avals = check_eqn(prim, in_avals, eqn.params)
1465      map(write, eqn.outvars, out_avals)
1466    except JaxprTypeError as e:
1467      msg, = e.args
1468      src = source_info_util.summarize(eqn.source_info)
1469      msg = "\n\n".join([msg, "in equation:", str(pp_eqn(eqn).indent(2)),
1470                         f"from source: {src}"])
1471      raise JaxprTypeError(msg, eqn_idx) from None
1472
1473  map(read, jaxpr.outvars)
1474
1475def check_eqn(prim, in_avals, params):
1476  for jaxpr in jaxprs_in_params(params):
1477    check_jaxpr(jaxpr)
1478
1479  out_avals = prim.abstract_eval(*in_avals, **params)
1480  if not prim.multiple_results:
1481    out_avals = [out_avals]
1482  return out_avals
1483
1484def check_call(prim, in_avals, params):
1485  typecheck_assert("call_jaxpr" in params,
1486                   f"Call primitive {prim} missing 'call_jaxpr' parameter")
1487  call_jaxpr = params["call_jaxpr"]
1488
1489  # These checks also happen in recursive call, but give better errors here.
1490  typecheck_assert(len(in_avals) == len(call_jaxpr.invars),
1491                   f"Call primitive {prim} with {len(call_jaxpr.invars)} "
1492                   f"operands cannot call jaxpr with {len(call_jaxpr.invars)} "
1493                   f"inputs")
1494  binder_avals = [v.aval for v in call_jaxpr.invars]
1495  for binder_aval, in_aval in zip(binder_avals, in_avals):
1496    typecheck_assert(typecompat(binder_aval, in_aval),
1497                     f"Call primitive {prim} passes operand {in_aval} "
1498                     f"to jaxpr expecting {binder_aval}")
1499
1500  _check_jaxpr(call_jaxpr, in_avals)
1501
1502  out_avals = [v.aval for v in call_jaxpr.outvars]
1503  return out_avals
1504
1505def check_map(prim, in_avals, params):
1506  typecheck_assert("call_jaxpr" in params,
1507                   f"Map primitive {prim} missing 'call_jaxpr' parameter")
1508  call_jaxpr = params["call_jaxpr"]
1509  typecheck_assert("axis_size" in params,
1510                   f"Map primitive {prim} missing 'axis_size' parameter")
1511  axis_size = params["axis_size"]
1512  typecheck_assert("in_axes" in params,
1513                   f"Map primitive {prim} missing 'in_axes' parameter")
1514  in_axes = params["in_axes"]
1515  typecheck_assert("out_axes" in params,
1516                   f"Map primitive {prim} missing 'out_axes' parameter")
1517  out_axes = params["out_axes"]
1518
1519  binder_avals = [unmapped_aval(axis_size, in_axis, v.aval)
1520                  if in_axis is not None else v.aval
1521                  for v, in_axis in zip(call_jaxpr.invars, in_axes)]
1522  for binder_aval, in_aval in zip(binder_avals, in_avals):
1523    typecheck_assert(typecompat(binder_aval, in_aval),
1524                     f"Call primitive {prim} passes operand {in_aval} "
1525                     f"to jaxpr expecting {binder_aval}")
1526
1527  mapped_avals = [mapped_aval(axis_size, in_axis, aval)
1528                  if in_axis is not None else aval
1529                  for aval, in_axis in zip(in_avals, in_axes)]
1530  _check_jaxpr(call_jaxpr, mapped_avals)
1531
1532  mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
1533  out_avals = [unmapped_aval(axis_size, out_axis, aval) if out_axis is not None else aval
1534               for aval, out_axis in zip(mapped_out_avals, out_axes)]
1535  return out_avals
1536
1537
1538# ------------------- Jaxpr printed representation -------------------
1539
1540def pp_vars(vs: Sequence[Any], print_shapes: bool = False) -> str:
1541  if print_shapes:
1542    return ' '.join(f'{v}:{v.aval.str_short()}' for v in vs)
1543  else:
1544    return ' '.join(map(str, vs))
1545
1546def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint:
1547  filtered_params = {k: v for k, v in params.items()
1548                     if (k != 'branches' and
1549                         not isinstance(v, (Jaxpr, ClosedJaxpr)))}
1550  return pp(primitive_name) >> pp_kv_pairs(sorted(filtered_params.items()))
1551
1552def pp_eqn(eqn: JaxprEqn, print_shapes: bool = False) -> PrettyPrint:
1553  lhs = pp_vars(eqn.outvars, print_shapes)
1554  pp_lhs = pp(f'{lhs} =')
1555  pp_rhs = (pp(eqn.primitive.name) >>
1556            pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >>
1557            pp(pp_vars(eqn.invars, print_shapes)))
1558  if len(lhs) <= 6 or print_shapes:
1559    return pp_lhs >> pp(' ') >> pp_rhs
1560  else:
1561    return pp_lhs + pp_rhs.indent(2)
1562
1563def pp_eqns(eqns: Sequence[JaxprEqn],
1564            source_info: bool = False) -> Sequence[PrettyPrint]:
1565  pps = map(pp_eqn, eqns)
1566  if source_info:
1567    l = max((i + len(s) for x in pps for i, s in x.lines), default=None)
1568    if l is not None:
1569      return [p.annotate(l, source_info_util.summarize(e.source_info))
1570              for e, p in zip(eqns, pps)]
1571  return pps
1572
1573def pp_jaxpr(jaxpr: Jaxpr, source_info: bool = False) -> PrettyPrint:
1574  pps = pp_eqns(jaxpr.eqns, source_info=source_info)
1575  str_outvars = str(tuple(jaxpr.outvars))
1576  return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
1577                                         pp_vars(jaxpr.invars))) +
1578          ((pp('let ') >> vcat(pps))
1579           + pp('in {} }}'.format(str_outvars))).indent(2))
1580
1581def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int,
1582                       source_info: bool = False) -> PrettyPrint:
1583  lo = max(lo, 0)
1584  hi = max(lo, min(hi, len(jaxpr.eqns)))
1585  eqns = jaxpr.eqns[lo:hi]
1586  pps = []
1587  if len(eqns) == 0 and len(jaxpr.eqns) != 0:
1588    pps.append(pp('...'))
1589  else:
1590    if lo != 0:
1591      pps.append(pp('...'))
1592    pps.extend(pp_eqns(eqns, source_info=source_info))
1593    if hi != len(jaxpr.eqns):
1594      pps.append(pp('...'))
1595  str_outvars = str(tuple(jaxpr.outvars))
1596  return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
1597                                         pp_vars(jaxpr.invars))) +
1598          ((pp('let ') >> vcat(pps))
1599           + pp('in {} }}'.format(str_outvars))).indent(2))
1600
1601def pp_jaxprs(jaxprs) -> PrettyPrint:
1602  jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
1603  return pp('( ') >> vcat(map(pp_jaxpr, jaxprs)) >> pp(' )')
1604
1605def pp_kv_pair(k, v):
1606  if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
1607    pp_v = pp_jaxprs(v)
1608  else:
1609    pp_v = pp(v)
1610  return pp(f'{k}=') >> pp_v
1611
1612def pp_kv_pairs(kv_pairs):
1613  if kv_pairs:
1614    return pp('[ ') >> vcat([pp_kv_pair(k, v) for k, v in kv_pairs]) >> pp(' ]')
1615  else:
1616    return pp('')
1617
1618@config.register_omnistaging_disabler
1619def omnistaging_disabler() -> None:
1620  global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
1621      new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env, \
1622      eval_context
1623
1624  class TraceStack:
1625    upward: List[MainTrace]
1626    downward: List[MainTrace]
1627
1628    def __init__(self):
1629      self.upward = []
1630      self.downward = []
1631
1632    def next_level(self, bottom: bool) -> int:
1633      if bottom:
1634        return - (len(self.downward) + 1)
1635      else:
1636        return len(self.upward)
1637
1638    def push(self, main_trace: MainTrace, bottom: bool) -> None:
1639      if bottom:
1640        self.downward.append(main_trace)
1641      else:
1642        self.upward.append(main_trace)
1643
1644    def pop(self, bottom: bool) -> None:
1645      if bottom:
1646        self.downward.pop()
1647      else:
1648        self.upward.pop()
1649
1650    def __repr__(self) -> str:
1651      return  'Trace stack\n{} ---\n{}'.format(
1652        map('  {}\n'.format, self.upward[::-1]),
1653        map('  {}\n'.format, self.downward))
1654
1655    def copy(self):
1656      new = TraceStack()
1657      new.upward = self.upward[:]
1658      new.downward = self.downward[:]
1659      return new
1660
1661  class TraceState:
1662    trace_stack: TraceStack
1663    substack: List[Sublevel]
1664    initial_style: bool
1665
1666    def __init__(self) -> None:
1667      self.trace_stack = TraceStack()  # type: ignore
1668      self.substack = [Sublevel(0)]
1669      self.initial_style = False
1670
1671    def copy(self):
1672      new = TraceState()
1673      new.trace_stack = self.trace_stack.copy()
1674      new.substack = self.substack[:]
1675      new.initial_style = self.initial_style
1676      return new
1677
1678  thread_local_state = ThreadLocalState()
1679
1680  def reset_trace_state() -> bool:
1681    "Reset the global trace state and return True if it was already clean."
1682    if (thread_local_state.trace_state.substack != [Sublevel(0)] or
1683        thread_local_state.trace_state.trace_stack.downward or
1684        thread_local_state.trace_state.trace_stack.upward):
1685      thread_local_state.trace_state.__init__()  # type: ignore
1686      return False
1687    else:
1688      return True
1689
1690  @contextmanager
1691  def new_main(trace_type: Type[Trace], bottom=False, **payload) -> Generator[MainTrace, None, None]:
1692    level = thread_local_state.trace_state.trace_stack.next_level(bottom)
1693    main = MainTrace(level, trace_type, **payload)
1694    thread_local_state.trace_state.trace_stack.push(main, bottom)
1695
1696    try:
1697      yield main
1698    finally:
1699      thread_local_state.trace_state.trace_stack.pop(bottom)
1700
1701    if debug_state.check_leaks:
1702      t = ref(main)
1703      del main
1704      if t() is not None:
1705        print(thread_local_state.trace_state.trace_stack)
1706        raise Exception('Leaked trace {}'.format(t()))
1707
1708  def find_top_trace(xs) -> Optional[Trace]:
1709    top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
1710                    key=attrgetter('level'), default=None)
1711    return top_trace and top_trace.main.with_cur_sublevel()
1712
1713  @contextmanager
1714  def eval_context():
1715    yield  # dummy implementation for forward compatibility
1716
1717  def bind(self, *args, **kwargs):
1718    assert skip_checks or all(isinstance(arg, Tracer)
1719                              or valid_jaxtype(arg) for arg in args), args
1720    top_trace = find_top_trace(args)
1721    if top_trace is None:
1722      return self.impl(*args, **kwargs)
1723
1724    tracers = map(top_trace.full_raise, args)
1725    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
1726    if self.multiple_results:
1727      return map(full_lower, out_tracer)
1728    else:
1729      return full_lower(out_tracer)
1730  Primitive.bind = bind  # type: ignore
1731
1732  def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
1733                fun: lu.WrappedFun, *args, **params):
1734    out_axes_transforms = _IgnoreElemList()
1735    if primitive.map_primitive:
1736      out_axes_thunk = params['out_axes_thunk']
1737      # The new thunk depends deterministically on the old thunk and the wrapped function.
1738      # Any caching already has to include the wrapped function as part of the key, so we
1739      # only use the previous thunk for equality checks.
1740      @as_hashable_function(closure=out_axes_thunk)
1741      def new_out_axes_thunk():
1742        out_axes = out_axes_thunk()
1743        for t in out_axes_transforms:
1744          out_axes = t(out_axes)
1745        return out_axes
1746      params = dict(params, out_axes_thunk=new_out_axes_thunk)
1747    params_tuple = tuple(params.items())
1748    top_trace = find_top_trace(args)
1749    level = (thread_local_state.trace_state.trace_stack.next_level(True)
1750            if top_trace is None else top_trace.level)
1751    params_tuple = tuple(params.items())
1752    fun, env_trace_todo = process_env_traces(
1753        fun, primitive, level, params_tuple, out_axes_transforms)
1754    if top_trace is None:
1755      with new_sublevel():
1756        outs = primitive.impl(fun, *args, **params)
1757    else:
1758      tracers = map(top_trace.full_raise, args)
1759      outs = primitive.process(top_trace, fun, tracers, params)
1760    return apply_todos(env_trace_todo(), map(full_lower, outs))
1761
1762  @contextmanager
1763  def extend_axis_env(axis_name, size: int, tag: Any):
1764    yield
1765
1766  @contextmanager
1767  def initial_style_staging():
1768    trace_state = thread_local_state.trace_state
1769    prev, trace_state.initial_style = trace_state.initial_style, True
1770    try:
1771      yield
1772    finally:
1773      trace_state.initial_style = prev
1774
1775# Casting float0 array to a float-valued zero array.
1776def zeros_like_float0(array, dtype=None):
1777  if not dtype:
1778    dtype = np.float
1779  return np.zeros(array.shape, dtype)
1780