1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# Base.Array related interface 19 20import Base: reshape 21 22""" 23 reshape(sym::SymbolicNode, dim; reverse=false, name) 24 reshape(sym::SymbolicNode, dim...; reverse=false, name) 25 26Reshape SymbolicNode operator 27 28Some dimensions of the shape can take special values from the set 29{0, -1, -2, -3, -4}. 30The significance of each is explained below: 31 32- `0` copy this dimension from the input to the output shape. 33 34 Example: 35 36 - input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2) 37 - input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4) 38 39- `-1` infers the dimension of the output shape by using the remainder of the 40 input dimensions keeping the size of the new array same as that of the input 41 array. At most one dimension of shape can be -1. 42 43 Example: 44 45 - input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4) 46 - input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8) 47 - input shape = (2,3,4), shape=(-1,), output shape = (24,) 48 49- `-2` copy all/remainder of the input dimensions to the output shape. 50 51 Example: 52 53 - input shape = (2,3,4), shape = (-2,), output shape = (2,3,4) 54 - input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4) 55 - input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1) 56 57- `-3` use the product of two consecutive dimensions of the input shape as the 58 output dimension. 59 60 Example: 61 62 - input shape = (2,3,4), shape = (-3,4), output shape = (6,4) 63 - input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20) 64 - input shape = (2,3,4), shape = (0,-3), output shape = (2,12) 65 - input shape = (2,3,4), shape = (-3,-2), output shape = (6,4) 66 67- `-4` split one dimension of the input into two dimensions passed subsequent 68 to -4 in shape (can contain -1). 69 70 Example: 71 72 - input shape = (2,3,4), shape = (-4,1,2,-2), output shape = (1,2,3,4) 73 - input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4) 74 75If the argument `reverse` is set to `1`, then the special values are inferred 76from right to left. 77 78 Example: 79 80 - with `reverse=false`, for input shape = (10,5,4), shape = (-1,0), 81 output shape would be (40,5) 82 - with `reverse=true`, output shape will be (50,4). 83""" 84reshape(sym::SymbolicNode, dim::NTuple{N, Integer}; kwargs...) where {N} = 85 _reshape(sym, dim; kwargs...) 86reshape(sym::SymbolicNode, dim::Integer...; kwargs...) = 87 _reshape(sym, dim; kwargs...) 88 89@inline function _reshape(sym::SymbolicNode, dim::NTuple{N,Integer}; 90 reverse::Bool=false, name::String="") where N 91 op = _get_cached_libmx_op_handle("reshape") 92 node = _create_atomic_symbol(op.value, ["shape", "reverse"], 93 [dump_mx_param(dim), dump_mx_param(!reverse)]) 94 name = get!(DEFAULT_NAME_MANAGER, name, "reshape") 95 _compose!(node, name=name, data=sym) 96end 97 98################################################################################ 99# Base.getindex 100################################################################################ 101 102""" 103 getindex(self :: SymbolicNode, idx :: Union{Int, Base.Symbol, AbstractString}) 104 105Get a node representing the specified output of this node. The index could be 106a symbol or string indicating the name of the output, or a 1-based integer 107indicating the index, as in the list of [`list_outputs`](@ref). 108""" 109function Base.getindex(self :: SymbolicNode, idx :: Union{Base.Symbol, AbstractString}) 110 idx = Symbol(idx) 111 i_idx = findall(idx .== list_outputs(self)) 112 @assert(length(i_idx) > 0, "Cannot find output with name '$idx'") 113 @assert(length(i_idx) < 2, "Found duplicated output with name '$idx'") 114 Base.getindex(self, i_idx[1]) 115end 116function Base.getindex(self :: SymbolicNode, idx :: Int) 117 ref_hdr = Ref{MX_handle}(0) 118 # note Julia is 1-based, while MXNet is 0-based 119 @mxcall(:MXSymbolGetOutput, (MX_handle, MX_uint, Ref{MX_handle}), self, idx-1, ref_hdr) 120 return SymbolicNode(MX_SymbolHandle(ref_hdr[])) 121end 122 123