1package dbus
2
3import (
4	"encoding/binary"
5	"io"
6	"io/ioutil"
7	"testing"
8	"time"
9)
10
11func TestSessionBus(t *testing.T) {
12	_, err := SessionBus()
13	if err != nil {
14		t.Error(err)
15	}
16}
17
18func TestSystemBus(t *testing.T) {
19	_, err := SystemBus()
20	if err != nil {
21		t.Error(err)
22	}
23}
24
25func ExampleSystemBusPrivate() {
26	setupPrivateSystemBus := func() (conn *Conn, err error) {
27		conn, err = SystemBusPrivate()
28		if err != nil {
29			return nil, err
30		}
31		if err = conn.Auth(nil); err != nil {
32			conn.Close()
33			conn = nil
34			return
35		}
36		if err = conn.Hello(); err != nil {
37			conn.Close()
38			conn = nil
39		}
40		return conn, nil // success
41	}
42	_, _ = setupPrivateSystemBus()
43}
44
45func TestSend(t *testing.T) {
46	bus, err := SessionBus()
47	if err != nil {
48		t.Fatal(err)
49	}
50	ch := make(chan *Call, 1)
51	msg := &Message{
52		Type:  TypeMethodCall,
53		Flags: 0,
54		Headers: map[HeaderField]Variant{
55			FieldDestination: MakeVariant(bus.Names()[0]),
56			FieldPath:        MakeVariant(ObjectPath("/org/freedesktop/DBus")),
57			FieldInterface:   MakeVariant("org.freedesktop.DBus.Peer"),
58			FieldMember:      MakeVariant("Ping"),
59		},
60	}
61	call := bus.Send(msg, ch)
62	<-ch
63	if call.Err != nil {
64		t.Error(call.Err)
65	}
66}
67
68func TestFlagNoReplyExpectedSend(t *testing.T) {
69	bus, err := SessionBus()
70	if err != nil {
71		t.Fatal(err)
72	}
73	done := make(chan struct{})
74	go func() {
75		bus.BusObject().Call("org.freedesktop.DBus.ListNames", FlagNoReplyExpected)
76		close(done)
77	}()
78	select {
79	case <-done:
80	case <-time.After(1 * time.Second):
81		t.Error("Failed to announce that the call was done")
82	}
83}
84
85func TestRemoveSignal(t *testing.T) {
86	bus, err := NewConn(nil)
87	if err != nil {
88		t.Error(err)
89	}
90	signals := bus.signalHandler.(*defaultSignalHandler).signals
91	ch := make(chan *Signal)
92	ch2 := make(chan *Signal)
93	for _, ch := range []chan *Signal{ch, ch2, ch, ch2, ch2, ch} {
94		bus.Signal(ch)
95	}
96	signals = bus.signalHandler.(*defaultSignalHandler).signals
97	if len(signals) != 6 {
98		t.Errorf("remove signal: signals length not equal: got '%d', want '6'", len(signals))
99	}
100	bus.RemoveSignal(ch)
101	signals = bus.signalHandler.(*defaultSignalHandler).signals
102	if len(signals) != 3 {
103		t.Errorf("remove signal: signals length not equal: got '%d', want '3'", len(signals))
104	}
105	signals = bus.signalHandler.(*defaultSignalHandler).signals
106	for _, scd := range signals {
107		if scd.ch != ch2 {
108			t.Errorf("remove signal: removed signal present: got '%v', want '%v'", scd.ch, ch2)
109		}
110	}
111}
112
113type rwc struct {
114	io.Reader
115	io.Writer
116}
117
118func (rwc) Close() error { return nil }
119
120type fakeAuth struct {
121}
122
123func (fakeAuth) FirstData() (name, resp []byte, status AuthStatus) {
124	return []byte("name"), []byte("resp"), AuthOk
125}
126
127func (fakeAuth) HandleData(data []byte) (resp []byte, status AuthStatus) {
128	return nil, AuthOk
129}
130
131func TestCloseBeforeSignal(t *testing.T) {
132	reader, pipewriter := io.Pipe()
133	defer pipewriter.Close()
134	defer reader.Close()
135
136	bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard})
137	if err != nil {
138		t.Fatal(err)
139	}
140	// give ch a buffer so sends won't block
141	ch := make(chan *Signal, 1)
142	bus.Signal(ch)
143
144	go func() {
145		_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
146		if err != nil {
147			t.Errorf("error writing to pipe: %v", err)
148		}
149	}()
150
151	err = bus.Auth([]Auth{fakeAuth{}})
152	if err != nil {
153		t.Fatal(err)
154	}
155
156	err = bus.Close()
157	if err != nil {
158		t.Fatal(err)
159	}
160
161	msg := &Message{
162		Type: TypeSignal,
163		Headers: map[HeaderField]Variant{
164			FieldInterface: MakeVariant("foo.bar"),
165			FieldMember:    MakeVariant("bar"),
166			FieldPath:      MakeVariant(ObjectPath("/baz")),
167		},
168	}
169	err = msg.EncodeTo(pipewriter, binary.LittleEndian)
170	if err != nil {
171		t.Fatal(err)
172	}
173}
174
175func TestCloseChannelAfterRemoveSignal(t *testing.T) {
176	bus, err := NewConn(nil)
177	if err != nil {
178		t.Fatal(err)
179	}
180
181	// Add an unbuffered signal channel
182	ch := make(chan *Signal)
183	bus.Signal(ch)
184
185	// Send a signal
186	msg := &Message{
187		Type: TypeSignal,
188		Headers: map[HeaderField]Variant{
189			FieldInterface: MakeVariant("foo.bar"),
190			FieldMember:    MakeVariant("bar"),
191			FieldPath:      MakeVariant(ObjectPath("/baz")),
192		},
193	}
194	bus.handleSignal(msg)
195
196	// Remove and close the signal channel
197	bus.RemoveSignal(ch)
198	close(ch)
199}
200
201func TestAddAndRemoveMatchSignal(t *testing.T) {
202	conn, err := SessionBusPrivate()
203	if err != nil {
204		t.Fatal(err)
205	}
206	defer conn.Close()
207
208	if err = conn.Auth(nil); err != nil {
209		t.Fatal(err)
210	}
211	if err = conn.Hello(); err != nil {
212		t.Fatal(err)
213	}
214
215	sigc := make(chan *Signal, 1)
216	conn.Signal(sigc)
217
218	// subscribe to a made up signal name and emit one of the type
219	if err = conn.AddMatchSignal(
220		WithMatchInterface("org.test"),
221		WithMatchMember("Test"),
222	); err != nil {
223		t.Fatal(err)
224	}
225	if err = conn.Emit("/", "org.test.Test"); err != nil {
226		t.Fatal(err)
227	}
228	if sig := waitSignal(sigc, "org.test.Test", time.Second); sig == nil {
229		t.Fatal("signal receive timed out")
230	}
231
232	// unsubscribe from the signal and check that is not delivered anymore
233	if err = conn.RemoveMatchSignal(
234		WithMatchInterface("org.test"),
235		WithMatchMember("Test"),
236	); err != nil {
237		t.Fatal(err)
238	}
239	if err = conn.Emit("/", "org.test.Test"); err != nil {
240		t.Fatal(err)
241	}
242	if sig := waitSignal(sigc, "org.test.Test", time.Second); sig != nil {
243		t.Fatalf("unsubscribed from %q signal, but received %#v", "org.test.Test", sig)
244	}
245}
246
247func waitSignal(sigc <-chan *Signal, name string, timeout time.Duration) *Signal {
248	for {
249		select {
250		case sig := <-sigc:
251			if sig.Name == name {
252				return sig
253			}
254		case <-time.After(timeout):
255			return nil
256		}
257	}
258}
259
260type server struct{}
261
262func (server) Double(i int64) (int64, *Error) {
263	return 2 * i, nil
264}
265
266func BenchmarkCall(b *testing.B) {
267	b.StopTimer()
268	b.ReportAllocs()
269	var s string
270	bus, err := SessionBus()
271	if err != nil {
272		b.Fatal(err)
273	}
274	name := bus.Names()[0]
275	obj := bus.BusObject()
276	b.StartTimer()
277	for i := 0; i < b.N; i++ {
278		err := obj.Call("org.freedesktop.DBus.GetNameOwner", 0, name).Store(&s)
279		if err != nil {
280			b.Fatal(err)
281		}
282		if s != name {
283			b.Errorf("got %s, wanted %s", s, name)
284		}
285	}
286}
287
288func BenchmarkCallAsync(b *testing.B) {
289	b.StopTimer()
290	b.ReportAllocs()
291	bus, err := SessionBus()
292	if err != nil {
293		b.Fatal(err)
294	}
295	name := bus.Names()[0]
296	obj := bus.BusObject()
297	c := make(chan *Call, 50)
298	done := make(chan struct{})
299	go func() {
300		for i := 0; i < b.N; i++ {
301			v := <-c
302			if v.Err != nil {
303				b.Error(v.Err)
304			}
305			s := v.Body[0].(string)
306			if s != name {
307				b.Errorf("got %s, wanted %s", s, name)
308			}
309		}
310		close(done)
311	}()
312	b.StartTimer()
313	for i := 0; i < b.N; i++ {
314		obj.Go("org.freedesktop.DBus.GetNameOwner", 0, c, name)
315	}
316	<-done
317}
318
319func BenchmarkServe(b *testing.B) {
320	b.StopTimer()
321	srv, err := SessionBus()
322	if err != nil {
323		b.Fatal(err)
324	}
325	cli, err := SessionBusPrivate()
326	if err != nil {
327		b.Fatal(err)
328	}
329	if err = cli.Auth(nil); err != nil {
330		b.Fatal(err)
331	}
332	if err = cli.Hello(); err != nil {
333		b.Fatal(err)
334	}
335	benchmarkServe(b, srv, cli)
336}
337
338func BenchmarkServeAsync(b *testing.B) {
339	b.StopTimer()
340	srv, err := SessionBus()
341	if err != nil {
342		b.Fatal(err)
343	}
344	cli, err := SessionBusPrivate()
345	if err != nil {
346		b.Fatal(err)
347	}
348	if err = cli.Auth(nil); err != nil {
349		b.Fatal(err)
350	}
351	if err = cli.Hello(); err != nil {
352		b.Fatal(err)
353	}
354	benchmarkServeAsync(b, srv, cli)
355}
356
357func BenchmarkServeSameConn(b *testing.B) {
358	b.StopTimer()
359	bus, err := SessionBus()
360	if err != nil {
361		b.Fatal(err)
362	}
363
364	benchmarkServe(b, bus, bus)
365}
366
367func BenchmarkServeSameConnAsync(b *testing.B) {
368	b.StopTimer()
369	bus, err := SessionBus()
370	if err != nil {
371		b.Fatal(err)
372	}
373
374	benchmarkServeAsync(b, bus, bus)
375}
376
377func benchmarkServe(b *testing.B, srv, cli *Conn) {
378	var r int64
379	var err error
380	dest := srv.Names()[0]
381	srv.Export(server{}, "/org/guelfey/DBus/Test", "org.guelfey.DBus.Test")
382	obj := cli.Object(dest, "/org/guelfey/DBus/Test")
383	b.StartTimer()
384	for i := 0; i < b.N; i++ {
385		err = obj.Call("org.guelfey.DBus.Test.Double", 0, int64(i)).Store(&r)
386		if err != nil {
387			b.Fatal(err)
388		}
389		if r != 2*int64(i) {
390			b.Errorf("got %d, wanted %d", r, 2*int64(i))
391		}
392	}
393}
394
395func benchmarkServeAsync(b *testing.B, srv, cli *Conn) {
396	dest := srv.Names()[0]
397	srv.Export(server{}, "/org/guelfey/DBus/Test", "org.guelfey.DBus.Test")
398	obj := cli.Object(dest, "/org/guelfey/DBus/Test")
399	c := make(chan *Call, 50)
400	done := make(chan struct{})
401	go func() {
402		for i := 0; i < b.N; i++ {
403			v := <-c
404			if v.Err != nil {
405				b.Fatal(v.Err)
406			}
407			i, r := v.Args[0].(int64), v.Body[0].(int64)
408			if 2*i != r {
409				b.Errorf("got %d, wanted %d", r, 2*i)
410			}
411		}
412		close(done)
413	}()
414	b.StartTimer()
415	for i := 0; i < b.N; i++ {
416		obj.Go("org.guelfey.DBus.Test.Double", 0, c, int64(i))
417	}
418	<-done
419}
420
421func TestGetKey(t *testing.T) {
422	keys := "host=1.2.3.4,port=5678,family=ipv4"
423	if host := getKey(keys, "host"); host != "1.2.3.4" {
424		t.Error(`Expected "1.2.3.4", got`, host)
425	}
426	if port := getKey(keys, "port"); port != "5678" {
427		t.Error(`Expected "5678", got`, port)
428	}
429	if family := getKey(keys, "family"); family != "ipv4" {
430		t.Error(`Expected "ipv4", got`, family)
431	}
432}
433