1package mockserver 2 3import ( 4 "context" 5 "fmt" 6 "io" 7 "log" 8 "net" 9 10 "github.com/dustin/gomemcached" 11 "github.com/dustin/gomemcached/server" 12) 13 14const ( 15 Noop = 0x0a 16 Stat = 0x10 17) 18 19const ( 20 Success = 0x00 21 UnknownCommand = 0x81 22) 23 24const debug = false 25 26func printDebugf(format string, v ...interface{}) { 27 if debug { 28 log.Printf(format, v...) 29 } 30} 31 32type MCRequest = gomemcached.MCRequest 33type MCResponse = gomemcached.MCResponse 34 35type HandlerFunc func(req *MCRequest, w io.Writer) *MCResponse 36 37type MockServer struct { 38 handlers map[gomemcached.CommandCode]HandlerFunc 39 listener net.Listener 40 ctx context.Context 41 port int 42 stop context.CancelFunc 43} 44 45func (s *MockServer) RegisterHandler(code uint8, fn HandlerFunc) { 46 s.handlers[gomemcached.CommandCode(code)] = fn 47} 48 49type chanReq struct { 50 req *MCRequest 51 res chan *MCResponse 52 w io.Writer 53} 54 55func notFound(_ *MCRequest) *MCResponse { 56 return &MCResponse{ 57 Status: gomemcached.UNKNOWN_COMMAND, 58 } 59} 60 61type reqHandler struct { 62 ch chan chanReq 63} 64 65func (rh *reqHandler) HandleMessage(w io.Writer, req *MCRequest) *MCResponse { 66 cr := chanReq{ 67 req, 68 make(chan *MCResponse), 69 w, 70 } 71 72 rh.ch <- cr 73 74 return <-cr.res 75} 76 77func (s *MockServer) handle(req *MCRequest, w io.Writer) (rv *MCResponse) { 78 if h, ok := s.handlers[req.Opcode]; ok { 79 rv = h(req, w) 80 } else { 81 return notFound(req) 82 } 83 84 return rv 85} 86 87func (s *MockServer) dispatch(input chan chanReq) { 88 // TODO: stop goroutine 89 for { 90 req := <-input 91 printDebugf("Got a request: %s", req.req) 92 req.res <- s.handle(req.req, req.w) 93 } 94} 95 96func handleIO(conn net.Conn, rh memcached.RequestHandler) { 97 // Explicitly ignoring errors since they all result in the 98 // client getting hung up on and many are common. 99 _ = memcached.HandleIO(conn, rh) 100} 101 102func (s *MockServer) ListenAndServe() { 103 var err error 104 105 s.listener, err = net.Listen("tcp", s.GetAddr()) 106 if err != nil { 107 panic(err) 108 } 109 110 reqChannel := make(chan chanReq) 111 112 go s.dispatch(reqChannel) 113 rh := &reqHandler{reqChannel} 114 115 printDebugf("Listening on %s", s.listener.Addr()) 116 117 for { 118 conn, err := s.listener.Accept() 119 select { 120 case <-s.ctx.Done(): 121 printDebugf("Server stopped") 122 return 123 default: 124 if err != nil { 125 printDebugf("Error accepting from %s", s.listener) 126 } else { 127 printDebugf("Got a connection from %v", conn.RemoteAddr()) 128 go handleIO(conn, rh) 129 } 130 } 131 } 132} 133 134func (s *MockServer) Stop() { 135 // TODO: finish client requests and close connections 136 if s.listener != nil { 137 s.stop() 138 _ = s.listener.Close() 139 } 140} 141 142func (s *MockServer) GetAddr() string { 143 return fmt.Sprintf("localhost:%d", s.port) 144} 145 146func getFreePort() (int, error) { 147 addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 148 if err != nil { 149 return 0, err 150 } 151 152 ls, err := net.ListenTCP("tcp", addr) 153 if err != nil { 154 return 0, err 155 } 156 157 _ = ls.Close() 158 159 return ls.Addr().(*net.TCPAddr).Port, nil 160} 161 162func NewMockServer() (srv *MockServer, err error) { 163 ctx, cancel := context.WithCancel(context.Background()) 164 165 srv = &MockServer{ 166 handlers: make(map[gomemcached.CommandCode]HandlerFunc), 167 ctx: ctx, 168 stop: cancel, 169 } 170 171 srv.port, err = getFreePort() 172 if err != nil { 173 return nil, err 174 } 175 176 return srv, nil 177} 178