1""" 2from twisted.internet import defer 3Tests borrowed from the twisted.web.client tests. 4""" 5import os 6import shutil 7import sys 8from pkg_resources import parse_version 9 10import cryptography 11import OpenSSL.SSL 12from twisted.trial import unittest 13from twisted.web import server, static, util, resource 14from twisted.internet import reactor, defer 15try: 16 from twisted.internet.testing import StringTransport 17except ImportError: 18 # deprecated in Twisted 19.7.0 19 # (remove once we bump our requirement past that version) 20 from twisted.test.proto_helpers import StringTransport 21from twisted.python.filepath import FilePath 22from twisted.protocols.policies import WrappingFactory 23from twisted.internet.defer import inlineCallbacks 24from twisted.web.test.test_webclient import ( 25 ForeverTakingResource, 26 ErrorResource, 27 NoLengthResource, 28 HostHeaderResource, 29 PayloadResource, 30 BrokenDownloadResource, 31) 32 33from scrapy.core.downloader import webclient as client 34from scrapy.core.downloader.contextfactory import ScrapyClientContextFactory 35from scrapy.http import Request, Headers 36from scrapy.settings import Settings 37from scrapy.utils.misc import create_instance 38from scrapy.utils.python import to_bytes, to_unicode 39from tests.mockserver import ssl_context_factory 40 41 42def getPage(url, contextFactory=None, response_transform=None, *args, **kwargs): 43 """Adapted version of twisted.web.client.getPage""" 44 def _clientfactory(url, *args, **kwargs): 45 url = to_unicode(url) 46 timeout = kwargs.pop('timeout', 0) 47 f = client.ScrapyHTTPClientFactory( 48 Request(url, *args, **kwargs), timeout=timeout) 49 f.deferred.addCallback(response_transform or (lambda r: r.body)) 50 return f 51 52 from twisted.web.client import _makeGetterFactory 53 return _makeGetterFactory( 54 to_bytes(url), _clientfactory, contextFactory=contextFactory, *args, **kwargs 55 ).deferred 56 57 58class ParseUrlTestCase(unittest.TestCase): 59 """Test URL parsing facility and defaults values.""" 60 61 def _parse(self, url): 62 f = client.ScrapyHTTPClientFactory(Request(url)) 63 return (f.scheme, f.netloc, f.host, f.port, f.path) 64 65 def testParse(self): 66 lip = '127.0.0.1' 67 tests = ( 68 ("http://127.0.0.1?c=v&c2=v2#fragment", ('http', lip, lip, 80, '/?c=v&c2=v2')), 69 ("http://127.0.0.1/?c=v&c2=v2#fragment", ('http', lip, lip, 80, '/?c=v&c2=v2')), 70 ("http://127.0.0.1/foo?c=v&c2=v2#frag", ('http', lip, lip, 80, '/foo?c=v&c2=v2')), 71 ("http://127.0.0.1:100?c=v&c2=v2#fragment", ('http', lip + ':100', lip, 100, '/?c=v&c2=v2')), 72 ("http://127.0.0.1:100/?c=v&c2=v2#frag", ('http', lip + ':100', lip, 100, '/?c=v&c2=v2')), 73 ("http://127.0.0.1:100/foo?c=v&c2=v2#frag", ('http', lip + ':100', lip, 100, '/foo?c=v&c2=v2')), 74 75 ("http://127.0.0.1", ('http', lip, lip, 80, '/')), 76 ("http://127.0.0.1/", ('http', lip, lip, 80, '/')), 77 ("http://127.0.0.1/foo", ('http', lip, lip, 80, '/foo')), 78 ("http://127.0.0.1?param=value", ('http', lip, lip, 80, '/?param=value')), 79 ("http://127.0.0.1/?param=value", ('http', lip, lip, 80, '/?param=value')), 80 ("http://127.0.0.1:12345/foo", ('http', lip + ':12345', lip, 12345, '/foo')), 81 ("http://spam:12345/foo", ('http', 'spam:12345', 'spam', 12345, '/foo')), 82 ("http://spam.test.org/foo", ('http', 'spam.test.org', 'spam.test.org', 80, '/foo')), 83 84 ("https://127.0.0.1/foo", ('https', lip, lip, 443, '/foo')), 85 ("https://127.0.0.1/?param=value", ('https', lip, lip, 443, '/?param=value')), 86 ("https://127.0.0.1:12345/", ('https', lip + ':12345', lip, 12345, '/')), 87 88 ("http://scrapytest.org/foo ", ('http', 'scrapytest.org', 'scrapytest.org', 80, '/foo')), 89 ("http://egg:7890 ", ('http', 'egg:7890', 'egg', 7890, '/')), 90 ) 91 92 for url, test in tests: 93 test = tuple( 94 to_bytes(x) if not isinstance(x, int) else x for x in test) 95 self.assertEqual(client._parse(url), test, url) 96 97 98class ScrapyHTTPPageGetterTests(unittest.TestCase): 99 100 def test_earlyHeaders(self): 101 # basic test stolen from twisted HTTPageGetter 102 factory = client.ScrapyHTTPClientFactory(Request( 103 url='http://foo/bar', 104 body="some data", 105 headers={ 106 'Host': 'example.net', 107 'User-Agent': 'fooble', 108 'Cookie': 'blah blah', 109 'Content-Length': '12981', 110 'Useful': 'value'})) 111 112 self._test( 113 factory, 114 b"GET /bar HTTP/1.0\r\n" 115 b"Content-Length: 9\r\n" 116 b"Useful: value\r\n" 117 b"Connection: close\r\n" 118 b"User-Agent: fooble\r\n" 119 b"Host: example.net\r\n" 120 b"Cookie: blah blah\r\n" 121 b"\r\n" 122 b"some data") 123 124 # test minimal sent headers 125 factory = client.ScrapyHTTPClientFactory(Request('http://foo/bar')) 126 self._test( 127 factory, 128 b"GET /bar HTTP/1.0\r\n" 129 b"Host: foo\r\n" 130 b"\r\n") 131 132 # test a simple POST with body and content-type 133 factory = client.ScrapyHTTPClientFactory(Request( 134 method='POST', 135 url='http://foo/bar', 136 body='name=value', 137 headers={'Content-Type': 'application/x-www-form-urlencoded'})) 138 139 self._test( 140 factory, 141 b"POST /bar HTTP/1.0\r\n" 142 b"Host: foo\r\n" 143 b"Connection: close\r\n" 144 b"Content-Type: application/x-www-form-urlencoded\r\n" 145 b"Content-Length: 10\r\n" 146 b"\r\n" 147 b"name=value") 148 149 # test a POST method with no body provided 150 factory = client.ScrapyHTTPClientFactory(Request( 151 method='POST', 152 url='http://foo/bar' 153 )) 154 155 self._test( 156 factory, 157 b"POST /bar HTTP/1.0\r\n" 158 b"Host: foo\r\n" 159 b"Content-Length: 0\r\n" 160 b"\r\n") 161 162 # test with single and multivalued headers 163 factory = client.ScrapyHTTPClientFactory(Request( 164 url='http://foo/bar', 165 headers={ 166 'X-Meta-Single': 'single', 167 'X-Meta-Multivalued': ['value1', 'value2'], 168 }, 169 )) 170 171 self._test( 172 factory, 173 b"GET /bar HTTP/1.0\r\n" 174 b"Host: foo\r\n" 175 b"X-Meta-Multivalued: value1\r\n" 176 b"X-Meta-Multivalued: value2\r\n" 177 b"X-Meta-Single: single\r\n" 178 b"\r\n") 179 180 # same test with single and multivalued headers but using Headers class 181 factory = client.ScrapyHTTPClientFactory(Request( 182 url='http://foo/bar', 183 headers=Headers({ 184 'X-Meta-Single': 'single', 185 'X-Meta-Multivalued': ['value1', 'value2'], 186 }), 187 )) 188 189 self._test( 190 factory, 191 b"GET /bar HTTP/1.0\r\n" 192 b"Host: foo\r\n" 193 b"X-Meta-Multivalued: value1\r\n" 194 b"X-Meta-Multivalued: value2\r\n" 195 b"X-Meta-Single: single\r\n" 196 b"\r\n") 197 198 def _test(self, factory, testvalue): 199 transport = StringTransport() 200 protocol = client.ScrapyHTTPPageGetter() 201 protocol.factory = factory 202 protocol.makeConnection(transport) 203 self.assertEqual( 204 set(transport.value().splitlines()), 205 set(testvalue.splitlines())) 206 return testvalue 207 208 def test_non_standard_line_endings(self): 209 # regression test for: http://dev.scrapy.org/ticket/258 210 factory = client.ScrapyHTTPClientFactory(Request( 211 url='http://foo/bar')) 212 protocol = client.ScrapyHTTPPageGetter() 213 protocol.factory = factory 214 protocol.headers = Headers() 215 protocol.dataReceived(b"HTTP/1.0 200 OK\n") 216 protocol.dataReceived(b"Hello: World\n") 217 protocol.dataReceived(b"Foo: Bar\n") 218 protocol.dataReceived(b"\n") 219 self.assertEqual(protocol.headers, Headers({'Hello': ['World'], 'Foo': ['Bar']})) 220 221 222class EncodingResource(resource.Resource): 223 out_encoding = 'cp1251' 224 225 def render(self, request): 226 body = to_unicode(request.content.read()) 227 request.setHeader(b'content-encoding', self.out_encoding) 228 return body.encode(self.out_encoding) 229 230 231class WebClientTestCase(unittest.TestCase): 232 def _listen(self, site): 233 return reactor.listenTCP(0, site, interface="127.0.0.1") 234 235 def setUp(self): 236 self.tmpname = self.mktemp() 237 os.mkdir(self.tmpname) 238 FilePath(self.tmpname).child("file").setContent(b"0123456789") 239 r = static.File(self.tmpname) 240 r.putChild(b"redirect", util.Redirect(b"/file")) 241 r.putChild(b"wait", ForeverTakingResource()) 242 r.putChild(b"error", ErrorResource()) 243 r.putChild(b"nolength", NoLengthResource()) 244 r.putChild(b"host", HostHeaderResource()) 245 r.putChild(b"payload", PayloadResource()) 246 r.putChild(b"broken", BrokenDownloadResource()) 247 r.putChild(b"encoding", EncodingResource()) 248 self.site = server.Site(r, timeout=None) 249 self.wrapper = WrappingFactory(self.site) 250 self.port = self._listen(self.wrapper) 251 self.portno = self.port.getHost().port 252 253 @inlineCallbacks 254 def tearDown(self): 255 yield self.port.stopListening() 256 shutil.rmtree(self.tmpname) 257 258 def getURL(self, path): 259 return f"http://127.0.0.1:{self.portno}/{path}" 260 261 def testPayload(self): 262 s = "0123456789" * 10 263 return getPage(self.getURL("payload"), body=s).addCallback( 264 self.assertEqual, to_bytes(s)) 265 266 def testHostHeader(self): 267 # if we pass Host header explicitly, it should be used, otherwise 268 # it should extract from url 269 return defer.gatherResults([ 270 getPage(self.getURL("host")).addCallback( 271 self.assertEqual, to_bytes(f"127.0.0.1:{self.portno}")), 272 getPage(self.getURL("host"), headers={"Host": "www.example.com"}).addCallback( 273 self.assertEqual, to_bytes("www.example.com"))]) 274 275 def test_getPage(self): 276 """ 277 L{client.getPage} returns a L{Deferred} which is called back with 278 the body of the response if the default method B{GET} is used. 279 """ 280 d = getPage(self.getURL("file")) 281 d.addCallback(self.assertEqual, b"0123456789") 282 return d 283 284 def test_getPageHead(self): 285 """ 286 L{client.getPage} returns a L{Deferred} which is called back with 287 the empty string if the method is C{HEAD} and there is a successful 288 response code. 289 """ 290 def _getPage(method): 291 return getPage(self.getURL("file"), method=method) 292 return defer.gatherResults([ 293 _getPage("head").addCallback(self.assertEqual, b""), 294 _getPage("HEAD").addCallback(self.assertEqual, b"")]) 295 296 def test_timeoutNotTriggering(self): 297 """ 298 When a non-zero timeout is passed to L{getPage} and the page is 299 retrieved before the timeout period elapses, the L{Deferred} is 300 called back with the contents of the page. 301 """ 302 d = getPage(self.getURL("host"), timeout=100) 303 d.addCallback( 304 self.assertEqual, to_bytes(f"127.0.0.1:{self.portno}")) 305 return d 306 307 def test_timeoutTriggering(self): 308 """ 309 When a non-zero timeout is passed to L{getPage} and that many 310 seconds elapse before the server responds to the request. the 311 L{Deferred} is errbacked with a L{error.TimeoutError}. 312 """ 313 finished = self.assertFailure( 314 getPage(self.getURL("wait"), timeout=0.000001), 315 defer.TimeoutError) 316 317 def cleanup(passthrough): 318 # Clean up the server which is hanging around not doing 319 # anything. 320 connected = list(self.wrapper.protocols.keys()) 321 # There might be nothing here if the server managed to already see 322 # that the connection was lost. 323 if connected: 324 connected[0].transport.loseConnection() 325 return passthrough 326 finished.addBoth(cleanup) 327 return finished 328 329 def testNotFound(self): 330 return getPage(self.getURL('notsuchfile')).addCallback(self._cbNoSuchFile) 331 332 def _cbNoSuchFile(self, pageData): 333 self.assertIn(b'404 - No Such Resource', pageData) 334 335 def testFactoryInfo(self): 336 url = self.getURL('file') 337 _, _, host, port, _ = client._parse(url) 338 factory = client.ScrapyHTTPClientFactory(Request(url)) 339 reactor.connectTCP(to_unicode(host), port, factory) 340 return factory.deferred.addCallback(self._cbFactoryInfo, factory) 341 342 def _cbFactoryInfo(self, ignoredResult, factory): 343 self.assertEqual(factory.status, b'200') 344 self.assertTrue(factory.version.startswith(b'HTTP/')) 345 self.assertEqual(factory.message, b'OK') 346 self.assertEqual(factory.response_headers[b'content-length'], b'10') 347 348 def testRedirect(self): 349 return getPage(self.getURL("redirect")).addCallback(self._cbRedirect) 350 351 def _cbRedirect(self, pageData): 352 self.assertEqual( 353 pageData, 354 b'\n<html>\n <head>\n <meta http-equiv="refresh" content="0;URL=/file">\n' 355 b' </head>\n <body bgcolor="#FFFFFF" text="#000000">\n ' 356 b'<a href="/file">click here</a>\n </body>\n</html>\n') 357 358 def test_encoding(self): 359 """ Test that non-standart body encoding matches 360 Content-Encoding header """ 361 body = b'\xd0\x81\xd1\x8e\xd0\xaf' 362 dfd = getPage(self.getURL('encoding'), body=body, response_transform=lambda r: r) 363 return dfd.addCallback(self._check_Encoding, body) 364 365 def _check_Encoding(self, response, original_body): 366 content_encoding = to_unicode(response.headers[b'Content-Encoding']) 367 self.assertEqual(content_encoding, EncodingResource.out_encoding) 368 self.assertEqual( 369 response.body.decode(content_encoding), to_unicode(original_body)) 370 371 372class WebClientSSLTestCase(unittest.TestCase): 373 context_factory = None 374 375 def _listen(self, site): 376 return reactor.listenSSL( 377 0, site, 378 contextFactory=self.context_factory or ssl_context_factory(), 379 interface="127.0.0.1") 380 381 def getURL(self, path): 382 return f"https://127.0.0.1:{self.portno}/{path}" 383 384 def setUp(self): 385 self.tmpname = self.mktemp() 386 os.mkdir(self.tmpname) 387 FilePath(self.tmpname).child("file").setContent(b"0123456789") 388 r = static.File(self.tmpname) 389 r.putChild(b"payload", PayloadResource()) 390 self.site = server.Site(r, timeout=None) 391 self.wrapper = WrappingFactory(self.site) 392 self.port = self._listen(self.wrapper) 393 self.portno = self.port.getHost().port 394 395 @inlineCallbacks 396 def tearDown(self): 397 yield self.port.stopListening() 398 shutil.rmtree(self.tmpname) 399 400 def testPayload(self): 401 s = "0123456789" * 10 402 return getPage(self.getURL("payload"), body=s).addCallback( 403 self.assertEqual, to_bytes(s)) 404 405 406class WebClientCustomCiphersSSLTestCase(WebClientSSLTestCase): 407 # we try to use a cipher that is not enabled by default in OpenSSL 408 custom_ciphers = 'CAMELLIA256-SHA' 409 context_factory = ssl_context_factory(cipher_string=custom_ciphers) 410 411 def testPayload(self): 412 s = "0123456789" * 10 413 settings = Settings({'DOWNLOADER_CLIENT_TLS_CIPHERS': self.custom_ciphers}) 414 client_context_factory = create_instance(ScrapyClientContextFactory, settings=settings, crawler=None) 415 return getPage( 416 self.getURL("payload"), body=s, contextFactory=client_context_factory 417 ).addCallback(self.assertEqual, to_bytes(s)) 418 419 def testPayloadDisabledCipher(self): 420 if sys.implementation.name == "pypy" and parse_version(cryptography.__version__) <= parse_version("2.3.1"): 421 self.skipTest("This test expects a failure, but the code does work in PyPy with cryptography<=2.3.1") 422 s = "0123456789" * 10 423 settings = Settings({'DOWNLOADER_CLIENT_TLS_CIPHERS': 'ECDHE-RSA-AES256-GCM-SHA384'}) 424 client_context_factory = create_instance(ScrapyClientContextFactory, settings=settings, crawler=None) 425 d = getPage(self.getURL("payload"), body=s, contextFactory=client_context_factory) 426 return self.assertFailure(d, OpenSSL.SSL.Error) 427