1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement 20"""Symbolic configuration API.""" 21 22import ctypes 23from ..base import _LIB 24from ..base import c_str_array, c_handle_array, c_str, mx_uint 25from ..base import SymbolHandle 26from ..base import check_call 27 28# The symbol class to be used (Cython or Ctypes) 29_symbol_cls = None 30_np_symbol_cls = None 31 32class SymbolBase(object): 33 """Symbol is symbolic graph.""" 34 __slots__ = ["handle"] 35 # pylint: disable=no-member 36 def __init__(self, handle): 37 """Initialize the function with handle 38 39 Parameters 40 ---------- 41 handle : SymbolHandle 42 the handle to the underlying C++ Symbol 43 """ 44 self.handle = handle 45 46 def __del__(self): 47 check_call(_LIB.NNSymbolFree(self.handle)) 48 49 def _compose(self, *args, **kwargs): 50 """Compose symbol on inputs. 51 52 This call mutates the current symbol. 53 54 Parameters 55 ---------- 56 args: 57 provide positional arguments 58 59 kwargs: 60 provide keyword arguments 61 62 Returns 63 ------- 64 the resulting symbol 65 """ 66 name = kwargs.pop('name', None) 67 68 if name: 69 name = c_str(name) 70 if len(args) != 0 and len(kwargs) != 0: 71 raise TypeError('compose only accept input Symbols \ 72 either as positional or keyword arguments, not both') 73 74 for arg in args: 75 if not isinstance(arg, SymbolBase): 76 raise TypeError('Compose expect `Symbol` as arguments') 77 for val in kwargs.values(): 78 if not isinstance(val, SymbolBase): 79 raise TypeError('Compose expect `Symbol` as arguments') 80 81 num_args = len(args) + len(kwargs) 82 if len(kwargs) != 0: 83 keys = c_str_array(kwargs.keys()) 84 args = c_handle_array(kwargs.values()) 85 else: 86 keys = None 87 args = c_handle_array(kwargs.values()) 88 check_call(_LIB.NNSymbolCompose( 89 self.handle, name, num_args, keys, args)) 90 91 def _set_attr(self, **kwargs): 92 """Set the attribute of the symbol. 93 94 Parameters 95 ---------- 96 **kwargs 97 The attributes to set 98 """ 99 keys = c_str_array(kwargs.keys()) 100 vals = c_str_array([str(s) for s in kwargs.values()]) 101 num_args = mx_uint(len(kwargs)) 102 check_call(_LIB.MXSymbolSetAttrs( 103 self.handle, num_args, keys, vals)) 104 105 def _set_handle(self, handle): 106 """Set handle.""" 107 self.handle = handle 108 109 def __reduce__(self): 110 return (_symbol_cls, (None,), self.__getstate__()) 111 112 113def _set_symbol_class(cls): 114 """Set the symbolic class to be cls""" 115 global _symbol_cls 116 _symbol_cls = cls 117 118 119def _set_np_symbol_class(cls): 120 """Set the numpy-compatible symbolic class to be cls""" 121 global _np_symbol_cls 122 _np_symbol_cls = cls 123 124 125def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op, output_is_list=False): 126 sym_handle = SymbolHandle() 127 check_call(_LIB.MXSymbolCreateAtomicSymbol( 128 ctypes.c_void_p(handle), 129 mx_uint(len(keys)), 130 c_str_array(keys), 131 c_str_array([str(v) for v in vals]), 132 ctypes.byref(sym_handle))) 133 134 if args and kwargs: 135 raise TypeError( 136 'Operators with variable length input can only accept input' 137 'Symbols either as positional or keyword arguments, not both') 138 create_symbol_fn = _np_symbol_cls if is_np_op else _symbol_cls 139 s = create_symbol_fn(sym_handle) 140 if args: 141 s._compose(*args, name=name) 142 elif kwargs: 143 s._compose(name=name, **kwargs) 144 else: 145 s._compose(name=name) 146 if is_np_op: 147 # Determine whether the symbol is a list. 148 if s.num_outputs > 1: 149 return list(s) 150 elif output_is_list: 151 return [s] 152 return s 153