1 ///////////////////////////////////////////////////////////////////////
2 // File:        network.h
3 // Description: Base class for neural network implementations.
4 // Author:      Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17 
18 #ifndef TESSERACT_LSTM_NETWORK_H_
19 #define TESSERACT_LSTM_NETWORK_H_
20 
21 #include "helpers.h"
22 #include "matrix.h"
23 #include "networkio.h"
24 #include "serialis.h"
25 #include "static_shape.h"
26 #include "tprintf.h"
27 
28 #include <cmath>
29 #include <cstdio>
30 
31 struct Pix;
32 
33 namespace tesseract {
34 
35 class ScrollView;
36 class TBOX;
37 class ImageData;
38 class NetworkScratch;
39 
40 // Enum to store the run-time type of a Network. Keep in sync with kTypeNames.
41 enum NetworkType {
42   NT_NONE,  // The naked base class.
43   NT_INPUT, // Inputs from an image.
44   // Plumbing networks combine other networks or rearrange the inputs.
45   NT_CONVOLVE,    // Duplicates inputs in a sliding window neighborhood.
46   NT_MAXPOOL,     // Chooses the max result from a rectangle.
47   NT_PARALLEL,    // Runs networks in parallel.
48   NT_REPLICATED,  // Runs identical networks in parallel.
49   NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel.
50   NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel.
51   NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel.
52   NT_SERIES,      // Executes a sequence of layers.
53   NT_RECONFIG,    // Scales the time/y size but makes the output deeper.
54   NT_XREVERSED,   // Reverses the x direction of the inputs/outputs.
55   NT_YREVERSED,   // Reverses the y-direction of the inputs/outputs.
56   NT_XYTRANSPOSE, // Transposes x and y (for just a single op).
57   // Functional networks actually calculate stuff.
58   NT_LSTM,           // Long-Short-Term-Memory block.
59   NT_LSTM_SUMMARY,   // LSTM that only keeps its last output.
60   NT_LOGISTIC,       // Fully connected logistic nonlinearity.
61   NT_POSCLIP,        // Fully connected rect lin version of logistic.
62   NT_SYMCLIP,        // Fully connected rect lin version of tanh.
63   NT_TANH,           // Fully connected with tanh nonlinearity.
64   NT_RELU,           // Fully connected with rectifier nonlinearity.
65   NT_LINEAR,         // Fully connected with no nonlinearity.
66   NT_SOFTMAX,        // Softmax uses exponential normalization, with CTC.
67   NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC.
68   // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with
69   // the outputs fed back to the input of the LSTM at the next timestep.
70   // The ENCODED version binary encodes the softmax outputs, providing log2 of
71   // the number of outputs as additional inputs, and the other version just
72   // provides all the softmax outputs as additional inputs.
73   NT_LSTM_SOFTMAX,         // 1-d LSTM with built-in fully connected softmax.
74   NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax.
75   // A TensorFlow graph encapsulated as a Tesseract network.
76   NT_TENSORFLOW,
77 
78   NT_COUNT // Array size.
79 };
80 
81 // Enum of Network behavior flags. Can in theory be set for each individual
82 // network element.
83 enum NetworkFlags {
84   // Network forward/backprop behavior.
85   NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer.
86   NF_ADAM = 128,             // Weight-specific learning rate.
87 };
88 
89 // State of training and desired state used in SetEnableTraining.
90 enum TrainingState {
91   // Valid states of training_.
92   TS_DISABLED,     // Disabled permanently.
93   TS_ENABLED,      // Enabled for backprop and to write a training dump.
94                    // Re-enable from ANY disabled state.
95   TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
96   // Valid only for SetEnableTraining.
97   TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
98 };
99 
100 // Base class for network types. Not quite an abstract base class, but almost.
101 // Most of the time no isolated Network exists, except prior to
102 // deserialization.
103 class TESS_API Network {
104 public:
105   Network();
106   Network(NetworkType type, const std::string &name, int ni, int no);
107   virtual ~Network() = default;
108 
109   // Accessors.
type()110   NetworkType type() const {
111     return type_;
112   }
IsTraining()113   bool IsTraining() const {
114     return training_ == TS_ENABLED;
115   }
needs_to_backprop()116   bool needs_to_backprop() const {
117     return needs_to_backprop_;
118   }
num_weights()119   int num_weights() const {
120     return num_weights_;
121   }
NumInputs()122   int NumInputs() const {
123     return ni_;
124   }
NumOutputs()125   int NumOutputs() const {
126     return no_;
127   }
128   // Returns the required shape input to the network.
InputShape()129   virtual StaticShape InputShape() const {
130     StaticShape result;
131     return result;
132   }
133   // Returns the shape output from the network given an input shape (which may
134   // be partially unknown ie zero).
OutputShape(const StaticShape & input_shape)135   virtual StaticShape OutputShape(const StaticShape &input_shape) const {
136     StaticShape result(input_shape);
137     result.set_depth(no_);
138     return result;
139   }
name()140   const std::string &name() const {
141     return name_;
142   }
spec()143   virtual std::string spec() const {
144     return "?";
145   }
TestFlag(NetworkFlags flag)146   bool TestFlag(NetworkFlags flag) const {
147     return (network_flags_ & flag) != 0;
148   }
149 
150   // Initialization and administrative functions that are mostly provided
151   // by Plumbing.
152   // Returns true if the given type is derived from Plumbing, and thus contains
153   // multiple sub-networks that can have their own learning rate.
IsPlumbingType()154   virtual bool IsPlumbingType() const {
155     return false;
156   }
157 
158   // Suspends/Enables/Permanently disables training by setting the training_
159   // flag. Serialize and DeSerialize only operate on the run-time data if state
160   // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
161   // temporarily disable layers in state TS_ENABLED, allowing a trainer to
162   // serialize as if it were a recognizer.
163   // TS_RE_ENABLE will re-enable layers that were previously in any disabled
164   // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
165   // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
166   // recognizer can be converted back to a trainer.
167   virtual void SetEnableTraining(TrainingState state);
168 
169   // Sets flags that control the action of the network. See NetworkFlags enum
170   // for bit values.
171   virtual void SetNetworkFlags(uint32_t flags);
172 
173   // Sets up the network for training. Initializes weights using weights of
174   // scale `range` picked according to the random number generator `randomizer`.
175   // Note that randomizer is a borrowed pointer that should outlive the network
176   // and should not be deleted by any of the networks.
177   // Returns the number of weights initialized.
178   virtual int InitWeights(float range, TRand *randomizer);
179   // Changes the number of outputs to the outside world to the size of the given
180   // code_map. Recursively searches the entire network for Softmax layers that
181   // have exactly old_no outputs, and operates only on those, leaving all others
182   // unchanged. This enables networks with multiple output layers to get all
183   // their softmaxes updated, but if an internal layer, uses one of those
184   // softmaxes for input, then the inputs will effectively be scrambled.
185   // TODO(rays) Fix this before any such network is implemented.
186   // The softmaxes are resized by copying the old weight matrix entries for each
187   // output from code_map[output] where non-negative, and uses the mean (over
188   // all outputs) of the existing weights for all outputs with negative code_map
189   // entries. Returns the new number of weights.
RemapOutputs(int old_no,const std::vector<int> & code_map)190   virtual int RemapOutputs([[maybe_unused]] int old_no,
191                            [[maybe_unused]] const std::vector<int> &code_map) {
192     return 0;
193   }
194 
195   // Converts a float network to an int network.
ConvertToInt()196   virtual void ConvertToInt() {}
197 
198   // Provides a pointer to a TRand for any networks that care to use it.
199   // Note that randomizer is a borrowed pointer that should outlive the network
200   // and should not be deleted by any of the networks.
201   virtual void SetRandomizer(TRand *randomizer);
202 
203   // Sets needs_to_backprop_ to needs_backprop and returns true if
204   // needs_backprop || any weights in this network so the next layer forward
205   // can be told to produce backprop for this layer if needed.
206   virtual bool SetupNeedsBackprop(bool needs_backprop);
207 
208   // Returns the most recent reduction factor that the network applied to the
209   // time sequence. Assumes that any 2-d is already eliminated. Used for
210   // scaling bounding boxes of truth data and calculating result bounding boxes.
211   // WARNING: if GlobalMinimax is used to vary the scale, this will return
212   // the last used scale factor. Call it before any forward, and it will return
213   // the minimum scale factor of the paths through the GlobalMinimax.
XScaleFactor()214   virtual int XScaleFactor() const {
215     return 1;
216   }
217 
218   // Provides the (minimum) x scale factor to the network (of interest only to
219   // input units) so they can determine how to scale bounding boxes.
CacheXScaleFactor(int factor)220   virtual void CacheXScaleFactor([[maybe_unused]] int factor) {}
221 
222   // Provides debug output on the weights.
223   virtual void DebugWeights() = 0;
224 
225   // Writes to the given file. Returns false in case of error.
226   // Should be overridden by subclasses, but called by their Serialize.
227   virtual bool Serialize(TFile *fp) const;
228   // Reads from the given file. Returns false in case of error.
229   // Should be overridden by subclasses, but NOT called by their DeSerialize.
230   virtual bool DeSerialize(TFile *fp) = 0;
231 
232 public:
233   // Updates the weights using the given learning rate, momentum and adam_beta.
234   // num_samples is used in the adam computation iff use_adam_ is true.
Update(float learning_rate,float momentum,float adam_beta,int num_samples)235   virtual void Update([[maybe_unused]] float learning_rate,
236                       [[maybe_unused]] float momentum,
237                       [[maybe_unused]] float adam_beta,
238                       [[maybe_unused]] int num_samples) {}
239   // Sums the products of weight updates in *this and other, splitting into
240   // positive (same direction) in *same and negative (different direction) in
241   // *changed.
CountAlternators(const Network & other,TFloat * same,TFloat * changed)242   virtual void CountAlternators([[maybe_unused]] const Network &other,
243                                 [[maybe_unused]] TFloat *same,
244                                 [[maybe_unused]] TFloat *changed) const {}
245 
246   // Reads from the given file. Returns nullptr in case of error.
247   // Determines the type of the serialized class and calls its DeSerialize
248   // on a new object of the appropriate type, which is returned.
249   static Network *CreateFromFile(TFile *fp);
250 
251   // Runs forward propagation of activations on the input line.
252   // Note that input and output are both 2-d arrays.
253   // The 1st index is the time element. In a 1-d network, it might be the pixel
254   // position on the textline. In a 2-d network, the linearization is defined
255   // by the stride_map. (See networkio.h).
256   // The 2nd index of input is the network inputs/outputs, and the dimension
257   // of the input must match NumInputs() of this network.
258   // The output array will be resized as needed so that its 1st dimension is
259   // always equal to the number of output values, and its second dimension is
260   // always NumOutputs(). Note that all this detail is encapsulated away inside
261   // NetworkIO, as are the internals of the scratch memory space used by the
262   // network. See networkscratch.h for that.
263   // If input_transpose is not nullptr, then it contains the transpose of input,
264   // and the caller guarantees that it will still be valid on the next call to
265   // backward. The callee is therefore at liberty to save the pointer and
266   // reference it on a call to backward. This is a bit ugly, but it makes it
267   // possible for a replicating parallel to calculate the input transpose once
268   // instead of all the replicated networks having to do it.
269   virtual void Forward(bool debug, const NetworkIO &input,
270                        const TransposedArray *input_transpose,
271                        NetworkScratch *scratch, NetworkIO *output) = 0;
272 
273   // Runs backward propagation of errors on fwdX_deltas.
274   // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
275   // Returns false if back_deltas was not set, due to there being no point in
276   // propagating further backwards. Thus most complete networks will always
277   // return false from Backward!
278   virtual bool Backward(bool debug, const NetworkIO &fwd_deltas,
279                         NetworkScratch *scratch, NetworkIO *back_deltas) = 0;
280 
281   // === Debug image display methods. ===
282   // Displays the image of the matrix to the forward window.
283   void DisplayForward(const NetworkIO &matrix);
284   // Displays the image of the matrix to the backward window.
285   void DisplayBackward(const NetworkIO &matrix);
286 
287   // Creates the window if needed, otherwise clears it.
288   static void ClearWindow(bool tess_coords, const char *window_name, int width,
289                           int height, ScrollView **window);
290 
291   // Displays the pix in the given window. and returns the height of the pix.
292   // The pix is pixDestroyed.
293   static int DisplayImage(Image pix, ScrollView *window);
294 
295 protected:
296   // Returns a random number in [-range, range].
297   TFloat Random(TFloat range);
298 
299 protected:
300   NetworkType type_;       // Type of the derived network class.
301   TrainingState training_; // Are we currently training?
302   bool needs_to_backprop_; // This network needs to output back_deltas.
303   int32_t network_flags_;  // Behavior control flags in NetworkFlags.
304   int32_t ni_;             // Number of input values.
305   int32_t no_;             // Number of output values.
306   int32_t num_weights_;    // Number of weights in this and sub-network.
307   std::string name_;       // A unique name for this layer.
308 
309   // NOT-serialized debug data.
310   ScrollView *forward_win_;  // Recognition debug display window.
311   ScrollView *backward_win_; // Training debug display window.
312   TRand *randomizer_;        // Random number generator.
313 };
314 
315 } // namespace tesseract.
316 
317 #endif // TESSERACT_LSTM_NETWORK_H_
318