1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16from jax import core
17from .core import (lattice_join, Primitive, Unit, unit, AbstractUnit,
18                   valid_jaxtype, raise_to_shaped, get_aval)
19from .tree_util import register_pytree_node
20from typing import Any, Dict, Type
21from ._src.util import safe_map
22
23from ._src import traceback_util
24traceback_util.register_exclusion(__file__)
25
26Array = Any
27
28map = safe_map
29
30jaxval_adders = {}
31jaxval_adders[Unit] = lambda _, __: unit
32
33def add_jaxvals(x, y):
34  if core.get_aval(x) is core.abstract_unit is core.get_aval(y):
35    return core.unit
36  else:
37    return add_jaxvals_p.bind(x, y)
38
39add_jaxvals_p = Primitive('add_any')
40add_any_p = add_jaxvals_p
41
42@add_jaxvals_p.def_impl
43def add_impl(xs, ys):
44  return jaxval_adders[type(xs)](xs, ys)
45
46@add_jaxvals_p.def_abstract_eval
47def add_abstract(xs, ys):
48  return lattice_join(xs, ys)
49
50jaxval_zeros_likers: Dict[type, Array] = {}
51
52def zeros_like_aval(aval):
53  return aval_zeros_likers[type(aval)](aval)
54
55aval_zeros_likers: Dict[Type[core.AbstractValue], Array] = {}
56aval_zeros_likers[AbstractUnit] = lambda _: unit
57
58def zeros_like_jaxval(val):
59  return zeros_like_p.bind(val)
60
61zeros_like_p = Primitive('zeros_like')
62
63@zeros_like_p.def_impl
64def zeros_like_impl(example):
65  return jaxval_zeros_likers[type(example)](example)
66
67zeros_like_p.def_abstract_eval(lambda x: x)
68
69class Zero:
70  __slots__ = ['aval']
71  def __init__(self, aval):
72    self.aval = aval
73  def __repr__(self):
74    return 'Zero({})'.format(self.aval)
75  @staticmethod
76  def from_value(val):
77    return Zero(raise_to_shaped(get_aval(val)))
78
79register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))
80
81
82def _stop_gradient_impl(x):
83  if not valid_jaxtype(x):
84    raise TypeError("stop_gradient only works on valid JAX arrays, but "
85                    f"input argument is: {x}")
86  return x
87
88stop_gradient_p = Primitive('stop_gradient')
89stop_gradient_p.def_impl(_stop_gradient_impl)
90stop_gradient_p.def_abstract_eval(lambda x: x)
91