1package dbus
2
3import (
4	"context"
5	"encoding/binary"
6	"io"
7	"io/ioutil"
8	"log"
9	"testing"
10	"time"
11)
12
13func TestSessionBus(t *testing.T) {
14	_, err := SessionBus()
15	if err != nil {
16		t.Error(err)
17	}
18}
19
20func TestSystemBus(t *testing.T) {
21	_, err := SystemBus()
22	if err != nil {
23		t.Error(err)
24	}
25}
26
27func ExampleSystemBusPrivate() {
28	setupPrivateSystemBus := func() (conn *Conn, err error) {
29		conn, err = SystemBusPrivate()
30		if err != nil {
31			return nil, err
32		}
33		if err = conn.Auth(nil); err != nil {
34			conn.Close()
35			conn = nil
36			return
37		}
38		if err = conn.Hello(); err != nil {
39			conn.Close()
40			conn = nil
41		}
42		return conn, nil // success
43	}
44	_, _ = setupPrivateSystemBus()
45}
46
47func TestSend(t *testing.T) {
48	bus, err := SessionBus()
49	if err != nil {
50		t.Fatal(err)
51	}
52	ch := make(chan *Call, 1)
53	msg := &Message{
54		Type:  TypeMethodCall,
55		Flags: 0,
56		Headers: map[HeaderField]Variant{
57			FieldDestination: MakeVariant(bus.Names()[0]),
58			FieldPath:        MakeVariant(ObjectPath("/org/freedesktop/DBus")),
59			FieldInterface:   MakeVariant("org.freedesktop.DBus.Peer"),
60			FieldMember:      MakeVariant("Ping"),
61		},
62	}
63	call := bus.Send(msg, ch)
64	<-ch
65	if call.Err != nil {
66		t.Error(call.Err)
67	}
68}
69
70func TestFlagNoReplyExpectedSend(t *testing.T) {
71	bus, err := SessionBus()
72	if err != nil {
73		t.Fatal(err)
74	}
75	done := make(chan struct{})
76	go func() {
77		bus.BusObject().Call("org.freedesktop.DBus.ListNames", FlagNoReplyExpected)
78		close(done)
79	}()
80	select {
81	case <-done:
82	case <-time.After(1 * time.Second):
83		t.Error("Failed to announce that the call was done")
84	}
85}
86
87func TestRemoveSignal(t *testing.T) {
88	bus, err := NewConn(nil)
89	if err != nil {
90		t.Error(err)
91	}
92	signals := bus.signalHandler.(*defaultSignalHandler).signals
93	ch := make(chan *Signal)
94	ch2 := make(chan *Signal)
95	for _, ch := range []chan *Signal{ch, ch2, ch, ch2, ch2, ch} {
96		bus.Signal(ch)
97	}
98	signals = bus.signalHandler.(*defaultSignalHandler).signals
99	if len(signals) != 6 {
100		t.Errorf("remove signal: signals length not equal: got '%d', want '6'", len(signals))
101	}
102	bus.RemoveSignal(ch)
103	signals = bus.signalHandler.(*defaultSignalHandler).signals
104	if len(signals) != 3 {
105		t.Errorf("remove signal: signals length not equal: got '%d', want '3'", len(signals))
106	}
107	signals = bus.signalHandler.(*defaultSignalHandler).signals
108	for _, scd := range signals {
109		if scd.ch != ch2 {
110			t.Errorf("remove signal: removed signal present: got '%v', want '%v'", scd.ch, ch2)
111		}
112	}
113}
114
115type rwc struct {
116	io.Reader
117	io.Writer
118}
119
120func (rwc) Close() error { return nil }
121
122type fakeAuth struct {
123}
124
125func (fakeAuth) FirstData() (name, resp []byte, status AuthStatus) {
126	return []byte("name"), []byte("resp"), AuthOk
127}
128
129func (fakeAuth) HandleData(data []byte) (resp []byte, status AuthStatus) {
130	return nil, AuthOk
131}
132
133func TestCloseBeforeSignal(t *testing.T) {
134	reader, pipewriter := io.Pipe()
135	defer pipewriter.Close()
136	defer reader.Close()
137
138	bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard})
139	if err != nil {
140		t.Fatal(err)
141	}
142	// give ch a buffer so sends won't block
143	ch := make(chan *Signal, 1)
144	bus.Signal(ch)
145
146	go func() {
147		_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
148		if err != nil {
149			t.Errorf("error writing to pipe: %v", err)
150		}
151	}()
152
153	err = bus.Auth([]Auth{fakeAuth{}})
154	if err != nil {
155		t.Fatal(err)
156	}
157
158	err = bus.Close()
159	if err != nil {
160		t.Fatal(err)
161	}
162
163	msg := &Message{
164		Type: TypeSignal,
165		Headers: map[HeaderField]Variant{
166			FieldInterface: MakeVariant("foo.bar"),
167			FieldMember:    MakeVariant("bar"),
168			FieldPath:      MakeVariant(ObjectPath("/baz")),
169		},
170	}
171	err = msg.EncodeTo(pipewriter, binary.LittleEndian)
172	if err != nil {
173		t.Fatal(err)
174	}
175}
176
177func TestCloseChannelAfterRemoveSignal(t *testing.T) {
178	bus, err := NewConn(nil)
179	if err != nil {
180		t.Fatal(err)
181	}
182
183	// Add an unbuffered signal channel
184	ch := make(chan *Signal)
185	bus.Signal(ch)
186
187	// Send a signal
188	msg := &Message{
189		Type: TypeSignal,
190		Headers: map[HeaderField]Variant{
191			FieldInterface: MakeVariant("foo.bar"),
192			FieldMember:    MakeVariant("bar"),
193			FieldPath:      MakeVariant(ObjectPath("/baz")),
194		},
195	}
196	bus.handleSignal(msg)
197
198	// Remove and close the signal channel
199	bus.RemoveSignal(ch)
200	close(ch)
201}
202
203func TestAddAndRemoveMatchSignal(t *testing.T) {
204	conn, err := SessionBusPrivate()
205	if err != nil {
206		t.Fatal(err)
207	}
208	defer conn.Close()
209
210	if err = conn.Auth(nil); err != nil {
211		t.Fatal(err)
212	}
213	if err = conn.Hello(); err != nil {
214		t.Fatal(err)
215	}
216
217	sigc := make(chan *Signal, 1)
218	conn.Signal(sigc)
219
220	// subscribe to a made up signal name and emit one of the type
221	if err = conn.AddMatchSignal(
222		WithMatchInterface("org.test"),
223		WithMatchMember("Test"),
224	); err != nil {
225		t.Fatal(err)
226	}
227	if err = conn.Emit("/", "org.test.Test"); err != nil {
228		t.Fatal(err)
229	}
230	if sig := waitSignal(sigc, "org.test.Test", time.Second); sig == nil {
231		t.Fatal("signal receive timed out")
232	}
233
234	// unsubscribe from the signal and check that is not delivered anymore
235	if err = conn.RemoveMatchSignal(
236		WithMatchInterface("org.test"),
237		WithMatchMember("Test"),
238	); err != nil {
239		t.Fatal(err)
240	}
241	if err = conn.Emit("/", "org.test.Test"); err != nil {
242		t.Fatal(err)
243	}
244	if sig := waitSignal(sigc, "org.test.Test", time.Second); sig != nil {
245		t.Fatalf("unsubscribed from %q signal, but received %#v", "org.test.Test", sig)
246	}
247}
248
249func waitSignal(sigc <-chan *Signal, name string, timeout time.Duration) *Signal {
250	for {
251		select {
252		case sig := <-sigc:
253			if sig.Name == name {
254				return sig
255			}
256		case <-time.After(timeout):
257			return nil
258		}
259	}
260}
261
262type server struct{}
263
264func (server) Double(i int64) (int64, *Error) {
265	return 2 * i, nil
266}
267
268func BenchmarkCall(b *testing.B) {
269	b.StopTimer()
270	b.ReportAllocs()
271	var s string
272	bus, err := SessionBus()
273	if err != nil {
274		b.Fatal(err)
275	}
276	name := bus.Names()[0]
277	obj := bus.BusObject()
278	b.StartTimer()
279	for i := 0; i < b.N; i++ {
280		err := obj.Call("org.freedesktop.DBus.GetNameOwner", 0, name).Store(&s)
281		if err != nil {
282			b.Fatal(err)
283		}
284		if s != name {
285			b.Errorf("got %s, wanted %s", s, name)
286		}
287	}
288}
289
290func BenchmarkCallAsync(b *testing.B) {
291	b.StopTimer()
292	b.ReportAllocs()
293	bus, err := SessionBus()
294	if err != nil {
295		b.Fatal(err)
296	}
297	name := bus.Names()[0]
298	obj := bus.BusObject()
299	c := make(chan *Call, 50)
300	done := make(chan struct{})
301	go func() {
302		for i := 0; i < b.N; i++ {
303			v := <-c
304			if v.Err != nil {
305				b.Error(v.Err)
306			}
307			s := v.Body[0].(string)
308			if s != name {
309				b.Errorf("got %s, wanted %s", s, name)
310			}
311		}
312		close(done)
313	}()
314	b.StartTimer()
315	for i := 0; i < b.N; i++ {
316		obj.Go("org.freedesktop.DBus.GetNameOwner", 0, c, name)
317	}
318	<-done
319}
320
321func BenchmarkServe(b *testing.B) {
322	b.StopTimer()
323	srv, err := SessionBus()
324	if err != nil {
325		b.Fatal(err)
326	}
327	cli, err := SessionBusPrivate()
328	if err != nil {
329		b.Fatal(err)
330	}
331	if err = cli.Auth(nil); err != nil {
332		b.Fatal(err)
333	}
334	if err = cli.Hello(); err != nil {
335		b.Fatal(err)
336	}
337	benchmarkServe(b, srv, cli)
338}
339
340func BenchmarkServeAsync(b *testing.B) {
341	b.StopTimer()
342	srv, err := SessionBus()
343	if err != nil {
344		b.Fatal(err)
345	}
346	cli, err := SessionBusPrivate()
347	if err != nil {
348		b.Fatal(err)
349	}
350	if err = cli.Auth(nil); err != nil {
351		b.Fatal(err)
352	}
353	if err = cli.Hello(); err != nil {
354		b.Fatal(err)
355	}
356	benchmarkServeAsync(b, srv, cli)
357}
358
359func BenchmarkServeSameConn(b *testing.B) {
360	b.StopTimer()
361	bus, err := SessionBus()
362	if err != nil {
363		b.Fatal(err)
364	}
365
366	benchmarkServe(b, bus, bus)
367}
368
369func BenchmarkServeSameConnAsync(b *testing.B) {
370	b.StopTimer()
371	bus, err := SessionBus()
372	if err != nil {
373		b.Fatal(err)
374	}
375
376	benchmarkServeAsync(b, bus, bus)
377}
378
379func benchmarkServe(b *testing.B, srv, cli *Conn) {
380	var r int64
381	var err error
382	dest := srv.Names()[0]
383	srv.Export(server{}, "/org/guelfey/DBus/Test", "org.guelfey.DBus.Test")
384	obj := cli.Object(dest, "/org/guelfey/DBus/Test")
385	b.StartTimer()
386	for i := 0; i < b.N; i++ {
387		err = obj.Call("org.guelfey.DBus.Test.Double", 0, int64(i)).Store(&r)
388		if err != nil {
389			b.Fatal(err)
390		}
391		if r != 2*int64(i) {
392			b.Errorf("got %d, wanted %d", r, 2*int64(i))
393		}
394	}
395}
396
397func benchmarkServeAsync(b *testing.B, srv, cli *Conn) {
398	dest := srv.Names()[0]
399	srv.Export(server{}, "/org/guelfey/DBus/Test", "org.guelfey.DBus.Test")
400	obj := cli.Object(dest, "/org/guelfey/DBus/Test")
401	c := make(chan *Call, 50)
402	done := make(chan struct{})
403	go func() {
404		for i := 0; i < b.N; i++ {
405			v := <-c
406			if v.Err != nil {
407				b.Fatal(v.Err)
408			}
409			i, r := v.Args[0].(int64), v.Body[0].(int64)
410			if 2*i != r {
411				b.Errorf("got %d, wanted %d", r, 2*i)
412			}
413		}
414		close(done)
415	}()
416	b.StartTimer()
417	for i := 0; i < b.N; i++ {
418		obj.Go("org.guelfey.DBus.Test.Double", 0, c, int64(i))
419	}
420	<-done
421}
422
423func TestGetKey(t *testing.T) {
424	keys := "host=1.2.3.4,port=5678,family=ipv4"
425	if host := getKey(keys, "host"); host != "1.2.3.4" {
426		t.Error(`Expected "1.2.3.4", got`, host)
427	}
428	if port := getKey(keys, "port"); port != "5678" {
429		t.Error(`Expected "5678", got`, port)
430	}
431	if family := getKey(keys, "family"); family != "ipv4" {
432		t.Error(`Expected "ipv4", got`, family)
433	}
434}
435
436func TestInterceptors(t *testing.T) {
437	conn, err := SessionBusPrivate(
438		WithIncomingInterceptor(func(msg *Message) {
439			log.Println("<", msg)
440		}),
441		WithOutgoingInterceptor(func(msg *Message) {
442			log.Println(">", msg)
443		}),
444	)
445	if err != nil {
446		t.Fatal(err)
447	}
448	defer conn.Close()
449
450	if err = conn.Auth(nil); err != nil {
451		t.Fatal(err)
452	}
453	if err = conn.Hello(); err != nil {
454		t.Fatal(err)
455	}
456}
457
458func TestCloseCancelsConnectionContext(t *testing.T) {
459	bus, err := SessionBusPrivate()
460	if err != nil {
461		t.Fatal(err)
462	}
463	defer bus.Close()
464
465	if err = bus.Auth(nil); err != nil {
466		t.Fatal(err)
467	}
468	if err = bus.Hello(); err != nil {
469		t.Fatal(err)
470	}
471	if err != nil {
472		t.Fatal(err)
473	}
474
475	// The context is not done at this point
476	ctx := bus.Context()
477	select {
478	case <-ctx.Done():
479		t.Fatal("context should not be done")
480	default:
481	}
482
483	err = bus.Close()
484	if err != nil {
485		t.Fatal(err)
486	}
487	select {
488	case <-ctx.Done():
489		// expected
490	case <-time.After(5 * time.Second):
491		t.Fatal("context is not done after connection closed")
492	}
493}
494
495func TestDisconnectCancelsConnectionContext(t *testing.T) {
496	reader, pipewriter := io.Pipe()
497	defer pipewriter.Close()
498	defer reader.Close()
499
500	bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard})
501	if err != nil {
502		t.Fatal(err)
503	}
504
505	go func() {
506		_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
507		if err != nil {
508			t.Errorf("error writing to pipe: %v", err)
509		}
510	}()
511	err = bus.Auth([]Auth{fakeAuth{}})
512	if err != nil {
513		t.Fatal(err)
514	}
515
516	ctx := bus.Context()
517
518	pipewriter.Close()
519	select {
520	case <-ctx.Done():
521		// expected
522	case <-time.After(5 * time.Second):
523		t.Fatal("context is not done after connection closed")
524	}
525}
526
527func TestCancellingContextClosesConnection(t *testing.T) {
528	ctx, cancel := context.WithCancel(context.Background())
529	defer cancel()
530
531	reader, pipewriter := io.Pipe()
532	defer pipewriter.Close()
533	defer reader.Close()
534
535	bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard}, WithContext(ctx))
536	if err != nil {
537		t.Fatal(err)
538	}
539
540	go func() {
541		_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
542		if err != nil {
543			t.Errorf("error writing to pipe: %v", err)
544		}
545	}()
546	err = bus.Auth([]Auth{fakeAuth{}})
547	if err != nil {
548		t.Fatal(err)
549	}
550
551	// Cancel the connection's parent context and give time for
552	// other goroutines to schedule.
553	cancel()
554	time.Sleep(50 * time.Millisecond)
555
556	err = bus.BusObject().Call("org.freedesktop.DBus.Peer.Ping", 0).Store()
557	if err != ErrClosed {
558		t.Errorf("expected connection to be closed, but got: %v", err)
559	}
560}
561