1 /////////////////////////////////////////////////////////////////////// 2 // File: network.h 3 // Description: Base class for neural network implementations. 4 // Author: Ray Smith 5 // 6 // (C) Copyright 2013, Google Inc. 7 // Licensed under the Apache License, Version 2.0 (the "License"); 8 // you may not use this file except in compliance with the License. 9 // You may obtain a copy of the License at 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 /////////////////////////////////////////////////////////////////////// 17 18 #ifndef TESSERACT_LSTM_NETWORK_H_ 19 #define TESSERACT_LSTM_NETWORK_H_ 20 21 #include "helpers.h" 22 #include "matrix.h" 23 #include "networkio.h" 24 #include "serialis.h" 25 #include "static_shape.h" 26 #include "tprintf.h" 27 28 #include <cmath> 29 #include <cstdio> 30 31 struct Pix; 32 33 namespace tesseract { 34 35 class ScrollView; 36 class TBOX; 37 class ImageData; 38 class NetworkScratch; 39 40 // Enum to store the run-time type of a Network. Keep in sync with kTypeNames. 41 enum NetworkType { 42 NT_NONE, // The naked base class. 43 NT_INPUT, // Inputs from an image. 44 // Plumbing networks combine other networks or rearrange the inputs. 45 NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood. 46 NT_MAXPOOL, // Chooses the max result from a rectangle. 47 NT_PARALLEL, // Runs networks in parallel. 48 NT_REPLICATED, // Runs identical networks in parallel. 49 NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel. 50 NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel. 51 NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel. 52 NT_SERIES, // Executes a sequence of layers. 53 NT_RECONFIG, // Scales the time/y size but makes the output deeper. 54 NT_XREVERSED, // Reverses the x direction of the inputs/outputs. 55 NT_YREVERSED, // Reverses the y-direction of the inputs/outputs. 56 NT_XYTRANSPOSE, // Transposes x and y (for just a single op). 57 // Functional networks actually calculate stuff. 58 NT_LSTM, // Long-Short-Term-Memory block. 59 NT_LSTM_SUMMARY, // LSTM that only keeps its last output. 60 NT_LOGISTIC, // Fully connected logistic nonlinearity. 61 NT_POSCLIP, // Fully connected rect lin version of logistic. 62 NT_SYMCLIP, // Fully connected rect lin version of tanh. 63 NT_TANH, // Fully connected with tanh nonlinearity. 64 NT_RELU, // Fully connected with rectifier nonlinearity. 65 NT_LINEAR, // Fully connected with no nonlinearity. 66 NT_SOFTMAX, // Softmax uses exponential normalization, with CTC. 67 NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC. 68 // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with 69 // the outputs fed back to the input of the LSTM at the next timestep. 70 // The ENCODED version binary encodes the softmax outputs, providing log2 of 71 // the number of outputs as additional inputs, and the other version just 72 // provides all the softmax outputs as additional inputs. 73 NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax. 74 NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax. 75 // A TensorFlow graph encapsulated as a Tesseract network. 76 NT_TENSORFLOW, 77 78 NT_COUNT // Array size. 79 }; 80 81 // Enum of Network behavior flags. Can in theory be set for each individual 82 // network element. 83 enum NetworkFlags { 84 // Network forward/backprop behavior. 85 NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer. 86 NF_ADAM = 128, // Weight-specific learning rate. 87 }; 88 89 // State of training and desired state used in SetEnableTraining. 90 enum TrainingState { 91 // Valid states of training_. 92 TS_DISABLED, // Disabled permanently. 93 TS_ENABLED, // Enabled for backprop and to write a training dump. 94 // Re-enable from ANY disabled state. 95 TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump. 96 // Valid only for SetEnableTraining. 97 TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED. 98 }; 99 100 // Base class for network types. Not quite an abstract base class, but almost. 101 // Most of the time no isolated Network exists, except prior to 102 // deserialization. 103 class TESS_API Network { 104 public: 105 Network(); 106 Network(NetworkType type, const std::string &name, int ni, int no); 107 virtual ~Network() = default; 108 109 // Accessors. type()110 NetworkType type() const { 111 return type_; 112 } IsTraining()113 bool IsTraining() const { 114 return training_ == TS_ENABLED; 115 } needs_to_backprop()116 bool needs_to_backprop() const { 117 return needs_to_backprop_; 118 } num_weights()119 int num_weights() const { 120 return num_weights_; 121 } NumInputs()122 int NumInputs() const { 123 return ni_; 124 } NumOutputs()125 int NumOutputs() const { 126 return no_; 127 } 128 // Returns the required shape input to the network. InputShape()129 virtual StaticShape InputShape() const { 130 StaticShape result; 131 return result; 132 } 133 // Returns the shape output from the network given an input shape (which may 134 // be partially unknown ie zero). OutputShape(const StaticShape & input_shape)135 virtual StaticShape OutputShape(const StaticShape &input_shape) const { 136 StaticShape result(input_shape); 137 result.set_depth(no_); 138 return result; 139 } name()140 const std::string &name() const { 141 return name_; 142 } spec()143 virtual std::string spec() const { 144 return "?"; 145 } TestFlag(NetworkFlags flag)146 bool TestFlag(NetworkFlags flag) const { 147 return (network_flags_ & flag) != 0; 148 } 149 150 // Initialization and administrative functions that are mostly provided 151 // by Plumbing. 152 // Returns true if the given type is derived from Plumbing, and thus contains 153 // multiple sub-networks that can have their own learning rate. IsPlumbingType()154 virtual bool IsPlumbingType() const { 155 return false; 156 } 157 158 // Suspends/Enables/Permanently disables training by setting the training_ 159 // flag. Serialize and DeSerialize only operate on the run-time data if state 160 // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will 161 // temporarily disable layers in state TS_ENABLED, allowing a trainer to 162 // serialize as if it were a recognizer. 163 // TS_RE_ENABLE will re-enable layers that were previously in any disabled 164 // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in 165 // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a 166 // recognizer can be converted back to a trainer. 167 virtual void SetEnableTraining(TrainingState state); 168 169 // Sets flags that control the action of the network. See NetworkFlags enum 170 // for bit values. 171 virtual void SetNetworkFlags(uint32_t flags); 172 173 // Sets up the network for training. Initializes weights using weights of 174 // scale `range` picked according to the random number generator `randomizer`. 175 // Note that randomizer is a borrowed pointer that should outlive the network 176 // and should not be deleted by any of the networks. 177 // Returns the number of weights initialized. 178 virtual int InitWeights(float range, TRand *randomizer); 179 // Changes the number of outputs to the outside world to the size of the given 180 // code_map. Recursively searches the entire network for Softmax layers that 181 // have exactly old_no outputs, and operates only on those, leaving all others 182 // unchanged. This enables networks with multiple output layers to get all 183 // their softmaxes updated, but if an internal layer, uses one of those 184 // softmaxes for input, then the inputs will effectively be scrambled. 185 // TODO(rays) Fix this before any such network is implemented. 186 // The softmaxes are resized by copying the old weight matrix entries for each 187 // output from code_map[output] where non-negative, and uses the mean (over 188 // all outputs) of the existing weights for all outputs with negative code_map 189 // entries. Returns the new number of weights. RemapOutputs(int old_no,const std::vector<int> & code_map)190 virtual int RemapOutputs([[maybe_unused]] int old_no, 191 [[maybe_unused]] const std::vector<int> &code_map) { 192 return 0; 193 } 194 195 // Converts a float network to an int network. ConvertToInt()196 virtual void ConvertToInt() {} 197 198 // Provides a pointer to a TRand for any networks that care to use it. 199 // Note that randomizer is a borrowed pointer that should outlive the network 200 // and should not be deleted by any of the networks. 201 virtual void SetRandomizer(TRand *randomizer); 202 203 // Sets needs_to_backprop_ to needs_backprop and returns true if 204 // needs_backprop || any weights in this network so the next layer forward 205 // can be told to produce backprop for this layer if needed. 206 virtual bool SetupNeedsBackprop(bool needs_backprop); 207 208 // Returns the most recent reduction factor that the network applied to the 209 // time sequence. Assumes that any 2-d is already eliminated. Used for 210 // scaling bounding boxes of truth data and calculating result bounding boxes. 211 // WARNING: if GlobalMinimax is used to vary the scale, this will return 212 // the last used scale factor. Call it before any forward, and it will return 213 // the minimum scale factor of the paths through the GlobalMinimax. XScaleFactor()214 virtual int XScaleFactor() const { 215 return 1; 216 } 217 218 // Provides the (minimum) x scale factor to the network (of interest only to 219 // input units) so they can determine how to scale bounding boxes. CacheXScaleFactor(int factor)220 virtual void CacheXScaleFactor([[maybe_unused]] int factor) {} 221 222 // Provides debug output on the weights. 223 virtual void DebugWeights() = 0; 224 225 // Writes to the given file. Returns false in case of error. 226 // Should be overridden by subclasses, but called by their Serialize. 227 virtual bool Serialize(TFile *fp) const; 228 // Reads from the given file. Returns false in case of error. 229 // Should be overridden by subclasses, but NOT called by their DeSerialize. 230 virtual bool DeSerialize(TFile *fp) = 0; 231 232 public: 233 // Updates the weights using the given learning rate, momentum and adam_beta. 234 // num_samples is used in the adam computation iff use_adam_ is true. Update(float learning_rate,float momentum,float adam_beta,int num_samples)235 virtual void Update([[maybe_unused]] float learning_rate, 236 [[maybe_unused]] float momentum, 237 [[maybe_unused]] float adam_beta, 238 [[maybe_unused]] int num_samples) {} 239 // Sums the products of weight updates in *this and other, splitting into 240 // positive (same direction) in *same and negative (different direction) in 241 // *changed. CountAlternators(const Network & other,TFloat * same,TFloat * changed)242 virtual void CountAlternators([[maybe_unused]] const Network &other, 243 [[maybe_unused]] TFloat *same, 244 [[maybe_unused]] TFloat *changed) const {} 245 246 // Reads from the given file. Returns nullptr in case of error. 247 // Determines the type of the serialized class and calls its DeSerialize 248 // on a new object of the appropriate type, which is returned. 249 static Network *CreateFromFile(TFile *fp); 250 251 // Runs forward propagation of activations on the input line. 252 // Note that input and output are both 2-d arrays. 253 // The 1st index is the time element. In a 1-d network, it might be the pixel 254 // position on the textline. In a 2-d network, the linearization is defined 255 // by the stride_map. (See networkio.h). 256 // The 2nd index of input is the network inputs/outputs, and the dimension 257 // of the input must match NumInputs() of this network. 258 // The output array will be resized as needed so that its 1st dimension is 259 // always equal to the number of output values, and its second dimension is 260 // always NumOutputs(). Note that all this detail is encapsulated away inside 261 // NetworkIO, as are the internals of the scratch memory space used by the 262 // network. See networkscratch.h for that. 263 // If input_transpose is not nullptr, then it contains the transpose of input, 264 // and the caller guarantees that it will still be valid on the next call to 265 // backward. The callee is therefore at liberty to save the pointer and 266 // reference it on a call to backward. This is a bit ugly, but it makes it 267 // possible for a replicating parallel to calculate the input transpose once 268 // instead of all the replicated networks having to do it. 269 virtual void Forward(bool debug, const NetworkIO &input, 270 const TransposedArray *input_transpose, 271 NetworkScratch *scratch, NetworkIO *output) = 0; 272 273 // Runs backward propagation of errors on fwdX_deltas. 274 // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward. 275 // Returns false if back_deltas was not set, due to there being no point in 276 // propagating further backwards. Thus most complete networks will always 277 // return false from Backward! 278 virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, 279 NetworkScratch *scratch, NetworkIO *back_deltas) = 0; 280 281 // === Debug image display methods. === 282 // Displays the image of the matrix to the forward window. 283 void DisplayForward(const NetworkIO &matrix); 284 // Displays the image of the matrix to the backward window. 285 void DisplayBackward(const NetworkIO &matrix); 286 287 // Creates the window if needed, otherwise clears it. 288 static void ClearWindow(bool tess_coords, const char *window_name, int width, 289 int height, ScrollView **window); 290 291 // Displays the pix in the given window. and returns the height of the pix. 292 // The pix is pixDestroyed. 293 static int DisplayImage(Image pix, ScrollView *window); 294 295 protected: 296 // Returns a random number in [-range, range]. 297 TFloat Random(TFloat range); 298 299 protected: 300 NetworkType type_; // Type of the derived network class. 301 TrainingState training_; // Are we currently training? 302 bool needs_to_backprop_; // This network needs to output back_deltas. 303 int32_t network_flags_; // Behavior control flags in NetworkFlags. 304 int32_t ni_; // Number of input values. 305 int32_t no_; // Number of output values. 306 int32_t num_weights_; // Number of weights in this and sub-network. 307 std::string name_; // A unique name for this layer. 308 309 // NOT-serialized debug data. 310 ScrollView *forward_win_; // Recognition debug display window. 311 ScrollView *backward_win_; // Training debug display window. 312 TRand *randomizer_; // Random number generator. 313 }; 314 315 } // namespace tesseract. 316 317 #endif // TESSERACT_LSTM_NETWORK_H_ 318