1"""Contains the classes that are used to define the dependency network.
2
3Copyright (C) 2013, Joshua More and Michele Ceriotti
4
5This program is free software: you can redistribute it and/or modify
6it under the terms of the GNU General Public License as published by
7the Free Software Foundation, either version 3 of the License, or
8(at your option) any later version.
9
10This program is distributed in the hope that it will be useful,
11but WITHOUT ANY WARRANTY; without even the implied warranty of
12MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13GNU General Public License for more details.
14
15You should have received a copy of the GNU General Public License
16along with this program. If not, see <http.//www.gnu.org/licenses/>.
17
18
19The classes defined in this module overload the standard __get__ and __set__
20routines of the numpy ndarray class and standard library object class so that
21they automatically keep track of whether anything they depend on has been
22altered, and so only recalculate their value when necessary.
23
24Basic quantities that depend on nothing else can be manually altered in the
25usual way, all other quantities are updated automatically and cannot be changed
26directly.
27
28The exceptions to this are synchronized properties, which are in effect
29multiple basic quantities all related to each other, for example the bead and
30normal mode representations of the positions and momenta. In this case any of
31the representations can be set manually, and all the other representations
32must keep in step.
33
34For a more detailed discussion, see the reference manual.
35
36Classes:
37   depend_base: Base depend class with the generic methods and attributes.
38   depend_value: Depend class for scalar objects.
39   depend_array: Depend class for arrays.
40   synchronizer: Class that holds the different objects that are related to each
41      other and keeps track of which property has been set manually.
42   dobject: An extension of the standard library object that overloads
43      __getattribute__ and __setattribute__, so that we can use the
44      standard syntax for setting and getting the depend object,
45      i.e. foo = value, not foo.set(value).
46
47Functions:
48   dget: Gets the dependencies of a depend object.
49   dset: Sets the dependencies of a depend object.
50   depstrip: Used on a depend_array object, to access its value without
51      needing the depend machinery, and so much more quickly. Must not be used
52      if the value of the array is to be changed.
53   depcopy: Copies the dependencies from one object to another
54   deppipe: Used to make two objects be synchronized to the same value.
55"""
56
57__all__ = ['depend_base', 'depend_value', 'depend_array', 'synchronizer',
58           'dobject', 'dget', 'dset', 'depstrip', 'depcopy', 'deppipe']
59
60import numpy as np
61from ipi.utils.messages import verbosity, warning
62
63class synchronizer(object):
64   """Class to implement synched objects.
65
66   Holds the objects used to keep two or more objects in step with each other.
67   This is shared between all the synched objects.
68
69   Attributes:
70      synched: A dictionary containing all the synched objects, of the form
71         {"name": depend object}.
72      manual: A string containing the name of the object being manually changed.
73   """
74
75   def __init__(self, deps=None):
76      """Initializes synchronizer.
77
78      Args:
79         deps: Optional dictionary giving the synched objects of the form
80            {"name": depend object}.
81      """
82
83      if deps is None:
84         self.synced = dict()
85      else:
86         self.synced = deps
87
88      self.manual = None
89
90
91#TODO put some error checks in the init to make sure that the object is initialized from consistent synchro and func states
92class depend_base(object):
93   """Base class for dependency handling.
94
95   Builds the majority of the machinery required for the different depend
96   objects. Contains functions to add and remove dependencies, the tainting
97   mechanism by which information about which objects have been updated is
98   passed around the dependency network, and the manual and automatic update
99   functions to check that depend objects with functions are not manually
100   updated and that synchronized objects are kept in step with the one manually
101   changed.
102
103   Attributes:
104      _tainted: An array containing one boolean, which is True if one of the
105         dependencies has been changed since the last time the value was
106         cached.
107      _func: A function name giving the method of calculating the value,
108         if required. None otherwise.
109      _name: The name of the depend base object.
110      _synchro: A synchronizer object to deal with synched objects, if
111         required. None otherwise.
112      _dependants: A list containing all objects dependent on the self.
113   """
114
115   def __init__(self, name, synchro=None, func=None, dependants=None, dependencies=None, tainted=None):
116      """Initializes depend_base.
117
118      An unusual initialization routine, as it has to be able to deal with the
119      depend array mechanism for returning slices as new depend arrays.
120
121      This is the reason for the penultimate if statement; it automatically
122      taints objects created from scratch but does nothing to slices which are
123      not tainted.
124
125      Also, the last if statement makes sure that if a synchronized property is
126      sliced, this initialization routine does not automatically set it to the
127      manually updated property.
128
129      Args:
130         name: A string giving the name of self.
131         tainted: An optional array containing one boolean which is True if one
132         of the dependencies has been changed.
133         func: An optional argument that can be specified either by a function
134            name, or for synchronized values a dictionary of the form
135            {"name": function name}; where "name" is one of the other
136            synched objects and function name is the name of a function to
137            get the object "name" from self.
138         synchro: An optional synchronizer object.
139         dependants: An optional list containing objects that depend on self.
140         dependencies: An optional list containing objects that self
141            depends upon.
142      """
143
144      self._dependants = []
145      if tainted is None:
146         tainted = np.array([True],bool)
147      if dependants is None:
148         dependants = []
149      if dependencies is None:
150         dependencies = []
151      self._tainted = tainted
152      self._func = func
153      self._name = name
154
155      self.add_synchro(synchro)
156
157      for item in dependencies:
158         item.add_dependant(self, tainted)
159
160      self._dependants = dependants
161
162      # Don't taint self if the object is a primitive one. However, do propagate tainting to dependants if required.
163      if (tainted):
164         if self._func is None:
165            self.taint(taintme=False)
166         else:
167            self.taint(taintme=tainted)
168
169
170   def add_synchro(self, synchro=None):
171      """ Links depend object to a synchronizer. """
172
173      self._synchro = synchro
174      if not self._synchro is None and not self._name in self._synchro.synced:
175         self._synchro.synced[self._name] = self
176         self._synchro.manual = self._name
177
178
179   def add_dependant(self, newdep, tainted=True):
180      """Adds a dependant property.
181
182      Args:
183         newdep: The depend object to be added to the dependency list.
184         tainted: A boolean that decides whether newdep should be tainted.
185            True by default.
186      """
187
188      self._dependants.append(newdep)
189      if tainted:
190         newdep.taint(taintme=True)
191
192   def add_dependency(self, newdep, tainted=True):
193      """Adds a dependency.
194
195      Args:
196         newdep: The depend object self now depends upon.
197         tainted: A boolean that decides whether self should
198            be tainted. True by default.
199      """
200
201      newdep._dependants.append(self)
202      if tainted:
203         self.taint(taintme=True)
204
205   def taint(self,taintme=True):
206      """Recursively sets tainted flag on dependent objects.
207
208      The main function dealing with the dependencies. Taints all objects
209      further down the dependency tree until either all objects have been
210      tainted, or it reaches only objects that have already been tainted. Note
211      that in the case of a dependency loop the initial setting of _tainted to
212      True prevents an infinite loop occurring.
213
214      Also, in the case of a synchro object, the manually set quantity is not
215      tainted, as it is assumed that synchro objects only depend on each other.
216
217      Args:
218         taintme: A boolean giving whether self should be tainted at the end.
219            True by default.
220      """
221
222      self._tainted[:] = True
223      for item in self._dependants:
224         if (not item._tainted[0]):
225            item.taint()
226      if not self._synchro is None:
227         for v in self._synchro.synced.values():
228            if (not v._tainted[0]) and (not v is self):
229               v.taint(taintme=True)
230         self._tainted[:] = (taintme and (not self._name == self._synchro.manual))
231      else:
232         self._tainted[:] = taintme
233
234   def tainted(self):
235      """Returns tainted flag."""
236
237      return self._tainted[0]
238
239   def update_auto(self):
240      """Automatic update routine.
241
242      Updates the value when get has been called and self has been tainted.
243      """
244
245      if not self._synchro is None:
246         if (not self._name == self._synchro.manual):
247            self.set(self._func[self._synchro.manual](), manual=False)
248         else:
249            warning(self._name + " probably shouldn't be tainted (synchro)", verbosity.low)
250      elif not self._func is None:
251         self.set(self._func(), manual=False)
252      else:
253         warning(self._name + " probably shouldn't be tainted (value)", verbosity.low)
254
255   def update_man(self):
256      """Manual update routine.
257
258      Updates the value when the value has been manually set. Also raises an
259      exception if a calculated quantity has been manually set. Also starts the
260      tainting routine.
261
262      Raises:
263         NameError: If a calculated quantity has been manually set.
264      """
265
266      if not self._synchro is None:
267         self._synchro.manual = self._name
268         for v in self._synchro.synced.values():
269            v.taint(taintme=True)
270         self._tainted[:] = False
271      elif not self._func is None:
272         raise NameError("Cannot set manually the value of the automatically-computed property <" + self._name + ">")
273      else:
274         self.taint(taintme=False)
275
276   def set(self, value, manual=False):
277      """Dummy setting routine."""
278
279      pass
280
281   def get(self):
282      """Dummy getting routine."""
283
284      pass
285
286class depend_value(depend_base):
287   """Scalar class for dependency handling.
288
289   Attributes:
290      _value: The value associated with self.
291   """
292
293   def __init__(self, name, value=None, synchro=None, func=None, dependants=None, dependencies=None, tainted=None):
294      """Initializes depend_value.
295
296      Args:
297         name: A string giving the name of self.
298         value: The value of the object. Optional.
299         tainted: An optional array giving the tainted flag. Default is [True].
300         func: An optional argument that can be specified either by a function
301            name, or for synchronized values a dictionary of the form
302            {"name": function name}; where "name" is one of the other
303            synched objects and function name is the name of a function to
304            get the object "name" from self.
305         synchro: An optional synchronizer object.
306         dependants: An optional list containing objects that depend on self.
307         dependencies: An optional list containing objects that self
308            depends upon.
309      """
310
311      self._value = value
312      super(depend_value,self).__init__(name, synchro, func, dependants, dependencies, tainted)
313
314   def get(self):
315      """Returns value, after recalculating if necessary.
316
317      Overwrites the standard method of getting value, so that value
318      is recalculated if tainted.
319      """
320
321      if self._tainted[0]:
322         self.update_auto()
323         self.taint(taintme=False)
324
325      return self._value
326
327   def __get__(self, instance, owner):
328      """Overwrites standard get function."""
329
330      return self.get()
331
332   def set(self, value, manual=True):
333      """Alters value and taints dependencies.
334
335      Overwrites the standard method of setting value, so that dependent
336      quantities are tainted, and so we check that computed quantities are not
337      manually updated.
338      """
339
340      self._value = value
341      self.taint(taintme=False)
342      if (manual):
343         self.update_man()
344
345   def __set__(self, instance, value):
346      """Overwrites standard set function."""
347
348      self.set(value)
349
350
351class depend_array(np.ndarray, depend_base):
352   """Array class for dependency handling.
353
354   Differs from depend_value as arrays handle getting items in a different
355   way to scalar quantities, and as there needs to be support for slicing an
356   array. Initialisation is also done in a different way for ndarrays.
357
358   Attributes:
359      _bval: The base deparray storage space. Equal to depstrip(self) unless
360         self is a slice.
361   """
362
363   def __new__(cls, value, name, synchro=None, func=None, dependants=None, dependencies=None, tainted=None, base=None):
364      """Creates a new array from a template.
365
366      Called whenever a new instance of depend_array is created. Casts the
367      array base into an appropriate form before passing it to
368      __array_finalize__().
369
370      Args:
371         See __init__().
372      """
373
374      obj = np.asarray(value).view(cls)
375      return obj
376
377   def __init__(self, value, name, synchro=None, func=None, dependants=None, dependencies=None, tainted=None, base=None):
378      """Initializes depend_array.
379
380      Note that this is only called when a new array is created by an
381      explicit constructor.
382
383      Args:
384         name: A string giving the name of self.
385         value: The (numpy) array to serve as the memory base.
386         tainted: An optional array giving the tainted flag. Default is [True].
387         func: An optional argument that can be specified either by a function
388            name, or for synchronized values a dictionary of the form
389            {"name": function name}; where "name" is one of the other
390            synched objects and function name is the name of a function to
391            get the object "name" from self.
392         synchro: An optional synchronizer object.
393         dependants: An optional list containing objects that depend on self.
394         dependencies: An optional list containing objects that self
395            depends upon.
396      """
397
398      super(depend_array,self).__init__(name, synchro, func, dependants, dependencies, tainted)
399
400      if base is None:
401         self._bval = value
402      else:
403         self._bval = base
404
405   def copy(self, order='C', maskna=None):
406      """Wrapper for numpy copy mechanism."""
407
408      # Sets a flag and hands control to the numpy copy
409      self._fcopy = True
410      return super(depend_array,self).copy(order)
411
412   def __array_finalize__(self, obj):
413      """Deals with properly creating some arrays.
414
415      In the case where a function acting on a depend array returns a ndarray,
416      this casts it into the correct form and gives it the
417      depend machinery for other methods to be able to act upon it. New
418      depend_arrays will next be passed to __init__ ()to be properly
419      initialized, but some ways of creating arrays do not call __new__() or
420      __init__(), so need to be initialized.
421      """
422
423      depend_base.__init__(self, name="")
424
425      if type(obj) is depend_array:
426         # We are in a view cast or in new from template. Unfortunately
427         # there is no sure way to tell (or so it seems). Hence we need to
428         # handle special cases, and hope we are in a view cast otherwise.
429         if hasattr(obj,"_fcopy"):
430            del(obj._fcopy) # removes the "copy flag"
431            self._bval = depstrip(self)
432         else:
433            # Assumes we are in view cast, so copy over the attributes from the
434            # parent object. Typical case: when transpose is performed as a
435            # view.
436            super(depend_array,self).__init__(obj._name, obj._synchro, obj._func, obj._dependants, None, obj._tainted)
437            self._bval = obj._bval
438      else:
439         # Most likely we came here on the way to init.
440         # Just sets a defaults for safety
441         self._bval = depstrip(self)
442
443
444   def __array_prepare__(self, arr, context=None):
445      """Prepare output array for ufunc.
446
447      Depending on the context we try to understand if we are doing an
448      in-place operation (in which case we want to keep the return value a
449      deparray) or we are generating a new array as a result of the ufunc.
450      In this case there is no way to know if dependencies should be copied,
451      so we strip and return a ndarray.
452      """
453
454      if context is None or len(context) < 2 or not type(context[0]) is np.ufunc:
455         # It is not clear what we should do. If in doubt, strip dependencies.
456         return np.ndarray.__array_prepare__(self.view(np.ndarray),arr.view(np.ndarray),context)
457      elif len(context[1]) > context[0].nin and context[0].nout > 0:
458         # We are being called by a ufunc with a output argument, which is being
459         # actually used. Most likely, something like an increment,
460         # so we pass on a deparray
461         return super(depend_array,self).__array_prepare__(arr,context)
462      else:
463         # Apparently we are generating a new array.
464         # We have no way of knowing its
465         # dependencies, so we'd better return a ndarray view!
466         return np.ndarray.__array_prepare__(self.view(np.ndarray),arr.view(np.ndarray),context)
467
468   def __array_wrap__(self, arr, context=None):
469      """ Wraps up output array from ufunc.
470
471      See docstring of __array_prepare__().
472      """
473
474      if context is None or len(context) < 2 or not type(context[0]) is np.ufunc:
475         return np.ndarray.__array_wrap__(self.view(np.ndarray),arr.view(np.ndarray),context)
476      elif len(context[1]) > context[0].nin and context[0].nout > 0:
477         return super(depend_array,self).__array_wrap__(arr,context)
478      else:
479         return np.ndarray.__array_wrap__(self.view(np.ndarray),arr.view(np.ndarray),context)
480
481   # whenever possible in compound operations just return a regular ndarray
482   __array_priority__ = -1.0
483
484   def reshape(self, newshape):
485      """Changes the shape of the base array.
486
487      Args:
488         newshape: A tuple giving the desired shape of the new array.
489
490      Returns:
491         A depend_array with the dimensions given by newshape.
492      """
493
494      return depend_array(depstrip(self).reshape(newshape), name=self._name, synchro=self._synchro, func=self._func, dependants=self._dependants, tainted=self._tainted, base=self._bval)
495
496   def flatten(self):
497      """Makes the base array one dimensional.
498
499      Returns:
500         A flattened array.
501      """
502
503      return self.reshape(self.size)
504
505   @staticmethod
506   def __scalarindex(index, depth=1):
507      """Checks if an index points at a scalar value.
508
509      Used so that looking up one item in an array returns a scalar, whereas
510      looking up a slice of the array returns a new array with the same
511      dependencies as the original, so that changing the slice also taints
512      the global array.
513
514      Arguments:
515         index: the index to be checked.
516         depth: the rank of the array which is being accessed. Default value
517            is 1.
518
519      Returns:
520         A logical stating whether a __get__ instruction based
521         on index would return a scalar.
522      """
523
524      if (np.isscalar(index) and depth <= 1):
525         return True
526      elif (isinstance(index, tuple) and len(index)==depth):
527          #if the index is a tuple check it does not contain slices
528         for i in index:
529            if not np.isscalar(i): return False
530         return True
531      return False
532
533   def __getitem__(self,index):
534      """Returns value[index], after recalculating if necessary.
535
536      Overwrites the standard method of getting value, so that value
537      is recalculated if tainted. Scalar slices are returned as an ndarray,
538      so without depend machinery. If you need a "scalar depend" which
539      behaves as a slice, just create a 1x1 matrix, e.g b=a(7,1:2)
540
541      Args:
542         index: A slice variable giving the appropriate slice to be read.
543      """
544
545      if self._tainted[0]:
546         self.update_auto()
547         self.taint(taintme=False)
548
549      if (self.__scalarindex(index, self.ndim)):
550         return depstrip(self)[index]
551      else:
552         return depend_array(depstrip(self)[index], name=self._name, synchro=self._synchro, func=self._func, dependants=self._dependants, tainted=self._tainted, base=self._bval)
553
554
555   def __getslice__(self,i,j):
556      """Overwrites standard get function."""
557
558      return self.__getitem__(slice(i,j,None))
559
560   def get(self):
561      """Alternative to standard get function."""
562
563      return self.__get__(slice(None,None,None))
564
565   def __get__(self, instance, owner):
566      """Overwrites standard get function."""
567
568      # It is worth duplicating this code that is also used in __getitem__ as this
569      # is called most of the time, and we avoid creating a load of copies pointing to the same depend_array
570
571      if self._tainted[0]:
572         self.update_auto()
573         self.taint(taintme=False)
574
575      return self
576
577   def __setitem__(self,index,value,manual=True):
578      """Alters value[index] and taints dependencies.
579
580      Overwrites the standard method of setting value, so that dependent
581      quantities are tainted, and so we check that computed quantities are not
582      manually updated.
583
584      Args:
585         index: A slice variable giving the appropriate slice to be read.
586         value: The new value of the slice.
587         manual: Optional boolean giving whether the value has been changed
588            manually. True by default.
589      """
590
591      self.taint(taintme=False)
592      if manual:
593         depstrip(self)[index] = value
594         self.update_man()
595      elif index == slice(None,None,None):
596         self._bval[index] = value
597      else:
598         raise IndexError("Automatically computed arrays should span the whole parent")
599
600   def __setslice__(self,i,j,value):
601      """Overwrites standard set function."""
602
603      return self.__setitem__(slice(i,j),value)
604
605   def set(self, value, manual=True):
606      """Alterative to standard set function.
607
608      Args:
609         See __setitem__().
610      """
611
612      self.__setitem__(slice(None,None),value=value,manual=manual)
613
614   def __set__(self, instance, value):
615      """Overwrites standard set function."""
616
617      self.__setitem__(slice(None,None),value=value)
618
619
620# np.dot and other numpy.linalg functions have the nasty habit to
621# view cast to generate the output. Since we don't want to pass on
622# dependencies to the result of these functions, and we can't use
623# the ufunc mechanism to demote the class type to ndarray, we must
624# overwrite np.dot and other similar functions.
625# BEGINS NUMPY FUNCTIONS OVERRIDE
626# ** np.dot
627__dp_dot = np.dot
628
629def dep_dot(da, db):
630   a=depstrip(da)
631   b=depstrip(db)
632
633   return __dp_dot(da,db)
634
635np.dot = dep_dot
636# ENDS NUMPY FUNCTIONS OVERRIDE
637
638def dget(obj,member):
639   """Takes an object and retrieves one of its attributes.
640
641   Note that this is necessary as calling it in the standard way calls the
642   __get__() function of member.
643
644   Args:
645      obj: A user defined class.
646      member: A string giving the name of an attribute of obj.
647
648   Exceptions:
649      KeyError: If member is not an attribute of obj.
650
651   Returns:
652      obj.member.
653   """
654
655   return obj.__dict__[member]
656
657def dset(obj,member,value,name=None):
658   """Takes an object and sets one of its attributes.
659
660   Necessary for editing any depend object, and should be used for
661   initialising them as well, as often initialization occurs more than once,
662   with the second time effectively being an edit.
663
664   Args:
665      obj: A user defined class.
666      member: A string giving the name of an attribute of obj.
667      value: The new value of member.
668      name: New name of member.
669
670   Exceptions:
671      KeyError: If member is not an attribute of obj.
672   """
673
674   obj.__dict__[member] = value
675   if not name is None:
676      obj.__dict__[member]._name = name
677
678def depstrip(da):
679   """Removes dependencies from a depend_array.
680
681   Takes a depend_array and returns its value as a ndarray, effectively
682   stripping the dependencies from the ndarray. This speeds up a lot of
683   calculations involving these arrays. Must only be used if the value of the
684   array is not going to be changed.
685
686   Args:
687      deparray: A depend_array.
688
689   Returns:
690      A ndarray with the same value as deparray.
691   """
692
693   if isinstance(da, depend_array): # only bother to strip dependencies if the array actually IS a depend_array
694      #if da._tainted[0]:
695      #   print "!!! WARNING depstrip called on tainted array WARNING !!!!!" # I think we can safely assume that when we call depstrip the array has been cleared already but I am not 100% sure so better check - and in case raise the update
696      return da.view(np.ndarray)
697   else:
698      return da
699
700def deppipe(objfrom,memberfrom,objto,memberto):
701   """Synchronizes two depend objects.
702
703   Takes two depend objects, and makes one of them depend on the other in such
704   a way that both keep the same value. Used for attributes such as temperature
705   that are used in many different modules, and so need different depend objects
706   in each, but which should all have the same value.
707
708   Args:
709      objfrom: An object containing memberfrom.
710      memberfrom: The base depend object.
711      objto: An object containing memberto.
712      memberto: The depend object that should be equal to memberfrom.
713   """
714
715   dfrom = dget(objfrom,memberfrom)
716   dto = dget(objto,memberto)
717   dto._func = lambda : dfrom.get()
718   dto.add_dependency(dfrom)
719
720def depcopy(objfrom,memberfrom,objto,memberto):
721   """Copies the dependencies of one depend object to another.
722
723   Args:
724      See deppipe.
725   """
726   dfrom = dget(objfrom,memberfrom)
727   dto = dget(objto,memberto)
728   dto._dependants = dfrom._dependants
729   dto._synchro = dfrom._synchro
730   dto.add_synchro(dfrom._synchro)
731   dto._tainted = dfrom._tainted
732   dto._func = dfrom._func
733   if hasattr(dfrom,"_bval"):
734      dto._bval = dfrom._bval
735
736
737class dobject(object):
738   """Class that allows standard notation to be used for depend objects."""
739
740   def __getattribute__(self, name):
741      """Overwrites standard __getattribute__().
742
743      This changes the standard __getattribute__() function of any class that
744      subclasses dobject such that depend objects are called with their own
745      __get__() function rather than the standard one.
746      """
747
748      value = object.__getattribute__(self, name)
749      if hasattr(value, '__get__'):
750         value = value.__get__(self, self.__class__)
751      return value
752
753   def __setattr__(self, name, value):
754      """Overwrites standard __setattribute__().
755
756      This changes the standard __setattribute__() function of any class that
757      subclasses dobject such that depend objects are called with their own
758      __set__() function rather than the standard one.
759      """
760
761      try:
762         obj = object.__getattribute__(self, name)
763      except AttributeError:
764         pass
765      else:
766         if hasattr(obj, '__set__'):
767            return obj.__set__(self, value)
768      return object.__setattr__(self, name, value)
769