1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file ordering_op-inl.h
22  * \brief Function definition of ordering operators
23  */
24 #ifndef MXNET_OPERATOR_TENSOR_ORDERING_OP_INL_H_
25 #define MXNET_OPERATOR_TENSOR_ORDERING_OP_INL_H_
26 
27 #include <mxnet/operator_util.h>
28 #include <dmlc/optional.h>
29 #include <mshadow/tensor.h>
30 #include <algorithm>
31 #include <vector>
32 #include <type_traits>
33 #include "../mshadow_op.h"
34 #include "../elemwise_op_common.h"
35 #include "./sort_op.h"
36 #include "./indexing_op.h"
37 
38 namespace mshadow {
39 template<typename xpu, int src_dim, typename DType, int dst_dim>
inplace_reshape(Tensor<xpu,src_dim,DType> src,Shape<dst_dim> target_shape)40 inline Tensor<xpu, dst_dim, DType> inplace_reshape(Tensor<xpu, src_dim, DType> src,
41                                                    Shape<dst_dim> target_shape) {
42   CHECK_EQ(src.CheckContiguous(), true);
43   return Tensor<xpu, dst_dim, DType>(src.dptr_, target_shape, src.stream_);
44 }
45 };
46 
47 
48 namespace mxnet {
49 namespace op {
50 // These enums are only visible within this header
51 namespace topk_enum {
52 enum TopKReturnType {kReturnValue, kReturnIndices, kReturnMask, kReturnBoth};
53 }  // topk_enum
54 
55 struct TopKParam : public dmlc::Parameter<TopKParam> {
56   dmlc::optional<int> axis;
57   int k;
58   int ret_typ;
59   bool is_ascend;
60   int dtype;
DMLC_DECLARE_PARAMETERTopKParam61   DMLC_DECLARE_PARAMETER(TopKParam) {
62     DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
63     .describe("Axis along which to choose the top k indices."
64               " If not given, the flattened array is used. Default is -1.");
65     DMLC_DECLARE_FIELD(k).set_default(1)
66     .describe("Number of top elements to select,"
67               " should be always smaller than or equal to the element number in the given axis."
68               " A global sort is performed if set k < 1.");
69     DMLC_DECLARE_FIELD(ret_typ).set_default(topk_enum::kReturnIndices)
70     .add_enum("value", topk_enum::kReturnValue)
71     .add_enum("indices", topk_enum::kReturnIndices)
72     .add_enum("mask", topk_enum::kReturnMask)
73     .add_enum("both", topk_enum::kReturnBoth)
74     .describe("The return type.\n"
75         " \"value\" means to return the top k values,"
76         " \"indices\" means to return the indices of the top k values,"
77         " \"mask\" means to return a mask array containing 0 and 1. 1 means the top k values."
78         " \"both\" means to return a list of both values and indices of top k elements.");
79     DMLC_DECLARE_FIELD(is_ascend).set_default(false)
80       .describe("Whether to choose k largest or k smallest elements."
81                 " Top K largest elements will be chosen if set to false.");
82     DMLC_DECLARE_FIELD(dtype)
83     // TODO(srivrohi): remove support for real data type in mxnet-2.0
84     .add_enum("uint8", mshadow::kUint8)
85     .add_enum("int32", mshadow::kInt32)
86     .add_enum("int64", mshadow::kInt64)
87     .add_enum("float16", mshadow::kFloat16)
88     .add_enum("float32", mshadow::kFloat32)
89     .add_enum("float64", mshadow::kFloat64)
90     .set_default(mshadow::kFloat32)
91     .describe("DType of the output indices when ret_typ is \"indices\" or \"both\". "
92               "An error will be raised if the selected data type cannot precisely represent the "
93               "indices.");
94   }
95 };
96 
97 struct SortParam : public dmlc::Parameter<SortParam> {
98   dmlc::optional<int> axis;
99   bool is_ascend;
DMLC_DECLARE_PARAMETERSortParam100   DMLC_DECLARE_PARAMETER(SortParam) {
101     DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
102     .describe("Axis along which to choose sort the input tensor."
103               " If not given, the flattened array is used. Default is -1.");
104     DMLC_DECLARE_FIELD(is_ascend).set_default(true)
105       .describe("Whether to sort in ascending or descending order.");
106   }
107 };
108 
109 struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
110   dmlc::optional<int> axis;
111   bool is_ascend;
112   int dtype;
DMLC_DECLARE_PARAMETERArgSortParam113   DMLC_DECLARE_PARAMETER(ArgSortParam) {
114     DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
115     .describe("Axis along which to sort the input tensor."
116               " If not given, the flattened array is used. Default is -1.");
117     DMLC_DECLARE_FIELD(is_ascend).set_default(true)
118       .describe("Whether to sort in ascending or descending order.");
119     DMLC_DECLARE_FIELD(dtype)
120     // TODO(srivrohi): remove support for real data type in mxnet-2.0
121     .add_enum("uint8", mshadow::kUint8)
122     .add_enum("int32", mshadow::kInt32)
123     .add_enum("int64", mshadow::kInt64)
124     .add_enum("float16", mshadow::kFloat16)
125     .add_enum("float32", mshadow::kFloat32)
126     .add_enum("float64", mshadow::kFloat64)
127     .set_default(mshadow::kFloat32)
128     .describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or"
129               " \"both\". An error will be raised if the selected data type cannot precisely "
130               "represent the indices.");
131   }
132 };
133 
ParseTopKParam(const TShape & src_shape,const TopKParam & param,TShape * target_shape,size_t * batch_size,index_t * element_num,int * axis,index_t * k,bool * do_transpose,bool * is_ascend)134 inline void ParseTopKParam(const TShape& src_shape,
135                            const TopKParam& param,
136                            TShape *target_shape,
137                            size_t *batch_size,
138                            index_t *element_num,
139                            int *axis,
140                            index_t *k,
141                            bool *do_transpose,
142                            bool *is_ascend) {
143   *do_transpose = false;
144   *k = param.k;
145   *is_ascend = param.is_ascend;
146   // get batch_size, axis and element_num
147   if (!static_cast<bool>(param.axis)) {  // No axis given
148     *axis = 0;
149     *batch_size = 1;
150     *element_num = src_shape.Size();
151   } else {
152     *axis = param.axis.value();
153     if (*axis < 0) {
154       *axis += src_shape.ndim();
155     }
156     CHECK(*axis >= 0 && *axis < static_cast<int>(src_shape.ndim()))
157                                                   << "Invalid axis! axis should be between 0 and "
158                                                   << src_shape.ndim() << ", found axis=" << *axis;
159     if (src_shape[*axis] != 0) {
160       *batch_size = src_shape.Size() / src_shape[*axis];
161     }
162     *element_num = src_shape[*axis];
163     if (*axis != src_shape.ndim() - 1) {
164       *do_transpose = true;
165     }
166   }
167   // get k
168   if (param.k <= 0) {
169     *k = *element_num;
170   }
171   // get target_shape
172   if (!static_cast<bool>(param.axis)) {
173     if (param.ret_typ != topk_enum::kReturnMask) {
174       *target_shape = mshadow::Shape1(*k);
175     } else {
176       *target_shape = src_shape;
177     }
178   } else {
179     *target_shape = src_shape;
180     if (param.ret_typ != topk_enum::kReturnMask) {
181       (*target_shape)[*axis] = *k;
182     }
183   }
184   CHECK(*k >= 0 && *k <= *element_num) << "k must be smaller than "
185                                       << *element_num << ", get k = " << *k;
186 }
187 
188 using namespace mshadow;
189 
190 
191 struct fill_ind_to_one {
192   template<typename DType>
Mapfill_ind_to_one193   MSHADOW_XINLINE static void Map(int i, const index_t* indices, DType* out) {
194     out[indices[i]] = static_cast<DType>(1);
195   }
196 };
197 
198 struct fill_ind {
199   template<typename DType>
Mapfill_ind200   MSHADOW_XINLINE static void Map(int i, const index_t* indices, const DType* val,
201                                   int req, DType* out) {
202     KERNEL_ASSIGN(out[indices[i]], req, val[i]);
203   }
204 };
205 
206 template<typename DType>
TopKSort(const Tensor<cpu,1,DType> & dat,const Tensor<cpu,1,index_t> & ind,const Tensor<cpu,1,char> & work,index_t K,index_t N,bool is_ascend,Stream<cpu> * s)207 MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, DType>& dat,
208                                    const Tensor<cpu, 1, index_t>& ind,
209                                    const Tensor<cpu, 1, char>& work,
210                                    index_t K, index_t N, bool is_ascend,
211                                    Stream<cpu> *s) {
212   // Use full sort when K is relatively large.
213   const bool full_sort(K*8 > N);
214   // Batch size.
215   const index_t M(work.size(0)/(sizeof(DType)*N));
216   const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
217   #pragma omp parallel for num_threads(omp_threads)
218   for (index_t i = 0; i < M; ++i) {
219     // Tensor `work` stores the flattened source data, while `dat` stores the sorted result.
220     DType *vals = reinterpret_cast<DType*>(work.dptr_);
221     DType *sorted_vals = dat.dptr_+i*N;
222     index_t *indices = ind.dptr_+i*N;
223     if (is_ascend) {
224       if (full_sort) {
225         std::sort(indices, indices+N,
226                   [&](const index_t& i1, const index_t& i2){
227           return vals[i1] < vals[i2]; });
228       } else {
229         std::partial_sort(indices, indices+K, indices+N,
230                           [&](const index_t& i1, const index_t& i2){
231           return vals[i1] < vals[i2]; });
232       }
233     } else {
234       if (full_sort) {
235         std::sort(indices, indices+N,
236                   [&](const index_t& i1, const index_t& i2){
237           return vals[i1] > vals[i2]; });
238       } else {
239         std::partial_sort(indices, indices+K, indices+N,
240                           [&](const index_t& i1, const index_t& i2){
241           return vals[i1] > vals[i2]; });
242       }
243     }
244     for (index_t j = 0; j < K; ++j) {
245       sorted_vals[j] = vals[indices[j]];
246     }
247   }
248 }
249 
250 #ifdef __CUDACC__
251 
252 template<typename DType>
TopKCompare(DType val1,index_t ind1,DType val2,index_t ind2,bool is_ascend)253 MSHADOW_XINLINE bool TopKCompare(DType val1, index_t ind1, DType val2, index_t ind2,
254                                  bool is_ascend) {
255   // Negative indices denote undefined values which are considered arbitrary small resp. large.
256   return (ind2 < 0) || (ind1 >= 0 && ((is_ascend && val1 < val2) || (!is_ascend && val1 > val2)));
257 }
258 
259 template<typename DType>
MergeTopK(index_t K,DType * val1,index_t * ind1,DType * val2,index_t * ind2,bool is_ascend)260 MSHADOW_XINLINE void MergeTopK(index_t K, DType *val1, index_t *ind1, DType *val2, index_t *ind2,
261                                bool is_ascend) {
262   // In-place merge of two sorted top-K lists into val1/ind1. First determine the intervals
263   // [0,..,i1], [0,..i2] of the two lists that will be part of the merged list.
264   index_t i1(K-1), i2(K-1);
265   for (index_t i = 0; i < K; ++i) {
266     if (TopKCompare(val1[i1], ind1[i1], val2[i2], ind2[i2], is_ascend)) {
267       --i2;
268     } else {
269       --i1;
270     }
271   }
272   // Now merge the lists from back to front.
273   for (index_t i = K; i--;) {
274     if (i2 < 0 || i1 >= 0 && TopKCompare(val2[i2], ind2[i2], val1[i1], ind1[i1], is_ascend)) {
275       val1[i] = val1[i1];
276       ind1[i] = ind1[i1];
277       --i1;
278     } else {
279       val1[i] = val2[i2];
280       ind1[i] = ind2[i2];
281       --i2;
282     }
283   }
284 }
285 
286 template<typename DType>
PartialSortSmallK(index_t K,index_t N,DType * val,index_t * ind,bool is_ascend)287 __global__ void PartialSortSmallK(index_t K, index_t N, DType *val, index_t *ind, bool is_ascend) {
288   // Buffer for pairwise reduction.
289   extern __shared__ index_t buff[];
290   // Start of buffer sections associated with this thread.
291   const index_t offset(threadIdx.x*K);
292   index_t *ind_buff = &buff[offset];
293   DType *val_buff = reinterpret_cast<DType*>(&buff[blockDim.x*K])+offset;
294   // Initialize top-K values for this thread.
295   for (index_t i = 0; i < K; ++i) {
296     ind_buff[i] = -1;
297   }
298   // Range of values this thread cares about. Each thread block processes
299   // a different batch item (i.e. a different set of ind/val where we
300   // have to select the top-K elements). All threads within the same
301   // block work on the same batch item.
302   const index_t first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N);
303   // Select top-K from this range and store it sorted in the buffer.
304   // We assume a small K, so linear insertion is o.k.
305   for (index_t i = first; i < last; i += blockDim.x) {
306     DType cur_val(val[i]);
307     index_t cur_ind(ind[i]);
308     for (index_t j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j],
309                                            ind_buff[j], is_ascend); ) {
310       if (j+1 < K) {
311         val_buff[j+1] = val_buff[j];
312         ind_buff[j+1] = ind_buff[j];
313       }
314       val_buff[j] = cur_val;
315       ind_buff[j] = cur_ind;
316     }
317   }
318   // Recursive merge of sorted lists for this thread block. Note that blockDim.x is not
319   // necessary a power of two, therefore the additional checks for last_s.
320   for (index_t s = (blockDim.x+1)/2, last_s = blockDim.x;
321        last_s > 1; last_s = s, s = (s+1)/2) {
322     __syncthreads();
323     if (threadIdx.x < s && threadIdx.x+s < last_s) {
324       MergeTopK(K, val_buff, ind_buff, val_buff+s*K, ind_buff+s*K, is_ascend);
325     }
326   }
327   // Final updates on master thread.
328   if (threadIdx.x == 0) {
329     for (index_t i = 0; i < K; ++i) {
330       ind[blockIdx.x*N+i] = ind_buff[i];
331       val[blockIdx.x*N+i] = val_buff[i];
332     }
333   }
334 }
335 
336 template<typename DType>
TopKSort(const Tensor<gpu,1,DType> & dat,const Tensor<gpu,1,index_t> & ind,const Tensor<gpu,1,char> & work,index_t K,index_t N,bool is_ascend,Stream<gpu> * s)337 MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
338                                    const Tensor<gpu, 1, index_t>& ind,
339                                    const Tensor<gpu, 1, char>& work,
340                                    index_t K, index_t N, bool is_ascend,
341                                    Stream<gpu> *s) {
342   // Use full sort for all but very small K for which we
343   // can do a partial sort entirely within shared memory.
344   const bool full_sort(K > 5);
345   // Batch size.
346   const index_t M(dat.size(0)/N);
347   if (full_sort) {
348     // Divide workspace into two parts. The first one is needed to store batch ids.
349     size_t alignment = std::max(sizeof(DType), sizeof(index_t));
350     size_t id_size = PadBytes(sizeof(index_t) * ind.size(0), alignment);
351     Tensor<gpu, 1, index_t> batch_id(reinterpret_cast<index_t*>(work.dptr_),
352                                      Shape1(ind.size(0)), s);
353     Tensor<gpu, 1, char> sort_work(work.dptr_+id_size, Shape1(work.size(0)-id_size), s);
354     mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work);
355     if (M > 1) {
356       // Back to back sorting. Note that mxnet::op::SortByKey is a stable sort.
357       batch_id = ind / N;
358       mxnet::op::SortByKey(batch_id, dat, true, &sort_work);
359       batch_id = ind / N;
360       mxnet::op::SortByKey(batch_id, ind, true, &sort_work);
361     }
362   } else {
363     const int nthreads(mshadow::cuda::kBaseThreadNum);
364     PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(DType)),
365                         mshadow::Stream<gpu>::GetStream(s)>>>
366                         (K, N, dat.dptr_, ind.dptr_, is_ascend);
367   }
368 }
369 
370 #endif
371 
372 
373 /*!
374    * \brief Implementation of the TopK operation
375    *
376    *
377    * \param ctx the running context
378    * \param resource temporary resource handler
379    * \param src the Source blob
380    * \param ret the destination blobs
381    * \param param the topk parameters
382    * \tparam xpu the device type.
383    * \tparam DType type of the output value/mask.
384    * \tparam IDType type of the output indices.
385    */
386 template<typename xpu, typename DType, typename IDType>
TopKImpl(const RunContext & ctx,const Resource & resource,const std::vector<OpReqType> & req,const TBlob & src,const std::vector<TBlob> & ret,const TopKParam & param)387 void TopKImpl(const RunContext &ctx,
388               const Resource &resource,
389               const std::vector<OpReqType>& req,
390               const TBlob& src,
391               const std::vector<TBlob>& ret,
392               const TopKParam& param) {
393   using namespace mshadow;
394   using namespace mshadow::expr;
395   // 0. If input shape is 0-shape, directly return
396   if (src.Size() == 0) return;
397   // 1. Parse and initialize information
398   Stream<xpu> *s = ctx.get_stream<xpu>();
399   Tensor<xpu, 1, char> workspace;
400   Tensor<xpu, 1, char> temp_workspace;
401   Tensor<xpu, 1, DType> sorted_dat;
402   Tensor<xpu, 1, index_t> indices, sel_indices;
403   size_t batch_size = 0;
404   index_t element_num = 0;  // number of batches + the size of each batch
405   int axis = 0;
406   bool do_transpose = false;
407   bool is_ascend = false;
408   index_t k = 0;
409   size_t alignment = std::max(sizeof(DType), sizeof(index_t));
410   mxnet::TShape target_shape;
411   ParseTopKParam(src.shape_, param,
412                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
413   CHECK_LE(element_num, mxnet::common::MaxIntegerValue<index_t>())
414     << "'index_t' does not have a sufficient precision to represent "
415     << "the indices of the input array. The total element_num is "
416     << element_num << ", but the selected index_t can only represent "
417     << mxnet::common::MaxIntegerValue<index_t>() << " elements";
418   Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
419   // Temp space needed by the full sorts.
420   size_t temp_size = std::max(
421       mxnet::op::SortByKeyWorkspaceSize<index_t, DType, xpu>(src.Size()),
422       mxnet::op::SortByKeyWorkspaceSize<DType, index_t, xpu>(src.Size()));
423 
424   temp_size = std::max(temp_size,
425       mxnet::op::SortByKeyWorkspaceSize<index_t, index_t, xpu>(src.Size()));
426   // Additional temp space for gpu full sorts for batch ids.
427   temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment);
428   // Temp space for cpu sorts.
429   temp_size = std::max(temp_size, sizeof(DType) * src.Size());
430 
431   size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment)
432                                     + PadBytes(sizeof(index_t) * src.Size(), alignment);
433   if (param.ret_typ == topk_enum::kReturnMask) {
434     workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment);
435   }
436   workspace = resource.get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
437   char* workspace_curr_ptr = workspace.dptr_;
438   sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
439       Shape1(src.Size()), s);  // contain sorted dat
440   workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
441   indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
442       Shape1(src.Size()), s);  // indices in the original matrix
443   workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment);
444 
445   if (param.ret_typ == topk_enum::kReturnMask) {
446     sel_indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
447                                       Shape1(batch_size * k), s);
448     workspace_curr_ptr += PadBytes(sizeof(index_t) * batch_size * k, alignment);
449     CHECK_EQ(sel_indices.CheckContiguous(), true);
450   }
451 
452   if (std::is_same<xpu, cpu>::value) {
453     Tensor<xpu, 1, DType> flattened_data;
454     if (do_transpose) {
455       flattened_data = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
456                                               Shape1(src.Size()), s);
457       workspace_curr_ptr += sizeof(DType) * src.Size();
458       flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
459       CHECK_EQ(flattened_data.CheckContiguous(), true);
460     } else {
461       flattened_data = src.FlatTo1D<xpu, DType>(s);
462     }
463     // `temp_workspace` stores the flattened data
464     temp_workspace = Tensor<xpu, 1, char>(reinterpret_cast<char*>(flattened_data.dptr_),
465                                           Shape1(sizeof(DType)*src.Size()), s);
466     CHECK_EQ(temp_workspace.CheckContiguous(), true);
467   } else {
468     if (do_transpose) {
469       sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
470     } else {
471       sorted_dat = reshape(dat, Shape1(src.Size()));
472     }
473     CHECK_EQ(sorted_dat.CheckContiguous(), true);
474     temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s);  // temp space
475     workspace_curr_ptr += temp_size;
476   }
477 
478   mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, index_t{0}, index_t{1},
479     kWriteTo, indices.dptr_);
480   CHECK_EQ(indices.CheckContiguous(), true);
481 
482   // 2. Perform inplace batch sort.
483   // After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
484   // up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat`
485   // `temp_workspace` is used to store the flattend source data for CPU device, and it's used as
486   // a temporal buffer for GPU device.
487   TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s);
488 
489   // 3. Assign results to the ret blob
490   // When returning indices, only update(modulo) required elements instead of full elements
491   // to avoid redundant calculation.
492   // Cast `ret_indices` from int to real_t could introduce conversion error when the element_num
493   // is large enough.
494   if (param.ret_typ == topk_enum::kReturnMask) {
495     Tensor<xpu, 1, DType> ret_mask = ret[0].FlatTo1D<xpu, DType>(s);
496     ret_mask = scalar<DType>(0);
497     sel_indices = reshape(slice<1>(
498                               inplace_reshape(indices,
499                                               Shape2(batch_size,
500                                                      element_num)), 0, k),
501                               Shape1(batch_size * k));
502     if (do_transpose) {
503       mxnet::TShape src_shape = src.shape_.FlatTo3D(axis);
504       CHECK_EQ(sel_indices.CheckContiguous(), true);
505       sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
506                                       Shape3(0, 2, 1));
507     }
508     if (req[0] == kNullOp) {
509       return;
510     } else if (req[0] == kWriteTo) {
511       mxnet_op::Kernel<fill_ind_to_one, xpu>::Launch(s, batch_size * k,
512                                                      sel_indices.dptr_, ret_mask.dptr_);
513     } else {
514       LOG(FATAL) << "req=" << req[0] << " is not supported yet.";
515     }
516   } else if (param.ret_typ == topk_enum::kReturnIndices) {
517     if (do_transpose) {
518       Tensor<xpu, 3, IDType> ret_indices = ret[0].FlatTo3D<xpu, IDType>(axis, axis, s);
519       ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(transpose(
520                       slice<2>(inplace_reshape(indices,
521                                                Shape3(ret_indices.shape_[0],
522                                                       ret_indices.shape_[2],
523                                                       element_num)),
524                                0, k),
525                       Shape3(0, 2, 1)), element_num)));
526     } else {
527       Tensor<xpu, 2, IDType> ret_indices =
528         ret[0].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
529       ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
530                       inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k),
531                       element_num)));
532     }
533   } else {
534     if (do_transpose) {
535       Tensor<xpu, 3, DType> ret_value = ret[0].FlatTo3D<xpu, DType>(axis, axis, s);
536       Tensor<xpu, 3, IDType> ret_indices = ret[1].FlatTo3D<xpu, IDType>(axis, axis, s);
537       ASSIGN_DISPATCH(ret_value, req[0], transpose(
538                    slice<2>(inplace_reshape(sorted_dat,
539                                     Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
540                             0, k), Shape3(0, 2, 1)));
541       ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(transpose(
542                       slice<2>(inplace_reshape(indices,
543                                                Shape3(ret_indices.shape_[0],
544                                                       ret_indices.shape_[2],
545                                                       element_num)),
546                                0, k), Shape3(0, 2, 1)), element_num)));
547     } else {
548       Tensor<xpu, 2, DType> ret_value =
549         ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
550       Tensor<xpu, 2, IDType> ret_indices =
551         ret[1].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
552       ASSIGN_DISPATCH(ret_value, req[0],
553              slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k));
554       ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
555                  inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k), element_num)));
556     }
557   }
558 }
559 
560 template<typename xpu, typename DType>
TopKWorkspaceSize(const TBlob & src,const TopKParam & param,size_t * temp_size_ptr)561 size_t TopKWorkspaceSize(const TBlob& src,
562                          const TopKParam& param,
563                          size_t *temp_size_ptr) {
564   using namespace mshadow;
565   using namespace mshadow::expr;
566 
567   size_t batch_size = 0;
568   size_t temp_size;
569   index_t element_num = 0;  // number of batches + the size of each batch
570   int axis = 0;
571   bool do_transpose = false;
572   bool is_ascend = false;
573   index_t k = 0;
574   size_t alignment = std::max(sizeof(DType), sizeof(index_t));
575   mxnet::TShape target_shape;
576   ParseTopKParam(src.shape_, param,
577                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
578 
579   // Temp space needed by the full sorts.
580   temp_size = std::max(
581       mxnet::op::SortByKeyWorkspaceSize<index_t, DType, xpu>(src.Size()),
582       mxnet::op::SortByKeyWorkspaceSize<DType, index_t, xpu>(src.Size()));
583 
584   temp_size = std::max(temp_size,
585       mxnet::op::SortByKeyWorkspaceSize<index_t, index_t, xpu>(src.Size()));
586   // Additional temp space for gpu full sorts for batch ids.
587   temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment);
588   // Temp space for cpu sorts.
589   temp_size = std::max(temp_size, sizeof(DType) * src.Size());
590   *temp_size_ptr = temp_size;
591 
592   size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment)
593                                     + PadBytes(sizeof(index_t) * src.Size(), alignment);
594   if (param.ret_typ == topk_enum::kReturnMask) {
595     workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment);
596   }
597   return workspace_size;
598 }
599 
600 template<typename xpu, typename DType, typename IDType>
TopKImplwithWorkspace(const RunContext & ctx,const std::vector<OpReqType> & req,const TBlob & src,const std::vector<TBlob> & ret,const TopKParam & param,char * workspace_curr_ptr,const size_t & temp_size,Stream<xpu> * s)601 void TopKImplwithWorkspace(const RunContext &ctx,
602                            const std::vector<OpReqType>& req,
603                            const TBlob& src,
604                            const std::vector<TBlob>& ret,
605                            const TopKParam& param,
606                            char* workspace_curr_ptr,
607                            const size_t &temp_size,
608                            Stream<xpu>* s) {
609   using namespace mshadow;
610   using namespace mshadow::expr;
611   // 0. If input shape is 0-shape, directly return
612   if (src.Size() == 0) return;
613   // 1. Parse and initialize information
614   Tensor<xpu, 1, char> workspace;
615   Tensor<xpu, 1, char> temp_workspace;
616   Tensor<xpu, 1, DType> sorted_dat;
617   Tensor<xpu, 1, index_t> indices, sel_indices;
618   size_t batch_size = 0;
619   index_t element_num = 0;  // number of batches + the size of each batch
620   int axis = 0;
621   bool do_transpose = false;
622   bool is_ascend = false;
623   index_t k = 0;
624   size_t alignment = std::max(sizeof(DType), sizeof(index_t));
625   mxnet::TShape target_shape;
626   ParseTopKParam(src.shape_, param,
627                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
628   CHECK_LE(element_num, mxnet::common::MaxIntegerValue<index_t>())
629     << "'index_t' does not have a sufficient precision to represent "
630     << "the indices of the input array. The total element_num is "
631     << element_num << ", but the selected index_t can only represent "
632     << mxnet::common::MaxIntegerValue<index_t>() << " elements";
633   Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
634   sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
635       Shape1(src.Size()), s);  // contain sorted dat
636   workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
637   indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
638       Shape1(src.Size()), s);  // indices in the original matrix
639   workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment);
640 
641   if (param.ret_typ == topk_enum::kReturnMask) {
642     sel_indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
643                                       Shape1(batch_size * k), s);
644     workspace_curr_ptr += PadBytes(sizeof(index_t) * batch_size * k, alignment);
645     CHECK_EQ(sel_indices.CheckContiguous(), true);
646   }
647 
648   if (std::is_same<xpu, cpu>::value) {
649     Tensor<xpu, 1, DType> flattened_data;
650     if (do_transpose) {
651       flattened_data = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
652                                               Shape1(src.Size()), s);
653       workspace_curr_ptr += sizeof(DType) * src.Size();
654       flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
655       CHECK_EQ(flattened_data.CheckContiguous(), true);
656     } else {
657       flattened_data = src.FlatTo1D<xpu, DType>(s);
658     }
659     // `temp_workspace` stores the flattened data
660     temp_workspace = Tensor<xpu, 1, char>(reinterpret_cast<char*>(flattened_data.dptr_),
661                                           Shape1(sizeof(DType)*src.Size()), s);
662     CHECK_EQ(temp_workspace.CheckContiguous(), true);
663   } else {
664     if (do_transpose) {
665       sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
666     } else {
667       sorted_dat = reshape(dat, Shape1(src.Size()));
668     }
669     CHECK_EQ(sorted_dat.CheckContiguous(), true);
670     temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s);  // temp space
671     workspace_curr_ptr += temp_size;
672   }
673 
674   mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, index_t{0}, index_t{1},
675     kWriteTo, indices.dptr_);
676   CHECK_EQ(indices.CheckContiguous(), true);
677 
678   // 2. Perform inplace batch sort.
679   // After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
680   // up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat`
681   // `temp_workspace` is used to store the flattend source data for CPU device, and it's used as
682   // a temporal buffer for GPU device.
683   TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s);
684 
685   // 3. Assign results to the ret blob
686   // When returning indices, only update(modulo) required elements instead of full elements
687   // to avoid redundant calculation.
688   // Cast `ret_indices` from int to real_t could introduce conversion error when the element_num
689   // is large enough.
690   if (param.ret_typ == topk_enum::kReturnMask) {
691     Tensor<xpu, 1, DType> ret_mask = ret[0].FlatTo1D<xpu, DType>(s);
692     ret_mask = scalar<DType>(0);
693     sel_indices = reshape(slice<1>(
694                               inplace_reshape(indices,
695                                               Shape2(batch_size,
696                                                      element_num)), 0, k),
697                               Shape1(batch_size * k));
698     if (do_transpose) {
699       mxnet::TShape src_shape = src.shape_.FlatTo3D(axis);
700       CHECK_EQ(sel_indices.CheckContiguous(), true);
701       sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
702                                       Shape3(0, 2, 1));
703     }
704     if (req[0] == kNullOp) {
705       return;
706     } else if (req[0] == kWriteTo) {
707       mxnet_op::Kernel<fill_ind_to_one, xpu>::Launch(s, batch_size * k,
708                                                      sel_indices.dptr_, ret_mask.dptr_);
709     } else {
710       LOG(FATAL) << "req=" << req[0] << " is not supported yet.";
711     }
712   } else if (param.ret_typ == topk_enum::kReturnIndices) {
713     if (do_transpose) {
714       Tensor<xpu, 3, IDType> ret_indices = ret[0].FlatTo3D<xpu, IDType>(axis, axis, s);
715       ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(transpose(
716                       slice<2>(inplace_reshape(indices,
717                                                Shape3(ret_indices.shape_[0],
718                                                       ret_indices.shape_[2],
719                                                       element_num)),
720                                0, k),
721                       Shape3(0, 2, 1)), element_num)));
722     } else {
723       Tensor<xpu, 2, IDType> ret_indices =
724         ret[0].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
725       ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
726                       inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k),
727                       element_num)));
728     }
729   } else {
730     if (do_transpose) {
731       Tensor<xpu, 3, DType> ret_value = ret[0].FlatTo3D<xpu, DType>(axis, axis, s);
732       Tensor<xpu, 3, IDType> ret_indices = ret[1].FlatTo3D<xpu, IDType>(axis, axis, s);
733       ASSIGN_DISPATCH(ret_value, req[0], transpose(
734                    slice<2>(inplace_reshape(sorted_dat,
735                                     Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
736                             0, k), Shape3(0, 2, 1)));
737       ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(transpose(
738                       slice<2>(inplace_reshape(indices,
739                                                Shape3(ret_indices.shape_[0],
740                                                       ret_indices.shape_[2],
741                                                       element_num)),
742                                0, k), Shape3(0, 2, 1)), element_num)));
743     } else {
744       Tensor<xpu, 2, DType> ret_value =
745         ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
746       Tensor<xpu, 2, IDType> ret_indices =
747         ret[1].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
748       ASSIGN_DISPATCH(ret_value, req[0],
749              slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k));
750       ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
751                  inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k), element_num)));
752     }
753   }
754 }
755 
756 template<typename xpu>
TopK(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)757 void TopK(const nnvm::NodeAttrs& attrs,
758           const OpContext& ctx,
759           const std::vector<TBlob>& inputs,
760           const std::vector<OpReqType>& req,
761           const std::vector<TBlob>& outputs) {
762   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
763   if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) {
764     MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
765       MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
766         TopKImpl<xpu, DType, IDType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
767       })
768     });
769   } else {
770     MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
771       TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
772     });
773   }
774 }
775 
776 template<typename xpu>
Sort(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)777 void Sort(const nnvm::NodeAttrs& attrs,
778           const OpContext& ctx,
779           const std::vector<TBlob>& inputs,
780           const std::vector<OpReqType>& req,
781           const std::vector<TBlob>& outputs) {
782   const SortParam& param = nnvm::get<SortParam>(attrs.parsed);
783   TopKParam topk_param;
784   topk_param.axis = param.axis;
785   topk_param.is_ascend = param.is_ascend;
786   topk_param.k = 0;
787   topk_param.ret_typ = topk_enum::kReturnValue;
788   MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
789     TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, inputs[0],
790                                   outputs, topk_param);
791   });
792 }
793 
794 template<typename xpu>
ArgSort(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)795 void ArgSort(const nnvm::NodeAttrs& attrs,
796              const OpContext& ctx,
797              const std::vector<TBlob>& inputs,
798              const std::vector<OpReqType>& req,
799              const std::vector<TBlob>& outputs) {
800   const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
801   TopKParam topk_param;
802   topk_param.axis = param.axis;
803   topk_param.is_ascend = param.is_ascend;
804   topk_param.k = 0;
805   topk_param.dtype = param.dtype;
806   topk_param.ret_typ = topk_enum::kReturnIndices;
807   MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
808     MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
809       TopKImpl<xpu, DType, IDType>(ctx.run_ctx,
810                                    ctx.requested[0], req, inputs[0], outputs, topk_param);
811     });
812   });
813 }
814 
815 template<typename xpu, typename DType, typename IDType>
TopKBackwardImpl(const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs,const TopKParam & param)816 void TopKBackwardImpl(const OpContext &ctx,
817                       const std::vector<TBlob>& inputs,
818                       const std::vector<OpReqType>& req,
819                       const std::vector<TBlob>& outputs,
820                       const TopKParam& param) {
821   CHECK_NE(req[0], kWriteInplace);
822   using namespace mshadow;
823   using namespace mshadow::expr;
824   Stream<xpu> *s = ctx.run_ctx.get_stream<xpu>();
825   CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth);
826   size_t batch_size = 0;
827   index_t element_num = 0;  // number of batches + the size of each batch
828   int axis = 0;
829   bool do_transpose = false;
830   bool is_ascend = false;
831   index_t k = 0;
832   mxnet::TShape target_shape;
833   ParseTopKParam(outputs[0].shape_, param,
834                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
835   CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
836     << "'IDType' does not have a sufficient precision to represent "
837     << "the indices of the input array. The total element_num is " << element_num
838     << ", but the selected index_t can only represent "
839     << mxnet::common::MaxIntegerValue<IDType>() << " elements";
840   Tensor<xpu, 1, index_t> workspace =
841     ctx.requested[0].get_space_typed<xpu, 1, index_t>(Shape1(batch_size * k + batch_size), s);
842   Tensor<xpu, 1, index_t> sel_indices =
843     Tensor<xpu, 1, index_t>(workspace.dptr_, Shape1(batch_size * k), s);
844   Tensor<xpu, 1, index_t> batch_shift =
845     Tensor<xpu, 1, index_t>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
846 
847   Tensor<xpu, 2, DType> out_grad =
848     inputs[0].get_with_shape<xpu, 2, DType>(Shape2(inputs[0].shape_.Size(), 1), s);
849   Tensor<xpu, 2, DType> in_grad =
850     outputs[0].get_with_shape<xpu, 2, DType>(Shape2(outputs[0].shape_.Size(), 1), s);
851   mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, index_t{0}, element_num, kWriteTo,
852                                            batch_shift.dptr_);
853   if (do_transpose) {
854     Tensor<xpu, 1, IDType> indices = inputs[2].FlatTo1D<xpu, IDType>(s);
855     mxnet::TShape src_shape = outputs[0].shape_.FlatTo3D(axis);
856     sel_indices = reshape(transpose(
857                             broadcast_to(inplace_reshape(batch_shift,
858                                                          Shape3(src_shape[0], src_shape[2], 1)),
859                                          mxnet::TShape(Shape3(src_shape[0], src_shape[2], k))),
860                             Shape3(0, 2, 1)),
861                           Shape1(batch_size * k));
862     sel_indices += tcast<index_t>(indices);
863     sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
864                                     Shape3(0, 2, 1));
865   } else {
866     Tensor<xpu, 2, IDType> indices =
867       inputs[2].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
868     sel_indices = reshape(tcast<index_t>(indices) +
869                           broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)),
870                                        mxnet::TShape(Shape2(batch_size, k))),
871                           Shape1(batch_size * k));
872   }
873   CHECK_EQ(sel_indices.CheckContiguous(), true);
874   if (kWriteTo == req[0] || kAddTo == req[0]) {
875     if (kWriteTo == req[0]) {
876       in_grad = scalar<DType>(0);
877     }
878     mxnet_op::Kernel<fill_ind, xpu>::Launch(s, batch_size * k,
879                                             sel_indices.dptr_,
880                                             out_grad.dptr_,
881                                             req[0],
882                                             in_grad.dptr_);
883   } else {
884     LOG(FATAL) << "Not Implemented!";
885   }
886 }
887 
888 template<typename xpu>
TopKBackward_(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)889 void TopKBackward_(const nnvm::NodeAttrs& attrs,
890                    const OpContext& ctx,
891                    const std::vector<TBlob>& inputs,
892                    const std::vector<OpReqType>& req,
893                    const std::vector<TBlob>& outputs) {
894   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
895   if (param.ret_typ == topk_enum::kReturnBoth) {
896     MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
897       MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
898         TopKBackwardImpl<xpu, DType, IDType>(ctx, inputs, req, outputs, param);
899       });
900     });
901   } else if (param.ret_typ == topk_enum::kReturnValue) {
902     MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
903       TopKBackwardImpl<xpu, DType, index_t>(ctx, inputs, req, outputs, param);
904     });
905   } else {
906     LOG(FATAL) << "Not Implemented";
907   }
908 }
909 
TopKNumOutputs(const NodeAttrs & attrs)910 inline uint32_t TopKNumOutputs(const NodeAttrs& attrs) {
911   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
912   if (param.ret_typ == topk_enum::kReturnIndices ||
913     param.ret_typ == topk_enum::kReturnMask) {
914     return static_cast<uint32_t>(1);
915   } else {
916     return static_cast<uint32_t>(2);
917   }
918 }
919 
TopKNumVisibleOutputs(const NodeAttrs & attrs)920 inline uint32_t TopKNumVisibleOutputs(const NodeAttrs& attrs) {
921   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
922   if (param.ret_typ == topk_enum::kReturnBoth) {
923     return static_cast<uint32_t>(2);
924   } else {
925     return static_cast<uint32_t>(1);
926   }
927 }
928 
TopKType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)929 inline bool TopKType(const nnvm::NodeAttrs& attrs,
930                      std::vector<int> *in_attrs,
931                      std::vector<int> *out_attrs) {
932   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
933   size_t in_size = in_attrs->size();
934   size_t out_size = out_attrs->size();
935   CHECK_EQ(in_size, 1);
936   CHECK(out_size == 1 || out_size == 2);
937   //  out_attr[0] -> stores value
938   //  out_attr[1] -> stores indices
939   if (out_size > 1) {
940     if (param.ret_typ == topk_enum::kReturnValue) {
941 #if MXNET_USE_INT64_TENSOR_SIZE == 1
942       CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64))
943 #else
944       CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
945 #endif
946           << "Failed to set the type of ret_indices.";
947     } else {
948       CHECK(type_assign(&(*out_attrs)[1], param.dtype))
949           << "Failed to set the type of ret_indices.";
950     }
951   }
952   if (param.ret_typ == topk_enum::kReturnIndices) {
953     CHECK(type_assign(&(*out_attrs)[0], param.dtype))
954             << "Failed to set the type of ret_indices.";
955   } else {
956     TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
957     TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
958     return out_attrs->at(0) != -1;
959   }
960   return true;
961 }
962 
TopKShapeImpl(const TopKParam & param,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)963 inline bool TopKShapeImpl(const TopKParam& param,
964                           mxnet::ShapeVector *in_attrs,
965                           mxnet::ShapeVector *out_attrs) {
966   CHECK_EQ(in_attrs->size(), 1U);
967   if (param.ret_typ == topk_enum::kReturnIndices ||
968     param.ret_typ == topk_enum::kReturnMask) {
969     CHECK_EQ(out_attrs->size(), 1U);
970   } else {
971     CHECK_EQ(out_attrs->size(), 2U);
972   }
973   mxnet::TShape& in_shape = (*in_attrs)[0];
974   size_t batch_size = 0;
975   index_t element_num = 0;  // number of batches + the size of each batch
976   int axis = 0;
977   bool do_transpose = false;
978   bool is_ascend = false;
979   index_t k = 0;
980   mxnet::TShape target_shape;
981   ParseTopKParam(in_shape, param,
982     &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
983   if (param.ret_typ == topk_enum::kReturnIndices ||
984     param.ret_typ == topk_enum::kReturnMask) {
985     SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
986   } else {
987     SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
988     SHAPE_ASSIGN_CHECK(*out_attrs, 1, target_shape);
989   }
990   return true;
991 }
992 
TopKShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)993 inline bool TopKShape(const nnvm::NodeAttrs& attrs,
994                       mxnet::ShapeVector *in_attrs,
995                       mxnet::ShapeVector *out_attrs) {
996   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
997   return TopKShapeImpl(param, in_attrs, out_attrs);
998 }
999 
SortType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)1000 inline bool SortType(const nnvm::NodeAttrs& attrs,
1001                      std::vector<int> *in_attrs,
1002                      std::vector<int> *out_attrs) {
1003   int data_type = -1;
1004   size_t in_size = in_attrs->size();
1005   size_t out_size = out_attrs->size();
1006   CHECK_EQ(in_size, 1);
1007   CHECK_EQ(out_size, 2);
1008 #if MXNET_USE_INT64_TENSOR_SIZE == 1
1009   CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64))
1010 #else
1011   CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
1012 #endif
1013       << "Failed to set the type of ret_indices";
1014   CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
1015                                                  << (*in_attrs)[0];
1016   CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
1017                                                   << (*out_attrs)[0];
1018   CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
1019                                                  << (*in_attrs)[0];
1020   CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
1021                                                   << (*out_attrs)[0];
1022   if (data_type == -1) return false;
1023   return true;
1024 }
1025 
SortShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)1026 inline bool SortShape(const nnvm::NodeAttrs& attrs,
1027                       mxnet::ShapeVector *in_attrs,
1028                       mxnet::ShapeVector *out_attrs) {
1029   const SortParam& param = nnvm::get<SortParam>(attrs.parsed);
1030   TopKParam topk_param;
1031   topk_param.axis = param.axis;
1032   topk_param.is_ascend = param.is_ascend;
1033   topk_param.k = 0;
1034   topk_param.ret_typ = topk_enum::kReturnValue;
1035   return TopKShapeImpl(topk_param, in_attrs, out_attrs);
1036 }
1037 
ArgSortType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)1038 inline bool ArgSortType(const nnvm::NodeAttrs& attrs,
1039                         std::vector<int> *in_attrs,
1040                         std::vector<int> *out_attrs) {
1041   const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
1042   CHECK(type_assign(&(*out_attrs)[0], param.dtype))
1043       << "Failed to set the type of ret_indices.";
1044   return true;
1045 }
1046 
ArgSortShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)1047 inline bool ArgSortShape(const nnvm::NodeAttrs& attrs,
1048                          mxnet::ShapeVector *in_attrs,
1049                          mxnet::ShapeVector *out_attrs) {
1050   const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
1051   TopKParam topk_param;
1052   topk_param.axis = param.axis;
1053   topk_param.is_ascend = param.is_ascend;
1054   topk_param.k = 0;
1055   topk_param.ret_typ = topk_enum::kReturnIndices;
1056   return TopKShapeImpl(topk_param, in_attrs, out_attrs);
1057 }
1058 }  // namespace op
1059 }  // namespace mxnet
1060 #endif  // MXNET_OPERATOR_TENSOR_ORDERING_OP_INL_H_
1061