1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file transsparse_lib.cc
22  * \brief Sample 2D transpose custom operator.
23  */
24 
25 #include <iostream>
26 #include "mxnet/lib_api.h"
27 
28 using namespace mxnet::ext;
29 
transpose(MXTensor & src,MXTensor & dst,const OpResource & res)30 void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
31   MXSparse* A = src.data<MXSparse>();
32   MXSparse* B = dst.data<MXSparse>();
33   std::vector<int64_t> shape = src.shape;
34   int64_t h = shape[0];
35   int64_t w = shape[1];
36   if(src.stype == kCSRStorage) {
37     float *Aval = (float*) (A->data);
38     // Here we need one more element to help calculate index(line 57).
39     std::vector<int64_t> rowPtr(w + 2, 0);
40     // count column
41     for(int i = 0; i < A->data_len; i++) {
42       rowPtr[A->indices[i] + 2]++;
43     }
44     // Accumulated sum. After this for loop, rowPtr[1:w+2) stores the correct
45     // result of transposed rowPtr.
46     for(int i = 2; i < rowPtr.size(); i++) {
47       rowPtr[i] += rowPtr[i - 1];
48     }
49 
50     // Alloc memory for sparse data, where 0 is the index
51     // of B in output vector.
52     res.alloc_sparse(B, 0, A->data_len, w + 1);
53     float *Bval = (float*) (B->data);
54     for(int i = 0; i < h; i++) {
55       for(int j = A->indptr[i]; j < A->indptr[i + 1]; j++) {
56         // Helps calculate index and after that rowPtr[0:w+1) stores the
57         // correct result of transposed rowPtr.
58         int index = rowPtr[A->indices[j] + 1]++;
59         Bval[index] = Aval[j];
60         B->indices[index] = i;
61       }
62     }
63     memcpy(B->indptr, rowPtr.data(), sizeof(int64_t) * (w + 1));
64   }
65 }
66 
forward(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)67 MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
68                       std::vector<MXTensor>* inputs,
69                       std::vector<MXTensor>* outputs,
70                       const OpResource& res) {
71   // The data types and storage types of inputs and outputs should be the same.
72   if(inputs->at(0).dtype != outputs->at(0).dtype ||
73      inputs->at(0).stype != outputs->at(0).stype) {
74     MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
75                  << "Found input storage type:" << inputs->at(0).stype
76                  << " Found output storage type:" << outputs->at(0).stype
77                  << " Found input data type:" << inputs->at(0).dtype
78                  << " Found output data type:" << outputs->at(0).dtype;
79     return MX_FAIL;
80   }
81 
82   transpose(inputs->at(0), outputs->at(0), res);
83   return MX_SUCCESS;
84 }
85 
backward(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)86 MXReturnValue backward(const std::unordered_map<std::string, std::string>& attrs,
87                        std::vector<MXTensor>* inputs,
88                        std::vector<MXTensor>* outputs,
89                        const OpResource& res) {
90   return MX_SUCCESS;
91 }
92 
parseAttrs(const std::unordered_map<std::string,std::string> & attrs,int * num_in,int * num_out)93 MXReturnValue parseAttrs(const std::unordered_map<std::string, std::string>& attrs,
94                          int* num_in, int* num_out) {
95   *num_in = 1;
96   *num_out = 1;
97   return MX_SUCCESS;
98 }
99 
inferType(const std::unordered_map<std::string,std::string> & attrs,std::vector<int> * intypes,std::vector<int> * outtypes)100 MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attrs,
101                         std::vector<int>* intypes,
102                         std::vector<int>* outtypes) {
103   // validate inputs
104   if (intypes->size() != 1) {
105     MX_ERROR_MSG << "Expected 1 inputs to inferType";
106     return MX_FAIL;
107   }
108   if (intypes->at(0) != kFloat32) {
109     MX_ERROR_MSG << "Expected input to have float32 type";
110     return MX_FAIL;
111   }
112 
113   outtypes->at(0) = intypes->at(0);
114   return MX_SUCCESS;
115 }
116 
inferSType(const std::unordered_map<std::string,std::string> & attrs,std::vector<int> * instypes,std::vector<int> * outstypes)117 MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& attrs,
118                          std::vector<int>* instypes,
119                          std::vector<int>* outstypes) {
120   if (instypes->at(0) != kCSRStorage) {
121     MX_ERROR_MSG << "Expected storage type is kCSRStorage";
122     return MX_FAIL;
123   }
124   outstypes->at(0) = instypes->at(0);
125   return MX_SUCCESS;
126 }
127 
inferShape(const std::unordered_map<std::string,std::string> & attrs,std::vector<std::vector<unsigned int>> * inshapes,std::vector<std::vector<unsigned int>> * outshapes)128 MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& attrs,
129                          std::vector<std::vector<unsigned int>>* inshapes,
130                          std::vector<std::vector<unsigned int>>* outshapes) {
131   // validate inputs
132   if (inshapes->size() != 1) {
133     MX_ERROR_MSG << "Expected 1 inputs to inferShape";
134     return MX_FAIL;
135   }
136 
137   outshapes->at(0).push_back(inshapes->at(0)[1]);
138   outshapes->at(0).push_back(inshapes->at(0)[0]);
139   return MX_SUCCESS;
140 }
141 
142 REGISTER_OP(my_transposecsr)
143 .setForward(forward, "cpu")
144 .setBackward(backward, "cpu")
145 .setParseAttrs(parseAttrs)
146 .setInferType(inferType)
147 .setInferSType(inferSType)
148 .setInferShape(inferShape);
149 
150 /* ------------------------------------------------------------------------- */
151 
152 class MyStatefulTransposeCSR : public CustomStatefulOp {
153   public:
MyStatefulTransposeCSR(int count,const std::unordered_map<std::string,std::string> & attrs)154     explicit MyStatefulTransposeCSR(int count,
155                                     const std::unordered_map<std::string, std::string>& attrs)
156       : count(count), attrs_(attrs) {}
157 
Forward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)158     MXReturnValue Forward(std::vector<MXTensor>* inputs,
159                           std::vector<MXTensor>* outputs,
160                           const OpResource& op_res) {
161       std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
162       return forward(attrs_, inputs, outputs, op_res);
163     }
164 
Backward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)165     MXReturnValue Backward(std::vector<MXTensor>* inputs,
166                            std::vector<MXTensor>* outputs,
167                            const OpResource& op_res) {
168       return backward(attrs_, inputs, outputs, op_res);
169     }
170 
171   private:
172     int count;
173     const std::unordered_map<std::string, std::string> attrs_;
174 };
175 
createOpState(const std::unordered_map<std::string,std::string> & attrs,const MXContext & ctx,const std::vector<std::vector<unsigned int>> & in_shapes,const std::vector<int> in_types,CustomStatefulOp ** op_inst)176 MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
177                             const MXContext& ctx,
178                             const std::vector<std::vector<unsigned int> >& in_shapes,
179                             const std::vector<int> in_types,
180                             CustomStatefulOp** op_inst) {
181   // testing passing of keyword arguments
182   int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
183   // creating stateful operator instance
184   *op_inst = new MyStatefulTransposeCSR(count, attrs);
185   std::cout << "Info: stateful operator created" << std::endl;
186   return MX_SUCCESS;
187 }
188 
189 REGISTER_OP(my_state_transposecsr)
190 .setParseAttrs(parseAttrs)
191 .setInferType(inferType)
192 .setInferSType(inferSType)
193 .setInferShape(inferShape)
194 .setCreateOpState(createOpState, "cpu");
195 
initialize(int version)196 MXReturnValue initialize(int version) {
197   if (version >= 10900) {
198     std::cout << "MXNet version " << version << " supported" << std::endl;
199     return MX_SUCCESS;
200   } else {
201     MX_ERROR_MSG << "MXNet version " << version << " not supported";
202     return MX_FAIL;
203   }
204 }
205