1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# NDArray functions dynamically imported from libmxnet 19 20function _invoke_mxfunction(func_handle::MX_handle, use_vars, scalars, mut_vars; kwargs...) 21 names = String[string(entry[1]) for entry in kwargs] 22 args = String[string(entry[2]) for entry in kwargs] 23 @mxcall(:MXFuncInvokeEx, 24 (MX_handle, Ptr{MX_handle}, Ptr{MX_float}, Ptr{MX_handle}, Cint, char_pp, char_pp), 25 func_handle, use_vars, scalars, mut_vars, length(names), names, args) 26end 27 28@enum(LIBMX_FUNC_TYPE_MASK, 29 NDARRAY_ARG_BEFORE_SCALAR = 1, 30 ACCEPT_EMPTY_MUTATE_TARGET = (1 << 2) 31) 32 33# Import corresponding math functions from base so the automatically defined libmxnet 34# functions can overload them 35import Base: sqrt 36 37""" 38The libxmnet APIs are automatically imported from `libmxnet.so`. The functions listed 39here operate on `NDArray` objects. The arguments to the functions are typically ordered 40as 41 42```julia 43 func_name(arg_in1, arg_in2, ..., scalar1, scalar2, ..., arg_out1, arg_out2, ...) 44``` 45 46unless `NDARRAY_ARG_BEFORE_SCALAR` is not set. In this case, the scalars are put before the input arguments: 47 48```julia 49 func_name(scalar1, scalar2, ..., arg_in1, arg_in2, ..., arg_out1, arg_out2, ...) 50``` 51 52If `ACCEPT_EMPTY_MUTATE_TARGET` is set. An overloaded function without the output arguments will also be defined: 53 54```julia 55 func_name(arg_in1, arg_in2, ..., scalar1, scalar2, ...) 56``` 57 58Upon calling, the output arguments will be automatically initialized with empty NDArrays. 59 60Those functions always return the output arguments. If there is only one output (the typical situation), that 61object (`NDArray`) is returned. Otherwise, a tuple containing all the outputs will be returned. 62""" 63function _get_ndarray_function_def(name::String) 64 func_name = Symbol(name) 65 66 func_def = quote 67 function $func_name(::Type{<:NDArray}, args::NDArray...; out=nothing, kwargs...) 68 if out != nothing 69 output_vars = out 70 if isa(output_vars, NDArray) 71 output_vars = NDArray[output_vars] 72 end 73 num_outputs = length(output_vars) 74 else 75 output_vars = NDArray[] 76 num_outputs = 0 77 end 78 79 args = collect(args) # tuple to list 80 if length(args) == 0 81 args = MX_handle[] 82 end 83 84 output_handles_pp = if length(output_vars) > 0 85 [map(x -> x.handle, output_vars)] 86 else 87 [Ptr{MX_handle}(C_NULL)] 88 end 89 num_outputs_p = [convert(Cint, num_outputs)] 90 91 kw_keys_str = String[string(x[1]) for x in kwargs] 92 kw_vals_str = String[dump_mx_param(x[2]) for x in kwargs] 93 94 op_handle = _get_cached_libmx_op_handle($(name)) 95 @mxcall(:MXImperativeInvoke, 96 (MX_handle, Cint, Ptr{MX_handle}, 97 Ptr{Cint}, Ptr{Ptr{MX_handle}}, 98 Cint, char_pp, char_pp), 99 op_handle, length(args), args, 100 num_outputs_p, output_handles_pp, 101 length(kwargs), kw_keys_str, kw_vals_str) 102 103 if out == nothing 104 n = num_outputs_p[] 105 hdls = unsafe_wrap(Array{MX_handle}, output_handles_pp[], n) 106 xs = NDArray[NDArray(MX_NDArrayHandle(x)) for x in hdls] 107 if n == 1 108 return xs[] 109 else 110 return xs 111 end 112 else 113 return out 114 end 115 end 116 end 117 118 func_def2 = quote 119 function $func_name(args::NDArray...; out=nothing, kwargs...) 120 $func_name(NDArray, args...; out=out, kwargs...) 121 end 122 end 123 124 return func_def, func_def2 125end 126 127const _op_import_bl = [ # import black list; do not import these funcs 128 "_full", # we already have `mx.fill` 129 "_ones", # we already have `mx.ones` 130 "_zeros", # we already have `mx.zeros` 131 "clip", 132 "expand_dims", 133 134 # arithmetic 135 "_plus", 136 "_minus", 137 "_mod", 138 "_mod_scalar", 139 "_rmod_scalar", 140 141 "dot", 142 "max", 143 "max_axis", 144 "mean", 145 "min", 146 "min_axis", 147 "prod", 148 "reshape", 149 "sum", 150 "transpose", 151 152 # trigonometric 153 "sin", 154 "cos", 155 "tan", 156 "arcsin", 157 "arccos", 158 "arctan", 159 160 # hyperbolic 161 "sinh", 162 "cosh", 163 "tanh", 164 "arcsinh", 165 "arccosh", 166 "arctanh", 167 168 # activation 169 "sigmoid", 170 "relu", 171 "softmax", 172 "log_softmax", 173 174 # broadcast 175 "broadcast_add", 176 "broadcast_plus", 177 "broadcast_minus", 178 "broadcast_sub", 179 "broadcast_mul", 180 "broadcast_div", 181 "broadcast_mod", 182 "broadcast_power", 183 "broadcast_equal", 184 "broadcast_not_equal", 185 "broadcast_greater", 186 "broadcast_greater_equal", 187 "broadcast_lesser", 188 "broadcast_lesser_equal", 189 "broadcast_maximum", 190 "broadcast_minimum", 191 "broadcast_to", 192 "broadcast_axis", 193 "broadcast_axes", 194 "broadcast_hypot", 195 196 # reduction 197 "argmax", 198 "argmin", 199] 200 201macro _import_ndarray_functions() 202 names = filter(n -> ∉(lowercase(n), _op_import_bl), _get_libmx_op_names()) 203 204 func_exprs = map(names) do name 205 op_handle = _get_libmx_op_handle(name) 206 207 desc, key_narg = _get_libmx_op_description(name, op_handle) 208 func_def, func_def2 = _get_ndarray_function_def(name) 209 210 func_name = Symbol(name) 211 212 import_expr = _import_expr(func_name) 213 214 quote 215 $import_expr 216 $func_def 217 @doc $desc 218 $func_def2 219 end 220 end 221 222 esc(quote 223 $(func_exprs...) 224 end) 225end 226 227@_import_ndarray_functions 228