1package dbus 2 3import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "io" 8 "net" 9 "syscall" 10) 11 12type oobReader struct { 13 conn *net.UnixConn 14 oob []byte 15 buf [4096]byte 16} 17 18func (o *oobReader) Read(b []byte) (n int, err error) { 19 n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:]) 20 if err != nil { 21 return n, err 22 } 23 if flags&syscall.MSG_CTRUNC != 0 { 24 return n, errors.New("dbus: control data truncated (too many fds received)") 25 } 26 o.oob = append(o.oob, o.buf[:oobn]...) 27 return n, nil 28} 29 30type unixTransport struct { 31 *net.UnixConn 32 hasUnixFDs bool 33} 34 35func newUnixTransport(keys string) (transport, error) { 36 var err error 37 38 t := new(unixTransport) 39 abstract := getKey(keys, "abstract") 40 path := getKey(keys, "path") 41 switch { 42 case abstract == "" && path == "": 43 return nil, errors.New("dbus: invalid address (neither path nor abstract set)") 44 case abstract != "" && path == "": 45 t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"}) 46 if err != nil { 47 return nil, err 48 } 49 return t, nil 50 case abstract == "" && path != "": 51 t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"}) 52 if err != nil { 53 return nil, err 54 } 55 return t, nil 56 default: 57 return nil, errors.New("dbus: invalid address (both path and abstract set)") 58 } 59} 60 61func (t *unixTransport) EnableUnixFDs() { 62 t.hasUnixFDs = true 63} 64 65func (t *unixTransport) ReadMessage() (*Message, error) { 66 var ( 67 blen, hlen uint32 68 csheader [16]byte 69 headers []header 70 order binary.ByteOrder 71 unixfds uint32 72 ) 73 // To be sure that all bytes of out-of-band data are read, we use a special 74 // reader that uses ReadUnix on the underlying connection instead of Read 75 // and gathers the out-of-band data in a buffer. 76 rd := &oobReader{conn: t.UnixConn} 77 // read the first 16 bytes (the part of the header that has a constant size), 78 // from which we can figure out the length of the rest of the message 79 if _, err := io.ReadFull(rd, csheader[:]); err != nil { 80 return nil, err 81 } 82 switch csheader[0] { 83 case 'l': 84 order = binary.LittleEndian 85 case 'B': 86 order = binary.BigEndian 87 default: 88 return nil, InvalidMessageError("invalid byte order") 89 } 90 // csheader[4:8] -> length of message body, csheader[12:16] -> length of 91 // header fields (without alignment) 92 binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen) 93 binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen) 94 if hlen%8 != 0 { 95 hlen += 8 - (hlen % 8) 96 } 97 98 // decode headers and look for unix fds 99 headerdata := make([]byte, hlen+4) 100 copy(headerdata, csheader[12:]) 101 if _, err := io.ReadFull(t, headerdata[4:]); err != nil { 102 return nil, err 103 } 104 dec := newDecoder(bytes.NewBuffer(headerdata), order) 105 dec.pos = 12 106 vs, err := dec.Decode(Signature{"a(yv)"}) 107 if err != nil { 108 return nil, err 109 } 110 Store(vs, &headers) 111 for _, v := range headers { 112 if v.Field == byte(FieldUnixFDs) { 113 unixfds, _ = v.Variant.value.(uint32) 114 } 115 } 116 all := make([]byte, 16+hlen+blen) 117 copy(all, csheader[:]) 118 copy(all[16:], headerdata[4:]) 119 if _, err := io.ReadFull(rd, all[16+hlen:]); err != nil { 120 return nil, err 121 } 122 if unixfds != 0 { 123 if !t.hasUnixFDs { 124 return nil, errors.New("dbus: got unix fds on unsupported transport") 125 } 126 // read the fds from the OOB data 127 scms, err := syscall.ParseSocketControlMessage(rd.oob) 128 if err != nil { 129 return nil, err 130 } 131 if len(scms) != 1 { 132 return nil, errors.New("dbus: received more than one socket control message") 133 } 134 fds, err := syscall.ParseUnixRights(&scms[0]) 135 if err != nil { 136 return nil, err 137 } 138 msg, err := DecodeMessage(bytes.NewBuffer(all)) 139 if err != nil { 140 return nil, err 141 } 142 // substitute the values in the message body (which are indices for the 143 // array receiver via OOB) with the actual values 144 for i, v := range msg.Body { 145 if j, ok := v.(UnixFDIndex); ok { 146 if uint32(j) >= unixfds { 147 return nil, InvalidMessageError("invalid index for unix fd") 148 } 149 msg.Body[i] = UnixFD(fds[j]) 150 } 151 } 152 return msg, nil 153 } 154 return DecodeMessage(bytes.NewBuffer(all)) 155} 156 157func (t *unixTransport) SendMessage(msg *Message) error { 158 fds := make([]int, 0) 159 for i, v := range msg.Body { 160 if fd, ok := v.(UnixFD); ok { 161 msg.Body[i] = UnixFDIndex(len(fds)) 162 fds = append(fds, int(fd)) 163 } 164 } 165 if len(fds) != 0 { 166 if !t.hasUnixFDs { 167 return errors.New("dbus: unix fd passing not enabled") 168 } 169 msg.Headers[FieldUnixFDs] = MakeVariant(uint32(len(fds))) 170 oob := syscall.UnixRights(fds...) 171 buf := new(bytes.Buffer) 172 msg.EncodeTo(buf, binary.LittleEndian) 173 n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil) 174 if err != nil { 175 return err 176 } 177 if n != buf.Len() || oobn != len(oob) { 178 return io.ErrShortWrite 179 } 180 } else { 181 if err := msg.EncodeTo(t, binary.LittleEndian); err != nil { 182 return nil 183 } 184 } 185 return nil 186} 187 188func (t *unixTransport) SupportsUnixFDs() bool { 189 return true 190} 191