1import asyncio
2from unittest import mock
3
4import pytest
5from async_generator import async_generator, yield_
6
7from aiohttp import log, web
8from aiohttp.abc import AbstractAccessLogger, AbstractRouter
9from aiohttp.helpers import DEBUG, PY_36
10from aiohttp.test_utils import make_mocked_coro
11
12
13async def test_app_ctor() -> None:
14    loop = asyncio.get_event_loop()
15    with pytest.warns(DeprecationWarning):
16        app = web.Application(loop=loop)
17    with pytest.warns(DeprecationWarning):
18        assert loop is app.loop
19    assert app.logger is log.web_logger
20
21
22def test_app_call() -> None:
23    app = web.Application()
24    assert app is app()
25
26
27def test_app_default_loop() -> None:
28    app = web.Application()
29    with pytest.warns(DeprecationWarning):
30        assert app.loop is None
31
32
33async def test_set_loop() -> None:
34    loop = asyncio.get_event_loop()
35    app = web.Application()
36    app._set_loop(loop)
37    with pytest.warns(DeprecationWarning):
38        assert app.loop is loop
39
40
41def test_set_loop_default_loop() -> None:
42    loop = asyncio.new_event_loop()
43    asyncio.set_event_loop(loop)
44    app = web.Application()
45    app._set_loop(None)
46    with pytest.warns(DeprecationWarning):
47        assert app.loop is loop
48    asyncio.set_event_loop(None)
49
50
51def test_set_loop_with_different_loops() -> None:
52    loop = asyncio.new_event_loop()
53    app = web.Application()
54    app._set_loop(loop)
55    with pytest.warns(DeprecationWarning):
56        assert app.loop is loop
57
58    with pytest.raises(RuntimeError):
59        app._set_loop(loop=object())
60
61
62@pytest.mark.parametrize("debug", [True, False])
63async def test_app_make_handler_debug_exc(mocker, debug) -> None:
64    with pytest.warns(DeprecationWarning):
65        app = web.Application(debug=debug)
66    srv = mocker.patch("aiohttp.web_app.Server")
67
68    with pytest.warns(DeprecationWarning):
69        assert app.debug == debug
70
71    app._make_handler()
72    srv.assert_called_with(
73        app._handle,
74        request_factory=app._make_request,
75        access_log_class=mock.ANY,
76        loop=asyncio.get_event_loop(),
77        debug=debug,
78    )
79
80
81async def test_app_make_handler_args(mocker) -> None:
82    app = web.Application(handler_args={"test": True})
83    srv = mocker.patch("aiohttp.web_app.Server")
84
85    app._make_handler()
86    srv.assert_called_with(
87        app._handle,
88        request_factory=app._make_request,
89        access_log_class=mock.ANY,
90        loop=asyncio.get_event_loop(),
91        debug=mock.ANY,
92        test=True,
93    )
94
95
96async def test_app_make_handler_access_log_class(mocker) -> None:
97    class Logger:
98        pass
99
100    app = web.Application()
101
102    with pytest.raises(TypeError):
103        app._make_handler(access_log_class=Logger)
104
105    class Logger(AbstractAccessLogger):
106        def log(self, request, response, time):
107            self.logger.info("msg")
108
109    srv = mocker.patch("aiohttp.web_app.Server")
110
111    app._make_handler(access_log_class=Logger)
112    srv.assert_called_with(
113        app._handle,
114        access_log_class=Logger,
115        request_factory=app._make_request,
116        loop=asyncio.get_event_loop(),
117        debug=mock.ANY,
118    )
119
120    app = web.Application(handler_args={"access_log_class": Logger})
121    app._make_handler(access_log_class=Logger)
122    srv.assert_called_with(
123        app._handle,
124        access_log_class=Logger,
125        request_factory=app._make_request,
126        loop=asyncio.get_event_loop(),
127        debug=mock.ANY,
128    )
129
130
131async def test_app_make_handler_raises_deprecation_warning() -> None:
132    app = web.Application()
133
134    with pytest.warns(DeprecationWarning):
135        app.make_handler()
136
137
138async def test_app_register_on_finish() -> None:
139    app = web.Application()
140    cb1 = make_mocked_coro(None)
141    cb2 = make_mocked_coro(None)
142    app.on_cleanup.append(cb1)
143    app.on_cleanup.append(cb2)
144    app.freeze()
145    await app.cleanup()
146    cb1.assert_called_once_with(app)
147    cb2.assert_called_once_with(app)
148
149
150async def test_app_register_coro() -> None:
151    app = web.Application()
152    fut = asyncio.get_event_loop().create_future()
153
154    async def cb(app):
155        await asyncio.sleep(0.001)
156        fut.set_result(123)
157
158    app.on_cleanup.append(cb)
159    app.freeze()
160    await app.cleanup()
161    assert fut.done()
162    assert 123 == fut.result()
163
164
165def test_non_default_router() -> None:
166    router = mock.Mock(spec=AbstractRouter)
167    with pytest.warns(DeprecationWarning):
168        app = web.Application(router=router)
169    assert router is app.router
170
171
172def test_logging() -> None:
173    logger = mock.Mock()
174    app = web.Application()
175    app.logger = logger
176    assert app.logger is logger
177
178
179async def test_on_shutdown() -> None:
180    app = web.Application()
181    called = False
182
183    async def on_shutdown(app_param):
184        nonlocal called
185        assert app is app_param
186        called = True
187
188    app.on_shutdown.append(on_shutdown)
189    app.freeze()
190    await app.shutdown()
191    assert called
192
193
194async def test_on_startup() -> None:
195    app = web.Application()
196
197    long_running1_called = False
198    long_running2_called = False
199    all_long_running_called = False
200
201    async def long_running1(app_param):
202        nonlocal long_running1_called
203        assert app is app_param
204        long_running1_called = True
205
206    async def long_running2(app_param):
207        nonlocal long_running2_called
208        assert app is app_param
209        long_running2_called = True
210
211    async def on_startup_all_long_running(app_param):
212        nonlocal all_long_running_called
213        assert app is app_param
214        all_long_running_called = True
215        return await asyncio.gather(long_running1(app_param), long_running2(app_param))
216
217    app.on_startup.append(on_startup_all_long_running)
218    app.freeze()
219
220    await app.startup()
221    assert long_running1_called
222    assert long_running2_called
223    assert all_long_running_called
224
225
226def test_app_delitem() -> None:
227    app = web.Application()
228    app["key"] = "value"
229    assert len(app) == 1
230    del app["key"]
231    assert len(app) == 0
232
233
234def test_app_freeze() -> None:
235    app = web.Application()
236    subapp = mock.Mock()
237    subapp._middlewares = ()
238    app._subapps.append(subapp)
239
240    app.freeze()
241    assert subapp.freeze.called
242
243    app.freeze()
244    assert len(subapp.freeze.call_args_list) == 1
245
246
247def test_equality() -> None:
248    app1 = web.Application()
249    app2 = web.Application()
250
251    assert app1 == app1
252    assert app1 != app2
253
254
255def test_app_run_middlewares() -> None:
256
257    root = web.Application()
258    sub = web.Application()
259    root.add_subapp("/sub", sub)
260    root.freeze()
261    assert root._run_middlewares is False
262
263    @web.middleware
264    async def middleware(request, handler):
265        return await handler(request)
266
267    root = web.Application(middlewares=[middleware])
268    sub = web.Application()
269    root.add_subapp("/sub", sub)
270    root.freeze()
271    assert root._run_middlewares is True
272
273    root = web.Application()
274    sub = web.Application(middlewares=[middleware])
275    root.add_subapp("/sub", sub)
276    root.freeze()
277    assert root._run_middlewares is True
278
279
280def test_subapp_pre_frozen_after_adding() -> None:
281    app = web.Application()
282    subapp = web.Application()
283
284    app.add_subapp("/prefix", subapp)
285    assert subapp.pre_frozen
286    assert not subapp.frozen
287
288
289@pytest.mark.skipif(not PY_36, reason="Python 3.6+ required")
290def test_app_inheritance() -> None:
291    with pytest.warns(DeprecationWarning):
292
293        class A(web.Application):
294            pass
295
296
297@pytest.mark.skipif(not DEBUG, reason="The check is applied in DEBUG mode only")
298def test_app_custom_attr() -> None:
299    app = web.Application()
300    with pytest.warns(DeprecationWarning):
301        app.custom = None
302
303
304async def test_cleanup_ctx() -> None:
305    app = web.Application()
306    out = []
307
308    def f(num):
309        @async_generator
310        async def inner(app):
311            out.append("pre_" + str(num))
312            await yield_(None)
313            out.append("post_" + str(num))
314
315        return inner
316
317    app.cleanup_ctx.append(f(1))
318    app.cleanup_ctx.append(f(2))
319    app.freeze()
320    await app.startup()
321    assert out == ["pre_1", "pre_2"]
322    await app.cleanup()
323    assert out == ["pre_1", "pre_2", "post_2", "post_1"]
324
325
326async def test_cleanup_ctx_exception_on_startup() -> None:
327    app = web.Application()
328    out = []
329
330    exc = Exception("fail")
331
332    def f(num, fail=False):
333        @async_generator
334        async def inner(app):
335            out.append("pre_" + str(num))
336            if fail:
337                raise exc
338            await yield_(None)
339            out.append("post_" + str(num))
340
341        return inner
342
343    app.cleanup_ctx.append(f(1))
344    app.cleanup_ctx.append(f(2, True))
345    app.cleanup_ctx.append(f(3))
346    app.freeze()
347    with pytest.raises(Exception) as ctx:
348        await app.startup()
349    assert ctx.value is exc
350    assert out == ["pre_1", "pre_2"]
351    await app.cleanup()
352    assert out == ["pre_1", "pre_2", "post_1"]
353
354
355async def test_cleanup_ctx_exception_on_cleanup() -> None:
356    app = web.Application()
357    out = []
358
359    exc = Exception("fail")
360
361    def f(num, fail=False):
362        @async_generator
363        async def inner(app):
364            out.append("pre_" + str(num))
365            await yield_(None)
366            out.append("post_" + str(num))
367            if fail:
368                raise exc
369
370        return inner
371
372    app.cleanup_ctx.append(f(1))
373    app.cleanup_ctx.append(f(2, True))
374    app.cleanup_ctx.append(f(3))
375    app.freeze()
376    await app.startup()
377    assert out == ["pre_1", "pre_2", "pre_3"]
378    with pytest.raises(Exception) as ctx:
379        await app.cleanup()
380    assert ctx.value is exc
381    assert out == ["pre_1", "pre_2", "pre_3", "post_3", "post_2", "post_1"]
382
383
384async def test_cleanup_ctx_exception_on_cleanup_multiple() -> None:
385    app = web.Application()
386    out = []
387
388    def f(num, fail=False):
389        @async_generator
390        async def inner(app):
391            out.append("pre_" + str(num))
392            await yield_(None)
393            out.append("post_" + str(num))
394            if fail:
395                raise Exception("fail_" + str(num))
396
397        return inner
398
399    app.cleanup_ctx.append(f(1))
400    app.cleanup_ctx.append(f(2, True))
401    app.cleanup_ctx.append(f(3, True))
402    app.freeze()
403    await app.startup()
404    assert out == ["pre_1", "pre_2", "pre_3"]
405    with pytest.raises(web.CleanupError) as ctx:
406        await app.cleanup()
407    exc = ctx.value
408    assert len(exc.exceptions) == 2
409    assert str(exc.exceptions[0]) == "fail_3"
410    assert str(exc.exceptions[1]) == "fail_2"
411    assert out == ["pre_1", "pre_2", "pre_3", "post_3", "post_2", "post_1"]
412
413
414async def test_cleanup_ctx_multiple_yields() -> None:
415    app = web.Application()
416    out = []
417
418    def f(num):
419        @async_generator
420        async def inner(app):
421            out.append("pre_" + str(num))
422            await yield_(None)
423            out.append("post_" + str(num))
424            await yield_(None)
425
426        return inner
427
428    app.cleanup_ctx.append(f(1))
429    app.freeze()
430    await app.startup()
431    assert out == ["pre_1"]
432    with pytest.raises(RuntimeError) as ctx:
433        await app.cleanup()
434    assert "has more than one 'yield'" in str(ctx.value)
435    assert out == ["pre_1", "post_1"]
436
437
438async def test_subapp_chained_config_dict_visibility(aiohttp_client) -> None:
439    async def main_handler(request):
440        assert request.config_dict["key1"] == "val1"
441        assert "key2" not in request.config_dict
442        return web.Response(status=200)
443
444    root = web.Application()
445    root["key1"] = "val1"
446    root.add_routes([web.get("/", main_handler)])
447
448    async def sub_handler(request):
449        assert request.config_dict["key1"] == "val1"
450        assert request.config_dict["key2"] == "val2"
451        return web.Response(status=201)
452
453    sub = web.Application()
454    sub["key2"] = "val2"
455    sub.add_routes([web.get("/", sub_handler)])
456    root.add_subapp("/sub", sub)
457
458    client = await aiohttp_client(root)
459
460    resp = await client.get("/")
461    assert resp.status == 200
462    resp = await client.get("/sub/")
463    assert resp.status == 201
464
465
466async def test_subapp_chained_config_dict_overriding(aiohttp_client) -> None:
467    async def main_handler(request):
468        assert request.config_dict["key"] == "val1"
469        return web.Response(status=200)
470
471    root = web.Application()
472    root["key"] = "val1"
473    root.add_routes([web.get("/", main_handler)])
474
475    async def sub_handler(request):
476        assert request.config_dict["key"] == "val2"
477        return web.Response(status=201)
478
479    sub = web.Application()
480    sub["key"] = "val2"
481    sub.add_routes([web.get("/", sub_handler)])
482    root.add_subapp("/sub", sub)
483
484    client = await aiohttp_client(root)
485
486    resp = await client.get("/")
487    assert resp.status == 200
488    resp = await client.get("/sub/")
489    assert resp.status == 201
490
491
492async def test_subapp_on_startup(aiohttp_client) -> None:
493
494    subapp = web.Application()
495
496    startup_called = False
497
498    async def on_startup(app):
499        nonlocal startup_called
500        startup_called = True
501        app["startup"] = True
502
503    subapp.on_startup.append(on_startup)
504
505    ctx_pre_called = False
506    ctx_post_called = False
507
508    @async_generator
509    async def cleanup_ctx(app):
510        nonlocal ctx_pre_called, ctx_post_called
511        ctx_pre_called = True
512        app["cleanup"] = True
513        await yield_(None)
514        ctx_post_called = True
515
516    subapp.cleanup_ctx.append(cleanup_ctx)
517
518    shutdown_called = False
519
520    async def on_shutdown(app):
521        nonlocal shutdown_called
522        shutdown_called = True
523
524    subapp.on_shutdown.append(on_shutdown)
525
526    cleanup_called = False
527
528    async def on_cleanup(app):
529        nonlocal cleanup_called
530        cleanup_called = True
531
532    subapp.on_cleanup.append(on_cleanup)
533
534    app = web.Application()
535
536    app.add_subapp("/subapp", subapp)
537
538    assert not startup_called
539    assert not ctx_pre_called
540    assert not ctx_post_called
541    assert not shutdown_called
542    assert not cleanup_called
543
544    assert subapp.on_startup.frozen
545    assert subapp.cleanup_ctx.frozen
546    assert subapp.on_shutdown.frozen
547    assert subapp.on_cleanup.frozen
548    assert subapp.router.frozen
549
550    client = await aiohttp_client(app)
551
552    assert startup_called
553    assert ctx_pre_called
554    assert not ctx_post_called
555    assert not shutdown_called
556    assert not cleanup_called
557
558    await client.close()
559
560    assert startup_called
561    assert ctx_pre_called
562    assert ctx_post_called
563    assert shutdown_called
564    assert cleanup_called
565
566
567def test_app_iter():
568    app = web.Application()
569    app["a"] = "1"
570    app["b"] = "2"
571    assert sorted(list(app)) == ["a", "b"]
572
573
574def test_app_boolean() -> None:
575    app = web.Application()
576    assert app
577