1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "flag" 7 "math/rand" 8 "net/http" 9 "strconv" 10 "strings" 11 "time" 12 13 dcontext "github.com/docker/distribution/context" 14 "github.com/docker/distribution/registry/api/errcode" 15 "github.com/docker/distribution/registry/auth" 16 _ "github.com/docker/distribution/registry/auth/htpasswd" 17 "github.com/docker/libtrust" 18 "github.com/gorilla/mux" 19 "github.com/sirupsen/logrus" 20) 21 22var ( 23 enforceRepoClass bool 24) 25 26func main() { 27 var ( 28 issuer = &TokenIssuer{} 29 pkFile string 30 addr string 31 debug bool 32 err error 33 34 passwdFile string 35 realm string 36 37 cert string 38 certKey string 39 ) 40 41 flag.StringVar(&issuer.Issuer, "issuer", "distribution-token-server", "Issuer string for token") 42 flag.StringVar(&pkFile, "key", "", "Private key file") 43 flag.StringVar(&addr, "addr", "localhost:8080", "Address to listen on") 44 flag.BoolVar(&debug, "debug", false, "Debug mode") 45 46 flag.StringVar(&passwdFile, "passwd", ".htpasswd", "Passwd file") 47 flag.StringVar(&realm, "realm", "", "Authentication realm") 48 49 flag.StringVar(&cert, "tlscert", "", "Certificate file for TLS") 50 flag.StringVar(&certKey, "tlskey", "", "Certificate key for TLS") 51 52 flag.BoolVar(&enforceRepoClass, "enforce-class", false, "Enforce policy for single repository class") 53 54 flag.Parse() 55 56 if debug { 57 logrus.SetLevel(logrus.DebugLevel) 58 } 59 60 if pkFile == "" { 61 issuer.SigningKey, err = libtrust.GenerateECP256PrivateKey() 62 if err != nil { 63 logrus.Fatalf("Error generating private key: %v", err) 64 } 65 logrus.Debugf("Using newly generated key with id %s", issuer.SigningKey.KeyID()) 66 } else { 67 issuer.SigningKey, err = libtrust.LoadKeyFile(pkFile) 68 if err != nil { 69 logrus.Fatalf("Error loading key file %s: %v", pkFile, err) 70 } 71 logrus.Debugf("Loaded private key with id %s", issuer.SigningKey.KeyID()) 72 } 73 74 if realm == "" { 75 logrus.Fatalf("Must provide realm") 76 } 77 78 ac, err := auth.GetAccessController("htpasswd", map[string]interface{}{ 79 "realm": realm, 80 "path": passwdFile, 81 }) 82 if err != nil { 83 logrus.Fatalf("Error initializing access controller: %v", err) 84 } 85 86 // TODO: Make configurable 87 issuer.Expiration = 15 * time.Minute 88 89 ctx := dcontext.Background() 90 91 ts := &tokenServer{ 92 issuer: issuer, 93 accessController: ac, 94 refreshCache: map[string]refreshToken{}, 95 } 96 97 router := mux.NewRouter() 98 router.Path("/token/").Methods("GET").Handler(handlerWithContext(ctx, ts.getToken)) 99 router.Path("/token/").Methods("POST").Handler(handlerWithContext(ctx, ts.postToken)) 100 101 if cert == "" { 102 err = http.ListenAndServe(addr, router) 103 } else if certKey == "" { 104 logrus.Fatalf("Must provide certficate (-tlscert) and key (-tlskey)") 105 } else { 106 err = http.ListenAndServeTLS(addr, cert, certKey, router) 107 } 108 109 if err != nil { 110 logrus.Infof("Error serving: %v", err) 111 } 112 113} 114 115// handlerWithContext wraps the given context-aware handler by setting up the 116// request context from a base context. 117func handlerWithContext(ctx context.Context, handler func(context.Context, http.ResponseWriter, *http.Request)) http.Handler { 118 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 119 ctx := dcontext.WithRequest(ctx, r) 120 logger := dcontext.GetRequestLogger(ctx) 121 ctx = dcontext.WithLogger(ctx, logger) 122 123 handler(ctx, w, r) 124 }) 125} 126 127func handleError(ctx context.Context, err error, w http.ResponseWriter) { 128 ctx, w = dcontext.WithResponseWriter(ctx, w) 129 130 if serveErr := errcode.ServeJSON(w, err); serveErr != nil { 131 dcontext.GetResponseLogger(ctx).Errorf("error sending error response: %v", serveErr) 132 return 133 } 134 135 dcontext.GetResponseLogger(ctx).Info("application error") 136} 137 138var refreshCharacters = []rune("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 139 140const refreshTokenLength = 15 141 142func newRefreshToken() string { 143 s := make([]rune, refreshTokenLength) 144 for i := range s { 145 s[i] = refreshCharacters[rand.Intn(len(refreshCharacters))] 146 } 147 return string(s) 148} 149 150type refreshToken struct { 151 subject string 152 service string 153} 154 155type tokenServer struct { 156 issuer *TokenIssuer 157 accessController auth.AccessController 158 refreshCache map[string]refreshToken 159} 160 161type tokenResponse struct { 162 Token string `json:"access_token"` 163 RefreshToken string `json:"refresh_token,omitempty"` 164 ExpiresIn int `json:"expires_in,omitempty"` 165} 166 167var repositoryClassCache = map[string]string{} 168 169func filterAccessList(ctx context.Context, scope string, requestedAccessList []auth.Access) []auth.Access { 170 if !strings.HasSuffix(scope, "/") { 171 scope = scope + "/" 172 } 173 grantedAccessList := make([]auth.Access, 0, len(requestedAccessList)) 174 for _, access := range requestedAccessList { 175 if access.Type == "repository" { 176 if !strings.HasPrefix(access.Name, scope) { 177 dcontext.GetLogger(ctx).Debugf("Resource scope not allowed: %s", access.Name) 178 continue 179 } 180 if enforceRepoClass { 181 if class, ok := repositoryClassCache[access.Name]; ok { 182 if class != access.Class { 183 dcontext.GetLogger(ctx).Debugf("Different repository class: %q, previously %q", access.Class, class) 184 continue 185 } 186 } else if strings.EqualFold(access.Action, "push") { 187 repositoryClassCache[access.Name] = access.Class 188 } 189 } 190 } else if access.Type == "registry" { 191 if access.Name != "catalog" { 192 dcontext.GetLogger(ctx).Debugf("Unknown registry resource: %s", access.Name) 193 continue 194 } 195 // TODO: Limit some actions to "admin" users 196 } else { 197 dcontext.GetLogger(ctx).Debugf("Skipping unsupported resource type: %s", access.Type) 198 continue 199 } 200 grantedAccessList = append(grantedAccessList, access) 201 } 202 return grantedAccessList 203} 204 205type acctSubject struct{} 206 207func (acctSubject) String() string { return "acctSubject" } 208 209type requestedAccess struct{} 210 211func (requestedAccess) String() string { return "requestedAccess" } 212 213type grantedAccess struct{} 214 215func (grantedAccess) String() string { return "grantedAccess" } 216 217// getToken handles authenticating the request and authorizing access to the 218// requested scopes. 219func (ts *tokenServer) getToken(ctx context.Context, w http.ResponseWriter, r *http.Request) { 220 dcontext.GetLogger(ctx).Info("getToken") 221 222 params := r.URL.Query() 223 service := params.Get("service") 224 scopeSpecifiers := params["scope"] 225 var offline bool 226 if offlineStr := params.Get("offline_token"); offlineStr != "" { 227 var err error 228 offline, err = strconv.ParseBool(offlineStr) 229 if err != nil { 230 handleError(ctx, ErrorBadTokenOption.WithDetail(err), w) 231 return 232 } 233 } 234 235 requestedAccessList := ResolveScopeSpecifiers(ctx, scopeSpecifiers) 236 237 authorizedCtx, err := ts.accessController.Authorized(ctx, requestedAccessList...) 238 if err != nil { 239 challenge, ok := err.(auth.Challenge) 240 if !ok { 241 handleError(ctx, err, w) 242 return 243 } 244 245 // Get response context. 246 ctx, w = dcontext.WithResponseWriter(ctx, w) 247 248 challenge.SetHeaders(r, w) 249 handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail(challenge.Error()), w) 250 251 dcontext.GetResponseLogger(ctx).Info("get token authentication challenge") 252 253 return 254 } 255 ctx = authorizedCtx 256 257 username := dcontext.GetStringValue(ctx, "auth.user.name") 258 259 ctx = context.WithValue(ctx, acctSubject{}, username) 260 ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, acctSubject{})) 261 262 dcontext.GetLogger(ctx).Info("authenticated client") 263 264 ctx = context.WithValue(ctx, requestedAccess{}, requestedAccessList) 265 ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, requestedAccess{})) 266 267 grantedAccessList := filterAccessList(ctx, username, requestedAccessList) 268 ctx = context.WithValue(ctx, grantedAccess{}, grantedAccessList) 269 ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, grantedAccess{})) 270 271 token, err := ts.issuer.CreateJWT(username, service, grantedAccessList) 272 if err != nil { 273 handleError(ctx, err, w) 274 return 275 } 276 277 dcontext.GetLogger(ctx).Info("authorized client") 278 279 response := tokenResponse{ 280 Token: token, 281 ExpiresIn: int(ts.issuer.Expiration.Seconds()), 282 } 283 284 if offline { 285 response.RefreshToken = newRefreshToken() 286 ts.refreshCache[response.RefreshToken] = refreshToken{ 287 subject: username, 288 service: service, 289 } 290 } 291 292 ctx, w = dcontext.WithResponseWriter(ctx, w) 293 294 w.Header().Set("Content-Type", "application/json") 295 json.NewEncoder(w).Encode(response) 296 297 dcontext.GetResponseLogger(ctx).Info("get token complete") 298} 299 300type postTokenResponse struct { 301 Token string `json:"access_token"` 302 Scope string `json:"scope,omitempty"` 303 ExpiresIn int `json:"expires_in,omitempty"` 304 IssuedAt string `json:"issued_at,omitempty"` 305 RefreshToken string `json:"refresh_token,omitempty"` 306} 307 308// postToken handles authenticating the request and authorizing access to the 309// requested scopes. 310func (ts *tokenServer) postToken(ctx context.Context, w http.ResponseWriter, r *http.Request) { 311 grantType := r.PostFormValue("grant_type") 312 if grantType == "" { 313 handleError(ctx, ErrorMissingRequiredField.WithDetail("missing grant_type value"), w) 314 return 315 } 316 317 service := r.PostFormValue("service") 318 if service == "" { 319 handleError(ctx, ErrorMissingRequiredField.WithDetail("missing service value"), w) 320 return 321 } 322 323 clientID := r.PostFormValue("client_id") 324 if clientID == "" { 325 handleError(ctx, ErrorMissingRequiredField.WithDetail("missing client_id value"), w) 326 return 327 } 328 329 var offline bool 330 switch r.PostFormValue("access_type") { 331 case "", "online": 332 case "offline": 333 offline = true 334 default: 335 handleError(ctx, ErrorUnsupportedValue.WithDetail("unknown access_type value"), w) 336 return 337 } 338 339 requestedAccessList := ResolveScopeList(ctx, r.PostFormValue("scope")) 340 341 var subject string 342 var rToken string 343 switch grantType { 344 case "refresh_token": 345 rToken = r.PostFormValue("refresh_token") 346 if rToken == "" { 347 handleError(ctx, ErrorUnsupportedValue.WithDetail("missing refresh_token value"), w) 348 return 349 } 350 rt, ok := ts.refreshCache[rToken] 351 if !ok || rt.service != service { 352 handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail("invalid refresh token"), w) 353 return 354 } 355 subject = rt.subject 356 case "password": 357 ca, ok := ts.accessController.(auth.CredentialAuthenticator) 358 if !ok { 359 handleError(ctx, ErrorUnsupportedValue.WithDetail("password grant type not supported"), w) 360 return 361 } 362 subject = r.PostFormValue("username") 363 if subject == "" { 364 handleError(ctx, ErrorUnsupportedValue.WithDetail("missing username value"), w) 365 return 366 } 367 password := r.PostFormValue("password") 368 if password == "" { 369 handleError(ctx, ErrorUnsupportedValue.WithDetail("missing password value"), w) 370 return 371 } 372 if err := ca.AuthenticateUser(subject, password); err != nil { 373 handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail("invalid credentials"), w) 374 return 375 } 376 default: 377 handleError(ctx, ErrorUnsupportedValue.WithDetail("unknown grant_type value"), w) 378 return 379 } 380 381 ctx = context.WithValue(ctx, acctSubject{}, subject) 382 ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, acctSubject{})) 383 384 dcontext.GetLogger(ctx).Info("authenticated client") 385 386 ctx = context.WithValue(ctx, requestedAccess{}, requestedAccessList) 387 ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, requestedAccess{})) 388 389 grantedAccessList := filterAccessList(ctx, subject, requestedAccessList) 390 ctx = context.WithValue(ctx, grantedAccess{}, grantedAccessList) 391 ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, grantedAccess{})) 392 393 token, err := ts.issuer.CreateJWT(subject, service, grantedAccessList) 394 if err != nil { 395 handleError(ctx, err, w) 396 return 397 } 398 399 dcontext.GetLogger(ctx).Info("authorized client") 400 401 response := postTokenResponse{ 402 Token: token, 403 ExpiresIn: int(ts.issuer.Expiration.Seconds()), 404 IssuedAt: time.Now().UTC().Format(time.RFC3339), 405 Scope: ToScopeList(grantedAccessList), 406 } 407 408 if offline { 409 rToken = newRefreshToken() 410 ts.refreshCache[rToken] = refreshToken{ 411 subject: subject, 412 service: service, 413 } 414 } 415 416 if rToken != "" { 417 response.RefreshToken = rToken 418 } 419 420 ctx, w = dcontext.WithResponseWriter(ctx, w) 421 422 w.Header().Set("Content-Type", "application/json") 423 json.NewEncoder(w).Encode(response) 424 425 dcontext.GetResponseLogger(ctx).Info("post token complete") 426} 427