1 ///////////////////////////////////////////////////////////////////////
2 // File:        network.cpp
3 // Description: Base class for neural network implementations.
4 // Author:      Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17 
18 // Include automatically generated configuration file if running autoconf.
19 #ifdef HAVE_CONFIG_H
20 #  include "config_auto.h"
21 #endif
22 
23 #include "network.h"
24 
25 #include <cstdlib>
26 
27 // This base class needs to know about all its sub-classes because of the
28 // factory deserializing method: CreateFromFile.
29 #include <allheaders.h>
30 #include "convolve.h"
31 #include "fullyconnected.h"
32 #include "input.h"
33 #include "lstm.h"
34 #include "maxpool.h"
35 #include "parallel.h"
36 #include "reconfig.h"
37 #include "reversed.h"
38 #include "scrollview.h"
39 #include "series.h"
40 #include "statistc.h"
41 #ifdef INCLUDE_TENSORFLOW
42 #  include "tfnetwork.h"
43 #endif
44 #include "tprintf.h"
45 
46 namespace tesseract {
47 
48 #ifndef GRAPHICS_DISABLED
49 
50 // Min and max window sizes.
51 const int kMinWinSize = 500;
52 const int kMaxWinSize = 2000;
53 // Window frame sizes need adding on to make the content fit.
54 const int kXWinFrameSize = 30;
55 const int kYWinFrameSize = 80;
56 
57 #endif // !GRAPHICS_DISABLED
58 
59 // String names corresponding to the NetworkType enum.
60 // Keep in sync with NetworkType.
61 // Names used in Serialization to allow re-ordering/addition/deletion of
62 // layer types in NetworkType without invalidating existing network files.
63 static char const *const kTypeNames[NT_COUNT] = {
64     "Invalid",     "Input",
65     "Convolve",    "Maxpool",
66     "Parallel",    "Replicated",
67     "ParBidiLSTM", "DepParUDLSTM",
68     "Par2dLSTM",   "Series",
69     "Reconfig",    "RTLReversed",
70     "TTBReversed", "XYTranspose",
71     "LSTM",        "SummLSTM",
72     "Logistic",    "LinLogistic",
73     "LinTanh",     "Tanh",
74     "Relu",        "Linear",
75     "Softmax",     "SoftmaxNoCTC",
76     "LSTMSoftmax", "LSTMBinarySoftmax",
77     "TensorFlow",
78 };
79 
Network()80 Network::Network()
81     : type_(NT_NONE)
82     , training_(TS_ENABLED)
83     , needs_to_backprop_(true)
84     , network_flags_(0)
85     , ni_(0)
86     , no_(0)
87     , num_weights_(0)
88     , forward_win_(nullptr)
89     , backward_win_(nullptr)
90     , randomizer_(nullptr) {}
Network(NetworkType type,const std::string & name,int ni,int no)91 Network::Network(NetworkType type, const std::string &name, int ni, int no)
92     : type_(type)
93     , training_(TS_ENABLED)
94     , needs_to_backprop_(true)
95     , network_flags_(0)
96     , ni_(ni)
97     , no_(no)
98     , num_weights_(0)
99     , name_(name)
100     , forward_win_(nullptr)
101     , backward_win_(nullptr)
102     , randomizer_(nullptr) {}
103 
104 // Suspends/Enables/Permanently disables training by setting the training_
105 // flag. Serialize and DeSerialize only operate on the run-time data if state
106 // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
107 // temporarily disable layers in state TS_ENABLED, allowing a trainer to
108 // serialize as if it were a recognizer.
109 // TS_RE_ENABLE will re-enable layers that were previously in any disabled
110 // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
111 // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
112 // recognizer can be converted back to a trainer.
SetEnableTraining(TrainingState state)113 void Network::SetEnableTraining(TrainingState state) {
114   if (state == TS_RE_ENABLE) {
115     // Enable only from temp disabled.
116     if (training_ == TS_TEMP_DISABLE) {
117       training_ = TS_ENABLED;
118     }
119   } else if (state == TS_TEMP_DISABLE) {
120     // Temp disable only from enabled.
121     if (training_ == TS_ENABLED) {
122       training_ = state;
123     }
124   } else {
125     training_ = state;
126   }
127 }
128 
129 // Sets flags that control the action of the network. See NetworkFlags enum
130 // for bit values.
SetNetworkFlags(uint32_t flags)131 void Network::SetNetworkFlags(uint32_t flags) {
132   network_flags_ = flags;
133 }
134 
135 // Sets up the network for training. Initializes weights using weights of
136 // scale `range` picked according to the random number generator `randomizer`.
InitWeights(float range,TRand * randomizer)137 int Network::InitWeights([[maybe_unused]] float range, TRand *randomizer) {
138   randomizer_ = randomizer;
139   return 0;
140 }
141 
142 // Provides a pointer to a TRand for any networks that care to use it.
143 // Note that randomizer is a borrowed pointer that should outlive the network
144 // and should not be deleted by any of the networks.
SetRandomizer(TRand * randomizer)145 void Network::SetRandomizer(TRand *randomizer) {
146   randomizer_ = randomizer;
147 }
148 
149 // Sets needs_to_backprop_ to needs_backprop and returns true if
150 // needs_backprop || any weights in this network so the next layer forward
151 // can be told to produce backprop for this layer if needed.
SetupNeedsBackprop(bool needs_backprop)152 bool Network::SetupNeedsBackprop(bool needs_backprop) {
153   needs_to_backprop_ = needs_backprop;
154   return needs_backprop || num_weights_ > 0;
155 }
156 
157 // Writes to the given file. Returns false in case of error.
Serialize(TFile * fp) const158 bool Network::Serialize(TFile *fp) const {
159   int8_t data = NT_NONE;
160   if (!fp->Serialize(&data)) {
161     return false;
162   }
163   std::string type_name = kTypeNames[type_];
164   if (!fp->Serialize(type_name)) {
165     return false;
166   }
167   data = training_;
168   if (!fp->Serialize(&data)) {
169     return false;
170   }
171   data = needs_to_backprop_;
172   if (!fp->Serialize(&data)) {
173     return false;
174   }
175   if (!fp->Serialize(&network_flags_)) {
176     return false;
177   }
178   if (!fp->Serialize(&ni_)) {
179     return false;
180   }
181   if (!fp->Serialize(&no_)) {
182     return false;
183   }
184   if (!fp->Serialize(&num_weights_)) {
185     return false;
186   }
187   uint32_t length = name_.length();
188   if (!fp->Serialize(&length)) {
189     return false;
190   }
191   return fp->Serialize(name_.c_str(), length);
192 }
193 
getNetworkType(TFile * fp)194 static NetworkType getNetworkType(TFile *fp) {
195   int8_t data;
196   if (!fp->DeSerialize(&data)) {
197     return NT_NONE;
198   }
199   if (data == NT_NONE) {
200     std::string type_name;
201     if (!fp->DeSerialize(type_name)) {
202       return NT_NONE;
203     }
204     for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
205     }
206     if (data == NT_COUNT) {
207       tprintf("Invalid network layer type:%s\n", type_name.c_str());
208       return NT_NONE;
209     }
210   }
211   return static_cast<NetworkType>(data);
212 }
213 
214 // Reads from the given file. Returns nullptr in case of error.
215 // Determines the type of the serialized class and calls its DeSerialize
216 // on a new object of the appropriate type, which is returned.
CreateFromFile(TFile * fp)217 Network *Network::CreateFromFile(TFile *fp) {
218   NetworkType type;       // Type of the derived network class.
219   TrainingState training; // Are we currently training?
220   bool needs_to_backprop; // This network needs to output back_deltas.
221   int32_t network_flags;  // Behavior control flags in NetworkFlags.
222   int32_t ni;             // Number of input values.
223   int32_t no;             // Number of output values.
224   int32_t num_weights;    // Number of weights in this and sub-network.
225   std::string name;       // A unique name for this layer.
226   int8_t data;
227   Network *network = nullptr;
228   type = getNetworkType(fp);
229   if (!fp->DeSerialize(&data)) {
230     return nullptr;
231   }
232   training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
233   if (!fp->DeSerialize(&data)) {
234     return nullptr;
235   }
236   needs_to_backprop = data != 0;
237   if (!fp->DeSerialize(&network_flags)) {
238     return nullptr;
239   }
240   if (!fp->DeSerialize(&ni)) {
241     return nullptr;
242   }
243   if (!fp->DeSerialize(&no)) {
244     return nullptr;
245   }
246   if (!fp->DeSerialize(&num_weights)) {
247     return nullptr;
248   }
249   if (!fp->DeSerialize(name)) {
250     return nullptr;
251   }
252 
253   switch (type) {
254     case NT_CONVOLVE:
255       network = new Convolve(name.c_str(), ni, 0, 0);
256       break;
257     case NT_INPUT:
258       network = new Input(name.c_str(), ni, no);
259       break;
260     case NT_LSTM:
261     case NT_LSTM_SOFTMAX:
262     case NT_LSTM_SOFTMAX_ENCODED:
263     case NT_LSTM_SUMMARY:
264       network = new LSTM(name.c_str(), ni, no, no, false, type);
265       break;
266     case NT_MAXPOOL:
267       network = new Maxpool(name.c_str(), ni, 0, 0);
268       break;
269     // All variants of Parallel.
270     case NT_PARALLEL:
271     case NT_REPLICATED:
272     case NT_PAR_RL_LSTM:
273     case NT_PAR_UD_LSTM:
274     case NT_PAR_2D_LSTM:
275       network = new Parallel(name.c_str(), type);
276       break;
277     case NT_RECONFIG:
278       network = new Reconfig(name.c_str(), ni, 0, 0);
279       break;
280     // All variants of reversed.
281     case NT_XREVERSED:
282     case NT_YREVERSED:
283     case NT_XYTRANSPOSE:
284       network = new Reversed(name.c_str(), type);
285       break;
286     case NT_SERIES:
287       network = new Series(name.c_str());
288       break;
289     case NT_TENSORFLOW:
290 #ifdef INCLUDE_TENSORFLOW
291       network = new TFNetwork(name.c_str());
292 #else
293       tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
294 #endif
295       break;
296     // All variants of FullyConnected.
297     case NT_SOFTMAX:
298     case NT_SOFTMAX_NO_CTC:
299     case NT_RELU:
300     case NT_TANH:
301     case NT_LINEAR:
302     case NT_LOGISTIC:
303     case NT_POSCLIP:
304     case NT_SYMCLIP:
305       network = new FullyConnected(name.c_str(), ni, no, type);
306       break;
307     default:
308       break;
309   }
310   if (network) {
311     network->training_ = training;
312     network->needs_to_backprop_ = needs_to_backprop;
313     network->network_flags_ = network_flags;
314     network->num_weights_ = num_weights;
315     if (!network->DeSerialize(fp)) {
316       delete network;
317       network = nullptr;
318     }
319   }
320   return network;
321 }
322 
323 // Returns a random number in [-range, range].
Random(TFloat range)324 TFloat Network::Random(TFloat range) {
325   ASSERT_HOST(randomizer_ != nullptr);
326   return randomizer_->SignedRand(range);
327 }
328 
329 #ifndef GRAPHICS_DISABLED
330 
331 // === Debug image display methods. ===
332 // Displays the image of the matrix to the forward window.
DisplayForward(const NetworkIO & matrix)333 void Network::DisplayForward(const NetworkIO &matrix) {
334   Image image = matrix.ToPix();
335   ClearWindow(false, name_.c_str(), pixGetWidth(image), pixGetHeight(image), &forward_win_);
336   DisplayImage(image, forward_win_);
337   forward_win_->Update();
338 }
339 
340 // Displays the image of the matrix to the backward window.
DisplayBackward(const NetworkIO & matrix)341 void Network::DisplayBackward(const NetworkIO &matrix) {
342   Image image = matrix.ToPix();
343   std::string window_name = name_ + "-back";
344   ClearWindow(false, window_name.c_str(), pixGetWidth(image), pixGetHeight(image), &backward_win_);
345   DisplayImage(image, backward_win_);
346   backward_win_->Update();
347 }
348 
349 // Creates the window if needed, otherwise clears it.
ClearWindow(bool tess_coords,const char * window_name,int width,int height,ScrollView ** window)350 void Network::ClearWindow(bool tess_coords, const char *window_name, int width, int height,
351                           ScrollView **window) {
352   if (*window == nullptr) {
353     int min_size = std::min(width, height);
354     if (min_size < kMinWinSize) {
355       if (min_size < 1) {
356         min_size = 1;
357       }
358       width = width * kMinWinSize / min_size;
359       height = height * kMinWinSize / min_size;
360     }
361     width += kXWinFrameSize;
362     height += kYWinFrameSize;
363     if (width > kMaxWinSize) {
364       width = kMaxWinSize;
365     }
366     if (height > kMaxWinSize) {
367       height = kMaxWinSize;
368     }
369     *window = new ScrollView(window_name, 80, 100, width, height, width, height, tess_coords);
370     tprintf("Created window %s of size %d, %d\n", window_name, width, height);
371   } else {
372     (*window)->Clear();
373   }
374 }
375 
376 // Displays the pix in the given window. and returns the height of the pix.
377 // The pix is pixDestroyed.
DisplayImage(Image pix,ScrollView * window)378 int Network::DisplayImage(Image pix, ScrollView *window) {
379   int height = pixGetHeight(pix);
380   window->Draw(pix, 0, 0);
381   pix.destroy();
382   return height;
383 }
384 #endif // !GRAPHICS_DISABLED
385 
386 } // namespace tesseract.
387