1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_OPERATOR_FUSION_FUSED_OP_H_
21 #define MXNET_OPERATOR_FUSION_FUSED_OP_H_
22 
23 #include <mxnet/operator.h>
24 #include <nnvm/graph.h>
25 #include <vector>
26 #include <string>
27 #include <utility>
28 #include <mutex>
29 #include <tuple>
30 
31 #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
32 
33 namespace mxnet {
34 
35 namespace fusion {
36   enum KernelVariants {kGeneral, kShapeOptimized,
37     kNumKernelVariants  // Not a variant- leave this at the end
38   };
39 }
40 
41 struct FusedOpConfig : public dmlc::Parameter<FusedOpConfig> {
42   int num_inputs;
43   int num_outputs;
DMLC_DECLARE_PARAMETERFusedOpConfig44   DMLC_DECLARE_PARAMETER(FusedOpConfig) {
45     DMLC_DECLARE_FIELD(num_inputs)
46     .describe("Number of inputs.");
47     DMLC_DECLARE_FIELD(num_outputs)
48     .describe("Number of outputs.");
49   }
50 };
51 
52 struct FusedOpEntry {
FusedOpEntryFusedOpEntry53   FusedOpEntry() : dtype(-1), ndim(-1) {}
54   int dtype;
55   int ndim;
56 };
57 
58 class FusedOp {
59  public:
60   static const int NTHREADS = 512;
61   static const int CACHESIZE_WARN_THRESHOLD = 10000;
62 
63   explicit FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config);
~FusedOp()64   ~FusedOp() {}
num_inputs()65   uint32_t num_inputs() const {
66     return inputs_.size();
67   }
num_outputs()68   uint32_t num_outputs() const {
69     return outputs_.size();
70   }
71 
72   template <typename xpu>
73   void Forward(const nnvm::NodeAttrs& attrs,
74                const OpContext &ctx,
75                const std::vector<TBlob> &inputs,
76                const std::vector<OpReqType> &req,
77                const std::vector<TBlob> &outputs);
78 
79   bool InferShape(const nnvm::NodeAttrs &attrs,
80                   std::vector<mxnet::TShape> *in_attrs,
81                   std::vector<mxnet::TShape> *out_attrs);
82 
83   bool InferType(const nnvm::NodeAttrs &attrs,
84                  std::vector<int> *in_attrs,
85                  std::vector<int> *out_attrs);
86 
87   template <typename Attr>
88   std::tuple<const nnvm::ObjectPtr,
89              std::vector<Attr>,
90              std::vector<Attr>>
91     GetAttrs(const std::string& attr_name,
92              const uint32_t node_id);
93 
ProvideShape(const std::vector<nnvm::ObjectPtr> & nodes,const std::vector<std::vector<mxnet::TShape>> & in_attrs,const std::vector<std::vector<mxnet::TShape>> & out_attrs)94   void ProvideShape(const std::vector<nnvm::ObjectPtr>& nodes,
95                     const std::vector<std::vector<mxnet::TShape>> &in_attrs,
96                     const std::vector<std::vector<mxnet::TShape>> &out_attrs) {
97     aux_nodes_ = nodes;
98     aux_in_shapes_ = in_attrs;
99     aux_out_shapes_ = out_attrs;
100   }
101 
ProvideType(const std::vector<nnvm::ObjectPtr> & nodes,const std::vector<std::vector<int>> & in_attrs,const std::vector<std::vector<int>> & out_attrs)102   void ProvideType(const std::vector<nnvm::ObjectPtr>& nodes,
103                    const std::vector<std::vector<int>> &in_attrs,
104                    const std::vector<std::vector<int>> &out_attrs) {
105     aux_nodes_ = nodes;
106     aux_in_types_ = in_attrs;
107     aux_out_types_ = out_attrs;
108   }
109 
110   std::tuple<const nnvm::ObjectPtr,
111              std::vector<mxnet::TShape>,
112              std::vector<mxnet::TShape>>
GetAuxShape(const int node_id)113     GetAuxShape(const int node_id) const {
114     return std::make_tuple(aux_nodes_[node_id],
115                            aux_in_shapes_[node_id],
116                            aux_out_shapes_[node_id]);
117   }
118 
119   std::tuple<const nnvm::ObjectPtr,
120              std::vector<int>,
121              std::vector<int>>
GetAuxType(const int node_id)122     GetAuxType(const int node_id) const {
123     return std::make_tuple(aux_nodes_[node_id],
124                            aux_in_types_[node_id],
125                            aux_out_types_[node_id]);
126   }
127 
128  private:
129   std::string GenerateCode(const std::vector<OpReqType> &req,
130                            const std::vector<int> &in_dtypes,
131                            const std::vector<int> &out_dtypes,
132                            const std::vector<int> &in_ndims,
133                            const std::vector<int> &out_ndims,
134                            const mxnet::ShapeVector &node_shapes,
135                            const std::vector<int> &node_dtypes,
136                            const int nvec,
137                            const std::string& kernel_name,
138                            std::vector<uint32_t> *check_shapes);
139 
140   CUfunction CompileCode(const std::string &code,
141                          const std::string &kernel_name, int dev_id);
142 
143   void CheckShapesAndTypes(const std::vector<TBlob> &inputs,
144                            const std::vector<TBlob> &outputs,
145                            std::vector<int> *in_dtypes,
146                            std::vector<int> *in_ndims,
147                            std::vector<int> *out_dtypes,
148                            std::vector<int> *out_ndims,
149                            int *nvec);
150 
151   std::vector<FusedOpEntry> inputs_;
152   std::vector<FusedOpEntry> outputs_;
153 
154   nnvm::Graph subgraph_;
155 
156   template <typename T>
157   struct IntermediateAttr {
158     std::vector<T> input_attr;
159     std::vector<T> output_attr;
160     std::vector<T> internal_attr;
161   };
162 
163   // Shapes and types inside the subgraph
164   // copied here, because a subsequent call
165   // to InferShape/InferType can overwrite the
166   // original information stored in subgraph_
167   // attributes while the previous iterations
168   // still need them.
169   std::vector<IntermediateAttr<mxnet::TShape> > intermediate_shapes_;
170   std::vector<IntermediateAttr<int> > intermediate_dtypes_;
171 
172   std::vector<nnvm::ObjectPtr> aux_nodes_;
173   std::vector<std::vector<mxnet::TShape>> aux_in_shapes_;
174   std::vector<std::vector<mxnet::TShape>> aux_out_shapes_;
175   std::vector<std::vector<int>> aux_in_types_;
176   std::vector<std::vector<int>> aux_out_types_;
177   std::vector<OpReqType> saved_reqs_;
178   std::vector<uint32_t> extra_shape_args_;
179   std::vector<uint32_t> check_shape_args_;
180 
181   CUfunction kernel_functions_[fusion::kNumKernelVariants];
182   bool initialized_;
183   int kernel_function_dev_id_;
184 
185   static std::mutex mutex_;
186   std::mutex my_mutex_;
187 };
188 
189 using FusedOpPtr = std::shared_ptr<FusedOp>;
190 
191 struct FusedOpHelperParam {
192   FusedOpPtr op;
193   uint32_t node_id;
194 
FusedOpHelperParamFusedOpHelperParam195   FusedOpHelperParam(FusedOpPtr op, uint32_t node_id) :
196     op(op),
197     node_id(node_id) {}
198 };
199 
200 using FusedOpHelperParamPtr = std::shared_ptr<FusedOpHelperParam>;
201 
202 }  // namespace mxnet
203 
204 #endif  // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
205 
206 #endif  // MXNET_OPERATOR_FUSION_FUSED_OP_H_
207