1package httptransport
2
3import (
4	"io"
5	"net/http"
6
7	"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
8)
9
10// ByteCountingTransport is a RoundTripper that counts bytes.
11type ByteCountingTransport struct {
12	RoundTripper
13	Counter *bytecounter.Counter
14}
15
16// RoundTrip implements RoundTripper.RoundTrip
17func (txp ByteCountingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
18	if req.Body != nil {
19		req.Body = byteCountingBody{
20			ReadCloser: req.Body, Account: txp.Counter.CountBytesSent}
21	}
22	txp.estimateRequestMetadata(req)
23	resp, err := txp.RoundTripper.RoundTrip(req)
24	if err != nil {
25		return nil, err
26	}
27	txp.estimateResponseMetadata(resp)
28	resp.Body = byteCountingBody{
29		ReadCloser: resp.Body, Account: txp.Counter.CountBytesReceived}
30	return resp, nil
31}
32
33func (txp ByteCountingTransport) estimateRequestMetadata(req *http.Request) {
34	txp.Counter.CountBytesSent(len(req.Method))
35	txp.Counter.CountBytesSent(len(req.URL.String()))
36	for key, values := range req.Header {
37		for _, value := range values {
38			txp.Counter.CountBytesSent(len(key))
39			txp.Counter.CountBytesSent(len(": "))
40			txp.Counter.CountBytesSent(len(value))
41			txp.Counter.CountBytesSent(len("\r\n"))
42		}
43	}
44	txp.Counter.CountBytesSent(len("\r\n"))
45}
46
47func (txp ByteCountingTransport) estimateResponseMetadata(resp *http.Response) {
48	txp.Counter.CountBytesReceived(len(resp.Status))
49	for key, values := range resp.Header {
50		for _, value := range values {
51			txp.Counter.CountBytesReceived(len(key))
52			txp.Counter.CountBytesReceived(len(": "))
53			txp.Counter.CountBytesReceived(len(value))
54			txp.Counter.CountBytesReceived(len("\r\n"))
55		}
56	}
57	txp.Counter.CountBytesReceived(len("\r\n"))
58}
59
60type byteCountingBody struct {
61	io.ReadCloser
62	Account func(int)
63}
64
65func (r byteCountingBody) Read(p []byte) (int, error) {
66	count, err := r.ReadCloser.Read(p)
67	if count > 0 {
68		r.Account(count)
69	}
70	return count, err
71}
72
73var _ RoundTripper = ByteCountingTransport{}
74