1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file ordering_op-inl.h
22 * \brief Function definition of ordering operators
23 */
24 #ifndef MXNET_OPERATOR_TENSOR_ORDERING_OP_INL_H_
25 #define MXNET_OPERATOR_TENSOR_ORDERING_OP_INL_H_
26
27 #include <mxnet/operator_util.h>
28 #include <dmlc/optional.h>
29 #include <mshadow/tensor.h>
30 #include <algorithm>
31 #include <vector>
32 #include <type_traits>
33 #include "../mshadow_op.h"
34 #include "../elemwise_op_common.h"
35 #include "./sort_op.h"
36 #include "./indexing_op.h"
37
38 namespace mshadow {
39 template<typename xpu, int src_dim, typename DType, int dst_dim>
inplace_reshape(Tensor<xpu,src_dim,DType> src,Shape<dst_dim> target_shape)40 inline Tensor<xpu, dst_dim, DType> inplace_reshape(Tensor<xpu, src_dim, DType> src,
41 Shape<dst_dim> target_shape) {
42 CHECK_EQ(src.CheckContiguous(), true);
43 return Tensor<xpu, dst_dim, DType>(src.dptr_, target_shape, src.stream_);
44 }
45 };
46
47
48 namespace mxnet {
49 namespace op {
50 // These enums are only visible within this header
51 namespace topk_enum {
52 enum TopKReturnType {kReturnValue, kReturnIndices, kReturnMask, kReturnBoth};
53 } // topk_enum
54
55 struct TopKParam : public dmlc::Parameter<TopKParam> {
56 dmlc::optional<int> axis;
57 int k;
58 int ret_typ;
59 bool is_ascend;
60 int dtype;
DMLC_DECLARE_PARAMETERTopKParam61 DMLC_DECLARE_PARAMETER(TopKParam) {
62 DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
63 .describe("Axis along which to choose the top k indices."
64 " If not given, the flattened array is used. Default is -1.");
65 DMLC_DECLARE_FIELD(k).set_default(1)
66 .describe("Number of top elements to select,"
67 " should be always smaller than or equal to the element number in the given axis."
68 " A global sort is performed if set k < 1.");
69 DMLC_DECLARE_FIELD(ret_typ).set_default(topk_enum::kReturnIndices)
70 .add_enum("value", topk_enum::kReturnValue)
71 .add_enum("indices", topk_enum::kReturnIndices)
72 .add_enum("mask", topk_enum::kReturnMask)
73 .add_enum("both", topk_enum::kReturnBoth)
74 .describe("The return type.\n"
75 " \"value\" means to return the top k values,"
76 " \"indices\" means to return the indices of the top k values,"
77 " \"mask\" means to return a mask array containing 0 and 1. 1 means the top k values."
78 " \"both\" means to return a list of both values and indices of top k elements.");
79 DMLC_DECLARE_FIELD(is_ascend).set_default(false)
80 .describe("Whether to choose k largest or k smallest elements."
81 " Top K largest elements will be chosen if set to false.");
82 DMLC_DECLARE_FIELD(dtype)
83 // TODO(srivrohi): remove support for real data type in mxnet-2.0
84 .add_enum("uint8", mshadow::kUint8)
85 .add_enum("int32", mshadow::kInt32)
86 .add_enum("int64", mshadow::kInt64)
87 .add_enum("float16", mshadow::kFloat16)
88 .add_enum("float32", mshadow::kFloat32)
89 .add_enum("float64", mshadow::kFloat64)
90 .set_default(mshadow::kFloat32)
91 .describe("DType of the output indices when ret_typ is \"indices\" or \"both\". "
92 "An error will be raised if the selected data type cannot precisely represent the "
93 "indices.");
94 }
95 };
96
97 struct SortParam : public dmlc::Parameter<SortParam> {
98 dmlc::optional<int> axis;
99 bool is_ascend;
DMLC_DECLARE_PARAMETERSortParam100 DMLC_DECLARE_PARAMETER(SortParam) {
101 DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
102 .describe("Axis along which to choose sort the input tensor."
103 " If not given, the flattened array is used. Default is -1.");
104 DMLC_DECLARE_FIELD(is_ascend).set_default(true)
105 .describe("Whether to sort in ascending or descending order.");
106 }
107 };
108
109 struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
110 dmlc::optional<int> axis;
111 bool is_ascend;
112 int dtype;
DMLC_DECLARE_PARAMETERArgSortParam113 DMLC_DECLARE_PARAMETER(ArgSortParam) {
114 DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
115 .describe("Axis along which to sort the input tensor."
116 " If not given, the flattened array is used. Default is -1.");
117 DMLC_DECLARE_FIELD(is_ascend).set_default(true)
118 .describe("Whether to sort in ascending or descending order.");
119 DMLC_DECLARE_FIELD(dtype)
120 // TODO(srivrohi): remove support for real data type in mxnet-2.0
121 .add_enum("uint8", mshadow::kUint8)
122 .add_enum("int32", mshadow::kInt32)
123 .add_enum("int64", mshadow::kInt64)
124 .add_enum("float16", mshadow::kFloat16)
125 .add_enum("float32", mshadow::kFloat32)
126 .add_enum("float64", mshadow::kFloat64)
127 .set_default(mshadow::kFloat32)
128 .describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or"
129 " \"both\". An error will be raised if the selected data type cannot precisely "
130 "represent the indices.");
131 }
132 };
133
ParseTopKParam(const TShape & src_shape,const TopKParam & param,TShape * target_shape,size_t * batch_size,index_t * element_num,int * axis,index_t * k,bool * do_transpose,bool * is_ascend)134 inline void ParseTopKParam(const TShape& src_shape,
135 const TopKParam& param,
136 TShape *target_shape,
137 size_t *batch_size,
138 index_t *element_num,
139 int *axis,
140 index_t *k,
141 bool *do_transpose,
142 bool *is_ascend) {
143 *do_transpose = false;
144 *k = param.k;
145 *is_ascend = param.is_ascend;
146 // get batch_size, axis and element_num
147 if (!static_cast<bool>(param.axis)) { // No axis given
148 *axis = 0;
149 *batch_size = 1;
150 *element_num = src_shape.Size();
151 } else {
152 *axis = param.axis.value();
153 if (*axis < 0) {
154 *axis += src_shape.ndim();
155 }
156 CHECK(*axis >= 0 && *axis < static_cast<int>(src_shape.ndim()))
157 << "Invalid axis! axis should be between 0 and "
158 << src_shape.ndim() << ", found axis=" << *axis;
159 if (src_shape[*axis] != 0) {
160 *batch_size = src_shape.Size() / src_shape[*axis];
161 }
162 *element_num = src_shape[*axis];
163 if (*axis != src_shape.ndim() - 1) {
164 *do_transpose = true;
165 }
166 }
167 // get k
168 if (param.k <= 0) {
169 *k = *element_num;
170 }
171 // get target_shape
172 if (!static_cast<bool>(param.axis)) {
173 if (param.ret_typ != topk_enum::kReturnMask) {
174 *target_shape = mshadow::Shape1(*k);
175 } else {
176 *target_shape = src_shape;
177 }
178 } else {
179 *target_shape = src_shape;
180 if (param.ret_typ != topk_enum::kReturnMask) {
181 (*target_shape)[*axis] = *k;
182 }
183 }
184 CHECK(*k >= 0 && *k <= *element_num) << "k must be smaller than "
185 << *element_num << ", get k = " << *k;
186 }
187
188 using namespace mshadow;
189
190
191 struct fill_ind_to_one {
192 template<typename DType>
Mapfill_ind_to_one193 MSHADOW_XINLINE static void Map(int i, const index_t* indices, DType* out) {
194 out[indices[i]] = static_cast<DType>(1);
195 }
196 };
197
198 struct fill_ind {
199 template<typename DType>
Mapfill_ind200 MSHADOW_XINLINE static void Map(int i, const index_t* indices, const DType* val,
201 int req, DType* out) {
202 KERNEL_ASSIGN(out[indices[i]], req, val[i]);
203 }
204 };
205
206 template<typename DType>
TopKSort(const Tensor<cpu,1,DType> & dat,const Tensor<cpu,1,index_t> & ind,const Tensor<cpu,1,char> & work,index_t K,index_t N,bool is_ascend,Stream<cpu> * s)207 MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, DType>& dat,
208 const Tensor<cpu, 1, index_t>& ind,
209 const Tensor<cpu, 1, char>& work,
210 index_t K, index_t N, bool is_ascend,
211 Stream<cpu> *s) {
212 // Use full sort when K is relatively large.
213 const bool full_sort(K*8 > N);
214 // Batch size.
215 const index_t M(work.size(0)/(sizeof(DType)*N));
216 const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
217 #pragma omp parallel for num_threads(omp_threads)
218 for (index_t i = 0; i < M; ++i) {
219 // Tensor `work` stores the flattened source data, while `dat` stores the sorted result.
220 DType *vals = reinterpret_cast<DType*>(work.dptr_);
221 DType *sorted_vals = dat.dptr_+i*N;
222 index_t *indices = ind.dptr_+i*N;
223 if (is_ascend) {
224 if (full_sort) {
225 std::sort(indices, indices+N,
226 [&](const index_t& i1, const index_t& i2){
227 return vals[i1] < vals[i2]; });
228 } else {
229 std::partial_sort(indices, indices+K, indices+N,
230 [&](const index_t& i1, const index_t& i2){
231 return vals[i1] < vals[i2]; });
232 }
233 } else {
234 if (full_sort) {
235 std::sort(indices, indices+N,
236 [&](const index_t& i1, const index_t& i2){
237 return vals[i1] > vals[i2]; });
238 } else {
239 std::partial_sort(indices, indices+K, indices+N,
240 [&](const index_t& i1, const index_t& i2){
241 return vals[i1] > vals[i2]; });
242 }
243 }
244 for (index_t j = 0; j < K; ++j) {
245 sorted_vals[j] = vals[indices[j]];
246 }
247 }
248 }
249
250 #ifdef __CUDACC__
251
252 template<typename DType>
TopKCompare(DType val1,index_t ind1,DType val2,index_t ind2,bool is_ascend)253 MSHADOW_XINLINE bool TopKCompare(DType val1, index_t ind1, DType val2, index_t ind2,
254 bool is_ascend) {
255 // Negative indices denote undefined values which are considered arbitrary small resp. large.
256 return (ind2 < 0) || (ind1 >= 0 && ((is_ascend && val1 < val2) || (!is_ascend && val1 > val2)));
257 }
258
259 template<typename DType>
MergeTopK(index_t K,DType * val1,index_t * ind1,DType * val2,index_t * ind2,bool is_ascend)260 MSHADOW_XINLINE void MergeTopK(index_t K, DType *val1, index_t *ind1, DType *val2, index_t *ind2,
261 bool is_ascend) {
262 // In-place merge of two sorted top-K lists into val1/ind1. First determine the intervals
263 // [0,..,i1], [0,..i2] of the two lists that will be part of the merged list.
264 index_t i1(K-1), i2(K-1);
265 for (index_t i = 0; i < K; ++i) {
266 if (TopKCompare(val1[i1], ind1[i1], val2[i2], ind2[i2], is_ascend)) {
267 --i2;
268 } else {
269 --i1;
270 }
271 }
272 // Now merge the lists from back to front.
273 for (index_t i = K; i--;) {
274 if (i2 < 0 || i1 >= 0 && TopKCompare(val2[i2], ind2[i2], val1[i1], ind1[i1], is_ascend)) {
275 val1[i] = val1[i1];
276 ind1[i] = ind1[i1];
277 --i1;
278 } else {
279 val1[i] = val2[i2];
280 ind1[i] = ind2[i2];
281 --i2;
282 }
283 }
284 }
285
286 template<typename DType>
PartialSortSmallK(index_t K,index_t N,DType * val,index_t * ind,bool is_ascend)287 __global__ void PartialSortSmallK(index_t K, index_t N, DType *val, index_t *ind, bool is_ascend) {
288 // Buffer for pairwise reduction.
289 extern __shared__ index_t buff[];
290 // Start of buffer sections associated with this thread.
291 const index_t offset(threadIdx.x*K);
292 index_t *ind_buff = &buff[offset];
293 DType *val_buff = reinterpret_cast<DType*>(&buff[blockDim.x*K])+offset;
294 // Initialize top-K values for this thread.
295 for (index_t i = 0; i < K; ++i) {
296 ind_buff[i] = -1;
297 }
298 // Range of values this thread cares about. Each thread block processes
299 // a different batch item (i.e. a different set of ind/val where we
300 // have to select the top-K elements). All threads within the same
301 // block work on the same batch item.
302 const index_t first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N);
303 // Select top-K from this range and store it sorted in the buffer.
304 // We assume a small K, so linear insertion is o.k.
305 for (index_t i = first; i < last; i += blockDim.x) {
306 DType cur_val(val[i]);
307 index_t cur_ind(ind[i]);
308 for (index_t j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j],
309 ind_buff[j], is_ascend); ) {
310 if (j+1 < K) {
311 val_buff[j+1] = val_buff[j];
312 ind_buff[j+1] = ind_buff[j];
313 }
314 val_buff[j] = cur_val;
315 ind_buff[j] = cur_ind;
316 }
317 }
318 // Recursive merge of sorted lists for this thread block. Note that blockDim.x is not
319 // necessary a power of two, therefore the additional checks for last_s.
320 for (index_t s = (blockDim.x+1)/2, last_s = blockDim.x;
321 last_s > 1; last_s = s, s = (s+1)/2) {
322 __syncthreads();
323 if (threadIdx.x < s && threadIdx.x+s < last_s) {
324 MergeTopK(K, val_buff, ind_buff, val_buff+s*K, ind_buff+s*K, is_ascend);
325 }
326 }
327 // Final updates on master thread.
328 if (threadIdx.x == 0) {
329 for (index_t i = 0; i < K; ++i) {
330 ind[blockIdx.x*N+i] = ind_buff[i];
331 val[blockIdx.x*N+i] = val_buff[i];
332 }
333 }
334 }
335
336 template<typename DType>
TopKSort(const Tensor<gpu,1,DType> & dat,const Tensor<gpu,1,index_t> & ind,const Tensor<gpu,1,char> & work,index_t K,index_t N,bool is_ascend,Stream<gpu> * s)337 MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
338 const Tensor<gpu, 1, index_t>& ind,
339 const Tensor<gpu, 1, char>& work,
340 index_t K, index_t N, bool is_ascend,
341 Stream<gpu> *s) {
342 // Use full sort for all but very small K for which we
343 // can do a partial sort entirely within shared memory.
344 const bool full_sort(K > 5);
345 // Batch size.
346 const index_t M(dat.size(0)/N);
347 if (full_sort) {
348 // Divide workspace into two parts. The first one is needed to store batch ids.
349 size_t alignment = std::max(sizeof(DType), sizeof(index_t));
350 size_t id_size = PadBytes(sizeof(index_t) * ind.size(0), alignment);
351 Tensor<gpu, 1, index_t> batch_id(reinterpret_cast<index_t*>(work.dptr_),
352 Shape1(ind.size(0)), s);
353 Tensor<gpu, 1, char> sort_work(work.dptr_+id_size, Shape1(work.size(0)-id_size), s);
354 mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work);
355 if (M > 1) {
356 // Back to back sorting. Note that mxnet::op::SortByKey is a stable sort.
357 batch_id = ind / N;
358 mxnet::op::SortByKey(batch_id, dat, true, &sort_work);
359 batch_id = ind / N;
360 mxnet::op::SortByKey(batch_id, ind, true, &sort_work);
361 }
362 } else {
363 const int nthreads(mshadow::cuda::kBaseThreadNum);
364 PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(DType)),
365 mshadow::Stream<gpu>::GetStream(s)>>>
366 (K, N, dat.dptr_, ind.dptr_, is_ascend);
367 }
368 }
369
370 #endif
371
372
373 /*!
374 * \brief Implementation of the TopK operation
375 *
376 *
377 * \param ctx the running context
378 * \param resource temporary resource handler
379 * \param src the Source blob
380 * \param ret the destination blobs
381 * \param param the topk parameters
382 * \tparam xpu the device type.
383 * \tparam DType type of the output value/mask.
384 * \tparam IDType type of the output indices.
385 */
386 template<typename xpu, typename DType, typename IDType>
TopKImpl(const RunContext & ctx,const Resource & resource,const std::vector<OpReqType> & req,const TBlob & src,const std::vector<TBlob> & ret,const TopKParam & param)387 void TopKImpl(const RunContext &ctx,
388 const Resource &resource,
389 const std::vector<OpReqType>& req,
390 const TBlob& src,
391 const std::vector<TBlob>& ret,
392 const TopKParam& param) {
393 using namespace mshadow;
394 using namespace mshadow::expr;
395 // 0. If input shape is 0-shape, directly return
396 if (src.Size() == 0) return;
397 // 1. Parse and initialize information
398 Stream<xpu> *s = ctx.get_stream<xpu>();
399 Tensor<xpu, 1, char> workspace;
400 Tensor<xpu, 1, char> temp_workspace;
401 Tensor<xpu, 1, DType> sorted_dat;
402 Tensor<xpu, 1, index_t> indices, sel_indices;
403 size_t batch_size = 0;
404 index_t element_num = 0; // number of batches + the size of each batch
405 int axis = 0;
406 bool do_transpose = false;
407 bool is_ascend = false;
408 index_t k = 0;
409 size_t alignment = std::max(sizeof(DType), sizeof(index_t));
410 mxnet::TShape target_shape;
411 ParseTopKParam(src.shape_, param,
412 &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
413 CHECK_LE(element_num, mxnet::common::MaxIntegerValue<index_t>())
414 << "'index_t' does not have a sufficient precision to represent "
415 << "the indices of the input array. The total element_num is "
416 << element_num << ", but the selected index_t can only represent "
417 << mxnet::common::MaxIntegerValue<index_t>() << " elements";
418 Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
419 // Temp space needed by the full sorts.
420 size_t temp_size = std::max(
421 mxnet::op::SortByKeyWorkspaceSize<index_t, DType, xpu>(src.Size()),
422 mxnet::op::SortByKeyWorkspaceSize<DType, index_t, xpu>(src.Size()));
423
424 temp_size = std::max(temp_size,
425 mxnet::op::SortByKeyWorkspaceSize<index_t, index_t, xpu>(src.Size()));
426 // Additional temp space for gpu full sorts for batch ids.
427 temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment);
428 // Temp space for cpu sorts.
429 temp_size = std::max(temp_size, sizeof(DType) * src.Size());
430
431 size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment)
432 + PadBytes(sizeof(index_t) * src.Size(), alignment);
433 if (param.ret_typ == topk_enum::kReturnMask) {
434 workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment);
435 }
436 workspace = resource.get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
437 char* workspace_curr_ptr = workspace.dptr_;
438 sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
439 Shape1(src.Size()), s); // contain sorted dat
440 workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
441 indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
442 Shape1(src.Size()), s); // indices in the original matrix
443 workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment);
444
445 if (param.ret_typ == topk_enum::kReturnMask) {
446 sel_indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
447 Shape1(batch_size * k), s);
448 workspace_curr_ptr += PadBytes(sizeof(index_t) * batch_size * k, alignment);
449 CHECK_EQ(sel_indices.CheckContiguous(), true);
450 }
451
452 if (std::is_same<xpu, cpu>::value) {
453 Tensor<xpu, 1, DType> flattened_data;
454 if (do_transpose) {
455 flattened_data = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
456 Shape1(src.Size()), s);
457 workspace_curr_ptr += sizeof(DType) * src.Size();
458 flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
459 CHECK_EQ(flattened_data.CheckContiguous(), true);
460 } else {
461 flattened_data = src.FlatTo1D<xpu, DType>(s);
462 }
463 // `temp_workspace` stores the flattened data
464 temp_workspace = Tensor<xpu, 1, char>(reinterpret_cast<char*>(flattened_data.dptr_),
465 Shape1(sizeof(DType)*src.Size()), s);
466 CHECK_EQ(temp_workspace.CheckContiguous(), true);
467 } else {
468 if (do_transpose) {
469 sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
470 } else {
471 sorted_dat = reshape(dat, Shape1(src.Size()));
472 }
473 CHECK_EQ(sorted_dat.CheckContiguous(), true);
474 temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s); // temp space
475 workspace_curr_ptr += temp_size;
476 }
477
478 mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, index_t{0}, index_t{1},
479 kWriteTo, indices.dptr_);
480 CHECK_EQ(indices.CheckContiguous(), true);
481
482 // 2. Perform inplace batch sort.
483 // After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
484 // up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat`
485 // `temp_workspace` is used to store the flattend source data for CPU device, and it's used as
486 // a temporal buffer for GPU device.
487 TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s);
488
489 // 3. Assign results to the ret blob
490 // When returning indices, only update(modulo) required elements instead of full elements
491 // to avoid redundant calculation.
492 // Cast `ret_indices` from int to real_t could introduce conversion error when the element_num
493 // is large enough.
494 if (param.ret_typ == topk_enum::kReturnMask) {
495 Tensor<xpu, 1, DType> ret_mask = ret[0].FlatTo1D<xpu, DType>(s);
496 ret_mask = scalar<DType>(0);
497 sel_indices = reshape(slice<1>(
498 inplace_reshape(indices,
499 Shape2(batch_size,
500 element_num)), 0, k),
501 Shape1(batch_size * k));
502 if (do_transpose) {
503 mxnet::TShape src_shape = src.shape_.FlatTo3D(axis);
504 CHECK_EQ(sel_indices.CheckContiguous(), true);
505 sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
506 Shape3(0, 2, 1));
507 }
508 if (req[0] == kNullOp) {
509 return;
510 } else if (req[0] == kWriteTo) {
511 mxnet_op::Kernel<fill_ind_to_one, xpu>::Launch(s, batch_size * k,
512 sel_indices.dptr_, ret_mask.dptr_);
513 } else {
514 LOG(FATAL) << "req=" << req[0] << " is not supported yet.";
515 }
516 } else if (param.ret_typ == topk_enum::kReturnIndices) {
517 if (do_transpose) {
518 Tensor<xpu, 3, IDType> ret_indices = ret[0].FlatTo3D<xpu, IDType>(axis, axis, s);
519 ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(transpose(
520 slice<2>(inplace_reshape(indices,
521 Shape3(ret_indices.shape_[0],
522 ret_indices.shape_[2],
523 element_num)),
524 0, k),
525 Shape3(0, 2, 1)), element_num)));
526 } else {
527 Tensor<xpu, 2, IDType> ret_indices =
528 ret[0].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
529 ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
530 inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k),
531 element_num)));
532 }
533 } else {
534 if (do_transpose) {
535 Tensor<xpu, 3, DType> ret_value = ret[0].FlatTo3D<xpu, DType>(axis, axis, s);
536 Tensor<xpu, 3, IDType> ret_indices = ret[1].FlatTo3D<xpu, IDType>(axis, axis, s);
537 ASSIGN_DISPATCH(ret_value, req[0], transpose(
538 slice<2>(inplace_reshape(sorted_dat,
539 Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
540 0, k), Shape3(0, 2, 1)));
541 ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(transpose(
542 slice<2>(inplace_reshape(indices,
543 Shape3(ret_indices.shape_[0],
544 ret_indices.shape_[2],
545 element_num)),
546 0, k), Shape3(0, 2, 1)), element_num)));
547 } else {
548 Tensor<xpu, 2, DType> ret_value =
549 ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
550 Tensor<xpu, 2, IDType> ret_indices =
551 ret[1].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
552 ASSIGN_DISPATCH(ret_value, req[0],
553 slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k));
554 ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
555 inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k), element_num)));
556 }
557 }
558 }
559
560 template<typename xpu, typename DType>
TopKWorkspaceSize(const TBlob & src,const TopKParam & param,size_t * temp_size_ptr)561 size_t TopKWorkspaceSize(const TBlob& src,
562 const TopKParam& param,
563 size_t *temp_size_ptr) {
564 using namespace mshadow;
565 using namespace mshadow::expr;
566
567 size_t batch_size = 0;
568 size_t temp_size;
569 index_t element_num = 0; // number of batches + the size of each batch
570 int axis = 0;
571 bool do_transpose = false;
572 bool is_ascend = false;
573 index_t k = 0;
574 size_t alignment = std::max(sizeof(DType), sizeof(index_t));
575 mxnet::TShape target_shape;
576 ParseTopKParam(src.shape_, param,
577 &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
578
579 // Temp space needed by the full sorts.
580 temp_size = std::max(
581 mxnet::op::SortByKeyWorkspaceSize<index_t, DType, xpu>(src.Size()),
582 mxnet::op::SortByKeyWorkspaceSize<DType, index_t, xpu>(src.Size()));
583
584 temp_size = std::max(temp_size,
585 mxnet::op::SortByKeyWorkspaceSize<index_t, index_t, xpu>(src.Size()));
586 // Additional temp space for gpu full sorts for batch ids.
587 temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment);
588 // Temp space for cpu sorts.
589 temp_size = std::max(temp_size, sizeof(DType) * src.Size());
590 *temp_size_ptr = temp_size;
591
592 size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment)
593 + PadBytes(sizeof(index_t) * src.Size(), alignment);
594 if (param.ret_typ == topk_enum::kReturnMask) {
595 workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment);
596 }
597 return workspace_size;
598 }
599
600 template<typename xpu, typename DType, typename IDType>
TopKImplwithWorkspace(const RunContext & ctx,const std::vector<OpReqType> & req,const TBlob & src,const std::vector<TBlob> & ret,const TopKParam & param,char * workspace_curr_ptr,const size_t & temp_size,Stream<xpu> * s)601 void TopKImplwithWorkspace(const RunContext &ctx,
602 const std::vector<OpReqType>& req,
603 const TBlob& src,
604 const std::vector<TBlob>& ret,
605 const TopKParam& param,
606 char* workspace_curr_ptr,
607 const size_t &temp_size,
608 Stream<xpu>* s) {
609 using namespace mshadow;
610 using namespace mshadow::expr;
611 // 0. If input shape is 0-shape, directly return
612 if (src.Size() == 0) return;
613 // 1. Parse and initialize information
614 Tensor<xpu, 1, char> workspace;
615 Tensor<xpu, 1, char> temp_workspace;
616 Tensor<xpu, 1, DType> sorted_dat;
617 Tensor<xpu, 1, index_t> indices, sel_indices;
618 size_t batch_size = 0;
619 index_t element_num = 0; // number of batches + the size of each batch
620 int axis = 0;
621 bool do_transpose = false;
622 bool is_ascend = false;
623 index_t k = 0;
624 size_t alignment = std::max(sizeof(DType), sizeof(index_t));
625 mxnet::TShape target_shape;
626 ParseTopKParam(src.shape_, param,
627 &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
628 CHECK_LE(element_num, mxnet::common::MaxIntegerValue<index_t>())
629 << "'index_t' does not have a sufficient precision to represent "
630 << "the indices of the input array. The total element_num is "
631 << element_num << ", but the selected index_t can only represent "
632 << mxnet::common::MaxIntegerValue<index_t>() << " elements";
633 Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
634 sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
635 Shape1(src.Size()), s); // contain sorted dat
636 workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
637 indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
638 Shape1(src.Size()), s); // indices in the original matrix
639 workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment);
640
641 if (param.ret_typ == topk_enum::kReturnMask) {
642 sel_indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
643 Shape1(batch_size * k), s);
644 workspace_curr_ptr += PadBytes(sizeof(index_t) * batch_size * k, alignment);
645 CHECK_EQ(sel_indices.CheckContiguous(), true);
646 }
647
648 if (std::is_same<xpu, cpu>::value) {
649 Tensor<xpu, 1, DType> flattened_data;
650 if (do_transpose) {
651 flattened_data = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
652 Shape1(src.Size()), s);
653 workspace_curr_ptr += sizeof(DType) * src.Size();
654 flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
655 CHECK_EQ(flattened_data.CheckContiguous(), true);
656 } else {
657 flattened_data = src.FlatTo1D<xpu, DType>(s);
658 }
659 // `temp_workspace` stores the flattened data
660 temp_workspace = Tensor<xpu, 1, char>(reinterpret_cast<char*>(flattened_data.dptr_),
661 Shape1(sizeof(DType)*src.Size()), s);
662 CHECK_EQ(temp_workspace.CheckContiguous(), true);
663 } else {
664 if (do_transpose) {
665 sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
666 } else {
667 sorted_dat = reshape(dat, Shape1(src.Size()));
668 }
669 CHECK_EQ(sorted_dat.CheckContiguous(), true);
670 temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s); // temp space
671 workspace_curr_ptr += temp_size;
672 }
673
674 mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, index_t{0}, index_t{1},
675 kWriteTo, indices.dptr_);
676 CHECK_EQ(indices.CheckContiguous(), true);
677
678 // 2. Perform inplace batch sort.
679 // After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
680 // up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat`
681 // `temp_workspace` is used to store the flattend source data for CPU device, and it's used as
682 // a temporal buffer for GPU device.
683 TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s);
684
685 // 3. Assign results to the ret blob
686 // When returning indices, only update(modulo) required elements instead of full elements
687 // to avoid redundant calculation.
688 // Cast `ret_indices` from int to real_t could introduce conversion error when the element_num
689 // is large enough.
690 if (param.ret_typ == topk_enum::kReturnMask) {
691 Tensor<xpu, 1, DType> ret_mask = ret[0].FlatTo1D<xpu, DType>(s);
692 ret_mask = scalar<DType>(0);
693 sel_indices = reshape(slice<1>(
694 inplace_reshape(indices,
695 Shape2(batch_size,
696 element_num)), 0, k),
697 Shape1(batch_size * k));
698 if (do_transpose) {
699 mxnet::TShape src_shape = src.shape_.FlatTo3D(axis);
700 CHECK_EQ(sel_indices.CheckContiguous(), true);
701 sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
702 Shape3(0, 2, 1));
703 }
704 if (req[0] == kNullOp) {
705 return;
706 } else if (req[0] == kWriteTo) {
707 mxnet_op::Kernel<fill_ind_to_one, xpu>::Launch(s, batch_size * k,
708 sel_indices.dptr_, ret_mask.dptr_);
709 } else {
710 LOG(FATAL) << "req=" << req[0] << " is not supported yet.";
711 }
712 } else if (param.ret_typ == topk_enum::kReturnIndices) {
713 if (do_transpose) {
714 Tensor<xpu, 3, IDType> ret_indices = ret[0].FlatTo3D<xpu, IDType>(axis, axis, s);
715 ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(transpose(
716 slice<2>(inplace_reshape(indices,
717 Shape3(ret_indices.shape_[0],
718 ret_indices.shape_[2],
719 element_num)),
720 0, k),
721 Shape3(0, 2, 1)), element_num)));
722 } else {
723 Tensor<xpu, 2, IDType> ret_indices =
724 ret[0].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
725 ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
726 inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k),
727 element_num)));
728 }
729 } else {
730 if (do_transpose) {
731 Tensor<xpu, 3, DType> ret_value = ret[0].FlatTo3D<xpu, DType>(axis, axis, s);
732 Tensor<xpu, 3, IDType> ret_indices = ret[1].FlatTo3D<xpu, IDType>(axis, axis, s);
733 ASSIGN_DISPATCH(ret_value, req[0], transpose(
734 slice<2>(inplace_reshape(sorted_dat,
735 Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
736 0, k), Shape3(0, 2, 1)));
737 ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(transpose(
738 slice<2>(inplace_reshape(indices,
739 Shape3(ret_indices.shape_[0],
740 ret_indices.shape_[2],
741 element_num)),
742 0, k), Shape3(0, 2, 1)), element_num)));
743 } else {
744 Tensor<xpu, 2, DType> ret_value =
745 ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
746 Tensor<xpu, 2, IDType> ret_indices =
747 ret[1].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
748 ASSIGN_DISPATCH(ret_value, req[0],
749 slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k));
750 ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
751 inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k), element_num)));
752 }
753 }
754 }
755
756 template<typename xpu>
TopK(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)757 void TopK(const nnvm::NodeAttrs& attrs,
758 const OpContext& ctx,
759 const std::vector<TBlob>& inputs,
760 const std::vector<OpReqType>& req,
761 const std::vector<TBlob>& outputs) {
762 const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
763 if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) {
764 MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
765 MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
766 TopKImpl<xpu, DType, IDType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
767 })
768 });
769 } else {
770 MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
771 TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
772 });
773 }
774 }
775
776 template<typename xpu>
Sort(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)777 void Sort(const nnvm::NodeAttrs& attrs,
778 const OpContext& ctx,
779 const std::vector<TBlob>& inputs,
780 const std::vector<OpReqType>& req,
781 const std::vector<TBlob>& outputs) {
782 const SortParam& param = nnvm::get<SortParam>(attrs.parsed);
783 TopKParam topk_param;
784 topk_param.axis = param.axis;
785 topk_param.is_ascend = param.is_ascend;
786 topk_param.k = 0;
787 topk_param.ret_typ = topk_enum::kReturnValue;
788 MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
789 TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, inputs[0],
790 outputs, topk_param);
791 });
792 }
793
794 template<typename xpu>
ArgSort(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)795 void ArgSort(const nnvm::NodeAttrs& attrs,
796 const OpContext& ctx,
797 const std::vector<TBlob>& inputs,
798 const std::vector<OpReqType>& req,
799 const std::vector<TBlob>& outputs) {
800 const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
801 TopKParam topk_param;
802 topk_param.axis = param.axis;
803 topk_param.is_ascend = param.is_ascend;
804 topk_param.k = 0;
805 topk_param.dtype = param.dtype;
806 topk_param.ret_typ = topk_enum::kReturnIndices;
807 MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
808 MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
809 TopKImpl<xpu, DType, IDType>(ctx.run_ctx,
810 ctx.requested[0], req, inputs[0], outputs, topk_param);
811 });
812 });
813 }
814
815 template<typename xpu, typename DType, typename IDType>
TopKBackwardImpl(const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs,const TopKParam & param)816 void TopKBackwardImpl(const OpContext &ctx,
817 const std::vector<TBlob>& inputs,
818 const std::vector<OpReqType>& req,
819 const std::vector<TBlob>& outputs,
820 const TopKParam& param) {
821 CHECK_NE(req[0], kWriteInplace);
822 using namespace mshadow;
823 using namespace mshadow::expr;
824 Stream<xpu> *s = ctx.run_ctx.get_stream<xpu>();
825 CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth);
826 size_t batch_size = 0;
827 index_t element_num = 0; // number of batches + the size of each batch
828 int axis = 0;
829 bool do_transpose = false;
830 bool is_ascend = false;
831 index_t k = 0;
832 mxnet::TShape target_shape;
833 ParseTopKParam(outputs[0].shape_, param,
834 &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
835 CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
836 << "'IDType' does not have a sufficient precision to represent "
837 << "the indices of the input array. The total element_num is " << element_num
838 << ", but the selected index_t can only represent "
839 << mxnet::common::MaxIntegerValue<IDType>() << " elements";
840 Tensor<xpu, 1, index_t> workspace =
841 ctx.requested[0].get_space_typed<xpu, 1, index_t>(Shape1(batch_size * k + batch_size), s);
842 Tensor<xpu, 1, index_t> sel_indices =
843 Tensor<xpu, 1, index_t>(workspace.dptr_, Shape1(batch_size * k), s);
844 Tensor<xpu, 1, index_t> batch_shift =
845 Tensor<xpu, 1, index_t>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
846
847 Tensor<xpu, 2, DType> out_grad =
848 inputs[0].get_with_shape<xpu, 2, DType>(Shape2(inputs[0].shape_.Size(), 1), s);
849 Tensor<xpu, 2, DType> in_grad =
850 outputs[0].get_with_shape<xpu, 2, DType>(Shape2(outputs[0].shape_.Size(), 1), s);
851 mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, index_t{0}, element_num, kWriteTo,
852 batch_shift.dptr_);
853 if (do_transpose) {
854 Tensor<xpu, 1, IDType> indices = inputs[2].FlatTo1D<xpu, IDType>(s);
855 mxnet::TShape src_shape = outputs[0].shape_.FlatTo3D(axis);
856 sel_indices = reshape(transpose(
857 broadcast_to(inplace_reshape(batch_shift,
858 Shape3(src_shape[0], src_shape[2], 1)),
859 mxnet::TShape(Shape3(src_shape[0], src_shape[2], k))),
860 Shape3(0, 2, 1)),
861 Shape1(batch_size * k));
862 sel_indices += tcast<index_t>(indices);
863 sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
864 Shape3(0, 2, 1));
865 } else {
866 Tensor<xpu, 2, IDType> indices =
867 inputs[2].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
868 sel_indices = reshape(tcast<index_t>(indices) +
869 broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)),
870 mxnet::TShape(Shape2(batch_size, k))),
871 Shape1(batch_size * k));
872 }
873 CHECK_EQ(sel_indices.CheckContiguous(), true);
874 if (kWriteTo == req[0] || kAddTo == req[0]) {
875 if (kWriteTo == req[0]) {
876 in_grad = scalar<DType>(0);
877 }
878 mxnet_op::Kernel<fill_ind, xpu>::Launch(s, batch_size * k,
879 sel_indices.dptr_,
880 out_grad.dptr_,
881 req[0],
882 in_grad.dptr_);
883 } else {
884 LOG(FATAL) << "Not Implemented!";
885 }
886 }
887
888 template<typename xpu>
TopKBackward_(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)889 void TopKBackward_(const nnvm::NodeAttrs& attrs,
890 const OpContext& ctx,
891 const std::vector<TBlob>& inputs,
892 const std::vector<OpReqType>& req,
893 const std::vector<TBlob>& outputs) {
894 const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
895 if (param.ret_typ == topk_enum::kReturnBoth) {
896 MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
897 MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
898 TopKBackwardImpl<xpu, DType, IDType>(ctx, inputs, req, outputs, param);
899 });
900 });
901 } else if (param.ret_typ == topk_enum::kReturnValue) {
902 MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
903 TopKBackwardImpl<xpu, DType, index_t>(ctx, inputs, req, outputs, param);
904 });
905 } else {
906 LOG(FATAL) << "Not Implemented";
907 }
908 }
909
TopKNumOutputs(const NodeAttrs & attrs)910 inline uint32_t TopKNumOutputs(const NodeAttrs& attrs) {
911 const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
912 if (param.ret_typ == topk_enum::kReturnIndices ||
913 param.ret_typ == topk_enum::kReturnMask) {
914 return static_cast<uint32_t>(1);
915 } else {
916 return static_cast<uint32_t>(2);
917 }
918 }
919
TopKNumVisibleOutputs(const NodeAttrs & attrs)920 inline uint32_t TopKNumVisibleOutputs(const NodeAttrs& attrs) {
921 const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
922 if (param.ret_typ == topk_enum::kReturnBoth) {
923 return static_cast<uint32_t>(2);
924 } else {
925 return static_cast<uint32_t>(1);
926 }
927 }
928
TopKType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)929 inline bool TopKType(const nnvm::NodeAttrs& attrs,
930 std::vector<int> *in_attrs,
931 std::vector<int> *out_attrs) {
932 const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
933 size_t in_size = in_attrs->size();
934 size_t out_size = out_attrs->size();
935 CHECK_EQ(in_size, 1);
936 CHECK(out_size == 1 || out_size == 2);
937 // out_attr[0] -> stores value
938 // out_attr[1] -> stores indices
939 if (out_size > 1) {
940 if (param.ret_typ == topk_enum::kReturnValue) {
941 #if MXNET_USE_INT64_TENSOR_SIZE == 1
942 CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64))
943 #else
944 CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
945 #endif
946 << "Failed to set the type of ret_indices.";
947 } else {
948 CHECK(type_assign(&(*out_attrs)[1], param.dtype))
949 << "Failed to set the type of ret_indices.";
950 }
951 }
952 if (param.ret_typ == topk_enum::kReturnIndices) {
953 CHECK(type_assign(&(*out_attrs)[0], param.dtype))
954 << "Failed to set the type of ret_indices.";
955 } else {
956 TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
957 TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
958 return out_attrs->at(0) != -1;
959 }
960 return true;
961 }
962
TopKShapeImpl(const TopKParam & param,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)963 inline bool TopKShapeImpl(const TopKParam& param,
964 mxnet::ShapeVector *in_attrs,
965 mxnet::ShapeVector *out_attrs) {
966 CHECK_EQ(in_attrs->size(), 1U);
967 if (param.ret_typ == topk_enum::kReturnIndices ||
968 param.ret_typ == topk_enum::kReturnMask) {
969 CHECK_EQ(out_attrs->size(), 1U);
970 } else {
971 CHECK_EQ(out_attrs->size(), 2U);
972 }
973 mxnet::TShape& in_shape = (*in_attrs)[0];
974 size_t batch_size = 0;
975 index_t element_num = 0; // number of batches + the size of each batch
976 int axis = 0;
977 bool do_transpose = false;
978 bool is_ascend = false;
979 index_t k = 0;
980 mxnet::TShape target_shape;
981 ParseTopKParam(in_shape, param,
982 &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
983 if (param.ret_typ == topk_enum::kReturnIndices ||
984 param.ret_typ == topk_enum::kReturnMask) {
985 SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
986 } else {
987 SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
988 SHAPE_ASSIGN_CHECK(*out_attrs, 1, target_shape);
989 }
990 return true;
991 }
992
TopKShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)993 inline bool TopKShape(const nnvm::NodeAttrs& attrs,
994 mxnet::ShapeVector *in_attrs,
995 mxnet::ShapeVector *out_attrs) {
996 const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
997 return TopKShapeImpl(param, in_attrs, out_attrs);
998 }
999
SortType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)1000 inline bool SortType(const nnvm::NodeAttrs& attrs,
1001 std::vector<int> *in_attrs,
1002 std::vector<int> *out_attrs) {
1003 int data_type = -1;
1004 size_t in_size = in_attrs->size();
1005 size_t out_size = out_attrs->size();
1006 CHECK_EQ(in_size, 1);
1007 CHECK_EQ(out_size, 2);
1008 #if MXNET_USE_INT64_TENSOR_SIZE == 1
1009 CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64))
1010 #else
1011 CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
1012 #endif
1013 << "Failed to set the type of ret_indices";
1014 CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
1015 << (*in_attrs)[0];
1016 CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
1017 << (*out_attrs)[0];
1018 CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
1019 << (*in_attrs)[0];
1020 CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
1021 << (*out_attrs)[0];
1022 if (data_type == -1) return false;
1023 return true;
1024 }
1025
SortShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)1026 inline bool SortShape(const nnvm::NodeAttrs& attrs,
1027 mxnet::ShapeVector *in_attrs,
1028 mxnet::ShapeVector *out_attrs) {
1029 const SortParam& param = nnvm::get<SortParam>(attrs.parsed);
1030 TopKParam topk_param;
1031 topk_param.axis = param.axis;
1032 topk_param.is_ascend = param.is_ascend;
1033 topk_param.k = 0;
1034 topk_param.ret_typ = topk_enum::kReturnValue;
1035 return TopKShapeImpl(topk_param, in_attrs, out_attrs);
1036 }
1037
ArgSortType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)1038 inline bool ArgSortType(const nnvm::NodeAttrs& attrs,
1039 std::vector<int> *in_attrs,
1040 std::vector<int> *out_attrs) {
1041 const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
1042 CHECK(type_assign(&(*out_attrs)[0], param.dtype))
1043 << "Failed to set the type of ret_indices.";
1044 return true;
1045 }
1046
ArgSortShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)1047 inline bool ArgSortShape(const nnvm::NodeAttrs& attrs,
1048 mxnet::ShapeVector *in_attrs,
1049 mxnet::ShapeVector *out_attrs) {
1050 const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
1051 TopKParam topk_param;
1052 topk_param.axis = param.axis;
1053 topk_param.is_ascend = param.is_ascend;
1054 topk_param.k = 0;
1055 topk_param.ret_typ = topk_enum::kReturnIndices;
1056 return TopKShapeImpl(topk_param, in_attrs, out_attrs);
1057 }
1058 } // namespace op
1059 } // namespace mxnet
1060 #endif // MXNET_OPERATOR_TENSOR_ORDERING_OP_INL_H_
1061