1# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
2# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
3#
4# MDAnalysis --- https://www.mdanalysis.org
5# Copyright (c) 2006-2017 The MDAnalysis Development Team and contributors
6# (see the file AUTHORS for the full list of names)
7#
8# Released under the GNU Public Licence, v2 or any higher version
9#
10# Please cite your use of MDAnalysis in published work:
11#
12# R. J. Gowers, M. Linke, J. Barnoud, T. J. E. Reddy, M. N. Melo, S. L. Seyler,
13# D. L. Dotson, J. Domanski, S. Buchoux, I. M. Kenney, and O. Beckstein.
14# MDAnalysis: A Python package for the rapid analysis of molecular dynamics
15# simulations. In S. Benthall and S. Rostrup editors, Proceedings of the 15th
16# Python in Science Conference, pages 102-109, Austin, TX, 2016. SciPy.
17#
18# N. Michaud-Agrawal, E. J. Denning, T. B. Woolf, and O. Beckstein.
19# MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations.
20# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
21#
22"""
23Analysis building blocks --- :mod:`MDAnalysis.analysis.base`
24============================================================
25
26A collection of useful building blocks for creating Analysis
27classes.
28
29"""
30from __future__ import absolute_import
31import six
32from six.moves import range, zip
33import inspect
34import logging
35import warnings
36
37import numpy as np
38from MDAnalysis import coordinates
39from MDAnalysis.core.groups import AtomGroup
40from MDAnalysis.lib.log import ProgressMeter
41
42logger = logging.getLogger(__name__)
43
44
45class AnalysisBase(object):
46    """Base class for defining multi frame analysis
47
48    The class it is designed as a template for creating multiframe analyses.
49    This class will automatically take care of setting up the trajectory
50    reader for iterating, and it offers to show a progress meter.
51
52    To define a new Analysis, `AnalysisBase` needs to be subclassed
53    `_single_frame` must be defined. It is also possible to define
54    `_prepare` and `_conclude` for pre and post processing. See the example
55    below.
56
57    .. code-block:: python
58
59       class NewAnalysis(AnalysisBase):
60           def __init__(self, atomgroup, parameter, **kwargs):
61               super(NewAnalysis, self).__init__(atomgroup.universe.trajectory,
62                                                 **kwargs)
63               self._parameter = parameter
64               self._ag = atomgroup
65
66           def _prepare(self):
67               # OPTIONAL
68               # Called before iteration on the trajectory has begun.
69               # Data structures can be set up at this time
70               self.result = []
71
72           def _single_frame(self):
73               # REQUIRED
74               # Called after the trajectory is moved onto each new frame.
75               # store result of `some_function` for a single frame
76               self.result.append(some_function(self._ag, self._parameter))
77
78           def _conclude(self):
79               # OPTIONAL
80               # Called once iteration on the trajectory is finished.
81               # Apply normalisation and averaging to results here.
82               self.result = np.asarray(self.result) / np.sum(self.result)
83
84    Afterwards the new analysis can be run like this.
85
86    .. code-block:: python
87
88       na = NewAnalysis(u.select_atoms('name CA'), 35).run(start=10, stop=20)
89       print(na.result)
90
91    """
92
93    def __init__(self, trajectory, verbose=False, **kwargs):
94        """
95        Parameters
96        ----------
97        trajectory : mda.Reader
98            A trajectory Reader
99        verbose : bool, optional
100           Turn on more logging and debugging, default ``False``
101        """
102        self._trajectory = trajectory
103        self._verbose = verbose
104        # do deprecated kwargs
105        # remove in 1.0
106        deps = []
107        for arg in ['start', 'stop', 'step']:
108            if arg in kwargs and not kwargs[arg] is None:
109                deps.append(arg)
110                setattr(self, arg, kwargs[arg])
111        if deps:
112            warnings.warn('Setting the following kwargs should be '
113                          'done in the run() method: {}'.format(
114                              ', '.join(deps)),
115                          DeprecationWarning)
116
117    def _setup_frames(self, trajectory, start=None, stop=None, step=None):
118        """
119        Pass a Reader object and define the desired iteration pattern
120        through the trajectory
121
122        Parameters
123        ----------
124        trajectory : mda.Reader
125            A trajectory Reader
126        start : int, optional
127            start frame of analysis
128        stop : int, optional
129            stop frame of analysis
130        step : int, optional
131            number of frames to skip between each analysed frame
132        """
133        self._trajectory = trajectory
134        # TODO: Remove once start/stop/step are deprecated from init
135        # See if these have been set as class attributes, and use that
136        start = getattr(self, 'start', start)
137        stop = getattr(self, 'stop', stop)
138        step = getattr(self, 'step', step)
139        start, stop, step = trajectory.check_slice_indices(start, stop, step)
140        self.start = start
141        self.stop = stop
142        self.step = step
143        self.n_frames = len(range(start, stop, step))
144        interval = int(self.n_frames // 100)
145        if interval == 0:
146            interval = 1
147
148        verbose = getattr(self, '_verbose', False)
149        self._pm = ProgressMeter(self.n_frames if self.n_frames else 1,
150                                 interval=interval, verbose=verbose)
151
152    def _single_frame(self):
153        """Calculate data from a single frame of trajectory
154
155        Don't worry about normalising, just deal with a single frame.
156        """
157        raise NotImplementedError("Only implemented in child classes")
158
159    def _prepare(self):
160        """Set things up before the analysis loop begins"""
161        pass
162
163    def _conclude(self):
164        """Finalise the results you've gathered.
165
166        Called at the end of the run() method to finish everything up.
167        """
168        pass
169
170    def run(self, start=None, stop=None, step=None, verbose=None):
171        """Perform the calculation
172
173        Parameters
174        ----------
175        start : int, optional
176            start frame of analysis
177        stop : int, optional
178            stop frame of analysis
179        step : int, optional
180            number of frames to skip between each analysed frame
181        verbose : bool, optional
182            Turn on verbosity
183        """
184        logger.info("Choosing frames to analyze")
185        # if verbose unchanged, use class default
186        verbose = getattr(self, '_verbose', False) if verbose is None else verbose
187
188        self._setup_frames(self._trajectory, start, stop, step)
189        logger.info("Starting preparation")
190        self._prepare()
191        for i, ts in enumerate(
192                self._trajectory[self.start:self.stop:self.step]):
193            self._frame_index = i
194            self._ts = ts
195            # logger.info("--> Doing frame {} of {}".format(i+1, self.n_frames))
196            self._single_frame()
197            self._pm.echo(self._frame_index)
198        logger.info("Finishing up")
199        self._conclude()
200        return self
201
202
203class AnalysisFromFunction(AnalysisBase):
204    """
205    Create an analysis from a function working on AtomGroups
206
207    Attributes
208    ----------
209    results : ndarray
210        results of calculation are stored after call to ``run``
211
212    Example
213    -------
214    >>> def rotation_matrix(mobile, ref):
215    >>>     return mda.analysis.align.rotation_matrix(mobile, ref)[0]
216
217    >>> rot = AnalysisFromFunction(rotation_matrix, trajectory, mobile, ref).run()
218    >>> print(rot.results)
219
220    Raises
221    ------
222    ValueError : if ``function`` has the same kwargs as ``BaseAnalysis``
223    """
224
225    def __init__(self, function, trajectory=None, *args, **kwargs):
226        """Parameters
227        ----------
228        function : callable
229            function to evaluate at each frame
230        trajectory : mda.coordinates.Reader (optional)
231            trajectory to iterate over. If ``None`` the first AtomGroup found in
232            args and kwargs is used as a source for the trajectory.
233        *args : list
234           arguments for ``function``
235        **kwargs : dict
236           arugments for ``function`` and ``AnalysisBase``
237
238        """
239        if (trajectory is not None) and (not isinstance(
240                trajectory, coordinates.base.ProtoReader)):
241            args = args + (trajectory,)
242            trajectory = None
243
244        if trajectory is None:
245            for arg in args:
246                if isinstance(arg, AtomGroup):
247                    trajectory = arg.universe.trajectory
248            # when we still didn't find anything
249            if trajectory is None:
250                for arg in six.itervalues(kwargs):
251                    if isinstance(arg, AtomGroup):
252                        trajectory = arg.universe.trajectory
253
254        if trajectory is None:
255            raise ValueError("Couldn't find a trajectory")
256
257        self.function = function
258        self.args = args
259
260        # TODO: Remove in 1.0
261        my_kwargs = {}
262        for depped_arg in ['start', 'stop', 'step']:
263            if depped_arg in kwargs:
264                my_kwargs[depped_arg] = kwargs.pop(depped_arg)
265        self.kwargs = kwargs
266
267        super(AnalysisFromFunction, self).__init__(trajectory, **my_kwargs)
268
269    def _prepare(self):
270        self.results = []
271
272    def _single_frame(self):
273        self.results.append(self.function(*self.args, **self.kwargs))
274
275    def _conclude(self):
276        self.results = np.asarray(self.results)
277
278
279def analysis_class(function):
280    """
281    Transform a function operating on a single frame to an analysis class
282
283    For an usage in a library we recommend the following style:
284
285    >>> def rotation_matrix(mobile, ref):
286    >>>     return mda.analysis.align.rotation_matrix(mobile, ref)[0]
287    >>> RotationMatrix = analysis_class(rotation_matrix)
288
289    It can also be used as a decorator:
290
291    >>> @analysis_class
292    >>> def RotationMatrix(mobile, ref):
293    >>>     return mda.analysis.align.rotation_matrix(mobile, ref)[0]
294
295    >>> rot = RotationMatrix(u.trajectory, mobile, ref).run(step=2)
296    >>> print(rot.results)
297    """
298
299    class WrapperClass(AnalysisFromFunction):
300        def __init__(self, trajectory=None, *args, **kwargs):
301            super(WrapperClass, self).__init__(function, trajectory,
302                                               *args, **kwargs)
303
304    return WrapperClass
305
306
307def _filter_baseanalysis_kwargs(function, kwargs):
308    """
309    create two dictionaries with kwargs separated for function and AnalysisBase
310
311    Parameters
312    ----------
313    function : callable
314        function to be called
315    kwargs : dict
316        keyword argument dictionary
317
318    Returns
319    -------
320    base_args : dict
321        dictionary of AnalysisBase kwargs
322    kwargs : dict
323        kwargs without AnalysisBase kwargs
324
325    Raises
326    ------
327    ValueError : if ``function`` has the same kwargs as ``BaseAnalysis``
328    """
329    try:
330        # pylint: disable=deprecated-method
331        base_argspec = inspect.getfullargspec(AnalysisBase.__init__)
332    except AttributeError:
333        # pylint: disable=deprecated-method
334        base_argspec = inspect.getargspec(AnalysisBase.__init__)
335
336    n_base_defaults = len(base_argspec.defaults)
337    base_kwargs = {name: val
338                   for name, val in zip(base_argspec.args[-n_base_defaults:],
339                                        base_argspec.defaults)}
340
341    try:
342        # pylint: disable=deprecated-method
343        argspec = inspect.getfullargspec(function)
344    except AttributeError:
345        # pylint: disable=deprecated-method
346        argspec = inspect.getargspec(function)
347
348    for base_kw in six.iterkeys(base_kwargs):
349        if base_kw in argspec.args:
350            raise ValueError(
351                "argument name '{}' clashes with AnalysisBase argument."
352                "Now allowed are: {}".format(base_kw, base_kwargs.keys()))
353
354    base_args = {}
355    for argname, default in six.iteritems(base_kwargs):
356        base_args[argname] = kwargs.pop(argname, default)
357
358    return base_args, kwargs
359