1package mockserver
2
3import (
4	"context"
5	"fmt"
6	"io"
7	"log"
8	"net"
9
10	"github.com/dustin/gomemcached"
11	"github.com/dustin/gomemcached/server"
12)
13
14const (
15	Noop = 0x0a
16	Stat = 0x10
17)
18
19const (
20	Success        = 0x00
21	UnknownCommand = 0x81
22)
23
24const debug = false
25
26func printDebugf(format string, v ...interface{}) {
27	if debug {
28		log.Printf(format, v...)
29	}
30}
31
32type MCRequest = gomemcached.MCRequest
33type MCResponse = gomemcached.MCResponse
34
35type HandlerFunc func(req *MCRequest, w io.Writer) *MCResponse
36
37type MockServer struct {
38	handlers map[gomemcached.CommandCode]HandlerFunc
39	listener net.Listener
40	ctx      context.Context
41	port     int
42	stop     context.CancelFunc
43}
44
45func (s *MockServer) RegisterHandler(code uint8, fn HandlerFunc) {
46	s.handlers[gomemcached.CommandCode(code)] = fn
47}
48
49type chanReq struct {
50	req *MCRequest
51	res chan *MCResponse
52	w   io.Writer
53}
54
55func notFound(_ *MCRequest) *MCResponse {
56	return &MCResponse{
57		Status: gomemcached.UNKNOWN_COMMAND,
58	}
59}
60
61type reqHandler struct {
62	ch chan chanReq
63}
64
65func (rh *reqHandler) HandleMessage(w io.Writer, req *MCRequest) *MCResponse {
66	cr := chanReq{
67		req,
68		make(chan *MCResponse),
69		w,
70	}
71
72	rh.ch <- cr
73
74	return <-cr.res
75}
76
77func (s *MockServer) handle(req *MCRequest, w io.Writer) (rv *MCResponse) {
78	if h, ok := s.handlers[req.Opcode]; ok {
79		rv = h(req, w)
80	} else {
81		return notFound(req)
82	}
83
84	return rv
85}
86
87func (s *MockServer) dispatch(input chan chanReq) {
88	// TODO: stop goroutine
89	for {
90		req := <-input
91		printDebugf("Got a request: %s", req.req)
92		req.res <- s.handle(req.req, req.w)
93	}
94}
95
96func handleIO(conn net.Conn, rh memcached.RequestHandler) {
97	// Explicitly ignoring errors since they all result in the
98	// client getting hung up on and many are common.
99	_ = memcached.HandleIO(conn, rh)
100}
101
102func (s *MockServer) ListenAndServe() {
103	var err error
104
105	s.listener, err = net.Listen("tcp", s.GetAddr())
106	if err != nil {
107		panic(err)
108	}
109
110	reqChannel := make(chan chanReq)
111
112	go s.dispatch(reqChannel)
113	rh := &reqHandler{reqChannel}
114
115	printDebugf("Listening on %s", s.listener.Addr())
116
117	for {
118		conn, err := s.listener.Accept()
119		select {
120		case <-s.ctx.Done():
121			printDebugf("Server stopped")
122			return
123		default:
124			if err != nil {
125				printDebugf("Error accepting from %s", s.listener)
126			} else {
127				printDebugf("Got a connection from %v", conn.RemoteAddr())
128				go handleIO(conn, rh)
129			}
130		}
131	}
132}
133
134func (s *MockServer) Stop() {
135	// TODO: finish client requests and close connections
136	if s.listener != nil {
137		s.stop()
138		_ = s.listener.Close()
139	}
140}
141
142func (s *MockServer) GetAddr() string {
143	return fmt.Sprintf("localhost:%d", s.port)
144}
145
146func getFreePort() (int, error) {
147	addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
148	if err != nil {
149		return 0, err
150	}
151
152	ls, err := net.ListenTCP("tcp", addr)
153	if err != nil {
154		return 0, err
155	}
156
157	_ = ls.Close()
158
159	return ls.Addr().(*net.TCPAddr).Port, nil
160}
161
162func NewMockServer() (srv *MockServer, err error) {
163	ctx, cancel := context.WithCancel(context.Background())
164
165	srv = &MockServer{
166		handlers: make(map[gomemcached.CommandCode]HandlerFunc),
167		ctx:      ctx,
168		stop:     cancel,
169	}
170
171	srv.port, err = getFreePort()
172	if err != nil {
173		return nil, err
174	}
175
176	return srv, nil
177}
178