1package main
2
3import (
4	"crypto/tls"
5	"flag"
6	"fmt"
7	"io"
8	"log"
9	"net/http"
10	"os"
11	"sync"
12
13	google_protobuf "github.com/golang/protobuf/ptypes/empty"
14	"github.com/improbable-eng/grpc-web/go/grpcweb"
15	testproto "github.com/improbable-eng/grpc-web/integration_test/go/_proto/improbable/grpcweb/test"
16	"golang.org/x/net/context"
17	"google.golang.org/grpc"
18	"google.golang.org/grpc/codes"
19	"google.golang.org/grpc/grpclog"
20	"google.golang.org/grpc/metadata"
21)
22
23var (
24	http1Port       = flag.Int("http1_port", 9090, "Port to listen with HTTP1.1 with TLS on.")
25	http1EmptyPort  = flag.Int("http1_empty_port", 9095, "Port to listen with HTTP1.1 with TLS on with a grpc server that has no services.")
26	http2Port       = flag.Int("http2_port", 9100, "Port to listen with HTTP2 with TLS on.")
27	http2EmptyPort  = flag.Int("http2_empty_port", 9105, "Port to listen with HTTP2 with TLS on with a grpc server that has no services.")
28	tlsCertFilePath = flag.String("tls_cert_file", "../../../misc/localhost.crt", "Path to the CRT/PEM file.")
29	tlsKeyFilePath  = flag.String("tls_key_file", "../../../misc/localhost.key", "Path to the private key file.")
30)
31
32func main() {
33	flag.Parse()
34
35	grpcServer := grpc.NewServer()
36	testServer := &testSrv{
37		streamsMutex: &sync.Mutex{},
38		streams:      map[string]chan bool{},
39	}
40	testproto.RegisterTestServiceServer(grpcServer, testServer)
41	testproto.RegisterTestUtilServiceServer(grpcServer, testServer)
42	grpclog.SetLogger(log.New(os.Stdout, "testserver: ", log.LstdFlags))
43
44	websocketOriginFunc := grpcweb.WithWebsocketOriginFunc(func(req *http.Request) bool {
45		return true
46	})
47	httpOriginFunc := grpcweb.WithOriginFunc(func(origin string) bool {
48		return true
49	})
50
51	wrappedServer := grpcweb.WrapServer(
52		grpcServer,
53		grpcweb.WithWebsockets(true),
54		httpOriginFunc,
55		websocketOriginFunc,
56	)
57	handler := func(resp http.ResponseWriter, req *http.Request) {
58		wrappedServer.ServeHTTP(resp, req)
59	}
60
61	emptyGrpcServer := grpc.NewServer()
62	emptyWrappedServer := grpcweb.WrapServer(
63		emptyGrpcServer,
64		grpcweb.WithWebsockets(true),
65		grpcweb.WithCorsForRegisteredEndpointsOnly(false),
66		httpOriginFunc,
67		websocketOriginFunc,
68	)
69	emptyHandler := func(resp http.ResponseWriter, req *http.Request) {
70		emptyWrappedServer.ServeHTTP(resp, req)
71	}
72
73	http1Server := http.Server{
74		Addr:    fmt.Sprintf(":%d", *http1Port),
75		Handler: http.HandlerFunc(handler),
76	}
77	http1Server.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} // Disable HTTP2
78	http1EmptyServer := http.Server{
79		Addr: fmt.Sprintf(":%d", *http1EmptyPort),
80		Handler: http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
81			emptyHandler(res, req)
82		}),
83	}
84	http1EmptyServer.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} // Disable HTTP2
85
86	http2Server := http.Server{
87		Addr:    fmt.Sprintf(":%d", *http2Port),
88		Handler: http.HandlerFunc(handler),
89	}
90	http2EmptyServer := http.Server{
91		Addr:    fmt.Sprintf(":%d", *http2EmptyPort),
92		Handler: http.HandlerFunc(emptyHandler),
93	}
94
95	grpclog.Printf("Starting servers. http1.1 port: %d, http1.1 empty port: %d, http2 port: %d, http2 empty port: %d", *http1Port, *http1EmptyPort, *http2Port, *http2EmptyPort)
96
97	// Start the empty Http1.1 server
98	go func() {
99		if err := http1EmptyServer.ListenAndServeTLS(*tlsCertFilePath, *tlsKeyFilePath); err != nil {
100			grpclog.Fatalf("failed starting http1.1 empty server: %v", err)
101		}
102	}()
103
104	// Start the Http1.1 server
105	go func() {
106		if err := http1Server.ListenAndServeTLS(*tlsCertFilePath, *tlsKeyFilePath); err != nil {
107			grpclog.Fatalf("failed starting http1.1 server: %v", err)
108		}
109	}()
110
111	// Start the empty Http2 server
112	go func() {
113		if err := http2EmptyServer.ListenAndServeTLS(*tlsCertFilePath, *tlsKeyFilePath); err != nil {
114			grpclog.Fatalf("failed starting http2 empty server: %v", err)
115		}
116	}()
117
118	// Start the Http2 server
119	if err := http2Server.ListenAndServeTLS(*tlsCertFilePath, *tlsKeyFilePath); err != nil {
120		grpclog.Fatalf("failed starting http2 server: %v", err)
121	}
122}
123
124type testSrv struct {
125	streamsMutex *sync.Mutex
126	streams      map[string]chan bool
127}
128
129func (s *testSrv) PingEmpty(ctx context.Context, _ *google_protobuf.Empty) (*testproto.PingResponse, error) {
130	grpc.SendHeader(ctx, metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2"))
131	grpc.SetTrailer(ctx, metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2"))
132	return &testproto.PingResponse{Value: "foobar"}, nil
133}
134
135func (s *testSrv) Ping(ctx context.Context, ping *testproto.PingRequest) (*testproto.PingResponse, error) {
136	if ping.GetCheckMetadata() {
137		md, ok := metadata.FromIncomingContext(ctx)
138		if !ok || md["headertestkey1"][0] != "ClientValue1" || md["headertestkey2"][0] != "ClientValue2" {
139			return nil, grpc.Errorf(codes.InvalidArgument, "Metadata was invalid")
140		}
141	}
142	if ping.GetSendHeaders() {
143		grpc.SendHeader(ctx, metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2"))
144	}
145	if ping.GetSendTrailers() {
146		grpc.SetTrailer(ctx, metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2"))
147	}
148	return &testproto.PingResponse{Value: ping.Value, Counter: 252}, nil
149}
150
151func (s *testSrv) Echo(ctx context.Context, text *testproto.TextMessage) (*testproto.TextMessage, error) {
152	if text.GetSendHeaders() {
153		grpc.SendHeader(ctx, metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2"))
154	}
155	if text.GetSendTrailers() {
156		grpc.SetTrailer(ctx, metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2"))
157	}
158	return text, nil
159}
160
161func (s *testSrv) PingError(ctx context.Context, ping *testproto.PingRequest) (*google_protobuf.Empty, error) {
162	if ping.GetSendHeaders() {
163		grpc.SendHeader(ctx, metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2"))
164	}
165	if ping.GetSendTrailers() {
166		grpc.SetTrailer(ctx, metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2"))
167	}
168	var msg = "��"
169	if ping.FailureType == testproto.PingRequest_CODE {
170		msg = "Intentionally returning error for PingError"
171	}
172	return nil, grpc.Errorf(codes.Code(ping.ErrorCodeReturned), msg)
173}
174
175func (s *testSrv) ContinueStream(ctx context.Context, req *testproto.ContinueStreamRequest) (*google_protobuf.Empty, error) {
176	s.streamsMutex.Lock()
177	defer s.streamsMutex.Unlock()
178	channel, ok := s.streams[req.GetStreamIdentifier()]
179	if !ok {
180		return nil, grpc.Errorf(codes.NotFound, "stream identifier not found")
181	}
182	channel <- true
183	return &google_protobuf.Empty{}, nil
184}
185
186func (s *testSrv) CheckStreamClosed(ctx context.Context, req *testproto.CheckStreamClosedRequest) (*testproto.CheckStreamClosedResponse, error) {
187	s.streamsMutex.Lock()
188	defer s.streamsMutex.Unlock()
189	_, ok := s.streams[req.GetStreamIdentifier()]
190	if !ok {
191		return &testproto.CheckStreamClosedResponse{
192			Closed: true,
193		}, nil
194	}
195	return &testproto.CheckStreamClosedResponse{
196		Closed: false,
197	}, nil
198}
199
200func (s *testSrv) PingList(ping *testproto.PingRequest, stream testproto.TestService_PingListServer) error {
201	if ping.GetSendHeaders() {
202		stream.SendHeader(metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2"))
203	}
204	if ping.GetSendTrailers() {
205		stream.SetTrailer(metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2"))
206	}
207
208	var channel chan bool
209	useChannel := ping.GetStreamIdentifier() != ""
210	if useChannel {
211		channel = make(chan bool)
212		s.streamsMutex.Lock()
213		s.streams[ping.GetStreamIdentifier()] = channel
214		s.streamsMutex.Unlock()
215
216		defer func() {
217			// When this stream has ended
218			s.streamsMutex.Lock()
219			delete(s.streams, ping.GetStreamIdentifier())
220			close(channel)
221			s.streamsMutex.Unlock()
222		}()
223	}
224
225	for i := int32(0); i < ping.ResponseCount; i++ {
226		if i != 0 && useChannel {
227			shouldContinue := <-channel
228			if !shouldContinue {
229				return grpc.Errorf(codes.OK, "stream was cancelled by side-channel")
230			}
231		}
232		err := stream.Context().Err()
233		if err != nil {
234			return grpc.Errorf(codes.Canceled, "client cancelled stream")
235		}
236		stream.Send(&testproto.PingResponse{Value: fmt.Sprintf("%s %d", ping.Value, i), Counter: i})
237	}
238	return nil
239}
240
241func (s *testSrv) PingStream(stream testproto.TestService_PingStreamServer) error {
242	allValues := ""
243	for {
244		in, err := stream.Recv()
245		if err == io.EOF {
246			stream.SendAndClose(&testproto.PingResponse{
247				Value: allValues,
248			})
249			return nil
250		}
251		if err != nil {
252			return err
253		}
254		if in.GetSendHeaders() {
255			stream.SendHeader(metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2"))
256		}
257		if in.GetSendTrailers() {
258			stream.SetTrailer(metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2"))
259		}
260		if allValues == "" {
261			allValues = in.GetValue()
262		} else {
263			allValues = allValues + "," + in.GetValue()
264		}
265		if in.FailureType == testproto.PingRequest_CODE {
266			return grpc.Errorf(codes.Code(in.ErrorCodeReturned), "Intentionally returning status code: %d", in.ErrorCodeReturned)
267		}
268	}
269}
270
271func (s *testSrv) PingPongBidi(stream testproto.TestService_PingPongBidiServer) error {
272	for {
273		in, err := stream.Recv()
274		if err == io.EOF {
275			return nil
276		}
277		if err != nil {
278			return err
279		}
280		if in.GetSendHeaders() {
281			stream.SendHeader(metadata.Pairs("HeaderTestKey1", "ServerValue1", "HeaderTestKey2", "ServerValue2"))
282		}
283		if in.GetSendTrailers() {
284			stream.SetTrailer(metadata.Pairs("TrailerTestKey1", "ServerValue1", "TrailerTestKey2", "ServerValue2"))
285		}
286		if in.FailureType == testproto.PingRequest_CODE {
287			if in.ErrorCodeReturned == 0 {
288				return nil
289			}
290			return grpc.Errorf(codes.Code(in.ErrorCodeReturned), "Intentionally returning status code: %d", in.ErrorCodeReturned)
291		}
292		stream.Send(&testproto.PingResponse{
293			Value: in.Value,
294		})
295	}
296}
297