1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, unused-import, protected-access 18"""Symbolic graph construction API. 19 20This namespace contains most of the registered operators. 21For detailed list of operators, checkout ``Core Tensor Operators`` 22""" 23from __future__ import absolute_import as _abs 24import sys as _sys 25import os as _os 26import ctypes as _ctypes 27from numbers import Number as _Number 28 29import numpy as np 30 31from . import _base 32from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init 33from .attribute import AttrScope 34from . import _symbol_internal as _internal 35from . import contrib 36 37# Use different verison of SymbolBase 38# When possible, use cython to speedup part of computation. 39 40IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError 41 42try: 43 if _FFI_MODE == "ctypes": 44 raise ImportError() 45 if _sys.version_info >= (3, 0): 46 from ._cy3.symbol import SymbolBase, _init_symbol_module 47 else: 48 from ._cy2.symbol import SymbolBase, _init_symbol_module 49except IMPORT_EXCEPT: 50 # pylint: disable=wrong-import-position 51 from ._ctypes.symbol import SymbolBase, _init_symbol_module 52 53 54class Symbol(SymbolBase): 55 """Symbol is basic operation unit for symbolic graph composition.""" 56 # disable dictionary storage, also do not have parent type. 57 __slots__ = [] 58 59 _tvm_tcode = 16 60 61 @property 62 def _tvm_handle(self): 63 return self.handle.value 64 65 def __add__(self, other): 66 """x.__add__(y) <=> x+y""" 67 if isinstance(other, Symbol): 68 return __add_symbol__(self, other) 69 if isinstance(other, _Number): 70 return __add_scalar__(self, scalar=other) 71 raise TypeError("type %s not supported" % str(type(other))) 72 73 def __radd__(self, other): 74 return self.__add__(other) 75 76 def __sub__(self, other): 77 """x.__sub__(y) <=> x-y""" 78 if isinstance(other, Symbol): 79 return __sub_symbol__(self, other) 80 if isinstance(other, _Number): 81 return __sub_scalar__(self, scalar=other) 82 raise TypeError('type %s not supported' % str(type(other))) 83 84 def __rsub__(self, other): 85 if isinstance(other, _Number): 86 return __rsub_scalar__(self, scalar=other) 87 raise TypeError('type %s not supported' % str(type(other))) 88 89 def __mul__(self, other): 90 """x.__mul__(y) <=> x*y""" 91 if isinstance(other, Symbol): 92 return __mul_symbol__(self, other) 93 if isinstance(other, _Number): 94 return __mul_scalar__(self, scalar=other) 95 raise TypeError('type %s not supported' % str(type(other))) 96 97 def __rmul__(self, other): 98 return self.__mul__(other) 99 100 def __div__(self, other): 101 """x.__div__(y) <=> x/y""" 102 if isinstance(other, Symbol): 103 return __div_symbol__(self, other) 104 if isinstance(other, _Number): 105 return __div_scalar__(self, scalar=other) 106 raise TypeError('type %s not supported' % str(type(other))) 107 108 def __rdiv__(self, other): 109 if isinstance(other, _Number): 110 return __rdiv_scalar__(self, scalar=other) 111 raise TypeError('type %s not supported' % str(type(other))) 112 113 def __lshift__(self, other): 114 """x.__lshift__(y) <=> x << y""" 115 if isinstance(other, _Number): 116 return __lshift_scalar__(self, scalar=other) 117 raise TypeError('type %s not supported' % str(type(other))) 118 119 def __rshift__(self, other): 120 """x.__rshift__(y) <=> x >> y""" 121 if isinstance(other, _Number): 122 return __rshift_scalar__(self, scalar=other) 123 raise TypeError('type %s not supported' % str(type(other))) 124 125 def __truediv__(self, other): 126 return self.__div__(other) 127 128 def __rtruediv__(self, other): 129 return self.__rdiv__(other) 130 131 def __pow__(self, other): 132 """x.__pow__(y) <=> x**y""" 133 if isinstance(other, Symbol): 134 return __pow_symbol__(self, other) 135 if isinstance(other, _Number): 136 return __pow_scalar__(self, scalar=other) 137 raise TypeError('type %s not supported' % str(type(other))) 138 139 def __rpow__(self, other): 140 if isinstance(other, _Number): 141 return __rpow_scalar__(self, scalar=other) 142 raise TypeError('type %s not supported' % str(type(other))) 143 144 def __neg__(self): 145 """x.__neg__() <=> -x""" 146 return self.__mul__(-1.0) 147 148 def __copy__(self): 149 return self.__deepcopy__() 150 151 def __deepcopy__(self, _=None): 152 """Returns a deep copy of the input object.""" 153 handle = _base.SymbolHandle() 154 _base.check_call(_LIB.NNSymbolCopy(self.handle, 155 _ctypes.byref(handle))) 156 return Symbol(handle) 157 158 def __getitem__(self, index): 159 if isinstance(index, _base.string_types): 160 idx = None 161 for i, name in enumerate(self.list_output_names()): 162 if name == index: 163 if idx is not None: 164 raise ValueError('There are multiple outputs with name \"%s\"' % index) 165 idx = i 166 if idx is None: 167 raise ValueError('Cannot find output that matches name \"%s\"' % index) 168 index = idx 169 if not isinstance(index, int): 170 raise TypeError('Symbol only support integer index to fetch i-th output') 171 handle = _base.SymbolHandle() 172 _check_call(_LIB.NNSymbolGetOutput( 173 self.handle, _base.nn_uint(index), _ctypes.byref(handle))) 174 return Symbol(handle=handle) 175 176 def __iter__(self): 177 return (self[i] for i in self.list_output_names()) 178 179 def attr(self, key): 180 """Get attribute string from the symbol, this function only works for non-grouped symbol. 181 182 Parameters 183 ---------- 184 key : str 185 The key to get attribute from. 186 187 Returns 188 ------- 189 value : str 190 The attribute value of the key, returns None if attribute do not exist. 191 """ 192 ret = _ctypes.c_char_p() 193 success = _ctypes.c_int() 194 _check_call(_LIB.NNSymbolGetAttr( 195 self.handle, _base.c_str(key), _ctypes.byref(ret), _ctypes.byref(success))) 196 if success.value != 0: 197 return _base.py_str(ret.value) 198 return None 199 200 def list_attr(self, recursive=False): 201 """Get all attributes from the symbol. 202 203 Parameters 204 ---------- 205 recursive : bool 206 Default `False`. When `recursive` is `True`, list recursively all the 207 attributes in the descendents. The attribute names are pre-pended with 208 the symbol names to avoid conflicts. If `False`, then only attributes 209 that belongs to this symbol is returned, and the attribute names will 210 **not** be pre-pended with the symbol name. 211 """ 212 size = _base.nn_uint() 213 pairs = _ctypes.POINTER(_ctypes.c_char_p)() 214 option = _ctypes.c_int(0) if recursive else _ctypes.c_int(1) 215 _check_call(_LIB.NNSymbolListAttrs( 216 self.handle, option, _ctypes.byref(size), _ctypes.byref(pairs))) 217 return {_base.py_str(pairs[i*2]): _base.py_str(pairs[i*2+1]) for i in range(size.value)} 218 219 def get_internals(self): 220 """Get a new grouped symbol whose output contains all the internal outputs of this symbol. 221 222 Returns 223 ------- 224 sgroup : Symbol 225 The internal of the symbol. 226 """ 227 handle = _base.SymbolHandle() 228 _check_call(_LIB.NNSymbolGetInternals( 229 self.handle, _ctypes.byref(handle))) 230 return Symbol(handle=handle) 231 232 def get_children(self): 233 """Gets a new grouped symbol whose output contains 234 inputs to output nodes of the original symbol.""" 235 handle = _base.SymbolHandle() 236 _check_call(_LIB.NNSymbolGetChildren( 237 self.handle, _ctypes.byref(handle))) 238 ret = Symbol(handle=handle) 239 if not ret.list_output_names(): 240 return None 241 return ret 242 243 def _get_list_copt(self, option): 244 """internal function to get list option""" 245 if option == 'all': 246 return _ctypes.c_int(0) 247 if option == 'read_only': 248 return _ctypes.c_int(1) 249 if option == 'aux_state': 250 return _ctypes.c_int(2) 251 raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}") 252 253 def list_input_variables(self, option='all'): 254 """List all the input variables in the symbol. 255 256 Parameters 257 ---------- 258 option : {'all', 'read_only', 'aux_state'}, optional 259 The listing option 260 - 'all' will list all the arguments. 261 - 'read_only' lists arguments that are readed by the graph. 262 - 'aux_state' lists arguments that are mutated by the graph as state. 263 Returns 264 ------- 265 vars : list of symbol 266 List of all the variables 267 """ 268 size = _ctypes.c_uint() 269 sarr = _ctypes.POINTER(_base.SymbolHandle)() 270 _check_call(_LIB.NNSymbolListInputVariables( 271 self.handle, self._get_list_copt(option), 272 _ctypes.byref(size), _ctypes.byref(sarr))) 273 return [Symbol(_base.SymbolHandle(sarr[i])) for i in range(size.value)] 274 275 def list_input_names(self, option='all'): 276 """List all the inputs in the symbol. 277 278 Parameters 279 ---------- 280 option : {'all', 'read_only', 'aux_state'}, optional 281 The listing option 282 - 'all' will list all the arguments. 283 - 'read_only' lists arguments that are readed by the graph. 284 - 'aux_state' lists arguments that are mutated by the graph as state. 285 Returns 286 ------- 287 args : list of string 288 List of all the arguments. 289 """ 290 size = _ctypes.c_uint() 291 sarr = _ctypes.POINTER(_ctypes.c_char_p)() 292 _check_call(_LIB.NNSymbolListInputNames( 293 self.handle, self._get_list_copt(option), 294 _ctypes.byref(size), _ctypes.byref(sarr))) 295 return [_base.py_str(sarr[i]) for i in range(size.value)] 296 297 def list_output_names(self): 298 """List all outputs in the symbol. 299 300 Returns 301 ------- 302 returns : list of string 303 List of all the outputs. 304 """ 305 size = _ctypes.c_uint() 306 sarr = _ctypes.POINTER(_ctypes.c_char_p)() 307 _check_call(_LIB.NNSymbolListOutputNames( 308 self.handle, _ctypes.byref(size), _ctypes.byref(sarr))) 309 return [_base.py_str(sarr[i]) for i in range(size.value)] 310 311 def debug_str(self): 312 """Get a debug string. 313 314 Returns 315 ------- 316 debug_str : string 317 Debug string of the symbol. 318 """ 319 debug_str = _ctypes.c_char_p() 320 _check_call(_LIB.NNSymbolPrint( 321 self.handle, _ctypes.byref(debug_str))) 322 return _base.py_str(debug_str.value) 323 324 def _add_control_deps(self, deps): 325 """Add control flow dependencies. 326 This makes current op depend on the deps. 327 Only use when necessary, 328 this function mutate the current symbol node. 329 330 Returns 331 ------- 332 deps : Symbol for list of symbol 333 The dependencies 334 """ 335 if isinstance(deps, list): 336 deps = Group(deps) 337 _check_call(_LIB.NNAddControlDeps( 338 self.handle, deps.handle)) 339 340 341def Variable(name, init=None, **kwargs): 342 """Create a symbolic variable with specified name. 343 344 Parameters 345 ---------- 346 name : str 347 Name of the variable. 348 init : Symbol or numpy.ndarray 349 Symbol or numpy ndarray of initial value for the variable. 350 Note that for symbolic initialization value, it must be able 351 to be defined through InferShape, such as sym.zeros_like(v), 352 in which v is an input or parameter. Otherwise, pass a numpy 353 ndarray instead. 354 kwargs : dict of string -> string 355 Additional attributes to set on the variable. 356 357 Returns 358 ------- 359 variable : Symbol 360 The created variable symbol. 361 """ 362 if not isinstance(name, _base.string_types): 363 raise TypeError('Expect a string for variable `name`') 364 handle = _base.SymbolHandle() 365 _base.check_call(_LIB.NNSymbolCreateVariable( 366 _base.c_str(name), _ctypes.byref(handle))) 367 ret = Symbol(handle) 368 attr = AttrScope.current.get(kwargs) 369 if attr: 370 ret._set_attr(**attr) 371 if init is not None: 372 if not isinstance(init, (Symbol, np.ndarray)): 373 raise TypeError('Expect a Symbol or numpy ndarray' 374 'for variable `init`') 375 _all_var_init[name] = init 376 return ret 377 378 379def Group(symbols): 380 """Create a symbol that groups symbols together. 381 382 Parameters 383 ---------- 384 symbols : list 385 List of symbols to be grouped. 386 387 Returns 388 ------- 389 sym : Symbol 390 The created group symbol. 391 """ 392 ihandles = [] 393 for sym in symbols: 394 if not isinstance(sym, Symbol): 395 raise TypeError('Expect Symbols in the list input') 396 ihandles.append(sym.handle) 397 handle = _base.SymbolHandle() 398 _check_call(_LIB.NNSymbolCreateGroup( 399 _base.nn_uint(len(ihandles)), 400 _base.c_array(_base.SymbolHandle, ihandles), 401 _ctypes.byref(handle))) 402 return Symbol(handle) 403 404# Set the real symbol class to Symbol 405_init_symbol_module(Symbol, "nnvm") 406