1 /* -*- Mode: C++; -*- */ 2 // VER: $Id: ANN.h,v 1.3 2005/08/05 09:02:57 berniw Exp $ 3 // copyright (c) 2004 by Christos Dimitrakakis <dimitrak@idiap.ch> 4 /*************************************************************************** 5 * * 6 * This program is free software; you can redistribute it and/or modify * 7 * it under the terms of the GNU General Public License as published by * 8 * the Free Software Foundation; either version 2 of the License, or * 9 * (at your option) any later version. * 10 * * 11 ***************************************************************************/ 12 13 #ifndef ANN_H 14 #define ANN_H 15 16 #include <assert.h> 17 #include <stdio.h> 18 #include <stdlib.h> 19 #include <math.h> 20 #include <learning/learn_debug.h> 21 #include <learning/string_utils.h> 22 #include <learning/List.h> 23 #include <learning/real.h> 24 25 /** \file ANN.h 26 \brief A neural network implementation. 27 28 A neural network is a parametric function composed of a number of 29 'layers'. Each layer can be expressed as a function \f$g(y) 30 =g(\sum_i w_i f_i (x))\f$, where the \f$w\f$ are a set of weights 31 and \f$f(\cdot)\f$ is a set of basis functions. The basis 32 functions can be fixed or they can be another layer. The neural 33 network can be adapted to minimise some cost criterion \f$C\f$ 34 (defined on some data) via gradient descent. The gradient of the 35 cost with respect to the data is \f$\partial C/\partial x\f$. By 36 expanding this with the chain rule, we have: \f$\partial 37 C/\partial x = \partial g/\partial y \partial y/\partial w 38 \partial w/\partial x\f$. 39 */ 40 41 /** \brief A linear connection between two neural elements. 42 43 It is used to express the sum weights in 44 \f$y = \sum_i w_i f_i (x)\f$. 45 This type of connection currently also supports eligibility traces 46 gradient descent updates, batch updates and variance estimates. 47 */ 48 typedef struct Connection_ { 49 int c; ///< connected? 50 real w; ///< weight 51 real dw; ///< Weight-change 52 real e; ///< eligibility; 53 real v; ///<variance estimate 54 } Connection; 55 56 /** \brief An RBF connection between two neural elements. 57 58 It is used to express: 59 \f$y = \sum_i \big((m_i-f_i(x)) w_i\big)^2\f$. 60 This type of connection has no extra features. When an RBF connection is created through the standard high-level function AddRBFHiddenLayer, then the layer output is \f$g(y) = e^y\f$. 61 */ 62 typedef struct RBFConnection_ { 63 real w; ///< weight (=\f$1/\sigma\f$) 64 real m; ///< mean 65 } RBFConnection; 66 67 /// \brief A collection of connections from one layer to another, plus 68 /// management functions and data. 69 typedef struct Layer_ { 70 int n_inputs; ///< number of inputs 71 int n_outputs; ///< number of outputs 72 real* x; ///< inputs; 73 real* y; ///< outputs 74 real* z; ///< activation 75 real* d; ///< derivatives 76 Connection* c; ///< connections 77 RBFConnection* rbf; ///< rbf connections (if any) 78 real a; ///< learning rate 79 real lambda; ///< eligibility decay 80 real zeta; ///< variance update smoothness. 81 bool batch_mode; ///< do not update weights immediately 82 void (*forward) (struct Layer_* current_layer, bool stochastic); ///< forward calculation 83 real (*backward) (LISTITEM* p, real* d, bool use_eligibility, real TD); ///< partial derivative calculation 84 real (*f) (real x); ///< activation function 85 real (*f_d) (real x); ///< derivative of activation function 86 } Layer; 87 88 /// \brief ANN management structure. 89 typedef struct ANN_ { 90 int n_inputs; ///< number of inputs 91 int n_outputs; ///< number of outputs 92 LIST* c; ///< connection layers 93 real* x; ///< unit inputs 94 real* y; ///< unit activations 95 real* t; ///< targets 96 real* d; ///< delta vector 97 real a; ///< learning rate 98 real lambda; ///< eligibility trace decay 99 real zeta; ///< variance update smoothness 100 real* error; ///< errors 101 bool batch_mode; ///< use batch mode 102 bool eligibility_traces; ///< use eligibility 103 } ANN; 104 105 106 /**************** User-level API ******************/ 107 /* Object Management Interface */ 108 extern ANN* NewANN(int n_inputs, int n_outputs); 109 extern int DeleteANN(ANN* ann); 110 extern ANN* LoadANN(char* filename); 111 extern ANN* LoadANN(FILE* f); 112 extern int SaveANN(ANN* ann, char* filename); 113 extern int SaveANN(ANN* ann, FILE* f); 114 115 /* Setup Inteface */ 116 extern int ANN_AddHiddenLayer(ANN* ann, int n_nodes); 117 extern int ANN_AddRBFHiddenLayer (ANN* ann, int n_nodes); 118 extern int ANN_Init(ANN* ann); 119 extern void ANN_SetOutputsToTanH(ANN* ann); 120 extern void ANN_SetOutputsToLinear(ANN* ann); 121 extern void ANN_SetLearningRate(ANN* ann, real a); 122 extern void ANN_SetLambda(ANN * ann, real lambda); 123 extern void ANN_SetZeta(ANN * ann, real lambda); 124 extern void ANN_Reset(ANN* ann); 125 126 /* Functionality Interface */ 127 extern real ANN_Input(ANN* ann, real* x); 128 extern real ANN_StochasticInput(ANN * ann, real * x); 129 extern real ANN_Train(ANN* ann, real* x, real* t); 130 extern real ANN_Delta_Train(ANN * ann, real* delta, real TD = 0.0); 131 extern void ANN_SetBatchMode(ANN* ann, bool batch); 132 extern void ANN_BatchAdapt(ANN* ann); 133 extern real ANN_Test(ANN* ann, real* x, real* t); 134 extern real* ANN_GetOutput(ANN* ann); 135 extern real ANN_GetError(ANN* ann); 136 extern real* ANN_GetErrorVector(ANN* ann); 137 138 /********* Low-level code **********/ 139 140 /* Sub-object management functions */ 141 extern Layer* ANN_AddLayer (ANN* ann, int n_inputs, int n_outputs, real* x); 142 extern Layer* ANN_AddRBFLayer (ANN* ann, int n_inputs, int n_outputs, real* x); 143 extern void ANN_FreeLayer (void* l); 144 extern void ANN_FreeLayer (Layer* l); 145 146 /* Calculations */ 147 extern void ANN_CalculateLayerOutputs (Layer* current_layer, bool stochastic=false); 148 extern real ANN_Backpropagate (LISTITEM* p, real* d, bool use_eligibility=false, real TD = 0.0); 149 extern void ANN_RBFCalculateLayerOutputs (Layer* current_layer, bool stochastic=false); 150 extern real ANN_RBFBackpropagate (LISTITEM* p, real* d, bool use_eligibility=false, real TD = 0.0); 151 extern void ANN_LayerBatchAdapt (Layer* l); 152 153 /* Output functions and derivatives */ 154 extern real Exp (real x); 155 extern real Exp_d (real x); 156 extern real htan (real x); 157 extern real htan_d (real x); 158 extern real dtan (real x); 159 extern real dtan_d (real x); 160 extern real linear (real x); 161 extern real linear_d (real x); 162 163 /* Debugging functions */ 164 extern real ANN_LayerShowWeights (Layer* l); 165 extern real ANN_ShowWeights(ANN* ann); 166 extern void ANN_ShowOutputs(ANN* ann); 167 extern real ANN_ShowInputs(ANN* ann); 168 extern real ANN_LayerShowInputs(Layer* l); 169 #endif /* ANN_H */ 170