1// Copyright 2012 Google Inc. All rights reserved.
2// Use of this source code is governed by the Apache 2.0
3// license that can be found in the LICENSE file.
4
5// +build appengine
6
7package socket
8
9import (
10	"fmt"
11	"io"
12	"net"
13	"strconv"
14	"time"
15
16	"github.com/golang/protobuf/proto"
17	"golang.org/x/net/context"
18	"google.golang.org/appengine/internal"
19
20	pb "google.golang.org/appengine/internal/socket"
21)
22
23// Dial connects to the address addr on the network protocol.
24// The address format is host:port, where host may be a hostname or an IP address.
25// Known protocols are "tcp" and "udp".
26// The returned connection satisfies net.Conn, and is valid while ctx is valid;
27// if the connection is to be used after ctx becomes invalid, invoke SetContext
28// with the new context.
29func Dial(ctx context.Context, protocol, addr string) (*Conn, error) {
30	return DialTimeout(ctx, protocol, addr, 0)
31}
32
33var ipFamilies = []pb.CreateSocketRequest_SocketFamily{
34	pb.CreateSocketRequest_IPv4,
35	pb.CreateSocketRequest_IPv6,
36}
37
38// DialTimeout is like Dial but takes a timeout.
39// The timeout includes name resolution, if required.
40func DialTimeout(ctx context.Context, protocol, addr string, timeout time.Duration) (*Conn, error) {
41	dialCtx := ctx // Used for dialing and name resolution, but not stored in the *Conn.
42	if timeout > 0 {
43		var cancel context.CancelFunc
44		dialCtx, cancel = context.WithTimeout(ctx, timeout)
45		defer cancel()
46	}
47
48	host, portStr, err := net.SplitHostPort(addr)
49	if err != nil {
50		return nil, err
51	}
52	port, err := strconv.Atoi(portStr)
53	if err != nil {
54		return nil, fmt.Errorf("socket: bad port %q: %v", portStr, err)
55	}
56
57	var prot pb.CreateSocketRequest_SocketProtocol
58	switch protocol {
59	case "tcp":
60		prot = pb.CreateSocketRequest_TCP
61	case "udp":
62		prot = pb.CreateSocketRequest_UDP
63	default:
64		return nil, fmt.Errorf("socket: unknown protocol %q", protocol)
65	}
66
67	packedAddrs, resolved, err := resolve(dialCtx, ipFamilies, host)
68	if err != nil {
69		return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err)
70	}
71	if len(packedAddrs) == 0 {
72		return nil, fmt.Errorf("no addresses for %q", host)
73	}
74
75	packedAddr := packedAddrs[0] // use first address
76	fam := pb.CreateSocketRequest_IPv4
77	if len(packedAddr) == net.IPv6len {
78		fam = pb.CreateSocketRequest_IPv6
79	}
80
81	req := &pb.CreateSocketRequest{
82		Family:   fam.Enum(),
83		Protocol: prot.Enum(),
84		RemoteIp: &pb.AddressPort{
85			Port:          proto.Int32(int32(port)),
86			PackedAddress: packedAddr,
87		},
88	}
89	if resolved {
90		req.RemoteIp.HostnameHint = &host
91	}
92	res := &pb.CreateSocketReply{}
93	if err := internal.Call(dialCtx, "remote_socket", "CreateSocket", req, res); err != nil {
94		return nil, err
95	}
96
97	return &Conn{
98		ctx:    ctx,
99		desc:   res.GetSocketDescriptor(),
100		prot:   prot,
101		local:  res.ProxyExternalIp,
102		remote: req.RemoteIp,
103	}, nil
104}
105
106// LookupIP returns the given host's IP addresses.
107func LookupIP(ctx context.Context, host string) (addrs []net.IP, err error) {
108	packedAddrs, _, err := resolve(ctx, ipFamilies, host)
109	if err != nil {
110		return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err)
111	}
112	addrs = make([]net.IP, len(packedAddrs))
113	for i, pa := range packedAddrs {
114		addrs[i] = net.IP(pa)
115	}
116	return addrs, nil
117}
118
119func resolve(ctx context.Context, fams []pb.CreateSocketRequest_SocketFamily, host string) ([][]byte, bool, error) {
120	// Check if it's an IP address.
121	if ip := net.ParseIP(host); ip != nil {
122		if ip := ip.To4(); ip != nil {
123			return [][]byte{ip}, false, nil
124		}
125		return [][]byte{ip}, false, nil
126	}
127
128	req := &pb.ResolveRequest{
129		Name:            &host,
130		AddressFamilies: fams,
131	}
132	res := &pb.ResolveReply{}
133	if err := internal.Call(ctx, "remote_socket", "Resolve", req, res); err != nil {
134		// XXX: need to map to pb.ResolveReply_ErrorCode?
135		return nil, false, err
136	}
137	return res.PackedAddress, true, nil
138}
139
140// withDeadline is like context.WithDeadline, except it ignores the zero deadline.
141func withDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) {
142	if deadline.IsZero() {
143		return parent, func() {}
144	}
145	return context.WithDeadline(parent, deadline)
146}
147
148// Conn represents a socket connection.
149// It implements net.Conn.
150type Conn struct {
151	ctx    context.Context
152	desc   string
153	offset int64
154
155	prot          pb.CreateSocketRequest_SocketProtocol
156	local, remote *pb.AddressPort
157
158	readDeadline, writeDeadline time.Time // optional
159}
160
161// SetContext sets the context that is used by this Conn.
162// It is usually used only when using a Conn that was created in a different context,
163// such as when a connection is created during a warmup request but used while
164// servicing a user request.
165func (cn *Conn) SetContext(ctx context.Context) {
166	cn.ctx = ctx
167}
168
169func (cn *Conn) Read(b []byte) (n int, err error) {
170	const maxRead = 1 << 20
171	if len(b) > maxRead {
172		b = b[:maxRead]
173	}
174
175	req := &pb.ReceiveRequest{
176		SocketDescriptor: &cn.desc,
177		DataSize:         proto.Int32(int32(len(b))),
178	}
179	res := &pb.ReceiveReply{}
180	if !cn.readDeadline.IsZero() {
181		req.TimeoutSeconds = proto.Float64(cn.readDeadline.Sub(time.Now()).Seconds())
182	}
183	ctx, cancel := withDeadline(cn.ctx, cn.readDeadline)
184	defer cancel()
185	if err := internal.Call(ctx, "remote_socket", "Receive", req, res); err != nil {
186		return 0, err
187	}
188	if len(res.Data) == 0 {
189		return 0, io.EOF
190	}
191	if len(res.Data) > len(b) {
192		return 0, fmt.Errorf("socket: internal error: read too much data: %d > %d", len(res.Data), len(b))
193	}
194	return copy(b, res.Data), nil
195}
196
197func (cn *Conn) Write(b []byte) (n int, err error) {
198	const lim = 1 << 20 // max per chunk
199
200	for n < len(b) {
201		chunk := b[n:]
202		if len(chunk) > lim {
203			chunk = chunk[:lim]
204		}
205
206		req := &pb.SendRequest{
207			SocketDescriptor: &cn.desc,
208			Data:             chunk,
209			StreamOffset:     &cn.offset,
210		}
211		res := &pb.SendReply{}
212		if !cn.writeDeadline.IsZero() {
213			req.TimeoutSeconds = proto.Float64(cn.writeDeadline.Sub(time.Now()).Seconds())
214		}
215		ctx, cancel := withDeadline(cn.ctx, cn.writeDeadline)
216		defer cancel()
217		if err = internal.Call(ctx, "remote_socket", "Send", req, res); err != nil {
218			// assume zero bytes were sent in this RPC
219			break
220		}
221		n += int(res.GetDataSent())
222		cn.offset += int64(res.GetDataSent())
223	}
224
225	return
226}
227
228func (cn *Conn) Close() error {
229	req := &pb.CloseRequest{
230		SocketDescriptor: &cn.desc,
231	}
232	res := &pb.CloseReply{}
233	if err := internal.Call(cn.ctx, "remote_socket", "Close", req, res); err != nil {
234		return err
235	}
236	cn.desc = "CLOSED"
237	return nil
238}
239
240func addr(prot pb.CreateSocketRequest_SocketProtocol, ap *pb.AddressPort) net.Addr {
241	if ap == nil {
242		return nil
243	}
244	switch prot {
245	case pb.CreateSocketRequest_TCP:
246		return &net.TCPAddr{
247			IP:   net.IP(ap.PackedAddress),
248			Port: int(*ap.Port),
249		}
250	case pb.CreateSocketRequest_UDP:
251		return &net.UDPAddr{
252			IP:   net.IP(ap.PackedAddress),
253			Port: int(*ap.Port),
254		}
255	}
256	panic("unknown protocol " + prot.String())
257}
258
259func (cn *Conn) LocalAddr() net.Addr  { return addr(cn.prot, cn.local) }
260func (cn *Conn) RemoteAddr() net.Addr { return addr(cn.prot, cn.remote) }
261
262func (cn *Conn) SetDeadline(t time.Time) error {
263	cn.readDeadline = t
264	cn.writeDeadline = t
265	return nil
266}
267
268func (cn *Conn) SetReadDeadline(t time.Time) error {
269	cn.readDeadline = t
270	return nil
271}
272
273func (cn *Conn) SetWriteDeadline(t time.Time) error {
274	cn.writeDeadline = t
275	return nil
276}
277
278// KeepAlive signals that the connection is still in use.
279// It may be called to prevent the socket being closed due to inactivity.
280func (cn *Conn) KeepAlive() error {
281	req := &pb.GetSocketNameRequest{
282		SocketDescriptor: &cn.desc,
283	}
284	res := &pb.GetSocketNameReply{}
285	return internal.Call(cn.ctx, "remote_socket", "GetSocketName", req, res)
286}
287
288func init() {
289	internal.RegisterErrorCodeMap("remote_socket", pb.RemoteSocketServiceError_ErrorCode_name)
290}
291