1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file transsparse_lib.cc
22 * \brief Sample 2D transpose custom operator.
23 */
24
25 #include <iostream>
26 #include "mxnet/lib_api.h"
27
28 using namespace mxnet::ext;
29
transpose(MXTensor & src,MXTensor & dst,const OpResource & res)30 void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
31 MXSparse* A = src.data<MXSparse>();
32 MXSparse* B = dst.data<MXSparse>();
33
34 std::vector<int64_t> shape = src.shape;
35 int64_t h = shape[0];
36 int64_t w = shape[1];
37 if(src.stype == kRowSparseStorage) {
38 // Keys of the map is the row index of transposed tensors.
39 // Values of the map is the rows which have non-zero elements.
40 std::map<int, std::vector<float>> mp;
41 float *Aval = (float*) (A->data);
42 for(int i = 0; i < A->data_len; i++) {
43 int row = i / w;
44 int col = i % w;
45 row = A->indices[row];
46 if(Aval[i] != 0) {
47 if(mp.find(col) == mp.end()) {
48 mp[col] = std::vector<float>(h, 0);
49 mp[col][row] = Aval[i];
50 }
51 else {
52 mp[col][row] = Aval[i];
53 }
54 }
55 }
56
57 // Alloc memory for output tensors.
58 res.alloc_sparse(B, 0, mp.size());
59 float *Bval = (float*) (B->data);
60 int didx = 0, iidx = 0;
61 for(auto i : mp) {
62 B->indices[iidx++] = i.first;
63 for(auto j : i.second) {
64 Bval[didx++] = j;
65 }
66 }
67 }
68 }
69
forward(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)70 MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
71 std::vector<MXTensor>* inputs,
72 std::vector<MXTensor>* outputs,
73 const OpResource& res) {
74 // The data types and storage types of inputs and outputs should be the same.
75 if(inputs->at(0).dtype != outputs->at(0).dtype ||
76 inputs->at(0).stype != outputs->at(0).stype) {
77 MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
78 << "Found input storage type:" << inputs->at(0).stype
79 << " Found output storage type:" << outputs->at(0).stype
80 << " Found input data type:" << inputs->at(0).dtype
81 << " Found output data type:" << outputs->at(0).dtype;
82 return MX_FAIL;
83 }
84 transpose(inputs->at(0), outputs->at(0), res);
85 return MX_SUCCESS;
86 }
87
backward(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)88 MXReturnValue backward(const std::unordered_map<std::string, std::string>& attrs,
89 std::vector<MXTensor>* inputs,
90 std::vector<MXTensor>* outputs,
91 const OpResource& res) {
92 return MX_SUCCESS;
93 }
94
parseAttrs(const std::unordered_map<std::string,std::string> & attrs,int * num_in,int * num_out)95 MXReturnValue parseAttrs(const std::unordered_map<std::string, std::string>& attrs,
96 int* num_in, int* num_out) {
97 *num_in = 1;
98 *num_out = 1;
99 return MX_SUCCESS;
100 }
101
inferType(const std::unordered_map<std::string,std::string> & attrs,std::vector<int> * intypes,std::vector<int> * outtypes)102 MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attrs,
103 std::vector<int>* intypes,
104 std::vector<int>* outtypes) {
105 // validate inputs
106 if (intypes->size() != 1) {
107 MX_ERROR_MSG << "Expected 1 inputs to inferType";
108 return MX_FAIL;
109 }
110 if (intypes->at(0) != kFloat32) {
111 MX_ERROR_MSG << "Expected input to have float32 type";
112 return MX_FAIL;
113 }
114
115 outtypes->at(0) = intypes->at(0);
116 return MX_SUCCESS;
117 }
118
inferSType(const std::unordered_map<std::string,std::string> & attrs,std::vector<int> * instypes,std::vector<int> * outstypes)119 MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& attrs,
120 std::vector<int>* instypes,
121 std::vector<int>* outstypes) {
122 if (instypes->at(0) != kRowSparseStorage) {
123 MX_ERROR_MSG << "Expected storage type is kRowSparseStorage";
124 return MX_FAIL;
125 }
126 outstypes->at(0) = instypes->at(0);
127 return MX_SUCCESS;
128 }
129
inferShape(const std::unordered_map<std::string,std::string> & attrs,std::vector<std::vector<unsigned int>> * inshapes,std::vector<std::vector<unsigned int>> * outshapes)130 MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& attrs,
131 std::vector<std::vector<unsigned int>>* inshapes,
132 std::vector<std::vector<unsigned int>>* outshapes) {
133 // validate inputs
134 if (inshapes->size() != 1) {
135 MX_ERROR_MSG << "Expected 1 inputs to inferShape";
136 return MX_FAIL;
137 }
138
139 outshapes->at(0).push_back(inshapes->at(0)[1]);
140 outshapes->at(0).push_back(inshapes->at(0)[0]);
141 return MX_SUCCESS;
142 }
143
144 REGISTER_OP(my_transposerowsp)
145 .setForward(forward, "cpu")
146 .setBackward(backward, "cpu")
147 .setParseAttrs(parseAttrs)
148 .setInferType(inferType)
149 .setInferSType(inferSType)
150 .setInferShape(inferShape);
151
152 /* ------------------------------------------------------------------------- */
153
154 class MyStatefulTransposeRowSP : public CustomStatefulOp {
155 public:
MyStatefulTransposeRowSP(int count,const std::unordered_map<std::string,std::string> & attrs)156 explicit MyStatefulTransposeRowSP(int count,
157 const std::unordered_map<std::string, std::string>& attrs)
158 : count(count), attrs_(attrs) {}
159
Forward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)160 MXReturnValue Forward(std::vector<MXTensor>* inputs,
161 std::vector<MXTensor>* outputs,
162 const OpResource& op_res) {
163 std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
164 return forward(attrs_, inputs, outputs, op_res);
165 }
166
Backward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)167 MXReturnValue Backward(std::vector<MXTensor>* inputs,
168 std::vector<MXTensor>* outputs,
169 const OpResource& op_res) {
170 return backward(attrs_, inputs, outputs, op_res);
171 }
172
173 private:
174 int count;
175 const std::unordered_map<std::string, std::string> attrs_;
176 };
177
createOpState(const std::unordered_map<std::string,std::string> & attrs,const MXContext & ctx,const std::vector<std::vector<unsigned int>> & in_shapes,const std::vector<int> in_types,CustomStatefulOp ** op_inst)178 MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
179 const MXContext& ctx,
180 const std::vector<std::vector<unsigned int> >& in_shapes,
181 const std::vector<int> in_types,
182 CustomStatefulOp** op_inst) {
183 // testing passing of keyword arguments
184 int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
185 // creating stateful operator instance
186 *op_inst = new MyStatefulTransposeRowSP(count, attrs);
187 (*op_inst)->ignore_warn = true;
188 std::cout << "Info: stateful operator created" << std::endl;
189 return MX_SUCCESS;
190 }
191
192 REGISTER_OP(my_state_transposerowsp)
193 .setParseAttrs(parseAttrs)
194 .setInferType(inferType)
195 .setInferSType(inferSType)
196 .setInferShape(inferShape)
197 .setCreateOpState(createOpState, "cpu");
198
initialize(int version)199 MXReturnValue initialize(int version) {
200 if (version >= 10900) {
201 std::cout << "MXNet version " << version << " supported" << std::endl;
202 return MX_SUCCESS;
203 } else {
204 MX_ERROR_MSG << "MXNet version " << version << " not supported";
205 return MX_FAIL;
206 }
207 }
208