1package winio
2
3import (
4	"fmt"
5	"io"
6	"net"
7	"os"
8	"syscall"
9	"time"
10	"unsafe"
11
12	"github.com/Microsoft/go-winio/pkg/guid"
13)
14
15//sys bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
16
17const (
18	afHvSock = 34 // AF_HYPERV
19
20	socketError = ^uintptr(0)
21)
22
23// An HvsockAddr is an address for a AF_HYPERV socket.
24type HvsockAddr struct {
25	VMID      guid.GUID
26	ServiceID guid.GUID
27}
28
29type rawHvsockAddr struct {
30	Family    uint16
31	_         uint16
32	VMID      guid.GUID
33	ServiceID guid.GUID
34}
35
36// Network returns the address's network name, "hvsock".
37func (addr *HvsockAddr) Network() string {
38	return "hvsock"
39}
40
41func (addr *HvsockAddr) String() string {
42	return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
43}
44
45// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
46func VsockServiceID(port uint32) guid.GUID {
47	g, _ := guid.FromString("00000000-facb-11e6-bd58-64006a7986d3")
48	g.Data1 = port
49	return g
50}
51
52func (addr *HvsockAddr) raw() rawHvsockAddr {
53	return rawHvsockAddr{
54		Family:    afHvSock,
55		VMID:      addr.VMID,
56		ServiceID: addr.ServiceID,
57	}
58}
59
60func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
61	addr.VMID = raw.VMID
62	addr.ServiceID = raw.ServiceID
63}
64
65// HvsockListener is a socket listener for the AF_HYPERV address family.
66type HvsockListener struct {
67	sock *win32File
68	addr HvsockAddr
69}
70
71// HvsockConn is a connected socket of the AF_HYPERV address family.
72type HvsockConn struct {
73	sock          *win32File
74	local, remote HvsockAddr
75}
76
77func newHvSocket() (*win32File, error) {
78	fd, err := syscall.Socket(afHvSock, syscall.SOCK_STREAM, 1)
79	if err != nil {
80		return nil, os.NewSyscallError("socket", err)
81	}
82	f, err := makeWin32File(fd)
83	if err != nil {
84		syscall.Close(fd)
85		return nil, err
86	}
87	f.socket = true
88	return f, nil
89}
90
91// ListenHvsock listens for connections on the specified hvsock address.
92func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
93	l := &HvsockListener{addr: *addr}
94	sock, err := newHvSocket()
95	if err != nil {
96		return nil, l.opErr("listen", err)
97	}
98	sa := addr.raw()
99	err = bind(sock.handle, unsafe.Pointer(&sa), int32(unsafe.Sizeof(sa)))
100	if err != nil {
101		return nil, l.opErr("listen", os.NewSyscallError("socket", err))
102	}
103	err = syscall.Listen(sock.handle, 16)
104	if err != nil {
105		return nil, l.opErr("listen", os.NewSyscallError("listen", err))
106	}
107	return &HvsockListener{sock: sock, addr: *addr}, nil
108}
109
110func (l *HvsockListener) opErr(op string, err error) error {
111	return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
112}
113
114// Addr returns the listener's network address.
115func (l *HvsockListener) Addr() net.Addr {
116	return &l.addr
117}
118
119// Accept waits for the next connection and returns it.
120func (l *HvsockListener) Accept() (_ net.Conn, err error) {
121	sock, err := newHvSocket()
122	if err != nil {
123		return nil, l.opErr("accept", err)
124	}
125	defer func() {
126		if sock != nil {
127			sock.Close()
128		}
129	}()
130	c, err := l.sock.prepareIo()
131	if err != nil {
132		return nil, l.opErr("accept", err)
133	}
134	defer l.sock.wg.Done()
135
136	// AcceptEx, per documentation, requires an extra 16 bytes per address.
137	const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
138	var addrbuf [addrlen * 2]byte
139
140	var bytes uint32
141	err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0, addrlen, addrlen, &bytes, &c.o)
142	_, err = l.sock.asyncIo(c, nil, bytes, err)
143	if err != nil {
144		return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
145	}
146	conn := &HvsockConn{
147		sock: sock,
148	}
149	conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
150	conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
151	sock = nil
152	return conn, nil
153}
154
155// Close closes the listener, causing any pending Accept calls to fail.
156func (l *HvsockListener) Close() error {
157	return l.sock.Close()
158}
159
160/* Need to finish ConnectEx handling
161func DialHvsock(ctx context.Context, addr *HvsockAddr) (*HvsockConn, error) {
162	sock, err := newHvSocket()
163	if err != nil {
164		return nil, err
165	}
166	defer func() {
167		if sock != nil {
168			sock.Close()
169		}
170	}()
171	c, err := sock.prepareIo()
172	if err != nil {
173		return nil, err
174	}
175	defer sock.wg.Done()
176	var bytes uint32
177	err = windows.ConnectEx(windows.Handle(sock.handle), sa, nil, 0, &bytes, &c.o)
178	_, err = sock.asyncIo(ctx, c, nil, bytes, err)
179	if err != nil {
180		return nil, err
181	}
182	conn := &HvsockConn{
183		sock:   sock,
184		remote: *addr,
185	}
186	sock = nil
187	return conn, nil
188}
189*/
190
191func (conn *HvsockConn) opErr(op string, err error) error {
192	return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
193}
194
195func (conn *HvsockConn) Read(b []byte) (int, error) {
196	c, err := conn.sock.prepareIo()
197	if err != nil {
198		return 0, conn.opErr("read", err)
199	}
200	defer conn.sock.wg.Done()
201	buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
202	var flags, bytes uint32
203	err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
204	n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err)
205	if err != nil {
206		if _, ok := err.(syscall.Errno); ok {
207			err = os.NewSyscallError("wsarecv", err)
208		}
209		return 0, conn.opErr("read", err)
210	} else if n == 0 {
211		err = io.EOF
212	}
213	return n, err
214}
215
216func (conn *HvsockConn) Write(b []byte) (int, error) {
217	t := 0
218	for len(b) != 0 {
219		n, err := conn.write(b)
220		if err != nil {
221			return t + n, err
222		}
223		t += n
224		b = b[n:]
225	}
226	return t, nil
227}
228
229func (conn *HvsockConn) write(b []byte) (int, error) {
230	c, err := conn.sock.prepareIo()
231	if err != nil {
232		return 0, conn.opErr("write", err)
233	}
234	defer conn.sock.wg.Done()
235	buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
236	var bytes uint32
237	err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
238	n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err)
239	if err != nil {
240		if _, ok := err.(syscall.Errno); ok {
241			err = os.NewSyscallError("wsasend", err)
242		}
243		return 0, conn.opErr("write", err)
244	}
245	return n, err
246}
247
248// Close closes the socket connection, failing any pending read or write calls.
249func (conn *HvsockConn) Close() error {
250	return conn.sock.Close()
251}
252
253func (conn *HvsockConn) shutdown(how int) error {
254	err := syscall.Shutdown(conn.sock.handle, syscall.SHUT_RD)
255	if err != nil {
256		return os.NewSyscallError("shutdown", err)
257	}
258	return nil
259}
260
261// CloseRead shuts down the read end of the socket.
262func (conn *HvsockConn) CloseRead() error {
263	err := conn.shutdown(syscall.SHUT_RD)
264	if err != nil {
265		return conn.opErr("close", err)
266	}
267	return nil
268}
269
270// CloseWrite shuts down the write end of the socket, notifying the other endpoint that
271// no more data will be written.
272func (conn *HvsockConn) CloseWrite() error {
273	err := conn.shutdown(syscall.SHUT_WR)
274	if err != nil {
275		return conn.opErr("close", err)
276	}
277	return nil
278}
279
280// LocalAddr returns the local address of the connection.
281func (conn *HvsockConn) LocalAddr() net.Addr {
282	return &conn.local
283}
284
285// RemoteAddr returns the remote address of the connection.
286func (conn *HvsockConn) RemoteAddr() net.Addr {
287	return &conn.remote
288}
289
290// SetDeadline implements the net.Conn SetDeadline method.
291func (conn *HvsockConn) SetDeadline(t time.Time) error {
292	conn.SetReadDeadline(t)
293	conn.SetWriteDeadline(t)
294	return nil
295}
296
297// SetReadDeadline implements the net.Conn SetReadDeadline method.
298func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
299	return conn.sock.SetReadDeadline(t)
300}
301
302// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
303func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
304	return conn.sock.SetWriteDeadline(t)
305}
306