1from __future__ import with_statement 2 3__docformat__ = "restructuredtext en" 4 5import cPickle as _cPickle 6import warnings as _warnings 7import copy as _copy 8import inspect 9 10import mdp 11from mdp import numx 12 13class NodeException(mdp.MDPException): 14 """Base class for exceptions in `Node` subclasses.""" 15 pass 16 17class InconsistentDimException(NodeException): 18 """Raised when there is a conflict setting the dimensionalities. 19 20 Note that incoming data with conflicting dimensionality raises a normal 21 `NodeException`. 22 """ 23 pass 24 25class TrainingException(NodeException): 26 """Base class for exceptions in the training phase.""" 27 pass 28 29class TrainingFinishedException(TrainingException): 30 """Raised when the `Node.train` method is called although the 31 training phase is closed.""" 32 pass 33 34class IsNotTrainableException(TrainingException): 35 """Raised when the `Node.train` method is called although the 36 node is not trainable.""" 37 pass 38 39class IsNotInvertibleException(NodeException): 40 """Raised when the `Node.inverse` method is called although the 41 node is not invertible.""" 42 pass 43 44 45class NodeMetaclass(type): 46 """A metaclass which copies docstrings from private to public methods. 47 48 This metaclass is meant to overwrite doc-strings of methods like 49 `Node.execute`, `Node.stop_training`, `Node.inverse` with the ones 50 defined in the corresponding private methods `Node._execute`, 51 `Node._stop_training`, `Node._inverse`, etc. 52 53 This makes it possible for subclasses of `Node` to document the usage 54 of public methods, without the need to overwrite the ancestor's methods. 55 """ 56 57 # methods that can overwrite docs: 58 DOC_METHODS = ['_train', '_stop_training', '_execute', '_inverse', 59 '_label', '_prob'] 60 61 def __new__(cls, classname, bases, members): 62 new_cls = super(NodeMetaclass, cls).__new__(cls, classname, 63 bases, members) 64 65 priv_infos = cls._select_private_methods_to_wrap(cls, members) 66 67 # now add the wrappers 68 for wrapper_name, priv_info in priv_infos.iteritems(): 69 # Note: super works because we never wrap in the defining class 70 orig_pubmethod = getattr(super(new_cls, new_cls), wrapper_name) 71 72 priv_info['name'] = wrapper_name 73 # preserve the last non-empty docstring 74 if not priv_info['doc']: 75 priv_info['doc'] = orig_pubmethod.__doc__ 76 77 recursed = hasattr(orig_pubmethod, '_undecorated_') 78 if recursed: 79 undec_pubmethod = orig_pubmethod._undecorated_ 80 priv_info.update(NodeMetaclass._get_infos(undec_pubmethod)) 81 wrapper_method = cls._wrap_function(undec_pubmethod, 82 priv_info) 83 wrapper_method._undecorated_ = undec_pubmethod 84 else: 85 priv_info.update(NodeMetaclass._get_infos(orig_pubmethod)) 86 wrapper_method = cls._wrap_method(priv_info, new_cls) 87 wrapper_method._undecorated_ = orig_pubmethod 88 89 setattr(new_cls, wrapper_name, wrapper_method) 90 return new_cls 91 92 @staticmethod 93 def _get_infos(pubmethod): 94 infos = {} 95 wrapped_info = NodeMetaclass._function_infodict(pubmethod) 96 # Preserve the signature if it still does not end with kwargs 97 # (this is important for binodes). 98 if wrapped_info['kwargs_name'] is None: 99 infos['signature'] = wrapped_info['signature'] 100 infos['argnames'] = wrapped_info['argnames'] 101 infos['defaults'] = wrapped_info['defaults'] 102 return infos 103 104 @staticmethod 105 def _select_private_methods_to_wrap(cls, members): 106 """Select private methods that can overwrite the public docstring. 107 108 Return a dictionary priv_infos[pubname], where the keys are the 109 public name of the private method to be wrapped, 110 and the values are dictionaries with the signature, doc, 111 ... informations of the private methods (see `_function_infodict`). 112 """ 113 priv_infos = {} 114 for privname in cls.DOC_METHODS: 115 if privname in members: 116 # get the name of the corresponding public method 117 pubname = privname[1:] 118 # If the public method has been overwritten in this 119 # subclass, then keep it. 120 # This is also important because we use super in the wrapper 121 # (so the public method in this class would be missed). 122 if pubname not in members: 123 priv_infos[pubname] = cls._function_infodict(members[privname]) 124 return priv_infos 125 126 # The next two functions (originally called get_info, wrapper) 127 # are adapted versions of functions in the 128 # decorator module by Michele Simionato 129 # Version: 2.3.1 (25 July 2008) 130 # Download page: http://pypi.python.org/pypi/decorator 131 # Note: Moving these functions to utils would cause circular import. 132 133 @staticmethod 134 def _function_infodict(func): 135 """ 136 Returns an info dictionary containing: 137 138 - name (the name of the function : str) 139 - argnames (the names of the arguments : list) 140 - defaults (the values of the default arguments : tuple) 141 - signature (the signature without the defaults : str) 142 - doc (the docstring : str) 143 - module (the module name : str) 144 - dict (the function __dict__ : str) 145 - kwargs_name (the name of the kwargs argument, if present, else None) 146 147 >>> def f(self, x=1, y=2, *args, **kw): pass 148 >>> info = getinfo(f) 149 >>> info["name"] 150 'f' 151 >>> info["argnames"] 152 ['self', 'x', 'y', 'args', 'kw'] 153 >>> info["defaults"] 154 (1, 2) 155 >>> info["signature"] 156 'self, x, y, *args, **kw' 157 >>> info["kwargs_name"] 158 kw 159 """ 160 regargs, varargs, varkwargs, defaults = inspect.getargspec(func) 161 argnames = list(regargs) 162 if varargs: 163 argnames.append(varargs) 164 if varkwargs: 165 argnames.append(varkwargs) 166 signature = inspect.formatargspec(regargs, 167 varargs, 168 varkwargs, 169 defaults, 170 formatvalue=lambda value: "")[1:-1] 171 return dict(name=func.__name__, 172 signature=signature, 173 argnames=argnames, 174 kwargs_name=varkwargs, 175 defaults=func.func_defaults, 176 doc=func.__doc__, 177 module=func.__module__, 178 dict=func.__dict__, 179 globals=func.func_globals, 180 closure=func.func_closure) 181 182 @staticmethod 183 def _wrap_function(original_func, wrapper_infodict): 184 """Return a wrapped version of func. 185 186 :param original_func: The function to be wrapped. 187 :param wrapper_infodict: The infodict to use for constructing the 188 wrapper. 189 """ 190 src = ("lambda %(signature)s: _original_func_(%(signature)s)" % 191 wrapper_infodict) 192 wrapped_func = eval(src, dict(_original_func_=original_func)) 193 wrapped_func.__name__ = wrapper_infodict['name'] 194 wrapped_func.__doc__ = wrapper_infodict['doc'] 195 wrapped_func.__module__ = wrapper_infodict['module'] 196 wrapped_func.__dict__.update(wrapper_infodict['dict']) 197 wrapped_func.func_defaults = wrapper_infodict['defaults'] 198 return wrapped_func 199 200 @staticmethod 201 def _wrap_method(wrapper_infodict, cls): 202 """Return a wrapped version of func. 203 204 :param wrapper_infodict: The infodict to be used for constructing the 205 wrapper. 206 :param cls: Class to which the wrapper method will be added, this is 207 used for the super call. 208 """ 209 src = ("lambda %(signature)s: super(_wrapper_class_, _wrapper_class_)." 210 "%(name)s(%(signature)s)" % wrapper_infodict) 211 wrapped_func = eval(src, {"_wrapper_class_": cls}) 212 wrapped_func.__name__ = wrapper_infodict['name'] 213 wrapped_func.__doc__ = wrapper_infodict['doc'] 214 wrapped_func.__module__ = wrapper_infodict['module'] 215 wrapped_func.__dict__.update(wrapper_infodict['dict']) 216 wrapped_func.func_defaults = wrapper_infodict['defaults'] 217 return wrapped_func 218 219 220class Node(object): 221 """A `Node` is the basic building block of an MDP application. 222 223 It represents a data processing element, like for example a learning 224 algorithm, a data filter, or a visualization step. 225 Each node can have one or more training phases, during which the 226 internal structures are learned from training data (e.g. the weights 227 of a neural network are adapted or the covariance matrix is estimated) 228 and an execution phase, where new data can be processed forwards (by 229 processing the data through the node) or backwards (by applying the 230 inverse of the transformation computed by the node if defined). 231 232 Nodes have been designed to be applied to arbitrarily long sets of data: 233 if the underlying algorithms supports it, the internal structures can 234 be updated incrementally by sending multiple batches of data (this is 235 equivalent to online learning if the chunks consists of single 236 observations, or to batch learning if the whole data is sent in a 237 single chunk). It is thus possible to perform computations on amounts 238 of data that would not fit into memory or to generate data on-the-fly. 239 240 A `Node` also defines some utility methods, like for example 241 `copy` and `save`, that return an exact copy of a node and save it 242 in a file, respectively. Additional methods may be present, depending 243 on the algorithm. 244 245 `Node` subclasses should take care of overwriting (if necessary) 246 the functions `is_trainable`, `_train`, `_stop_training`, `_execute`, 247 `is_invertible`, `_inverse`, `_get_train_seq`, and `_get_supported_dtypes`. 248 If you need to overwrite the getters and setters of the 249 node's properties refer to the docstring of `get_input_dim`/`set_input_dim`, 250 `get_output_dim`/`set_output_dim`, and `get_dtype`/`set_dtype`. 251 """ 252 253 __metaclass__ = NodeMetaclass 254 255 def __init__(self, input_dim=None, output_dim=None, dtype=None): 256 """If the input dimension and the output dimension are 257 unspecified, they will be set when the `train` or `execute` 258 method is called for the first time. 259 If dtype is unspecified, it will be inherited from the data 260 it receives at the first call of `train` or `execute`. 261 262 Every subclass must take care of up- or down-casting the internal 263 structures to match this argument (use `_refcast` private 264 method when possible). 265 """ 266 # initialize basic attributes 267 self._input_dim = None 268 self._output_dim = None 269 self._dtype = None 270 # call set functions for properties 271 self.set_input_dim(input_dim) 272 self.set_output_dim(output_dim) 273 self.set_dtype(dtype) 274 275 # skip the training phase if the node is not trainable 276 if not self.is_trainable(): 277 self._training = False 278 self._train_phase = -1 279 self._train_phase_started = False 280 else: 281 # this var stores at which point in the training sequence we are 282 self._train_phase = 0 283 # this var is False if the training of the current phase hasn't 284 # started yet, True otherwise 285 self._train_phase_started = False 286 # this var is False if the complete training is finished 287 self._training = True 288 289 ### properties 290 291 def get_input_dim(self): 292 """Return input dimensions.""" 293 return self._input_dim 294 295 def set_input_dim(self, n): 296 """Set input dimensions. 297 298 Perform sanity checks and then calls ``self._set_input_dim(n)``, which 299 is responsible for setting the internal attribute ``self._input_dim``. 300 Note that subclasses should overwrite `self._set_input_dim` 301 when needed. 302 """ 303 if n is None: 304 pass 305 elif (self._input_dim is not None) and (self._input_dim != n): 306 msg = ("Input dim are set already (%d) " 307 "(%d given)!" % (self.input_dim, n)) 308 raise InconsistentDimException(msg) 309 else: 310 self._set_input_dim(n) 311 312 def _set_input_dim(self, n): 313 self._input_dim = n 314 315 input_dim = property(get_input_dim, 316 set_input_dim, 317 doc="Input dimensions") 318 319 def get_output_dim(self): 320 """Return output dimensions.""" 321 return self._output_dim 322 323 def set_output_dim(self, n): 324 """Set output dimensions. 325 326 Perform sanity checks and then calls ``self._set_output_dim(n)``, which 327 is responsible for setting the internal attribute ``self._output_dim``. 328 Note that subclasses should overwrite `self._set_output_dim` 329 when needed. 330 """ 331 if n is None: 332 pass 333 elif (self._output_dim is not None) and (self._output_dim != n): 334 msg = ("Output dim are set already (%d) " 335 "(%d given)!" % (self.output_dim, n)) 336 raise InconsistentDimException(msg) 337 else: 338 self._set_output_dim(n) 339 340 def _set_output_dim(self, n): 341 self._output_dim = n 342 343 output_dim = property(get_output_dim, 344 set_output_dim, 345 doc="Output dimensions") 346 347 def get_dtype(self): 348 """Return dtype.""" 349 return self._dtype 350 351 def set_dtype(self, t): 352 """Set internal structures' dtype. 353 354 Perform sanity checks and then calls ``self._set_dtype(n)``, which 355 is responsible for setting the internal attribute ``self._dtype``. 356 Note that subclasses should overwrite `self._set_dtype` 357 when needed. 358 """ 359 if t is None: 360 return 361 t = numx.dtype(t) 362 if (self._dtype is not None) and (self._dtype != t): 363 errstr = ("dtype is already set to '%s' " 364 "('%s' given)!" % (t, self.dtype.name)) 365 raise NodeException(errstr) 366 elif t not in self.get_supported_dtypes(): 367 errstr = ("\ndtype '%s' is not supported.\n" 368 "Supported dtypes: %s" % (t.name, 369 [numx.dtype(t).name for t in 370 self.get_supported_dtypes()])) 371 raise NodeException(errstr) 372 else: 373 self._set_dtype(t) 374 375 def _set_dtype(self, t): 376 t = numx.dtype(t) 377 if t not in self.get_supported_dtypes(): 378 raise NodeException('dtype %s not among supported dtypes (%s)' 379 % (str(t), self.get_supported_dtypes())) 380 self._dtype = t 381 382 dtype = property(get_dtype, 383 set_dtype, 384 doc="dtype") 385 386 def _get_supported_dtypes(self): 387 """Return the list of dtypes supported by this node. 388 389 The types can be specified in any format allowed by :numpy:`dtype`. 390 """ 391 # TODO: http://epydoc.sourceforge.net/manual-othermarkup.html#external-api-links for numpy 392 return mdp.utils.get_dtypes('Float') 393 394 def get_supported_dtypes(self): 395 """Return dtypes supported by the node as a list of :numpy:`dtype` 396 objects. 397 398 Note that subclasses should overwrite `self._get_supported_dtypes` 399 when needed.""" 400 return [numx.dtype(t) for t in self._get_supported_dtypes()] 401 402 supported_dtypes = property(get_supported_dtypes, 403 doc="Supported dtypes") 404 405 _train_seq = property(lambda self: self._get_train_seq(), 406 doc="""\ 407 List of tuples:: 408 409 [(training-phase1, stop-training-phase1), 410 (training-phase2, stop_training-phase2), 411 ...] 412 413 By default:: 414 415 _train_seq = [(self._train, self._stop_training)] 416 """) 417 418 def _get_train_seq(self): 419 return [(self._train, self._stop_training)] 420 421 def has_multiple_training_phases(self): 422 """Return True if the node has multiple training phases.""" 423 return len(self._train_seq) > 1 424 425 ### Node states 426 def is_training(self): 427 """Return True if the node is in the training phase, 428 False otherwise.""" 429 return self._training 430 431 def get_current_train_phase(self): 432 """Return the index of the current training phase. 433 434 The training phases are defined in the list `self._train_seq`.""" 435 return self._train_phase 436 437 def get_remaining_train_phase(self): 438 """Return the number of training phases still to accomplish. 439 440 If the node is not trainable then return 0. 441 """ 442 if self.is_trainable(): 443 return len(self._train_seq) - self._train_phase 444 else: 445 return 0 446 447 ### Node capabilities 448 @staticmethod 449 def is_trainable(): 450 """Return True if the node can be trained, False otherwise.""" 451 return True 452 453 @staticmethod 454 def is_invertible(): 455 """Return True if the node can be inverted, False otherwise.""" 456 return True 457 458 ### check functions 459 def _check_input(self, x): 460 # check input rank 461 if not x.ndim == 2: 462 error_str = "x has rank %d, should be 2" % (x.ndim) 463 raise NodeException(error_str) 464 465 # set the input dimension if necessary 466 if self.input_dim is None: 467 self.input_dim = x.shape[1] 468 469 # set the dtype if necessary 470 if self.dtype is None: 471 self.dtype = x.dtype 472 473 # check the input dimension 474 if not x.shape[1] == self.input_dim: 475 error_str = "x has dimension %d, should be %d" % (x.shape[1], 476 self.input_dim) 477 raise NodeException(error_str) 478 479 if x.shape[0] == 0: 480 error_str = "x must have at least one observation (zero given)" 481 raise NodeException(error_str) 482 483 def _check_output(self, y): 484 # check output rank 485 if not y.ndim == 2: 486 error_str = "y has rank %d, should be 2" % (y.ndim) 487 raise NodeException(error_str) 488 489 # check the output dimension 490 if not y.shape[1] == self.output_dim: 491 error_str = "y has dimension %d, should be %d" % (y.shape[1], 492 self.output_dim) 493 raise NodeException(error_str) 494 495 def _if_training_stop_training(self): 496 if self.is_training(): 497 self.stop_training() 498 # if there is some training phases left we shouldn't be here! 499 if self.get_remaining_train_phase() > 0: 500 error_str = "The training phases are not completed yet." 501 raise TrainingException(error_str) 502 503 def _pre_execution_checks(self, x): 504 """This method contains all pre-execution checks. 505 506 It can be used when a subclass defines multiple execution methods. 507 """ 508 # if training has not started yet, assume we want to train the node 509 if (self.get_current_train_phase() == 0 and 510 not self._train_phase_started): 511 while True: 512 self.train(x) 513 if self.get_remaining_train_phase() > 1: 514 self.stop_training() 515 else: 516 break 517 518 self._if_training_stop_training() 519 520 # control the dimension x 521 self._check_input(x) 522 523 # set the output dimension if necessary 524 if self.output_dim is None: 525 self.output_dim = self.input_dim 526 527 def _pre_inversion_checks(self, y): 528 """This method contains all pre-inversion checks. 529 530 It can be used when a subclass defines multiple inversion methods. 531 """ 532 if not self.is_invertible(): 533 raise IsNotInvertibleException("This node is not invertible.") 534 535 self._if_training_stop_training() 536 537 # set the output dimension if necessary 538 if self.output_dim is None: 539 # if the input_dim is not defined, raise an exception 540 if self.input_dim is None: 541 errstr = ("Number of input dimensions undefined. Inversion" 542 "not possible.") 543 raise NodeException(errstr) 544 self.output_dim = self.input_dim 545 546 # control the dimension of y 547 self._check_output(y) 548 549 ### casting helper functions 550 551 def _refcast(self, x): 552 """Helper function to cast arrays to the internal dtype.""" 553 return mdp.utils.refcast(x, self.dtype) 554 555 ### Methods to be implemented by the user 556 557 # this are the methods the user has to overwrite 558 # they receive the data already casted to the correct type 559 560 def _train(self, x): 561 if self.is_trainable(): 562 raise NotImplementedError 563 564 def _stop_training(self, *args, **kwargs): 565 pass 566 567 def _execute(self, x): 568 return x 569 570 def _inverse(self, x): 571 if self.is_invertible(): 572 return x 573 574 def _check_train_args(self, x, *args, **kwargs): 575 # implemented by subclasses if needed 576 pass 577 578 ### User interface to the overwritten methods 579 580 def train(self, x, *args, **kwargs): 581 """Update the internal structures according to the input data `x`. 582 583 `x` is a matrix having different variables on different columns 584 and observations on the rows. 585 586 By default, subclasses should overwrite `_train` to implement their 587 training phase. The docstring of the `_train` method overwrites this 588 docstring. 589 590 Note: a subclass supporting multiple training phases should implement 591 the *same* signature for all the training phases and document the 592 meaning of the arguments in the `_train` method doc-string. Having 593 consistent signatures is a requirement to use the node in a flow. 594 """ 595 596 if not self.is_trainable(): 597 raise IsNotTrainableException("This node is not trainable.") 598 599 if not self.is_training(): 600 err_str = "The training phase has already finished." 601 raise TrainingFinishedException(err_str) 602 603 self._check_input(x) 604 self._check_train_args(x, *args, **kwargs) 605 606 self._train_phase_started = True 607 self._train_seq[self._train_phase][0](self._refcast(x), *args, **kwargs) 608 609 def stop_training(self, *args, **kwargs): 610 """Stop the training phase. 611 612 By default, subclasses should overwrite `_stop_training` to implement 613 this functionality. The docstring of the `_stop_training` method 614 overwrites this docstring. 615 """ 616 if self.is_training() and self._train_phase_started == False: 617 raise TrainingException("The node has not been trained.") 618 619 if not self.is_training(): 620 err_str = "The training phase has already finished." 621 raise TrainingFinishedException(err_str) 622 623 # close the current phase. 624 self._train_seq[self._train_phase][1](*args, **kwargs) 625 self._train_phase += 1 626 self._train_phase_started = False 627 # check if we have some training phase left 628 if self.get_remaining_train_phase() == 0: 629 self._training = False 630 631 def execute(self, x, *args, **kwargs): 632 """Process the data contained in `x`. 633 634 If the object is still in the training phase, the function 635 `stop_training` will be called. 636 `x` is a matrix having different variables on different columns 637 and observations on the rows. 638 639 By default, subclasses should overwrite `_execute` to implement 640 their execution phase. The docstring of the `_execute` method 641 overwrites this docstring. 642 """ 643 self._pre_execution_checks(x) 644 return self._execute(self._refcast(x), *args, **kwargs) 645 646 def inverse(self, y, *args, **kwargs): 647 """Invert `y`. 648 649 If the node is invertible, compute the input ``x`` such that 650 ``y = execute(x)``. 651 652 By default, subclasses should overwrite `_inverse` to implement 653 their `inverse` function. The docstring of the `inverse` method 654 overwrites this docstring. 655 """ 656 self._pre_inversion_checks(y) 657 return self._inverse(self._refcast(y), *args, **kwargs) 658 659 def __call__(self, x, *args, **kwargs): 660 """Calling an instance of `Node` is equivalent to calling 661 its `execute` method.""" 662 return self.execute(x, *args, **kwargs) 663 664 ###### adding nodes returns flows 665 666 def __add__(self, other): 667 # check other is a node 668 if isinstance(other, Node): 669 return mdp.Flow([self, other]) 670 elif isinstance(other, mdp.Flow): 671 flow_copy = other.copy() 672 flow_copy.insert(0, self) 673 return flow_copy.copy() 674 else: 675 err_str = ('can only concatenate node' 676 ' (not \'%s\') to node' % (type(other).__name__)) 677 raise TypeError(err_str) 678 679 ###### string representation 680 681 def __str__(self): 682 return str(type(self).__name__) 683 684 def __repr__(self): 685 # print input_dim, output_dim, dtype 686 name = type(self).__name__ 687 inp = "input_dim=%s" % str(self.input_dim) 688 out = "output_dim=%s" % str(self.output_dim) 689 if self.dtype is None: 690 typ = 'dtype=None' 691 else: 692 typ = "dtype='%s'" % self.dtype.name 693 args = ', '.join((inp, out, typ)) 694 return name + '(' + args + ')' 695 696 def copy(self, protocol=None): 697 """Return a deep copy of the node. 698 699 :param protocol: the pickle protocol (deprecated).""" 700 if protocol is not None: 701 _warnings.warn("protocol parameter to copy() is ignored", 702 mdp.MDPDeprecationWarning, stacklevel=2) 703 return _copy.deepcopy(self) 704 705 def save(self, filename, protocol=-1): 706 """Save a pickled serialization of the node to `filename`. 707 If `filename` is None, return a string. 708 709 Note: the pickled `Node` is not guaranteed to be forwards or 710 backwards compatible.""" 711 if filename is None: 712 return _cPickle.dumps(self, protocol) 713 else: 714 # if protocol != 0 open the file in binary mode 715 mode = 'wb' if protocol != 0 else 'w' 716 with open(filename, mode) as flh: 717 _cPickle.dump(self, flh, protocol) 718 719 720class PreserveDimNode(Node): 721 """Abstract base class with ``output_dim == input_dim``. 722 723 If one dimension is set then the other is set to the same value. 724 If the dimensions are set to different values, then an 725 `InconsistentDimException` is raised. 726 """ 727 728 def _set_input_dim(self, n): 729 if (self._output_dim is not None) and (self._output_dim != n): 730 err = "input_dim must be equal to output_dim for this node." 731 raise InconsistentDimException(err) 732 self._input_dim = n 733 self._output_dim = n 734 735 def _set_output_dim(self, n): 736 if (self._input_dim is not None) and (self._input_dim != n): 737 err = "output_dim must be equal to input_dim for this node." 738 raise InconsistentDimException(err) 739 self._input_dim = n 740 self._output_dim = n 741 742 743def VariadicCumulator(*fields): 744 """A VariadicCumulator is a `Node` whose training phase simply collects 745 all input data. In this way it is possible to easily implement 746 batch-mode learning. 747 748 The data is accessible in the attributes given with the VariadicCumulator's 749 constructor after the beginning of the `Node._stop_training` phase. 750 ``self.tlen`` contains the number of data points collected. 751 """ 752 753 class Cumulator(Node): 754 def __init__(self, *args, **kwargs): 755 super(Cumulator, self).__init__(*args, **kwargs) 756 self._cumulator_fields = fields 757 for arg in self._cumulator_fields: 758 if hasattr(self, arg): 759 errstr = "Cumulator Error: Property %s already defined" 760 raise mdp.MDPException(errstr % arg) 761 setattr(self, arg, []) 762 self.tlen = 0 763 764 def _train(self, *args): 765 """Collect all input data in a list.""" 766 self.tlen += args[0].shape[0] 767 for field, data in zip(self._cumulator_fields, args): 768 getattr(self, field).append(data) 769 770 def _stop_training(self, *args, **kwargs): 771 """Concatenate the collected data in a single array.""" 772 for field in self._cumulator_fields: 773 data = getattr(self, field) 774 setattr(self, field, numx.concatenate(data, 0)) 775 776 return Cumulator 777 778Cumulator = VariadicCumulator('data') 779Cumulator.__doc__ = """A specialized version of `VariadicCumulator` which only 780 fills the field ``self.data``. 781 """ 782