1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from functools import partial 16 17import numpy as np 18 19from . import ad_util 20from . import core 21from . import dtypes 22 23from ._src import traceback_util 24traceback_util.register_exclusion(__file__) 25 26_DIMENSION_TYPES = core._DIMENSION_TYPES 27 28UnshapedArray = core.UnshapedArray 29ShapedArray = core.ShapedArray 30ConcreteArray = core.ConcreteArray 31AbstractToken = core.AbstractToken 32abstract_token = core.abstract_token 33canonicalize_shape = core.canonicalize_shape 34raise_to_shaped = core.raise_to_shaped 35 36 37def make_shaped_array(x): 38 dtype = dtypes.canonicalize_dtype(dtypes.result_type(x)) 39 return ShapedArray(np.shape(x), dtype) 40 41def zeros_like_array(x): 42 dtype = dtypes.canonicalize_dtype(dtypes.result_type(x)) 43 return zeros_like_shaped_array(ShapedArray(np.shape(x), dtype)) 44 45array_types = {np.ndarray, np.bool_, 46 np.int8, np.int16, np.int32, np.int64, 47 np.uint8, np.uint16, np.uint32, np.uint64, 48 dtypes.bfloat16, np.float16, np.float32, np.float64, 49 np.complex64, np.complex128, 50 np.longlong, np.intc} 51 52for t in array_types: 53 core.pytype_aval_mappings[t] = ConcreteArray 54 ad_util.jaxval_zeros_likers[t] = zeros_like_array 55 56 57def zeros_like_shaped_array(aval): 58 assert isinstance(aval, ShapedArray) 59 if aval.dtype == dtypes.float0: 60 return np.zeros(aval.shape, dtypes.float0) 61 return np.broadcast_to(np.array(0, aval.dtype), aval.shape) 62 63ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array 64 65core.literalable_types.update(array_types) 66 67def _zeros_like_python_scalar(t, x): 68 return np.array(0, dtypes.python_scalar_dtypes[t]) 69 70def _make_concrete_python_scalar(t, x): 71 return ConcreteArray( 72 np.array(x, dtype=dtypes.python_scalar_dtypes[t]), 73 weak_type=True) 74 75for t in dtypes.python_scalar_dtypes: 76 core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t) 77 ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t) 78 79core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) 80