1package main
2
3import (
4	"context"
5	"flag"
6	"fmt"
7	"log"
8	"net"
9	"net/http"
10	"os"
11	"os/signal"
12	"strconv"
13	"sync"
14	"syscall"
15	"time"
16
17	_ "net/http/pprof"
18
19	"github.com/centrifugal/centrifuge/_examples/unidirectional_grpc/clientproto"
20
21	"github.com/centrifugal/centrifuge"
22	"google.golang.org/grpc"
23	"google.golang.org/protobuf/proto"
24)
25
26var (
27	httpPort = flag.Int("http_port", 8000, "Port to bind HTTP server to")
28	grpcPort = flag.Int("grpc_port", 10000, "Port to bind GRPC server to")
29	redis    = flag.Bool("redis", false, "Use Redis")
30)
31
32func grpcAuthInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
33	// You probably want to authenticate user by information included in stream metadata.
34	// meta, ok := metadata.FromIncomingContext(ss.Context())
35	// But here we skip it for simplicity and just always authenticate user with ID 42.
36	ctx := ss.Context()
37	newCtx := centrifuge.SetCredentials(ctx, &centrifuge.Credentials{
38		UserID: "42",
39	})
40
41	// GRPC has no builtin method to add data to context so here we use small
42	// wrapper over ServerStream.
43	wrapped := WrapServerStream(ss)
44	wrapped.WrappedContext = newCtx
45	return handler(srv, wrapped)
46}
47
48// WrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
49// This can be replaced by analogue from github.com/grpc-ecosystem/go-grpc-middleware
50// package - https://github.com/grpc-ecosystem/go-grpc-middleware/blob/master/wrappers.go.
51// You most probably will have dependency to it in your application as it has lots of
52// useful features to deal with GRPC.
53type WrappedServerStream struct {
54	grpc.ServerStream
55	// WrappedContext is the wrapper's own Context. You can assign it.
56	WrappedContext context.Context
57}
58
59// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
60func (w *WrappedServerStream) Context() context.Context {
61	return w.WrappedContext
62}
63
64// WrapServerStream returns a ServerStream that has the ability to overwrite context.
65func WrapServerStream(stream grpc.ServerStream) *WrappedServerStream {
66	if existing, ok := stream.(*WrappedServerStream); ok {
67		return existing
68	}
69	return &WrappedServerStream{ServerStream: stream, WrappedContext: stream.Context()}
70}
71
72func waitExitSignal(n *centrifuge.Node, server *grpc.Server) {
73	sigCh := make(chan os.Signal, 1)
74	done := make(chan bool, 1)
75	signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
76	go func() {
77		<-sigCh
78		_ = n.Shutdown(context.Background())
79		server.GracefulStop()
80		done <- true
81	}()
82	<-done
83}
84
85// RegisterGRPCServerClient ...
86func RegisterGRPCServerClient(n *centrifuge.Node, server *grpc.Server, config GRPCClientServiceConfig) error {
87	clientproto.RegisterCentrifugeUniServer(server, newGRPCClientService(n, config))
88	return nil
89}
90
91// GRPCClientServiceConfig for GRPC client Service.
92type GRPCClientServiceConfig struct{}
93
94// GRPCClientService can work with client GRPC connections.
95type grpcClientService struct {
96	clientproto.UnimplementedCentrifugeUniServer
97	config GRPCClientServiceConfig
98	node   *centrifuge.Node
99}
100
101// newGRPCClientService creates new Service.
102func newGRPCClientService(n *centrifuge.Node, c GRPCClientServiceConfig) *grpcClientService {
103	return &grpcClientService{
104		config: c,
105		node:   n,
106	}
107}
108
109// Consume is a unidirectional server->client stream wit real-time data.
110func (s *grpcClientService) Consume(req *clientproto.ConnectRequest, stream clientproto.CentrifugeUni_ConsumeServer) error {
111	streamDataCh := make(chan []byte)
112	transport := newGRPCTransport(stream, streamDataCh)
113
114	connectRequest := centrifuge.ConnectRequest{
115		Token:   req.Token,
116		Data:    req.Data,
117		Name:    req.Name,
118		Version: req.Version,
119	}
120	if req.Subs != nil {
121		subs := make(map[string]centrifuge.SubscribeRequest)
122		for k, v := range connectRequest.Subs {
123			subs[k] = centrifuge.SubscribeRequest{
124				Recover: v.Recover,
125				Offset:  v.Offset,
126				Epoch:   v.Epoch,
127			}
128		}
129	}
130	c, closeFn, err := centrifuge.NewClient(stream.Context(), s.node, transport)
131	if err != nil {
132		log.Printf("client create error: %v", err)
133		return err
134	}
135	defer func() { _ = closeFn() }()
136
137	log.Printf("client connected (id %s)", c.ID())
138	defer func(started time.Time) {
139		log.Printf("client disconnected (id %s, duration %s)", c.ID(), time.Since(started))
140	}(time.Now())
141
142	c.Connect(connectRequest)
143
144	for {
145		select {
146		case streamData := <-streamDataCh:
147			err := stream.SendMsg(rawFrame(streamData))
148			if err != nil {
149				log.Printf("stream send error: %v", err)
150				return err
151			}
152		case <-transport.closeCh:
153			return nil
154		}
155	}
156}
157
158// grpcTransport wraps a stream.
159type grpcTransport struct {
160	mu           sync.RWMutex
161	stream       clientproto.CentrifugeUni_ConsumeServer
162	closed       bool
163	closeCh      chan struct{}
164	streamDataCh chan []byte
165}
166
167func newGRPCTransport(stream clientproto.CentrifugeUni_ConsumeServer, streamDataCh chan []byte) *grpcTransport {
168	return &grpcTransport{
169		stream:       stream,
170		streamDataCh: streamDataCh,
171		closeCh:      make(chan struct{}),
172	}
173}
174
175func (t *grpcTransport) Name() string {
176	return "grpc"
177}
178
179func (t *grpcTransport) Protocol() centrifuge.ProtocolType {
180	return centrifuge.ProtocolTypeProtobuf
181}
182
183// Unidirectional returns whether transport is unidirectional.
184func (t *grpcTransport) Unidirectional() bool {
185	return true
186}
187
188// DisabledPushFlags ...
189func (t *grpcTransport) DisabledPushFlags() uint64 {
190	return 0
191}
192
193func (t *grpcTransport) Write(message []byte) error {
194	return t.WriteMany(message)
195}
196
197func (t *grpcTransport) WriteMany(messages ...[]byte) error {
198	t.mu.RLock()
199	if t.closed {
200		t.mu.RUnlock()
201		return nil
202	}
203	t.mu.RUnlock()
204	for i := 0; i < len(messages); i++ {
205		select {
206		case t.streamDataCh <- messages[i]:
207		case <-t.closeCh:
208			return nil
209		}
210	}
211	return nil
212}
213
214func (t *grpcTransport) Close(_ *centrifuge.Disconnect) error {
215	t.mu.Lock()
216	defer t.mu.Unlock()
217	if t.closed {
218		return nil
219	}
220	t.closed = true
221	close(t.closeCh)
222	return nil
223}
224
225func handleLog(e centrifuge.LogEntry) {
226	log.Printf("%s: %v", e.Message, e.Fields)
227}
228
229var exampleChannel = "unidirectional"
230
231func main() {
232	flag.Parse()
233
234	cfg := centrifuge.DefaultConfig
235	cfg.LogLevel = centrifuge.LogLevelDebug
236	cfg.LogHandler = handleLog
237
238	node, _ := centrifuge.New(cfg)
239
240	if *redis {
241		redisShardConfigs := []centrifuge.RedisShardConfig{
242			{Address: "localhost:6379"},
243		}
244		var redisShards []*centrifuge.RedisShard
245		for _, redisConf := range redisShardConfigs {
246			redisShard, err := centrifuge.NewRedisShard(node, redisConf)
247			if err != nil {
248				log.Fatal(err)
249			}
250			redisShards = append(redisShards, redisShard)
251		}
252		// Using Redis Broker here to scale nodes.
253		broker, err := centrifuge.NewRedisBroker(node, centrifuge.RedisBrokerConfig{
254			Shards: redisShards,
255		})
256		if err != nil {
257			log.Fatal(err)
258		}
259		node.SetBroker(broker)
260
261		presenceManager, err := centrifuge.NewRedisPresenceManager(node, centrifuge.RedisPresenceManagerConfig{
262			Shards: redisShards,
263		})
264		if err != nil {
265			log.Fatal(err)
266		}
267		node.SetPresenceManager(presenceManager)
268	}
269
270	node.OnConnecting(func(ctx context.Context, e centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) {
271		return centrifuge.ConnectReply{
272			Subscriptions: map[string]centrifuge.SubscribeOptions{
273				exampleChannel: {},
274			},
275		}, nil
276	})
277
278	node.OnConnect(func(client *centrifuge.Client) {
279		client.OnUnsubscribe(func(e centrifuge.UnsubscribeEvent) {
280			log.Printf("user %s unsubscribed from %s", client.UserID(), e.Channel)
281		})
282		client.OnDisconnect(func(e centrifuge.DisconnectEvent) {
283			log.Printf("user %s disconnected, disconnect: %s", client.UserID(), e.Disconnect)
284		})
285		transport := client.Transport()
286		log.Printf("user %s connected via %s", client.UserID(), transport.Name())
287	})
288
289	// Publish to a channel periodically.
290	go func() {
291		for {
292			currentTime := strconv.FormatInt(time.Now().Unix(), 10)
293			_, err := node.Publish(exampleChannel, []byte(`{"server_time": "`+currentTime+`"}`))
294			if err != nil {
295				log.Println(err.Error())
296			}
297			time.Sleep(5 * time.Second)
298		}
299	}()
300
301	if err := node.Run(); err != nil {
302		log.Fatal(err)
303	}
304
305	grpcServer := grpc.NewServer(
306		grpc.StreamInterceptor(grpcAuthInterceptor),
307		grpc.CustomCodec(&rawCodec{}),
308	)
309	err := RegisterGRPCServerClient(node, grpcServer, GRPCClientServiceConfig{})
310	if err != nil {
311		log.Fatal(err)
312	}
313	go func() {
314		log.Println("starting GRPC server on :" + strconv.Itoa(*grpcPort))
315		listener, err := net.Listen("tcp", ":"+strconv.Itoa(*grpcPort))
316		if err != nil {
317			log.Fatal(err)
318		}
319		if err := grpcServer.Serve(listener); err != nil {
320			log.Fatalf("Serve GRPC: %v", err)
321		}
322	}()
323
324	go func() {
325		if err := http.ListenAndServe(":"+strconv.Itoa(*httpPort), nil); err != nil {
326			log.Fatal(err)
327		}
328	}()
329
330	waitExitSignal(node, grpcServer)
331	log.Println("bye!")
332}
333
334type rawFrame []byte
335
336type rawCodec struct{}
337
338func (c *rawCodec) Marshal(v interface{}) ([]byte, error) {
339	out, ok := v.(rawFrame)
340	if !ok {
341		vv, ok := v.(proto.Message)
342		if !ok {
343			return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
344		}
345		return proto.Marshal(vv)
346	}
347	return out, nil
348}
349
350func (c *rawCodec) Unmarshal(data []byte, v interface{}) error {
351	vv, ok := v.(proto.Message)
352	if !ok {
353		return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v)
354	}
355	return proto.Unmarshal(data, vv)
356}
357
358func (c *rawCodec) String() string {
359	return "proto"
360}
361