1from __future__ import unicode_literals
2try:
3    import unittest2 as unittest
4except ImportError:
5    import unittest
6
7import os
8import datetime
9import time
10import subprocess
11import warnings
12import tempfile
13import pickle
14
15
16class WarningTestMixin(object):
17    # Based on https://stackoverflow.com/a/12935176/467366
18    class _AssertWarnsContext(warnings.catch_warnings):
19        def __init__(self, expected_warnings, parent, **kwargs):
20            super(WarningTestMixin._AssertWarnsContext, self).__init__(**kwargs)
21
22            self.parent = parent
23            try:
24                self.expected_warnings = list(expected_warnings)
25            except TypeError:
26                self.expected_warnings = [expected_warnings]
27
28            self._warning_log = []
29
30        def __enter__(self, *args, **kwargs):
31            rv = super(WarningTestMixin._AssertWarnsContext, self).__enter__(*args, **kwargs)
32
33            if self._showwarning is not self._module.showwarning:
34                super_showwarning = self._module.showwarning
35            else:
36                super_showwarning = None
37
38            def showwarning(*args, **kwargs):
39                if super_showwarning is not None:
40                    super_showwarning(*args, **kwargs)
41
42                self._warning_log.append(warnings.WarningMessage(*args, **kwargs))
43
44            self._module.showwarning = showwarning
45            return rv
46
47        def __exit__(self, *args, **kwargs):
48            super(WarningTestMixin._AssertWarnsContext, self).__exit__(self, *args, **kwargs)
49
50            self.parent.assertTrue(any(issubclass(item.category, warning)
51                                       for warning in self.expected_warnings
52                                       for item in self._warning_log))
53
54    def assertWarns(self, warning, callable=None, *args, **kwargs):
55        warnings.simplefilter('always')
56        context = self.__class__._AssertWarnsContext(warning, self)
57        if callable is None:
58            return context
59        else:
60            with context:
61                callable(*args, **kwargs)
62
63
64class PicklableMixin(object):
65    def _get_nobj_bytes(self, obj, dump_kwargs, load_kwargs):
66        """
67        Pickle and unpickle an object using ``pickle.dumps`` / ``pickle.loads``
68        """
69        pkl = pickle.dumps(obj, **dump_kwargs)
70        return pickle.loads(pkl, **load_kwargs)
71
72    def _get_nobj_file(self, obj, dump_kwargs, load_kwargs):
73        """
74        Pickle and unpickle an object using ``pickle.dump`` / ``pickle.load`` on
75        a temporary file.
76        """
77        with tempfile.TemporaryFile('w+b') as pkl:
78            pickle.dump(obj, pkl, **dump_kwargs)
79            pkl.seek(0)         # Reset the file to the beginning to read it
80            nobj = pickle.load(pkl, **load_kwargs)
81
82        return nobj
83
84    def assertPicklable(self, obj, asfile=False,
85                        dump_kwargs=None, load_kwargs=None):
86        """
87        Assert that an object can be pickled and unpickled. This assertion
88        assumes that the desired behavior is that the unpickled object compares
89        equal to the original object, but is not the same object.
90        """
91        get_nobj = self._get_nobj_file if asfile else self._get_nobj_bytes
92        dump_kwargs = dump_kwargs or {}
93        load_kwargs = load_kwargs or {}
94
95        nobj = get_nobj(obj, dump_kwargs, load_kwargs)
96        self.assertIsNot(obj, nobj)
97        self.assertEqual(obj, nobj)
98
99
100class TZContextBase(object):
101    """
102    Base class for a context manager which allows changing of time zones.
103
104    Subclasses may define a guard variable to either block or or allow time
105    zone changes by redefining ``_guard_var_name`` and ``_guard_allows_change``.
106    The default is that the guard variable must be affirmatively set.
107
108    Subclasses must define ``get_current_tz`` and ``set_current_tz``.
109    """
110    _guard_var_name = "DATEUTIL_MAY_CHANGE_TZ"
111    _guard_allows_change = True
112
113    def __init__(self, tzval):
114        self.tzval = tzval
115        self._old_tz = None
116
117    @classmethod
118    def tz_change_allowed(cls):
119        """
120        Class method used to query whether or not this class allows time zone
121        changes.
122        """
123        guard = bool(os.environ.get(cls._guard_var_name, False))
124
125        # _guard_allows_change gives the "default" behavior - if True, the
126        # guard is overcoming a block. If false, the guard is causing a block.
127        # Whether tz_change is allowed is therefore the XNOR of the two.
128        return guard == cls._guard_allows_change
129
130    @classmethod
131    def tz_change_disallowed_message(cls):
132        """ Generate instructions on how to allow tz changes """
133        msg = ('Changing time zone not allowed. Set {envar} to {gval} '
134               'if you would like to allow this behavior')
135
136        return msg.format(envar=cls._guard_var_name,
137                          gval=cls._guard_allows_change)
138
139    def __enter__(self):
140        if not self.tz_change_allowed():
141            raise ValueError(self.tz_change_disallowed_message())
142
143        self._old_tz = self.get_current_tz()
144        self.set_current_tz(self.tzval)
145
146    def __exit__(self, type, value, traceback):
147        if self._old_tz is not None:
148            self.set_current_tz(self._old_tz)
149
150        self._old_tz = None
151
152    def get_current_tz(self):
153        raise NotImplementedError
154
155    def set_current_tz(self):
156        raise NotImplementedError
157
158
159class TZEnvContext(TZContextBase):
160    """
161    Context manager that temporarily sets the `TZ` variable (for use on
162    *nix-like systems). Because the effect is local to the shell anyway, this
163    will apply *unless* a guard is set.
164
165    If you do not want the TZ environment variable set, you may set the
166    ``DATEUTIL_MAY_NOT_CHANGE_TZ_VAR`` variable to a truthy value.
167    """
168    _guard_var_name = "DATEUTIL_MAY_NOT_CHANGE_TZ_VAR"
169    _guard_allows_change = False
170
171    def get_current_tz(self):
172        return os.environ.get('TZ', UnsetTz)
173
174    def set_current_tz(self, tzval):
175        if tzval is UnsetTz and 'TZ' in os.environ:
176            del os.environ['TZ']
177        else:
178            os.environ['TZ'] = tzval
179
180        time.tzset()
181
182
183class TZWinContext(TZContextBase):
184    """
185    Context manager for changing local time zone on Windows.
186
187    Because the effect of this is system-wide and global, it may have
188    unintended side effect. Set the ``DATEUTIL_MAY_CHANGE_TZ`` environment
189    variable to a truthy value before using this context manager.
190    """
191    def get_current_tz(self):
192        p = subprocess.Popen(['tzutil', '/g'], stdout=subprocess.PIPE)
193
194        ctzname, err = p.communicate()
195        ctzname = ctzname.decode()     # Popen returns
196
197        if p.returncode:
198            raise OSError('Failed to get current time zone: ' + err)
199
200        return ctzname
201
202    def set_current_tz(self, tzname):
203        p = subprocess.Popen('tzutil /s "' + tzname + '"')
204
205        out, err = p.communicate()
206
207        if p.returncode:
208            raise OSError('Failed to set current time zone: ' +
209                          (err or 'Unknown error.'))
210
211
212###
213# Compatibility functions
214
215def _total_seconds(td):
216    # Python 2.6 doesn't have a total_seconds() method on timedelta objects
217    return ((td.seconds + td.days * 86400) * 1000000 +
218            td.microseconds) // 1000000
219
220total_seconds = getattr(datetime.timedelta, 'total_seconds', _total_seconds)
221
222
223###
224# Utility classes
225class NotAValueClass(object):
226    """
227    A class analogous to NaN that has operations defined for any type.
228    """
229    def _op(self, other):
230        return self             # Operation with NotAValue returns NotAValue
231
232    def _cmp(self, other):
233        return False
234
235    __add__ = __radd__ = _op
236    __sub__ = __rsub__ = _op
237    __mul__ = __rmul__ = _op
238    __div__ = __rdiv__ = _op
239    __truediv__ = __rtruediv__ = _op
240    __floordiv__ = __rfloordiv__ = _op
241
242    __lt__ = __rlt__ = _op
243    __gt__ = __rgt__ = _op
244    __eq__ = __req__ = _op
245    __le__ = __rle__ = _op
246    __ge__ = __rge__ = _op
247
248NotAValue = NotAValueClass()
249
250
251class ComparesEqualClass(object):
252    """
253    A class that is always equal to whatever you compare it to.
254    """
255
256    def __eq__(self, other):
257        return True
258
259    def __ne__(self, other):
260        return False
261
262    def __le__(self, other):
263        return True
264
265    def __ge__(self, other):
266        return True
267
268    def __lt__(self, other):
269        return False
270
271    def __gt__(self, other):
272        return False
273
274    __req__ = __eq__
275    __rne__ = __ne__
276    __rle__ = __le__
277    __rge__ = __ge__
278    __rlt__ = __lt__
279    __rgt__ = __gt__
280
281ComparesEqual = ComparesEqualClass()
282
283class UnsetTzClass(object):
284    """ Sentinel class for unset time zone variable """
285    pass
286
287UnsetTz = UnsetTzClass()
288
289