1package main
2
3import (
4	"context"
5	"flag"
6	"log"
7	"net/http"
8	"os"
9	"os/signal"
10	"strconv"
11	"sync"
12	"syscall"
13	"time"
14
15	_ "net/http/pprof"
16
17	"github.com/centrifugal/centrifuge"
18)
19
20var (
21	port     = flag.Int("port", 8000, "Port to bind app to")
22	redis    = flag.Bool("redis", false, "Use Redis")
23	tls      = flag.Bool("tls", false, "Use TLS")
24	keyFile  = flag.String("key_file", "server.key", "path to TLS key file")
25	certFile = flag.String("cert_file", "server.crt", "path to TLS crt file")
26	// Was not able to make it work with HTTP/3 yet.
27	//useHttp3 = flag.Bool("http3", false, "Use HTTP/3")
28)
29
30func handleLog(e centrifuge.LogEntry) {
31	log.Printf("%s: %v", e.Message, e.Fields)
32}
33
34func authMiddleware(h http.Handler) http.Handler {
35	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
36		ctx := r.Context()
37		newCtx := centrifuge.SetCredentials(ctx, &centrifuge.Credentials{
38			UserID: "42",
39		})
40		r = r.WithContext(newCtx)
41		h.ServeHTTP(w, r)
42	})
43}
44
45func waitExitSignal(n *centrifuge.Node) {
46	sigCh := make(chan os.Signal, 1)
47	done := make(chan bool, 1)
48	signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
49	go func() {
50		<-sigCh
51		_ = n.Shutdown(context.Background())
52		done <- true
53	}()
54	<-done
55}
56
57var exampleChannel = "unidirectional"
58
59func main() {
60	flag.Parse()
61
62	cfg := centrifuge.DefaultConfig
63	cfg.LogLevel = centrifuge.LogLevelDebug
64	cfg.LogHandler = handleLog
65
66	node, _ := centrifuge.New(cfg)
67
68	if *redis {
69		redisShardConfigs := []centrifuge.RedisShardConfig{
70			{Address: "localhost:6379"},
71		}
72		var redisShards []*centrifuge.RedisShard
73		for _, redisConf := range redisShardConfigs {
74			redisShard, err := centrifuge.NewRedisShard(node, redisConf)
75			if err != nil {
76				log.Fatal(err)
77			}
78			redisShards = append(redisShards, redisShard)
79		}
80		// Using Redis Broker here to scale nodes.
81		broker, err := centrifuge.NewRedisBroker(node, centrifuge.RedisBrokerConfig{
82			Shards: redisShards,
83		})
84		if err != nil {
85			log.Fatal(err)
86		}
87		node.SetBroker(broker)
88
89		presenceManager, err := centrifuge.NewRedisPresenceManager(node, centrifuge.RedisPresenceManagerConfig{
90			Shards: redisShards,
91		})
92		if err != nil {
93			log.Fatal(err)
94		}
95		node.SetPresenceManager(presenceManager)
96	}
97
98	node.OnConnecting(func(ctx context.Context, e centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) {
99		return centrifuge.ConnectReply{
100			Subscriptions: map[string]centrifuge.SubscribeOptions{
101				exampleChannel: {
102					Recover:  true,
103					Position: true,
104					Data:     []byte(`{"message": "welcome to a channel"}`),
105				},
106			},
107		}, nil
108	})
109
110	node.OnConnect(func(client *centrifuge.Client) {
111		client.OnUnsubscribe(func(e centrifuge.UnsubscribeEvent) {
112			log.Printf("user %s unsubscribed from %s", client.UserID(), e.Channel)
113		})
114		client.OnDisconnect(func(e centrifuge.DisconnectEvent) {
115			log.Printf("user %s disconnected, disconnect: %s", client.UserID(), e.Disconnect)
116		})
117		transport := client.Transport()
118		log.Printf("user %s connected via %s", client.UserID(), transport.Name())
119	})
120
121	// Publish to a channel periodically.
122	go func() {
123		for {
124			currentTime := strconv.FormatInt(time.Now().Unix(), 10)
125			_, err := node.Publish(exampleChannel, []byte(`{"server_time": "`+currentTime+`"}`), centrifuge.WithHistory(10, time.Minute))
126			if err != nil {
127				log.Println(err.Error())
128			}
129			time.Sleep(5 * time.Second)
130		}
131	}()
132
133	if err := node.Run(); err != nil {
134		log.Fatal(err)
135	}
136
137	http.Handle("/connection/stream", authMiddleware(handleStream(node)))
138	http.Handle("/subscribe", handleSubscribe(node))
139	http.Handle("/unsubscribe", handleUnsubscribe(node))
140	http.Handle("/", http.FileServer(http.Dir("./")))
141
142	go func() {
143		if *tls {
144			//if *useHttp3 {
145			//	if err := http3.ListenAndServe("0.0.0.0:443", *certFile, *keyFile, nil); err != nil {
146			//		log.Fatal(err)
147			//	}
148			//} else {
149			if err := http.ListenAndServeTLS(":"+strconv.Itoa(*port), *certFile, *keyFile, nil); err != nil {
150				log.Fatal(err)
151			}
152			//}
153		} else {
154			if err := http.ListenAndServe(":"+strconv.Itoa(*port), nil); err != nil {
155				log.Fatal(err)
156			}
157		}
158	}()
159
160	waitExitSignal(node)
161	log.Println("bye!")
162}
163
164func handleStream(node *centrifuge.Node) http.HandlerFunc {
165	return func(w http.ResponseWriter, req *http.Request) {
166		w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
167		w.Header().Set("Connection", "keep-alive")
168		w.WriteHeader(http.StatusOK)
169
170		transport := newStreamTransport(req)
171
172		c, closeFn, err := centrifuge.NewClient(req.Context(), node, transport)
173		if err != nil {
174			log.Printf("error creating client: %v", err)
175			return
176		}
177		defer func() { _ = closeFn() }()
178		defer close(transport.closedCh) // need to execute this after client closeFn.
179
180		c.Connect(centrifuge.ConnectRequest{})
181
182		flusher, ok := w.(http.Flusher)
183		if !ok {
184			log.Printf("ResponseWriter does not support Flusher")
185			return
186		}
187
188		pingInterval := 25 * time.Second
189		tick := time.NewTicker(pingInterval)
190		defer tick.Stop()
191
192		for {
193			select {
194			case <-req.Context().Done():
195				return
196			case <-transport.disconnectCh:
197				return
198			case <-tick.C:
199				_, err = w.Write([]byte("null\n"))
200				if err != nil {
201					log.Printf("error write: %v", err)
202					return
203				}
204				flusher.Flush()
205			case data, ok := <-transport.messages:
206				if !ok {
207					return
208				}
209				tick.Reset(pingInterval)
210				_, err = w.Write(data)
211				if err != nil {
212					log.Printf("error write: %v", err)
213					return
214				}
215				_, err = w.Write([]byte("\n"))
216				if err != nil {
217					log.Printf("error write: %v", err)
218					return
219				}
220				flusher.Flush()
221			}
222		}
223	}
224}
225
226func handleSubscribe(node *centrifuge.Node) http.HandlerFunc {
227	return func(w http.ResponseWriter, req *http.Request) {
228		clientID := req.URL.Query().Get("client")
229		if clientID == "" {
230			w.WriteHeader(http.StatusBadRequest)
231			return
232		}
233		err := node.Subscribe(
234			"42", exampleChannel,
235			centrifuge.WithSubscribeClient(clientID),
236			centrifuge.WithSubscribeData([]byte(`{"message": "welcome to a channel"}`)),
237		)
238		if err != nil {
239			w.WriteHeader(http.StatusInternalServerError)
240			return
241		}
242		w.WriteHeader(http.StatusOK)
243	}
244}
245
246func handleUnsubscribe(node *centrifuge.Node) http.HandlerFunc {
247	return func(w http.ResponseWriter, req *http.Request) {
248		clientID := req.URL.Query().Get("client")
249		if clientID == "" {
250			w.WriteHeader(http.StatusBadRequest)
251			return
252		}
253		err := node.Unsubscribe("42", exampleChannel, centrifuge.WithUnsubscribeClient(clientID))
254		if err != nil {
255			w.WriteHeader(http.StatusInternalServerError)
256			return
257		}
258		w.WriteHeader(http.StatusOK)
259	}
260}
261
262type streamTransport struct {
263	mu           sync.Mutex
264	req          *http.Request
265	messages     chan []byte
266	disconnectCh chan *centrifuge.Disconnect
267	closedCh     chan struct{}
268	closed       bool
269}
270
271func newStreamTransport(req *http.Request) *streamTransport {
272	return &streamTransport{
273		messages:     make(chan []byte),
274		disconnectCh: make(chan *centrifuge.Disconnect),
275		closedCh:     make(chan struct{}),
276		req:          req,
277	}
278}
279
280func (t *streamTransport) Name() string {
281	return "stream"
282}
283
284func (t *streamTransport) Protocol() centrifuge.ProtocolType {
285	return centrifuge.ProtocolTypeJSON
286}
287
288// Unidirectional returns whether transport is unidirectional.
289func (t *streamTransport) Unidirectional() bool {
290	return true
291}
292
293// DisabledPushFlags ...
294func (t *streamTransport) DisabledPushFlags() uint64 {
295	return 0
296}
297
298func (t *streamTransport) Write(message []byte) error {
299	return t.WriteMany(message)
300}
301
302func (t *streamTransport) WriteMany(messages ...[]byte) error {
303	t.mu.Lock()
304	defer t.mu.Unlock()
305	if t.closed {
306		return nil
307	}
308	for i := 0; i < len(messages); i++ {
309		select {
310		case t.messages <- messages[i]:
311		case <-t.closedCh:
312			return nil
313		}
314	}
315	return nil
316}
317
318func (t *streamTransport) Close(_ *centrifuge.Disconnect) error {
319	t.mu.Lock()
320	defer t.mu.Unlock()
321	if t.closed {
322		return nil
323	}
324	t.closed = true
325	close(t.disconnectCh)
326	<-t.closedCh
327	return nil
328}
329