1//+build !windows,!solaris
2
3package dbus
4
5import (
6	"bytes"
7	"encoding/binary"
8	"errors"
9	"io"
10	"net"
11	"syscall"
12)
13
14type oobReader struct {
15	conn *net.UnixConn
16	oob  []byte
17	buf  [4096]byte
18}
19
20func (o *oobReader) Read(b []byte) (n int, err error) {
21	n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:])
22	if err != nil {
23		return n, err
24	}
25	if flags&syscall.MSG_CTRUNC != 0 {
26		return n, errors.New("dbus: control data truncated (too many fds received)")
27	}
28	o.oob = append(o.oob, o.buf[:oobn]...)
29	return n, nil
30}
31
32type unixTransport struct {
33	*net.UnixConn
34	rdr        *oobReader
35	hasUnixFDs bool
36}
37
38func newUnixTransport(keys string) (transport, error) {
39	var err error
40
41	t := new(unixTransport)
42	abstract := getKey(keys, "abstract")
43	path := getKey(keys, "path")
44	switch {
45	case abstract == "" && path == "":
46		return nil, errors.New("dbus: invalid address (neither path nor abstract set)")
47	case abstract != "" && path == "":
48		t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"})
49		if err != nil {
50			return nil, err
51		}
52		return t, nil
53	case abstract == "" && path != "":
54		t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"})
55		if err != nil {
56			return nil, err
57		}
58		return t, nil
59	default:
60		return nil, errors.New("dbus: invalid address (both path and abstract set)")
61	}
62}
63
64func init() {
65	transports["unix"] = newUnixTransport
66}
67
68func (t *unixTransport) EnableUnixFDs() {
69	t.hasUnixFDs = true
70}
71
72func (t *unixTransport) ReadMessage() (*Message, error) {
73	var (
74		blen, hlen uint32
75		csheader   [16]byte
76		headers    []header
77		order      binary.ByteOrder
78		unixfds    uint32
79	)
80	// To be sure that all bytes of out-of-band data are read, we use a special
81	// reader that uses ReadUnix on the underlying connection instead of Read
82	// and gathers the out-of-band data in a buffer.
83	if t.rdr == nil {
84		t.rdr = &oobReader{conn: t.UnixConn}
85	} else {
86		t.rdr.oob = nil
87	}
88
89	// read the first 16 bytes (the part of the header that has a constant size),
90	// from which we can figure out the length of the rest of the message
91	if _, err := io.ReadFull(t.rdr, csheader[:]); err != nil {
92		return nil, err
93	}
94	switch csheader[0] {
95	case 'l':
96		order = binary.LittleEndian
97	case 'B':
98		order = binary.BigEndian
99	default:
100		return nil, InvalidMessageError("invalid byte order")
101	}
102	// csheader[4:8] -> length of message body, csheader[12:16] -> length of
103	// header fields (without alignment)
104	binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
105	binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
106	if hlen%8 != 0 {
107		hlen += 8 - (hlen % 8)
108	}
109
110	// decode headers and look for unix fds
111	headerdata := make([]byte, hlen+4)
112	copy(headerdata, csheader[12:])
113	if _, err := io.ReadFull(t.rdr, headerdata[4:]); err != nil {
114		return nil, err
115	}
116	dec := newDecoder(bytes.NewBuffer(headerdata), order)
117	dec.pos = 12
118	vs, err := dec.Decode(Signature{"a(yv)"})
119	if err != nil {
120		return nil, err
121	}
122	Store(vs, &headers)
123	for _, v := range headers {
124		if v.Field == byte(FieldUnixFDs) {
125			unixfds, _ = v.Variant.value.(uint32)
126		}
127	}
128	all := make([]byte, 16+hlen+blen)
129	copy(all, csheader[:])
130	copy(all[16:], headerdata[4:])
131	if _, err := io.ReadFull(t.rdr, all[16+hlen:]); err != nil {
132		return nil, err
133	}
134	if unixfds != 0 {
135		if !t.hasUnixFDs {
136			return nil, errors.New("dbus: got unix fds on unsupported transport")
137		}
138		// read the fds from the OOB data
139		scms, err := syscall.ParseSocketControlMessage(t.rdr.oob)
140		if err != nil {
141			return nil, err
142		}
143		if len(scms) != 1 {
144			return nil, errors.New("dbus: received more than one socket control message")
145		}
146		fds, err := syscall.ParseUnixRights(&scms[0])
147		if err != nil {
148			return nil, err
149		}
150		msg, err := DecodeMessage(bytes.NewBuffer(all))
151		if err != nil {
152			return nil, err
153		}
154		// substitute the values in the message body (which are indices for the
155		// array receiver via OOB) with the actual values
156		for i, v := range msg.Body {
157			switch v.(type) {
158			case UnixFDIndex:
159				j := v.(UnixFDIndex)
160				if uint32(j) >= unixfds {
161					return nil, InvalidMessageError("invalid index for unix fd")
162				}
163				msg.Body[i] = UnixFD(fds[j])
164			case []UnixFDIndex:
165				idxArray := v.([]UnixFDIndex)
166				fdArray := make([]UnixFD, len(idxArray))
167				for k, j := range idxArray {
168					if uint32(j) >= unixfds {
169						return nil, InvalidMessageError("invalid index for unix fd")
170					}
171					fdArray[k] = UnixFD(fds[j])
172				}
173				msg.Body[i] = fdArray
174			}
175		}
176		return msg, nil
177	}
178	return DecodeMessage(bytes.NewBuffer(all))
179}
180
181func (t *unixTransport) SendMessage(msg *Message) error {
182	fds := make([]int, 0)
183	for i, v := range msg.Body {
184		if fd, ok := v.(UnixFD); ok {
185			msg.Body[i] = UnixFDIndex(len(fds))
186			fds = append(fds, int(fd))
187		}
188	}
189	if len(fds) != 0 {
190		if !t.hasUnixFDs {
191			return errors.New("dbus: unix fd passing not enabled")
192		}
193		msg.Headers[FieldUnixFDs] = MakeVariant(uint32(len(fds)))
194		oob := syscall.UnixRights(fds...)
195		buf := new(bytes.Buffer)
196		msg.EncodeTo(buf, nativeEndian)
197		n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil)
198		if err != nil {
199			return err
200		}
201		if n != buf.Len() || oobn != len(oob) {
202			return io.ErrShortWrite
203		}
204	} else {
205		if err := msg.EncodeTo(t, nativeEndian); err != nil {
206			return err
207		}
208	}
209	return nil
210}
211
212func (t *unixTransport) SupportsUnixFDs() bool {
213	return true
214}
215