1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file warpctc-inl.h
22  * \brief warpctc operator
23  * \author Liang Xiang
24 */
25 #ifndef PLUGIN_WARPCTC_WARPCTC_INL_H_
26 #define PLUGIN_WARPCTC_WARPCTC_INL_H_
27 
28 #include <dmlc/logging.h>
29 #include <dmlc/parameter.h>
30 #include <mxnet/operator.h>
31 #include <stdio.h>
32 #include <ctc.h>
33 #include <cstring>
34 #include <map>
35 #include <string>
36 #include <vector>
37 #include <utility>
38 #include <iostream>
39 #include "../../src/operator/operator_common.h"
40 
41 namespace mxnet {
42 namespace op {
43 
44 namespace warpctc_enum {
45   enum CTCOpInputs {kData, kLabel};
46   enum CTCOpOutputs {kOut};
47   enum CTCTemp {kTmp};
48 }  // namespace warpctc_enum
49 
50 struct WarpCTCParam : public dmlc::Parameter<WarpCTCParam> {
51   int label_length;
52   int input_length;
DMLC_DECLARE_PARAMETERWarpCTCParam53   DMLC_DECLARE_PARAMETER(WarpCTCParam) {
54     DMLC_DECLARE_FIELD(label_length)
55         .set_default(0)
56         .describe("Real label length");
57     DMLC_DECLARE_FIELD(input_length)
58         .set_default(0)
59         .describe("Input length");
60   }
61 };
62 
63 template<typename xpu>
64 class WarpCTCOp : public Operator {
65  private:
66   WarpCTCParam param_;
67 
68  public:
WarpCTCOp(WarpCTCParam p)69   explicit WarpCTCOp(WarpCTCParam p) {
70     this->param_ = p;
71   }
72 
~WarpCTCOp()73   ~WarpCTCOp() {
74   }
75 
throw_on_error(ctcStatus_t status,const char * message)76   inline void throw_on_error(ctcStatus_t status, const char* message) {
77     if (status != CTC_STATUS_SUCCESS) {
78       throw std::runtime_error(message
79                                + (", stat = "
80                                   + std::string(ctcGetStatusString(status))));
81     }
82   }
83 
Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)84   virtual void Forward(const OpContext &ctx,
85                        const std::vector<TBlob> &in_data,
86                        const std::vector<OpReqType> &req,
87                        const std::vector<TBlob> &out_data,
88                        const std::vector<TBlob> &aux_args) {
89     using namespace mshadow;
90     using namespace mshadow::expr;
91     CHECK_EQ(in_data.size(), 2) << "CTCOutput Input: [data, label]";
92     CHECK_EQ(out_data.size(), 1) << "CTCOutput Output: [output]";
93 
94     Stream<xpu> *s = ctx.get_stream<xpu>();
95     TBlob data = in_data[warpctc_enum::kData];
96     TBlob out = out_data[warpctc_enum::kOut];
97     Tensor<xpu, 2, float> data_tensor = data.FlatTo2D<xpu, float>(s);
98     Tensor<xpu, 2, float> out_tensor = out.FlatTo2D<xpu, float>(s);
99     Softmax(out_tensor, data_tensor);
100   }
101 
labelLengths(const int * flat_labels,int minibatch,int size,int blank,int * total_length)102   std::vector<int> labelLengths(const int * flat_labels, int minibatch,
103                                 int size, int blank, int * total_length) {
104     CHECK_EQ(param_.label_length * minibatch, size)
105         << "label size should = label_length * minibatch";
106     std::vector<int> ret(minibatch, 0);
107     for (int i = 0; i < size; i++) {
108       if (flat_labels[i] == blank) {
109         continue;
110       }
111       int b = i / param_.label_length;
112       ret[b]++;
113       (*total_length)++;
114     }
115     return ret;
116   }
117 
removeBlank(const int * flat_labels,int * cpu_labels,int size,int blank)118   void removeBlank(const int * flat_labels, int * cpu_labels,
119                    int size, int blank) {
120     int k = 0;
121     for (int i = 0; i < size; i++) {
122       if (flat_labels[i] != blank) {
123         cpu_labels[k] = flat_labels[i];
124         k += 1;
125       }
126     }
127   }
128 
Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_args)129   virtual void Backward(const OpContext &ctx,
130                         const std::vector<TBlob> &out_grad,
131                         const std::vector<TBlob> &in_data,
132                         const std::vector<TBlob> &out_data,
133                         const std::vector<OpReqType> &req,
134                         const std::vector<TBlob> &in_grad,
135                         const std::vector<TBlob> &aux_args) {
136     using namespace mshadow;
137     Stream<xpu> *s = ctx.get_stream<xpu>();
138     TBlob data = in_data[warpctc_enum::kData];
139     TBlob label = in_data[warpctc_enum::kLabel];
140     CHECK_EQ(data.shape_.ndim(), 2) << "input data shape should be 2 (t*n, p)";
141     ctcOptions info; //please updated to latest baidu/warp-ctc NOLINT(*)
142     if (data.dev_mask() == cpu::kDevMask) {
143       info.loc = CTC_CPU;
144       info.num_threads = 1;
145     } else if (data.dev_mask() == gpu::kDevMask) {
146 #if MXNET_USE_CUDA
147       info.loc = CTC_GPU;
148       info.stream = ctx.get_stream<gpu>()->stream_;
149     } else {
150 #endif
151       LOG(FATAL) << "Unknown device type " << data.dev_mask();
152     }
153     info.blank_label = 0;
154 
155     int T = param_.input_length;
156     int minibatch = data.shape_[0] / T;
157     int alphabet_size = data.shape_[1];
158     std::vector<int> input_lengths;
159     for (int i = 0; i < minibatch; i++) {
160       input_lengths.push_back(T);
161     }
162 
163 #if MXNET_USE_CUDA
164     cudaError_t cuda_status;
165 #endif
166     float* activations = static_cast<float*>(data.dptr_);
167     int* flat_labels = static_cast<int*>(label.dptr_);
168     int* cpu_raw_labels = flat_labels;
169     float* grads = static_cast<float*>(in_grad[warpctc_enum::kData].dptr_);
170     if (data.dev_mask() == gpu::kDevMask) {
171 #if MXNET_USE_CUDA
172       cpu_raw_labels = reinterpret_cast<int*>(malloc(sizeof(int) * label.Size()));
173       cuda_status = cudaMemcpyAsync(cpu_raw_labels, flat_labels,
174                                     label.Size()*sizeof(int),
175                                     cudaMemcpyDeviceToHost,
176                                     ctx.get_stream<gpu>()->stream_);
177       CHECK_EQ(cuda_status, cudaSuccess) << "cuda memcpy label error";
178 #endif
179     }
180 
181     int total_label_length = 0;
182     std::vector<int> label_lengths = labelLengths(cpu_raw_labels,
183                                                   minibatch,
184                                                   label.Size(),
185                                                   0, &total_label_length);
186     int* cpu_labels = reinterpret_cast<int*>(
187         malloc(sizeof(int) * total_label_length));
188     removeBlank(cpu_raw_labels, cpu_labels, label.Size(), 0);
189 
190     size_t alloc_bytes;
191     throw_on_error(get_workspace_size(label_lengths.data(),
192                                       input_lengths.data(),
193                                       alphabet_size,
194                                       input_lengths.size(), info,
195                                       &alloc_bytes),
196                    "Error: get_workspace_size in inf_test");
197 
198     Tensor<xpu, 1> ctc_workspace = ctx.requested[warpctc_enum::kTmp].get_space<xpu>(
199         mshadow::Shape1(alloc_bytes), s);
200 
201     std::vector<float> costs(minibatch);
202     throw_on_error(compute_ctc_loss(activations,
203                                     grads,
204                                     cpu_labels,
205                                     label_lengths.data(),
206                                     input_lengths.data(),
207                                     alphabet_size,
208                                     minibatch,
209                                     costs.data(),
210                                     ctc_workspace.dptr_,
211                                     info),
212                    "Error: compute_ctc_loss");
213 
214     if (data.dev_mask() == cpu::kDevMask) {
215       free(cpu_labels);
216     } else if (data.dev_mask() == gpu::kDevMask) {
217 #if MXNET_USE_CUDA
218       free(cpu_raw_labels);
219       free(cpu_labels);
220 #endif
221     }
222   }
223 };
224 
225 template<typename xpu>
226 Operator* CreateOp(WarpCTCParam type);
227 
228 
229 #if DMLC_USE_CXX11
230 class WarpCTCProp : public OperatorProperty {
231  public:
ListArguments()232   std::vector<std::string> ListArguments() const override {
233     return {"data", "label"};
234   }
235 
ListOutputs()236   virtual std::vector<std::string> ListOutputs() const {
237     return {"output"};
238   }
239 
Init(const std::vector<std::pair<std::string,std::string>> & kwargs)240   void Init(const std::vector<std::pair<std::string, std::string> >& kwargs)
241       override {
242     param_.Init(kwargs);
243   }
244 
GetParams()245   std::map<std::string, std::string> GetParams() const override {
246     return param_.__DICT__();
247   }
248 
InferShape(mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape,mxnet::ShapeVector * aux_shape)249   bool InferShape(mxnet::ShapeVector *in_shape,
250                   mxnet::ShapeVector *out_shape,
251                   mxnet::ShapeVector *aux_shape) const override {
252     using namespace mshadow;
253     CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]";
254     const mxnet::TShape &dshape = in_shape->at(0);
255     if (dshape.ndim() == 0) return false;
256     mxnet::TShape label_shape(dshape.ndim() - 1, 1);
257     label_shape[0] = param_.label_length * (dshape[0] / param_.input_length);
258     SHAPE_ASSIGN_CHECK(*in_shape, warpctc_enum::kLabel, label_shape);
259 
260     out_shape->clear();
261     out_shape->push_back(dshape);
262     return true;
263   }
264 
InferType(std::vector<int> * in_type,std::vector<int> * out_type,std::vector<int> * aux_type)265   virtual bool InferType(std::vector<int> *in_type,
266                          std::vector<int> *out_type,
267                          std::vector<int> *aux_type) const {
268     CHECK_LE(in_type->size(), this->ListArguments().size());
269     in_type->clear();
270     in_type->push_back(mshadow::kFloat32);
271     in_type->push_back(mshadow::kInt32);
272     out_type->clear();
273     out_type->push_back(mshadow::kFloat32);
274     return true;
275   }
276 
BackwardResource(const mxnet::ShapeVector & in_shape)277   std::vector<ResourceRequest> BackwardResource(
278       const mxnet::ShapeVector &in_shape) const override {
279     return {ResourceRequest::kTempSpace};
280   }
281 
Copy()282   OperatorProperty* Copy() const override {
283     auto ptr = new WarpCTCProp();
284     ptr->param_ = param_;
285     return ptr;
286   }
287 
TypeString()288   std::string TypeString() const override {
289     return "WarpCTC";
290   }
291 
292 
DeclareBackwardDependency(const std::vector<int> & out_grad,const std::vector<int> & in_data,const std::vector<int> & out_data)293   std::vector<int> DeclareBackwardDependency(const std::vector<int> &out_grad,
294                                              const std::vector<int> &in_data,
295                                              const std::vector<int> &out_data)
296       const override {
297     return {in_data[warpctc_enum::kData],
298           in_data[warpctc_enum::kLabel],
299           out_data[warpctc_enum::kOut]};
300   }
301 
302   Operator* CreateOperator(Context ctx) const override;
303 
304  private:
305   WarpCTCParam param_;
306 };
307 #endif  // DMLC_USE_CXX11
308 
309 }  // namespace op
310 }  // namespace mxnet
311 
312 #endif  // PLUGIN_WARPCTC_WARPCTC_INL_H_
313