1package rest 2 3import ( 4 "errors" 5 "fmt" 6 "log" 7 "net/http" 8 "path" 9 "reflect" 10 "strconv" 11 "strings" 12 "time" 13 14 "github.com/pascaldekloe/goe/el" 15) 16 17var ( 18 // ErrNotFound signals that the resource is absent. 19 // See CRUDRepo's SetReadFunc, SetUpdateFunc and SetDeleteFunc for the details. 20 ErrNotFound = errors.New("no such entry") 21 22 // ErrOptimisticLock signals that the latest version does not match the request. 23 // See CRUDRepo's SetUpdateFunc and SetDeleteFunc for the details. 24 ErrOptimisticLock = errors.New("lost optimistic lock") 25) 26 27var ( 28 keyType = reflect.TypeOf(int64(0)) 29 errorType = reflect.TypeOf((*error)(nil)).Elem() 30) 31 32// CRUDRepo is a REST repository. 33type CRUDRepo struct { 34 // mountLocation is the root path of this repository. 35 mountLoc string 36 37 // versionPath is the GoEL expression to the data's version int64. 38 versionPath string 39 40 create, read, update, delete *reflect.Value 41 42 dataType reflect.Type 43} 44 45// NewCRUD returns a new REST repository for the CRUD operations. 46// The mountLocation specifies the root for CRUDRepo.ServeHTTP. 47// The versionPath is a GoEL expression to the date version (in the data type). 48// 49// It's operation is based on two assumptions. 50// 1) Identifiers are int64. 51// 2) Versions are int64 unix timestamps in nanoseconds. 52func NewCRUD(mountLocation, versionPath string) *CRUDRepo { 53 return &CRUDRepo{ 54 mountLoc: path.Clean(mountLocation), 55 versionPath: versionPath, 56 } 57} 58 59// SetCreateFunc enables create support. 60// The method panics on any of the following conditions. 61// 1) f does not match signature func(data T) (id int64, err error) 62// 2) Data type T does not match the other CRUD operations. 63// 3) Data type T is not a pointer. 64// 65// It is the responsibility of f to set the version. 66func (repo *CRUDRepo) SetCreateFunc(f interface{}) { 67 v := reflect.ValueOf(f) 68 repo.create = &v 69 70 t := v.Type() 71 if t.Kind() != reflect.Func || t.NumIn() != 1 || t.NumOut() != 2 || t.Out(0) != keyType || !t.Out(1).Implements(errorType) { 72 log.Panic("create is not a func(data T, id int64) error") 73 } 74 repo.setDataType(t.In(0)) 75} 76 77// SetReadFunc enables read support. 78// The method panics on any of the following conditions. 79// 1) f does not match signature func(id, version int64) (hit T, err error) 80// 2) Data type T does not match the other CRUD operations. 81// 3) Data type T is not a pointer. 82// 83// When the id is not found f must return ErrNotFound. 84// The version must be honored and the latest version should be served as a fallback. 85func (repo *CRUDRepo) SetReadFunc(f interface{}) { 86 v := reflect.ValueOf(f) 87 repo.read = &v 88 89 t := v.Type() 90 if t.Kind() != reflect.Func || t.NumIn() != 2 || t.In(0) != keyType || t.In(1) != keyType || t.NumOut() != 2 || !t.Out(1).Implements(errorType) { 91 log.Panic("read is not a func(id, version int64) (T, error)") 92 } 93 repo.setDataType(t.Out(0)) 94} 95 96// SetUpdateFunc enables update support. 97// The method panics on any of the following conditions. 98// 1) f does not match signature func(id int64, data T) (error) 99// 2) Data type T does not match the other CRUD operations. 100// 3) Data type T is not a pointer. 101// 102// When the id is not found f must return ErrNotFound. 103// When the data's version is not equal to 0 and version does not match the latest 104// one available then f must skip normal operation and return ErrOptimisticLock. 105// It is the responsibility of f to set the new version. 106func (repo *CRUDRepo) SetUpdateFunc(f interface{}) { 107 v := reflect.ValueOf(f) 108 repo.update = &v 109 110 t := v.Type() 111 if t.Kind() != reflect.Func || t.NumIn() != 2 || t.In(0) != keyType || t.NumOut() != 1 || !t.Out(0).Implements(errorType) { 112 log.Panic("update is not a func(id int64, data T) error") 113 } 114 repo.setDataType(t.In(1)) 115 116} 117 118// SetUpdateFunc enables update support. 119// The method panics when f does not match signature func(id, version int64) error. 120// 121// When the id is not found f must return ErrNotFound. 122// When the version is not equal to 0 and version does not match the latest 123// one available then f must skip normal operation and return ErrOptimisticLock. 124func (repo *CRUDRepo) SetDeleteFunc(f interface{}) { 125 v := reflect.ValueOf(f) 126 repo.delete = &v 127 128 t := v.Type() 129 if t.Kind() != reflect.Func || t.NumIn() != 2 || t.In(0) != keyType || t.In(1) != keyType && t.NumOut() != 1 || !t.Out(0).Implements(errorType) { 130 log.Panic("delete is not a func(id, version int64) error") 131 } 132} 133 134func (repo *CRUDRepo) setDataType(t reflect.Type) { 135 if t.Kind() != reflect.Ptr { 136 log.Panicf("goe rest: CRUD operation's data type %s must be a pointer", t) 137 } 138 139 switch repo.dataType { 140 case nil: 141 repo.dataType = t 142 143 if n := el.Assign(reflect.New(t).Interface(), repo.versionPath, 99); n != 1 { 144 log.Panicf("goe rest: version path %q matches %d element on type %s", repo.versionPath, n, t) 145 } 146 case t: 147 // do nothing 148 default: 149 log.Panicf("goe rest: CRUD operation's data type %s does not match %s", t, repo.dataType) 150 } 151} 152 153// ServeHTTP honors the http.Handler interface for the mount point provided with NewCRUD. 154// For now only JSON is supported. 155func (repo *CRUDRepo) ServeHTTP(w http.ResponseWriter, r *http.Request) { 156 p := path.Clean(r.URL.Path) 157 if !strings.HasPrefix(p, repo.mountLoc) { 158 log.Printf("goe rest: basepath %q mismatch with %q from %q", repo.mountLoc, p, r.URL.String()) 159 http.Error(w, "path mismatch", http.StatusNotFound) 160 return 161 } 162 163 if len(p) == len(repo.mountLoc) { 164 switch r.Method { 165 default: 166 if repo.create != nil { 167 w.Header().Set("Allow", "POST") 168 w.WriteHeader(http.StatusMethodNotAllowed) 169 } else { 170 http.Error(w, "", http.StatusNotFound) 171 } 172 case "POST": 173 if repo.create != nil { 174 repo.serveCreate(w, r) 175 } else { 176 http.Error(w, "", http.StatusNotFound) 177 } 178 } 179 return 180 } 181 182 if i := len(repo.mountLoc); p[i] == '/' { 183 p = p[i+1:] 184 } else { 185 p = p[i:] 186 } 187 if i := strings.IndexByte(p, '/'); i >= 0 { 188 http.Error(w, fmt.Sprintf("goe rest: no such subdirectory: %q", p), http.StatusNotFound) 189 return 190 } 191 192 id, err := strconv.ParseInt(p, 10, 64) 193 if err != nil { 194 http.Error(w, fmt.Sprintf("goe rest: malformed ID: %s", err), http.StatusNotFound) 195 return 196 } 197 198 switch r.Method { 199 case "GET", "HEAD": 200 if repo.read != nil { 201 repo.serveRead(w, r, id) 202 return 203 } 204 case "PUT": 205 if repo.update != nil { 206 repo.serveUpdate(w, r, id) 207 return 208 } 209 case "DELETE": 210 if repo.delete != nil { 211 repo.serveDelete(w, r, id) 212 return 213 } 214 case "OPTIONS": 215 w.Header().Set("Allow", repo.resourceMethods()) 216 w.WriteHeader(http.StatusOK) 217 return 218 } 219 w.Header().Set("Allow", repo.resourceMethods()) 220 w.WriteHeader(http.StatusMethodNotAllowed) 221} 222 223// resourceMethods lists the HTTP methods served on r's resources. 224func (r *CRUDRepo) resourceMethods() string { 225 a := make([]string, 1, 5) 226 a[0] = "OPTIONS" 227 if r.read != nil { 228 a = append(a, "GET", "HEAD") 229 } 230 if r.update != nil { 231 a = append(a, "PUT") 232 } 233 if r.delete != nil { 234 a = append(a, "DELETE") 235 } 236 return strings.Join(a, ", ") 237} 238 239func (repo *CRUDRepo) serveCreate(w http.ResponseWriter, r *http.Request) { 240 v := reflect.New(repo.dataType) 241 if !ReceiveJSON(v.Interface(), r, w) { 242 return 243 } 244 245 result := repo.create.Call([]reflect.Value{v.Elem()}) 246 if !result[1].IsNil() { 247 err := result[1].Interface().(error) 248 log.Print("goe/rest: create: ", err) 249 http.Error(w, err.Error(), http.StatusInternalServerError) 250 return 251 } 252 253 loc := *r.URL // copy 254 loc.Path = path.Join(loc.Path, strconv.FormatInt(result[0].Int(), 10)) 255 loc.RawQuery = "" 256 loc.Fragment = "" 257 258 h := w.Header() 259 h.Set("Location", loc.String()) 260 261 version, _ := el.Int(repo.versionPath, v.Interface()) 262 h.Set("ETag", fmt.Sprintf(`"%d"`, version)) 263 264 w.WriteHeader(http.StatusCreated) 265} 266 267func (repo *CRUDRepo) serveRead(w http.ResponseWriter, r *http.Request, id int64) { 268 versionReq, ok := versionQuery(r, w) 269 if !ok { 270 return 271 } 272 273 result := repo.read.Call([]reflect.Value{reflect.ValueOf(id), reflect.ValueOf(int64(versionReq))}) 274 if !result[1].IsNil() { 275 switch err := result[1].Interface().(error); err { 276 case ErrNotFound: 277 http.Error(w, fmt.Sprintf("ID %d not found", id), http.StatusNotFound) 278 default: 279 log.Print("goe/rest: read: ", err) 280 http.Error(w, err.Error(), http.StatusInternalServerError) 281 } 282 return 283 } 284 285 version, _ := el.Int(repo.versionPath, result[0].Interface()) 286 if versionReq != 0 && version != versionReq { 287 http.Error(w, fmt.Sprintf("version %d not found (latest is %d)", versionReq, version), http.StatusNotFound) 288 return 289 } 290 291 h := w.Header() 292 293 etag := fmt.Sprintf(`"%d"`, version) 294 h.Set("ETag", etag) 295 296 loc := *r.URL // copy 297 loc.RawQuery = fmt.Sprintf("v=%d", version) 298 loc.Fragment = "" 299 h.Set("Content-Location", loc.String()) 300 301 h.Set("Allow", repo.resourceMethods()) 302 303 // BUG(pascaldekloe): No support for multiple entity tags in If-None-Match header. 304 for _, s := range r.Header["If-None-Match"] { 305 if s == etag { 306 w.WriteHeader(http.StatusNotModified) 307 return 308 } 309 } 310 311 for _, s := range r.Header["If-Modified-Since"] { 312 t, err := time.Parse(time.RFC1123, s) 313 if err != nil { 314 http.Error(w, fmt.Sprintf("If-Unmodified-Since header %q not RFC1123 compliant: %s", s, err), http.StatusBadRequest) 315 return 316 } 317 // Round down to RFC 1123 resolution: 318 resolution := int64(time.Second) 319 if t.After(time.Unix(0, (version/resolution)*resolution)) { 320 w.WriteHeader(http.StatusNotModified) 321 return 322 } 323 } 324 325 timestamp := time.Unix(0, version) 326 h.Set("Last-Modified", timestamp.In(time.UTC).Format(time.RFC1123)) 327 328 if r.Method != "HEAD" { 329 ServeJSON(w, http.StatusOK, result[0].Interface()) 330 } 331} 332 333func (repo *CRUDRepo) serveUpdate(w http.ResponseWriter, r *http.Request, id int64) { 334 v := reflect.New(repo.dataType) 335 if !ReceiveJSON(v.Interface(), r, w) { 336 return 337 } 338 339 queryVersion, ok := versionQuery(r, w) 340 if !ok { 341 return 342 } 343 344 matchVersion, ok := versionMatch(r, w) 345 if !ok { 346 return 347 } 348 349 var version int64 350 switch { 351 case queryVersion == 0: 352 version = matchVersion 353 case matchVersion == 0: 354 version = queryVersion 355 case queryVersion == matchVersion: 356 version = matchVersion 357 default: 358 http.Error(w, fmt.Sprintf("query parameter v %d does not match If-Match header %d", queryVersion, matchVersion), http.StatusPreconditionFailed) 359 return 360 } 361 if version != 0 { 362 el.Assign(v.Interface(), repo.versionPath, version) 363 } 364 365 result := repo.update.Call([]reflect.Value{reflect.ValueOf(id), v.Elem()}) 366 if !result[0].IsNil() { 367 switch err := result[0].Interface().(error); err { 368 case ErrNotFound: 369 http.Error(w, "", http.StatusNotFound) 370 case ErrOptimisticLock: 371 if matchVersion != 0 { 372 http.Error(w, err.Error(), http.StatusPreconditionFailed) 373 return 374 } 375 http.Error(w, "not the latest version", http.StatusMethodNotAllowed) 376 default: 377 log.Printf("goe rest: update %d v%d: %s", id, version, err) 378 http.Error(w, err.Error(), http.StatusInternalServerError) 379 } 380 return 381 } 382 383 h := w.Header() 384 385 version, _ = el.Int(repo.versionPath, v.Interface()) 386 h.Set("ETag", fmt.Sprintf(`"%d"`, version)) 387 h.Set("Last-Modified", time.Unix(0, version).In(time.UTC).Format(time.RFC1123)) 388 389 h.Set("Allow", repo.resourceMethods()) 390 391 loc := *r.URL // copy 392 loc.RawQuery = fmt.Sprintf("v=%d", version) 393 loc.Fragment = "" 394 h.Set("Content-Location", loc.String()) 395 396 ServeJSON(w, http.StatusOK, v.Interface()) 397} 398 399func (repo *CRUDRepo) serveDelete(w http.ResponseWriter, r *http.Request, id int64) { 400 queryVersion, ok := versionQuery(r, w) 401 if !ok { 402 return 403 } 404 405 matchVersion, ok := versionMatch(r, w) 406 if !ok { 407 return 408 } 409 410 var version int64 411 switch { 412 case queryVersion == 0: 413 version = matchVersion 414 case matchVersion == 0: 415 version = queryVersion 416 case queryVersion != matchVersion: 417 http.Error(w, fmt.Sprintf("query parameter v %d does not match If-Match header %d", queryVersion, matchVersion), http.StatusPreconditionFailed) 418 return 419 default: 420 version = matchVersion 421 } 422 423 result := repo.delete.Call([]reflect.Value{reflect.ValueOf(id), reflect.ValueOf(version)}) 424 if !result[0].IsNil() { 425 switch err := result[0].Interface().(error); err { 426 case ErrNotFound: 427 http.Error(w, "", http.StatusNotFound) 428 case ErrOptimisticLock: 429 if matchVersion != 0 { 430 http.Error(w, err.Error(), http.StatusPreconditionFailed) 431 return 432 } 433 http.Error(w, "not the latest version", http.StatusMethodNotAllowed) 434 default: 435 log.Printf("goe rest: delete %d v%d: %s", id, version, err) 436 http.Error(w, err.Error(), http.StatusInternalServerError) 437 } 438 return 439 } 440 441 w.WriteHeader(http.StatusNoContent) 442 443} 444 445// versionQuery parses URL parameter v or it returns ok false on error. 446func versionQuery(r *http.Request, w http.ResponseWriter) (version int64, ok bool) { 447 switch params := r.URL.Query()["v"]; len(params) { 448 case 0: 449 return 0, true 450 case 1: 451 i, err := strconv.ParseInt(params[0], 10, 64) 452 if err != nil { 453 http.Error(w, "query parameter v: malformed version number", http.StatusNotFound) 454 return 0, false 455 } 456 return i, true 457 default: 458 http.Error(w, "multiple version query parameters", http.StatusBadRequest) 459 return 0, false 460 } 461} 462 463// versionMatch parses the If-Match header or it returns ok false on error. 464func versionMatch(r *http.Request, w http.ResponseWriter) (version int64, ok bool) { 465 // BUG(pascaldekloe): No support for multiple entity tags in If-Match header. 466 467 tags := strings.Join(r.Header["If-Match"], ", ") 468 if tags == "" || tags == "*" { 469 return 0, true 470 } 471 472 const linearWhiteSpace = " \t" 473 tag := strings.Trim(tags, linearWhiteSpace) 474 if tag[0] != '"' || tag[len(tag)-1] != '"' { 475 http.Error(w, fmt.Sprintf("need opaque tags in If-Match header %q", tag), http.StatusBadRequest) 476 return 0, false 477 } 478 s := tag[1 : len(tag)-1] 479 480 i, err := strconv.ParseInt(s, 10, 64) 481 if err != nil { 482 http.Error(w, fmt.Sprintf("malformed or unknow tag in If-Match header %q", tag), http.StatusPreconditionFailed) 483 return 0, false 484 } 485 return i, true 486} 487