1#!/usr/bin/env python
2# -*- coding: utf-8; py-indent-offset:4 -*-
3###############################################################################
4#
5# Copyright (C) 2015, 2016, 2017 Daniel Rodriguez
6#
7# This program is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# This program is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with this program.  If not, see <http://www.gnu.org/licenses/>.
19#
20###############################################################################
21from __future__ import (absolute_import, division, print_function,
22                        unicode_literals)
23
24from collections import OrderedDict
25import itertools
26import sys
27
28import backtrader as bt
29from .utils.py3 import zip, string_types, with_metaclass
30
31
32def findbases(kls, topclass):
33    retval = list()
34    for base in kls.__bases__:
35        if issubclass(base, topclass):
36            retval.extend(findbases(base, topclass))
37            retval.append(base)
38
39    return retval
40
41
42def findowner(owned, cls, startlevel=2, skip=None):
43    # skip this frame and the caller's -> start at 2
44    for framelevel in itertools.count(startlevel):
45        try:
46            frame = sys._getframe(framelevel)
47        except ValueError:
48            # Frame depth exceeded ... no owner ... break away
49            break
50
51        # 'self' in regular code
52        self_ = frame.f_locals.get('self', None)
53        if skip is not self_:
54            if self_ is not owned and isinstance(self_, cls):
55                return self_
56
57        # '_obj' in metaclasses
58        obj_ = frame.f_locals.get('_obj', None)
59        if skip is not obj_:
60            if obj_ is not owned and isinstance(obj_, cls):
61                return obj_
62
63    return None
64
65
66class MetaBase(type):
67    def doprenew(cls, *args, **kwargs):
68        return cls, args, kwargs
69
70    def donew(cls, *args, **kwargs):
71        _obj = cls.__new__(cls, *args, **kwargs)
72        return _obj, args, kwargs
73
74    def dopreinit(cls, _obj, *args, **kwargs):
75        return _obj, args, kwargs
76
77    def doinit(cls, _obj, *args, **kwargs):
78        _obj.__init__(*args, **kwargs)
79        return _obj, args, kwargs
80
81    def dopostinit(cls, _obj, *args, **kwargs):
82        return _obj, args, kwargs
83
84    def __call__(cls, *args, **kwargs):
85        cls, args, kwargs = cls.doprenew(*args, **kwargs)
86        _obj, args, kwargs = cls.donew(*args, **kwargs)
87        _obj, args, kwargs = cls.dopreinit(_obj, *args, **kwargs)
88        _obj, args, kwargs = cls.doinit(_obj, *args, **kwargs)
89        _obj, args, kwargs = cls.dopostinit(_obj, *args, **kwargs)
90        return _obj
91
92
93class AutoInfoClass(object):
94    _getpairsbase = classmethod(lambda cls: OrderedDict())
95    _getpairs = classmethod(lambda cls: OrderedDict())
96    _getrecurse = classmethod(lambda cls: False)
97
98    @classmethod
99    def _derive(cls, name, info, otherbases, recurse=False):
100        # collect the 3 set of infos
101        # info = OrderedDict(info)
102        baseinfo = cls._getpairs().copy()
103        obasesinfo = OrderedDict()
104        for obase in otherbases:
105            if isinstance(obase, (tuple, dict)):
106                obasesinfo.update(obase)
107            else:
108                obasesinfo.update(obase._getpairs())
109
110        # update the info of this class (base) with that from the other bases
111        baseinfo.update(obasesinfo)
112
113        # The info of the new class is a copy of the full base info
114        # plus and update from parameter
115        clsinfo = baseinfo.copy()
116        clsinfo.update(info)
117
118        # The new items to update/set are those from the otherbase plus the new
119        info2add = obasesinfo.copy()
120        info2add.update(info)
121
122        clsmodule = sys.modules[cls.__module__]
123        newclsname = str(cls.__name__ + '_' + name)  # str - Python 2/3 compat
124
125        # This loop makes sure that if the name has already been defined, a new
126        # unique name is found. A collision example is in the plotlines names
127        # definitions of bt.indicators.MACD and bt.talib.MACD. Both end up
128        # definining a MACD_pl_macd and this makes it impossible for the pickle
129        # module to send results over a multiprocessing channel
130        namecounter = 1
131        while hasattr(clsmodule, newclsname):
132            newclsname += str(namecounter)
133            namecounter += 1
134
135        newcls = type(newclsname, (cls,), {})
136        setattr(clsmodule, newclsname, newcls)
137
138        setattr(newcls, '_getpairsbase',
139                classmethod(lambda cls: baseinfo.copy()))
140        setattr(newcls, '_getpairs', classmethod(lambda cls: clsinfo.copy()))
141        setattr(newcls, '_getrecurse', classmethod(lambda cls: recurse))
142
143        for infoname, infoval in info2add.items():
144            if recurse:
145                recursecls = getattr(newcls, infoname, AutoInfoClass)
146                infoval = recursecls._derive(name + '_' + infoname,
147                                             infoval,
148                                             [])
149
150            setattr(newcls, infoname, infoval)
151
152        return newcls
153
154    def isdefault(self, pname):
155        return self._get(pname) == self._getkwargsdefault()[pname]
156
157    def notdefault(self, pname):
158        return self._get(pname) != self._getkwargsdefault()[pname]
159
160    def _get(self, name, default=None):
161        return getattr(self, name, default)
162
163    @classmethod
164    def _getkwargsdefault(cls):
165        return cls._getpairs()
166
167    @classmethod
168    def _getkeys(cls):
169        return cls._getpairs().keys()
170
171    @classmethod
172    def _getdefaults(cls):
173        return list(cls._getpairs().values())
174
175    @classmethod
176    def _getitems(cls):
177        return cls._getpairs().items()
178
179    @classmethod
180    def _gettuple(cls):
181        return tuple(cls._getpairs().items())
182
183    def _getkwargs(self, skip_=False):
184        l = [
185            (x, getattr(self, x))
186            for x in self._getkeys() if not skip_ or not x.startswith('_')]
187        return OrderedDict(l)
188
189    def _getvalues(self):
190        return [getattr(self, x) for x in self._getkeys()]
191
192    def __new__(cls, *args, **kwargs):
193        obj = super(AutoInfoClass, cls).__new__(cls, *args, **kwargs)
194
195        if cls._getrecurse():
196            for infoname in obj._getkeys():
197                recursecls = getattr(cls, infoname)
198                setattr(obj, infoname, recursecls())
199
200        return obj
201
202
203class MetaParams(MetaBase):
204    def __new__(meta, name, bases, dct):
205        # Remove params from class definition to avod inheritance
206        # (and hence "repetition")
207        newparams = dct.pop('params', ())
208
209        packs = 'packages'
210        newpackages = tuple(dct.pop(packs, ()))  # remove before creation
211
212        fpacks = 'frompackages'
213        fnewpackages = tuple(dct.pop(fpacks, ()))  # remove before creation
214
215        # Create the new class - this pulls predefined "params"
216        cls = super(MetaParams, meta).__new__(meta, name, bases, dct)
217
218        # Pulls the param class out of it - default is the empty class
219        params = getattr(cls, 'params', AutoInfoClass)
220
221        # Pulls the packages class out of it - default is the empty class
222        packages = tuple(getattr(cls, packs, ()))
223        fpackages = tuple(getattr(cls, fpacks, ()))
224
225        # get extra (to the right) base classes which have a param attribute
226        morebasesparams = [x.params for x in bases[1:] if hasattr(x, 'params')]
227
228        # Get extra packages, add them to the packages and put all in the class
229        for y in [x.packages for x in bases[1:] if hasattr(x, packs)]:
230            packages += tuple(y)
231
232        for y in [x.frompackages for x in bases[1:] if hasattr(x, fpacks)]:
233            fpackages += tuple(y)
234
235        cls.packages = packages + newpackages
236        cls.frompackages = fpackages + fnewpackages
237
238        # Subclass and store the newly derived params class
239        cls.params = params._derive(name, newparams, morebasesparams)
240
241        return cls
242
243    def donew(cls, *args, **kwargs):
244        clsmod = sys.modules[cls.__module__]
245        # import specified packages
246        for p in cls.packages:
247            if isinstance(p, (tuple, list)):
248                p, palias = p
249            else:
250                palias = p
251
252            pmod = __import__(p)
253
254            plevels = p.split('.')
255            if p == palias and len(plevels) > 1:  # 'os.path' not aliased
256                setattr(clsmod, pmod.__name__, pmod)  # set 'os' in module
257
258            else:  # aliased and/or dots
259                for plevel in plevels[1:]:  # recurse down the mod
260                    pmod = getattr(pmod, plevel)
261
262                setattr(clsmod, palias, pmod)
263
264        # import from specified packages - the 2nd part is a string or iterable
265        for p, frompackage in cls.frompackages:
266            if isinstance(frompackage, string_types):
267                frompackage = (frompackage,)  # make it a tuple
268
269            for fp in frompackage:
270                if isinstance(fp, (tuple, list)):
271                    fp, falias = fp
272                else:
273                    fp, falias = fp, fp  # assumed is string
274
275                # complain "not string" without fp (unicode vs bytes)
276                pmod = __import__(p, fromlist=[str(fp)])
277                pattr = getattr(pmod, fp)
278                setattr(clsmod, falias, pattr)
279
280        # Create params and set the values from the kwargs
281        params = cls.params()
282        for pname, pdef in cls.params._getitems():
283            setattr(params, pname, kwargs.pop(pname, pdef))
284
285        # Create the object and set the params in place
286        _obj, args, kwargs = super(MetaParams, cls).donew(*args, **kwargs)
287        _obj.params = params
288        _obj.p = params  # shorter alias
289
290        # Parameter values have now been set before __init__
291        return _obj, args, kwargs
292
293
294class ParamsBase(with_metaclass(MetaParams, object)):
295    pass  # stub to allow easy subclassing without metaclasses
296
297
298class ItemCollection(object):
299    '''
300    Holds a collection of items that can be reached by
301
302      - Index
303      - Name (if set in the append operation)
304    '''
305    def __init__(self):
306        self._items = list()
307        self._names = list()
308
309    def __len__(self):
310        return len(self._items)
311
312    def append(self, item, name=None):
313        setattr(self, name, item)
314        self._items.append(item)
315        if name:
316            self._names.append(name)
317
318    def __getitem__(self, key):
319        return self._items[key]
320
321    def getnames(self):
322        return self._names
323
324    def getitems(self):
325        return zip(self._names, self._items)
326
327    def getbyname(self, name):
328        idx = self._names.index(name)
329        return self._items[idx]
330