1# -*- coding: utf-8 -*-
2# This file is part of beets.
3# Copyright 2016, Adrian Sampson.
4#
5# Permission is hereby granted, free of charge, to any person obtaining
6# a copy of this software and associated documentation files (the
7# "Software"), to deal in the Software without restriction, including
8# without limitation the rights to use, copy, modify, merge, publish,
9# distribute, sublicense, and/or sell copies of the Software, and to
10# permit persons to whom the Software is furnished to do so, subject to
11# the following conditions:
12#
13# The above copyright notice and this permission notice shall be
14# included in all copies or substantial portions of the Software.
15
16"""Some common functionality for beets' test cases."""
17from __future__ import division, absolute_import, print_function
18
19import time
20import sys
21import os
22import tempfile
23import shutil
24import six
25import unittest
26from contextlib import contextmanager
27
28
29# Mangle the search path to include the beets sources.
30sys.path.insert(0, '..')
31import beets.library  # noqa: E402
32from beets import importer, logging  # noqa: E402
33from beets.ui import commands  # noqa: E402
34from beets import util  # noqa: E402
35import beets  # noqa: E402
36
37# Make sure the development versions of the plugins are used
38import beetsplug  # noqa: E402
39beetsplug.__path__ = [os.path.abspath(
40    os.path.join(__file__, '..', '..', 'beetsplug')
41)]
42
43# Test resources path.
44RSRC = util.bytestring_path(os.path.join(os.path.dirname(__file__), 'rsrc'))
45PLUGINPATH = os.path.join(os.path.dirname(__file__), 'rsrc', 'beetsplug')
46
47# Propagate to root logger so nosetest can capture it
48log = logging.getLogger('beets')
49log.propagate = True
50log.setLevel(logging.DEBUG)
51
52# Dummy item creation.
53_item_ident = 0
54
55# OS feature test.
56HAVE_SYMLINK = sys.platform != 'win32'
57HAVE_HARDLINK = sys.platform != 'win32'
58
59
60def item(lib=None):
61    global _item_ident
62    _item_ident += 1
63    i = beets.library.Item(
64        title=u'the title',
65        artist=u'the artist',
66        albumartist=u'the album artist',
67        album=u'the album',
68        genre=u'the genre',
69        lyricist=u'the lyricist',
70        composer=u'the composer',
71        arranger=u'the arranger',
72        grouping=u'the grouping',
73        year=1,
74        month=2,
75        day=3,
76        track=4,
77        tracktotal=5,
78        disc=6,
79        disctotal=7,
80        lyrics=u'the lyrics',
81        comments=u'the comments',
82        bpm=8,
83        comp=True,
84        path='somepath{0}'.format(_item_ident),
85        length=60.0,
86        bitrate=128000,
87        format='FLAC',
88        mb_trackid='someID-1',
89        mb_albumid='someID-2',
90        mb_artistid='someID-3',
91        mb_albumartistid='someID-4',
92        mb_releasetrackid='someID-5',
93        album_id=None,
94        mtime=12345,
95    )
96    if lib:
97        lib.add(i)
98    return i
99
100_album_ident = 0
101
102
103def album(lib=None):
104    global _item_ident
105    _item_ident += 1
106    i = beets.library.Album(
107        artpath=None,
108        albumartist=u'some album artist',
109        albumartist_sort=u'some sort album artist',
110        albumartist_credit=u'some album artist credit',
111        album=u'the album',
112        genre=u'the genre',
113        year=2014,
114        month=2,
115        day=5,
116        tracktotal=0,
117        disctotal=1,
118        comp=False,
119        mb_albumid='someID-1',
120        mb_albumartistid='someID-1'
121    )
122    if lib:
123        lib.add(i)
124    return i
125
126
127# Dummy import session.
128def import_session(lib=None, loghandler=None, paths=[], query=[], cli=False):
129    cls = commands.TerminalImportSession if cli else importer.ImportSession
130    return cls(lib, loghandler, paths, query)
131
132
133class Assertions(object):
134    """A mixin with additional unit test assertions."""
135
136    def assertExists(self, path):  # noqa
137        self.assertTrue(os.path.exists(util.syspath(path)),
138                        u'file does not exist: {!r}'.format(path))
139
140    def assertNotExists(self, path):  # noqa
141        self.assertFalse(os.path.exists(util.syspath(path)),
142                         u'file exists: {!r}'.format((path)))
143
144    def assert_equal_path(self, a, b):
145        """Check that two paths are equal."""
146        self.assertEqual(util.normpath(a), util.normpath(b),
147                         u'paths are not equal: {!r} and {!r}'.format(a, b))
148
149
150# A test harness for all beets tests.
151# Provides temporary, isolated configuration.
152class TestCase(unittest.TestCase, Assertions):
153    """A unittest.TestCase subclass that saves and restores beets'
154    global configuration. This allows tests to make temporary
155    modifications that will then be automatically removed when the test
156    completes. Also provides some additional assertion methods, a
157    temporary directory, and a DummyIO.
158    """
159    def setUp(self):
160        # A "clean" source list including only the defaults.
161        beets.config.sources = []
162        beets.config.read(user=False, defaults=True)
163
164        # Direct paths to a temporary directory. Tests can also use this
165        # temporary directory.
166        self.temp_dir = util.bytestring_path(tempfile.mkdtemp())
167
168        beets.config['statefile'] = \
169            util.py3_path(os.path.join(self.temp_dir, b'state.pickle'))
170        beets.config['library'] = \
171            util.py3_path(os.path.join(self.temp_dir, b'library.db'))
172        beets.config['directory'] = \
173            util.py3_path(os.path.join(self.temp_dir, b'libdir'))
174
175        # Set $HOME, which is used by confit's `config_dir()` to create
176        # directories.
177        self._old_home = os.environ.get('HOME')
178        os.environ['HOME'] = util.py3_path(self.temp_dir)
179
180        # Initialize, but don't install, a DummyIO.
181        self.io = DummyIO()
182
183    def tearDown(self):
184        if os.path.isdir(self.temp_dir):
185            shutil.rmtree(self.temp_dir)
186        if self._old_home is None:
187            del os.environ['HOME']
188        else:
189            os.environ['HOME'] = self._old_home
190        self.io.restore()
191
192        beets.config.clear()
193        beets.config._materialized = False
194
195
196class LibTestCase(TestCase):
197    """A test case that includes an in-memory library object (`lib`) and
198    an item added to the library (`i`).
199    """
200    def setUp(self):
201        super(LibTestCase, self).setUp()
202        self.lib = beets.library.Library(':memory:')
203        self.i = item(self.lib)
204
205    def tearDown(self):
206        self.lib._connection().close()
207        super(LibTestCase, self).tearDown()
208
209
210# Mock timing.
211
212class Timecop(object):
213    """Mocks the timing system (namely time() and sleep()) for testing.
214    Inspired by the Ruby timecop library.
215    """
216    def __init__(self):
217        self.now = time.time()
218
219    def time(self):
220        return self.now
221
222    def sleep(self, amount):
223        self.now += amount
224
225    def install(self):
226        self.orig = {
227            'time': time.time,
228            'sleep': time.sleep,
229        }
230        time.time = self.time
231        time.sleep = self.sleep
232
233    def restore(self):
234        time.time = self.orig['time']
235        time.sleep = self.orig['sleep']
236
237
238# Mock I/O.
239
240class InputException(Exception):
241    def __init__(self, output=None):
242        self.output = output
243
244    def __str__(self):
245        msg = "Attempt to read with no input provided."
246        if self.output is not None:
247            msg += " Output: {!r}".format(self.output)
248        return msg
249
250
251class DummyOut(object):
252    encoding = 'utf-8'
253
254    def __init__(self):
255        self.buf = []
256
257    def write(self, s):
258        self.buf.append(s)
259
260    def get(self):
261        if six.PY2:
262            return b''.join(self.buf)
263        else:
264            return ''.join(self.buf)
265
266    def flush(self):
267        self.clear()
268
269    def clear(self):
270        self.buf = []
271
272
273class DummyIn(object):
274    encoding = 'utf-8'
275
276    def __init__(self, out=None):
277        self.buf = []
278        self.reads = 0
279        self.out = out
280
281    def add(self, s):
282        if six.PY2:
283            self.buf.append(s + b'\n')
284        else:
285            self.buf.append(s + '\n')
286
287    def close(self):
288        pass
289
290    def readline(self):
291        if not self.buf:
292            if self.out:
293                raise InputException(self.out.get())
294            else:
295                raise InputException()
296        self.reads += 1
297        return self.buf.pop(0)
298
299
300class DummyIO(object):
301    """Mocks input and output streams for testing UI code."""
302    def __init__(self):
303        self.stdout = DummyOut()
304        self.stdin = DummyIn(self.stdout)
305
306    def addinput(self, s):
307        self.stdin.add(s)
308
309    def getoutput(self):
310        res = self.stdout.get()
311        self.stdout.clear()
312        return res
313
314    def readcount(self):
315        return self.stdin.reads
316
317    def install(self):
318        sys.stdin = self.stdin
319        sys.stdout = self.stdout
320
321    def restore(self):
322        sys.stdin = sys.__stdin__
323        sys.stdout = sys.__stdout__
324
325
326# Utility.
327
328def touch(path):
329    open(path, 'a').close()
330
331
332class Bag(object):
333    """An object that exposes a set of fields given as keyword
334    arguments. Any field not found in the dictionary appears to be None.
335    Used for mocking Album objects and the like.
336    """
337    def __init__(self, **fields):
338        self.fields = fields
339
340    def __getattr__(self, key):
341        return self.fields.get(key)
342
343
344# Convenience methods for setting up a temporary sandbox directory for tests
345# that need to interact with the filesystem.
346
347class TempDirMixin(object):
348    """Text mixin for creating and deleting a temporary directory.
349    """
350
351    def create_temp_dir(self):
352        """Create a temporary directory and assign it into `self.temp_dir`.
353        Call `remove_temp_dir` later to delete it.
354        """
355        path = tempfile.mkdtemp()
356        if not isinstance(path, bytes):
357            path = path.encode('utf8')
358        self.temp_dir = path
359
360    def remove_temp_dir(self):
361        """Delete the temporary directory created by `create_temp_dir`.
362        """
363        if os.path.isdir(self.temp_dir):
364            shutil.rmtree(self.temp_dir)
365
366
367# Platform mocking.
368
369@contextmanager
370def platform_windows():
371    import ntpath
372    old_path = os.path
373    try:
374        os.path = ntpath
375        yield
376    finally:
377        os.path = old_path
378
379
380@contextmanager
381def platform_posix():
382    import posixpath
383    old_path = os.path
384    try:
385        os.path = posixpath
386        yield
387    finally:
388        os.path = old_path
389
390
391@contextmanager
392def system_mock(name):
393    import platform
394    old_system = platform.system
395    platform.system = lambda: name
396    try:
397        yield
398    finally:
399        platform.system = old_system
400
401
402def slow_test(unused=None):
403    def _id(obj):
404        return obj
405    if 'SKIP_SLOW_TESTS' in os.environ:
406        return unittest.skip(u'test is slow')
407    return _id
408