1package dbus
2
3import (
4	"bytes"
5	"encoding/binary"
6	"errors"
7	"io"
8	"net"
9	"syscall"
10)
11
12type oobReader struct {
13	conn *net.UnixConn
14	oob  []byte
15	buf  [4096]byte
16}
17
18func (o *oobReader) Read(b []byte) (n int, err error) {
19	n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:])
20	if err != nil {
21		return n, err
22	}
23	if flags&syscall.MSG_CTRUNC != 0 {
24		return n, errors.New("dbus: control data truncated (too many fds received)")
25	}
26	o.oob = append(o.oob, o.buf[:oobn]...)
27	return n, nil
28}
29
30type unixTransport struct {
31	*net.UnixConn
32	hasUnixFDs bool
33}
34
35func newUnixTransport(keys string) (transport, error) {
36	var err error
37
38	t := new(unixTransport)
39	abstract := getKey(keys, "abstract")
40	path := getKey(keys, "path")
41	switch {
42	case abstract == "" && path == "":
43		return nil, errors.New("dbus: invalid address (neither path nor abstract set)")
44	case abstract != "" && path == "":
45		t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"})
46		if err != nil {
47			return nil, err
48		}
49		return t, nil
50	case abstract == "" && path != "":
51		t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"})
52		if err != nil {
53			return nil, err
54		}
55		return t, nil
56	default:
57		return nil, errors.New("dbus: invalid address (both path and abstract set)")
58	}
59}
60
61func (t *unixTransport) EnableUnixFDs() {
62	t.hasUnixFDs = true
63}
64
65func (t *unixTransport) ReadMessage() (*Message, error) {
66	var (
67		blen, hlen uint32
68		csheader   [16]byte
69		headers    []header
70		order      binary.ByteOrder
71		unixfds    uint32
72	)
73	// To be sure that all bytes of out-of-band data are read, we use a special
74	// reader that uses ReadUnix on the underlying connection instead of Read
75	// and gathers the out-of-band data in a buffer.
76	rd := &oobReader{conn: t.UnixConn}
77	// read the first 16 bytes (the part of the header that has a constant size),
78	// from which we can figure out the length of the rest of the message
79	if _, err := io.ReadFull(rd, csheader[:]); err != nil {
80		return nil, err
81	}
82	switch csheader[0] {
83	case 'l':
84		order = binary.LittleEndian
85	case 'B':
86		order = binary.BigEndian
87	default:
88		return nil, InvalidMessageError("invalid byte order")
89	}
90	// csheader[4:8] -> length of message body, csheader[12:16] -> length of
91	// header fields (without alignment)
92	binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
93	binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
94	if hlen%8 != 0 {
95		hlen += 8 - (hlen % 8)
96	}
97
98	// decode headers and look for unix fds
99	headerdata := make([]byte, hlen+4)
100	copy(headerdata, csheader[12:])
101	if _, err := io.ReadFull(t, headerdata[4:]); err != nil {
102		return nil, err
103	}
104	dec := newDecoder(bytes.NewBuffer(headerdata), order)
105	dec.pos = 12
106	vs, err := dec.Decode(Signature{"a(yv)"})
107	if err != nil {
108		return nil, err
109	}
110	Store(vs, &headers)
111	for _, v := range headers {
112		if v.Field == byte(FieldUnixFDs) {
113			unixfds, _ = v.Variant.value.(uint32)
114		}
115	}
116	all := make([]byte, 16+hlen+blen)
117	copy(all, csheader[:])
118	copy(all[16:], headerdata[4:])
119	if _, err := io.ReadFull(rd, all[16+hlen:]); err != nil {
120		return nil, err
121	}
122	if unixfds != 0 {
123		if !t.hasUnixFDs {
124			return nil, errors.New("dbus: got unix fds on unsupported transport")
125		}
126		// read the fds from the OOB data
127		scms, err := syscall.ParseSocketControlMessage(rd.oob)
128		if err != nil {
129			return nil, err
130		}
131		if len(scms) != 1 {
132			return nil, errors.New("dbus: received more than one socket control message")
133		}
134		fds, err := syscall.ParseUnixRights(&scms[0])
135		if err != nil {
136			return nil, err
137		}
138		msg, err := DecodeMessage(bytes.NewBuffer(all))
139		if err != nil {
140			return nil, err
141		}
142		// substitute the values in the message body (which are indices for the
143		// array receiver via OOB) with the actual values
144		for i, v := range msg.Body {
145			if j, ok := v.(UnixFDIndex); ok {
146				if uint32(j) >= unixfds {
147					return nil, InvalidMessageError("invalid index for unix fd")
148				}
149				msg.Body[i] = UnixFD(fds[j])
150			}
151		}
152		return msg, nil
153	}
154	return DecodeMessage(bytes.NewBuffer(all))
155}
156
157func (t *unixTransport) SendMessage(msg *Message) error {
158	fds := make([]int, 0)
159	for i, v := range msg.Body {
160		if fd, ok := v.(UnixFD); ok {
161			msg.Body[i] = UnixFDIndex(len(fds))
162			fds = append(fds, int(fd))
163		}
164	}
165	if len(fds) != 0 {
166		if !t.hasUnixFDs {
167			return errors.New("dbus: unix fd passing not enabled")
168		}
169		msg.Headers[FieldUnixFDs] = MakeVariant(uint32(len(fds)))
170		oob := syscall.UnixRights(fds...)
171		buf := new(bytes.Buffer)
172		msg.EncodeTo(buf, binary.LittleEndian)
173		n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil)
174		if err != nil {
175			return err
176		}
177		if n != buf.Len() || oobn != len(oob) {
178			return io.ErrShortWrite
179		}
180	} else {
181		if err := msg.EncodeTo(t, binary.LittleEndian); err != nil {
182			return nil
183		}
184	}
185	return nil
186}
187
188func (t *unixTransport) SupportsUnixFDs() bool {
189	return true
190}
191