1/*
2Copyright 2013 Google Inc.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8     http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package groupcache
18
19import (
20	"bytes"
21	"fmt"
22	"io"
23	"net/http"
24	"net/url"
25	"strings"
26	"sync"
27
28	"github.com/golang/groupcache/consistenthash"
29	pb "github.com/golang/groupcache/groupcachepb"
30	"github.com/golang/protobuf/proto"
31)
32
33const defaultBasePath = "/_groupcache/"
34
35const defaultReplicas = 50
36
37// HTTPPool implements PeerPicker for a pool of HTTP peers.
38type HTTPPool struct {
39	// Context optionally specifies a context for the server to use when it
40	// receives a request.
41	// If nil, the server uses a nil Context.
42	Context func(*http.Request) Context
43
44	// Transport optionally specifies an http.RoundTripper for the client
45	// to use when it makes a request.
46	// If nil, the client uses http.DefaultTransport.
47	Transport func(Context) http.RoundTripper
48
49	// this peer's base URL, e.g. "https://example.net:8000"
50	self string
51
52	// opts specifies the options.
53	opts HTTPPoolOptions
54
55	mu          sync.Mutex // guards peers and httpGetters
56	peers       *consistenthash.Map
57	httpGetters map[string]*httpGetter // keyed by e.g. "http://10.0.0.2:8008"
58}
59
60// HTTPPoolOptions are the configurations of a HTTPPool.
61type HTTPPoolOptions struct {
62	// BasePath specifies the HTTP path that will serve groupcache requests.
63	// If blank, it defaults to "/_groupcache/".
64	BasePath string
65
66	// Replicas specifies the number of key replicas on the consistent hash.
67	// If blank, it defaults to 50.
68	Replicas int
69
70	// HashFn specifies the hash function of the consistent hash.
71	// If blank, it defaults to crc32.ChecksumIEEE.
72	HashFn consistenthash.Hash
73}
74
75// NewHTTPPool initializes an HTTP pool of peers, and registers itself as a PeerPicker.
76// For convenience, it also registers itself as an http.Handler with http.DefaultServeMux.
77// The self argument should be a valid base URL that points to the current server,
78// for example "http://example.net:8000".
79func NewHTTPPool(self string) *HTTPPool {
80	p := NewHTTPPoolOpts(self, nil)
81	http.Handle(p.opts.BasePath, p)
82	return p
83}
84
85var httpPoolMade bool
86
87// NewHTTPPoolOpts initializes an HTTP pool of peers with the given options.
88// Unlike NewHTTPPool, this function does not register the created pool as an HTTP handler.
89// The returned *HTTPPool implements http.Handler and must be registered using http.Handle.
90func NewHTTPPoolOpts(self string, o *HTTPPoolOptions) *HTTPPool {
91	if httpPoolMade {
92		panic("groupcache: NewHTTPPool must be called only once")
93	}
94	httpPoolMade = true
95
96	p := &HTTPPool{
97		self:        self,
98		httpGetters: make(map[string]*httpGetter),
99	}
100	if o != nil {
101		p.opts = *o
102	}
103	if p.opts.BasePath == "" {
104		p.opts.BasePath = defaultBasePath
105	}
106	if p.opts.Replicas == 0 {
107		p.opts.Replicas = defaultReplicas
108	}
109	p.peers = consistenthash.New(p.opts.Replicas, p.opts.HashFn)
110
111	RegisterPeerPicker(func() PeerPicker { return p })
112	return p
113}
114
115// Set updates the pool's list of peers.
116// Each peer value should be a valid base URL,
117// for example "http://example.net:8000".
118func (p *HTTPPool) Set(peers ...string) {
119	p.mu.Lock()
120	defer p.mu.Unlock()
121	p.peers = consistenthash.New(p.opts.Replicas, p.opts.HashFn)
122	p.peers.Add(peers...)
123	p.httpGetters = make(map[string]*httpGetter, len(peers))
124	for _, peer := range peers {
125		p.httpGetters[peer] = &httpGetter{transport: p.Transport, baseURL: peer + p.opts.BasePath}
126	}
127}
128
129func (p *HTTPPool) PickPeer(key string) (ProtoGetter, bool) {
130	p.mu.Lock()
131	defer p.mu.Unlock()
132	if p.peers.IsEmpty() {
133		return nil, false
134	}
135	if peer := p.peers.Get(key); peer != p.self {
136		return p.httpGetters[peer], true
137	}
138	return nil, false
139}
140
141func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
142	// Parse request.
143	if !strings.HasPrefix(r.URL.Path, p.opts.BasePath) {
144		panic("HTTPPool serving unexpected path: " + r.URL.Path)
145	}
146	parts := strings.SplitN(r.URL.Path[len(p.opts.BasePath):], "/", 2)
147	if len(parts) != 2 {
148		http.Error(w, "bad request", http.StatusBadRequest)
149		return
150	}
151	groupName := parts[0]
152	key := parts[1]
153
154	// Fetch the value for this group/key.
155	group := GetGroup(groupName)
156	if group == nil {
157		http.Error(w, "no such group: "+groupName, http.StatusNotFound)
158		return
159	}
160	var ctx Context
161	if p.Context != nil {
162		ctx = p.Context(r)
163	}
164
165	group.Stats.ServerRequests.Add(1)
166	var value []byte
167	err := group.Get(ctx, key, AllocatingByteSliceSink(&value))
168	if err != nil {
169		http.Error(w, err.Error(), http.StatusInternalServerError)
170		return
171	}
172
173	// Write the value to the response body as a proto message.
174	body, err := proto.Marshal(&pb.GetResponse{Value: value})
175	if err != nil {
176		http.Error(w, err.Error(), http.StatusInternalServerError)
177		return
178	}
179	w.Header().Set("Content-Type", "application/x-protobuf")
180	w.Write(body)
181}
182
183type httpGetter struct {
184	transport func(Context) http.RoundTripper
185	baseURL   string
186}
187
188var bufferPool = sync.Pool{
189	New: func() interface{} { return new(bytes.Buffer) },
190}
191
192func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse) error {
193	u := fmt.Sprintf(
194		"%v%v/%v",
195		h.baseURL,
196		url.QueryEscape(in.GetGroup()),
197		url.QueryEscape(in.GetKey()),
198	)
199	req, err := http.NewRequest("GET", u, nil)
200	if err != nil {
201		return err
202	}
203	tr := http.DefaultTransport
204	if h.transport != nil {
205		tr = h.transport(context)
206	}
207	res, err := tr.RoundTrip(req)
208	if err != nil {
209		return err
210	}
211	defer res.Body.Close()
212	if res.StatusCode != http.StatusOK {
213		return fmt.Errorf("server returned: %v", res.Status)
214	}
215	b := bufferPool.Get().(*bytes.Buffer)
216	b.Reset()
217	defer bufferPool.Put(b)
218	_, err = io.Copy(b, res.Body)
219	if err != nil {
220		return fmt.Errorf("reading response body: %v", err)
221	}
222	err = proto.Unmarshal(b.Bytes(), out)
223	if err != nil {
224		return fmt.Errorf("decoding response body: %v", err)
225	}
226	return nil
227}
228