1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Util functions for the numpy module.""" 19 20 21 22import numpy as onp 23 24__all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64', 25 'bool', 'bool_', 'pi', 'inf', 'nan', 'PZERO', 'NZERO', 'newaxis', 'finfo', 26 'e', 'NINF', 'PINF', 'NAN', 'NaN', 27 '_STR_2_DTYPE_'] 28 29float16 = onp.float16 30float32 = onp.float32 31float64 = onp.float64 32uint8 = onp.uint8 33int32 = onp.int32 34int8 = onp.int8 35int64 = onp.int64 36bool_ = onp.bool_ 37bool = onp.bool 38 39pi = onp.pi 40inf = onp.inf 41nan = onp.nan 42PZERO = onp.PZERO 43NZERO = onp.NZERO 44NINF = onp.NINF 45PINF = onp.PINF 46e = onp.e 47NAN = onp.NAN 48NaN = onp.NaN 49 50newaxis = None 51finfo = onp.finfo 52 53_STR_2_DTYPE_ = {'float16': float16, 'float32': float32, 'float64':float64, 'float': float64, 54 'uint8': uint8, 'int8': int8, 'int32': int32, 'int64': int64, 'int': int64, 55 'bool': bool, 'bool_': bool_, 'None': None} 56 57_ONP_OP_MODULES = [onp, onp.linalg, onp.random, onp.fft] 58 59 60def _get_np_op(name): 61 """Get official NumPy operator with `name`. If not found, raise ValueError.""" 62 for mod in _ONP_OP_MODULES: 63 op = getattr(mod, name, None) 64 if op is not None: 65 return op 66 raise ValueError('Operator `{}` is not supported by `mxnet.numpy`.'.format(name)) 67