1package main
2
3import (
4	"context"
5	"flag"
6	"io"
7	"log"
8	"net"
9	"net/http"
10	"os"
11	"os/signal"
12	"strconv"
13	"strings"
14	"syscall"
15	"time"
16
17	_ "net/http/pprof"
18
19	"github.com/centrifugal/centrifuge"
20	"github.com/gobwas/ws"
21	"github.com/mailru/easygo/netpoll"
22)
23
24var (
25	workers   = flag.Int("workers", 128, "max workers count")
26	queue     = flag.Int("queue", 1, "workers task queue size")
27	ioTimeout = flag.Duration("io_timeout", time.Millisecond*500, "i/o operations timeout")
28)
29
30func handleLog(e centrifuge.LogEntry) {
31	log.Printf("%s: %v", e.Message, e.Fields)
32}
33
34func waitExitSignal(n *centrifuge.Node) {
35	sigCh := make(chan os.Signal, 1)
36	done := make(chan bool, 1)
37	signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
38	go func() {
39		<-sigCh
40		_ = n.Shutdown(context.Background())
41		done <- true
42	}()
43	<-done
44}
45
46func main() {
47	flag.Parse()
48
49	cfg := centrifuge.DefaultConfig
50
51	cfg.LogLevel = centrifuge.LogLevelDebug
52	cfg.LogHandler = handleLog
53
54	node, _ := centrifuge.New(cfg)
55
56	node.OnConnecting(func(ctx context.Context, e centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) {
57		return centrifuge.ConnectReply{
58			Credentials: &centrifuge.Credentials{
59				UserID: "",
60			},
61		}, nil
62	})
63
64	node.OnConnect(func(client *centrifuge.Client) {
65		transport := client.Transport()
66		log.Printf("user %s connected via %s with format: %s", client.UserID(), transport.Name(), transport.Protocol())
67
68		// Connect handler should not block, so start separate goroutine to
69		// periodically send messages to client.
70		go func() {
71			for {
72				err := client.Send([]byte(`{"time": "` + strconv.FormatInt(time.Now().Unix(), 10) + `"}`))
73				if err != nil {
74					if err != io.EOF {
75						log.Println(err.Error())
76					} else {
77						return
78					}
79				}
80				time.Sleep(5 * time.Second)
81			}
82		}()
83
84		client.OnSubscribe(func(e centrifuge.SubscribeEvent, cb centrifuge.SubscribeCallback) {
85			log.Printf("user %s subscribes on %s", client.UserID(), e.Channel)
86			cb(centrifuge.SubscribeReply{}, nil)
87		})
88
89		client.OnUnsubscribe(func(e centrifuge.UnsubscribeEvent) {
90			log.Printf("user %s unsubscribed from %s", client.UserID(), e.Channel)
91		})
92
93		client.OnPublish(func(e centrifuge.PublishEvent, cb centrifuge.PublishCallback) {
94			log.Printf("user %s publishes into channel %s: %s", client.UserID(), e.Channel, string(e.Data))
95			cb(centrifuge.PublishReply{}, nil)
96		})
97
98		client.OnMessage(func(e centrifuge.MessageEvent) {
99			log.Printf("Message from user: %s, data: %s", client.UserID(), string(e.Data))
100		})
101
102		client.OnDisconnect(func(e centrifuge.DisconnectEvent) {
103			log.Printf("user %s disconnected, disconnect: %s", client.UserID(), e.Disconnect)
104		})
105	})
106
107	if err := node.Run(); err != nil {
108		log.Fatal(err)
109	}
110
111	// Initialize netpoll instance. We will use it to be noticed about incoming
112	// events from listener of user connections.
113	poller, err := netpoll.New(nil)
114	if err != nil {
115		log.Fatal(err)
116	}
117
118	var (
119		// Make pool of X size, Y sized work queue and one pre-spawned
120		// goroutine.
121		pool = NewPool(*workers, *queue, 1)
122	)
123
124	// handle is a new incoming connection handler.
125	// It upgrades TCP connection to WebSocket, registers netpoll listener on
126	// it and stores it as a chat user in Chat instance.
127	//
128	// We will call it below within accept() loop.
129	handle := func(conn net.Conn) {
130		// NOTE: we wrap conn here to show that ws could work with any kind of
131		// io.ReadWriter.
132		safeConn := deadliner{conn, *ioTimeout}
133
134		protoType := centrifuge.ProtocolTypeJSON
135
136		up := ws.Upgrader{
137			OnRequest: func(uri []byte) error {
138				if strings.Contains(string(uri), "format=protobuf") {
139					protoType = centrifuge.ProtocolTypeProtobuf
140				}
141				return nil
142			},
143		}
144
145		// Zero-copy upgrade to WebSocket connection.
146		hs, err := up.Upgrade(safeConn)
147		if err != nil {
148			log.Printf("%s: upgrade error: %v", nameConn(conn), err)
149			_ = conn.Close()
150			return
151		}
152
153		log.Printf("%s: established websocket connection: %+v", nameConn(conn), hs)
154
155		transport := newWebsocketTransport(safeConn, protoType)
156		client, closeFn, err := centrifuge.NewClient(context.Background(), node, transport)
157		if err != nil {
158			log.Printf("%s: client create error: %v", nameConn(conn), err)
159			_ = conn.Close()
160			return
161		}
162
163		// Create netpoll event descriptor for conn.
164		// We want to handle only read events of it.
165		desc := netpoll.Must(netpoll.HandleReadOnce(conn))
166
167		// Subscribe to events about conn.
168		_ = poller.Start(desc, func(ev netpoll.Event) {
169			if ev&(netpoll.EventReadHup|netpoll.EventHup) != 0 {
170				// When ReadHup or Hup received, this mean that client has
171				// closed at least write end of the connection or connections
172				// itself. So we want to stop receive events about such conn
173				// and remove it from the chat registry.
174				_ = poller.Stop(desc)
175				_ = closeFn()
176				return
177			}
178			// Here we can read some new message from connection.
179			// We can not read it right here in callback, because then we will
180			// block the poller's inner loop.
181			// We do not want to spawn a new goroutine to read single message.
182			// But we want to reuse previously spawned goroutine.
183			pool.Schedule(func() {
184				if data, isControl, err := transport.read(); err != nil {
185					// When receive failed, we can only disconnect broken
186					// connection and stop to receive events about it.
187					_ = poller.Stop(desc)
188					_ = closeFn()
189				} else {
190					if !isControl {
191						ok := client.Handle(data)
192						if !ok {
193							_ = poller.Stop(desc)
194							return
195						}
196					}
197					_ = poller.Resume(desc)
198				}
199			})
200		})
201	}
202
203	// Create incoming connections listener.
204	ln, err := net.Listen("tcp", ":3333")
205	if err != nil {
206		log.Fatal(err)
207	}
208
209	log.Printf("websocket is listening on %s", ln.Addr().String())
210
211	// Create netpoll descriptor for the listener.
212	// We use OneShot here to manually resume events stream when we want to.
213	acceptDesc := netpoll.Must(netpoll.HandleListener(
214		ln, netpoll.EventRead|netpoll.EventOneShot,
215	))
216
217	// accept is a channel to signal about next incoming connection Accept()
218	// results.
219	accept := make(chan error, 1)
220
221	// Subscribe to events about listener.
222	_ = poller.Start(acceptDesc, func(e netpoll.Event) {
223		// We do not want to accept incoming connection when goroutine pool is
224		// busy. So if there are no free goroutines during 1ms we want to
225		// cooldown the server and do not receive connection for some short
226		// time.
227		err := pool.ScheduleTimeout(time.Millisecond, func() {
228			conn, err := ln.Accept()
229			if err != nil {
230				accept <- err
231				return
232			}
233
234			accept <- nil
235			handle(conn)
236		})
237		if err == nil {
238			err = <-accept
239		}
240		if err != nil {
241			if err != ErrScheduleTimeout {
242				goto cooldown
243			}
244			if ne, ok := err.(net.Error); ok && ne.Temporary() {
245				goto cooldown
246			}
247
248			log.Fatalf("accept error: %v", err)
249
250		cooldown:
251			delay := 5 * time.Millisecond
252			log.Printf("accept error: %v; retrying in %s", err, delay)
253			time.Sleep(delay)
254		}
255
256		_ = poller.Resume(acceptDesc)
257	})
258
259	go func() {
260		http.Handle("/", http.FileServer(http.Dir("./")))
261		log.Printf("run http server on :8000")
262		if err := http.ListenAndServe(":8000", nil); err != nil {
263			log.Fatal(err)
264		}
265	}()
266
267	waitExitSignal(node)
268	log.Println("bye!")
269}
270
271func nameConn(conn net.Conn) string {
272	return conn.LocalAddr().String() + " > " + conn.RemoteAddr().String()
273}
274
275// deadliner is a wrapper around net.Conn that sets read/write deadlines before
276// every Read() or Write() call.
277type deadliner struct {
278	net.Conn
279	t time.Duration
280}
281
282func (d deadliner) Write(p []byte) (int, error) {
283	if err := d.Conn.SetWriteDeadline(time.Now().Add(d.t)); err != nil {
284		return 0, err
285	}
286	return d.Conn.Write(p)
287}
288
289func (d deadliner) Read(p []byte) (int, error) {
290	if err := d.Conn.SetReadDeadline(time.Now().Add(d.t)); err != nil {
291		return 0, err
292	}
293	return d.Conn.Read(p)
294}
295