1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 #ifndef MXNET_OPERATOR_FUSION_FUSED_OP_H_ 21 #define MXNET_OPERATOR_FUSION_FUSED_OP_H_ 22 23 #include <mxnet/operator.h> 24 #include <nnvm/graph.h> 25 #include <vector> 26 #include <string> 27 #include <utility> 28 #include <mutex> 29 #include <tuple> 30 31 #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC 32 33 namespace mxnet { 34 35 namespace fusion { 36 enum KernelVariants {kGeneral, kShapeOptimized, 37 kNumKernelVariants // Not a variant- leave this at the end 38 }; 39 } 40 41 struct FusedOpConfig : public dmlc::Parameter<FusedOpConfig> { 42 int num_inputs; 43 int num_outputs; DMLC_DECLARE_PARAMETERFusedOpConfig44 DMLC_DECLARE_PARAMETER(FusedOpConfig) { 45 DMLC_DECLARE_FIELD(num_inputs) 46 .describe("Number of inputs."); 47 DMLC_DECLARE_FIELD(num_outputs) 48 .describe("Number of outputs."); 49 } 50 }; 51 52 struct FusedOpEntry { FusedOpEntryFusedOpEntry53 FusedOpEntry() : dtype(-1), ndim(-1) {} 54 int dtype; 55 int ndim; 56 }; 57 58 class FusedOp { 59 public: 60 static const int NTHREADS = 512; 61 static const int CACHESIZE_WARN_THRESHOLD = 10000; 62 63 explicit FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config); ~FusedOp()64 ~FusedOp() {} num_inputs()65 uint32_t num_inputs() const { 66 return inputs_.size(); 67 } num_outputs()68 uint32_t num_outputs() const { 69 return outputs_.size(); 70 } 71 72 template <typename xpu> 73 void Forward(const nnvm::NodeAttrs& attrs, 74 const OpContext &ctx, 75 const std::vector<TBlob> &inputs, 76 const std::vector<OpReqType> &req, 77 const std::vector<TBlob> &outputs); 78 79 bool InferShape(const nnvm::NodeAttrs &attrs, 80 std::vector<mxnet::TShape> *in_attrs, 81 std::vector<mxnet::TShape> *out_attrs); 82 83 bool InferType(const nnvm::NodeAttrs &attrs, 84 std::vector<int> *in_attrs, 85 std::vector<int> *out_attrs); 86 87 template <typename Attr> 88 std::tuple<const nnvm::ObjectPtr, 89 std::vector<Attr>, 90 std::vector<Attr>> 91 GetAttrs(const std::string& attr_name, 92 const uint32_t node_id); 93 ProvideShape(const std::vector<nnvm::ObjectPtr> & nodes,const std::vector<std::vector<mxnet::TShape>> & in_attrs,const std::vector<std::vector<mxnet::TShape>> & out_attrs)94 void ProvideShape(const std::vector<nnvm::ObjectPtr>& nodes, 95 const std::vector<std::vector<mxnet::TShape>> &in_attrs, 96 const std::vector<std::vector<mxnet::TShape>> &out_attrs) { 97 aux_nodes_ = nodes; 98 aux_in_shapes_ = in_attrs; 99 aux_out_shapes_ = out_attrs; 100 } 101 ProvideType(const std::vector<nnvm::ObjectPtr> & nodes,const std::vector<std::vector<int>> & in_attrs,const std::vector<std::vector<int>> & out_attrs)102 void ProvideType(const std::vector<nnvm::ObjectPtr>& nodes, 103 const std::vector<std::vector<int>> &in_attrs, 104 const std::vector<std::vector<int>> &out_attrs) { 105 aux_nodes_ = nodes; 106 aux_in_types_ = in_attrs; 107 aux_out_types_ = out_attrs; 108 } 109 110 std::tuple<const nnvm::ObjectPtr, 111 std::vector<mxnet::TShape>, 112 std::vector<mxnet::TShape>> GetAuxShape(const int node_id)113 GetAuxShape(const int node_id) const { 114 return std::make_tuple(aux_nodes_[node_id], 115 aux_in_shapes_[node_id], 116 aux_out_shapes_[node_id]); 117 } 118 119 std::tuple<const nnvm::ObjectPtr, 120 std::vector<int>, 121 std::vector<int>> GetAuxType(const int node_id)122 GetAuxType(const int node_id) const { 123 return std::make_tuple(aux_nodes_[node_id], 124 aux_in_types_[node_id], 125 aux_out_types_[node_id]); 126 } 127 128 private: 129 std::string GenerateCode(const std::vector<OpReqType> &req, 130 const std::vector<int> &in_dtypes, 131 const std::vector<int> &out_dtypes, 132 const std::vector<int> &in_ndims, 133 const std::vector<int> &out_ndims, 134 const mxnet::ShapeVector &node_shapes, 135 const std::vector<int> &node_dtypes, 136 const int nvec, 137 const std::string& kernel_name, 138 std::vector<uint32_t> *check_shapes); 139 140 CUfunction CompileCode(const std::string &code, 141 const std::string &kernel_name, int dev_id); 142 143 void CheckShapesAndTypes(const std::vector<TBlob> &inputs, 144 const std::vector<TBlob> &outputs, 145 std::vector<int> *in_dtypes, 146 std::vector<int> *in_ndims, 147 std::vector<int> *out_dtypes, 148 std::vector<int> *out_ndims, 149 int *nvec); 150 151 std::vector<FusedOpEntry> inputs_; 152 std::vector<FusedOpEntry> outputs_; 153 154 nnvm::Graph subgraph_; 155 156 template <typename T> 157 struct IntermediateAttr { 158 std::vector<T> input_attr; 159 std::vector<T> output_attr; 160 std::vector<T> internal_attr; 161 }; 162 163 // Shapes and types inside the subgraph 164 // copied here, because a subsequent call 165 // to InferShape/InferType can overwrite the 166 // original information stored in subgraph_ 167 // attributes while the previous iterations 168 // still need them. 169 std::vector<IntermediateAttr<mxnet::TShape> > intermediate_shapes_; 170 std::vector<IntermediateAttr<int> > intermediate_dtypes_; 171 172 std::vector<nnvm::ObjectPtr> aux_nodes_; 173 std::vector<std::vector<mxnet::TShape>> aux_in_shapes_; 174 std::vector<std::vector<mxnet::TShape>> aux_out_shapes_; 175 std::vector<std::vector<int>> aux_in_types_; 176 std::vector<std::vector<int>> aux_out_types_; 177 std::vector<OpReqType> saved_reqs_; 178 std::vector<uint32_t> extra_shape_args_; 179 std::vector<uint32_t> check_shape_args_; 180 181 CUfunction kernel_functions_[fusion::kNumKernelVariants]; 182 bool initialized_; 183 int kernel_function_dev_id_; 184 185 static std::mutex mutex_; 186 std::mutex my_mutex_; 187 }; 188 189 using FusedOpPtr = std::shared_ptr<FusedOp>; 190 191 struct FusedOpHelperParam { 192 FusedOpPtr op; 193 uint32_t node_id; 194 FusedOpHelperParamFusedOpHelperParam195 FusedOpHelperParam(FusedOpPtr op, uint32_t node_id) : 196 op(op), 197 node_id(node_id) {} 198 }; 199 200 using FusedOpHelperParamPtr = std::shared_ptr<FusedOpHelperParam>; 201 202 } // namespace mxnet 203 204 #endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC 205 206 #endif // MXNET_OPERATOR_FUSION_FUSED_OP_H_ 207