1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #ifndef MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_
21 #define MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_
22
23 #include <vector>
24 #include <utility>
25 #include "../mshadow_op.h"
26 #include "../tensor/init_op.h"
27
28 namespace mxnet {
29 namespace op {
30
31 namespace index_array_enum {
32 enum IndexArrayOpInputs {kIn};
33 enum IndexArrayOpOutputs {kOut};
34 enum IndexArrayOpResource {kTempSpace};
35 } // namespace index_array_enum
36
37 template<int req>
38 struct IndexArrayKernel {
MapIndexArrayKernel39 MSHADOW_XINLINE static void Map(int i,
40 int64_t* out_data,
41 const int n,
42 const int64_t* workspace) {
43 for (ptrdiff_t j = 0; j < n; j++) {
44 int64_t upper = workspace[ptrdiff_t(2) * j];
45 int64_t lower = workspace[ptrdiff_t(2) * j + ptrdiff_t(1)];
46 KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(n) + j], req, (i % upper) / lower);
47 }
48 }
49 };
50
51 template<int req>
52 struct IndexArrayDefaultKernel {
MapIndexArrayDefaultKernel53 MSHADOW_XINLINE static void Map(int i,
54 int64_t* out_data,
55 const int ndim,
56 const dim_t* shape) {
57 int64_t index = i;
58 for (ptrdiff_t j = ndim - 1; j >= 0; j--) {
59 KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(ndim) + j], req, index % shape[j]);
60 index /= shape[j];
61 }
62 }
63 };
64
IndexArrayComputeIndexProducts(const TShape & inshape)65 inline std::vector<int64_t> IndexArrayComputeIndexProducts(const TShape &inshape) {
66 const int ndim = inshape.ndim();
67
68 std::vector<int64_t> index_products(static_cast<size_t>(ndim + 1));
69
70 index_products[ndim] = 1;
71
72 for (int i = ndim - 1; i >= 0; i--) {
73 index_products[i] = index_products[i + 1] * inshape[i];
74 }
75
76 return index_products;
77 }
78
IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple<int> & axes,const std::vector<int64_t> & index_products,int64_t * workspace,const int ndim)79 inline void IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple<int> &axes,
80 const std::vector<int64_t> &index_products,
81 int64_t* workspace,
82 const int ndim) {
83 for (int i = 0; i < axes.ndim(); i++) {
84 // Make sure that the axis is between 0 and ndim.
85 const int axis = ((axes[i] % ndim) + ndim) % ndim;
86
87 workspace[ptrdiff_t(2) * ptrdiff_t(i)] = index_products[axis];
88 workspace[ptrdiff_t(2) * ptrdiff_t(i) + ptrdiff_t(1)] = index_products[axis + 1];
89 }
90 }
91
92 struct IndexArrayParam : public dmlc::Parameter<IndexArrayParam> {
93 dmlc::optional<mxnet::Tuple<int>> axes;
DMLC_DECLARE_PARAMETERIndexArrayParam94 DMLC_DECLARE_PARAMETER(IndexArrayParam) {
95 DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional<mxnet::Tuple<int>>())
96 .describe("The axes to include in the index array. Supports negative values.");
97 }
98 }; // struct IndexArrayParam
99
100 } // namespace op
101 } // namespace mxnet
102
103 #endif // MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_
104