1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file warpctc-inl.h 22 * \brief warpctc operator 23 * \author Liang Xiang 24 */ 25 #ifndef PLUGIN_WARPCTC_WARPCTC_INL_H_ 26 #define PLUGIN_WARPCTC_WARPCTC_INL_H_ 27 28 #include <dmlc/logging.h> 29 #include <dmlc/parameter.h> 30 #include <mxnet/operator.h> 31 #include <stdio.h> 32 #include <ctc.h> 33 #include <cstring> 34 #include <map> 35 #include <string> 36 #include <vector> 37 #include <utility> 38 #include <iostream> 39 #include "../../src/operator/operator_common.h" 40 41 namespace mxnet { 42 namespace op { 43 44 namespace warpctc_enum { 45 enum CTCOpInputs {kData, kLabel}; 46 enum CTCOpOutputs {kOut}; 47 enum CTCTemp {kTmp}; 48 } // namespace warpctc_enum 49 50 struct WarpCTCParam : public dmlc::Parameter<WarpCTCParam> { 51 int label_length; 52 int input_length; DMLC_DECLARE_PARAMETERWarpCTCParam53 DMLC_DECLARE_PARAMETER(WarpCTCParam) { 54 DMLC_DECLARE_FIELD(label_length) 55 .set_default(0) 56 .describe("Real label length"); 57 DMLC_DECLARE_FIELD(input_length) 58 .set_default(0) 59 .describe("Input length"); 60 } 61 }; 62 63 template<typename xpu> 64 class WarpCTCOp : public Operator { 65 private: 66 WarpCTCParam param_; 67 68 public: WarpCTCOp(WarpCTCParam p)69 explicit WarpCTCOp(WarpCTCParam p) { 70 this->param_ = p; 71 } 72 ~WarpCTCOp()73 ~WarpCTCOp() { 74 } 75 throw_on_error(ctcStatus_t status,const char * message)76 inline void throw_on_error(ctcStatus_t status, const char* message) { 77 if (status != CTC_STATUS_SUCCESS) { 78 throw std::runtime_error(message 79 + (", stat = " 80 + std::string(ctcGetStatusString(status)))); 81 } 82 } 83 Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)84 virtual void Forward(const OpContext &ctx, 85 const std::vector<TBlob> &in_data, 86 const std::vector<OpReqType> &req, 87 const std::vector<TBlob> &out_data, 88 const std::vector<TBlob> &aux_args) { 89 using namespace mshadow; 90 using namespace mshadow::expr; 91 CHECK_EQ(in_data.size(), 2) << "CTCOutput Input: [data, label]"; 92 CHECK_EQ(out_data.size(), 1) << "CTCOutput Output: [output]"; 93 94 Stream<xpu> *s = ctx.get_stream<xpu>(); 95 TBlob data = in_data[warpctc_enum::kData]; 96 TBlob out = out_data[warpctc_enum::kOut]; 97 Tensor<xpu, 2, float> data_tensor = data.FlatTo2D<xpu, float>(s); 98 Tensor<xpu, 2, float> out_tensor = out.FlatTo2D<xpu, float>(s); 99 Softmax(out_tensor, data_tensor); 100 } 101 labelLengths(const int * flat_labels,int minibatch,int size,int blank,int * total_length)102 std::vector<int> labelLengths(const int * flat_labels, int minibatch, 103 int size, int blank, int * total_length) { 104 CHECK_EQ(param_.label_length * minibatch, size) 105 << "label size should = label_length * minibatch"; 106 std::vector<int> ret(minibatch, 0); 107 for (int i = 0; i < size; i++) { 108 if (flat_labels[i] == blank) { 109 continue; 110 } 111 int b = i / param_.label_length; 112 ret[b]++; 113 (*total_length)++; 114 } 115 return ret; 116 } 117 removeBlank(const int * flat_labels,int * cpu_labels,int size,int blank)118 void removeBlank(const int * flat_labels, int * cpu_labels, 119 int size, int blank) { 120 int k = 0; 121 for (int i = 0; i < size; i++) { 122 if (flat_labels[i] != blank) { 123 cpu_labels[k] = flat_labels[i]; 124 k += 1; 125 } 126 } 127 } 128 Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_args)129 virtual void Backward(const OpContext &ctx, 130 const std::vector<TBlob> &out_grad, 131 const std::vector<TBlob> &in_data, 132 const std::vector<TBlob> &out_data, 133 const std::vector<OpReqType> &req, 134 const std::vector<TBlob> &in_grad, 135 const std::vector<TBlob> &aux_args) { 136 using namespace mshadow; 137 Stream<xpu> *s = ctx.get_stream<xpu>(); 138 TBlob data = in_data[warpctc_enum::kData]; 139 TBlob label = in_data[warpctc_enum::kLabel]; 140 CHECK_EQ(data.shape_.ndim(), 2) << "input data shape should be 2 (t*n, p)"; 141 ctcOptions info; //please updated to latest baidu/warp-ctc NOLINT(*) 142 if (data.dev_mask() == cpu::kDevMask) { 143 info.loc = CTC_CPU; 144 info.num_threads = 1; 145 } else if (data.dev_mask() == gpu::kDevMask) { 146 #if MXNET_USE_CUDA 147 info.loc = CTC_GPU; 148 info.stream = ctx.get_stream<gpu>()->stream_; 149 } else { 150 #endif 151 LOG(FATAL) << "Unknown device type " << data.dev_mask(); 152 } 153 info.blank_label = 0; 154 155 int T = param_.input_length; 156 int minibatch = data.shape_[0] / T; 157 int alphabet_size = data.shape_[1]; 158 std::vector<int> input_lengths; 159 for (int i = 0; i < minibatch; i++) { 160 input_lengths.push_back(T); 161 } 162 163 #if MXNET_USE_CUDA 164 cudaError_t cuda_status; 165 #endif 166 float* activations = static_cast<float*>(data.dptr_); 167 int* flat_labels = static_cast<int*>(label.dptr_); 168 int* cpu_raw_labels = flat_labels; 169 float* grads = static_cast<float*>(in_grad[warpctc_enum::kData].dptr_); 170 if (data.dev_mask() == gpu::kDevMask) { 171 #if MXNET_USE_CUDA 172 cpu_raw_labels = reinterpret_cast<int*>(malloc(sizeof(int) * label.Size())); 173 cuda_status = cudaMemcpyAsync(cpu_raw_labels, flat_labels, 174 label.Size()*sizeof(int), 175 cudaMemcpyDeviceToHost, 176 ctx.get_stream<gpu>()->stream_); 177 CHECK_EQ(cuda_status, cudaSuccess) << "cuda memcpy label error"; 178 #endif 179 } 180 181 int total_label_length = 0; 182 std::vector<int> label_lengths = labelLengths(cpu_raw_labels, 183 minibatch, 184 label.Size(), 185 0, &total_label_length); 186 int* cpu_labels = reinterpret_cast<int*>( 187 malloc(sizeof(int) * total_label_length)); 188 removeBlank(cpu_raw_labels, cpu_labels, label.Size(), 0); 189 190 size_t alloc_bytes; 191 throw_on_error(get_workspace_size(label_lengths.data(), 192 input_lengths.data(), 193 alphabet_size, 194 input_lengths.size(), info, 195 &alloc_bytes), 196 "Error: get_workspace_size in inf_test"); 197 198 Tensor<xpu, 1> ctc_workspace = ctx.requested[warpctc_enum::kTmp].get_space<xpu>( 199 mshadow::Shape1(alloc_bytes), s); 200 201 std::vector<float> costs(minibatch); 202 throw_on_error(compute_ctc_loss(activations, 203 grads, 204 cpu_labels, 205 label_lengths.data(), 206 input_lengths.data(), 207 alphabet_size, 208 minibatch, 209 costs.data(), 210 ctc_workspace.dptr_, 211 info), 212 "Error: compute_ctc_loss"); 213 214 if (data.dev_mask() == cpu::kDevMask) { 215 free(cpu_labels); 216 } else if (data.dev_mask() == gpu::kDevMask) { 217 #if MXNET_USE_CUDA 218 free(cpu_raw_labels); 219 free(cpu_labels); 220 #endif 221 } 222 } 223 }; 224 225 template<typename xpu> 226 Operator* CreateOp(WarpCTCParam type); 227 228 229 #if DMLC_USE_CXX11 230 class WarpCTCProp : public OperatorProperty { 231 public: ListArguments()232 std::vector<std::string> ListArguments() const override { 233 return {"data", "label"}; 234 } 235 ListOutputs()236 virtual std::vector<std::string> ListOutputs() const { 237 return {"output"}; 238 } 239 Init(const std::vector<std::pair<std::string,std::string>> & kwargs)240 void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) 241 override { 242 param_.Init(kwargs); 243 } 244 GetParams()245 std::map<std::string, std::string> GetParams() const override { 246 return param_.__DICT__(); 247 } 248 InferShape(mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape,mxnet::ShapeVector * aux_shape)249 bool InferShape(mxnet::ShapeVector *in_shape, 250 mxnet::ShapeVector *out_shape, 251 mxnet::ShapeVector *aux_shape) const override { 252 using namespace mshadow; 253 CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]"; 254 const mxnet::TShape &dshape = in_shape->at(0); 255 if (dshape.ndim() == 0) return false; 256 mxnet::TShape label_shape(dshape.ndim() - 1, 1); 257 label_shape[0] = param_.label_length * (dshape[0] / param_.input_length); 258 SHAPE_ASSIGN_CHECK(*in_shape, warpctc_enum::kLabel, label_shape); 259 260 out_shape->clear(); 261 out_shape->push_back(dshape); 262 return true; 263 } 264 InferType(std::vector<int> * in_type,std::vector<int> * out_type,std::vector<int> * aux_type)265 virtual bool InferType(std::vector<int> *in_type, 266 std::vector<int> *out_type, 267 std::vector<int> *aux_type) const { 268 CHECK_LE(in_type->size(), this->ListArguments().size()); 269 in_type->clear(); 270 in_type->push_back(mshadow::kFloat32); 271 in_type->push_back(mshadow::kInt32); 272 out_type->clear(); 273 out_type->push_back(mshadow::kFloat32); 274 return true; 275 } 276 BackwardResource(const mxnet::ShapeVector & in_shape)277 std::vector<ResourceRequest> BackwardResource( 278 const mxnet::ShapeVector &in_shape) const override { 279 return {ResourceRequest::kTempSpace}; 280 } 281 Copy()282 OperatorProperty* Copy() const override { 283 auto ptr = new WarpCTCProp(); 284 ptr->param_ = param_; 285 return ptr; 286 } 287 TypeString()288 std::string TypeString() const override { 289 return "WarpCTC"; 290 } 291 292 DeclareBackwardDependency(const std::vector<int> & out_grad,const std::vector<int> & in_data,const std::vector<int> & out_data)293 std::vector<int> DeclareBackwardDependency(const std::vector<int> &out_grad, 294 const std::vector<int> &in_data, 295 const std::vector<int> &out_data) 296 const override { 297 return {in_data[warpctc_enum::kData], 298 in_data[warpctc_enum::kLabel], 299 out_data[warpctc_enum::kOut]}; 300 } 301 302 Operator* CreateOperator(Context ctx) const override; 303 304 private: 305 WarpCTCParam param_; 306 }; 307 #endif // DMLC_USE_CXX11 308 309 } // namespace op 310 } // namespace mxnet 311 312 #endif // PLUGIN_WARPCTC_WARPCTC_INL_H_ 313