1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# Base.Array related interface
19
20import Base: reshape
21
22"""
23    reshape(sym::SymbolicNode, dim; reverse=false, name)
24    reshape(sym::SymbolicNode, dim...; reverse=false, name)
25
26Reshape SymbolicNode operator
27
28Some dimensions of the shape can take special values from the set
29{0, -1, -2, -3, -4}.
30The significance of each is explained below:
31
32- `0`  copy this dimension from the input to the output shape.
33
34  Example:
35
36  - input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2)
37  - input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4)
38
39- `-1` infers the dimension of the output shape by using the remainder of the
40  input dimensions keeping the size of the new array same as that of the input
41  array. At most one dimension of shape can be -1.
42
43  Example:
44
45  - input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4)
46  - input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8)
47  - input shape = (2,3,4), shape=(-1,), output shape = (24,)
48
49- `-2` copy all/remainder of the input dimensions to the output shape.
50
51  Example:
52
53  - input shape = (2,3,4), shape = (-2,), output shape = (2,3,4)
54  - input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4)
55  - input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1)
56
57- `-3` use the product of two consecutive dimensions of the input shape as the
58  output dimension.
59
60  Example:
61
62  - input shape = (2,3,4), shape = (-3,4), output shape = (6,4)
63  - input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20)
64  - input shape = (2,3,4), shape = (0,-3), output shape = (2,12)
65  - input shape = (2,3,4), shape = (-3,-2), output shape = (6,4)
66
67- `-4` split one dimension of the input into two dimensions passed subsequent
68  to -4 in shape (can contain -1).
69
70  Example:
71
72  - input shape = (2,3,4), shape = (-4,1,2,-2), output shape = (1,2,3,4)
73  - input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4)
74
75If the argument `reverse` is set to `1`, then the special values are inferred
76from right to left.
77
78  Example:
79
80  - with `reverse=false`, for input shape = (10,5,4), shape = (-1,0),
81    output shape would be (40,5)
82  - with `reverse=true`, output shape will be (50,4).
83"""
84reshape(sym::SymbolicNode, dim::NTuple{N, Integer}; kwargs...) where {N} =
85  _reshape(sym, dim; kwargs...)
86reshape(sym::SymbolicNode, dim::Integer...; kwargs...) =
87  _reshape(sym, dim; kwargs...)
88
89@inline function _reshape(sym::SymbolicNode, dim::NTuple{N,Integer};
90                          reverse::Bool=false, name::String="") where N
91  op = _get_cached_libmx_op_handle("reshape")
92  node = _create_atomic_symbol(op.value, ["shape", "reverse"],
93                               [dump_mx_param(dim), dump_mx_param(!reverse)])
94  name = get!(DEFAULT_NAME_MANAGER, name, "reshape")
95  _compose!(node, name=name, data=sym)
96end
97
98################################################################################
99# Base.getindex
100################################################################################
101
102"""
103    getindex(self :: SymbolicNode, idx :: Union{Int, Base.Symbol, AbstractString})
104
105Get a node representing the specified output of this node. The index could be
106a symbol or string indicating the name of the output, or a 1-based integer
107indicating the index, as in the list of [`list_outputs`](@ref).
108"""
109function Base.getindex(self :: SymbolicNode, idx :: Union{Base.Symbol, AbstractString})
110  idx   = Symbol(idx)
111  i_idx = findall(idx .== list_outputs(self))
112  @assert(length(i_idx) > 0, "Cannot find output with name '$idx'")
113  @assert(length(i_idx) < 2, "Found duplicated output with name '$idx'")
114  Base.getindex(self, i_idx[1])
115end
116function Base.getindex(self :: SymbolicNode, idx :: Int)
117  ref_hdr = Ref{MX_handle}(0)
118  # note Julia is 1-based, while MXNet is 0-based
119  @mxcall(:MXSymbolGetOutput, (MX_handle, MX_uint, Ref{MX_handle}), self, idx-1, ref_hdr)
120  return SymbolicNode(MX_SymbolHandle(ref_hdr[]))
121end
122
123