1"""Contains the classes that are used to define the dependency network. 2 3Copyright (C) 2013, Joshua More and Michele Ceriotti 4 5This program is free software: you can redistribute it and/or modify 6it under the terms of the GNU General Public License as published by 7the Free Software Foundation, either version 3 of the License, or 8(at your option) any later version. 9 10This program is distributed in the hope that it will be useful, 11but WITHOUT ANY WARRANTY; without even the implied warranty of 12MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13GNU General Public License for more details. 14 15You should have received a copy of the GNU General Public License 16along with this program. If not, see <http.//www.gnu.org/licenses/>. 17 18 19The classes defined in this module overload the standard __get__ and __set__ 20routines of the numpy ndarray class and standard library object class so that 21they automatically keep track of whether anything they depend on has been 22altered, and so only recalculate their value when necessary. 23 24Basic quantities that depend on nothing else can be manually altered in the 25usual way, all other quantities are updated automatically and cannot be changed 26directly. 27 28The exceptions to this are synchronized properties, which are in effect 29multiple basic quantities all related to each other, for example the bead and 30normal mode representations of the positions and momenta. In this case any of 31the representations can be set manually, and all the other representations 32must keep in step. 33 34For a more detailed discussion, see the reference manual. 35 36Classes: 37 depend_base: Base depend class with the generic methods and attributes. 38 depend_value: Depend class for scalar objects. 39 depend_array: Depend class for arrays. 40 synchronizer: Class that holds the different objects that are related to each 41 other and keeps track of which property has been set manually. 42 dobject: An extension of the standard library object that overloads 43 __getattribute__ and __setattribute__, so that we can use the 44 standard syntax for setting and getting the depend object, 45 i.e. foo = value, not foo.set(value). 46 47Functions: 48 dget: Gets the dependencies of a depend object. 49 dset: Sets the dependencies of a depend object. 50 depstrip: Used on a depend_array object, to access its value without 51 needing the depend machinery, and so much more quickly. Must not be used 52 if the value of the array is to be changed. 53 depcopy: Copies the dependencies from one object to another 54 deppipe: Used to make two objects be synchronized to the same value. 55""" 56 57__all__ = ['depend_base', 'depend_value', 'depend_array', 'synchronizer', 58 'dobject', 'dget', 'dset', 'depstrip', 'depcopy', 'deppipe'] 59 60import numpy as np 61from ipi.utils.messages import verbosity, warning 62 63class synchronizer(object): 64 """Class to implement synched objects. 65 66 Holds the objects used to keep two or more objects in step with each other. 67 This is shared between all the synched objects. 68 69 Attributes: 70 synched: A dictionary containing all the synched objects, of the form 71 {"name": depend object}. 72 manual: A string containing the name of the object being manually changed. 73 """ 74 75 def __init__(self, deps=None): 76 """Initializes synchronizer. 77 78 Args: 79 deps: Optional dictionary giving the synched objects of the form 80 {"name": depend object}. 81 """ 82 83 if deps is None: 84 self.synced = dict() 85 else: 86 self.synced = deps 87 88 self.manual = None 89 90 91#TODO put some error checks in the init to make sure that the object is initialized from consistent synchro and func states 92class depend_base(object): 93 """Base class for dependency handling. 94 95 Builds the majority of the machinery required for the different depend 96 objects. Contains functions to add and remove dependencies, the tainting 97 mechanism by which information about which objects have been updated is 98 passed around the dependency network, and the manual and automatic update 99 functions to check that depend objects with functions are not manually 100 updated and that synchronized objects are kept in step with the one manually 101 changed. 102 103 Attributes: 104 _tainted: An array containing one boolean, which is True if one of the 105 dependencies has been changed since the last time the value was 106 cached. 107 _func: A function name giving the method of calculating the value, 108 if required. None otherwise. 109 _name: The name of the depend base object. 110 _synchro: A synchronizer object to deal with synched objects, if 111 required. None otherwise. 112 _dependants: A list containing all objects dependent on the self. 113 """ 114 115 def __init__(self, name, synchro=None, func=None, dependants=None, dependencies=None, tainted=None): 116 """Initializes depend_base. 117 118 An unusual initialization routine, as it has to be able to deal with the 119 depend array mechanism for returning slices as new depend arrays. 120 121 This is the reason for the penultimate if statement; it automatically 122 taints objects created from scratch but does nothing to slices which are 123 not tainted. 124 125 Also, the last if statement makes sure that if a synchronized property is 126 sliced, this initialization routine does not automatically set it to the 127 manually updated property. 128 129 Args: 130 name: A string giving the name of self. 131 tainted: An optional array containing one boolean which is True if one 132 of the dependencies has been changed. 133 func: An optional argument that can be specified either by a function 134 name, or for synchronized values a dictionary of the form 135 {"name": function name}; where "name" is one of the other 136 synched objects and function name is the name of a function to 137 get the object "name" from self. 138 synchro: An optional synchronizer object. 139 dependants: An optional list containing objects that depend on self. 140 dependencies: An optional list containing objects that self 141 depends upon. 142 """ 143 144 self._dependants = [] 145 if tainted is None: 146 tainted = np.array([True],bool) 147 if dependants is None: 148 dependants = [] 149 if dependencies is None: 150 dependencies = [] 151 self._tainted = tainted 152 self._func = func 153 self._name = name 154 155 self.add_synchro(synchro) 156 157 for item in dependencies: 158 item.add_dependant(self, tainted) 159 160 self._dependants = dependants 161 162 # Don't taint self if the object is a primitive one. However, do propagate tainting to dependants if required. 163 if (tainted): 164 if self._func is None: 165 self.taint(taintme=False) 166 else: 167 self.taint(taintme=tainted) 168 169 170 def add_synchro(self, synchro=None): 171 """ Links depend object to a synchronizer. """ 172 173 self._synchro = synchro 174 if not self._synchro is None and not self._name in self._synchro.synced: 175 self._synchro.synced[self._name] = self 176 self._synchro.manual = self._name 177 178 179 def add_dependant(self, newdep, tainted=True): 180 """Adds a dependant property. 181 182 Args: 183 newdep: The depend object to be added to the dependency list. 184 tainted: A boolean that decides whether newdep should be tainted. 185 True by default. 186 """ 187 188 self._dependants.append(newdep) 189 if tainted: 190 newdep.taint(taintme=True) 191 192 def add_dependency(self, newdep, tainted=True): 193 """Adds a dependency. 194 195 Args: 196 newdep: The depend object self now depends upon. 197 tainted: A boolean that decides whether self should 198 be tainted. True by default. 199 """ 200 201 newdep._dependants.append(self) 202 if tainted: 203 self.taint(taintme=True) 204 205 def taint(self,taintme=True): 206 """Recursively sets tainted flag on dependent objects. 207 208 The main function dealing with the dependencies. Taints all objects 209 further down the dependency tree until either all objects have been 210 tainted, or it reaches only objects that have already been tainted. Note 211 that in the case of a dependency loop the initial setting of _tainted to 212 True prevents an infinite loop occurring. 213 214 Also, in the case of a synchro object, the manually set quantity is not 215 tainted, as it is assumed that synchro objects only depend on each other. 216 217 Args: 218 taintme: A boolean giving whether self should be tainted at the end. 219 True by default. 220 """ 221 222 self._tainted[:] = True 223 for item in self._dependants: 224 if (not item._tainted[0]): 225 item.taint() 226 if not self._synchro is None: 227 for v in self._synchro.synced.values(): 228 if (not v._tainted[0]) and (not v is self): 229 v.taint(taintme=True) 230 self._tainted[:] = (taintme and (not self._name == self._synchro.manual)) 231 else: 232 self._tainted[:] = taintme 233 234 def tainted(self): 235 """Returns tainted flag.""" 236 237 return self._tainted[0] 238 239 def update_auto(self): 240 """Automatic update routine. 241 242 Updates the value when get has been called and self has been tainted. 243 """ 244 245 if not self._synchro is None: 246 if (not self._name == self._synchro.manual): 247 self.set(self._func[self._synchro.manual](), manual=False) 248 else: 249 warning(self._name + " probably shouldn't be tainted (synchro)", verbosity.low) 250 elif not self._func is None: 251 self.set(self._func(), manual=False) 252 else: 253 warning(self._name + " probably shouldn't be tainted (value)", verbosity.low) 254 255 def update_man(self): 256 """Manual update routine. 257 258 Updates the value when the value has been manually set. Also raises an 259 exception if a calculated quantity has been manually set. Also starts the 260 tainting routine. 261 262 Raises: 263 NameError: If a calculated quantity has been manually set. 264 """ 265 266 if not self._synchro is None: 267 self._synchro.manual = self._name 268 for v in self._synchro.synced.values(): 269 v.taint(taintme=True) 270 self._tainted[:] = False 271 elif not self._func is None: 272 raise NameError("Cannot set manually the value of the automatically-computed property <" + self._name + ">") 273 else: 274 self.taint(taintme=False) 275 276 def set(self, value, manual=False): 277 """Dummy setting routine.""" 278 279 pass 280 281 def get(self): 282 """Dummy getting routine.""" 283 284 pass 285 286class depend_value(depend_base): 287 """Scalar class for dependency handling. 288 289 Attributes: 290 _value: The value associated with self. 291 """ 292 293 def __init__(self, name, value=None, synchro=None, func=None, dependants=None, dependencies=None, tainted=None): 294 """Initializes depend_value. 295 296 Args: 297 name: A string giving the name of self. 298 value: The value of the object. Optional. 299 tainted: An optional array giving the tainted flag. Default is [True]. 300 func: An optional argument that can be specified either by a function 301 name, or for synchronized values a dictionary of the form 302 {"name": function name}; where "name" is one of the other 303 synched objects and function name is the name of a function to 304 get the object "name" from self. 305 synchro: An optional synchronizer object. 306 dependants: An optional list containing objects that depend on self. 307 dependencies: An optional list containing objects that self 308 depends upon. 309 """ 310 311 self._value = value 312 super(depend_value,self).__init__(name, synchro, func, dependants, dependencies, tainted) 313 314 def get(self): 315 """Returns value, after recalculating if necessary. 316 317 Overwrites the standard method of getting value, so that value 318 is recalculated if tainted. 319 """ 320 321 if self._tainted[0]: 322 self.update_auto() 323 self.taint(taintme=False) 324 325 return self._value 326 327 def __get__(self, instance, owner): 328 """Overwrites standard get function.""" 329 330 return self.get() 331 332 def set(self, value, manual=True): 333 """Alters value and taints dependencies. 334 335 Overwrites the standard method of setting value, so that dependent 336 quantities are tainted, and so we check that computed quantities are not 337 manually updated. 338 """ 339 340 self._value = value 341 self.taint(taintme=False) 342 if (manual): 343 self.update_man() 344 345 def __set__(self, instance, value): 346 """Overwrites standard set function.""" 347 348 self.set(value) 349 350 351class depend_array(np.ndarray, depend_base): 352 """Array class for dependency handling. 353 354 Differs from depend_value as arrays handle getting items in a different 355 way to scalar quantities, and as there needs to be support for slicing an 356 array. Initialisation is also done in a different way for ndarrays. 357 358 Attributes: 359 _bval: The base deparray storage space. Equal to depstrip(self) unless 360 self is a slice. 361 """ 362 363 def __new__(cls, value, name, synchro=None, func=None, dependants=None, dependencies=None, tainted=None, base=None): 364 """Creates a new array from a template. 365 366 Called whenever a new instance of depend_array is created. Casts the 367 array base into an appropriate form before passing it to 368 __array_finalize__(). 369 370 Args: 371 See __init__(). 372 """ 373 374 obj = np.asarray(value).view(cls) 375 return obj 376 377 def __init__(self, value, name, synchro=None, func=None, dependants=None, dependencies=None, tainted=None, base=None): 378 """Initializes depend_array. 379 380 Note that this is only called when a new array is created by an 381 explicit constructor. 382 383 Args: 384 name: A string giving the name of self. 385 value: The (numpy) array to serve as the memory base. 386 tainted: An optional array giving the tainted flag. Default is [True]. 387 func: An optional argument that can be specified either by a function 388 name, or for synchronized values a dictionary of the form 389 {"name": function name}; where "name" is one of the other 390 synched objects and function name is the name of a function to 391 get the object "name" from self. 392 synchro: An optional synchronizer object. 393 dependants: An optional list containing objects that depend on self. 394 dependencies: An optional list containing objects that self 395 depends upon. 396 """ 397 398 super(depend_array,self).__init__(name, synchro, func, dependants, dependencies, tainted) 399 400 if base is None: 401 self._bval = value 402 else: 403 self._bval = base 404 405 def copy(self, order='C', maskna=None): 406 """Wrapper for numpy copy mechanism.""" 407 408 # Sets a flag and hands control to the numpy copy 409 self._fcopy = True 410 return super(depend_array,self).copy(order) 411 412 def __array_finalize__(self, obj): 413 """Deals with properly creating some arrays. 414 415 In the case where a function acting on a depend array returns a ndarray, 416 this casts it into the correct form and gives it the 417 depend machinery for other methods to be able to act upon it. New 418 depend_arrays will next be passed to __init__ ()to be properly 419 initialized, but some ways of creating arrays do not call __new__() or 420 __init__(), so need to be initialized. 421 """ 422 423 depend_base.__init__(self, name="") 424 425 if type(obj) is depend_array: 426 # We are in a view cast or in new from template. Unfortunately 427 # there is no sure way to tell (or so it seems). Hence we need to 428 # handle special cases, and hope we are in a view cast otherwise. 429 if hasattr(obj,"_fcopy"): 430 del(obj._fcopy) # removes the "copy flag" 431 self._bval = depstrip(self) 432 else: 433 # Assumes we are in view cast, so copy over the attributes from the 434 # parent object. Typical case: when transpose is performed as a 435 # view. 436 super(depend_array,self).__init__(obj._name, obj._synchro, obj._func, obj._dependants, None, obj._tainted) 437 self._bval = obj._bval 438 else: 439 # Most likely we came here on the way to init. 440 # Just sets a defaults for safety 441 self._bval = depstrip(self) 442 443 444 def __array_prepare__(self, arr, context=None): 445 """Prepare output array for ufunc. 446 447 Depending on the context we try to understand if we are doing an 448 in-place operation (in which case we want to keep the return value a 449 deparray) or we are generating a new array as a result of the ufunc. 450 In this case there is no way to know if dependencies should be copied, 451 so we strip and return a ndarray. 452 """ 453 454 if context is None or len(context) < 2 or not type(context[0]) is np.ufunc: 455 # It is not clear what we should do. If in doubt, strip dependencies. 456 return np.ndarray.__array_prepare__(self.view(np.ndarray),arr.view(np.ndarray),context) 457 elif len(context[1]) > context[0].nin and context[0].nout > 0: 458 # We are being called by a ufunc with a output argument, which is being 459 # actually used. Most likely, something like an increment, 460 # so we pass on a deparray 461 return super(depend_array,self).__array_prepare__(arr,context) 462 else: 463 # Apparently we are generating a new array. 464 # We have no way of knowing its 465 # dependencies, so we'd better return a ndarray view! 466 return np.ndarray.__array_prepare__(self.view(np.ndarray),arr.view(np.ndarray),context) 467 468 def __array_wrap__(self, arr, context=None): 469 """ Wraps up output array from ufunc. 470 471 See docstring of __array_prepare__(). 472 """ 473 474 if context is None or len(context) < 2 or not type(context[0]) is np.ufunc: 475 return np.ndarray.__array_wrap__(self.view(np.ndarray),arr.view(np.ndarray),context) 476 elif len(context[1]) > context[0].nin and context[0].nout > 0: 477 return super(depend_array,self).__array_wrap__(arr,context) 478 else: 479 return np.ndarray.__array_wrap__(self.view(np.ndarray),arr.view(np.ndarray),context) 480 481 # whenever possible in compound operations just return a regular ndarray 482 __array_priority__ = -1.0 483 484 def reshape(self, newshape): 485 """Changes the shape of the base array. 486 487 Args: 488 newshape: A tuple giving the desired shape of the new array. 489 490 Returns: 491 A depend_array with the dimensions given by newshape. 492 """ 493 494 return depend_array(depstrip(self).reshape(newshape), name=self._name, synchro=self._synchro, func=self._func, dependants=self._dependants, tainted=self._tainted, base=self._bval) 495 496 def flatten(self): 497 """Makes the base array one dimensional. 498 499 Returns: 500 A flattened array. 501 """ 502 503 return self.reshape(self.size) 504 505 @staticmethod 506 def __scalarindex(index, depth=1): 507 """Checks if an index points at a scalar value. 508 509 Used so that looking up one item in an array returns a scalar, whereas 510 looking up a slice of the array returns a new array with the same 511 dependencies as the original, so that changing the slice also taints 512 the global array. 513 514 Arguments: 515 index: the index to be checked. 516 depth: the rank of the array which is being accessed. Default value 517 is 1. 518 519 Returns: 520 A logical stating whether a __get__ instruction based 521 on index would return a scalar. 522 """ 523 524 if (np.isscalar(index) and depth <= 1): 525 return True 526 elif (isinstance(index, tuple) and len(index)==depth): 527 #if the index is a tuple check it does not contain slices 528 for i in index: 529 if not np.isscalar(i): return False 530 return True 531 return False 532 533 def __getitem__(self,index): 534 """Returns value[index], after recalculating if necessary. 535 536 Overwrites the standard method of getting value, so that value 537 is recalculated if tainted. Scalar slices are returned as an ndarray, 538 so without depend machinery. If you need a "scalar depend" which 539 behaves as a slice, just create a 1x1 matrix, e.g b=a(7,1:2) 540 541 Args: 542 index: A slice variable giving the appropriate slice to be read. 543 """ 544 545 if self._tainted[0]: 546 self.update_auto() 547 self.taint(taintme=False) 548 549 if (self.__scalarindex(index, self.ndim)): 550 return depstrip(self)[index] 551 else: 552 return depend_array(depstrip(self)[index], name=self._name, synchro=self._synchro, func=self._func, dependants=self._dependants, tainted=self._tainted, base=self._bval) 553 554 555 def __getslice__(self,i,j): 556 """Overwrites standard get function.""" 557 558 return self.__getitem__(slice(i,j,None)) 559 560 def get(self): 561 """Alternative to standard get function.""" 562 563 return self.__get__(slice(None,None,None)) 564 565 def __get__(self, instance, owner): 566 """Overwrites standard get function.""" 567 568 # It is worth duplicating this code that is also used in __getitem__ as this 569 # is called most of the time, and we avoid creating a load of copies pointing to the same depend_array 570 571 if self._tainted[0]: 572 self.update_auto() 573 self.taint(taintme=False) 574 575 return self 576 577 def __setitem__(self,index,value,manual=True): 578 """Alters value[index] and taints dependencies. 579 580 Overwrites the standard method of setting value, so that dependent 581 quantities are tainted, and so we check that computed quantities are not 582 manually updated. 583 584 Args: 585 index: A slice variable giving the appropriate slice to be read. 586 value: The new value of the slice. 587 manual: Optional boolean giving whether the value has been changed 588 manually. True by default. 589 """ 590 591 self.taint(taintme=False) 592 if manual: 593 depstrip(self)[index] = value 594 self.update_man() 595 elif index == slice(None,None,None): 596 self._bval[index] = value 597 else: 598 raise IndexError("Automatically computed arrays should span the whole parent") 599 600 def __setslice__(self,i,j,value): 601 """Overwrites standard set function.""" 602 603 return self.__setitem__(slice(i,j),value) 604 605 def set(self, value, manual=True): 606 """Alterative to standard set function. 607 608 Args: 609 See __setitem__(). 610 """ 611 612 self.__setitem__(slice(None,None),value=value,manual=manual) 613 614 def __set__(self, instance, value): 615 """Overwrites standard set function.""" 616 617 self.__setitem__(slice(None,None),value=value) 618 619 620# np.dot and other numpy.linalg functions have the nasty habit to 621# view cast to generate the output. Since we don't want to pass on 622# dependencies to the result of these functions, and we can't use 623# the ufunc mechanism to demote the class type to ndarray, we must 624# overwrite np.dot and other similar functions. 625# BEGINS NUMPY FUNCTIONS OVERRIDE 626# ** np.dot 627__dp_dot = np.dot 628 629def dep_dot(da, db): 630 a=depstrip(da) 631 b=depstrip(db) 632 633 return __dp_dot(da,db) 634 635np.dot = dep_dot 636# ENDS NUMPY FUNCTIONS OVERRIDE 637 638def dget(obj,member): 639 """Takes an object and retrieves one of its attributes. 640 641 Note that this is necessary as calling it in the standard way calls the 642 __get__() function of member. 643 644 Args: 645 obj: A user defined class. 646 member: A string giving the name of an attribute of obj. 647 648 Exceptions: 649 KeyError: If member is not an attribute of obj. 650 651 Returns: 652 obj.member. 653 """ 654 655 return obj.__dict__[member] 656 657def dset(obj,member,value,name=None): 658 """Takes an object and sets one of its attributes. 659 660 Necessary for editing any depend object, and should be used for 661 initialising them as well, as often initialization occurs more than once, 662 with the second time effectively being an edit. 663 664 Args: 665 obj: A user defined class. 666 member: A string giving the name of an attribute of obj. 667 value: The new value of member. 668 name: New name of member. 669 670 Exceptions: 671 KeyError: If member is not an attribute of obj. 672 """ 673 674 obj.__dict__[member] = value 675 if not name is None: 676 obj.__dict__[member]._name = name 677 678def depstrip(da): 679 """Removes dependencies from a depend_array. 680 681 Takes a depend_array and returns its value as a ndarray, effectively 682 stripping the dependencies from the ndarray. This speeds up a lot of 683 calculations involving these arrays. Must only be used if the value of the 684 array is not going to be changed. 685 686 Args: 687 deparray: A depend_array. 688 689 Returns: 690 A ndarray with the same value as deparray. 691 """ 692 693 if isinstance(da, depend_array): # only bother to strip dependencies if the array actually IS a depend_array 694 #if da._tainted[0]: 695 # print "!!! WARNING depstrip called on tainted array WARNING !!!!!" # I think we can safely assume that when we call depstrip the array has been cleared already but I am not 100% sure so better check - and in case raise the update 696 return da.view(np.ndarray) 697 else: 698 return da 699 700def deppipe(objfrom,memberfrom,objto,memberto): 701 """Synchronizes two depend objects. 702 703 Takes two depend objects, and makes one of them depend on the other in such 704 a way that both keep the same value. Used for attributes such as temperature 705 that are used in many different modules, and so need different depend objects 706 in each, but which should all have the same value. 707 708 Args: 709 objfrom: An object containing memberfrom. 710 memberfrom: The base depend object. 711 objto: An object containing memberto. 712 memberto: The depend object that should be equal to memberfrom. 713 """ 714 715 dfrom = dget(objfrom,memberfrom) 716 dto = dget(objto,memberto) 717 dto._func = lambda : dfrom.get() 718 dto.add_dependency(dfrom) 719 720def depcopy(objfrom,memberfrom,objto,memberto): 721 """Copies the dependencies of one depend object to another. 722 723 Args: 724 See deppipe. 725 """ 726 dfrom = dget(objfrom,memberfrom) 727 dto = dget(objto,memberto) 728 dto._dependants = dfrom._dependants 729 dto._synchro = dfrom._synchro 730 dto.add_synchro(dfrom._synchro) 731 dto._tainted = dfrom._tainted 732 dto._func = dfrom._func 733 if hasattr(dfrom,"_bval"): 734 dto._bval = dfrom._bval 735 736 737class dobject(object): 738 """Class that allows standard notation to be used for depend objects.""" 739 740 def __getattribute__(self, name): 741 """Overwrites standard __getattribute__(). 742 743 This changes the standard __getattribute__() function of any class that 744 subclasses dobject such that depend objects are called with their own 745 __get__() function rather than the standard one. 746 """ 747 748 value = object.__getattribute__(self, name) 749 if hasattr(value, '__get__'): 750 value = value.__get__(self, self.__class__) 751 return value 752 753 def __setattr__(self, name, value): 754 """Overwrites standard __setattribute__(). 755 756 This changes the standard __setattribute__() function of any class that 757 subclasses dobject such that depend objects are called with their own 758 __set__() function rather than the standard one. 759 """ 760 761 try: 762 obj = object.__getattribute__(self, name) 763 except AttributeError: 764 pass 765 else: 766 if hasattr(obj, '__set__'): 767 return obj.__set__(self, value) 768 return object.__setattr__(self, name, value) 769