1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16from jax import core 17from .core import (lattice_join, Primitive, Unit, unit, AbstractUnit, 18 valid_jaxtype, raise_to_shaped, get_aval) 19from .tree_util import register_pytree_node 20from typing import Any, Dict, Type 21from ._src.util import safe_map 22 23from ._src import traceback_util 24traceback_util.register_exclusion(__file__) 25 26Array = Any 27 28map = safe_map 29 30jaxval_adders = {} 31jaxval_adders[Unit] = lambda _, __: unit 32 33def add_jaxvals(x, y): 34 if core.get_aval(x) is core.abstract_unit is core.get_aval(y): 35 return core.unit 36 else: 37 return add_jaxvals_p.bind(x, y) 38 39add_jaxvals_p = Primitive('add_any') 40add_any_p = add_jaxvals_p 41 42@add_jaxvals_p.def_impl 43def add_impl(xs, ys): 44 return jaxval_adders[type(xs)](xs, ys) 45 46@add_jaxvals_p.def_abstract_eval 47def add_abstract(xs, ys): 48 return lattice_join(xs, ys) 49 50jaxval_zeros_likers: Dict[type, Array] = {} 51 52def zeros_like_aval(aval): 53 return aval_zeros_likers[type(aval)](aval) 54 55aval_zeros_likers: Dict[Type[core.AbstractValue], Array] = {} 56aval_zeros_likers[AbstractUnit] = lambda _: unit 57 58def zeros_like_jaxval(val): 59 return zeros_like_p.bind(val) 60 61zeros_like_p = Primitive('zeros_like') 62 63@zeros_like_p.def_impl 64def zeros_like_impl(example): 65 return jaxval_zeros_likers[type(example)](example) 66 67zeros_like_p.def_abstract_eval(lambda x: x) 68 69class Zero: 70 __slots__ = ['aval'] 71 def __init__(self, aval): 72 self.aval = aval 73 def __repr__(self): 74 return 'Zero({})'.format(self.aval) 75 @staticmethod 76 def from_value(val): 77 return Zero(raise_to_shaped(get_aval(val))) 78 79register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval)) 80 81 82def _stop_gradient_impl(x): 83 if not valid_jaxtype(x): 84 raise TypeError("stop_gradient only works on valid JAX arrays, but " 85 f"input argument is: {x}") 86 return x 87 88stop_gradient_p = Primitive('stop_gradient') 89stop_gradient_p.def_impl(_stop_gradient_impl) 90stop_gradient_p.def_abstract_eval(lambda x: x) 91