1from types import SimpleNamespace
2from typing import TYPE_CHECKING, Awaitable, Optional, Type, TypeVar
3
4import attr
5from multidict import CIMultiDict
6from yarl import URL
7
8from .client_reqrep import ClientResponse
9from .signals import Signal
10
11if TYPE_CHECKING:  # pragma: no cover
12    from typing_extensions import Protocol
13
14    from .client import ClientSession
15
16    _ParamT_contra = TypeVar("_ParamT_contra", contravariant=True)
17
18    class _SignalCallback(Protocol[_ParamT_contra]):
19        def __call__(
20            self,
21            __client_session: ClientSession,
22            __trace_config_ctx: SimpleNamespace,
23            __params: _ParamT_contra,
24        ) -> Awaitable[None]:
25            ...
26
27
28__all__ = (
29    "TraceConfig",
30    "TraceRequestStartParams",
31    "TraceRequestEndParams",
32    "TraceRequestExceptionParams",
33    "TraceConnectionQueuedStartParams",
34    "TraceConnectionQueuedEndParams",
35    "TraceConnectionCreateStartParams",
36    "TraceConnectionCreateEndParams",
37    "TraceConnectionReuseconnParams",
38    "TraceDnsResolveHostStartParams",
39    "TraceDnsResolveHostEndParams",
40    "TraceDnsCacheHitParams",
41    "TraceDnsCacheMissParams",
42    "TraceRequestRedirectParams",
43    "TraceRequestChunkSentParams",
44    "TraceResponseChunkReceivedParams",
45)
46
47
48class TraceConfig:
49    """First-class used to trace requests launched via ClientSession
50    objects."""
51
52    def __init__(
53        self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace
54    ) -> None:
55        self._on_request_start = Signal(
56            self
57        )  # type: Signal[_SignalCallback[TraceRequestStartParams]]
58        self._on_request_chunk_sent = Signal(
59            self
60        )  # type: Signal[_SignalCallback[TraceRequestChunkSentParams]]
61        self._on_response_chunk_received = Signal(
62            self
63        )  # type: Signal[_SignalCallback[TraceResponseChunkReceivedParams]]
64        self._on_request_end = Signal(
65            self
66        )  # type: Signal[_SignalCallback[TraceRequestEndParams]]
67        self._on_request_exception = Signal(
68            self
69        )  # type: Signal[_SignalCallback[TraceRequestExceptionParams]]
70        self._on_request_redirect = Signal(
71            self
72        )  # type: Signal[_SignalCallback[TraceRequestRedirectParams]]
73        self._on_connection_queued_start = Signal(
74            self
75        )  # type: Signal[_SignalCallback[TraceConnectionQueuedStartParams]]
76        self._on_connection_queued_end = Signal(
77            self
78        )  # type: Signal[_SignalCallback[TraceConnectionQueuedEndParams]]
79        self._on_connection_create_start = Signal(
80            self
81        )  # type: Signal[_SignalCallback[TraceConnectionCreateStartParams]]
82        self._on_connection_create_end = Signal(
83            self
84        )  # type: Signal[_SignalCallback[TraceConnectionCreateEndParams]]
85        self._on_connection_reuseconn = Signal(
86            self
87        )  # type: Signal[_SignalCallback[TraceConnectionReuseconnParams]]
88        self._on_dns_resolvehost_start = Signal(
89            self
90        )  # type: Signal[_SignalCallback[TraceDnsResolveHostStartParams]]
91        self._on_dns_resolvehost_end = Signal(
92            self
93        )  # type: Signal[_SignalCallback[TraceDnsResolveHostEndParams]]
94        self._on_dns_cache_hit = Signal(
95            self
96        )  # type: Signal[_SignalCallback[TraceDnsCacheHitParams]]
97        self._on_dns_cache_miss = Signal(
98            self
99        )  # type: Signal[_SignalCallback[TraceDnsCacheMissParams]]
100
101        self._trace_config_ctx_factory = trace_config_ctx_factory
102
103    def trace_config_ctx(
104        self, trace_request_ctx: Optional[SimpleNamespace] = None
105    ) -> SimpleNamespace:
106        """ Return a new trace_config_ctx instance """
107        return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx)
108
109    def freeze(self) -> None:
110        self._on_request_start.freeze()
111        self._on_request_chunk_sent.freeze()
112        self._on_response_chunk_received.freeze()
113        self._on_request_end.freeze()
114        self._on_request_exception.freeze()
115        self._on_request_redirect.freeze()
116        self._on_connection_queued_start.freeze()
117        self._on_connection_queued_end.freeze()
118        self._on_connection_create_start.freeze()
119        self._on_connection_create_end.freeze()
120        self._on_connection_reuseconn.freeze()
121        self._on_dns_resolvehost_start.freeze()
122        self._on_dns_resolvehost_end.freeze()
123        self._on_dns_cache_hit.freeze()
124        self._on_dns_cache_miss.freeze()
125
126    @property
127    def on_request_start(self) -> "Signal[_SignalCallback[TraceRequestStartParams]]":
128        return self._on_request_start
129
130    @property
131    def on_request_chunk_sent(
132        self,
133    ) -> "Signal[_SignalCallback[TraceRequestChunkSentParams]]":
134        return self._on_request_chunk_sent
135
136    @property
137    def on_response_chunk_received(
138        self,
139    ) -> "Signal[_SignalCallback[TraceResponseChunkReceivedParams]]":
140        return self._on_response_chunk_received
141
142    @property
143    def on_request_end(self) -> "Signal[_SignalCallback[TraceRequestEndParams]]":
144        return self._on_request_end
145
146    @property
147    def on_request_exception(
148        self,
149    ) -> "Signal[_SignalCallback[TraceRequestExceptionParams]]":
150        return self._on_request_exception
151
152    @property
153    def on_request_redirect(
154        self,
155    ) -> "Signal[_SignalCallback[TraceRequestRedirectParams]]":
156        return self._on_request_redirect
157
158    @property
159    def on_connection_queued_start(
160        self,
161    ) -> "Signal[_SignalCallback[TraceConnectionQueuedStartParams]]":
162        return self._on_connection_queued_start
163
164    @property
165    def on_connection_queued_end(
166        self,
167    ) -> "Signal[_SignalCallback[TraceConnectionQueuedEndParams]]":
168        return self._on_connection_queued_end
169
170    @property
171    def on_connection_create_start(
172        self,
173    ) -> "Signal[_SignalCallback[TraceConnectionCreateStartParams]]":
174        return self._on_connection_create_start
175
176    @property
177    def on_connection_create_end(
178        self,
179    ) -> "Signal[_SignalCallback[TraceConnectionCreateEndParams]]":
180        return self._on_connection_create_end
181
182    @property
183    def on_connection_reuseconn(
184        self,
185    ) -> "Signal[_SignalCallback[TraceConnectionReuseconnParams]]":
186        return self._on_connection_reuseconn
187
188    @property
189    def on_dns_resolvehost_start(
190        self,
191    ) -> "Signal[_SignalCallback[TraceDnsResolveHostStartParams]]":
192        return self._on_dns_resolvehost_start
193
194    @property
195    def on_dns_resolvehost_end(
196        self,
197    ) -> "Signal[_SignalCallback[TraceDnsResolveHostEndParams]]":
198        return self._on_dns_resolvehost_end
199
200    @property
201    def on_dns_cache_hit(self) -> "Signal[_SignalCallback[TraceDnsCacheHitParams]]":
202        return self._on_dns_cache_hit
203
204    @property
205    def on_dns_cache_miss(self) -> "Signal[_SignalCallback[TraceDnsCacheMissParams]]":
206        return self._on_dns_cache_miss
207
208
209@attr.s(auto_attribs=True, frozen=True, slots=True)
210class TraceRequestStartParams:
211    """ Parameters sent by the `on_request_start` signal"""
212
213    method: str
214    url: URL
215    headers: "CIMultiDict[str]"
216
217
218@attr.s(auto_attribs=True, frozen=True, slots=True)
219class TraceRequestChunkSentParams:
220    """ Parameters sent by the `on_request_chunk_sent` signal"""
221
222    method: str
223    url: URL
224    chunk: bytes
225
226
227@attr.s(auto_attribs=True, frozen=True, slots=True)
228class TraceResponseChunkReceivedParams:
229    """ Parameters sent by the `on_response_chunk_received` signal"""
230
231    method: str
232    url: URL
233    chunk: bytes
234
235
236@attr.s(auto_attribs=True, frozen=True, slots=True)
237class TraceRequestEndParams:
238    """ Parameters sent by the `on_request_end` signal"""
239
240    method: str
241    url: URL
242    headers: "CIMultiDict[str]"
243    response: ClientResponse
244
245
246@attr.s(auto_attribs=True, frozen=True, slots=True)
247class TraceRequestExceptionParams:
248    """ Parameters sent by the `on_request_exception` signal"""
249
250    method: str
251    url: URL
252    headers: "CIMultiDict[str]"
253    exception: BaseException
254
255
256@attr.s(auto_attribs=True, frozen=True, slots=True)
257class TraceRequestRedirectParams:
258    """ Parameters sent by the `on_request_redirect` signal"""
259
260    method: str
261    url: URL
262    headers: "CIMultiDict[str]"
263    response: ClientResponse
264
265
266@attr.s(auto_attribs=True, frozen=True, slots=True)
267class TraceConnectionQueuedStartParams:
268    """ Parameters sent by the `on_connection_queued_start` signal"""
269
270
271@attr.s(auto_attribs=True, frozen=True, slots=True)
272class TraceConnectionQueuedEndParams:
273    """ Parameters sent by the `on_connection_queued_end` signal"""
274
275
276@attr.s(auto_attribs=True, frozen=True, slots=True)
277class TraceConnectionCreateStartParams:
278    """ Parameters sent by the `on_connection_create_start` signal"""
279
280
281@attr.s(auto_attribs=True, frozen=True, slots=True)
282class TraceConnectionCreateEndParams:
283    """ Parameters sent by the `on_connection_create_end` signal"""
284
285
286@attr.s(auto_attribs=True, frozen=True, slots=True)
287class TraceConnectionReuseconnParams:
288    """ Parameters sent by the `on_connection_reuseconn` signal"""
289
290
291@attr.s(auto_attribs=True, frozen=True, slots=True)
292class TraceDnsResolveHostStartParams:
293    """ Parameters sent by the `on_dns_resolvehost_start` signal"""
294
295    host: str
296
297
298@attr.s(auto_attribs=True, frozen=True, slots=True)
299class TraceDnsResolveHostEndParams:
300    """ Parameters sent by the `on_dns_resolvehost_end` signal"""
301
302    host: str
303
304
305@attr.s(auto_attribs=True, frozen=True, slots=True)
306class TraceDnsCacheHitParams:
307    """ Parameters sent by the `on_dns_cache_hit` signal"""
308
309    host: str
310
311
312@attr.s(auto_attribs=True, frozen=True, slots=True)
313class TraceDnsCacheMissParams:
314    """ Parameters sent by the `on_dns_cache_miss` signal"""
315
316    host: str
317
318
319class Trace:
320    """Internal class used to keep together the main dependencies used
321    at the moment of send a signal."""
322
323    def __init__(
324        self,
325        session: "ClientSession",
326        trace_config: TraceConfig,
327        trace_config_ctx: SimpleNamespace,
328    ) -> None:
329        self._trace_config = trace_config
330        self._trace_config_ctx = trace_config_ctx
331        self._session = session
332
333    async def send_request_start(
334        self, method: str, url: URL, headers: "CIMultiDict[str]"
335    ) -> None:
336        return await self._trace_config.on_request_start.send(
337            self._session,
338            self._trace_config_ctx,
339            TraceRequestStartParams(method, url, headers),
340        )
341
342    async def send_request_chunk_sent(
343        self, method: str, url: URL, chunk: bytes
344    ) -> None:
345        return await self._trace_config.on_request_chunk_sent.send(
346            self._session,
347            self._trace_config_ctx,
348            TraceRequestChunkSentParams(method, url, chunk),
349        )
350
351    async def send_response_chunk_received(
352        self, method: str, url: URL, chunk: bytes
353    ) -> None:
354        return await self._trace_config.on_response_chunk_received.send(
355            self._session,
356            self._trace_config_ctx,
357            TraceResponseChunkReceivedParams(method, url, chunk),
358        )
359
360    async def send_request_end(
361        self,
362        method: str,
363        url: URL,
364        headers: "CIMultiDict[str]",
365        response: ClientResponse,
366    ) -> None:
367        return await self._trace_config.on_request_end.send(
368            self._session,
369            self._trace_config_ctx,
370            TraceRequestEndParams(method, url, headers, response),
371        )
372
373    async def send_request_exception(
374        self,
375        method: str,
376        url: URL,
377        headers: "CIMultiDict[str]",
378        exception: BaseException,
379    ) -> None:
380        return await self._trace_config.on_request_exception.send(
381            self._session,
382            self._trace_config_ctx,
383            TraceRequestExceptionParams(method, url, headers, exception),
384        )
385
386    async def send_request_redirect(
387        self,
388        method: str,
389        url: URL,
390        headers: "CIMultiDict[str]",
391        response: ClientResponse,
392    ) -> None:
393        return await self._trace_config._on_request_redirect.send(
394            self._session,
395            self._trace_config_ctx,
396            TraceRequestRedirectParams(method, url, headers, response),
397        )
398
399    async def send_connection_queued_start(self) -> None:
400        return await self._trace_config.on_connection_queued_start.send(
401            self._session, self._trace_config_ctx, TraceConnectionQueuedStartParams()
402        )
403
404    async def send_connection_queued_end(self) -> None:
405        return await self._trace_config.on_connection_queued_end.send(
406            self._session, self._trace_config_ctx, TraceConnectionQueuedEndParams()
407        )
408
409    async def send_connection_create_start(self) -> None:
410        return await self._trace_config.on_connection_create_start.send(
411            self._session, self._trace_config_ctx, TraceConnectionCreateStartParams()
412        )
413
414    async def send_connection_create_end(self) -> None:
415        return await self._trace_config.on_connection_create_end.send(
416            self._session, self._trace_config_ctx, TraceConnectionCreateEndParams()
417        )
418
419    async def send_connection_reuseconn(self) -> None:
420        return await self._trace_config.on_connection_reuseconn.send(
421            self._session, self._trace_config_ctx, TraceConnectionReuseconnParams()
422        )
423
424    async def send_dns_resolvehost_start(self, host: str) -> None:
425        return await self._trace_config.on_dns_resolvehost_start.send(
426            self._session, self._trace_config_ctx, TraceDnsResolveHostStartParams(host)
427        )
428
429    async def send_dns_resolvehost_end(self, host: str) -> None:
430        return await self._trace_config.on_dns_resolvehost_end.send(
431            self._session, self._trace_config_ctx, TraceDnsResolveHostEndParams(host)
432        )
433
434    async def send_dns_cache_hit(self, host: str) -> None:
435        return await self._trace_config.on_dns_cache_hit.send(
436            self._session, self._trace_config_ctx, TraceDnsCacheHitParams(host)
437        )
438
439    async def send_dns_cache_miss(self, host: str) -> None:
440        return await self._trace_config.on_dns_cache_miss.send(
441            self._session, self._trace_config_ctx, TraceDnsCacheMissParams(host)
442        )
443