1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# NDArray functions dynamically imported from libmxnet
19
20function _invoke_mxfunction(func_handle::MX_handle, use_vars, scalars, mut_vars; kwargs...)
21  names = String[string(entry[1]) for entry in kwargs]
22  args = String[string(entry[2]) for entry in kwargs]
23  @mxcall(:MXFuncInvokeEx,
24          (MX_handle, Ptr{MX_handle}, Ptr{MX_float}, Ptr{MX_handle}, Cint, char_pp, char_pp),
25          func_handle, use_vars, scalars, mut_vars, length(names), names, args)
26end
27
28@enum(LIBMX_FUNC_TYPE_MASK,
29  NDARRAY_ARG_BEFORE_SCALAR = 1,
30  ACCEPT_EMPTY_MUTATE_TARGET = (1 << 2)
31)
32
33# Import corresponding math functions from base so the automatically defined libmxnet
34# functions can overload them
35import Base: sqrt
36
37"""
38The libxmnet APIs are automatically imported from `libmxnet.so`. The functions listed
39here operate on `NDArray` objects. The arguments to the functions are typically ordered
40as
41
42```julia
43  func_name(arg_in1, arg_in2, ..., scalar1, scalar2, ..., arg_out1, arg_out2, ...)
44```
45
46unless `NDARRAY_ARG_BEFORE_SCALAR` is not set. In this case, the scalars are put before the input arguments:
47
48```julia
49  func_name(scalar1, scalar2, ..., arg_in1, arg_in2, ..., arg_out1, arg_out2, ...)
50```
51
52If `ACCEPT_EMPTY_MUTATE_TARGET` is set. An overloaded function without the output arguments will also be defined:
53
54```julia
55  func_name(arg_in1, arg_in2, ..., scalar1, scalar2, ...)
56```
57
58Upon calling, the output arguments will be automatically initialized with empty NDArrays.
59
60Those functions always return the output arguments. If there is only one output (the typical situation), that
61object (`NDArray`) is returned. Otherwise, a tuple containing all the outputs will be returned.
62"""
63function _get_ndarray_function_def(name::String)
64  func_name = Symbol(name)
65
66  func_def = quote
67    function $func_name(::Type{<:NDArray}, args::NDArray...; out=nothing, kwargs...)
68      if out != nothing
69        output_vars = out
70        if isa(output_vars, NDArray)
71          output_vars = NDArray[output_vars]
72        end
73        num_outputs = length(output_vars)
74      else
75        output_vars = NDArray[]
76        num_outputs = 0
77      end
78
79      args = collect(args)  # tuple to list
80      if length(args) == 0
81        args = MX_handle[]
82      end
83
84      output_handles_pp = if length(output_vars) > 0
85        [map(x -> x.handle, output_vars)]
86      else
87        [Ptr{MX_handle}(C_NULL)]
88      end
89      num_outputs_p = [convert(Cint, num_outputs)]
90
91      kw_keys_str = String[string(x[1]) for x in kwargs]
92      kw_vals_str = String[dump_mx_param(x[2]) for x in kwargs]
93
94      op_handle = _get_cached_libmx_op_handle($(name))
95      @mxcall(:MXImperativeInvoke,
96              (MX_handle, Cint, Ptr{MX_handle},
97               Ptr{Cint}, Ptr{Ptr{MX_handle}},
98               Cint, char_pp, char_pp),
99              op_handle, length(args), args,
100              num_outputs_p, output_handles_pp,
101              length(kwargs), kw_keys_str, kw_vals_str)
102
103      if out == nothing
104        n = num_outputs_p[]
105        hdls = unsafe_wrap(Array{MX_handle}, output_handles_pp[], n)
106        xs = NDArray[NDArray(MX_NDArrayHandle(x)) for x in hdls]
107        if n == 1
108          return xs[]
109        else
110          return xs
111        end
112      else
113        return out
114      end
115    end
116  end
117
118  func_def2 = quote
119    function $func_name(args::NDArray...; out=nothing, kwargs...)
120      $func_name(NDArray, args...; out=out, kwargs...)
121    end
122  end
123
124  return func_def, func_def2
125end
126
127const _op_import_bl = [  # import black list; do not import these funcs
128    "_full",   # we already have `mx.fill`
129    "_ones",   # we already have `mx.ones`
130    "_zeros",  # we already have `mx.zeros`
131    "clip",
132    "expand_dims",
133
134    # arithmetic
135    "_plus",
136    "_minus",
137    "_mod",
138    "_mod_scalar",
139    "_rmod_scalar",
140
141    "dot",
142    "max",
143    "max_axis",
144    "mean",
145    "min",
146    "min_axis",
147    "prod",
148    "reshape",
149    "sum",
150    "transpose",
151
152    # trigonometric
153    "sin",
154    "cos",
155    "tan",
156    "arcsin",
157    "arccos",
158    "arctan",
159
160    # hyperbolic
161    "sinh",
162    "cosh",
163    "tanh",
164    "arcsinh",
165    "arccosh",
166    "arctanh",
167
168    # activation
169    "sigmoid",
170    "relu",
171    "softmax",
172    "log_softmax",
173
174    # broadcast
175    "broadcast_add",
176    "broadcast_plus",
177    "broadcast_minus",
178    "broadcast_sub",
179    "broadcast_mul",
180    "broadcast_div",
181    "broadcast_mod",
182    "broadcast_power",
183    "broadcast_equal",
184    "broadcast_not_equal",
185    "broadcast_greater",
186    "broadcast_greater_equal",
187    "broadcast_lesser",
188    "broadcast_lesser_equal",
189    "broadcast_maximum",
190    "broadcast_minimum",
191    "broadcast_to",
192    "broadcast_axis",
193    "broadcast_axes",
194    "broadcast_hypot",
195
196    # reduction
197    "argmax",
198    "argmin",
199]
200
201macro _import_ndarray_functions()
202  names = filter(n -> ∉(lowercase(n), _op_import_bl), _get_libmx_op_names())
203
204  func_exprs = map(names) do name
205    op_handle = _get_libmx_op_handle(name)
206
207    desc, key_narg = _get_libmx_op_description(name, op_handle)
208    func_def, func_def2 = _get_ndarray_function_def(name)
209
210    func_name = Symbol(name)
211
212    import_expr = _import_expr(func_name)
213
214    quote
215      $import_expr
216      $func_def
217      @doc $desc
218      $func_def2
219    end
220  end
221
222  esc(quote
223    $(func_exprs...)
224  end)
225end
226
227@_import_ndarray_functions
228