1package rest
2
3import (
4	"errors"
5	"fmt"
6	"log"
7	"net/http"
8	"path"
9	"reflect"
10	"strconv"
11	"strings"
12	"time"
13
14	"github.com/pascaldekloe/goe/el"
15)
16
17var (
18	// ErrNotFound signals that the resource is absent.
19	// See CRUDRepo's SetReadFunc, SetUpdateFunc and SetDeleteFunc for the details.
20	ErrNotFound = errors.New("no such entry")
21
22	// ErrOptimisticLock signals that the latest version does not match the request.
23	// See CRUDRepo's SetUpdateFunc and SetDeleteFunc for the details.
24	ErrOptimisticLock = errors.New("lost optimistic lock")
25)
26
27var (
28	keyType   = reflect.TypeOf(int64(0))
29	errorType = reflect.TypeOf((*error)(nil)).Elem()
30)
31
32// CRUDRepo is a REST repository.
33type CRUDRepo struct {
34	// mountLocation is the root path of this repository.
35	mountLoc string
36
37	// versionPath is the GoEL expression to the data's version int64.
38	versionPath string
39
40	create, read, update, delete *reflect.Value
41
42	dataType reflect.Type
43}
44
45// NewCRUD returns a new REST repository for the CRUD operations.
46// The mountLocation specifies the root for CRUDRepo.ServeHTTP.
47// The versionPath is a GoEL expression to the date version (in the data type).
48//
49// It's operation is based on two assumptions.
50// 1) Identifiers are int64.
51// 2) Versions are int64 unix timestamps in nanoseconds.
52func NewCRUD(mountLocation, versionPath string) *CRUDRepo {
53	return &CRUDRepo{
54		mountLoc:    path.Clean(mountLocation),
55		versionPath: versionPath,
56	}
57}
58
59// SetCreateFunc enables create support.
60// The method panics on any of the following conditions.
61// 1) f does not match signature func(data T) (id int64, err error)
62// 2) Data type T does not match the other CRUD operations.
63// 3) Data type T is not a pointer.
64//
65// It is the responsibility of f to set the version.
66func (repo *CRUDRepo) SetCreateFunc(f interface{}) {
67	v := reflect.ValueOf(f)
68	repo.create = &v
69
70	t := v.Type()
71	if t.Kind() != reflect.Func || t.NumIn() != 1 || t.NumOut() != 2 || t.Out(0) != keyType || !t.Out(1).Implements(errorType) {
72		log.Panic("create is not a func(data T, id int64) error")
73	}
74	repo.setDataType(t.In(0))
75}
76
77// SetReadFunc enables read support.
78// The method panics on any of the following conditions.
79// 1) f does not match signature func(id, version int64) (hit T, err error)
80// 2) Data type T does not match the other CRUD operations.
81// 3) Data type T is not a pointer.
82//
83// When the id is not found f must return ErrNotFound.
84// The version must be honored and the latest version should be served as a fallback.
85func (repo *CRUDRepo) SetReadFunc(f interface{}) {
86	v := reflect.ValueOf(f)
87	repo.read = &v
88
89	t := v.Type()
90	if t.Kind() != reflect.Func || t.NumIn() != 2 || t.In(0) != keyType || t.In(1) != keyType || t.NumOut() != 2 || !t.Out(1).Implements(errorType) {
91		log.Panic("read is not a func(id, version int64) (T, error)")
92	}
93	repo.setDataType(t.Out(0))
94}
95
96// SetUpdateFunc enables update support.
97// The method panics on any of the following conditions.
98// 1) f does not match signature func(id int64, data T) (error)
99// 2) Data type T does not match the other CRUD operations.
100// 3) Data type T is not a pointer.
101//
102// When the id is not found f must return ErrNotFound.
103// When the data's version is not equal to 0 and version does not match the latest
104// one available then f must skip normal operation and return ErrOptimisticLock.
105// It is the responsibility of f to set the new version.
106func (repo *CRUDRepo) SetUpdateFunc(f interface{}) {
107	v := reflect.ValueOf(f)
108	repo.update = &v
109
110	t := v.Type()
111	if t.Kind() != reflect.Func || t.NumIn() != 2 || t.In(0) != keyType || t.NumOut() != 1 || !t.Out(0).Implements(errorType) {
112		log.Panic("update is not a func(id int64, data T) error")
113	}
114	repo.setDataType(t.In(1))
115
116}
117
118// SetUpdateFunc enables update support.
119// The method panics when f does not match signature func(id, version int64) error.
120//
121// When the id is not found f must return ErrNotFound.
122// When the version is not equal to 0 and version does not match the latest
123// one available then f must skip normal operation and return ErrOptimisticLock.
124func (repo *CRUDRepo) SetDeleteFunc(f interface{}) {
125	v := reflect.ValueOf(f)
126	repo.delete = &v
127
128	t := v.Type()
129	if t.Kind() != reflect.Func || t.NumIn() != 2 || t.In(0) != keyType || t.In(1) != keyType && t.NumOut() != 1 || !t.Out(0).Implements(errorType) {
130		log.Panic("delete is not a func(id, version int64) error")
131	}
132}
133
134func (repo *CRUDRepo) setDataType(t reflect.Type) {
135	if t.Kind() != reflect.Ptr {
136		log.Panicf("goe rest: CRUD operation's data type %s must be a pointer", t)
137	}
138
139	switch repo.dataType {
140	case nil:
141		repo.dataType = t
142
143		if n := el.Assign(reflect.New(t).Interface(), repo.versionPath, 99); n != 1 {
144			log.Panicf("goe rest: version path %q matches %d element on type %s", repo.versionPath, n, t)
145		}
146	case t:
147		// do nothing
148	default:
149		log.Panicf("goe rest: CRUD operation's data type %s does not match %s", t, repo.dataType)
150	}
151}
152
153// ServeHTTP honors the http.Handler interface for the mount point provided with NewCRUD.
154// For now only JSON is supported.
155func (repo *CRUDRepo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
156	p := path.Clean(r.URL.Path)
157	if !strings.HasPrefix(p, repo.mountLoc) {
158		log.Printf("goe rest: basepath %q mismatch with %q from %q", repo.mountLoc, p, r.URL.String())
159		http.Error(w, "path mismatch", http.StatusNotFound)
160		return
161	}
162
163	if len(p) == len(repo.mountLoc) {
164		switch r.Method {
165		default:
166			if repo.create != nil {
167				w.Header().Set("Allow", "POST")
168				w.WriteHeader(http.StatusMethodNotAllowed)
169			} else {
170				http.Error(w, "", http.StatusNotFound)
171			}
172		case "POST":
173			if repo.create != nil {
174				repo.serveCreate(w, r)
175			} else {
176				http.Error(w, "", http.StatusNotFound)
177			}
178		}
179		return
180	}
181
182	if i := len(repo.mountLoc); p[i] == '/' {
183		p = p[i+1:]
184	} else {
185		p = p[i:]
186	}
187	if i := strings.IndexByte(p, '/'); i >= 0 {
188		http.Error(w, fmt.Sprintf("goe rest: no such subdirectory: %q", p), http.StatusNotFound)
189		return
190	}
191
192	id, err := strconv.ParseInt(p, 10, 64)
193	if err != nil {
194		http.Error(w, fmt.Sprintf("goe rest: malformed ID: %s", err), http.StatusNotFound)
195		return
196	}
197
198	switch r.Method {
199	case "GET", "HEAD":
200		if repo.read != nil {
201			repo.serveRead(w, r, id)
202			return
203		}
204	case "PUT":
205		if repo.update != nil {
206			repo.serveUpdate(w, r, id)
207			return
208		}
209	case "DELETE":
210		if repo.delete != nil {
211			repo.serveDelete(w, r, id)
212			return
213		}
214	case "OPTIONS":
215		w.Header().Set("Allow", repo.resourceMethods())
216		w.WriteHeader(http.StatusOK)
217		return
218	}
219	w.Header().Set("Allow", repo.resourceMethods())
220	w.WriteHeader(http.StatusMethodNotAllowed)
221}
222
223// resourceMethods lists the HTTP methods served on r's resources.
224func (r *CRUDRepo) resourceMethods() string {
225	a := make([]string, 1, 5)
226	a[0] = "OPTIONS"
227	if r.read != nil {
228		a = append(a, "GET", "HEAD")
229	}
230	if r.update != nil {
231		a = append(a, "PUT")
232	}
233	if r.delete != nil {
234		a = append(a, "DELETE")
235	}
236	return strings.Join(a, ", ")
237}
238
239func (repo *CRUDRepo) serveCreate(w http.ResponseWriter, r *http.Request) {
240	v := reflect.New(repo.dataType)
241	if !ReceiveJSON(v.Interface(), r, w) {
242		return
243	}
244
245	result := repo.create.Call([]reflect.Value{v.Elem()})
246	if !result[1].IsNil() {
247		err := result[1].Interface().(error)
248		log.Print("goe/rest: create: ", err)
249		http.Error(w, err.Error(), http.StatusInternalServerError)
250		return
251	}
252
253	loc := *r.URL // copy
254	loc.Path = path.Join(loc.Path, strconv.FormatInt(result[0].Int(), 10))
255	loc.RawQuery = ""
256	loc.Fragment = ""
257
258	h := w.Header()
259	h.Set("Location", loc.String())
260
261	version, _ := el.Int(repo.versionPath, v.Interface())
262	h.Set("ETag", fmt.Sprintf(`"%d"`, version))
263
264	w.WriteHeader(http.StatusCreated)
265}
266
267func (repo *CRUDRepo) serveRead(w http.ResponseWriter, r *http.Request, id int64) {
268	versionReq, ok := versionQuery(r, w)
269	if !ok {
270		return
271	}
272
273	result := repo.read.Call([]reflect.Value{reflect.ValueOf(id), reflect.ValueOf(int64(versionReq))})
274	if !result[1].IsNil() {
275		switch err := result[1].Interface().(error); err {
276		case ErrNotFound:
277			http.Error(w, fmt.Sprintf("ID %d not found", id), http.StatusNotFound)
278		default:
279			log.Print("goe/rest: read: ", err)
280			http.Error(w, err.Error(), http.StatusInternalServerError)
281		}
282		return
283	}
284
285	version, _ := el.Int(repo.versionPath, result[0].Interface())
286	if versionReq != 0 && version != versionReq {
287		http.Error(w, fmt.Sprintf("version %d not found (latest is %d)", versionReq, version), http.StatusNotFound)
288		return
289	}
290
291	h := w.Header()
292
293	etag := fmt.Sprintf(`"%d"`, version)
294	h.Set("ETag", etag)
295
296	loc := *r.URL // copy
297	loc.RawQuery = fmt.Sprintf("v=%d", version)
298	loc.Fragment = ""
299	h.Set("Content-Location", loc.String())
300
301	h.Set("Allow", repo.resourceMethods())
302
303	// BUG(pascaldekloe): No support for multiple entity tags in If-None-Match header.
304	for _, s := range r.Header["If-None-Match"] {
305		if s == etag {
306			w.WriteHeader(http.StatusNotModified)
307			return
308		}
309	}
310
311	for _, s := range r.Header["If-Modified-Since"] {
312		t, err := time.Parse(time.RFC1123, s)
313		if err != nil {
314			http.Error(w, fmt.Sprintf("If-Unmodified-Since header %q not RFC1123 compliant: %s", s, err), http.StatusBadRequest)
315			return
316		}
317		// Round down to RFC 1123 resolution:
318		resolution := int64(time.Second)
319		if t.After(time.Unix(0, (version/resolution)*resolution)) {
320			w.WriteHeader(http.StatusNotModified)
321			return
322		}
323	}
324
325	timestamp := time.Unix(0, version)
326	h.Set("Last-Modified", timestamp.In(time.UTC).Format(time.RFC1123))
327
328	if r.Method != "HEAD" {
329		ServeJSON(w, http.StatusOK, result[0].Interface())
330	}
331}
332
333func (repo *CRUDRepo) serveUpdate(w http.ResponseWriter, r *http.Request, id int64) {
334	v := reflect.New(repo.dataType)
335	if !ReceiveJSON(v.Interface(), r, w) {
336		return
337	}
338
339	queryVersion, ok := versionQuery(r, w)
340	if !ok {
341		return
342	}
343
344	matchVersion, ok := versionMatch(r, w)
345	if !ok {
346		return
347	}
348
349	var version int64
350	switch {
351	case queryVersion == 0:
352		version = matchVersion
353	case matchVersion == 0:
354		version = queryVersion
355	case queryVersion == matchVersion:
356		version = matchVersion
357	default:
358		http.Error(w, fmt.Sprintf("query parameter v %d does not match If-Match header %d", queryVersion, matchVersion), http.StatusPreconditionFailed)
359		return
360	}
361	if version != 0 {
362		el.Assign(v.Interface(), repo.versionPath, version)
363	}
364
365	result := repo.update.Call([]reflect.Value{reflect.ValueOf(id), v.Elem()})
366	if !result[0].IsNil() {
367		switch err := result[0].Interface().(error); err {
368		case ErrNotFound:
369			http.Error(w, "", http.StatusNotFound)
370		case ErrOptimisticLock:
371			if matchVersion != 0 {
372				http.Error(w, err.Error(), http.StatusPreconditionFailed)
373				return
374			}
375			http.Error(w, "not the latest version", http.StatusMethodNotAllowed)
376		default:
377			log.Printf("goe rest: update %d v%d: %s", id, version, err)
378			http.Error(w, err.Error(), http.StatusInternalServerError)
379		}
380		return
381	}
382
383	h := w.Header()
384
385	version, _ = el.Int(repo.versionPath, v.Interface())
386	h.Set("ETag", fmt.Sprintf(`"%d"`, version))
387	h.Set("Last-Modified", time.Unix(0, version).In(time.UTC).Format(time.RFC1123))
388
389	h.Set("Allow", repo.resourceMethods())
390
391	loc := *r.URL // copy
392	loc.RawQuery = fmt.Sprintf("v=%d", version)
393	loc.Fragment = ""
394	h.Set("Content-Location", loc.String())
395
396	ServeJSON(w, http.StatusOK, v.Interface())
397}
398
399func (repo *CRUDRepo) serveDelete(w http.ResponseWriter, r *http.Request, id int64) {
400	queryVersion, ok := versionQuery(r, w)
401	if !ok {
402		return
403	}
404
405	matchVersion, ok := versionMatch(r, w)
406	if !ok {
407		return
408	}
409
410	var version int64
411	switch {
412	case queryVersion == 0:
413		version = matchVersion
414	case matchVersion == 0:
415		version = queryVersion
416	case queryVersion != matchVersion:
417		http.Error(w, fmt.Sprintf("query parameter v %d does not match If-Match header %d", queryVersion, matchVersion), http.StatusPreconditionFailed)
418		return
419	default:
420		version = matchVersion
421	}
422
423	result := repo.delete.Call([]reflect.Value{reflect.ValueOf(id), reflect.ValueOf(version)})
424	if !result[0].IsNil() {
425		switch err := result[0].Interface().(error); err {
426		case ErrNotFound:
427			http.Error(w, "", http.StatusNotFound)
428		case ErrOptimisticLock:
429			if matchVersion != 0 {
430				http.Error(w, err.Error(), http.StatusPreconditionFailed)
431				return
432			}
433			http.Error(w, "not the latest version", http.StatusMethodNotAllowed)
434		default:
435			log.Printf("goe rest: delete %d v%d: %s", id, version, err)
436			http.Error(w, err.Error(), http.StatusInternalServerError)
437		}
438		return
439	}
440
441	w.WriteHeader(http.StatusNoContent)
442
443}
444
445// versionQuery parses URL parameter v or it returns ok false on error.
446func versionQuery(r *http.Request, w http.ResponseWriter) (version int64, ok bool) {
447	switch params := r.URL.Query()["v"]; len(params) {
448	case 0:
449		return 0, true
450	case 1:
451		i, err := strconv.ParseInt(params[0], 10, 64)
452		if err != nil {
453			http.Error(w, "query parameter v: malformed version number", http.StatusNotFound)
454			return 0, false
455		}
456		return i, true
457	default:
458		http.Error(w, "multiple version query parameters", http.StatusBadRequest)
459		return 0, false
460	}
461}
462
463// versionMatch parses the If-Match header or it returns ok false on error.
464func versionMatch(r *http.Request, w http.ResponseWriter) (version int64, ok bool) {
465	// BUG(pascaldekloe): No support for multiple entity tags in If-Match header.
466
467	tags := strings.Join(r.Header["If-Match"], ", ")
468	if tags == "" || tags == "*" {
469		return 0, true
470	}
471
472	const linearWhiteSpace = " \t"
473	tag := strings.Trim(tags, linearWhiteSpace)
474	if tag[0] != '"' || tag[len(tag)-1] != '"' {
475		http.Error(w, fmt.Sprintf("need opaque tags in If-Match header %q", tag), http.StatusBadRequest)
476		return 0, false
477	}
478	s := tag[1 : len(tag)-1]
479
480	i, err := strconv.ParseInt(s, 10, 64)
481	if err != nil {
482		http.Error(w, fmt.Sprintf("malformed or unknow tag in If-Match header %q", tag), http.StatusPreconditionFailed)
483		return 0, false
484	}
485	return i, true
486}
487