1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=invalid-name, protected-access, too-many-arguments,  global-statement
20"""Symbolic configuration API."""
21
22import ctypes
23from ..base import _LIB
24from ..base import c_str_array, c_handle_array, c_str, mx_uint
25from ..base import SymbolHandle
26from ..base import check_call
27
28# The symbol class to be used (Cython or Ctypes)
29_symbol_cls = None
30_np_symbol_cls = None
31
32class SymbolBase(object):
33    """Symbol is symbolic graph."""
34    __slots__ = ["handle"]
35    # pylint: disable=no-member
36    def __init__(self, handle):
37        """Initialize the function with handle
38
39        Parameters
40        ----------
41        handle : SymbolHandle
42            the handle to the underlying C++ Symbol
43        """
44        self.handle = handle
45
46    def __del__(self):
47        check_call(_LIB.NNSymbolFree(self.handle))
48
49    def _compose(self, *args, **kwargs):
50        """Compose symbol on inputs.
51
52        This call mutates the current symbol.
53
54        Parameters
55        ----------
56        args:
57            provide positional arguments
58
59        kwargs:
60            provide keyword arguments
61
62        Returns
63        -------
64        the resulting symbol
65        """
66        name = kwargs.pop('name', None)
67
68        if name:
69            name = c_str(name)
70        if len(args) != 0 and len(kwargs) != 0:
71            raise TypeError('compose only accept input Symbols \
72                either as positional or keyword arguments, not both')
73
74        for arg in args:
75            if not isinstance(arg, SymbolBase):
76                raise TypeError('Compose expect `Symbol` as arguments')
77        for val in kwargs.values():
78            if not isinstance(val, SymbolBase):
79                raise TypeError('Compose expect `Symbol` as arguments')
80
81        num_args = len(args) + len(kwargs)
82        if len(kwargs) != 0:
83            keys = c_str_array(kwargs.keys())
84            args = c_handle_array(kwargs.values())
85        else:
86            keys = None
87            args = c_handle_array(kwargs.values())
88        check_call(_LIB.NNSymbolCompose(
89            self.handle, name, num_args, keys, args))
90
91    def _set_attr(self, **kwargs):
92        """Set the attribute of the symbol.
93
94        Parameters
95        ----------
96        **kwargs
97            The attributes to set
98        """
99        keys = c_str_array(kwargs.keys())
100        vals = c_str_array([str(s) for s in kwargs.values()])
101        num_args = mx_uint(len(kwargs))
102        check_call(_LIB.MXSymbolSetAttrs(
103            self.handle, num_args, keys, vals))
104
105    def _set_handle(self, handle):
106        """Set handle."""
107        self.handle = handle
108
109    def __reduce__(self):
110        return (_symbol_cls, (None,), self.__getstate__())
111
112
113def _set_symbol_class(cls):
114    """Set the symbolic class to be cls"""
115    global _symbol_cls
116    _symbol_cls = cls
117
118
119def _set_np_symbol_class(cls):
120    """Set the numpy-compatible symbolic class to be cls"""
121    global _np_symbol_cls
122    _np_symbol_cls = cls
123
124
125def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op, output_is_list=False):
126    sym_handle = SymbolHandle()
127    check_call(_LIB.MXSymbolCreateAtomicSymbol(
128        ctypes.c_void_p(handle),
129        mx_uint(len(keys)),
130        c_str_array(keys),
131        c_str_array([str(v) for v in vals]),
132        ctypes.byref(sym_handle)))
133
134    if args and kwargs:
135        raise TypeError(
136            'Operators with variable length input can only accept input'
137            'Symbols either as positional or keyword arguments, not both')
138    create_symbol_fn = _np_symbol_cls if is_np_op else _symbol_cls
139    s = create_symbol_fn(sym_handle)
140    if args:
141        s._compose(*args, name=name)
142    elif kwargs:
143        s._compose(name=name, **kwargs)
144    else:
145        s._compose(name=name)
146    if is_np_op:
147        # Determine whether the symbol is a list.
148        if s.num_outputs > 1:
149            return list(s)
150        elif output_is_list:
151            return [s]
152    return s
153