1# -*- coding: utf-8 -*-
2from __future__ import unicode_literals
3
4import os
5import curses
6import logging
7import threading
8from functools import partial
9
10import pytest
11from vcr import VCR
12from six.moves.urllib.parse import urlparse, parse_qs
13
14from tuir.oauth import OAuthHelper, OAuthHandler, OAuthHTTPServer
15from tuir.content import RequestHeaderRateLimiter
16from tuir.config import Config
17from tuir.packages import praw
18from tuir.terminal import Terminal
19from tuir.subreddit_page import SubredditPage
20from tuir.submission_page import SubmissionPage
21from tuir.subscription_page import SubscriptionPage
22from tuir.inbox_page import InboxPage
23
24try:
25    from unittest import mock
26except ImportError:
27    import mock
28
29# Turn on autospec by default for convenience
30patch = partial(mock.patch, autospec=True)
31
32# Turn on logging, but disable vcr from spamming
33logging.basicConfig(
34    level=logging.DEBUG,
35    format='%(asctime)s:%(levelname)s:%(filename)s:%(lineno)d:%(message)s')
36for name in ['vcr.matchers', 'vcr.stubs']:
37    logging.getLogger(name).disabled = True
38
39
40def pytest_addoption(parser):
41    parser.addoption('--record-mode', dest='record_mode', default='none')
42    parser.addoption('--refresh-token', dest='refresh_token',
43                     default='~/.local/share/tuir/refresh-token')
44
45
46class MockStdscr(mock.MagicMock):
47    """
48    Extend mock to mimic curses.stdscr by keeping track of the terminal
49    coordinates and allowing for the creation of subwindows with the same
50    properties as stdscr.
51    """
52
53    def getyx(self):
54        return self.y, self.x
55
56    def getbegyx(self):
57        return 0, 0
58
59    def getmaxyx(self):
60        return self.nlines, self.ncols
61
62    def derwin(self, *args):
63        """
64        derwin()
65        derwin(begin_y, begin_x)
66        derwin(nlines, ncols, begin_y, begin_x)
67        """
68
69        if 'subwin' not in dir(self):
70            self.attach_mock(MockStdscr(), 'subwin')
71
72        if len(args) == 0:
73            nlines = self.nlines
74            ncols = self.ncols
75        elif len(args) == 2:
76            nlines = self.nlines - args[0]
77            ncols = self.ncols - args[1]
78        else:
79            nlines = min(self.nlines - args[2], args[0])
80            ncols = min(self.ncols - args[3], args[1])
81
82        self.subwin.nlines = nlines
83        self.subwin.ncols = ncols
84        self.subwin.x = 0
85        self.subwin.y = 0
86        return self.subwin
87
88
89@pytest.fixture(scope='session')
90def vcr(request):
91
92    def auth_matcher(r1, r2):
93        return (r1.headers.get('authorization') ==
94                r2.headers.get('authorization'))
95
96    def uri_with_query_matcher(r1, r2):
97        "URI matcher that allows query params to appear in any order"
98        p1,  p2 = urlparse(r1.uri), urlparse(r2.uri)
99        return (p1[:3] == p2[:3] and
100                parse_qs(p1.query, True) == parse_qs(p2.query, True))
101
102    # Use `none` to use the recorded requests, and `once` to delete existing
103    # cassettes and re-record.
104    record_mode = request.config.option.record_mode
105    assert record_mode in ('once', 'none')
106
107    cassette_dir = os.path.join(os.path.dirname(__file__), 'cassettes')
108    if not os.path.exists(cassette_dir):
109        os.makedirs(cassette_dir)
110
111    # https://github.com/kevin1024/vcrpy/pull/196
112    vcr = VCR(
113        record_mode=request.config.option.record_mode,
114        filter_headers=[('Authorization', '**********')],
115        filter_post_data_parameters=[('refresh_token', '**********')],
116        match_on=['method', 'uri_with_query', 'auth', 'body'],
117        cassette_library_dir=cassette_dir)
118    vcr.register_matcher('auth', auth_matcher)
119    vcr.register_matcher('uri_with_query', uri_with_query_matcher)
120    return vcr
121
122
123@pytest.fixture(scope='session')
124def refresh_token(request):
125    if request.config.option.record_mode == 'none':
126        return 'mock_refresh_token'
127    else:
128        token_file = request.config.option.refresh_token
129        with open(os.path.expanduser(token_file)) as fp:
130            return fp.read()
131
132
133@pytest.yield_fixture()
134def config():
135    conf = Config()
136    with mock.patch.object(conf, 'save_history'),          \
137            mock.patch.object(conf, 'delete_history'),     \
138            mock.patch.object(conf, 'save_refresh_token'), \
139            mock.patch.object(conf, 'delete_refresh_token'):
140
141        def delete_refresh_token():
142            # Skip the os.remove
143            conf.refresh_token = None
144        conf.delete_refresh_token.side_effect = delete_refresh_token
145
146        yield conf
147
148
149@pytest.yield_fixture()
150def stdscr():
151    with patch('curses.initscr'),               \
152            patch('curses.echo'),               \
153            patch('curses.flash'),              \
154            patch('curses.endwin'),             \
155            patch('curses.newwin'),             \
156            patch('curses.noecho'),             \
157            patch('curses.cbreak'),             \
158            patch('curses.doupdate'),           \
159            patch('curses.nocbreak'),           \
160            patch('curses.curs_set'),           \
161            patch('curses.init_pair'),          \
162            patch('curses.color_pair'),         \
163            patch('curses.has_colors'),         \
164            patch('curses.start_color'),        \
165            patch('curses.use_default_colors'):
166        out = MockStdscr(nlines=40, ncols=80, x=0, y=0)
167        curses.initscr.return_value = out
168        curses.newwin.side_effect = lambda *args: out.derwin(*args)
169        curses.color_pair.return_value = 23
170        curses.has_colors.return_value = True
171        curses.ACS_VLINE = 0
172        curses.COLORS = 256
173        curses.COLOR_PAIRS = 256
174        yield out
175
176
177@pytest.yield_fixture()
178def reddit(vcr, request):
179    cassette_name = '%s.yaml' % request.node.name
180    # Clear the cassette before running the test
181    if request.config.option.record_mode == 'once':
182        filename = os.path.join(vcr.cassette_library_dir, cassette_name)
183        if os.path.exists(filename):
184            os.remove(filename)
185
186    with vcr.use_cassette(cassette_name):
187        with patch('tuir.packages.praw.Reddit.get_access_information'):
188            handler = RequestHeaderRateLimiter()
189            reddit = praw.Reddit(user_agent='tuir test suite',
190                                 decode_html_entities=False,
191                                 disable_update_check=True,
192                                 handler=handler)
193            # praw uses a global cache for requests, so we need to clear it
194            # before each unit test. Otherwise we may fail to generate new
195            # cassettes.
196            reddit.handler.clear_cache()
197            if request.config.option.record_mode == 'none':
198                # Turn off praw rate limiting when using cassettes
199                reddit.config.api_request_delay = 0
200            yield reddit
201
202
203@pytest.fixture()
204def terminal(stdscr, config):
205    term = Terminal(stdscr, config=config)
206    term.set_theme()
207    # Disable the python 3.4 addch patch so that the mock stdscr calls are
208    # always made the same way
209    term.addch = lambda window, *args: window.addch(*args)
210    return term
211
212
213@pytest.fixture()
214def oauth(reddit, terminal, config):
215    return OAuthHelper(reddit, terminal, config)
216
217
218@pytest.yield_fixture()
219def oauth_server():
220    # Start the OAuth server on a random port in the background
221    server = OAuthHTTPServer(('', 0), OAuthHandler)
222    server.url = 'http://{0}:{1}/'.format(*server.server_address)
223    thread = threading.Thread(target=server.serve_forever)
224    thread.start()
225    try:
226        yield server
227    finally:
228        server.shutdown()
229        thread.join()
230        server.server_close()
231
232
233@pytest.fixture()
234def submission_page(reddit, terminal, config, oauth):
235    submission = 'https://www.reddit.com/r/Python/comments/2xmo63'
236
237    with terminal.loader():
238        page = SubmissionPage(reddit, terminal, config, oauth, url=submission)
239    assert terminal.loader.exception is None
240    page.draw()
241    return page
242
243
244@pytest.fixture()
245def subreddit_page(reddit, terminal, config, oauth):
246    subreddit = '/r/python'
247
248    with terminal.loader():
249        page = SubredditPage(reddit, terminal, config, oauth, subreddit)
250    assert not terminal.loader.exception
251    page.draw()
252    return page
253
254
255@pytest.fixture()
256def subscription_page(reddit, terminal, config, oauth):
257    content_type = 'popular'
258
259    with terminal.loader():
260        page = SubscriptionPage(reddit, terminal, config, oauth, content_type)
261    assert terminal.loader.exception is None
262    page.draw()
263    return page
264
265
266@pytest.fixture()
267def inbox_page(reddit, terminal, config, oauth, refresh_token):
268    # The inbox page required logging in on an account with at least one message
269    config.refresh_token = refresh_token
270    oauth.authorize()
271
272    with terminal.loader():
273        page = InboxPage(reddit, terminal, config, oauth)
274    assert terminal.loader.exception is None
275    page.draw()
276    return page
277