1"""
2from twisted.internet import defer
3Tests borrowed from the twisted.web.client tests.
4"""
5import os
6import shutil
7import sys
8from pkg_resources import parse_version
9
10import cryptography
11import OpenSSL.SSL
12from twisted.trial import unittest
13from twisted.web import server, static, util, resource
14from twisted.internet import reactor, defer
15try:
16    from twisted.internet.testing import StringTransport
17except ImportError:
18    # deprecated in Twisted 19.7.0
19    # (remove once we bump our requirement past that version)
20    from twisted.test.proto_helpers import StringTransport
21from twisted.python.filepath import FilePath
22from twisted.protocols.policies import WrappingFactory
23from twisted.internet.defer import inlineCallbacks
24from twisted.web.test.test_webclient import (
25    ForeverTakingResource,
26    ErrorResource,
27    NoLengthResource,
28    HostHeaderResource,
29    PayloadResource,
30    BrokenDownloadResource,
31)
32
33from scrapy.core.downloader import webclient as client
34from scrapy.core.downloader.contextfactory import ScrapyClientContextFactory
35from scrapy.http import Request, Headers
36from scrapy.settings import Settings
37from scrapy.utils.misc import create_instance
38from scrapy.utils.python import to_bytes, to_unicode
39from tests.mockserver import ssl_context_factory
40
41
42def getPage(url, contextFactory=None, response_transform=None, *args, **kwargs):
43    """Adapted version of twisted.web.client.getPage"""
44    def _clientfactory(url, *args, **kwargs):
45        url = to_unicode(url)
46        timeout = kwargs.pop('timeout', 0)
47        f = client.ScrapyHTTPClientFactory(
48            Request(url, *args, **kwargs), timeout=timeout)
49        f.deferred.addCallback(response_transform or (lambda r: r.body))
50        return f
51
52    from twisted.web.client import _makeGetterFactory
53    return _makeGetterFactory(
54        to_bytes(url), _clientfactory, contextFactory=contextFactory, *args, **kwargs
55    ).deferred
56
57
58class ParseUrlTestCase(unittest.TestCase):
59    """Test URL parsing facility and defaults values."""
60
61    def _parse(self, url):
62        f = client.ScrapyHTTPClientFactory(Request(url))
63        return (f.scheme, f.netloc, f.host, f.port, f.path)
64
65    def testParse(self):
66        lip = '127.0.0.1'
67        tests = (
68            ("http://127.0.0.1?c=v&c2=v2#fragment", ('http', lip, lip, 80, '/?c=v&c2=v2')),
69            ("http://127.0.0.1/?c=v&c2=v2#fragment", ('http', lip, lip, 80, '/?c=v&c2=v2')),
70            ("http://127.0.0.1/foo?c=v&c2=v2#frag", ('http', lip, lip, 80, '/foo?c=v&c2=v2')),
71            ("http://127.0.0.1:100?c=v&c2=v2#fragment", ('http', lip + ':100', lip, 100, '/?c=v&c2=v2')),
72            ("http://127.0.0.1:100/?c=v&c2=v2#frag", ('http', lip + ':100', lip, 100, '/?c=v&c2=v2')),
73            ("http://127.0.0.1:100/foo?c=v&c2=v2#frag", ('http', lip + ':100', lip, 100, '/foo?c=v&c2=v2')),
74
75            ("http://127.0.0.1", ('http', lip, lip, 80, '/')),
76            ("http://127.0.0.1/", ('http', lip, lip, 80, '/')),
77            ("http://127.0.0.1/foo", ('http', lip, lip, 80, '/foo')),
78            ("http://127.0.0.1?param=value", ('http', lip, lip, 80, '/?param=value')),
79            ("http://127.0.0.1/?param=value", ('http', lip, lip, 80, '/?param=value')),
80            ("http://127.0.0.1:12345/foo", ('http', lip + ':12345', lip, 12345, '/foo')),
81            ("http://spam:12345/foo", ('http', 'spam:12345', 'spam', 12345, '/foo')),
82            ("http://spam.test.org/foo", ('http', 'spam.test.org', 'spam.test.org', 80, '/foo')),
83
84            ("https://127.0.0.1/foo", ('https', lip, lip, 443, '/foo')),
85            ("https://127.0.0.1/?param=value", ('https', lip, lip, 443, '/?param=value')),
86            ("https://127.0.0.1:12345/", ('https', lip + ':12345', lip, 12345, '/')),
87
88            ("http://scrapytest.org/foo ", ('http', 'scrapytest.org', 'scrapytest.org', 80, '/foo')),
89            ("http://egg:7890 ", ('http', 'egg:7890', 'egg', 7890, '/')),
90        )
91
92        for url, test in tests:
93            test = tuple(
94                to_bytes(x) if not isinstance(x, int) else x for x in test)
95            self.assertEqual(client._parse(url), test, url)
96
97
98class ScrapyHTTPPageGetterTests(unittest.TestCase):
99
100    def test_earlyHeaders(self):
101        # basic test stolen from twisted HTTPageGetter
102        factory = client.ScrapyHTTPClientFactory(Request(
103            url='http://foo/bar',
104            body="some data",
105            headers={
106                'Host': 'example.net',
107                'User-Agent': 'fooble',
108                'Cookie': 'blah blah',
109                'Content-Length': '12981',
110                'Useful': 'value'}))
111
112        self._test(
113            factory,
114            b"GET /bar HTTP/1.0\r\n"
115            b"Content-Length: 9\r\n"
116            b"Useful: value\r\n"
117            b"Connection: close\r\n"
118            b"User-Agent: fooble\r\n"
119            b"Host: example.net\r\n"
120            b"Cookie: blah blah\r\n"
121            b"\r\n"
122            b"some data")
123
124        # test minimal sent headers
125        factory = client.ScrapyHTTPClientFactory(Request('http://foo/bar'))
126        self._test(
127            factory,
128            b"GET /bar HTTP/1.0\r\n"
129            b"Host: foo\r\n"
130            b"\r\n")
131
132        # test a simple POST with body and content-type
133        factory = client.ScrapyHTTPClientFactory(Request(
134            method='POST',
135            url='http://foo/bar',
136            body='name=value',
137            headers={'Content-Type': 'application/x-www-form-urlencoded'}))
138
139        self._test(
140            factory,
141            b"POST /bar HTTP/1.0\r\n"
142            b"Host: foo\r\n"
143            b"Connection: close\r\n"
144            b"Content-Type: application/x-www-form-urlencoded\r\n"
145            b"Content-Length: 10\r\n"
146            b"\r\n"
147            b"name=value")
148
149        # test a POST method with no body provided
150        factory = client.ScrapyHTTPClientFactory(Request(
151            method='POST',
152            url='http://foo/bar'
153        ))
154
155        self._test(
156            factory,
157            b"POST /bar HTTP/1.0\r\n"
158            b"Host: foo\r\n"
159            b"Content-Length: 0\r\n"
160            b"\r\n")
161
162        # test with single and multivalued headers
163        factory = client.ScrapyHTTPClientFactory(Request(
164            url='http://foo/bar',
165            headers={
166                'X-Meta-Single': 'single',
167                'X-Meta-Multivalued': ['value1', 'value2'],
168            },
169        ))
170
171        self._test(
172            factory,
173            b"GET /bar HTTP/1.0\r\n"
174            b"Host: foo\r\n"
175            b"X-Meta-Multivalued: value1\r\n"
176            b"X-Meta-Multivalued: value2\r\n"
177            b"X-Meta-Single: single\r\n"
178            b"\r\n")
179
180        # same test with single and multivalued headers but using Headers class
181        factory = client.ScrapyHTTPClientFactory(Request(
182            url='http://foo/bar',
183            headers=Headers({
184                'X-Meta-Single': 'single',
185                'X-Meta-Multivalued': ['value1', 'value2'],
186            }),
187        ))
188
189        self._test(
190            factory,
191            b"GET /bar HTTP/1.0\r\n"
192            b"Host: foo\r\n"
193            b"X-Meta-Multivalued: value1\r\n"
194            b"X-Meta-Multivalued: value2\r\n"
195            b"X-Meta-Single: single\r\n"
196            b"\r\n")
197
198    def _test(self, factory, testvalue):
199        transport = StringTransport()
200        protocol = client.ScrapyHTTPPageGetter()
201        protocol.factory = factory
202        protocol.makeConnection(transport)
203        self.assertEqual(
204            set(transport.value().splitlines()),
205            set(testvalue.splitlines()))
206        return testvalue
207
208    def test_non_standard_line_endings(self):
209        # regression test for: http://dev.scrapy.org/ticket/258
210        factory = client.ScrapyHTTPClientFactory(Request(
211            url='http://foo/bar'))
212        protocol = client.ScrapyHTTPPageGetter()
213        protocol.factory = factory
214        protocol.headers = Headers()
215        protocol.dataReceived(b"HTTP/1.0 200 OK\n")
216        protocol.dataReceived(b"Hello: World\n")
217        protocol.dataReceived(b"Foo: Bar\n")
218        protocol.dataReceived(b"\n")
219        self.assertEqual(protocol.headers, Headers({'Hello': ['World'], 'Foo': ['Bar']}))
220
221
222class EncodingResource(resource.Resource):
223    out_encoding = 'cp1251'
224
225    def render(self, request):
226        body = to_unicode(request.content.read())
227        request.setHeader(b'content-encoding', self.out_encoding)
228        return body.encode(self.out_encoding)
229
230
231class WebClientTestCase(unittest.TestCase):
232    def _listen(self, site):
233        return reactor.listenTCP(0, site, interface="127.0.0.1")
234
235    def setUp(self):
236        self.tmpname = self.mktemp()
237        os.mkdir(self.tmpname)
238        FilePath(self.tmpname).child("file").setContent(b"0123456789")
239        r = static.File(self.tmpname)
240        r.putChild(b"redirect", util.Redirect(b"/file"))
241        r.putChild(b"wait", ForeverTakingResource())
242        r.putChild(b"error", ErrorResource())
243        r.putChild(b"nolength", NoLengthResource())
244        r.putChild(b"host", HostHeaderResource())
245        r.putChild(b"payload", PayloadResource())
246        r.putChild(b"broken", BrokenDownloadResource())
247        r.putChild(b"encoding", EncodingResource())
248        self.site = server.Site(r, timeout=None)
249        self.wrapper = WrappingFactory(self.site)
250        self.port = self._listen(self.wrapper)
251        self.portno = self.port.getHost().port
252
253    @inlineCallbacks
254    def tearDown(self):
255        yield self.port.stopListening()
256        shutil.rmtree(self.tmpname)
257
258    def getURL(self, path):
259        return f"http://127.0.0.1:{self.portno}/{path}"
260
261    def testPayload(self):
262        s = "0123456789" * 10
263        return getPage(self.getURL("payload"), body=s).addCallback(
264            self.assertEqual, to_bytes(s))
265
266    def testHostHeader(self):
267        # if we pass Host header explicitly, it should be used, otherwise
268        # it should extract from url
269        return defer.gatherResults([
270            getPage(self.getURL("host")).addCallback(
271                self.assertEqual, to_bytes(f"127.0.0.1:{self.portno}")),
272            getPage(self.getURL("host"), headers={"Host": "www.example.com"}).addCallback(
273                self.assertEqual, to_bytes("www.example.com"))])
274
275    def test_getPage(self):
276        """
277        L{client.getPage} returns a L{Deferred} which is called back with
278        the body of the response if the default method B{GET} is used.
279        """
280        d = getPage(self.getURL("file"))
281        d.addCallback(self.assertEqual, b"0123456789")
282        return d
283
284    def test_getPageHead(self):
285        """
286        L{client.getPage} returns a L{Deferred} which is called back with
287        the empty string if the method is C{HEAD} and there is a successful
288        response code.
289        """
290        def _getPage(method):
291            return getPage(self.getURL("file"), method=method)
292        return defer.gatherResults([
293            _getPage("head").addCallback(self.assertEqual, b""),
294            _getPage("HEAD").addCallback(self.assertEqual, b"")])
295
296    def test_timeoutNotTriggering(self):
297        """
298        When a non-zero timeout is passed to L{getPage} and the page is
299        retrieved before the timeout period elapses, the L{Deferred} is
300        called back with the contents of the page.
301        """
302        d = getPage(self.getURL("host"), timeout=100)
303        d.addCallback(
304            self.assertEqual, to_bytes(f"127.0.0.1:{self.portno}"))
305        return d
306
307    def test_timeoutTriggering(self):
308        """
309        When a non-zero timeout is passed to L{getPage} and that many
310        seconds elapse before the server responds to the request. the
311        L{Deferred} is errbacked with a L{error.TimeoutError}.
312        """
313        finished = self.assertFailure(
314            getPage(self.getURL("wait"), timeout=0.000001),
315            defer.TimeoutError)
316
317        def cleanup(passthrough):
318            # Clean up the server which is hanging around not doing
319            # anything.
320            connected = list(self.wrapper.protocols.keys())
321            # There might be nothing here if the server managed to already see
322            # that the connection was lost.
323            if connected:
324                connected[0].transport.loseConnection()
325            return passthrough
326        finished.addBoth(cleanup)
327        return finished
328
329    def testNotFound(self):
330        return getPage(self.getURL('notsuchfile')).addCallback(self._cbNoSuchFile)
331
332    def _cbNoSuchFile(self, pageData):
333        self.assertIn(b'404 - No Such Resource', pageData)
334
335    def testFactoryInfo(self):
336        url = self.getURL('file')
337        _, _, host, port, _ = client._parse(url)
338        factory = client.ScrapyHTTPClientFactory(Request(url))
339        reactor.connectTCP(to_unicode(host), port, factory)
340        return factory.deferred.addCallback(self._cbFactoryInfo, factory)
341
342    def _cbFactoryInfo(self, ignoredResult, factory):
343        self.assertEqual(factory.status, b'200')
344        self.assertTrue(factory.version.startswith(b'HTTP/'))
345        self.assertEqual(factory.message, b'OK')
346        self.assertEqual(factory.response_headers[b'content-length'], b'10')
347
348    def testRedirect(self):
349        return getPage(self.getURL("redirect")).addCallback(self._cbRedirect)
350
351    def _cbRedirect(self, pageData):
352        self.assertEqual(
353            pageData,
354            b'\n<html>\n    <head>\n        <meta http-equiv="refresh" content="0;URL=/file">\n'
355            b'    </head>\n    <body bgcolor="#FFFFFF" text="#000000">\n    '
356            b'<a href="/file">click here</a>\n    </body>\n</html>\n')
357
358    def test_encoding(self):
359        """ Test that non-standart body encoding matches
360        Content-Encoding header """
361        body = b'\xd0\x81\xd1\x8e\xd0\xaf'
362        dfd = getPage(self.getURL('encoding'), body=body, response_transform=lambda r: r)
363        return dfd.addCallback(self._check_Encoding, body)
364
365    def _check_Encoding(self, response, original_body):
366        content_encoding = to_unicode(response.headers[b'Content-Encoding'])
367        self.assertEqual(content_encoding, EncodingResource.out_encoding)
368        self.assertEqual(
369            response.body.decode(content_encoding), to_unicode(original_body))
370
371
372class WebClientSSLTestCase(unittest.TestCase):
373    context_factory = None
374
375    def _listen(self, site):
376        return reactor.listenSSL(
377            0, site,
378            contextFactory=self.context_factory or ssl_context_factory(),
379            interface="127.0.0.1")
380
381    def getURL(self, path):
382        return f"https://127.0.0.1:{self.portno}/{path}"
383
384    def setUp(self):
385        self.tmpname = self.mktemp()
386        os.mkdir(self.tmpname)
387        FilePath(self.tmpname).child("file").setContent(b"0123456789")
388        r = static.File(self.tmpname)
389        r.putChild(b"payload", PayloadResource())
390        self.site = server.Site(r, timeout=None)
391        self.wrapper = WrappingFactory(self.site)
392        self.port = self._listen(self.wrapper)
393        self.portno = self.port.getHost().port
394
395    @inlineCallbacks
396    def tearDown(self):
397        yield self.port.stopListening()
398        shutil.rmtree(self.tmpname)
399
400    def testPayload(self):
401        s = "0123456789" * 10
402        return getPage(self.getURL("payload"), body=s).addCallback(
403            self.assertEqual, to_bytes(s))
404
405
406class WebClientCustomCiphersSSLTestCase(WebClientSSLTestCase):
407    # we try to use a cipher that is not enabled by default in OpenSSL
408    custom_ciphers = 'CAMELLIA256-SHA'
409    context_factory = ssl_context_factory(cipher_string=custom_ciphers)
410
411    def testPayload(self):
412        s = "0123456789" * 10
413        settings = Settings({'DOWNLOADER_CLIENT_TLS_CIPHERS': self.custom_ciphers})
414        client_context_factory = create_instance(ScrapyClientContextFactory, settings=settings, crawler=None)
415        return getPage(
416            self.getURL("payload"), body=s, contextFactory=client_context_factory
417        ).addCallback(self.assertEqual, to_bytes(s))
418
419    def testPayloadDisabledCipher(self):
420        if sys.implementation.name == "pypy" and parse_version(cryptography.__version__) <= parse_version("2.3.1"):
421            self.skipTest("This test expects a failure, but the code does work in PyPy with cryptography<=2.3.1")
422        s = "0123456789" * 10
423        settings = Settings({'DOWNLOADER_CLIENT_TLS_CIPHERS': 'ECDHE-RSA-AES256-GCM-SHA384'})
424        client_context_factory = create_instance(ScrapyClientContextFactory, settings=settings, crawler=None)
425        d = getPage(self.getURL("payload"), body=s, contextFactory=client_context_factory)
426        return self.assertFailure(d, OpenSSL.SSL.Error)
427