1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Util functions for the numpy module."""
19
20
21
22import numpy as onp
23
24__all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64',
25           'bool', 'bool_', 'pi', 'inf', 'nan', 'PZERO', 'NZERO', 'newaxis', 'finfo',
26           'e', 'NINF', 'PINF', 'NAN', 'NaN',
27           '_STR_2_DTYPE_']
28
29float16 = onp.float16
30float32 = onp.float32
31float64 = onp.float64
32uint8 = onp.uint8
33int32 = onp.int32
34int8 = onp.int8
35int64 = onp.int64
36bool_ = onp.bool_
37bool = onp.bool
38
39pi = onp.pi
40inf = onp.inf
41nan = onp.nan
42PZERO = onp.PZERO
43NZERO = onp.NZERO
44NINF = onp.NINF
45PINF = onp.PINF
46e = onp.e
47NAN = onp.NAN
48NaN = onp.NaN
49
50newaxis = None
51finfo = onp.finfo
52
53_STR_2_DTYPE_ = {'float16': float16, 'float32': float32, 'float64':float64, 'float': float64,
54                 'uint8': uint8, 'int8': int8, 'int32': int32, 'int64': int64, 'int': int64,
55                 'bool': bool, 'bool_': bool_, 'None': None}
56
57_ONP_OP_MODULES = [onp, onp.linalg, onp.random, onp.fft]
58
59
60def _get_np_op(name):
61    """Get official NumPy operator with `name`. If not found, raise ValueError."""
62    for mod in _ONP_OP_MODULES:
63        op = getattr(mod, name, None)
64        if op is not None:
65            return op
66    raise ValueError('Operator `{}` is not supported by `mxnet.numpy`.'.format(name))
67