1package resolver_test
2
3import (
4	"bytes"
5	"context"
6	"errors"
7	"io"
8	"net/http"
9	"testing"
10	"time"
11
12	"github.com/miekg/dns"
13	"github.com/ooni/probe-engine/legacy/netx/dialid"
14	"github.com/ooni/probe-engine/legacy/netx/handlers"
15	"github.com/ooni/probe-engine/legacy/netx/modelx"
16	"github.com/ooni/probe-engine/legacy/netx/transactionid"
17	"github.com/ooni/probe-engine/netx/resolver"
18)
19
20func TestEmitterTransportSuccess(t *testing.T) {
21	ctx := context.Background()
22	ctx = dialid.WithDialID(ctx)
23	handler := &handlers.SavingHandler{}
24	root := &modelx.MeasurementRoot{
25		Beginning: time.Now(),
26		Handler:   handler,
27	}
28	ctx = modelx.WithMeasurementRoot(ctx, root)
29	txp := resolver.EmitterTransport{RoundTripper: resolver.FakeTransport{
30		Data: resolver.GenReplySuccess(t, dns.TypeA, "8.8.8.8"),
31	}}
32	e := resolver.MiekgEncoder{}
33	querydata, err := e.Encode("www.google.com", dns.TypeAAAA, true)
34	if err != nil {
35		t.Fatal(err)
36	}
37	replydata, err := txp.RoundTrip(ctx, querydata)
38	if err != nil {
39		t.Fatal(err)
40	}
41	events := handler.Read()
42	if len(events) != 2 {
43		t.Fatal("unexpected number of events")
44	}
45	if events[0].DNSQuery == nil {
46		t.Fatal("missing DNSQuery field")
47	}
48	if !bytes.Equal(events[0].DNSQuery.Data, querydata) {
49		t.Fatal("invalid query data")
50	}
51	if events[0].DNSQuery.DialID == 0 {
52		t.Fatal("invalid query DialID")
53	}
54	if events[0].DNSQuery.DurationSinceBeginning <= 0 {
55		t.Fatal("invalid duration since beginning")
56	}
57	if events[1].DNSReply == nil {
58		t.Fatal("missing DNSReply field")
59	}
60	if !bytes.Equal(events[1].DNSReply.Data, replydata) {
61		t.Fatal("missing reply data")
62	}
63	if events[1].DNSReply.DialID != 1 {
64		t.Fatal("invalid query DialID")
65	}
66	if events[1].DNSReply.DurationSinceBeginning <= 0 {
67		t.Fatal("invalid duration since beginning")
68	}
69}
70
71func TestEmitterTransportFailure(t *testing.T) {
72	ctx := context.Background()
73	ctx = dialid.WithDialID(ctx)
74	handler := &handlers.SavingHandler{}
75	root := &modelx.MeasurementRoot{
76		Beginning: time.Now(),
77		Handler:   handler,
78	}
79	ctx = modelx.WithMeasurementRoot(ctx, root)
80	mocked := errors.New("mocked error")
81	txp := resolver.EmitterTransport{RoundTripper: resolver.FakeTransport{
82		Err: mocked,
83	}}
84	e := resolver.MiekgEncoder{}
85	querydata, err := e.Encode("www.google.com", dns.TypeAAAA, true)
86	if err != nil {
87		t.Fatal(err)
88	}
89	replydata, err := txp.RoundTrip(ctx, querydata)
90	if !errors.Is(err, mocked) {
91		t.Fatal("not the error we expected")
92	}
93	if replydata != nil {
94		t.Fatal("expected nil replydata")
95	}
96	events := handler.Read()
97	if len(events) != 1 {
98		t.Fatal("unexpected number of events")
99	}
100	if events[0].DNSQuery == nil {
101		t.Fatal("missing DNSQuery field")
102	}
103	if !bytes.Equal(events[0].DNSQuery.Data, querydata) {
104		t.Fatal("invalid query data")
105	}
106	if events[0].DNSQuery.DialID == 0 {
107		t.Fatal("invalid query DialID")
108	}
109	if events[0].DNSQuery.DurationSinceBeginning <= 0 {
110		t.Fatal("invalid duration since beginning")
111	}
112}
113
114func TestEmitterResolverFailure(t *testing.T) {
115	ctx := context.Background()
116	ctx = dialid.WithDialID(ctx)
117	ctx = transactionid.WithTransactionID(ctx)
118	handler := &handlers.SavingHandler{}
119	root := &modelx.MeasurementRoot{
120		Beginning: time.Now(),
121		Handler:   handler,
122	}
123	ctx = modelx.WithMeasurementRoot(ctx, root)
124	r := resolver.EmitterResolver{Resolver: resolver.NewSerialResolver(
125		resolver.DNSOverHTTPS{
126			Do: func(req *http.Request) (*http.Response, error) {
127				return nil, io.EOF
128			},
129			URL: "https://dns.google.com/",
130		},
131	)}
132	replies, err := r.LookupHost(ctx, "www.google.com")
133	if !errors.Is(err, io.EOF) {
134		t.Fatal("not the error we expected")
135	}
136	if replies != nil {
137		t.Fatal("expected nil replies")
138	}
139	events := handler.Read()
140	if len(events) != 2 {
141		t.Fatal("unexpected number of events")
142	}
143	if events[0].ResolveStart == nil {
144		t.Fatal("missing ResolveStart field")
145	}
146	if events[0].ResolveStart.DialID == 0 {
147		t.Fatal("invalid DialID")
148	}
149	if events[0].ResolveStart.DurationSinceBeginning <= 0 {
150		t.Fatal("invalid duration since beginning")
151	}
152	if events[0].ResolveStart.Hostname != "www.google.com" {
153		t.Fatal("invalid Hostname")
154	}
155	if events[0].ResolveStart.TransactionID == 0 {
156		t.Fatal("invalid TransactionID")
157	}
158	if events[0].ResolveStart.TransportAddress != "https://dns.google.com/" {
159		t.Fatal("invalid TransportAddress")
160	}
161	if events[0].ResolveStart.TransportNetwork != "doh" {
162		t.Fatal("invalid TransportNetwork")
163	}
164	if events[1].ResolveDone == nil {
165		t.Fatal("missing ResolveDone field")
166	}
167	if events[1].ResolveDone.DialID == 0 {
168		t.Fatal("invalid DialID")
169	}
170	if events[1].ResolveDone.DurationSinceBeginning <= 0 {
171		t.Fatal("invalid duration since beginning")
172	}
173	if events[1].ResolveDone.Error != io.EOF {
174		t.Fatal("invalid Error")
175	}
176	if events[1].ResolveDone.Hostname != "www.google.com" {
177		t.Fatal("invalid Hostname")
178	}
179	if events[1].ResolveDone.TransactionID == 0 {
180		t.Fatal("invalid TransactionID")
181	}
182	if events[1].ResolveDone.TransportAddress != "https://dns.google.com/" {
183		t.Fatal("invalid TransportAddress")
184	}
185	if events[1].ResolveDone.TransportNetwork != "doh" {
186		t.Fatal("invalid TransportNetwork")
187	}
188}
189
190func TestEmitterResolverSuccess(t *testing.T) {
191	ctx := context.Background()
192	ctx = dialid.WithDialID(ctx)
193	ctx = transactionid.WithTransactionID(ctx)
194	handler := &handlers.SavingHandler{}
195	root := &modelx.MeasurementRoot{
196		Beginning: time.Now(),
197		Handler:   handler,
198	}
199	ctx = modelx.WithMeasurementRoot(ctx, root)
200	r := resolver.EmitterResolver{Resolver: resolver.NewFakeResolverWithResult(
201		[]string{"8.8.8.8"},
202	)}
203	replies, err := r.LookupHost(ctx, "dns.google.com")
204	if err != nil {
205		t.Fatal(err)
206	}
207	if len(replies) != 1 {
208		t.Fatal("expected a single replies")
209	}
210	events := handler.Read()
211	if len(events) != 2 {
212		t.Fatal("unexpected number of events")
213	}
214	if events[1].ResolveDone == nil {
215		t.Fatal("missing ResolveDone field")
216	}
217	if events[1].ResolveDone.Addresses[0] != "8.8.8.8" {
218		t.Fatal("invalid Addresses")
219	}
220}
221