1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-import, protected-access
18"""Symbolic graph construction API.
19
20This namespace contains most of the registered operators.
21For detailed list of operators, checkout ``Core Tensor Operators``
22"""
23from __future__ import absolute_import as _abs
24import sys as _sys
25import os as _os
26import ctypes as _ctypes
27from numbers import Number as _Number
28
29import numpy as np
30
31from . import _base
32from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init
33from .attribute import AttrScope
34from . import _symbol_internal as _internal
35from . import contrib
36
37# Use different verison of SymbolBase
38# When possible, use cython to speedup part of computation.
39
40IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
41
42try:
43    if _FFI_MODE == "ctypes":
44        raise ImportError()
45    if _sys.version_info >= (3, 0):
46        from ._cy3.symbol import SymbolBase, _init_symbol_module
47    else:
48        from ._cy2.symbol import SymbolBase, _init_symbol_module
49except IMPORT_EXCEPT:
50    # pylint: disable=wrong-import-position
51    from ._ctypes.symbol import SymbolBase, _init_symbol_module
52
53
54class Symbol(SymbolBase):
55    """Symbol is basic operation unit for symbolic graph composition."""
56    # disable dictionary storage, also do not have parent type.
57    __slots__ = []
58
59    _tvm_tcode = 16
60
61    @property
62    def _tvm_handle(self):
63        return self.handle.value
64
65    def __add__(self, other):
66        """x.__add__(y) <=> x+y"""
67        if isinstance(other, Symbol):
68            return __add_symbol__(self, other)
69        if isinstance(other, _Number):
70            return __add_scalar__(self, scalar=other)
71        raise TypeError("type %s not supported" % str(type(other)))
72
73    def __radd__(self, other):
74        return self.__add__(other)
75
76    def __sub__(self, other):
77        """x.__sub__(y) <=> x-y"""
78        if isinstance(other, Symbol):
79            return __sub_symbol__(self, other)
80        if isinstance(other, _Number):
81            return __sub_scalar__(self, scalar=other)
82        raise TypeError('type %s not supported' % str(type(other)))
83
84    def __rsub__(self, other):
85        if isinstance(other, _Number):
86            return __rsub_scalar__(self, scalar=other)
87        raise TypeError('type %s not supported' % str(type(other)))
88
89    def __mul__(self, other):
90        """x.__mul__(y) <=> x*y"""
91        if isinstance(other, Symbol):
92            return __mul_symbol__(self, other)
93        if isinstance(other, _Number):
94            return __mul_scalar__(self, scalar=other)
95        raise TypeError('type %s not supported' % str(type(other)))
96
97    def __rmul__(self, other):
98        return self.__mul__(other)
99
100    def __div__(self, other):
101        """x.__div__(y) <=> x/y"""
102        if isinstance(other, Symbol):
103            return __div_symbol__(self, other)
104        if isinstance(other, _Number):
105            return __div_scalar__(self, scalar=other)
106        raise TypeError('type %s not supported' % str(type(other)))
107
108    def __rdiv__(self, other):
109        if isinstance(other, _Number):
110            return __rdiv_scalar__(self, scalar=other)
111        raise TypeError('type %s not supported' % str(type(other)))
112
113    def __lshift__(self, other):
114        """x.__lshift__(y) <=> x << y"""
115        if isinstance(other, _Number):
116            return __lshift_scalar__(self, scalar=other)
117        raise TypeError('type %s not supported' % str(type(other)))
118
119    def __rshift__(self, other):
120        """x.__rshift__(y) <=> x >> y"""
121        if isinstance(other, _Number):
122            return __rshift_scalar__(self, scalar=other)
123        raise TypeError('type %s not supported' % str(type(other)))
124
125    def __truediv__(self, other):
126        return self.__div__(other)
127
128    def __rtruediv__(self, other):
129        return self.__rdiv__(other)
130
131    def __pow__(self, other):
132        """x.__pow__(y) <=> x**y"""
133        if isinstance(other, Symbol):
134            return __pow_symbol__(self, other)
135        if isinstance(other, _Number):
136            return __pow_scalar__(self, scalar=other)
137        raise TypeError('type %s not supported' % str(type(other)))
138
139    def __rpow__(self, other):
140        if isinstance(other, _Number):
141            return __rpow_scalar__(self, scalar=other)
142        raise TypeError('type %s not supported' % str(type(other)))
143
144    def __neg__(self):
145        """x.__neg__() <=> -x"""
146        return self.__mul__(-1.0)
147
148    def __copy__(self):
149        return self.__deepcopy__()
150
151    def __deepcopy__(self, _=None):
152        """Returns a deep copy of the input object."""
153        handle = _base.SymbolHandle()
154        _base.check_call(_LIB.NNSymbolCopy(self.handle,
155                                           _ctypes.byref(handle)))
156        return Symbol(handle)
157
158    def __getitem__(self, index):
159        if isinstance(index, _base.string_types):
160            idx = None
161            for i, name in enumerate(self.list_output_names()):
162                if name == index:
163                    if idx is not None:
164                        raise ValueError('There are multiple outputs with name \"%s\"' % index)
165                    idx = i
166            if idx is None:
167                raise ValueError('Cannot find output that matches name \"%s\"' % index)
168            index = idx
169        if not isinstance(index, int):
170            raise TypeError('Symbol only support integer index to fetch i-th output')
171        handle = _base.SymbolHandle()
172        _check_call(_LIB.NNSymbolGetOutput(
173            self.handle, _base.nn_uint(index), _ctypes.byref(handle)))
174        return Symbol(handle=handle)
175
176    def __iter__(self):
177        return (self[i] for i in self.list_output_names())
178
179    def attr(self, key):
180        """Get attribute string from the symbol, this function only works for non-grouped symbol.
181
182        Parameters
183        ----------
184        key : str
185            The key to get attribute from.
186
187        Returns
188        -------
189        value : str
190            The attribute value of the key, returns None if attribute do not exist.
191        """
192        ret = _ctypes.c_char_p()
193        success = _ctypes.c_int()
194        _check_call(_LIB.NNSymbolGetAttr(
195            self.handle, _base.c_str(key), _ctypes.byref(ret), _ctypes.byref(success)))
196        if success.value != 0:
197            return _base.py_str(ret.value)
198        return None
199
200    def list_attr(self, recursive=False):
201        """Get all attributes from the symbol.
202
203        Parameters
204        ----------
205        recursive : bool
206            Default `False`. When `recursive` is `True`, list recursively all the
207            attributes in the descendents. The attribute names are pre-pended with
208            the symbol names to avoid conflicts. If `False`, then only attributes
209            that belongs to this symbol is returned, and the attribute names will
210            **not** be pre-pended with the symbol name.
211        """
212        size = _base.nn_uint()
213        pairs = _ctypes.POINTER(_ctypes.c_char_p)()
214        option = _ctypes.c_int(0) if recursive else _ctypes.c_int(1)
215        _check_call(_LIB.NNSymbolListAttrs(
216            self.handle, option, _ctypes.byref(size), _ctypes.byref(pairs)))
217        return {_base.py_str(pairs[i*2]): _base.py_str(pairs[i*2+1]) for i in range(size.value)}
218
219    def get_internals(self):
220        """Get a new grouped symbol whose output contains all the internal outputs of this symbol.
221
222        Returns
223        -------
224        sgroup : Symbol
225            The internal of the symbol.
226        """
227        handle = _base.SymbolHandle()
228        _check_call(_LIB.NNSymbolGetInternals(
229            self.handle, _ctypes.byref(handle)))
230        return Symbol(handle=handle)
231
232    def get_children(self):
233        """Gets a new grouped symbol whose output contains
234           inputs to output nodes of the original symbol."""
235        handle = _base.SymbolHandle()
236        _check_call(_LIB.NNSymbolGetChildren(
237            self.handle, _ctypes.byref(handle)))
238        ret = Symbol(handle=handle)
239        if not ret.list_output_names():
240            return None
241        return ret
242
243    def _get_list_copt(self, option):
244        """internal function to get list option"""
245        if option == 'all':
246            return _ctypes.c_int(0)
247        if option == 'read_only':
248            return _ctypes.c_int(1)
249        if option == 'aux_state':
250            return _ctypes.c_int(2)
251        raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
252
253    def list_input_variables(self, option='all'):
254        """List all the input variables in the symbol.
255
256        Parameters
257        ----------
258        option : {'all', 'read_only', 'aux_state'}, optional
259           The listing option
260           - 'all' will list all the arguments.
261           - 'read_only' lists arguments that are readed by the graph.
262           - 'aux_state' lists arguments that are mutated by the graph as state.
263        Returns
264        -------
265        vars : list of symbol
266            List of all the variables
267        """
268        size = _ctypes.c_uint()
269        sarr = _ctypes.POINTER(_base.SymbolHandle)()
270        _check_call(_LIB.NNSymbolListInputVariables(
271            self.handle, self._get_list_copt(option),
272            _ctypes.byref(size), _ctypes.byref(sarr)))
273        return [Symbol(_base.SymbolHandle(sarr[i])) for i in range(size.value)]
274
275    def list_input_names(self, option='all'):
276        """List all the inputs in the symbol.
277
278        Parameters
279        ----------
280        option : {'all', 'read_only', 'aux_state'}, optional
281           The listing option
282           - 'all' will list all the arguments.
283           - 'read_only' lists arguments that are readed by the graph.
284           - 'aux_state' lists arguments that are mutated by the graph as state.
285        Returns
286        -------
287        args : list of string
288            List of all the arguments.
289        """
290        size = _ctypes.c_uint()
291        sarr = _ctypes.POINTER(_ctypes.c_char_p)()
292        _check_call(_LIB.NNSymbolListInputNames(
293            self.handle, self._get_list_copt(option),
294            _ctypes.byref(size), _ctypes.byref(sarr)))
295        return [_base.py_str(sarr[i]) for i in range(size.value)]
296
297    def list_output_names(self):
298        """List all outputs in the symbol.
299
300        Returns
301        -------
302        returns : list of string
303            List of all the outputs.
304        """
305        size = _ctypes.c_uint()
306        sarr = _ctypes.POINTER(_ctypes.c_char_p)()
307        _check_call(_LIB.NNSymbolListOutputNames(
308            self.handle, _ctypes.byref(size), _ctypes.byref(sarr)))
309        return [_base.py_str(sarr[i]) for i in range(size.value)]
310
311    def debug_str(self):
312        """Get a debug string.
313
314        Returns
315        -------
316        debug_str : string
317            Debug string of the symbol.
318        """
319        debug_str = _ctypes.c_char_p()
320        _check_call(_LIB.NNSymbolPrint(
321            self.handle, _ctypes.byref(debug_str)))
322        return _base.py_str(debug_str.value)
323
324    def _add_control_deps(self, deps):
325        """Add control flow dependencies.
326        This makes current op depend on the deps.
327        Only use when necessary,
328        this function mutate the current symbol node.
329
330        Returns
331        -------
332        deps : Symbol for list of symbol
333            The dependencies
334        """
335        if isinstance(deps, list):
336            deps = Group(deps)
337        _check_call(_LIB.NNAddControlDeps(
338            self.handle, deps.handle))
339
340
341def Variable(name, init=None, **kwargs):
342    """Create a symbolic variable with specified name.
343
344    Parameters
345    ----------
346    name : str
347        Name of the variable.
348    init : Symbol or numpy.ndarray
349        Symbol or numpy ndarray of initial value for the variable.
350        Note that for symbolic initialization value, it must be able
351        to be defined through InferShape, such as sym.zeros_like(v),
352        in which v is an input or parameter. Otherwise, pass a numpy
353        ndarray instead.
354    kwargs : dict of string -> string
355        Additional attributes to set on the variable.
356
357    Returns
358    -------
359    variable : Symbol
360        The created variable symbol.
361    """
362    if not isinstance(name, _base.string_types):
363        raise TypeError('Expect a string for variable `name`')
364    handle = _base.SymbolHandle()
365    _base.check_call(_LIB.NNSymbolCreateVariable(
366        _base.c_str(name), _ctypes.byref(handle)))
367    ret = Symbol(handle)
368    attr = AttrScope.current.get(kwargs)
369    if attr:
370        ret._set_attr(**attr)
371    if init is not None:
372        if not isinstance(init, (Symbol, np.ndarray)):
373            raise TypeError('Expect a Symbol or numpy ndarray'
374                            'for variable `init`')
375        _all_var_init[name] = init
376    return ret
377
378
379def Group(symbols):
380    """Create a symbol that groups symbols together.
381
382    Parameters
383    ----------
384    symbols : list
385        List of symbols to be grouped.
386
387    Returns
388    -------
389    sym : Symbol
390        The created group symbol.
391     """
392    ihandles = []
393    for sym in symbols:
394        if not isinstance(sym, Symbol):
395            raise TypeError('Expect Symbols in the list input')
396        ihandles.append(sym.handle)
397    handle = _base.SymbolHandle()
398    _check_call(_LIB.NNSymbolCreateGroup(
399        _base.nn_uint(len(ihandles)),
400        _base.c_array(_base.SymbolHandle, ihandles),
401        _ctypes.byref(handle)))
402    return Symbol(handle)
403
404# Set the real symbol class to Symbol
405_init_symbol_module(Symbol, "nnvm")
406