1package grpc 2 3import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 "net" 9 "strconv" 10 "sync/atomic" 11 "testing" 12 "time" 13 14 "github.com/stretchr/testify/require" 15 "golang.org/x/sync/errgroup" 16 "google.golang.org/grpc" 17 18 "github.com/hashicorp/consul/agent/grpc/internal/testservice" 19 "github.com/hashicorp/consul/agent/metadata" 20 "github.com/hashicorp/consul/agent/pool" 21 "github.com/hashicorp/consul/sdk/freeport" 22 "github.com/hashicorp/consul/tlsutil" 23) 24 25type testServer struct { 26 addr net.Addr 27 name string 28 dc string 29 shutdown func() 30 rpc *fakeRPCListener 31} 32 33func (s testServer) Metadata() *metadata.Server { 34 return &metadata.Server{ 35 ID: s.name, 36 Name: s.name + "." + s.dc, 37 ShortName: s.name, 38 Datacenter: s.dc, 39 Addr: s.addr, 40 UseTLS: s.rpc.tlsConf != nil, 41 } 42} 43 44func newTestServer(t *testing.T, name string, dc string, tlsConf *tlsutil.Configurator) testServer { 45 addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} 46 handler := NewHandler(addr, func(server *grpc.Server) { 47 testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc}) 48 }) 49 50 ports := freeport.MustTake(1) 51 t.Cleanup(func() { 52 freeport.Return(ports) 53 }) 54 55 lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0]))) 56 require.NoError(t, err) 57 58 rpc := &fakeRPCListener{t: t, handler: handler, tlsConf: tlsConf} 59 60 g := errgroup.Group{} 61 g.Go(func() error { 62 if err := rpc.listen(lis); err != nil { 63 return fmt.Errorf("fake rpc listen error: %w", err) 64 } 65 return nil 66 }) 67 g.Go(func() error { 68 if err := handler.Run(); err != nil { 69 return fmt.Errorf("grpc server error: %w", err) 70 } 71 return nil 72 }) 73 return testServer{ 74 addr: lis.Addr(), 75 name: name, 76 dc: dc, 77 rpc: rpc, 78 shutdown: func() { 79 rpc.shutdown = true 80 if err := lis.Close(); err != nil { 81 t.Logf("listener closed with error: %v", err) 82 } 83 if err := handler.Shutdown(); err != nil { 84 t.Logf("grpc server shutdown: %v", err) 85 } 86 if err := g.Wait(); err != nil { 87 t.Log(err) 88 } 89 }, 90 } 91} 92 93type simple struct { 94 name string 95 dc string 96} 97 98func (s *simple) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) error { 99 for flow.Context().Err() == nil { 100 resp := &testservice.Resp{ServerName: "one", Datacenter: s.dc} 101 if err := flow.Send(resp); err != nil { 102 return err 103 } 104 time.Sleep(time.Millisecond) 105 } 106 return nil 107} 108 109func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) { 110 return &testservice.Resp{ServerName: s.name, Datacenter: s.dc}, nil 111} 112 113// fakeRPCListener mimics agent/consul.Server.listen to handle the RPCType byte. 114// In the future we should be able to refactor Server and extract this RPC 115// handling logic so that we don't need to use a fake. 116// For now, since this logic is in agent/consul, we can't easily use Server.listen 117// so we fake it. 118type fakeRPCListener struct { 119 t *testing.T 120 handler *Handler 121 shutdown bool 122 tlsConf *tlsutil.Configurator 123 tlsConnEstablished int32 124 alpnConnEstablished int32 125} 126 127func (f *fakeRPCListener) listen(listener net.Listener) error { 128 for { 129 conn, err := listener.Accept() 130 if err != nil { 131 if f.shutdown { 132 return nil 133 } 134 return err 135 } 136 137 go f.handleConn(conn) 138 } 139} 140 141func (f *fakeRPCListener) handleConn(conn net.Conn) { 142 if f.tlsConf != nil && f.tlsConf.MutualTLSCapable() { 143 // See if actually this is native TLS multiplexed onto the old 144 // "type-byte" system. 145 146 peekedConn, nativeTLS, err := pool.PeekForTLS(conn) 147 if err != nil { 148 if err != io.EOF { 149 fmt.Printf("ERROR: failed to read first byte: %v\n", err) 150 } 151 conn.Close() 152 return 153 } 154 155 if nativeTLS { 156 f.handleNativeTLSConn(peekedConn) 157 return 158 } 159 conn = peekedConn 160 } 161 162 buf := make([]byte, 1) 163 164 if _, err := conn.Read(buf); err != nil { 165 if err != io.EOF { 166 fmt.Println("ERROR", err.Error()) 167 } 168 conn.Close() 169 return 170 } 171 typ := pool.RPCType(buf[0]) 172 173 switch typ { 174 175 case pool.RPCGRPC: 176 f.handler.Handle(conn) 177 return 178 179 case pool.RPCTLS: 180 // occasionally we see a test client connecting to an rpc listener that 181 // was created as part of another test, despite none of the tests running 182 // in parallel. 183 // Maybe some strange grpc behaviour? I'm not sure. 184 if f.tlsConf == nil { 185 fmt.Println("ERROR: tls is not configured") 186 conn.Close() 187 return 188 } 189 190 atomic.AddInt32(&f.tlsConnEstablished, 1) 191 conn = tls.Server(conn, f.tlsConf.IncomingRPCConfig()) 192 f.handleConn(conn) 193 194 default: 195 fmt.Println("ERROR: unexpected byte", typ) 196 conn.Close() 197 } 198} 199 200func (f *fakeRPCListener) handleNativeTLSConn(conn net.Conn) { 201 tlscfg := f.tlsConf.IncomingALPNRPCConfig(pool.RPCNextProtos) 202 tlsConn := tls.Server(conn, tlscfg) 203 204 // Force the handshake to conclude. 205 if err := tlsConn.Handshake(); err != nil { 206 fmt.Printf("ERROR: TLS handshake failed: %v", err) 207 conn.Close() 208 return 209 } 210 211 conn.SetReadDeadline(time.Time{}) 212 213 var ( 214 cs = tlsConn.ConnectionState() 215 nextProto = cs.NegotiatedProtocol 216 ) 217 218 switch nextProto { 219 case pool.ALPN_RPCGRPC: 220 atomic.AddInt32(&f.alpnConnEstablished, 1) 221 f.handler.Handle(tlsConn) 222 223 default: 224 fmt.Printf("ERROR: discarding RPC for unknown negotiated protocol %q\n", nextProto) 225 conn.Close() 226 } 227} 228