1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from functools import partial
16
17import numpy as np
18
19from . import ad_util
20from . import core
21from . import dtypes
22
23from ._src import traceback_util
24traceback_util.register_exclusion(__file__)
25
26_DIMENSION_TYPES = core._DIMENSION_TYPES
27
28UnshapedArray = core.UnshapedArray
29ShapedArray = core.ShapedArray
30ConcreteArray = core.ConcreteArray
31AbstractToken = core.AbstractToken
32abstract_token = core.abstract_token
33canonicalize_shape = core.canonicalize_shape
34raise_to_shaped = core.raise_to_shaped
35
36
37def make_shaped_array(x):
38  dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
39  return ShapedArray(np.shape(x), dtype)
40
41def zeros_like_array(x):
42  dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
43  return zeros_like_shaped_array(ShapedArray(np.shape(x), dtype))
44
45array_types = {np.ndarray, np.bool_,
46               np.int8, np.int16, np.int32, np.int64,
47               np.uint8, np.uint16, np.uint32, np.uint64,
48               dtypes.bfloat16, np.float16, np.float32, np.float64,
49               np.complex64, np.complex128,
50               np.longlong, np.intc}
51
52for t in array_types:
53  core.pytype_aval_mappings[t] = ConcreteArray
54  ad_util.jaxval_zeros_likers[t] = zeros_like_array
55
56
57def zeros_like_shaped_array(aval):
58  assert isinstance(aval, ShapedArray)
59  if aval.dtype == dtypes.float0:
60    return np.zeros(aval.shape, dtypes.float0)
61  return np.broadcast_to(np.array(0, aval.dtype), aval.shape)
62
63ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
64
65core.literalable_types.update(array_types)
66
67def _zeros_like_python_scalar(t, x):
68  return np.array(0, dtypes.python_scalar_dtypes[t])
69
70def _make_concrete_python_scalar(t, x):
71  return ConcreteArray(
72    np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
73    weak_type=True)
74
75for t in dtypes.python_scalar_dtypes:
76  core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
77  ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
78
79core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
80