1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file transsparse_lib.cc
22  * \brief Sample 2D transpose custom operator.
23  */
24 
25 #include <iostream>
26 #include "mxnet/lib_api.h"
27 
28 using namespace mxnet::ext;
29 
transpose(MXTensor & src,MXTensor & dst,const OpResource & res)30 void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
31   MXSparse* A = src.data<MXSparse>();
32   MXSparse* B = dst.data<MXSparse>();
33 
34   std::vector<int64_t> shape = src.shape;
35   int64_t h = shape[0];
36   int64_t w = shape[1];
37   if(src.stype == kRowSparseStorage) {
38     // Keys of the map is the row index of transposed tensors.
39     // Values of the map is the rows which have non-zero elements.
40     std::map<int, std::vector<float>> mp;
41     float *Aval = (float*) (A->data);
42     for(int i = 0; i < A->data_len; i++) {
43       int row = i / w;
44       int col = i % w;
45       row = A->indices[row];
46       if(Aval[i] != 0) {
47         if(mp.find(col) == mp.end()) {
48           mp[col] = std::vector<float>(h, 0);
49           mp[col][row] = Aval[i];
50         }
51         else {
52           mp[col][row] = Aval[i];
53         }
54       }
55     }
56 
57     // Alloc memory for output tensors.
58     res.alloc_sparse(B, 0, mp.size());
59     float *Bval = (float*) (B->data);
60     int didx = 0, iidx = 0;
61     for(auto i : mp) {
62       B->indices[iidx++] = i.first;
63       for(auto j : i.second) {
64         Bval[didx++] = j;
65       }
66     }
67   }
68 }
69 
forward(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)70 MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
71                       std::vector<MXTensor>* inputs,
72                       std::vector<MXTensor>* outputs,
73                       const OpResource& res) {
74   // The data types and storage types of inputs and outputs should be the same.
75   if(inputs->at(0).dtype != outputs->at(0).dtype ||
76      inputs->at(0).stype != outputs->at(0).stype) {
77     MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
78                  << "Found input storage type:" << inputs->at(0).stype
79                  << " Found output storage type:" << outputs->at(0).stype
80                  << " Found input data type:" << inputs->at(0).dtype
81                  << " Found output data type:" << outputs->at(0).dtype;
82     return MX_FAIL;
83   }
84   transpose(inputs->at(0), outputs->at(0), res);
85   return MX_SUCCESS;
86 }
87 
backward(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)88 MXReturnValue backward(const std::unordered_map<std::string, std::string>& attrs,
89                        std::vector<MXTensor>* inputs,
90                        std::vector<MXTensor>* outputs,
91                        const OpResource& res) {
92   return MX_SUCCESS;
93 }
94 
parseAttrs(const std::unordered_map<std::string,std::string> & attrs,int * num_in,int * num_out)95 MXReturnValue parseAttrs(const std::unordered_map<std::string, std::string>& attrs,
96                          int* num_in, int* num_out) {
97   *num_in = 1;
98   *num_out = 1;
99   return MX_SUCCESS;
100 }
101 
inferType(const std::unordered_map<std::string,std::string> & attrs,std::vector<int> * intypes,std::vector<int> * outtypes)102 MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attrs,
103                         std::vector<int>* intypes,
104                         std::vector<int>* outtypes) {
105   // validate inputs
106   if (intypes->size() != 1) {
107     MX_ERROR_MSG << "Expected 1 inputs to inferType";
108     return MX_FAIL;
109   }
110   if (intypes->at(0) != kFloat32) {
111     MX_ERROR_MSG << "Expected input to have float32 type";
112     return MX_FAIL;
113   }
114 
115   outtypes->at(0) = intypes->at(0);
116   return MX_SUCCESS;
117 }
118 
inferSType(const std::unordered_map<std::string,std::string> & attrs,std::vector<int> * instypes,std::vector<int> * outstypes)119 MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& attrs,
120                          std::vector<int>* instypes,
121                          std::vector<int>* outstypes) {
122   if (instypes->at(0) != kRowSparseStorage) {
123     MX_ERROR_MSG << "Expected storage type is kRowSparseStorage";
124     return MX_FAIL;
125   }
126   outstypes->at(0) = instypes->at(0);
127   return MX_SUCCESS;
128 }
129 
inferShape(const std::unordered_map<std::string,std::string> & attrs,std::vector<std::vector<unsigned int>> * inshapes,std::vector<std::vector<unsigned int>> * outshapes)130 MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& attrs,
131                          std::vector<std::vector<unsigned int>>* inshapes,
132                          std::vector<std::vector<unsigned int>>* outshapes) {
133   // validate inputs
134   if (inshapes->size() != 1) {
135     MX_ERROR_MSG << "Expected 1 inputs to inferShape";
136     return MX_FAIL;
137   }
138 
139   outshapes->at(0).push_back(inshapes->at(0)[1]);
140   outshapes->at(0).push_back(inshapes->at(0)[0]);
141   return MX_SUCCESS;
142 }
143 
144 REGISTER_OP(my_transposerowsp)
145 .setForward(forward, "cpu")
146 .setBackward(backward, "cpu")
147 .setParseAttrs(parseAttrs)
148 .setInferType(inferType)
149 .setInferSType(inferSType)
150 .setInferShape(inferShape);
151 
152 /* ------------------------------------------------------------------------- */
153 
154 class MyStatefulTransposeRowSP : public CustomStatefulOp {
155   public:
MyStatefulTransposeRowSP(int count,const std::unordered_map<std::string,std::string> & attrs)156     explicit MyStatefulTransposeRowSP(int count,
157                                       const std::unordered_map<std::string, std::string>& attrs)
158       : count(count), attrs_(attrs) {}
159 
Forward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)160     MXReturnValue Forward(std::vector<MXTensor>* inputs,
161                           std::vector<MXTensor>* outputs,
162                           const OpResource& op_res) {
163       std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
164       return forward(attrs_, inputs, outputs, op_res);
165     }
166 
Backward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)167     MXReturnValue Backward(std::vector<MXTensor>* inputs,
168                            std::vector<MXTensor>* outputs,
169                            const OpResource& op_res) {
170       return backward(attrs_, inputs, outputs, op_res);
171     }
172 
173   private:
174     int count;
175     const std::unordered_map<std::string, std::string> attrs_;
176 };
177 
createOpState(const std::unordered_map<std::string,std::string> & attrs,const MXContext & ctx,const std::vector<std::vector<unsigned int>> & in_shapes,const std::vector<int> in_types,CustomStatefulOp ** op_inst)178 MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
179                             const MXContext& ctx,
180                             const std::vector<std::vector<unsigned int> >& in_shapes,
181                             const std::vector<int> in_types,
182                             CustomStatefulOp** op_inst) {
183   // testing passing of keyword arguments
184   int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
185   // creating stateful operator instance
186   *op_inst = new MyStatefulTransposeRowSP(count, attrs);
187   (*op_inst)->ignore_warn = true;
188   std::cout << "Info: stateful operator created" << std::endl;
189   return MX_SUCCESS;
190 }
191 
192 REGISTER_OP(my_state_transposerowsp)
193 .setParseAttrs(parseAttrs)
194 .setInferType(inferType)
195 .setInferSType(inferSType)
196 .setInferShape(inferShape)
197 .setCreateOpState(createOpState, "cpu");
198 
initialize(int version)199 MXReturnValue initialize(int version) {
200   if (version >= 10900) {
201     std::cout << "MXNet version " << version << " supported" << std::endl;
202     return MX_SUCCESS;
203   } else {
204     MX_ERROR_MSG << "MXNet version " << version << " not supported";
205     return MX_FAIL;
206   }
207 }
208