1package main
2
3import (
4	"context"
5	"flag"
6	"io"
7	"log"
8	"net"
9	"net/http"
10	"os"
11	"os/signal"
12	"runtime"
13	"strings"
14	"syscall"
15	"time"
16
17	_ "net/http/pprof"
18
19	"github.com/centrifugal/centrifuge"
20	"github.com/gobwas/ws"
21	"github.com/mailru/easygo/netpoll"
22)
23
24var (
25	workers   = flag.Int("workers", 128, "max workers count")
26	queue     = flag.Int("queue", 1, "workers task queue size")
27	ioTimeout = flag.Duration("io_timeout", time.Second, "i/o operations timeout")
28)
29
30func handleLog(e centrifuge.LogEntry) {
31	log.Printf("%s: %v", e.Message, e.Fields)
32}
33
34func waitExitSignal(n *centrifuge.Node) {
35	sigCh := make(chan os.Signal, 1)
36	done := make(chan bool, 1)
37	signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
38	go func() {
39		<-sigCh
40		_ = n.Shutdown(context.Background())
41		done <- true
42	}()
43	<-done
44}
45
46func main() {
47	log.Printf("NumCPU: %d", runtime.NumCPU())
48
49	flag.Parse()
50
51	cfg := centrifuge.DefaultConfig
52	cfg.LogLevel = centrifuge.LogLevelError
53	cfg.LogHandler = handleLog
54
55	node, _ := centrifuge.New(cfg)
56
57	node.OnConnecting(func(ctx context.Context, e centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) {
58		return centrifuge.ConnectReply{
59			Credentials: &centrifuge.Credentials{
60				UserID: "bench",
61			},
62		}, nil
63	})
64
65	node.OnConnect(func(client *centrifuge.Client) {
66		client.OnSubscribe(func(e centrifuge.SubscribeEvent, cb centrifuge.SubscribeCallback) {
67			cb(centrifuge.SubscribeReply{}, nil)
68		})
69
70		client.OnPublish(func(e centrifuge.PublishEvent, cb centrifuge.PublishCallback) {
71			cb(centrifuge.PublishReply{}, nil)
72		})
73
74		client.OnMessage(func(e centrifuge.MessageEvent) {
75			err := client.Send(e.Data)
76			if err != nil {
77				if err != io.EOF {
78					log.Fatalln("error sending to client:", err.Error())
79				}
80			}
81		})
82	})
83
84	if err := node.Run(); err != nil {
85		log.Fatal(err)
86	}
87
88	// Initialize netpoll instance. We will use it to be noticed about incoming
89	// events from listener of user connections.
90	poller, err := netpoll.New(nil)
91	if err != nil {
92		log.Fatal(err)
93	}
94
95	var (
96		// Make pool of X size, Y sized work queue and one pre-spawned
97		// goroutine.
98		pool = NewPool(*workers, *queue, 1)
99	)
100
101	// handle is a new incoming connection handler.
102	// It upgrades TCP connection to WebSocket, registers netpoll listener on
103	// it and stores it as a chat user in Chat instance.
104	//
105	// We will call it below within accept() loop.
106	handle := func(conn net.Conn) {
107		// NOTE: we wrap conn here to show that ws could work with any kind of
108		// io.ReadWriter.
109		safeConn := deadliner{conn, *ioTimeout}
110
111		protoType := centrifuge.ProtocolTypeJSON
112
113		upgrader := ws.Upgrader{
114			OnRequest: func(uri []byte) error {
115				if strings.Contains(string(uri), "format=protobuf") {
116					protoType = centrifuge.ProtocolTypeProtobuf
117				}
118				return nil
119			},
120		}
121
122		// Zero-copy upgrade to WebSocket connection.
123		hs, err := upgrader.Upgrade(safeConn)
124		if err != nil {
125			log.Printf("%s: upgrade error: %v", nameConn(conn), err)
126			_ = conn.Close()
127			return
128		}
129
130		log.Printf("%s: established websocket connection: %+v", nameConn(conn), hs)
131
132		transport := newWebsocketTransport(safeConn, protoType)
133		client, closeFn, err := centrifuge.NewClient(context.Background(), node, transport)
134		if err != nil {
135			log.Printf("%s: client create error: %v", nameConn(conn), err)
136			_ = conn.Close()
137			return
138		}
139
140		// Create netpoll event descriptor for conn.
141		// We want to handle only read events of it.
142		desc := netpoll.Must(netpoll.HandleReadOnce(conn))
143
144		// Subscribe to events about conn.
145		_ = poller.Start(desc, func(ev netpoll.Event) {
146			if ev&(netpoll.EventReadHup|netpoll.EventHup) != 0 {
147				// When ReadHup or Hup received, this mean that client has
148				// closed at least write end of the connection or connections
149				// itself. So we want to stop receive events about such conn
150				// and remove it from the chat registry.
151				_ = poller.Stop(desc)
152				_ = closeFn()
153				return
154			}
155			// Here we can read some new message from connection.
156			// We can not read it right here in callback, because then we will
157			// block the poller's inner loop.
158			// We do not want to spawn a new goroutine to read single message.
159			// But we want to reuse previously spawned goroutine.
160			pool.Schedule(func() {
161				if data, isControl, err := transport.read(); err != nil {
162					// When receive failed, we can only disconnect broken
163					// connection and stop to receive events about it.
164					_ = poller.Stop(desc)
165					_ = closeFn()
166				} else {
167					if !isControl {
168						ok := client.Handle(data)
169						if !ok {
170							_ = poller.Stop(desc)
171							return
172						}
173					}
174					_ = poller.Resume(desc)
175				}
176			})
177		})
178	}
179
180	// Create incoming connections listener.
181	ln, err := net.Listen("tcp", ":8000")
182	if err != nil {
183		log.Fatal(err)
184	}
185
186	log.Printf("websocket is listening on %s", ln.Addr().String())
187
188	// Create netpoll descriptor for the listener.
189	// We use OneShot here to manually resume events stream when we want to.
190	acceptDesc := netpoll.Must(netpoll.HandleListener(
191		ln, netpoll.EventRead|netpoll.EventOneShot,
192	))
193
194	// accept is a channel to signal about next incoming connection Accept()
195	// results.
196	accept := make(chan error, 1)
197
198	// Subscribe to events about listener.
199	_ = poller.Start(acceptDesc, func(e netpoll.Event) {
200		// We do not want to accept incoming connection when goroutine pool is
201		// busy. So if there are no free goroutines during 1ms we want to
202		// cooldown the server and do not receive connection for some short
203		// time.
204		err := pool.ScheduleTimeout(time.Millisecond, func() {
205			conn, err := ln.Accept()
206			if err != nil {
207				accept <- err
208				return
209			}
210
211			accept <- nil
212			handle(conn)
213		})
214		if err == nil {
215			err = <-accept
216		}
217
218		if err != nil {
219			cooldown := func() {
220				delay := 5 * time.Millisecond
221				log.Printf("accept error: %v; retrying in %s", err, delay)
222				time.Sleep(delay)
223			}
224			if err == ErrScheduleTimeout {
225				cooldown()
226			} else if ne, ok := err.(net.Error); ok && ne.Temporary() {
227				cooldown()
228			} else {
229				log.Fatalf("accept error: %v", err)
230			}
231		}
232
233		_ = poller.Resume(acceptDesc)
234	})
235
236	go func() {
237		http.Handle("/", http.FileServer(http.Dir("./")))
238		log.Printf("run http server on :9000")
239		if err := http.ListenAndServe(":9000", nil); err != nil {
240			log.Fatal(err)
241		}
242	}()
243
244	waitExitSignal(node)
245	log.Println("bye!")
246}
247
248func nameConn(conn net.Conn) string {
249	return conn.LocalAddr().String() + " > " + conn.RemoteAddr().String()
250}
251
252// deadliner is a wrapper around net.Conn that sets read/write deadlines before
253// every Read() or Write() call.
254type deadliner struct {
255	net.Conn
256	t time.Duration
257}
258
259func (d deadliner) Write(p []byte) (int, error) {
260	if err := d.Conn.SetWriteDeadline(time.Now().Add(d.t)); err != nil {
261		return 0, err
262	}
263	return d.Conn.Write(p)
264}
265
266func (d deadliner) Read(p []byte) (int, error) {
267	if err := d.Conn.SetReadDeadline(time.Now().Add(d.t)); err != nil {
268		return 0, err
269	}
270	return d.Conn.Read(p)
271}
272