1package grpc
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"io"
8	"net"
9	"strconv"
10	"sync/atomic"
11	"testing"
12	"time"
13
14	"github.com/stretchr/testify/require"
15	"golang.org/x/sync/errgroup"
16	"google.golang.org/grpc"
17
18	"github.com/hashicorp/consul/agent/grpc/internal/testservice"
19	"github.com/hashicorp/consul/agent/metadata"
20	"github.com/hashicorp/consul/agent/pool"
21	"github.com/hashicorp/consul/sdk/freeport"
22	"github.com/hashicorp/consul/tlsutil"
23)
24
25type testServer struct {
26	addr     net.Addr
27	name     string
28	dc       string
29	shutdown func()
30	rpc      *fakeRPCListener
31}
32
33func (s testServer) Metadata() *metadata.Server {
34	return &metadata.Server{
35		ID:         s.name,
36		Name:       s.name + "." + s.dc,
37		ShortName:  s.name,
38		Datacenter: s.dc,
39		Addr:       s.addr,
40		UseTLS:     s.rpc.tlsConf != nil,
41	}
42}
43
44func newTestServer(t *testing.T, name string, dc string, tlsConf *tlsutil.Configurator) testServer {
45	addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
46	handler := NewHandler(addr, func(server *grpc.Server) {
47		testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc})
48	})
49
50	ports := freeport.MustTake(1)
51	t.Cleanup(func() {
52		freeport.Return(ports)
53	})
54
55	lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0])))
56	require.NoError(t, err)
57
58	rpc := &fakeRPCListener{t: t, handler: handler, tlsConf: tlsConf}
59
60	g := errgroup.Group{}
61	g.Go(func() error {
62		if err := rpc.listen(lis); err != nil {
63			return fmt.Errorf("fake rpc listen error: %w", err)
64		}
65		return nil
66	})
67	g.Go(func() error {
68		if err := handler.Run(); err != nil {
69			return fmt.Errorf("grpc server error: %w", err)
70		}
71		return nil
72	})
73	return testServer{
74		addr: lis.Addr(),
75		name: name,
76		dc:   dc,
77		rpc:  rpc,
78		shutdown: func() {
79			rpc.shutdown = true
80			if err := lis.Close(); err != nil {
81				t.Logf("listener closed with error: %v", err)
82			}
83			if err := handler.Shutdown(); err != nil {
84				t.Logf("grpc server shutdown: %v", err)
85			}
86			if err := g.Wait(); err != nil {
87				t.Log(err)
88			}
89		},
90	}
91}
92
93type simple struct {
94	name string
95	dc   string
96}
97
98func (s *simple) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) error {
99	for flow.Context().Err() == nil {
100		resp := &testservice.Resp{ServerName: "one", Datacenter: s.dc}
101		if err := flow.Send(resp); err != nil {
102			return err
103		}
104		time.Sleep(time.Millisecond)
105	}
106	return nil
107}
108
109func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) {
110	return &testservice.Resp{ServerName: s.name, Datacenter: s.dc}, nil
111}
112
113// fakeRPCListener mimics agent/consul.Server.listen to handle the RPCType byte.
114// In the future we should be able to refactor Server and extract this RPC
115// handling logic so that we don't need to use a fake.
116// For now, since this logic is in agent/consul, we can't easily use Server.listen
117// so we fake it.
118type fakeRPCListener struct {
119	t                   *testing.T
120	handler             *Handler
121	shutdown            bool
122	tlsConf             *tlsutil.Configurator
123	tlsConnEstablished  int32
124	alpnConnEstablished int32
125}
126
127func (f *fakeRPCListener) listen(listener net.Listener) error {
128	for {
129		conn, err := listener.Accept()
130		if err != nil {
131			if f.shutdown {
132				return nil
133			}
134			return err
135		}
136
137		go f.handleConn(conn)
138	}
139}
140
141func (f *fakeRPCListener) handleConn(conn net.Conn) {
142	if f.tlsConf != nil && f.tlsConf.MutualTLSCapable() {
143		// See if actually this is native TLS multiplexed onto the old
144		// "type-byte" system.
145
146		peekedConn, nativeTLS, err := pool.PeekForTLS(conn)
147		if err != nil {
148			if err != io.EOF {
149				fmt.Printf("ERROR: failed to read first byte: %v\n", err)
150			}
151			conn.Close()
152			return
153		}
154
155		if nativeTLS {
156			f.handleNativeTLSConn(peekedConn)
157			return
158		}
159		conn = peekedConn
160	}
161
162	buf := make([]byte, 1)
163
164	if _, err := conn.Read(buf); err != nil {
165		if err != io.EOF {
166			fmt.Println("ERROR", err.Error())
167		}
168		conn.Close()
169		return
170	}
171	typ := pool.RPCType(buf[0])
172
173	switch typ {
174
175	case pool.RPCGRPC:
176		f.handler.Handle(conn)
177		return
178
179	case pool.RPCTLS:
180		// occasionally we see a test client connecting to an rpc listener that
181		// was created as part of another test, despite none of the tests running
182		// in parallel.
183		// Maybe some strange grpc behaviour? I'm not sure.
184		if f.tlsConf == nil {
185			fmt.Println("ERROR: tls is not configured")
186			conn.Close()
187			return
188		}
189
190		atomic.AddInt32(&f.tlsConnEstablished, 1)
191		conn = tls.Server(conn, f.tlsConf.IncomingRPCConfig())
192		f.handleConn(conn)
193
194	default:
195		fmt.Println("ERROR: unexpected byte", typ)
196		conn.Close()
197	}
198}
199
200func (f *fakeRPCListener) handleNativeTLSConn(conn net.Conn) {
201	tlscfg := f.tlsConf.IncomingALPNRPCConfig(pool.RPCNextProtos)
202	tlsConn := tls.Server(conn, tlscfg)
203
204	// Force the handshake to conclude.
205	if err := tlsConn.Handshake(); err != nil {
206		fmt.Printf("ERROR: TLS handshake failed: %v", err)
207		conn.Close()
208		return
209	}
210
211	conn.SetReadDeadline(time.Time{})
212
213	var (
214		cs        = tlsConn.ConnectionState()
215		nextProto = cs.NegotiatedProtocol
216	)
217
218	switch nextProto {
219	case pool.ALPN_RPCGRPC:
220		atomic.AddInt32(&f.alpnConnEstablished, 1)
221		f.handler.Handle(tlsConn)
222
223	default:
224		fmt.Printf("ERROR: discarding RPC for unknown negotiated protocol %q\n", nextProto)
225		conn.Close()
226	}
227}
228