1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_
21 #define MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_
22 
23 #include <vector>
24 #include <utility>
25 #include "../mshadow_op.h"
26 #include "../tensor/init_op.h"
27 
28 namespace mxnet {
29 namespace op {
30 
31 namespace index_array_enum {
32 enum IndexArrayOpInputs {kIn};
33 enum IndexArrayOpOutputs {kOut};
34 enum IndexArrayOpResource {kTempSpace};
35 }  // namespace index_array_enum
36 
37 template<int req>
38 struct IndexArrayKernel {
MapIndexArrayKernel39   MSHADOW_XINLINE static void Map(int i,
40                                   int64_t* out_data,
41                                   const int n,
42                                   const int64_t* workspace) {
43     for (ptrdiff_t j = 0; j < n; j++) {
44       int64_t upper = workspace[ptrdiff_t(2) * j];
45       int64_t lower = workspace[ptrdiff_t(2) * j + ptrdiff_t(1)];
46       KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(n) + j], req, (i % upper) / lower);
47     }
48   }
49 };
50 
51 template<int req>
52 struct IndexArrayDefaultKernel {
MapIndexArrayDefaultKernel53   MSHADOW_XINLINE static void Map(int i,
54                                   int64_t* out_data,
55                                   const int ndim,
56                                   const dim_t* shape) {
57     int64_t index = i;
58     for (ptrdiff_t j = ndim - 1; j >= 0; j--) {
59       KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(ndim) + j], req, index % shape[j]);
60       index /= shape[j];
61     }
62   }
63 };
64 
IndexArrayComputeIndexProducts(const TShape & inshape)65 inline std::vector<int64_t> IndexArrayComputeIndexProducts(const TShape &inshape) {
66   const int ndim = inshape.ndim();
67 
68   std::vector<int64_t> index_products(static_cast<size_t>(ndim + 1));
69 
70   index_products[ndim] = 1;
71 
72   for (int i = ndim - 1; i >= 0; i--) {
73     index_products[i] = index_products[i + 1] * inshape[i];
74   }
75 
76   return index_products;
77 }
78 
IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple<int> & axes,const std::vector<int64_t> & index_products,int64_t * workspace,const int ndim)79 inline void IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple<int> &axes,
80                                                  const std::vector<int64_t> &index_products,
81                                                  int64_t* workspace,
82                                                  const int ndim) {
83   for (int i = 0; i < axes.ndim(); i++) {
84     // Make sure that the axis is between 0 and ndim.
85     const int axis = ((axes[i] % ndim) + ndim) % ndim;
86 
87     workspace[ptrdiff_t(2) * ptrdiff_t(i)] = index_products[axis];
88     workspace[ptrdiff_t(2) * ptrdiff_t(i) + ptrdiff_t(1)] = index_products[axis + 1];
89   }
90 }
91 
92 struct IndexArrayParam : public dmlc::Parameter<IndexArrayParam> {
93   dmlc::optional<mxnet::Tuple<int>> axes;
DMLC_DECLARE_PARAMETERIndexArrayParam94   DMLC_DECLARE_PARAMETER(IndexArrayParam) {
95     DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional<mxnet::Tuple<int>>())
96       .describe("The axes to include in the index array. Supports negative values.");
97   }
98 };  // struct IndexArrayParam
99 
100 }  // namespace op
101 }  // namespace mxnet
102 
103 #endif  // MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_
104