1import json
2import os
3import re
4import sys
5from subprocess import Popen, PIPE
6from urllib.parse import urlsplit, urlunsplit
7from testfixtures import LogCapture
8from twisted.internet import defer
9from twisted.trial.unittest import TestCase
10
11from scrapy.http import Request
12from scrapy.utils.test import get_crawler
13
14from tests.mockserver import MockServer
15from tests.spiders import SimpleSpider, SingleRequestSpider
16
17
18class MitmProxy:
19    auth_user = 'scrapy'
20    auth_pass = 'scrapy'
21
22    def start(self):
23        from scrapy.utils.test import get_testenv
24        script = """
25import sys
26from mitmproxy.tools.main import mitmdump
27sys.argv[0] = "mitmdump"
28sys.exit(mitmdump())
29        """
30        cert_path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
31                                 'keys', 'mitmproxy-ca.pem')
32        self.proc = Popen([sys.executable,
33                           '-c', script,
34                           '--listen-host', '127.0.0.1',
35                           '--listen-port', '0',
36                           '--proxyauth', f'{self.auth_user}:{self.auth_pass}',
37                           '--certs', cert_path,
38                           '--ssl-insecure',
39                           ],
40                          stdout=PIPE, env=get_testenv())
41        line = self.proc.stdout.readline().decode('utf-8')
42        host_port = re.search(r'listening at http://([^:]+:\d+)', line).group(1)
43        address = f'http://{self.auth_user}:{self.auth_pass}@{host_port}'
44        return address
45
46    def stop(self):
47        self.proc.kill()
48        self.proc.communicate()
49
50
51def _wrong_credentials(proxy_url):
52    bad_auth_proxy = list(urlsplit(proxy_url))
53    bad_auth_proxy[1] = bad_auth_proxy[1].replace('scrapy:scrapy@', 'wrong:wronger@')
54    return urlunsplit(bad_auth_proxy)
55
56
57class ProxyConnectTestCase(TestCase):
58
59    def setUp(self):
60        try:
61            import mitmproxy  # noqa: F401
62        except ImportError:
63            self.skipTest('mitmproxy is not installed')
64
65        self.mockserver = MockServer()
66        self.mockserver.__enter__()
67        self._oldenv = os.environ.copy()
68
69        self._proxy = MitmProxy()
70        proxy_url = self._proxy.start()
71        os.environ['https_proxy'] = proxy_url
72        os.environ['http_proxy'] = proxy_url
73
74    def tearDown(self):
75        self.mockserver.__exit__(None, None, None)
76        self._proxy.stop()
77        os.environ = self._oldenv
78
79    @defer.inlineCallbacks
80    def test_https_connect_tunnel(self):
81        crawler = get_crawler(SimpleSpider)
82        with LogCapture() as log:
83            yield crawler.crawl(self.mockserver.url("/status?n=200", is_secure=True))
84        self._assert_got_response_code(200, log)
85
86    @defer.inlineCallbacks
87    def test_https_tunnel_auth_error(self):
88        os.environ['https_proxy'] = _wrong_credentials(os.environ['https_proxy'])
89        crawler = get_crawler(SimpleSpider)
90        with LogCapture() as log:
91            yield crawler.crawl(self.mockserver.url("/status?n=200", is_secure=True))
92        # The proxy returns a 407 error code but it does not reach the client;
93        # he just sees a TunnelError.
94        self._assert_got_tunnel_error(log)
95
96    @defer.inlineCallbacks
97    def test_https_tunnel_without_leak_proxy_authorization_header(self):
98        request = Request(self.mockserver.url("/echo", is_secure=True))
99        crawler = get_crawler(SingleRequestSpider)
100        with LogCapture() as log:
101            yield crawler.crawl(seed=request)
102        self._assert_got_response_code(200, log)
103        echo = json.loads(crawler.spider.meta['responses'][0].text)
104        self.assertTrue('Proxy-Authorization' not in echo['headers'])
105
106    def _assert_got_response_code(self, code, log):
107        print(log)
108        self.assertEqual(str(log).count(f'Crawled ({code})'), 1)
109
110    def _assert_got_tunnel_error(self, log):
111        print(log)
112        self.assertIn('TunnelError', str(log))
113