1 /** @file traindata.h
2 
3   @brief CTrainingData header - Neural Network Training Data
4 
5   @author Jakub Ad�mek
6   Last modified $Id: traindata.h,v 1.21 2002/04/22 15:43:23 jakubadamek Exp $
7 */
8 
9 #ifndef __BANG_NNDATA_H__
10 #define __BANG_NNDATA_H__
11 
12 #include "utils.h"
13 #include "xmlstream.h"
14 #include "assert.h"
15 
16 SMALLVECTOR (double, CFloat, TSmallVectorFloat)
17 VECTOR (CFloat, TVectorFloat)
18 //typedef TVectorFloat TSmallVectorFloat;
19 VECTOR (CInt, TVectorInt)
20 MAP (CString, CFloat, TMapStringFloat)
21 
22 /** @brief training pattern - inputs, desired outputs and set */
23 
24 enum enumSetType { ST_NULL=-1, ST_EVAL=0, ST_TRAIN=1 };
25 static const TEnumString SetTypes = "eval;train";
26 
VECTOR(TStream,TStreams)27 VECTOR (TStream, TStreams)
28 
29 STRUCT_BEGIN (TPattern)
30 public:
31 	void Init ()	{ set = ST_NULL; series=0; }
32 	/// data[CU_INPUT], data[CU_OUTPUT]
33 	TSmallVectorFloat data[2];
34 	/// which set contains this pattern - see enumSetType
35 	CInt set;
36 	/// which series contains this pattern
37 	CInt series;
38 	/// the rows are sorted by sets
39 	bool operator < (const TPattern &pat) const
40 	{ return pat.set < set; }
41 STRUCT_END (TPattern)
42 
43 VECTOR (TPattern, TPatterns)
44 
45 STRUCT_BEGIN (TTrainingDataFile)
46 public:
47 	/// streams containing data (to be joined from left to right)
48 	TStreams streams;
49 STRUCT_END (TTrainingDataFile)
50 
51 VECTOR (TTrainingDataFile, TTrainingDataFiles)
52 
53 /** C O L U M N   S T A F F */
54 
55 /** Types of values */
56 enum filterType {
57 	/// boolean - two different values, e.g. 0/1,yes/no
58 	FT_BOOL,
59 	/// integer number
60 	FT_INT,
61 	/// floating point number
62 	FT_FLOAT,
63 	/// category (text value), each row contains one
64 	FT_CATEGORY,
65 	/// category (text value), each row contains any number of them (0-n)
66 	FT_MULTICATEGORY
67 };
68 
69 static const TEnumString FTs = "bool;int;float;category;multicategory";
70 
71 /** How to translate values in column */
72 enum filterTranslate {
73 	/** Default for everything except multi-category
74 		- linearly transform from [min;max] into [-1;1]
75 		For categories: assign order values (0,1,2,...) to them and linearly transform.
76 		Not usable for multi-categories.
77 	*/
78 	FTR_FLOAT,
79 	/** uses the parameters "avg" and "stdev": x -> (x - avg) / stdev */
80 	FTR_LINEAR,
81 	/** For category / multi-category - assign one input 0/1 to every possible value */
82 	FTR_BOOLS,
83 	/** Don't translate - use values as they are */
84 	FTR_NONE
85 };
86 
87 static const TEnumString FTRs = "float;linear;bools;none";
88 
89 
STRUCT_BEGIN(TFilter)90 STRUCT_BEGIN (TFilter)
91 	void Init ()	{ type=FT_FLOAT; translate=FTR_FLOAT; minR=0; maxR=0;
92 			  toMinR=-1; toMaxR=1; toAvg=0; toStdev=1; avg=0; stdev=1;
93  			  digits=-1; }
94 
95 
96 	/// see enum filterType
97 	CInt type;
98 	/// see enum filterTranslate
99 	CInt translate;
100 	/// interval to transform from by FLOAT
101 	CFloat minR, maxR;
102 	/// interval to transform to by FLOAT (usually [-1;+1])
103 	CFloat toMinR, toMaxR;
104 	/// average and standard deviation for LINEAR transformation
105 	CFloat avg, stdev;
106 	/// average and standard deviation after LINEAR transformation (usually 0,1)
107 	CFloat toAvg, toStdev;
108 	/// categories
109 	TEnumString categories;
110 	/// for doubles used as categories - number of decimal digits when converting to string
111 	CInt digits;
112 	/// number of outputs when preprocessing
113 	inline int getWidth() const;
114 
115 	/** PRE- and POST- PROCESSING
116 		Following functions all use the last argument to place results.
117 		That's because some of them may place more than one results and also
118 		the result type may be specified by overloading, not by function name. */
119 
120 	/// returns preprocessed value (or array of values on FTR_BOOLS)
121 	inline void preprocess (double value, double &result);
122 	/// returns preprocessed value (or array of values on FTR_BOOLS)
123 	inline void preprocess (const CString &value, double &result);
124 
125 	/** returns postprocessed double value.	Inverse to preprocess. */
126 	inline void postprocess (const double &value, double &result) const;
127 	/** returns postprocessed string value. Inverse to preprocess. */
128 	inline void postprocess (const double &value, CString &result) const;
129 
130 	/** returns postprocessed double value, but doesn't add the average - only multiplies */
131 	inline void postprocessError (const double &value, double &result) const;
132 STRUCT_END (TFilter)
133 
134 /** Usage of column */
135 enum colUse {
136 	/// input
137 	CU_INPUT,
138 	/// output
139 	CU_OUTPUT,
140 	/// not used
141 	CU_NO
142 };
143 
144 static const TEnumString CUs = "input;output;no";
145 
STRUCT_BEGIN_FROM(TColumn,TFilter)146 STRUCT_BEGIN_FROM (TColumn, TFilter)
147 	void Init ()	{ use=CU_NO; pos = 0; }
148 	/// see enum colUse
149 	CInt use;
150 	/// starting position in the input or output vector
151 	CInt pos;
152 
153 	/// window - use only for series: take appropriate values from neighbours
STRUCT_BEGIN(TWindow)154 	STRUCT_BEGIN (TWindow)
155 		void Init()	{ left = right = 0; empty = "0"; }
156 
157 		CInt left, right;
158 		/// default value when no such neighbours exist (e.g. left from 1st)
159 		CString empty;
160 	STRUCT_END (TWindow)
161 	TWindow window;
162 
163 	inline int getColWidth() const;
164 
165 	/** Returns appropriate input or output column value. For CTR_BOOLS returns the first value.
166 		"shift" relates to window, ranges in [-left;+right] */
167 	inline double &val (TPattern &pat, int shift=0) const;
168 	inline double val (const TPattern &pat, int shift=0) const;
169 	inline void copy (const TPattern &src, int shiftSrc, TPattern &dst, int shiftDst);
STRUCT_END_FROM(TColumn,TFilter)170 STRUCT_END_FROM (TColumn, TFilter)
171 VECTOR (TColumn, TColumns)
172 
173 
174 /** @brief Controls manipulation with training data. Allows to define sets of training data.
175 
176 	This class has an iterator holding current position in the data. You can move it with move....
177 */
178 
179 STRUCT_BEGIN (CTrainingData)
180 public:
181 	void Init ()	{ pos=data.end(); series=false; colCount[CU_INPUT]=colCount[CU_OUTPUT]=0; }
182 
183 	/// Moves position to beginning. Returns: true if position valid.
moveFirst()184 	bool moveFirst () const					{ return (pos = data.begin()) != data.end(); }
moveNext()185 	bool moveNext () const					{ return ++pos != data.end(); }
186 	/** Moves position to next row. Returns: true if position valid.
187 		Inline to be quickly. */
188 	bool moveNext (CInt set) const;
189 	bool movePrev (CInt set) const;
190 
191 	/// Moves position to start of given set. Returns: true if position valid
192 	bool moveToSetStart (CInt setType) const;
193 
194 	/// Moves position to the row of a given index. Returns: true if position valid
195 	bool moveToRow (CInt row) const;
196 
197 	/// getSet returns: set of current row or ST_NULL if position not valid
getSet()198 	inline CInt getSet () const
199 	{ if (pos == data.end()) return ST_NULL; else return (*pos).set; }
200 
201 	/// getInputs returns: inputs on current row
getInputs()202 	inline const TSmallVectorFloat &getInputs () const	{ return pos->data[CU_INPUT]; }
203 	/// getOutputs returns: desired outputs on current row
getOutputs()204 	inline const TSmallVectorFloat &getOutputs () const	{ return pos->data[CU_OUTPUT]; }
modifyOutputs()205 	inline TSmallVectorFloat &modifyOutputs ()		{ return pos->data[CU_OUTPUT]; }
206 
getPattern()207 	inline const TPattern &getPattern () const		{ return *pos; }
getSeries()208 	inline int getSeries () const
209 	{ if (pos == data.end()) return -1; else return (int)pos->series; }
getPos()210 	inline TPatterns::iterator getPos () const		{ return pos; }
setPos(TPatterns::iterator x)211 	inline void setPos (TPatterns::iterator x) const	{ pos = x; }
212 
STRUCT_BEGIN(TRange)213 	STRUCT_BEGIN (TRange)
214 		void Init ()		{ from=0; to=0; random=0; rest=false; type=ST_NULL; }
215 		/** negative values mean size in percent; to = 0 means until the end of file */
216 		CInt from, to;
217 		/** Number of lines to be randomly chosen from this range. Zero means don't use random. */
218 		CInt random;
219 		/** Bool: Use the rest of the range - useful especially when another set uses random lines */
220 		CInt rest;
221 		/// type of data set - see enumSetType
222 		CInt type;
223 	STRUCT_END (TRange)
224 	VECTORC (TRange, TRanges)
225 
226 	/// goes through ranges and sets the set info to rows. Returns error or empty string.
227 	CString processRanges ();
228 
229 	/// goes through all data and columns and sets the window results
230 	void processWindows ();
231 
getColumns()232 	const TColumns &getColumns () const	{ return columns; }
getColCount(colUse use)233 	CInt getColCount(colUse use) const	{ assert (use==CU_INPUT || use==CU_OUTPUT); return colCount[use]; }
getRowCount(enumSetType set)234 	CInt getRowCount(enumSetType set) const	{ return rowCount[set]; }
235 	void addColumn (TColumn col);
addFile(const TTrainingDataFile & f)236 	void addFile (const TTrainingDataFile &f){ files.push_back (f); }
237 
238 	/// sets directly the column value (caller must know about columns), uses the column translation
239 	//inline void setColDirect (colUse use, int icol, double value);
240 	/** returns double column value translated from direct values */
241 	inline void getPostprocessed (colUse use, int col, double &result);
242 	/** returns string column value translated from direct values - not yet implemented */
243 	inline void getPostprocessed (colUse use, int col, CString &result);
244 
245 	/// returns input or output column of given index (starting with 0)
246 	const TColumn &getColumn (colUse use, int col) const;
247 
248 	/// where to dump configuration (including columns) after processing data - if empty, nowhere
249 	CString dumpCfg;
250 	/// bool: go to through columns and set their translations or not?
251 	CInt examineColumns;
252 
253 	/// finds minR, maxR, avg, stdev for all columns
254 	void examine ();
255 
256 	/// sets the colIndexes array
257 	void setColIndexes ();
258 
259 	int getSeriesCount(enumSetType set) const;
260 
261 	/// deletes all rows
clear()262 	void clear ()	{ data.DeleteAll(); pos = data.end(); }
263 	/// adds new row - set values by fillColumn functions
264 	void addRow (enumSetType set, int series = 0);
265 	/// preprocesses and sets value on current row
266 	inline void fillColumn (colUse use, int col, double value);
267 	/// preprocesses and sets value on current row
268 	inline void fillColumn (colUse use, int col, const CString &value);
269 
270 	TRanges ranges;
271 
272 	/** indexes of columns - e.g. columns' use is CU_NO,CU_INPUT,CU_OUTPUT,...
273 		=> colIndexes[CU_INPUT][0]=1, colIndexes[CU_OUTPUT][0]=2 */
274 	TVectorInt colIndexes[2];
275 
276 	/// holds all the training data
277 	TPatterns data;
278 	/// guards the data when a neural network uses them
279 	Mutex dataMutex;
280 	/// row count for training and eval set
281 	CInt rowCount [2];
282 	/// columns count for input and output cols
283 	CInt colCount[2];
284 	TTrainingDataFiles files;
285 	/// columns description
286 	TColumns columns;
287 	/** Bool: Are data organised into series? */
288 	CInt series;
289 	/** If a row begins with this, it separates two series */
290 	CString seriesSeparator;
291 
292 private:
293 	/** reads data from files - returns errors if any occur. The first param is the config file name. */
294 	CString readFromFiles (CString filename, bool onlyCategories = false);
295 	mutable TPatterns::iterator pos;
296 
297 public:
298 	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * */
299 	/*       Configuration READing and PRINTing              */
300 
301 	/// dumps all data to the string
302 	void dumpData (CString &out);
303 
304 	/** reads all data and sets columns - returns error description or empty string.
305 		First param is the config file name. */
306 	CString readData(CString filename);
307 
308 	/// prints the cfg and the data as local in the file
309 	CRox *printAllLocal ();
310 
311 	CRox *printColumns () const;
312 	CString readColumns (const CRox *xml);
313 
STRUCT_END(CTrainingData)314 STRUCT_END (CTrainingData)
315 
316 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
317  *                                                               *
318  *                  C T r a i n i n g D a t a                    *
319  *                                                               *
320  * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
321 
322 inline bool CTrainingData::moveNext (CInt set) const
323 {
324 	if (pos == data.end()) return false;
325 	do ++pos;
326 	while (pos != data.end() && pos->set != set);
327 	return pos != data.end();
328 }
329 
movePrev(CInt set)330 inline bool CTrainingData::movePrev (CInt set) const
331 {
332 	if (pos == data.begin()) return false;
333 	do --pos;
334 	while (pos != data.begin() && pos->set != set);
335 	return pos->set == set;
336 }
337 
fillColumn(colUse use,int col,double value)338 inline void CTrainingData::fillColumn (colUse use, int col, double value)
339 {
340 	TColumns::iterator icol = columns.begin()+colIndexes[use][col];
341 	icol->preprocess (value, icol->val( *pos ));
342 }
343 
getPostprocessed(colUse use,int col,double & result)344 inline void CTrainingData::getPostprocessed (colUse use, int col, double &result)
345 {
346 	TColumns::iterator icol = columns.begin()+colIndexes[use][col];
347 	icol->postprocess (icol->val (*pos), result);
348 }
349 
350 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
351  *                                                               *
352  *                        T F i l t e r		                 *
353  *                                                               *
354  * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
355 
preprocess(double value,double & result)356 inline void TFilter::preprocess (double value, double &result)
357 {
358 	switch ((int)translate) {
359 	case FTR_FLOAT: // transform into [toMinR;toMaxR]
360 		if (maxR == minR) result = 0;
361 		else result = (value - minR) * (toMaxR - toMinR) / (maxR - minR) + toMinR; break;
362 	case FTR_LINEAR:
363 		if (stdev == 0) result = 0;
364 		else result = (value - avg) * toStdev / stdev + toAvg; break;
365 	case FTR_NONE:
366 		result = value; break;
367 	case FTR_BOOLS:
368 		if ((int)type == FT_CATEGORY) {
369 			preprocess (toString (value,digits), result);
370 			break;
371 		}
372 	default: // not yet implemented
373 		assert (false);
374 		result = 0; break;
375 	}
376 }
377 
preprocess(const CString & value,double & result)378 inline void TFilter::preprocess (const CString &value, double &result)
379 {
380 	switch ((int)type) {
381 	case FT_INT:
382 	case FT_FLOAT:
383 		preprocess (atof (value.c_str()), result); break;
384 	case FT_CATEGORY: {
385 		int category = categories[value];
386 		switch ((int)translate) {
387 		case FTR_FLOAT:
388 		case FTR_LINEAR:
389 			preprocess (category, result);
390 			break;
391 		case FTR_BOOLS: {
392 			for (int i=0; i < categories.size(); ++i)
393 				(&result)[i] = i == category ? 1 : 0;
394 			break;	}
395 		default: assert (false); //not yet implemented
396 		}
397 		break;	  }
398 	default: assert (false); //not yet implemented
399 	}
400 }
401 
postprocessError(const double & value,double & result)402 inline void TFilter::postprocessError (const double &value, double &result) const
403 {
404 	switch ((int)translate) {
405 	case FTR_FLOAT:
406 		result = value * (maxR - minR) / (toMaxR - toMinR); break;
407 	case FTR_LINEAR:
408 		result = value * stdev / toStdev; break;
409 	case FTR_NONE:
410 		result = value; break;
411 	default: // not yet implemented
412 		assert (false);
413 		result = 0; break;
414 	}
415 }
416 
postprocess(const double & value,double & result)417 inline void TFilter::postprocess (const double &value, double &result) const
418 {
419 	switch ((int)translate) {
420 	case FTR_FLOAT:
421 		result = (value - toMinR) * (maxR - minR) / (toMaxR - toMinR) + minR; break;
422 	case FTR_LINEAR:
423 		result = (value - toAvg) * stdev / toStdev + avg; break;
424 	case FTR_NONE:
425 		result = value; break;
426 	default: // not yet implemented
427 		assert (false);
428 		result = 0; break;
429 	}
430 }
431 
postprocess(const double & value,CString & result)432 inline void TFilter::postprocess (const double &value, CString &result) const
433 {
434 	switch ((int)translate) {
435 	case FTR_FLOAT:
436 	case FTR_LINEAR: {
437 		double dresult;
438 		postprocess (value, dresult);
439 		result = toString (dresult);
440 		break; }
441 	case FTR_BOOLS: {
442 		assert (type == FT_CATEGORY);
443 		int i;
444 		for (i=0; i < categories.size(); ++i)
445 			if ((&value)[i] == 1) {
446 				result = categories[i];
447 				return;
448 			}
449 		result = -1;
450 		break; }
451 	case FTR_NONE:
452 		result = toString (value);
453 		break;
454 	default: // not yet implemented
455 		assert (false);
456 		result = "";
457 		break;
458 	}
459 }
460 
getWidth()461 inline int TFilter::getWidth() const
462 {
463 	if (translate != FTR_BOOLS)
464 		return 1;
465 	else return categories.size();
466 }
467 
468 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
469  *                                                               *
470  *                         T C o lu m n		                 *
471  *                                                               *
472  * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
473 
getColumn(colUse use,int icol)474 inline const TColumn &CTrainingData::getColumn (colUse use, int icol) const
475 {
476 	assert (use == CU_INPUT || use == CU_OUTPUT);
477 	return columns[colIndexes[use][icol]];
478 }
479 
val(TPattern & pat,int shift)480 inline double &TColumn::val (TPattern &pat, int shift) const
481 {
482 	assert (shift >= -window.left && shift <= window.right && use != CU_NO);
483 	return pat.data[use][pos + (window.left+shift) * TFilter::getWidth()];
484 }
485 
val(const TPattern & pat,int shift)486 inline double TColumn::val (const TPattern &pat, int shift) const
487 {
488 	assert (shift >= -window.left && shift <= window.right && use != CU_NO);
489 	return pat.data[use][pos + (window.left+shift) * TFilter::getWidth()];
490 }
491 
copy(const TPattern & src,int shiftSrc,TPattern & dst,int shiftDst)492 inline void TColumn::copy (const TPattern &src, int shiftSrc, TPattern &dst, int shiftDst)
493 {
494 	assert (shiftSrc >= -window.left && shiftSrc <= window.right && use != CU_NO
495 		&& shiftDst >= -window.left && shiftDst <= window.right);
496 	if (use == CU_NO) return;
497 	for (int i=0; i < TFilter::getWidth(); ++i)
498 		dst.data[use][pos + (window.left+shiftDst) * TFilter::getWidth() + i]
499 		= src.data[use][pos + (window.left+shiftSrc) * TFilter::getWidth() + i];
500 }
501 
getColWidth()502 inline int TColumn::getColWidth() const
503 {
504 	return (window.left + window.right + 1) * TFilter::getWidth();
505 }
506 
507 #endif
508 
509