1package winio 2 3import ( 4 "fmt" 5 "io" 6 "net" 7 "os" 8 "syscall" 9 "time" 10 "unsafe" 11 12 "github.com/Microsoft/go-winio/pkg/guid" 13) 14 15//sys bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind 16 17const ( 18 afHvSock = 34 // AF_HYPERV 19 20 socketError = ^uintptr(0) 21) 22 23// An HvsockAddr is an address for a AF_HYPERV socket. 24type HvsockAddr struct { 25 VMID guid.GUID 26 ServiceID guid.GUID 27} 28 29type rawHvsockAddr struct { 30 Family uint16 31 _ uint16 32 VMID guid.GUID 33 ServiceID guid.GUID 34} 35 36// Network returns the address's network name, "hvsock". 37func (addr *HvsockAddr) Network() string { 38 return "hvsock" 39} 40 41func (addr *HvsockAddr) String() string { 42 return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID) 43} 44 45// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port. 46func VsockServiceID(port uint32) guid.GUID { 47 g, _ := guid.FromString("00000000-facb-11e6-bd58-64006a7986d3") 48 g.Data1 = port 49 return g 50} 51 52func (addr *HvsockAddr) raw() rawHvsockAddr { 53 return rawHvsockAddr{ 54 Family: afHvSock, 55 VMID: addr.VMID, 56 ServiceID: addr.ServiceID, 57 } 58} 59 60func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) { 61 addr.VMID = raw.VMID 62 addr.ServiceID = raw.ServiceID 63} 64 65// HvsockListener is a socket listener for the AF_HYPERV address family. 66type HvsockListener struct { 67 sock *win32File 68 addr HvsockAddr 69} 70 71// HvsockConn is a connected socket of the AF_HYPERV address family. 72type HvsockConn struct { 73 sock *win32File 74 local, remote HvsockAddr 75} 76 77func newHvSocket() (*win32File, error) { 78 fd, err := syscall.Socket(afHvSock, syscall.SOCK_STREAM, 1) 79 if err != nil { 80 return nil, os.NewSyscallError("socket", err) 81 } 82 f, err := makeWin32File(fd) 83 if err != nil { 84 syscall.Close(fd) 85 return nil, err 86 } 87 f.socket = true 88 return f, nil 89} 90 91// ListenHvsock listens for connections on the specified hvsock address. 92func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) { 93 l := &HvsockListener{addr: *addr} 94 sock, err := newHvSocket() 95 if err != nil { 96 return nil, l.opErr("listen", err) 97 } 98 sa := addr.raw() 99 err = bind(sock.handle, unsafe.Pointer(&sa), int32(unsafe.Sizeof(sa))) 100 if err != nil { 101 return nil, l.opErr("listen", os.NewSyscallError("socket", err)) 102 } 103 err = syscall.Listen(sock.handle, 16) 104 if err != nil { 105 return nil, l.opErr("listen", os.NewSyscallError("listen", err)) 106 } 107 return &HvsockListener{sock: sock, addr: *addr}, nil 108} 109 110func (l *HvsockListener) opErr(op string, err error) error { 111 return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err} 112} 113 114// Addr returns the listener's network address. 115func (l *HvsockListener) Addr() net.Addr { 116 return &l.addr 117} 118 119// Accept waits for the next connection and returns it. 120func (l *HvsockListener) Accept() (_ net.Conn, err error) { 121 sock, err := newHvSocket() 122 if err != nil { 123 return nil, l.opErr("accept", err) 124 } 125 defer func() { 126 if sock != nil { 127 sock.Close() 128 } 129 }() 130 c, err := l.sock.prepareIo() 131 if err != nil { 132 return nil, l.opErr("accept", err) 133 } 134 defer l.sock.wg.Done() 135 136 // AcceptEx, per documentation, requires an extra 16 bytes per address. 137 const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{})) 138 var addrbuf [addrlen * 2]byte 139 140 var bytes uint32 141 err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0, addrlen, addrlen, &bytes, &c.o) 142 _, err = l.sock.asyncIo(c, nil, bytes, err) 143 if err != nil { 144 return nil, l.opErr("accept", os.NewSyscallError("acceptex", err)) 145 } 146 conn := &HvsockConn{ 147 sock: sock, 148 } 149 conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0]))) 150 conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen]))) 151 sock = nil 152 return conn, nil 153} 154 155// Close closes the listener, causing any pending Accept calls to fail. 156func (l *HvsockListener) Close() error { 157 return l.sock.Close() 158} 159 160/* Need to finish ConnectEx handling 161func DialHvsock(ctx context.Context, addr *HvsockAddr) (*HvsockConn, error) { 162 sock, err := newHvSocket() 163 if err != nil { 164 return nil, err 165 } 166 defer func() { 167 if sock != nil { 168 sock.Close() 169 } 170 }() 171 c, err := sock.prepareIo() 172 if err != nil { 173 return nil, err 174 } 175 defer sock.wg.Done() 176 var bytes uint32 177 err = windows.ConnectEx(windows.Handle(sock.handle), sa, nil, 0, &bytes, &c.o) 178 _, err = sock.asyncIo(ctx, c, nil, bytes, err) 179 if err != nil { 180 return nil, err 181 } 182 conn := &HvsockConn{ 183 sock: sock, 184 remote: *addr, 185 } 186 sock = nil 187 return conn, nil 188} 189*/ 190 191func (conn *HvsockConn) opErr(op string, err error) error { 192 return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err} 193} 194 195func (conn *HvsockConn) Read(b []byte) (int, error) { 196 c, err := conn.sock.prepareIo() 197 if err != nil { 198 return 0, conn.opErr("read", err) 199 } 200 defer conn.sock.wg.Done() 201 buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} 202 var flags, bytes uint32 203 err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil) 204 n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err) 205 if err != nil { 206 if _, ok := err.(syscall.Errno); ok { 207 err = os.NewSyscallError("wsarecv", err) 208 } 209 return 0, conn.opErr("read", err) 210 } else if n == 0 { 211 err = io.EOF 212 } 213 return n, err 214} 215 216func (conn *HvsockConn) Write(b []byte) (int, error) { 217 t := 0 218 for len(b) != 0 { 219 n, err := conn.write(b) 220 if err != nil { 221 return t + n, err 222 } 223 t += n 224 b = b[n:] 225 } 226 return t, nil 227} 228 229func (conn *HvsockConn) write(b []byte) (int, error) { 230 c, err := conn.sock.prepareIo() 231 if err != nil { 232 return 0, conn.opErr("write", err) 233 } 234 defer conn.sock.wg.Done() 235 buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} 236 var bytes uint32 237 err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil) 238 n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err) 239 if err != nil { 240 if _, ok := err.(syscall.Errno); ok { 241 err = os.NewSyscallError("wsasend", err) 242 } 243 return 0, conn.opErr("write", err) 244 } 245 return n, err 246} 247 248// Close closes the socket connection, failing any pending read or write calls. 249func (conn *HvsockConn) Close() error { 250 return conn.sock.Close() 251} 252 253func (conn *HvsockConn) shutdown(how int) error { 254 err := syscall.Shutdown(conn.sock.handle, syscall.SHUT_RD) 255 if err != nil { 256 return os.NewSyscallError("shutdown", err) 257 } 258 return nil 259} 260 261// CloseRead shuts down the read end of the socket. 262func (conn *HvsockConn) CloseRead() error { 263 err := conn.shutdown(syscall.SHUT_RD) 264 if err != nil { 265 return conn.opErr("close", err) 266 } 267 return nil 268} 269 270// CloseWrite shuts down the write end of the socket, notifying the other endpoint that 271// no more data will be written. 272func (conn *HvsockConn) CloseWrite() error { 273 err := conn.shutdown(syscall.SHUT_WR) 274 if err != nil { 275 return conn.opErr("close", err) 276 } 277 return nil 278} 279 280// LocalAddr returns the local address of the connection. 281func (conn *HvsockConn) LocalAddr() net.Addr { 282 return &conn.local 283} 284 285// RemoteAddr returns the remote address of the connection. 286func (conn *HvsockConn) RemoteAddr() net.Addr { 287 return &conn.remote 288} 289 290// SetDeadline implements the net.Conn SetDeadline method. 291func (conn *HvsockConn) SetDeadline(t time.Time) error { 292 conn.SetReadDeadline(t) 293 conn.SetWriteDeadline(t) 294 return nil 295} 296 297// SetReadDeadline implements the net.Conn SetReadDeadline method. 298func (conn *HvsockConn) SetReadDeadline(t time.Time) error { 299 return conn.sock.SetReadDeadline(t) 300} 301 302// SetWriteDeadline implements the net.Conn SetWriteDeadline method. 303func (conn *HvsockConn) SetWriteDeadline(t time.Time) error { 304 return conn.sock.SetWriteDeadline(t) 305} 306