1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file transsparse_lib.cc
22 * \brief Sample 2D transpose custom operator.
23 */
24
25 #include <iostream>
26 #include "mxnet/lib_api.h"
27
28 using namespace mxnet::ext;
29
transpose(MXTensor & src,MXTensor & dst,const OpResource & res)30 void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
31 MXSparse* A = src.data<MXSparse>();
32 MXSparse* B = dst.data<MXSparse>();
33 std::vector<int64_t> shape = src.shape;
34 int64_t h = shape[0];
35 int64_t w = shape[1];
36 if(src.stype == kCSRStorage) {
37 float *Aval = (float*) (A->data);
38 // Here we need one more element to help calculate index(line 57).
39 std::vector<int64_t> rowPtr(w + 2, 0);
40 // count column
41 for(int i = 0; i < A->data_len; i++) {
42 rowPtr[A->indices[i] + 2]++;
43 }
44 // Accumulated sum. After this for loop, rowPtr[1:w+2) stores the correct
45 // result of transposed rowPtr.
46 for(int i = 2; i < rowPtr.size(); i++) {
47 rowPtr[i] += rowPtr[i - 1];
48 }
49
50 // Alloc memory for sparse data, where 0 is the index
51 // of B in output vector.
52 res.alloc_sparse(B, 0, A->data_len, w + 1);
53 float *Bval = (float*) (B->data);
54 for(int i = 0; i < h; i++) {
55 for(int j = A->indptr[i]; j < A->indptr[i + 1]; j++) {
56 // Helps calculate index and after that rowPtr[0:w+1) stores the
57 // correct result of transposed rowPtr.
58 int index = rowPtr[A->indices[j] + 1]++;
59 Bval[index] = Aval[j];
60 B->indices[index] = i;
61 }
62 }
63 memcpy(B->indptr, rowPtr.data(), sizeof(int64_t) * (w + 1));
64 }
65 }
66
forward(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)67 MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
68 std::vector<MXTensor>* inputs,
69 std::vector<MXTensor>* outputs,
70 const OpResource& res) {
71 // The data types and storage types of inputs and outputs should be the same.
72 if(inputs->at(0).dtype != outputs->at(0).dtype ||
73 inputs->at(0).stype != outputs->at(0).stype) {
74 MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
75 << "Found input storage type:" << inputs->at(0).stype
76 << " Found output storage type:" << outputs->at(0).stype
77 << " Found input data type:" << inputs->at(0).dtype
78 << " Found output data type:" << outputs->at(0).dtype;
79 return MX_FAIL;
80 }
81
82 transpose(inputs->at(0), outputs->at(0), res);
83 return MX_SUCCESS;
84 }
85
backward(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)86 MXReturnValue backward(const std::unordered_map<std::string, std::string>& attrs,
87 std::vector<MXTensor>* inputs,
88 std::vector<MXTensor>* outputs,
89 const OpResource& res) {
90 return MX_SUCCESS;
91 }
92
parseAttrs(const std::unordered_map<std::string,std::string> & attrs,int * num_in,int * num_out)93 MXReturnValue parseAttrs(const std::unordered_map<std::string, std::string>& attrs,
94 int* num_in, int* num_out) {
95 *num_in = 1;
96 *num_out = 1;
97 return MX_SUCCESS;
98 }
99
inferType(const std::unordered_map<std::string,std::string> & attrs,std::vector<int> * intypes,std::vector<int> * outtypes)100 MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attrs,
101 std::vector<int>* intypes,
102 std::vector<int>* outtypes) {
103 // validate inputs
104 if (intypes->size() != 1) {
105 MX_ERROR_MSG << "Expected 1 inputs to inferType";
106 return MX_FAIL;
107 }
108 if (intypes->at(0) != kFloat32) {
109 MX_ERROR_MSG << "Expected input to have float32 type";
110 return MX_FAIL;
111 }
112
113 outtypes->at(0) = intypes->at(0);
114 return MX_SUCCESS;
115 }
116
inferSType(const std::unordered_map<std::string,std::string> & attrs,std::vector<int> * instypes,std::vector<int> * outstypes)117 MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& attrs,
118 std::vector<int>* instypes,
119 std::vector<int>* outstypes) {
120 if (instypes->at(0) != kCSRStorage) {
121 MX_ERROR_MSG << "Expected storage type is kCSRStorage";
122 return MX_FAIL;
123 }
124 outstypes->at(0) = instypes->at(0);
125 return MX_SUCCESS;
126 }
127
inferShape(const std::unordered_map<std::string,std::string> & attrs,std::vector<std::vector<unsigned int>> * inshapes,std::vector<std::vector<unsigned int>> * outshapes)128 MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& attrs,
129 std::vector<std::vector<unsigned int>>* inshapes,
130 std::vector<std::vector<unsigned int>>* outshapes) {
131 // validate inputs
132 if (inshapes->size() != 1) {
133 MX_ERROR_MSG << "Expected 1 inputs to inferShape";
134 return MX_FAIL;
135 }
136
137 outshapes->at(0).push_back(inshapes->at(0)[1]);
138 outshapes->at(0).push_back(inshapes->at(0)[0]);
139 return MX_SUCCESS;
140 }
141
142 REGISTER_OP(my_transposecsr)
143 .setForward(forward, "cpu")
144 .setBackward(backward, "cpu")
145 .setParseAttrs(parseAttrs)
146 .setInferType(inferType)
147 .setInferSType(inferSType)
148 .setInferShape(inferShape);
149
150 /* ------------------------------------------------------------------------- */
151
152 class MyStatefulTransposeCSR : public CustomStatefulOp {
153 public:
MyStatefulTransposeCSR(int count,const std::unordered_map<std::string,std::string> & attrs)154 explicit MyStatefulTransposeCSR(int count,
155 const std::unordered_map<std::string, std::string>& attrs)
156 : count(count), attrs_(attrs) {}
157
Forward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)158 MXReturnValue Forward(std::vector<MXTensor>* inputs,
159 std::vector<MXTensor>* outputs,
160 const OpResource& op_res) {
161 std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
162 return forward(attrs_, inputs, outputs, op_res);
163 }
164
Backward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)165 MXReturnValue Backward(std::vector<MXTensor>* inputs,
166 std::vector<MXTensor>* outputs,
167 const OpResource& op_res) {
168 return backward(attrs_, inputs, outputs, op_res);
169 }
170
171 private:
172 int count;
173 const std::unordered_map<std::string, std::string> attrs_;
174 };
175
createOpState(const std::unordered_map<std::string,std::string> & attrs,const MXContext & ctx,const std::vector<std::vector<unsigned int>> & in_shapes,const std::vector<int> in_types,CustomStatefulOp ** op_inst)176 MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
177 const MXContext& ctx,
178 const std::vector<std::vector<unsigned int> >& in_shapes,
179 const std::vector<int> in_types,
180 CustomStatefulOp** op_inst) {
181 // testing passing of keyword arguments
182 int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
183 // creating stateful operator instance
184 *op_inst = new MyStatefulTransposeCSR(count, attrs);
185 std::cout << "Info: stateful operator created" << std::endl;
186 return MX_SUCCESS;
187 }
188
189 REGISTER_OP(my_state_transposecsr)
190 .setParseAttrs(parseAttrs)
191 .setInferType(inferType)
192 .setInferSType(inferSType)
193 .setInferShape(inferShape)
194 .setCreateOpState(createOpState, "cpu");
195
initialize(int version)196 MXReturnValue initialize(int version) {
197 if (version >= 10900) {
198 std::cout << "MXNet version " << version << " supported" << std::endl;
199 return MX_SUCCESS;
200 } else {
201 MX_ERROR_MSG << "MXNet version " << version << " not supported";
202 return MX_FAIL;
203 }
204 }
205