1 /** @file traindata.h
2
3 @brief CTrainingData header - Neural Network Training Data
4
5 @author Jakub Ad�mek
6 Last modified $Id: traindata.h,v 1.21 2002/04/22 15:43:23 jakubadamek Exp $
7 */
8
9 #ifndef __BANG_NNDATA_H__
10 #define __BANG_NNDATA_H__
11
12 #include "utils.h"
13 #include "xmlstream.h"
14 #include "assert.h"
15
16 SMALLVECTOR (double, CFloat, TSmallVectorFloat)
17 VECTOR (CFloat, TVectorFloat)
18 //typedef TVectorFloat TSmallVectorFloat;
19 VECTOR (CInt, TVectorInt)
20 MAP (CString, CFloat, TMapStringFloat)
21
22 /** @brief training pattern - inputs, desired outputs and set */
23
24 enum enumSetType { ST_NULL=-1, ST_EVAL=0, ST_TRAIN=1 };
25 static const TEnumString SetTypes = "eval;train";
26
VECTOR(TStream,TStreams)27 VECTOR (TStream, TStreams)
28
29 STRUCT_BEGIN (TPattern)
30 public:
31 void Init () { set = ST_NULL; series=0; }
32 /// data[CU_INPUT], data[CU_OUTPUT]
33 TSmallVectorFloat data[2];
34 /// which set contains this pattern - see enumSetType
35 CInt set;
36 /// which series contains this pattern
37 CInt series;
38 /// the rows are sorted by sets
39 bool operator < (const TPattern &pat) const
40 { return pat.set < set; }
41 STRUCT_END (TPattern)
42
43 VECTOR (TPattern, TPatterns)
44
45 STRUCT_BEGIN (TTrainingDataFile)
46 public:
47 /// streams containing data (to be joined from left to right)
48 TStreams streams;
49 STRUCT_END (TTrainingDataFile)
50
51 VECTOR (TTrainingDataFile, TTrainingDataFiles)
52
53 /** C O L U M N S T A F F */
54
55 /** Types of values */
56 enum filterType {
57 /// boolean - two different values, e.g. 0/1,yes/no
58 FT_BOOL,
59 /// integer number
60 FT_INT,
61 /// floating point number
62 FT_FLOAT,
63 /// category (text value), each row contains one
64 FT_CATEGORY,
65 /// category (text value), each row contains any number of them (0-n)
66 FT_MULTICATEGORY
67 };
68
69 static const TEnumString FTs = "bool;int;float;category;multicategory";
70
71 /** How to translate values in column */
72 enum filterTranslate {
73 /** Default for everything except multi-category
74 - linearly transform from [min;max] into [-1;1]
75 For categories: assign order values (0,1,2,...) to them and linearly transform.
76 Not usable for multi-categories.
77 */
78 FTR_FLOAT,
79 /** uses the parameters "avg" and "stdev": x -> (x - avg) / stdev */
80 FTR_LINEAR,
81 /** For category / multi-category - assign one input 0/1 to every possible value */
82 FTR_BOOLS,
83 /** Don't translate - use values as they are */
84 FTR_NONE
85 };
86
87 static const TEnumString FTRs = "float;linear;bools;none";
88
89
STRUCT_BEGIN(TFilter)90 STRUCT_BEGIN (TFilter)
91 void Init () { type=FT_FLOAT; translate=FTR_FLOAT; minR=0; maxR=0;
92 toMinR=-1; toMaxR=1; toAvg=0; toStdev=1; avg=0; stdev=1;
93 digits=-1; }
94
95
96 /// see enum filterType
97 CInt type;
98 /// see enum filterTranslate
99 CInt translate;
100 /// interval to transform from by FLOAT
101 CFloat minR, maxR;
102 /// interval to transform to by FLOAT (usually [-1;+1])
103 CFloat toMinR, toMaxR;
104 /// average and standard deviation for LINEAR transformation
105 CFloat avg, stdev;
106 /// average and standard deviation after LINEAR transformation (usually 0,1)
107 CFloat toAvg, toStdev;
108 /// categories
109 TEnumString categories;
110 /// for doubles used as categories - number of decimal digits when converting to string
111 CInt digits;
112 /// number of outputs when preprocessing
113 inline int getWidth() const;
114
115 /** PRE- and POST- PROCESSING
116 Following functions all use the last argument to place results.
117 That's because some of them may place more than one results and also
118 the result type may be specified by overloading, not by function name. */
119
120 /// returns preprocessed value (or array of values on FTR_BOOLS)
121 inline void preprocess (double value, double &result);
122 /// returns preprocessed value (or array of values on FTR_BOOLS)
123 inline void preprocess (const CString &value, double &result);
124
125 /** returns postprocessed double value. Inverse to preprocess. */
126 inline void postprocess (const double &value, double &result) const;
127 /** returns postprocessed string value. Inverse to preprocess. */
128 inline void postprocess (const double &value, CString &result) const;
129
130 /** returns postprocessed double value, but doesn't add the average - only multiplies */
131 inline void postprocessError (const double &value, double &result) const;
132 STRUCT_END (TFilter)
133
134 /** Usage of column */
135 enum colUse {
136 /// input
137 CU_INPUT,
138 /// output
139 CU_OUTPUT,
140 /// not used
141 CU_NO
142 };
143
144 static const TEnumString CUs = "input;output;no";
145
STRUCT_BEGIN_FROM(TColumn,TFilter)146 STRUCT_BEGIN_FROM (TColumn, TFilter)
147 void Init () { use=CU_NO; pos = 0; }
148 /// see enum colUse
149 CInt use;
150 /// starting position in the input or output vector
151 CInt pos;
152
153 /// window - use only for series: take appropriate values from neighbours
STRUCT_BEGIN(TWindow)154 STRUCT_BEGIN (TWindow)
155 void Init() { left = right = 0; empty = "0"; }
156
157 CInt left, right;
158 /// default value when no such neighbours exist (e.g. left from 1st)
159 CString empty;
160 STRUCT_END (TWindow)
161 TWindow window;
162
163 inline int getColWidth() const;
164
165 /** Returns appropriate input or output column value. For CTR_BOOLS returns the first value.
166 "shift" relates to window, ranges in [-left;+right] */
167 inline double &val (TPattern &pat, int shift=0) const;
168 inline double val (const TPattern &pat, int shift=0) const;
169 inline void copy (const TPattern &src, int shiftSrc, TPattern &dst, int shiftDst);
STRUCT_END_FROM(TColumn,TFilter)170 STRUCT_END_FROM (TColumn, TFilter)
171 VECTOR (TColumn, TColumns)
172
173
174 /** @brief Controls manipulation with training data. Allows to define sets of training data.
175
176 This class has an iterator holding current position in the data. You can move it with move....
177 */
178
179 STRUCT_BEGIN (CTrainingData)
180 public:
181 void Init () { pos=data.end(); series=false; colCount[CU_INPUT]=colCount[CU_OUTPUT]=0; }
182
183 /// Moves position to beginning. Returns: true if position valid.
moveFirst()184 bool moveFirst () const { return (pos = data.begin()) != data.end(); }
moveNext()185 bool moveNext () const { return ++pos != data.end(); }
186 /** Moves position to next row. Returns: true if position valid.
187 Inline to be quickly. */
188 bool moveNext (CInt set) const;
189 bool movePrev (CInt set) const;
190
191 /// Moves position to start of given set. Returns: true if position valid
192 bool moveToSetStart (CInt setType) const;
193
194 /// Moves position to the row of a given index. Returns: true if position valid
195 bool moveToRow (CInt row) const;
196
197 /// getSet returns: set of current row or ST_NULL if position not valid
getSet()198 inline CInt getSet () const
199 { if (pos == data.end()) return ST_NULL; else return (*pos).set; }
200
201 /// getInputs returns: inputs on current row
getInputs()202 inline const TSmallVectorFloat &getInputs () const { return pos->data[CU_INPUT]; }
203 /// getOutputs returns: desired outputs on current row
getOutputs()204 inline const TSmallVectorFloat &getOutputs () const { return pos->data[CU_OUTPUT]; }
modifyOutputs()205 inline TSmallVectorFloat &modifyOutputs () { return pos->data[CU_OUTPUT]; }
206
getPattern()207 inline const TPattern &getPattern () const { return *pos; }
getSeries()208 inline int getSeries () const
209 { if (pos == data.end()) return -1; else return (int)pos->series; }
getPos()210 inline TPatterns::iterator getPos () const { return pos; }
setPos(TPatterns::iterator x)211 inline void setPos (TPatterns::iterator x) const { pos = x; }
212
STRUCT_BEGIN(TRange)213 STRUCT_BEGIN (TRange)
214 void Init () { from=0; to=0; random=0; rest=false; type=ST_NULL; }
215 /** negative values mean size in percent; to = 0 means until the end of file */
216 CInt from, to;
217 /** Number of lines to be randomly chosen from this range. Zero means don't use random. */
218 CInt random;
219 /** Bool: Use the rest of the range - useful especially when another set uses random lines */
220 CInt rest;
221 /// type of data set - see enumSetType
222 CInt type;
223 STRUCT_END (TRange)
224 VECTORC (TRange, TRanges)
225
226 /// goes through ranges and sets the set info to rows. Returns error or empty string.
227 CString processRanges ();
228
229 /// goes through all data and columns and sets the window results
230 void processWindows ();
231
getColumns()232 const TColumns &getColumns () const { return columns; }
getColCount(colUse use)233 CInt getColCount(colUse use) const { assert (use==CU_INPUT || use==CU_OUTPUT); return colCount[use]; }
getRowCount(enumSetType set)234 CInt getRowCount(enumSetType set) const { return rowCount[set]; }
235 void addColumn (TColumn col);
addFile(const TTrainingDataFile & f)236 void addFile (const TTrainingDataFile &f){ files.push_back (f); }
237
238 /// sets directly the column value (caller must know about columns), uses the column translation
239 //inline void setColDirect (colUse use, int icol, double value);
240 /** returns double column value translated from direct values */
241 inline void getPostprocessed (colUse use, int col, double &result);
242 /** returns string column value translated from direct values - not yet implemented */
243 inline void getPostprocessed (colUse use, int col, CString &result);
244
245 /// returns input or output column of given index (starting with 0)
246 const TColumn &getColumn (colUse use, int col) const;
247
248 /// where to dump configuration (including columns) after processing data - if empty, nowhere
249 CString dumpCfg;
250 /// bool: go to through columns and set their translations or not?
251 CInt examineColumns;
252
253 /// finds minR, maxR, avg, stdev for all columns
254 void examine ();
255
256 /// sets the colIndexes array
257 void setColIndexes ();
258
259 int getSeriesCount(enumSetType set) const;
260
261 /// deletes all rows
clear()262 void clear () { data.DeleteAll(); pos = data.end(); }
263 /// adds new row - set values by fillColumn functions
264 void addRow (enumSetType set, int series = 0);
265 /// preprocesses and sets value on current row
266 inline void fillColumn (colUse use, int col, double value);
267 /// preprocesses and sets value on current row
268 inline void fillColumn (colUse use, int col, const CString &value);
269
270 TRanges ranges;
271
272 /** indexes of columns - e.g. columns' use is CU_NO,CU_INPUT,CU_OUTPUT,...
273 => colIndexes[CU_INPUT][0]=1, colIndexes[CU_OUTPUT][0]=2 */
274 TVectorInt colIndexes[2];
275
276 /// holds all the training data
277 TPatterns data;
278 /// guards the data when a neural network uses them
279 Mutex dataMutex;
280 /// row count for training and eval set
281 CInt rowCount [2];
282 /// columns count for input and output cols
283 CInt colCount[2];
284 TTrainingDataFiles files;
285 /// columns description
286 TColumns columns;
287 /** Bool: Are data organised into series? */
288 CInt series;
289 /** If a row begins with this, it separates two series */
290 CString seriesSeparator;
291
292 private:
293 /** reads data from files - returns errors if any occur. The first param is the config file name. */
294 CString readFromFiles (CString filename, bool onlyCategories = false);
295 mutable TPatterns::iterator pos;
296
297 public:
298 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * */
299 /* Configuration READing and PRINTing */
300
301 /// dumps all data to the string
302 void dumpData (CString &out);
303
304 /** reads all data and sets columns - returns error description or empty string.
305 First param is the config file name. */
306 CString readData(CString filename);
307
308 /// prints the cfg and the data as local in the file
309 CRox *printAllLocal ();
310
311 CRox *printColumns () const;
312 CString readColumns (const CRox *xml);
313
STRUCT_END(CTrainingData)314 STRUCT_END (CTrainingData)
315
316 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
317 * *
318 * C T r a i n i n g D a t a *
319 * *
320 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
321
322 inline bool CTrainingData::moveNext (CInt set) const
323 {
324 if (pos == data.end()) return false;
325 do ++pos;
326 while (pos != data.end() && pos->set != set);
327 return pos != data.end();
328 }
329
movePrev(CInt set)330 inline bool CTrainingData::movePrev (CInt set) const
331 {
332 if (pos == data.begin()) return false;
333 do --pos;
334 while (pos != data.begin() && pos->set != set);
335 return pos->set == set;
336 }
337
fillColumn(colUse use,int col,double value)338 inline void CTrainingData::fillColumn (colUse use, int col, double value)
339 {
340 TColumns::iterator icol = columns.begin()+colIndexes[use][col];
341 icol->preprocess (value, icol->val( *pos ));
342 }
343
getPostprocessed(colUse use,int col,double & result)344 inline void CTrainingData::getPostprocessed (colUse use, int col, double &result)
345 {
346 TColumns::iterator icol = columns.begin()+colIndexes[use][col];
347 icol->postprocess (icol->val (*pos), result);
348 }
349
350 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
351 * *
352 * T F i l t e r *
353 * *
354 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
355
preprocess(double value,double & result)356 inline void TFilter::preprocess (double value, double &result)
357 {
358 switch ((int)translate) {
359 case FTR_FLOAT: // transform into [toMinR;toMaxR]
360 if (maxR == minR) result = 0;
361 else result = (value - minR) * (toMaxR - toMinR) / (maxR - minR) + toMinR; break;
362 case FTR_LINEAR:
363 if (stdev == 0) result = 0;
364 else result = (value - avg) * toStdev / stdev + toAvg; break;
365 case FTR_NONE:
366 result = value; break;
367 case FTR_BOOLS:
368 if ((int)type == FT_CATEGORY) {
369 preprocess (toString (value,digits), result);
370 break;
371 }
372 default: // not yet implemented
373 assert (false);
374 result = 0; break;
375 }
376 }
377
preprocess(const CString & value,double & result)378 inline void TFilter::preprocess (const CString &value, double &result)
379 {
380 switch ((int)type) {
381 case FT_INT:
382 case FT_FLOAT:
383 preprocess (atof (value.c_str()), result); break;
384 case FT_CATEGORY: {
385 int category = categories[value];
386 switch ((int)translate) {
387 case FTR_FLOAT:
388 case FTR_LINEAR:
389 preprocess (category, result);
390 break;
391 case FTR_BOOLS: {
392 for (int i=0; i < categories.size(); ++i)
393 (&result)[i] = i == category ? 1 : 0;
394 break; }
395 default: assert (false); //not yet implemented
396 }
397 break; }
398 default: assert (false); //not yet implemented
399 }
400 }
401
postprocessError(const double & value,double & result)402 inline void TFilter::postprocessError (const double &value, double &result) const
403 {
404 switch ((int)translate) {
405 case FTR_FLOAT:
406 result = value * (maxR - minR) / (toMaxR - toMinR); break;
407 case FTR_LINEAR:
408 result = value * stdev / toStdev; break;
409 case FTR_NONE:
410 result = value; break;
411 default: // not yet implemented
412 assert (false);
413 result = 0; break;
414 }
415 }
416
postprocess(const double & value,double & result)417 inline void TFilter::postprocess (const double &value, double &result) const
418 {
419 switch ((int)translate) {
420 case FTR_FLOAT:
421 result = (value - toMinR) * (maxR - minR) / (toMaxR - toMinR) + minR; break;
422 case FTR_LINEAR:
423 result = (value - toAvg) * stdev / toStdev + avg; break;
424 case FTR_NONE:
425 result = value; break;
426 default: // not yet implemented
427 assert (false);
428 result = 0; break;
429 }
430 }
431
postprocess(const double & value,CString & result)432 inline void TFilter::postprocess (const double &value, CString &result) const
433 {
434 switch ((int)translate) {
435 case FTR_FLOAT:
436 case FTR_LINEAR: {
437 double dresult;
438 postprocess (value, dresult);
439 result = toString (dresult);
440 break; }
441 case FTR_BOOLS: {
442 assert (type == FT_CATEGORY);
443 int i;
444 for (i=0; i < categories.size(); ++i)
445 if ((&value)[i] == 1) {
446 result = categories[i];
447 return;
448 }
449 result = -1;
450 break; }
451 case FTR_NONE:
452 result = toString (value);
453 break;
454 default: // not yet implemented
455 assert (false);
456 result = "";
457 break;
458 }
459 }
460
getWidth()461 inline int TFilter::getWidth() const
462 {
463 if (translate != FTR_BOOLS)
464 return 1;
465 else return categories.size();
466 }
467
468 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
469 * *
470 * T C o lu m n *
471 * *
472 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
473
getColumn(colUse use,int icol)474 inline const TColumn &CTrainingData::getColumn (colUse use, int icol) const
475 {
476 assert (use == CU_INPUT || use == CU_OUTPUT);
477 return columns[colIndexes[use][icol]];
478 }
479
val(TPattern & pat,int shift)480 inline double &TColumn::val (TPattern &pat, int shift) const
481 {
482 assert (shift >= -window.left && shift <= window.right && use != CU_NO);
483 return pat.data[use][pos + (window.left+shift) * TFilter::getWidth()];
484 }
485
val(const TPattern & pat,int shift)486 inline double TColumn::val (const TPattern &pat, int shift) const
487 {
488 assert (shift >= -window.left && shift <= window.right && use != CU_NO);
489 return pat.data[use][pos + (window.left+shift) * TFilter::getWidth()];
490 }
491
copy(const TPattern & src,int shiftSrc,TPattern & dst,int shiftDst)492 inline void TColumn::copy (const TPattern &src, int shiftSrc, TPattern &dst, int shiftDst)
493 {
494 assert (shiftSrc >= -window.left && shiftSrc <= window.right && use != CU_NO
495 && shiftDst >= -window.left && shiftDst <= window.right);
496 if (use == CU_NO) return;
497 for (int i=0; i < TFilter::getWidth(); ++i)
498 dst.data[use][pos + (window.left+shiftDst) * TFilter::getWidth() + i]
499 = src.data[use][pos + (window.left+shiftSrc) * TFilter::getWidth() + i];
500 }
501
getColWidth()502 inline int TColumn::getColWidth() const
503 {
504 return (window.left + window.right + 1) * TFilter::getWidth();
505 }
506
507 #endif
508
509