1 /////////////////////////////////////////////////////////////////////// 2 // File: weightmatrix.h 3 // Description: Hides distinction between float/int implementations. 4 // Author: Ray Smith 5 // 6 // (C) Copyright 2014, Google Inc. 7 // Licensed under the Apache License, Version 2.0 (the "License"); 8 // you may not use this file except in compliance with the License. 9 // You may obtain a copy of the License at 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 /////////////////////////////////////////////////////////////////////// 17 18 #ifndef TESSERACT_LSTM_WEIGHTMATRIX_H_ 19 #define TESSERACT_LSTM_WEIGHTMATRIX_H_ 20 21 #include <memory> 22 #include <vector> 23 #include "intsimdmatrix.h" 24 #include "matrix.h" 25 #include "tesstypes.h" 26 #include "tprintf.h" 27 28 namespace tesseract { 29 30 // Convenience instantiation of GENERIC_2D_ARRAY<TFloat> with additional 31 // operations to write a strided vector, so the transposed form of the input 32 // is memory-contiguous. 33 class TransposedArray : public GENERIC_2D_ARRAY<TFloat> { 34 public: 35 // Copies the whole input transposed, converted to TFloat, into *this. 36 void Transpose(const GENERIC_2D_ARRAY<TFloat> &input); 37 // Writes a vector of data representing a timestep (gradients or sources). 38 // The data is assumed to be of size1 in size (the strided dimension). 39 ~TransposedArray() override; WriteStrided(int t,const float * data)40 void WriteStrided(int t, const float *data) { 41 int size1 = dim1(); 42 for (int i = 0; i < size1; ++i) { 43 put(i, t, data[i]); 44 } 45 } WriteStrided(int t,const double * data)46 void WriteStrided(int t, const double *data) { 47 int size1 = dim1(); 48 for (int i = 0; i < size1; ++i) { 49 put(i, t, data[i]); 50 } 51 } 52 // Prints the first and last num elements of the un-transposed array. PrintUnTransposed(int num)53 void PrintUnTransposed(int num) { 54 int num_features = dim1(); 55 int width = dim2(); 56 for (int y = 0; y < num_features; ++y) { 57 for (int t = 0; t < width; ++t) { 58 if (num == 0 || t < num || t + num >= width) { 59 tprintf(" %g", (*this)(y, t)); 60 } 61 } 62 tprintf("\n"); 63 } 64 } 65 }; // class TransposedArray 66 67 // Generic weight matrix for network layers. Can store the matrix as either 68 // an array of floats or int8_t. Provides functions to compute the forward and 69 // backward steps with the matrix and updates to the weights. 70 class WeightMatrix { 71 public: WeightMatrix()72 WeightMatrix() : int_mode_(false), use_adam_(false) {} 73 // Sets up the network for training. Initializes weights using weights of 74 // scale `range` picked according to the random number generator `randomizer`. 75 // Note the order is outputs, inputs, as this is the order of indices to 76 // the matrix, so the adjacent elements are multiplied by the input during 77 // a forward operation. 78 int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer); 79 // Changes the number of outputs to the size of the given code_map, copying 80 // the old weight matrix entries for each output from code_map[output] where 81 // non-negative, and uses the mean (over all outputs) of the existing weights 82 // for all outputs with negative code_map entries. Returns the new number of 83 // weights. 84 int RemapOutputs(const std::vector<int> &code_map); 85 86 // Converts a float network to an int network. Each set of input weights that 87 // corresponds to a single output weight is converted independently: 88 // Compute the max absolute value of the weight set. 89 // Scale so the max absolute value becomes INT8_MAX. 90 // Round to integer. 91 // Store a multiplicative scale factor (as a float) that will reproduce 92 // the original value, subject to rounding errors. 93 void ConvertToInt(); 94 // Returns the size rounded up to an internal factor used by the SIMD 95 // implementation for its input. RoundInputs(int size)96 int RoundInputs(int size) const { 97 if (!int_mode_ || !IntSimdMatrix::intSimdMatrix) { 98 return size; 99 } 100 return IntSimdMatrix::intSimdMatrix->RoundInputs(size); 101 } 102 103 // Accessors. is_int_mode()104 bool is_int_mode() const { 105 return int_mode_; 106 } NumOutputs()107 int NumOutputs() const { 108 return int_mode_ ? wi_.dim1() : wf_.dim1(); 109 } 110 // Provides one set of weights. Only used by peep weight maxpool. GetWeights(int index)111 const TFloat *GetWeights(int index) const { 112 return wf_[index]; 113 } 114 // Provides access to the deltas (dw_). GetDW(int i,int j)115 TFloat GetDW(int i, int j) const { 116 return dw_(i, j); 117 } 118 119 // Allocates any needed memory for running Backward, and zeroes the deltas, 120 // thus eliminating any existing momentum. 121 void InitBackward(); 122 123 // Writes to the given file. Returns false in case of error. 124 bool Serialize(bool training, TFile *fp) const; 125 // Reads from the given file. Returns false in case of error. 126 bool DeSerialize(bool training, TFile *fp); 127 // As DeSerialize, but reads an old (float) format WeightMatrix for 128 // backward compatibility. 129 bool DeSerializeOld(bool training, TFile *fp); 130 131 // Computes matrix.vector v = Wu. 132 // u is of size W.dim2() - 1 and the output v is of size W.dim1(). 133 // u is imagined to have an extra element at the end with value 1, to 134 // implement the bias, but it doesn't actually have it. 135 // Asserts that the call matches what we have. 136 void MatrixDotVector(const TFloat *u, TFloat *v) const; 137 void MatrixDotVector(const int8_t *u, TFloat *v) const; 138 // MatrixDotVector for peep weights, MultiplyAccumulate adds the 139 // component-wise products of *this[0] and v to inout. 140 void MultiplyAccumulate(const TFloat *v, TFloat *inout); 141 // Computes vector.matrix v = uW. 142 // u is of size W.dim1() and the output v is of size W.dim2() - 1. 143 // The last result is discarded, as v is assumed to have an imaginary 144 // last value of 1, as with MatrixDotVector. 145 void VectorDotMatrix(const TFloat *u, TFloat *v) const; 146 // Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements 147 // from u and v, starting with u[i][offset] and v[j][offset]. 148 // Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0. 149 // Runs parallel if requested. Note that inputs must be transposed. 150 void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel); 151 // Updates the weights using the given learning rate, momentum and adam_beta. 152 // num_samples is used in the Adam correction factor. 153 void Update(float learning_rate, float momentum, float adam_beta, int num_samples); 154 // Adds the dw_ in other to the dw_ is *this. 155 void AddDeltas(const WeightMatrix &other); 156 // Sums the products of weight updates in *this and other, splitting into 157 // positive (same direction) in *same and negative (different direction) in 158 // *changed. 159 void CountAlternators(const WeightMatrix &other, TFloat *same, TFloat *changed) const; 160 161 void Debug2D(const char *msg); 162 163 private: 164 // Choice between float and 8 bit int implementations. 165 GENERIC_2D_ARRAY<TFloat> wf_; 166 GENERIC_2D_ARRAY<int8_t> wi_; 167 // Transposed copy of wf_, used only for Backward, and set with each Update. 168 TransposedArray wf_t_; 169 // Which of wf_ and wi_ are we actually using. 170 bool int_mode_; 171 // True if we are running adam in this weight matrix. 172 bool use_adam_; 173 // If we are using wi_, then scales_ is a factor to restore the row product 174 // with a vector to the correct range. 175 std::vector<TFloat> scales_; 176 // Weight deltas. dw_ is the new delta, and updates_ the momentum-decaying 177 // amount to be added to wf_/wi_. 178 GENERIC_2D_ARRAY<TFloat> dw_; 179 GENERIC_2D_ARRAY<TFloat> updates_; 180 // Iff use_adam_, the sum of squares of dw_. The number of samples is 181 // given to Update(). Serialized iff use_adam_. 182 GENERIC_2D_ARRAY<TFloat> dw_sq_sum_; 183 // The weights matrix reorganized in whatever way suits this instance. 184 std::vector<int8_t> shaped_w_; 185 }; 186 187 } // namespace tesseract. 188 189 #endif // TESSERACT_LSTM_WEIGHTMATRIX_H_ 190