1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file ordering_op.cc
22  * \brief CPU Implementation of the ordering operations
23  */
24 // this will be invoked by gcc and compile CPU version
25 #include "./ordering_op-inl.h"
26 #include "./elemwise_unary_op.h"
27 
28 
29 namespace mxnet {
30 namespace op {
31 DMLC_REGISTER_PARAMETER(TopKParam);
32 DMLC_REGISTER_PARAMETER(SortParam);
33 DMLC_REGISTER_PARAMETER(ArgSortParam);
34 
35 NNVM_REGISTER_OP(topk)
36 .add_alias("_npx_topk")
37 .describe(R"code(Returns the indices of the top *k* elements in an input array along the given
38  axis (by default).
39  If ret_type is set to 'value' returns the value of top *k* elements (instead of indices).
40  In case of ret_type = 'both', both value and index would be returned.
41  The returned elements will be sorted.
42 
43 Examples::
44 
45   x = [[ 0.3,  0.2,  0.4],
46        [ 0.1,  0.3,  0.2]]
47 
48   // returns an index of the largest element on last axis
49   topk(x) = [[ 2.],
50              [ 1.]]
51 
52   // returns the value of top-2 largest elements on last axis
53   topk(x, ret_typ='value', k=2) = [[ 0.4,  0.3],
54                                    [ 0.3,  0.2]]
55 
56   // returns the value of top-2 smallest elements on last axis
57   topk(x, ret_typ='value', k=2, is_ascend=1) = [[ 0.2 ,  0.3],
58                                                [ 0.1 ,  0.2]]
59 
60   // returns the value of top-2 largest elements on axis 0
61   topk(x, axis=0, ret_typ='value', k=2) = [[ 0.3,  0.3,  0.4],
62                                            [ 0.1,  0.2,  0.2]]
63 
64   // flattens and then returns list of both values and indices
65   topk(x, ret_typ='both', k=2) = [[[ 0.4,  0.3], [ 0.3,  0.2]] ,  [[ 2.,  0.], [ 1.,  2.]]]
66 
67 )code" ADD_FILELINE)
68 .set_num_inputs(1)
69 .set_num_outputs(TopKNumOutputs)
70 .set_attr_parser(ParamParser<TopKParam>)
71 .set_attr<mxnet::FInferShape>("FInferShape", TopKShape)
72 .set_attr<nnvm::FInferType>("FInferType", TopKType)
73 .set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", TopKNumVisibleOutputs)
74 .set_attr<FCompute>("FCompute<cpu>", TopK<cpu>)
75 .set_attr<nnvm::FGradient>("FGradient",
__anonead6481a0102(const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) 76   [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
77     const TopKParam& param = nnvm::get<TopKParam>(n->attrs.parsed);
78     if (param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth) {
79       std::vector<nnvm::NodeEntry> inputs;
80       uint32_t n_out = n->num_outputs();
81       for (uint32_t i = 0; i < n_out; ++i) {
82         inputs.emplace_back(n, i, 0);
83       }
84       return MakeNonlossGradNode("_backward_topk", n, {ograds[0]}, inputs, n->attrs.dict);
85     } else {
86       return MakeZeroGradNodes(n, ograds);
87     }
88   })
89 .set_attr<FResourceRequest>("FResourceRequest",
__anonead6481a0202(const NodeAttrs& attrs) 90   [](const NodeAttrs& attrs) {
91     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
92   })
93 .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
94 .add_argument("data", "NDArray-or-Symbol", "The input array")
95 .add_arguments(TopKParam::__FIELDS__());
96 
97 NNVM_REGISTER_OP(_backward_topk)
98 .set_num_inputs(3)
99 .set_num_outputs(1)
100 .set_attr_parser(ParamParser<TopKParam>)
101 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
102 .set_attr<FCompute>("FCompute<cpu>", TopKBackward_<cpu>)
103 .set_attr<FResourceRequest>("FResourceRequest",
__anonead6481a0302(const NodeAttrs& attrs) 104   [](const NodeAttrs& attrs) {
105   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
106 });
107 
108 NNVM_REGISTER_OP(sort)
109 .add_alias("_npi_sort")
110 .describe(R"code(Returns a sorted copy of an input array along the given axis.
111 
112 Examples::
113 
114   x = [[ 1, 4],
115        [ 3, 1]]
116 
117   // sorts along the last axis
118   sort(x) = [[ 1.,  4.],
119              [ 1.,  3.]]
120 
121   // flattens and then sorts
122   sort(x, axis=None) = [ 1.,  1.,  3.,  4.]
123 
124   // sorts along the first axis
125   sort(x, axis=0) = [[ 1.,  1.],
126                      [ 3.,  4.]]
127 
128   // in a descend order
129   sort(x, is_ascend=0) = [[ 4.,  1.],
130                           [ 3.,  1.]]
131 
132 )code" ADD_FILELINE)
133 .set_num_inputs(1)
134 .set_num_outputs(2)
135 .set_attr_parser(ParamParser<SortParam>)
136 .set_attr<mxnet::FInferShape>("FInferShape", SortShape)
137 .set_attr<nnvm::FInferType>("FInferType", SortType)
__anonead6481a0402(const NodeAttrs& attrs) 138 .set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; })
139 .set_attr<FCompute>("FCompute<cpu>", Sort<cpu>)
140 .set_attr<nnvm::FGradient>("FGradient",
__anonead6481a0502(const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) 141   [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
142     const SortParam& param = nnvm::get<SortParam>(n->attrs.parsed);
143     std::vector<nnvm::NodeEntry> inputs;
144     uint32_t n_out = n->num_outputs();
145     for (uint32_t i = 0; i < n_out; ++i) {
146       inputs.emplace_back(n, i, 0);
147     }
148     return MakeNonlossGradNode("_backward_topk", n, {ograds[0]}, inputs,
149                                {{"axis", n->attrs.dict["axis"]},
150                                 {"k", "0"},
151                                 {"ret_typ", "value"},
152                                 {"is_ascend", std::to_string(param.is_ascend)}});
153   })
154 .set_attr<FResourceRequest>("FResourceRequest",
__anonead6481a0602(const NodeAttrs& attrs) 155   [](const NodeAttrs& attrs) {
156     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
157   })
158 .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
159 .add_argument("data", "NDArray-or-Symbol", "The input array")
160 .add_arguments(SortParam::__FIELDS__());
161 
162 NNVM_REGISTER_OP(argsort)
163 .add_alias("_npi_argsort")
164 .describe(R"code(Returns the indices that would sort an input array along the given axis.
165 
166 This function performs sorting along the given axis and returns an array of indices having same shape
167 as an input array that index data in sorted order.
168 
169 Examples::
170 
171   x = [[ 0.3,  0.2,  0.4],
172        [ 0.1,  0.3,  0.2]]
173 
174   // sort along axis -1
175   argsort(x) = [[ 1.,  0.,  2.],
176                 [ 0.,  2.,  1.]]
177 
178   // sort along axis 0
179   argsort(x, axis=0) = [[ 1.,  0.,  1.]
180                         [ 0.,  1.,  0.]]
181 
182   // flatten and then sort
183   argsort(x, axis=None) = [ 3.,  1.,  5.,  0.,  4.,  2.]
184 )code" ADD_FILELINE)
185 .set_num_inputs(1)
186 .set_num_outputs(1)
187 .set_attr_parser(ParamParser<ArgSortParam>)
188 .set_attr<mxnet::FInferShape>("FInferShape", ArgSortShape)
189 .set_attr<nnvm::FInferType>("FInferType", ArgSortType)
190 .set_attr<FCompute>("FCompute<cpu>", ArgSort<cpu>)
191 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
192 .set_attr<FResourceRequest>("FResourceRequest",
__anonead6481a0702(const NodeAttrs& attrs) 193   [](const NodeAttrs& attrs) {
194     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
195   })
196 .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
197 .add_argument("data", "NDArray-or-Symbol", "The input array")
198 .add_arguments(ArgSortParam::__FIELDS__());
199 }  // namespace op
200 }  // namespace mxnet
201