1package auth 2 3import ( 4 "encoding/base64" 5 "fmt" 6 "io/ioutil" 7 "net/http" 8 "os" 9 "reflect" 10 "strings" 11 "testing" 12 "time" 13 14 "github.com/fabiolb/fabio/config" 15 "github.com/fabiolb/fabio/uuid" 16) 17 18type responseWriter struct { 19 header http.Header 20 code int 21 written []byte 22} 23 24func (rw *responseWriter) Header() http.Header { 25 if rw.header == nil { 26 rw.header = map[string][]string{} 27 } 28 return rw.header 29} 30 31func (rw *responseWriter) Write(b []byte) (int, error) { 32 rw.written = append(rw.written, b...) 33 return len(rw.written), nil 34} 35 36func (rw *responseWriter) WriteHeader(statusCode int) { 37 rw.code = statusCode 38} 39 40func createBasicAuthFile(contents string) (string, error) { 41 dir, err := ioutil.TempDir("", "basicauth") 42 43 if err != nil { 44 return "", fmt.Errorf("could not create temp dir: %s", err) 45 } 46 47 filename := fmt.Sprintf("%s/%s", dir, uuid.NewUUID()) 48 49 err = ioutil.WriteFile(filename, []byte(contents), 0666) 50 51 if err != nil { 52 return "", fmt.Errorf("could not write password file: %s", err) 53 } 54 55 return filename, nil 56} 57 58func createBasicAuth(user string, password string) (AuthScheme, error) { 59 contents := fmt.Sprintf("%s:%s", user, password) 60 61 filename, err := createBasicAuthFile(contents) 62 63 a, err := newBasicAuth(config.BasicAuth{ 64 File: filename, 65 Realm: "testrealm", 66 }) 67 68 if err != nil { 69 return nil, fmt.Errorf("could not create basic auth: %s", err) 70 } 71 72 return a, nil 73} 74 75func TestNewBasicAuth(t *testing.T) { 76 77 t.Run("should create a basic auth scheme from the supplied config", func(t *testing.T) { 78 filename, err := createBasicAuthFile("foo:bar") 79 80 if err != nil { 81 t.Error(err) 82 } 83 84 _, err = newBasicAuth(config.BasicAuth{ 85 File: filename, 86 }) 87 88 if err != nil { 89 t.Error(err) 90 } 91 }) 92 93 t.Run("should log a warning when credentials are malformed", func(t *testing.T) { 94 filename, err := createBasicAuthFile("foosdlijdgohdgdbar") 95 96 if err != nil { 97 t.Error(err) 98 } 99 100 _, err = newBasicAuth(config.BasicAuth{ 101 File: filename, 102 }) 103 104 if err != nil { 105 t.Error(err) 106 } 107 }) 108} 109 110func TestBasic_Authorised(t *testing.T) { 111 basicAuth, err := createBasicAuth("foo", "bar") 112 creds := []byte("foo:bar") 113 114 if err != nil { 115 t.Fatal(err) 116 } 117 118 tests := []struct { 119 name string 120 req *http.Request 121 res http.ResponseWriter 122 out bool 123 }{ 124 { 125 "correct credentials should be authorized", 126 &http.Request{ 127 Header: http.Header{ 128 "Authorization": []string{fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(creds))}, 129 }, 130 }, 131 &responseWriter{}, 132 true, 133 }, 134 { 135 "incorrect credentials should not be authorized", 136 &http.Request{ 137 Header: http.Header{ 138 "Authorization": []string{fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("baz:blarg")))}, 139 }, 140 }, 141 &responseWriter{}, 142 false, 143 }, 144 { 145 "missing Authorization header should not be authorized", 146 &http.Request{ 147 Header: http.Header{}, 148 }, 149 &responseWriter{}, 150 false, 151 }, 152 { 153 "malformed Authorization header should not be authorized", 154 &http.Request{ 155 Header: http.Header{ 156 "Authorization": []string{"malformed"}, 157 }, 158 }, 159 &responseWriter{}, 160 false, 161 }, 162 } 163 164 for _, tt := range tests { 165 t.Run(tt.name, func(t *testing.T) { 166 if got, want := basicAuth.Authorized(tt.req, tt.res), tt.out; !reflect.DeepEqual(got, want) { 167 t.Errorf("got %v want %v", got, want) 168 } 169 }) 170 } 171} 172 173func TestBasic_Authorised_should_fail_without_htpasswd_file(t *testing.T) { 174 filename, err := createBasicAuthFile("foo:bar") 175 if err != nil { 176 t.Error(err) 177 } 178 179 a, err := newBasicAuth(config.BasicAuth{ 180 File: filename, 181 Refresh: time.Second, 182 }) 183 if err != nil { 184 t.Error(err) 185 } 186 187 creds := []byte("foo:bar") 188 r := &http.Request{ 189 Header: http.Header{ 190 "Authorization": []string{fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(creds))}, 191 }, 192 } 193 194 w := &responseWriter{} 195 196 t.Run("should authorize against supplied htpasswd file", func(t *testing.T) { 197 if got, want := a.Authorized(r, w), true; !reflect.DeepEqual(got, want) { 198 t.Errorf("got %v want %v", got, want) 199 } 200 }) 201 202 if err := os.Remove(filename); err != nil { 203 t.Fatalf("removing htpasswd file: %s", err) 204 } 205 206 time.Sleep(2 * time.Second) // ensure htpasswd file refresh happend 207 208 t.Run("should not authorize after removing htpasswd file", func(t *testing.T) { 209 if got, want := a.Authorized(r, w), false; !reflect.DeepEqual(got, want) { 210 t.Errorf("got %v want %v", got, want) 211 } 212 }) 213} 214 215func TestBasic_Authorized_should_set_www_realm_header(t *testing.T) { 216 basicAuth, err := createBasicAuth("foo", "bar") 217 218 if err != nil { 219 t.Fatal(err) 220 } 221 222 rw := &responseWriter{} 223 224 _ = basicAuth.Authorized(&http.Request{Header: http.Header{}}, rw) 225 226 got := rw.Header().Get("WWW-Authenticate") 227 want := `Basic realm="testrealm"` 228 229 if strings.Compare(got, want) != 0 { 230 t.Errorf("got '%s', want '%s'", got, want) 231 } 232} 233