1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License 17 18 19import sys as _sys 20import ctypes as _ctypes 21import numpy as _numpy 22 23from numbers import Number as _Number 24from ..name import NameManager 25from ..attribute import AttrScope 26from ..symbol_doc import _build_doc 27 28include "./base.pyi" 29 30cdef class SymbolBase: 31 """Symbol is symbolic graph.""" 32 # handle for symbolic operator. 33 cdef SymbolHandle chandle 34 35 cdef _set_handle(self, handle): 36 cdef unsigned long long ptr 37 if handle is None: 38 self.chandle = NULL 39 else: 40 ptr = handle.value 41 self.chandle = <SymbolHandle>(ptr) 42 43 property handle: 44 def __get__(self): 45 if self.chandle == NULL: 46 return None 47 else: 48 return _ctypes.cast(<unsigned long long>self.chandle, _ctypes.c_void_p) 49 def __set__(self, value): 50 self._set_handle(value) 51 52 def __init__(self, handle): 53 self._set_handle(handle) 54 55 def __dealloc__(self): 56 CALL(NNSymbolFree(self.chandle)) 57 58 def _set_attr(self, **kwargs): 59 """Set the attribute of the symbol. 60 61 Parameters 62 ---------- 63 **kwargs 64 The attributes to set 65 """ 66 SymbolSetAttr(self.chandle, kwargs) 67 68 def __reduce__(self): 69 return (_symbol_cls, (None,), self.__getstate__()) 70 71 72cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): 73 cdef string sparam_key 74 cdef string sparam_val 75 cdef const char* param_key 76 cdef const char* param_val 77 for k, v in kwargs.items(): 78 sparam_key = c_str(k) 79 sparam_val = c_str(str(v)) 80 param_key = sparam_key.c_str() 81 param_val = sparam_val.c_str() 82 CALL(MXSymbolSetAttr(handle, param_key, param_val)) 83 84 85_symbol_cls = SymbolBase 86_np_symbol_cls = None 87 88def _set_symbol_class(cls): 89 global _symbol_cls 90 _symbol_cls = cls 91 92 93def _set_np_symbol_class(cls): 94 global _np_symbol_cls 95 _np_symbol_cls = cls 96 97 98cdef NewSymbol(SymbolHandle handle, int is_np_sym=0): 99 """Create a new symbol given handle""" 100 create_symbol_fn = _np_symbol_cls if is_np_sym else _symbol_cls 101 sym = create_symbol_fn(None) 102 (<SymbolBase>sym).chandle = handle 103 return sym 104 105 106def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0, output_is_list=0): 107 cdef unsigned long long ihandle = handle 108 cdef OpHandle chandle = <OpHandle>ihandle 109 cdef vector[string] ckeys 110 cdef vector[string] cvals 111 cdef vector[string] sym_keys 112 cdef vector[SymbolHandle] sym_args 113 cdef SymbolHandle ret_handle 114 cdef string cname = c_str(name) 115 cdef nn_uint nout 116 117 for i in keys: 118 ckeys.push_back(c_str(i)) 119 for i in vals: 120 cvals.push_back(c_str(str(i))) 121 122 cdef vector[const char*] param_keys = SVec2Ptr(ckeys) 123 cdef vector[const char*] param_vals = SVec2Ptr(cvals) 124 125 CALL(MXSymbolCreateAtomicSymbol( 126 chandle, 127 <nn_uint>param_keys.size(), 128 CBeginPtr(param_keys), 129 CBeginPtr(param_vals), 130 &ret_handle)) 131 132 if args and kwargs: 133 raise TypeError( 134 'Operators with variable length input can only accept input' 135 'Symbols either as positional or keyword arguments, not both') 136 137 if args: 138 for i in args: 139 sym_args.push_back((<SymbolBase>i).chandle) 140 elif kwargs: 141 for k, v in kwargs.items(): 142 sym_keys.push_back(c_str(k)) 143 sym_args.push_back((<SymbolBase>v).chandle) 144 145 cdef vector[const char*] csym_keys = SVec2Ptr(sym_keys) 146 147 CALL(NNSymbolCompose( 148 ret_handle, 149 cname.c_str(), 150 <nn_uint>sym_args.size(), 151 &csym_keys[0] if csym_keys.size() != 0 else NULL, 152 &sym_args[0] if sym_args.size() != 0 else NULL)) 153 154 sym = NewSymbol(ret_handle, is_np_op) 155 if is_np_op: 156 CALL(NNSymbolGetNumOutputs(ret_handle, &nout)) 157 if nout > 1: 158 return list(sym) 159 elif output_is_list: 160 return [sym] 161 return sym 162