1package main 2 3import ( 4 "crypto/tls" 5 "flag" 6 "fmt" 7 "io" 8 "log" 9 "net/http" 10 "os" 11 "sync" 12 13 google_protobuf "github.com/golang/protobuf/ptypes/empty" 14 "github.com/improbable-eng/grpc-web/go/grpcweb" 15 testproto "github.com/improbable-eng/grpc-web/integration_test/go/_proto/improbable/grpcweb/test" 16 "golang.org/x/net/context" 17 "google.golang.org/grpc" 18 "google.golang.org/grpc/codes" 19 "google.golang.org/grpc/grpclog" 20 "google.golang.org/grpc/metadata" 21) 22 23var ( 24 http1Port = flag.Int("http1_port", 9090, "Port to listen with HTTP1.1 with TLS on.") 25 http1EmptyPort = flag.Int("http1_empty_port", 9095, "Port to listen with HTTP1.1 with TLS on with a grpc server that has no services.") 26 http2Port = flag.Int("http2_port", 9100, "Port to listen with HTTP2 with TLS on.") 27 http2EmptyPort = flag.Int("http2_empty_port", 9105, "Port to listen with HTTP2 with TLS on with a grpc server that has no services.") 28 tlsCertFilePath = flag.String("tls_cert_file", "../../../misc/localhost.crt", "Path to the CRT/PEM file.") 29 tlsKeyFilePath = flag.String("tls_key_file", "../../../misc/localhost.key", "Path to the private key file.") 30) 31 32func main() { 33 flag.Parse() 34 35 grpcServer := grpc.NewServer() 36 testServer := &testSrv{ 37 streamsMutex: &sync.Mutex{}, 38 streams: map[string]chan bool{}, 39 } 40 testproto.RegisterTestServiceServer(grpcServer, testServer) 41 testproto.RegisterTestUtilServiceServer(grpcServer, testServer) 42 grpclog.SetLogger(log.New(os.Stdout, "testserver: ", log.LstdFlags)) 43 44 websocketOriginFunc := grpcweb.WithWebsocketOriginFunc(func(req *http.Request) bool { 45 return true 46 }) 47 httpOriginFunc := grpcweb.WithOriginFunc(func(origin string) bool { 48 return true 49 }) 50 51 wrappedServer := grpcweb.WrapServer( 52 grpcServer, 53 grpcweb.WithWebsockets(true), 54 httpOriginFunc, 55 websocketOriginFunc, 56 ) 57 handler := func(resp http.ResponseWriter, req *http.Request) { 58 wrappedServer.ServeHTTP(resp, req) 59 } 60 61 emptyGrpcServer := grpc.NewServer() 62 emptyWrappedServer := grpcweb.WrapServer( 63 emptyGrpcServer, 64 grpcweb.WithWebsockets(true), 65 grpcweb.WithCorsForRegisteredEndpointsOnly(false), 66 httpOriginFunc, 67 websocketOriginFunc, 68 ) 69 emptyHandler := func(resp http.ResponseWriter, req *http.Request) { 70 emptyWrappedServer.ServeHTTP(resp, req) 71 } 72 73 http1Server := http.Server{ 74 Addr: fmt.Sprintf(":%d", *http1Port), 75 Handler: http.HandlerFunc(handler), 76 } 77 http1Server.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} // Disable HTTP2 78 http1EmptyServer := http.Server{ 79 Addr: fmt.Sprintf(":%d", *http1EmptyPort), 80 Handler: http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { 81 emptyHandler(res, req) 82 }), 83 } 84 http1EmptyServer.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} // Disable HTTP2 85 86 http2Server := http.Server{ 87 Addr: fmt.Sprintf(":%d", *http2Port), 88 Handler: http.HandlerFunc(handler), 89 } 90 http2EmptyServer := http.Server{ 91 Addr: fmt.Sprintf(":%d", *http2EmptyPort), 92 Handler: http.HandlerFunc(emptyHandler), 93 } 94 95 grpclog.Printf("Starting servers. http1.1 port: %d, http1.1 empty port: %d, http2 port: %d, http2 empty port: %d", *http1Port, *http1EmptyPort, *http2Port, *http2EmptyPort) 96 97 // Start the empty Http1.1 server 98 go func() { 99 if err := http1EmptyServer.ListenAndServeTLS(*tlsCertFilePath, *tlsKeyFilePath); err != nil { 100 grpclog.Fatalf("failed starting http1.1 empty server: %v", err) 101 } 102 }() 103 104 // Start the Http1.1 server 105 go func() { 106 if err := http1Server.ListenAndServeTLS(*tlsCertFilePath, *tlsKeyFilePath); err != nil { 107 grpclog.Fatalf("failed starting http1.1 server: %v", err) 108 } 109 }() 110 111 // Start the empty Http2 server 112 go func() { 113 if err := http2EmptyServer.ListenAndServeTLS(*tlsCertFilePath, *tlsKeyFilePath); err != nil { 114 grpclog.Fatalf("failed starting http2 empty server: %v", err) 115 } 116 }() 117 118 // Start the Http2 server 119 if err := http2Server.ListenAndServeTLS(*tlsCertFilePath, *tlsKeyFilePath); err != nil { 120 grpclog.Fatalf("failed starting http2 server: %v", err) 121 } 122} 123 124type testSrv struct { 125 streamsMutex *sync.Mutex 126 streams map[string]chan bool 127} 128 129func (s *testSrv) PingEmpty(ctx context.Context, _ *google_protobuf.Empty) (*testproto.PingResponse, error) { 130 grpc.SendHeader(ctx, metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2")) 131 grpc.SetTrailer(ctx, metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2")) 132 return &testproto.PingResponse{Value: "foobar"}, nil 133} 134 135func (s *testSrv) Ping(ctx context.Context, ping *testproto.PingRequest) (*testproto.PingResponse, error) { 136 if ping.GetCheckMetadata() { 137 md, ok := metadata.FromIncomingContext(ctx) 138 if !ok || md["headertestkey1"][0] != "ClientValue1" || md["headertestkey2"][0] != "ClientValue2" { 139 return nil, grpc.Errorf(codes.InvalidArgument, "Metadata was invalid") 140 } 141 } 142 if ping.GetSendHeaders() { 143 grpc.SendHeader(ctx, metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2")) 144 } 145 if ping.GetSendTrailers() { 146 grpc.SetTrailer(ctx, metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2")) 147 } 148 return &testproto.PingResponse{Value: ping.Value, Counter: 252}, nil 149} 150 151func (s *testSrv) Echo(ctx context.Context, text *testproto.TextMessage) (*testproto.TextMessage, error) { 152 if text.GetSendHeaders() { 153 grpc.SendHeader(ctx, metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2")) 154 } 155 if text.GetSendTrailers() { 156 grpc.SetTrailer(ctx, metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2")) 157 } 158 return text, nil 159} 160 161func (s *testSrv) PingError(ctx context.Context, ping *testproto.PingRequest) (*google_protobuf.Empty, error) { 162 if ping.GetSendHeaders() { 163 grpc.SendHeader(ctx, metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2")) 164 } 165 if ping.GetSendTrailers() { 166 grpc.SetTrailer(ctx, metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2")) 167 } 168 var msg = "" 169 if ping.FailureType == testproto.PingRequest_CODE { 170 msg = "Intentionally returning error for PingError" 171 } 172 return nil, grpc.Errorf(codes.Code(ping.ErrorCodeReturned), msg) 173} 174 175func (s *testSrv) ContinueStream(ctx context.Context, req *testproto.ContinueStreamRequest) (*google_protobuf.Empty, error) { 176 s.streamsMutex.Lock() 177 defer s.streamsMutex.Unlock() 178 channel, ok := s.streams[req.GetStreamIdentifier()] 179 if !ok { 180 return nil, grpc.Errorf(codes.NotFound, "stream identifier not found") 181 } 182 channel <- true 183 return &google_protobuf.Empty{}, nil 184} 185 186func (s *testSrv) CheckStreamClosed(ctx context.Context, req *testproto.CheckStreamClosedRequest) (*testproto.CheckStreamClosedResponse, error) { 187 s.streamsMutex.Lock() 188 defer s.streamsMutex.Unlock() 189 _, ok := s.streams[req.GetStreamIdentifier()] 190 if !ok { 191 return &testproto.CheckStreamClosedResponse{ 192 Closed: true, 193 }, nil 194 } 195 return &testproto.CheckStreamClosedResponse{ 196 Closed: false, 197 }, nil 198} 199 200func (s *testSrv) PingList(ping *testproto.PingRequest, stream testproto.TestService_PingListServer) error { 201 if ping.GetSendHeaders() { 202 stream.SendHeader(metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2")) 203 } 204 if ping.GetSendTrailers() { 205 stream.SetTrailer(metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2")) 206 } 207 208 var channel chan bool 209 useChannel := ping.GetStreamIdentifier() != "" 210 if useChannel { 211 channel = make(chan bool) 212 s.streamsMutex.Lock() 213 s.streams[ping.GetStreamIdentifier()] = channel 214 s.streamsMutex.Unlock() 215 216 defer func() { 217 // When this stream has ended 218 s.streamsMutex.Lock() 219 delete(s.streams, ping.GetStreamIdentifier()) 220 close(channel) 221 s.streamsMutex.Unlock() 222 }() 223 } 224 225 for i := int32(0); i < ping.ResponseCount; i++ { 226 if i != 0 && useChannel { 227 shouldContinue := <-channel 228 if !shouldContinue { 229 return grpc.Errorf(codes.OK, "stream was cancelled by side-channel") 230 } 231 } 232 err := stream.Context().Err() 233 if err != nil { 234 return grpc.Errorf(codes.Canceled, "client cancelled stream") 235 } 236 stream.Send(&testproto.PingResponse{Value: fmt.Sprintf("%s %d", ping.Value, i), Counter: i}) 237 } 238 return nil 239} 240 241func (s *testSrv) PingStream(stream testproto.TestService_PingStreamServer) error { 242 allValues := "" 243 for { 244 in, err := stream.Recv() 245 if err == io.EOF { 246 stream.SendAndClose(&testproto.PingResponse{ 247 Value: allValues, 248 }) 249 return nil 250 } 251 if err != nil { 252 return err 253 } 254 if in.GetSendHeaders() { 255 stream.SendHeader(metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2")) 256 } 257 if in.GetSendTrailers() { 258 stream.SetTrailer(metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2")) 259 } 260 if allValues == "" { 261 allValues = in.GetValue() 262 } else { 263 allValues = allValues + "," + in.GetValue() 264 } 265 if in.FailureType == testproto.PingRequest_CODE { 266 return grpc.Errorf(codes.Code(in.ErrorCodeReturned), "Intentionally returning status code: %d", in.ErrorCodeReturned) 267 } 268 } 269} 270 271func (s *testSrv) PingPongBidi(stream testproto.TestService_PingPongBidiServer) error { 272 for { 273 in, err := stream.Recv() 274 if err == io.EOF { 275 return nil 276 } 277 if err != nil { 278 return err 279 } 280 if in.GetSendHeaders() { 281 stream.SendHeader(metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2")) 282 } 283 if in.GetSendTrailers() { 284 stream.SetTrailer(metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2")) 285 } 286 if in.FailureType == testproto.PingRequest_CODE { 287 if in.ErrorCodeReturned == 0 { 288 return nil 289 } 290 return grpc.Errorf(codes.Code(in.ErrorCodeReturned), "Intentionally returning status code: %d", in.ErrorCodeReturned) 291 } 292 stream.Send(&testproto.PingResponse{ 293 Value: in.Value, 294 }) 295 } 296} 297