1package main
2
3import (
4	"context"
5	"encoding/json"
6	"fmt"
7	"io"
8	"log"
9	"net/http"
10	"strconv"
11	"time"
12
13	"github.com/centrifugal/centrifuge"
14	"github.com/gin-contrib/sessions"
15	"github.com/gin-contrib/sessions/cookie"
16	"github.com/gin-gonic/gin"
17)
18
19type clientMessage struct {
20	Timestamp int64  `json:"timestamp"`
21	Input     string `json:"input"`
22}
23
24func handleLog(e centrifuge.LogEntry) {
25	log.Printf("%s: %v", e.Message, e.Fields)
26}
27
28type connectData struct {
29	Email string `json:"email"`
30}
31
32type contextKey int
33
34var ginContextKey contextKey
35
36// GinContextToContextMiddleware - at the resolver level we only have access
37// to context.Context inside centrifuge, but we need the gin context. So we
38// create a gin middleware to add its context to the context.Context used by
39// centrifuge websocket server.
40func GinContextToContextMiddleware() gin.HandlerFunc {
41	return func(c *gin.Context) {
42		ctx := context.WithValue(c.Request.Context(), ginContextKey, c)
43		c.Request = c.Request.WithContext(ctx)
44		c.Next()
45	}
46}
47
48// GinContextFromContext - we recover the gin context from the context.Context
49// struct where we added it just above
50func GinContextFromContext(ctx context.Context) (*gin.Context, error) {
51	ginContext := ctx.Value(ginContextKey)
52	if ginContext == nil {
53		err := fmt.Errorf("could not retrieve gin.Context")
54		return nil, err
55	}
56	gc, ok := ginContext.(*gin.Context)
57	if !ok {
58		err := fmt.Errorf("gin.Context has wrong type")
59		return nil, err
60	}
61	return gc, nil
62}
63
64// Finally we can use gin context in the auth middleware of centrifuge.
65func authMiddleware(h http.Handler) http.Handler {
66	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
67		ctx := r.Context()
68		// We get gin ctx from context.Context struct.
69		gc, err := GinContextFromContext(ctx)
70		if err != nil {
71			fmt.Printf("Failed to retrieve gin context")
72			fmt.Print(err.Error())
73			return
74		}
75		// And now we can access gin session.
76		s := sessions.Default(gc)
77		username := s.Get("user").(string)
78		if username != "" {
79			fmt.Printf("Successful websocket auth for user %s\n", username)
80		} else {
81			fmt.Printf("Failed websocket auth for user %s\n", username)
82			return
83		}
84		newCtx := centrifuge.SetCredentials(ctx, &centrifuge.Credentials{
85			UserID: s.Get("user").(string),
86		})
87		r = r.WithContext(newCtx)
88		h.ServeHTTP(w, r)
89	})
90}
91
92func main() {
93	cfg := centrifuge.DefaultConfig
94	cfg.LogLevel = centrifuge.LogLevelDebug
95	cfg.LogHandler = handleLog
96
97	node, _ := centrifuge.New(cfg)
98
99	node.OnConnecting(func(ctx context.Context, event centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) {
100		// Let's include user email into connect reply, so we can display user name in chat.
101		// This is an optional step actually.
102		cred, ok := centrifuge.GetCredentials(ctx)
103		if !ok {
104			return centrifuge.ConnectReply{}, centrifuge.DisconnectServerError
105		}
106		data, _ := json.Marshal(connectData{
107			Email: cred.UserID,
108		})
109		return centrifuge.ConnectReply{
110			Data: data,
111		}, nil
112	})
113
114	node.OnConnect(func(client *centrifuge.Client) {
115		transport := client.Transport()
116		log.Printf("user %s connected via %s.", client.UserID(), transport.Name())
117
118		// Connect handler should not block, so start separate goroutine to
119		// periodically send messages to client.
120		go func() {
121			for {
122				select {
123				case <-client.Context().Done():
124					return
125				case <-time.After(5 * time.Second):
126					err := client.Send([]byte(`{"time": "` + strconv.FormatInt(time.Now().Unix(), 10) + `"}`))
127					if err != nil {
128						if err == io.EOF {
129							return
130						}
131						log.Println(err.Error())
132					}
133				}
134			}
135		}()
136
137		client.OnRefresh(func(e centrifuge.RefreshEvent, cb centrifuge.RefreshCallback) {
138			log.Printf("user %s connection is going to expire, refreshing", client.UserID())
139			cb(centrifuge.RefreshReply{
140				ExpireAt: time.Now().Unix() + 10,
141			}, nil)
142		})
143
144		client.OnSubscribe(func(e centrifuge.SubscribeEvent, cb centrifuge.SubscribeCallback) {
145			log.Printf("user %s subscribes on %s", client.UserID(), e.Channel)
146			cb(centrifuge.SubscribeReply{}, nil)
147		})
148
149		client.OnUnsubscribe(func(e centrifuge.UnsubscribeEvent) {
150			log.Printf("user %s unsubscribed from %s", client.UserID(), e.Channel)
151		})
152
153		client.OnPublish(func(e centrifuge.PublishEvent, cb centrifuge.PublishCallback) {
154			log.Printf("user %s publishes into channel %s: %s", client.UserID(), e.Channel, string(e.Data))
155			var msg clientMessage
156			err := json.Unmarshal(e.Data, &msg)
157			if err != nil {
158				cb(centrifuge.PublishReply{}, centrifuge.ErrorBadRequest)
159				return
160			}
161			cb(centrifuge.PublishReply{}, nil)
162		})
163
164		client.OnRPC(func(e centrifuge.RPCEvent, cb centrifuge.RPCCallback) {
165			log.Printf("RPC from user: %s, data: %s", client.UserID(), string(e.Data))
166			cb(centrifuge.RPCReply{
167				Data: []byte(`{"year": "2020"}`),
168			}, nil)
169		})
170
171		client.OnMessage(func(e centrifuge.MessageEvent) {
172			log.Printf("Message from user: %s, data: %s", client.UserID(), string(e.Data))
173		})
174
175		client.OnDisconnect(func(e centrifuge.DisconnectEvent) {
176			log.Printf("user %s disconnected, disconnect: %s", client.UserID(), e.Disconnect)
177		})
178	})
179
180	// We also start a separate goroutine for centrifuge itself, since we
181	// still need to run gin web server.
182	go func() {
183		if err := node.Run(); err != nil {
184			log.Fatal(err)
185		}
186	}()
187
188	r := gin.Default()
189	store := cookie.NewStore([]byte("secret_string"))
190	r.Use(sessions.Sessions("session_name", store))
191	r.LoadHTMLFiles("./login_form.html", "./chat.html")
192	// Here we tell gin to use the middleware we created just above
193	r.Use(GinContextToContextMiddleware())
194
195	r.GET("/login", func(c *gin.Context) {
196		s := sessions.Default(c)
197		if s.Get("user") != nil && s.Get("user").(string) == "email@email.com" {
198			c.Redirect(http.StatusMovedPermanently, "/chat")
199			c.Abort()
200		} else {
201			c.HTML(200, "login_form.html", gin.H{})
202		}
203	})
204
205	r.POST("/login", func(c *gin.Context) {
206		email := c.PostForm("email")
207		passwd := c.PostForm("password")
208		s := sessions.Default(c)
209		if email == "email@email.com" && passwd == "password" {
210			s.Set("user", email)
211			_ = s.Save()
212			c.Redirect(http.StatusMovedPermanently, "/chat")
213			c.Abort()
214		} else {
215			c.JSON(403, gin.H{
216				"message": "Bad email/password combination",
217			})
218		}
219	})
220
221	r.GET("/connection/websocket", gin.WrapH(authMiddleware(centrifuge.NewWebsocketHandler(node, centrifuge.WebsocketConfig{}))))
222	r.GET("/connection/sockjs", gin.WrapH(authMiddleware(centrifuge.NewSockjsHandler(node, centrifuge.SockjsConfig{
223		URL:           "https://cdn.jsdelivr.net/npm/sockjs-client@1/dist/sockjs.min.js",
224		HandlerPrefix: "/connection/sockjs",
225	}))))
226
227	r.GET("/chat", func(c *gin.Context) {
228		s := sessions.Default(c)
229		if s.Get("user") != nil {
230			c.HTML(200, "chat.html", gin.H{})
231		} else {
232			c.JSON(403, gin.H{
233				"message": "Not logged in!",
234			})
235		}
236		c.Abort()
237	})
238
239	_ = r.Run() // listen and serve on 0.0.0.0:8080
240}
241