1package basic
2
3import (
4	"bytes"
5	"context"
6	"crypto/sha256"
7	"crypto/subtle"
8	"encoding/base64"
9	"fmt"
10	"net/http"
11	"strings"
12
13	"github.com/go-kit/kit/endpoint"
14	httptransport "github.com/go-kit/kit/transport/http"
15)
16
17// AuthError represents an authorization error.
18type AuthError struct {
19	Realm string
20}
21
22// StatusCode is an implementation of the StatusCoder interface in go-kit/http.
23func (AuthError) StatusCode() int {
24	return http.StatusUnauthorized
25}
26
27// Error is an implementation of the Error interface.
28func (AuthError) Error() string {
29	return http.StatusText(http.StatusUnauthorized)
30}
31
32// Headers is an implementation of the Headerer interface in go-kit/http.
33func (e AuthError) Headers() http.Header {
34	return http.Header{
35		"Content-Type":           []string{"text/plain; charset=utf-8"},
36		"X-Content-Type-Options": []string{"nosniff"},
37		"WWW-Authenticate":       []string{fmt.Sprintf(`Basic realm=%q`, e.Realm)},
38	}
39}
40
41// parseBasicAuth parses an HTTP Basic Authentication string.
42// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ([]byte("Aladdin"), []byte("open sesame"), true).
43func parseBasicAuth(auth string) (username, password []byte, ok bool) {
44	const prefix = "Basic "
45	if !strings.HasPrefix(auth, prefix) {
46		return
47	}
48	c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
49	if err != nil {
50		return
51	}
52
53	s := bytes.IndexByte(c, ':')
54	if s < 0 {
55		return
56	}
57	return c[:s], c[s+1:], true
58}
59
60// Returns a hash of a given slice.
61func toHashSlice(s []byte) []byte {
62	hash := sha256.Sum256(s)
63	return hash[:]
64}
65
66// AuthMiddleware returns a Basic Authentication middleware for a particular user and password.
67func AuthMiddleware(requiredUser, requiredPassword, realm string) endpoint.Middleware {
68	requiredUserBytes := toHashSlice([]byte(requiredUser))
69	requiredPasswordBytes := toHashSlice([]byte(requiredPassword))
70
71	return func(next endpoint.Endpoint) endpoint.Endpoint {
72		return func(ctx context.Context, request interface{}) (interface{}, error) {
73			auth, ok := ctx.Value(httptransport.ContextKeyRequestAuthorization).(string)
74			if !ok {
75				return nil, AuthError{realm}
76			}
77
78			givenUser, givenPassword, ok := parseBasicAuth(auth)
79			if !ok {
80				return nil, AuthError{realm}
81			}
82
83			givenUserBytes := toHashSlice(givenUser)
84			givenPasswordBytes := toHashSlice(givenPassword)
85
86			if subtle.ConstantTimeCompare(givenUserBytes, requiredUserBytes) == 0 ||
87				subtle.ConstantTimeCompare(givenPasswordBytes, requiredPasswordBytes) == 0 {
88				return nil, AuthError{realm}
89			}
90
91			return next(ctx, request)
92		}
93	}
94}
95