1 ///////////////////////////////////////////////////////////////////////
2 // File:        lstm.cpp
3 // Description: Long-term-short-term-memory Recurrent neural network.
4 // Author:      Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17 
18 #ifdef HAVE_CONFIG_H
19 #  include "config_auto.h"
20 #endif
21 
22 #include "lstm.h"
23 
24 #ifdef _OPENMP
25 #  include <omp.h>
26 #endif
27 #include <cstdio>
28 #include <cstdlib>
29 #include <sstream> // for std::ostringstream
30 
31 #if !defined(__GNUC__) && defined(_MSC_VER)
32 #  include <intrin.h> // _BitScanReverse
33 #endif
34 
35 #include "fullyconnected.h"
36 #include "functions.h"
37 #include "networkscratch.h"
38 #include "tprintf.h"
39 
40 // Macros for openmp code if it is available, otherwise empty macros.
41 #ifdef _OPENMP
42 #  define PARALLEL_IF_OPENMP(__num_threads)                                  \
43     PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
44       PRAGMA(omp sections nowait) {                                          \
45         PRAGMA(omp section) {
46 #  define SECTION_IF_OPENMP \
47     }                       \
48     PRAGMA(omp section) {
49 #  define END_PARALLEL_IF_OPENMP \
50     }                            \
51     } /* end of sections */      \
52     } /* end of parallel section */
53 
54 // Define the portable PRAGMA macro.
55 #  ifdef _MSC_VER // Different _Pragma
56 #    define PRAGMA(x) __pragma(x)
57 #  else
58 #    define PRAGMA(x) _Pragma(#    x)
59 #  endif // _MSC_VER
60 
61 #else // _OPENMP
62 #  define PARALLEL_IF_OPENMP(__num_threads)
63 #  define SECTION_IF_OPENMP
64 #  define END_PARALLEL_IF_OPENMP
65 #endif // _OPENMP
66 
67 namespace tesseract {
68 
69 // Max absolute value of state_. It is reasonably high to enable the state
70 // to count things.
71 const TFloat kStateClip = 100.0;
72 // Max absolute value of gate_errors (the gradients).
73 const TFloat kErrClip = 1.0f;
74 
75 // Calculate ceil(log2(n)).
ceil_log2(uint32_t n)76 static inline uint32_t ceil_log2(uint32_t n) {
77   // l2 = (unsigned)log2(n).
78 #if defined(__GNUC__)
79   // Use fast inline assembler code for gcc or clang.
80   uint32_t l2 = 31 - __builtin_clz(n);
81 #elif defined(_MSC_VER)
82   // Use fast intrinsic function for MS compiler.
83   unsigned long l2 = 0;
84   _BitScanReverse(&l2, n);
85 #else
86   if (n == 0)
87     return UINT_MAX;
88   if (n == 1)
89     return 0;
90   uint32_t val = n;
91   uint32_t l2 = 0;
92   while (val > 1) {
93     val >>= 1;
94     l2++;
95   }
96 #endif
97   // Round up if n is not a power of 2.
98   return (n == (1u << l2)) ? l2 : l2 + 1;
99 }
100 
LSTM(const std::string & name,int ni,int ns,int no,bool two_dimensional,NetworkType type)101 LSTM::LSTM(const std::string &name, int ni, int ns, int no, bool two_dimensional, NetworkType type)
102     : Network(type, name, ni, no)
103     , na_(ni + ns)
104     , ns_(ns)
105     , nf_(0)
106     , is_2d_(two_dimensional)
107     , softmax_(nullptr)
108     , input_width_(0) {
109   if (two_dimensional) {
110     na_ += ns_;
111   }
112   if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
113     nf_ = 0;
114     // networkbuilder ensures this is always true.
115     ASSERT_HOST(no == ns);
116   } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
117     nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_);
118     softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
119   } else {
120     tprintf("%d is invalid type of LSTM!\n", type);
121     ASSERT_HOST(false);
122   }
123   na_ += nf_;
124 }
125 
~LSTM()126 LSTM::~LSTM() {
127   delete softmax_;
128 }
129 
130 // Returns the shape output from the network given an input shape (which may
131 // be partially unknown ie zero).
OutputShape(const StaticShape & input_shape) const132 StaticShape LSTM::OutputShape(const StaticShape &input_shape) const {
133   StaticShape result = input_shape;
134   result.set_depth(no_);
135   if (type_ == NT_LSTM_SUMMARY) {
136     result.set_width(1);
137   }
138   if (softmax_ != nullptr) {
139     return softmax_->OutputShape(result);
140   }
141   return result;
142 }
143 
144 // Suspends/Enables training by setting the training_ flag. Serialize and
145 // DeSerialize only operate on the run-time data if state is false.
SetEnableTraining(TrainingState state)146 void LSTM::SetEnableTraining(TrainingState state) {
147   if (state == TS_RE_ENABLE) {
148     // Enable only from temp disabled.
149     if (training_ == TS_TEMP_DISABLE) {
150       training_ = TS_ENABLED;
151     }
152   } else if (state == TS_TEMP_DISABLE) {
153     // Temp disable only from enabled.
154     if (training_ == TS_ENABLED) {
155       training_ = state;
156     }
157   } else {
158     if (state == TS_ENABLED && training_ != TS_ENABLED) {
159       for (int w = 0; w < WT_COUNT; ++w) {
160         if (w == GFS && !Is2D()) {
161           continue;
162         }
163         gate_weights_[w].InitBackward();
164       }
165     }
166     training_ = state;
167   }
168   if (softmax_ != nullptr) {
169     softmax_->SetEnableTraining(state);
170   }
171 }
172 
173 // Sets up the network for training. Initializes weights using weights of
174 // scale `range` picked according to the random number generator `randomizer`.
InitWeights(float range,TRand * randomizer)175 int LSTM::InitWeights(float range, TRand *randomizer) {
176   Network::SetRandomizer(randomizer);
177   num_weights_ = 0;
178   for (int w = 0; w < WT_COUNT; ++w) {
179     if (w == GFS && !Is2D()) {
180       continue;
181     }
182     num_weights_ +=
183         gate_weights_[w].InitWeightsFloat(ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
184   }
185   if (softmax_ != nullptr) {
186     num_weights_ += softmax_->InitWeights(range, randomizer);
187   }
188   return num_weights_;
189 }
190 
191 // Recursively searches the network for softmaxes with old_no outputs,
192 // and remaps their outputs according to code_map. See network.h for details.
RemapOutputs(int old_no,const std::vector<int> & code_map)193 int LSTM::RemapOutputs(int old_no, const std::vector<int> &code_map) {
194   if (softmax_ != nullptr) {
195     num_weights_ -= softmax_->num_weights();
196     num_weights_ += softmax_->RemapOutputs(old_no, code_map);
197   }
198   return num_weights_;
199 }
200 
201 // Converts a float network to an int network.
ConvertToInt()202 void LSTM::ConvertToInt() {
203   for (int w = 0; w < WT_COUNT; ++w) {
204     if (w == GFS && !Is2D()) {
205       continue;
206     }
207     gate_weights_[w].ConvertToInt();
208   }
209   if (softmax_ != nullptr) {
210     softmax_->ConvertToInt();
211   }
212 }
213 
214 // Sets up the network for training using the given weight_range.
DebugWeights()215 void LSTM::DebugWeights() {
216   for (int w = 0; w < WT_COUNT; ++w) {
217     if (w == GFS && !Is2D()) {
218       continue;
219     }
220     std::ostringstream msg;
221     msg << name_ << " Gate weights " << w;
222     gate_weights_[w].Debug2D(msg.str().c_str());
223   }
224   if (softmax_ != nullptr) {
225     softmax_->DebugWeights();
226   }
227 }
228 
229 // Writes to the given file. Returns false in case of error.
Serialize(TFile * fp) const230 bool LSTM::Serialize(TFile *fp) const {
231   if (!Network::Serialize(fp)) {
232     return false;
233   }
234   if (!fp->Serialize(&na_)) {
235     return false;
236   }
237   for (int w = 0; w < WT_COUNT; ++w) {
238     if (w == GFS && !Is2D()) {
239       continue;
240     }
241     if (!gate_weights_[w].Serialize(IsTraining(), fp)) {
242       return false;
243     }
244   }
245   if (softmax_ != nullptr && !softmax_->Serialize(fp)) {
246     return false;
247   }
248   return true;
249 }
250 
251 // Reads from the given file. Returns false in case of error.
252 
DeSerialize(TFile * fp)253 bool LSTM::DeSerialize(TFile *fp) {
254   if (!fp->DeSerialize(&na_)) {
255     return false;
256   }
257   if (type_ == NT_LSTM_SOFTMAX) {
258     nf_ = no_;
259   } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
260     nf_ = ceil_log2(no_);
261   } else {
262     nf_ = 0;
263   }
264   is_2d_ = false;
265   for (int w = 0; w < WT_COUNT; ++w) {
266     if (w == GFS && !Is2D()) {
267       continue;
268     }
269     if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) {
270       return false;
271     }
272     if (w == CI) {
273       ns_ = gate_weights_[CI].NumOutputs();
274       is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
275     }
276   }
277   delete softmax_;
278   if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
279     softmax_ = static_cast<FullyConnected *>(Network::CreateFromFile(fp));
280     if (softmax_ == nullptr) {
281       return false;
282     }
283   } else {
284     softmax_ = nullptr;
285   }
286   return true;
287 }
288 
289 // Runs forward propagation of activations on the input line.
290 // See NetworkCpp for a detailed discussion of the arguments.
Forward(bool debug,const NetworkIO & input,const TransposedArray * input_transpose,NetworkScratch * scratch,NetworkIO * output)291 void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose,
292                    NetworkScratch *scratch, NetworkIO *output) {
293   input_map_ = input.stride_map();
294   input_width_ = input.Width();
295   if (softmax_ != nullptr) {
296     output->ResizeFloat(input, no_);
297   } else if (type_ == NT_LSTM_SUMMARY) {
298     output->ResizeXTo1(input, no_);
299   } else {
300     output->Resize(input, no_);
301   }
302   ResizeForward(input);
303   // Temporary storage of forward computation for each gate.
304   NetworkScratch::FloatVec temp_lines[WT_COUNT];
305   int ro = ns_;
306   if (source_.int_mode() && IntSimdMatrix::intSimdMatrix) {
307     ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
308   }
309   for (auto &temp_line : temp_lines) {
310     temp_line.Init(ns_, ro, scratch);
311   }
312   // Single timestep buffers for the current/recurrent output and state.
313   NetworkScratch::FloatVec curr_state, curr_output;
314   curr_state.Init(ns_, scratch);
315   ZeroVector<TFloat>(ns_, curr_state);
316   curr_output.Init(ns_, scratch);
317   ZeroVector<TFloat>(ns_, curr_output);
318   // Rotating buffers of width buf_width allow storage of the state and output
319   // for the other dimension, used only when working in true 2D mode. The width
320   // is enough to hold an entire strip of the major direction.
321   int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
322   std::vector<NetworkScratch::FloatVec> states, outputs;
323   if (Is2D()) {
324     states.resize(buf_width);
325     outputs.resize(buf_width);
326     for (int i = 0; i < buf_width; ++i) {
327       states[i].Init(ns_, scratch);
328       ZeroVector<TFloat>(ns_, states[i]);
329       outputs[i].Init(ns_, scratch);
330       ZeroVector<TFloat>(ns_, outputs[i]);
331     }
332   }
333   // Used only if a softmax LSTM.
334   NetworkScratch::FloatVec softmax_output;
335   NetworkScratch::IO int_output;
336   if (softmax_ != nullptr) {
337     softmax_output.Init(no_, scratch);
338     ZeroVector<TFloat>(no_, softmax_output);
339     int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
340     if (input.int_mode()) {
341       int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
342     }
343     softmax_->SetupForward(input, nullptr);
344   }
345   NetworkScratch::FloatVec curr_input;
346   curr_input.Init(na_, scratch);
347   StrideMap::Index src_index(input_map_);
348   // Used only by NT_LSTM_SUMMARY.
349   StrideMap::Index dest_index(output->stride_map());
350   do {
351     int t = src_index.t();
352     // True if there is a valid old state for the 2nd dimension.
353     bool valid_2d = Is2D();
354     if (valid_2d) {
355       StrideMap::Index dim_index(src_index);
356       if (!dim_index.AddOffset(-1, FD_HEIGHT)) {
357         valid_2d = false;
358       }
359     }
360     // Index of the 2-D revolving buffers (outputs, states).
361     int mod_t = Modulo(t, buf_width); // Current timestep.
362     // Setup the padded input in source.
363     source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
364     if (softmax_ != nullptr) {
365       source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
366     }
367     source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
368     if (Is2D()) {
369       source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
370     }
371     if (!source_.int_mode()) {
372       source_.ReadTimeStep(t, curr_input);
373     }
374     // Matrix multiply the inputs with the source.
375     PARALLEL_IF_OPENMP(GFS)
376     // It looks inefficient to create the threads on each t iteration, but the
377     // alternative of putting the parallel outside the t loop, a single around
378     // the t-loop and then tasks in place of the sections is a *lot* slower.
379     // Cell inputs.
380     if (source_.int_mode()) {
381       gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
382     } else {
383       gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
384     }
385     FuncInplace<GFunc>(ns_, temp_lines[CI]);
386 
387     SECTION_IF_OPENMP
388     // Input Gates.
389     if (source_.int_mode()) {
390       gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
391     } else {
392       gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
393     }
394     FuncInplace<FFunc>(ns_, temp_lines[GI]);
395 
396     SECTION_IF_OPENMP
397     // 1-D forget gates.
398     if (source_.int_mode()) {
399       gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
400     } else {
401       gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
402     }
403     FuncInplace<FFunc>(ns_, temp_lines[GF1]);
404 
405     // 2-D forget gates.
406     if (Is2D()) {
407       if (source_.int_mode()) {
408         gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
409       } else {
410         gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
411       }
412       FuncInplace<FFunc>(ns_, temp_lines[GFS]);
413     }
414 
415     SECTION_IF_OPENMP
416     // Output gates.
417     if (source_.int_mode()) {
418       gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
419     } else {
420       gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
421     }
422     FuncInplace<FFunc>(ns_, temp_lines[GO]);
423     END_PARALLEL_IF_OPENMP
424 
425     // Apply forget gate to state.
426     MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
427     if (Is2D()) {
428       // Max-pool the forget gates (in 2-d) instead of blindly adding.
429       int8_t *which_fg_col = which_fg_[t];
430       memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
431       if (valid_2d) {
432         const TFloat *stepped_state = states[mod_t];
433         for (int i = 0; i < ns_; ++i) {
434           if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
435             curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
436             which_fg_col[i] = 2;
437           }
438         }
439       }
440     }
441     MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
442     // Clip curr_state to a sane range.
443     ClipVector<TFloat>(ns_, -kStateClip, kStateClip, curr_state);
444     if (IsTraining()) {
445       // Save the gate node values.
446       node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
447       node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
448       node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
449       node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
450       if (Is2D()) {
451         node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
452       }
453     }
454     FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
455     if (IsTraining()) {
456       state_.WriteTimeStep(t, curr_state);
457     }
458     if (softmax_ != nullptr) {
459       if (input.int_mode()) {
460         int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
461         softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
462       } else {
463         softmax_->ForwardTimeStep(curr_output, t, softmax_output);
464       }
465       output->WriteTimeStep(t, softmax_output);
466       if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
467         CodeInBinary(no_, nf_, softmax_output);
468       }
469     } else if (type_ == NT_LSTM_SUMMARY) {
470       // Output only at the end of a row.
471       if (src_index.IsLast(FD_WIDTH)) {
472         output->WriteTimeStep(dest_index.t(), curr_output);
473         dest_index.Increment();
474       }
475     } else {
476       output->WriteTimeStep(t, curr_output);
477     }
478     // Save states for use by the 2nd dimension only if needed.
479     if (Is2D()) {
480       CopyVector(ns_, curr_state, states[mod_t]);
481       CopyVector(ns_, curr_output, outputs[mod_t]);
482     }
483     // Always zero the states at the end of every row, but only for the major
484     // direction. The 2-D state remains intact.
485     if (src_index.IsLast(FD_WIDTH)) {
486       ZeroVector<TFloat>(ns_, curr_state);
487       ZeroVector<TFloat>(ns_, curr_output);
488     }
489   } while (src_index.Increment());
490 #if DEBUG_DETAIL > 0
491   tprintf("Source:%s\n", name_.c_str());
492   source_.Print(10);
493   tprintf("State:%s\n", name_.c_str());
494   state_.Print(10);
495   tprintf("Output:%s\n", name_.c_str());
496   output->Print(10);
497 #endif
498 #ifndef GRAPHICS_DISABLED
499   if (debug) {
500     DisplayForward(*output);
501   }
502 #endif
503 }
504 
505 // Runs backward propagation of errors on the deltas line.
506 // See NetworkCpp for a detailed discussion of the arguments.
Backward(bool debug,const NetworkIO & fwd_deltas,NetworkScratch * scratch,NetworkIO * back_deltas)507 bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
508                     NetworkIO *back_deltas) {
509 #ifndef GRAPHICS_DISABLED
510   if (debug) {
511     DisplayBackward(fwd_deltas);
512   }
513 #endif
514   back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
515   // ======Scratch space.======
516   // Output errors from deltas with recurrence from sourceerr.
517   NetworkScratch::FloatVec outputerr;
518   outputerr.Init(ns_, scratch);
519   // Recurrent error in the state/source.
520   NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
521   curr_stateerr.Init(ns_, scratch);
522   curr_sourceerr.Init(na_, scratch);
523   ZeroVector<TFloat>(ns_, curr_stateerr);
524   ZeroVector<TFloat>(na_, curr_sourceerr);
525   // Errors in the gates.
526   NetworkScratch::FloatVec gate_errors[WT_COUNT];
527   for (auto &gate_error : gate_errors) {
528     gate_error.Init(ns_, scratch);
529   }
530   // Rotating buffers of width buf_width allow storage of the recurrent time-
531   // steps used only for true 2-D. Stores one full strip of the major direction.
532   int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
533   std::vector<NetworkScratch::FloatVec> stateerr, sourceerr;
534   if (Is2D()) {
535     stateerr.resize(buf_width);
536     sourceerr.resize(buf_width);
537     for (int t = 0; t < buf_width; ++t) {
538       stateerr[t].Init(ns_, scratch);
539       sourceerr[t].Init(na_, scratch);
540       ZeroVector<TFloat>(ns_, stateerr[t]);
541       ZeroVector<TFloat>(na_, sourceerr[t]);
542     }
543   }
544   // Parallel-generated sourceerr from each of the gates.
545   NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
546   for (auto &sourceerr_temp : sourceerr_temps) {
547     sourceerr_temp.Init(na_, scratch);
548   }
549   int width = input_width_;
550   // Transposed gate errors stored over all timesteps for sum outer.
551   NetworkScratch::GradientStore gate_errors_t[WT_COUNT];
552   for (auto &w : gate_errors_t) {
553     w.Init(ns_, width, scratch);
554   }
555   // Used only if softmax_ != nullptr.
556   NetworkScratch::FloatVec softmax_errors;
557   NetworkScratch::GradientStore softmax_errors_t;
558   if (softmax_ != nullptr) {
559     softmax_errors.Init(no_, scratch);
560     softmax_errors_t.Init(no_, width, scratch);
561   }
562   TFloat state_clip = Is2D() ? 9.0 : 4.0;
563 #if DEBUG_DETAIL > 1
564   tprintf("fwd_deltas:%s\n", name_.c_str());
565   fwd_deltas.Print(10);
566 #endif
567   StrideMap::Index dest_index(input_map_);
568   dest_index.InitToLast();
569   // Used only by NT_LSTM_SUMMARY.
570   StrideMap::Index src_index(fwd_deltas.stride_map());
571   src_index.InitToLast();
572   do {
573     int t = dest_index.t();
574     bool at_last_x = dest_index.IsLast(FD_WIDTH);
575     // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
576     // valid if >= 0, which is true if 2d and not on the top/bottom.
577     int up_pos = -1;
578     int down_pos = -1;
579     if (Is2D()) {
580       if (dest_index.index(FD_HEIGHT) > 0) {
581         StrideMap::Index up_index(dest_index);
582         if (up_index.AddOffset(-1, FD_HEIGHT)) {
583           up_pos = up_index.t();
584         }
585       }
586       if (!dest_index.IsLast(FD_HEIGHT)) {
587         StrideMap::Index down_index(dest_index);
588         if (down_index.AddOffset(1, FD_HEIGHT)) {
589           down_pos = down_index.t();
590         }
591       }
592     }
593     // Index of the 2-D revolving buffers (sourceerr, stateerr).
594     int mod_t = Modulo(t, buf_width); // Current timestep.
595     // Zero the state in the major direction only at the end of every row.
596     if (at_last_x) {
597       ZeroVector<TFloat>(na_, curr_sourceerr);
598       ZeroVector<TFloat>(ns_, curr_stateerr);
599     }
600     // Setup the outputerr.
601     if (type_ == NT_LSTM_SUMMARY) {
602       if (dest_index.IsLast(FD_WIDTH)) {
603         fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
604         src_index.Decrement();
605       } else {
606         ZeroVector<TFloat>(ns_, outputerr);
607       }
608     } else if (softmax_ == nullptr) {
609       fwd_deltas.ReadTimeStep(t, outputerr);
610     } else {
611       softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors, softmax_errors_t.get(), outputerr);
612     }
613     if (!at_last_x) {
614       AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
615     }
616     if (down_pos >= 0) {
617       AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
618     }
619     // Apply the 1-d forget gates.
620     if (!at_last_x) {
621       const float *next_node_gf1 = node_values_[GF1].f(t + 1);
622       for (int i = 0; i < ns_; ++i) {
623         curr_stateerr[i] *= next_node_gf1[i];
624       }
625     }
626     if (Is2D() && t + 1 < width) {
627       for (int i = 0; i < ns_; ++i) {
628         if (which_fg_[t + 1][i] != 1) {
629           curr_stateerr[i] = 0.0;
630         }
631       }
632       if (down_pos >= 0) {
633         const float *right_node_gfs = node_values_[GFS].f(down_pos);
634         const TFloat *right_stateerr = stateerr[mod_t];
635         for (int i = 0; i < ns_; ++i) {
636           if (which_fg_[down_pos][i] == 2) {
637             curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
638           }
639         }
640       }
641     }
642     state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr, curr_stateerr);
643     // Clip stateerr_ to a sane range.
644     ClipVector<TFloat>(ns_, -state_clip, state_clip, curr_stateerr);
645 #if DEBUG_DETAIL > 1
646     if (t + 10 > width) {
647       tprintf("t=%d, stateerr=", t);
648       for (int i = 0; i < ns_; ++i)
649         tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i], curr_sourceerr[ni_ + nf_ + i]);
650       tprintf("\n");
651     }
652 #endif
653     // Matrix multiply to get the source errors.
654     PARALLEL_IF_OPENMP(GFS)
655 
656     // Cell inputs.
657     node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t, curr_stateerr, gate_errors[CI]);
658     ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
659     gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
660     gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
661 
662     SECTION_IF_OPENMP
663     // Input Gates.
664     node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t, curr_stateerr, gate_errors[GI]);
665     ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
666     gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
667     gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
668 
669     SECTION_IF_OPENMP
670     // 1-D forget Gates.
671     if (t > 0) {
672       node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr, gate_errors[GF1]);
673       ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
674       gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1], sourceerr_temps[GF1]);
675     } else {
676       memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
677       memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
678     }
679     gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
680 
681     // 2-D forget Gates.
682     if (up_pos >= 0) {
683       node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr, gate_errors[GFS]);
684       ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
685       gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS], sourceerr_temps[GFS]);
686     } else {
687       memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
688       memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
689     }
690     if (Is2D()) {
691       gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
692     }
693 
694     SECTION_IF_OPENMP
695     // Output gates.
696     state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr, gate_errors[GO]);
697     ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
698     gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
699     gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
700     END_PARALLEL_IF_OPENMP
701 
702     SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI], sourceerr_temps[GF1],
703                sourceerr_temps[GO], sourceerr_temps[GFS], curr_sourceerr);
704     back_deltas->WriteTimeStep(t, curr_sourceerr);
705     // Save states for use by the 2nd dimension only if needed.
706     if (Is2D()) {
707       CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
708       CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
709     }
710   } while (dest_index.Decrement());
711 #if DEBUG_DETAIL > 2
712   for (int w = 0; w < WT_COUNT; ++w) {
713     tprintf("%s gate errors[%d]\n", name_.c_str(), w);
714     gate_errors_t[w].get()->PrintUnTransposed(10);
715   }
716 #endif
717   // Transposed source_ used to speed-up SumOuter.
718   NetworkScratch::GradientStore source_t, state_t;
719   source_t.Init(na_, width, scratch);
720   source_.Transpose(source_t.get());
721   state_t.Init(ns_, width, scratch);
722   state_.Transpose(state_t.get());
723 #ifdef _OPENMP
724 #  pragma omp parallel for num_threads(GFS) if (!Is2D())
725 #endif
726   for (int w = 0; w < WT_COUNT; ++w) {
727     if (w == GFS && !Is2D()) {
728       continue;
729     }
730     gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
731   }
732   if (softmax_ != nullptr) {
733     softmax_->FinishBackward(*softmax_errors_t);
734   }
735   return needs_to_backprop_;
736 }
737 
738 // Updates the weights using the given learning rate, momentum and adam_beta.
739 // num_samples is used in the adam computation iff use_adam_ is true.
Update(float learning_rate,float momentum,float adam_beta,int num_samples)740 void LSTM::Update(float learning_rate, float momentum, float adam_beta, int num_samples) {
741 #if DEBUG_DETAIL > 3
742   PrintW();
743 #endif
744   for (int w = 0; w < WT_COUNT; ++w) {
745     if (w == GFS && !Is2D()) {
746       continue;
747     }
748     gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
749   }
750   if (softmax_ != nullptr) {
751     softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
752   }
753 #if DEBUG_DETAIL > 3
754   PrintDW();
755 #endif
756 }
757 
758 // Sums the products of weight updates in *this and other, splitting into
759 // positive (same direction) in *same and negative (different direction) in
760 // *changed.
CountAlternators(const Network & other,TFloat * same,TFloat * changed) const761 void LSTM::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const {
762   ASSERT_HOST(other.type() == type_);
763   const LSTM *lstm = static_cast<const LSTM *>(&other);
764   for (int w = 0; w < WT_COUNT; ++w) {
765     if (w == GFS && !Is2D()) {
766       continue;
767     }
768     gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
769   }
770   if (softmax_ != nullptr) {
771     softmax_->CountAlternators(*lstm->softmax_, same, changed);
772   }
773 }
774 
775 #if DEBUG_DETAIL > 3
776 
777 // Prints the weights for debug purposes.
PrintW()778 void LSTM::PrintW() {
779   tprintf("Weight state:%s\n", name_.c_str());
780   for (int w = 0; w < WT_COUNT; ++w) {
781     if (w == GFS && !Is2D()) {
782       continue;
783     }
784     tprintf("Gate %d, inputs\n", w);
785     for (int i = 0; i < ni_; ++i) {
786       tprintf("Row %d:", i);
787       for (int s = 0; s < ns_; ++s) {
788         tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
789       }
790       tprintf("\n");
791     }
792     tprintf("Gate %d, outputs\n", w);
793     for (int i = ni_; i < ni_ + ns_; ++i) {
794       tprintf("Row %d:", i - ni_);
795       for (int s = 0; s < ns_; ++s) {
796         tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
797       }
798       tprintf("\n");
799     }
800     tprintf("Gate %d, bias\n", w);
801     for (int s = 0; s < ns_; ++s) {
802       tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
803     }
804     tprintf("\n");
805   }
806 }
807 
808 // Prints the weight deltas for debug purposes.
PrintDW()809 void LSTM::PrintDW() {
810   tprintf("Delta state:%s\n", name_.c_str());
811   for (int w = 0; w < WT_COUNT; ++w) {
812     if (w == GFS && !Is2D()) {
813       continue;
814     }
815     tprintf("Gate %d, inputs\n", w);
816     for (int i = 0; i < ni_; ++i) {
817       tprintf("Row %d:", i);
818       for (int s = 0; s < ns_; ++s) {
819         tprintf(" %g", gate_weights_[w].GetDW(s, i));
820       }
821       tprintf("\n");
822     }
823     tprintf("Gate %d, outputs\n", w);
824     for (int i = ni_; i < ni_ + ns_; ++i) {
825       tprintf("Row %d:", i - ni_);
826       for (int s = 0; s < ns_; ++s) {
827         tprintf(" %g", gate_weights_[w].GetDW(s, i));
828       }
829       tprintf("\n");
830     }
831     tprintf("Gate %d, bias\n", w);
832     for (int s = 0; s < ns_; ++s) {
833       tprintf(" %g", gate_weights_[w].GetDW(s, na_));
834     }
835     tprintf("\n");
836   }
837 }
838 
839 #endif
840 
841 // Resizes forward data to cope with an input image of the given width.
ResizeForward(const NetworkIO & input)842 void LSTM::ResizeForward(const NetworkIO &input) {
843   int rounded_inputs = gate_weights_[CI].RoundInputs(na_);
844   source_.Resize(input, rounded_inputs);
845   which_fg_.ResizeNoInit(input.Width(), ns_);
846   if (IsTraining()) {
847     state_.ResizeFloat(input, ns_);
848     for (int w = 0; w < WT_COUNT; ++w) {
849       if (w == GFS && !Is2D()) {
850         continue;
851       }
852       node_values_[w].ResizeFloat(input, ns_);
853     }
854   }
855 }
856 
857 } // namespace tesseract.
858