1package adal 2 3// Copyright 2017 Microsoft Corporation 4// 5// Licensed under the Apache License, Version 2.0 (the "License"); 6// you may not use this file except in compliance with the License. 7// You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, software 12// distributed under the License is distributed on an "AS IS" BASIS, 13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14// See the License for the specific language governing permissions and 15// limitations under the License. 16 17/* 18 This file is largely based on rjw57/oauth2device's code, with the follow differences: 19 * scope -> resource, and only allow a single one 20 * receive "Message" in the DeviceCode struct and show it to users as the prompt 21 * azure-xplat-cli has the following behavior that this emulates: 22 - does not send client_secret during the token exchange 23 - sends resource again in the token exchange request 24*/ 25 26import ( 27 "context" 28 "encoding/json" 29 "fmt" 30 "io/ioutil" 31 "net/http" 32 "net/url" 33 "strings" 34 "time" 35) 36 37const ( 38 logPrefix = "autorest/adal/devicetoken:" 39) 40 41var ( 42 // ErrDeviceGeneric represents an unknown error from the token endpoint when using device flow 43 ErrDeviceGeneric = fmt.Errorf("%s Error while retrieving OAuth token: Unknown Error", logPrefix) 44 45 // ErrDeviceAccessDenied represents an access denied error from the token endpoint when using device flow 46 ErrDeviceAccessDenied = fmt.Errorf("%s Error while retrieving OAuth token: Access Denied", logPrefix) 47 48 // ErrDeviceAuthorizationPending represents the server waiting on the user to complete the device flow 49 ErrDeviceAuthorizationPending = fmt.Errorf("%s Error while retrieving OAuth token: Authorization Pending", logPrefix) 50 51 // ErrDeviceCodeExpired represents the server timing out and expiring the code during device flow 52 ErrDeviceCodeExpired = fmt.Errorf("%s Error while retrieving OAuth token: Code Expired", logPrefix) 53 54 // ErrDeviceSlowDown represents the service telling us we're polling too often during device flow 55 ErrDeviceSlowDown = fmt.Errorf("%s Error while retrieving OAuth token: Slow Down", logPrefix) 56 57 // ErrDeviceCodeEmpty represents an empty device code from the device endpoint while using device flow 58 ErrDeviceCodeEmpty = fmt.Errorf("%s Error while retrieving device code: Device Code Empty", logPrefix) 59 60 // ErrOAuthTokenEmpty represents an empty OAuth token from the token endpoint when using device flow 61 ErrOAuthTokenEmpty = fmt.Errorf("%s Error while retrieving OAuth token: Token Empty", logPrefix) 62 63 errCodeSendingFails = "Error occurred while sending request for Device Authorization Code" 64 errCodeHandlingFails = "Error occurred while handling response from the Device Endpoint" 65 errTokenSendingFails = "Error occurred while sending request with device code for a token" 66 errTokenHandlingFails = "Error occurred while handling response from the Token Endpoint (during device flow)" 67 errStatusNotOK = "Error HTTP status != 200" 68) 69 70// DeviceCode is the object returned by the device auth endpoint 71// It contains information to instruct the user to complete the auth flow 72type DeviceCode struct { 73 DeviceCode *string `json:"device_code,omitempty"` 74 UserCode *string `json:"user_code,omitempty"` 75 VerificationURL *string `json:"verification_url,omitempty"` 76 ExpiresIn *int64 `json:"expires_in,string,omitempty"` 77 Interval *int64 `json:"interval,string,omitempty"` 78 79 Message *string `json:"message"` // Azure specific 80 Resource string // store the following, stored when initiating, used when exchanging 81 OAuthConfig OAuthConfig 82 ClientID string 83} 84 85// TokenError is the object returned by the token exchange endpoint 86// when something is amiss 87type TokenError struct { 88 Error *string `json:"error,omitempty"` 89 ErrorCodes []int `json:"error_codes,omitempty"` 90 ErrorDescription *string `json:"error_description,omitempty"` 91 Timestamp *string `json:"timestamp,omitempty"` 92 TraceID *string `json:"trace_id,omitempty"` 93} 94 95// DeviceToken is the object return by the token exchange endpoint 96// It can either look like a Token or an ErrorToken, so put both here 97// and check for presence of "Error" to know if we are in error state 98type deviceToken struct { 99 Token 100 TokenError 101} 102 103// InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode 104// that can be used with CheckForUserCompletion or WaitForUserCompletion. 105// Deprecated: use InitiateDeviceAuthWithContext() instead. 106func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) { 107 return InitiateDeviceAuthWithContext(context.Background(), sender, oauthConfig, clientID, resource) 108} 109 110// InitiateDeviceAuthWithContext initiates a device auth flow. It returns a DeviceCode 111// that can be used with CheckForUserCompletion or WaitForUserCompletion. 112func InitiateDeviceAuthWithContext(ctx context.Context, sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) { 113 v := url.Values{ 114 "client_id": []string{clientID}, 115 "resource": []string{resource}, 116 } 117 118 s := v.Encode() 119 body := ioutil.NopCloser(strings.NewReader(s)) 120 121 req, err := http.NewRequest(http.MethodPost, oauthConfig.DeviceCodeEndpoint.String(), body) 122 if err != nil { 123 return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error()) 124 } 125 126 req.ContentLength = int64(len(s)) 127 req.Header.Set(contentType, mimeTypeFormPost) 128 resp, err := sender.Do(req.WithContext(ctx)) 129 if err != nil { 130 return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error()) 131 } 132 defer resp.Body.Close() 133 134 rb, err := ioutil.ReadAll(resp.Body) 135 if err != nil { 136 return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error()) 137 } 138 139 if resp.StatusCode != http.StatusOK { 140 return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, errStatusNotOK) 141 } 142 143 if len(strings.Trim(string(rb), " ")) == 0 { 144 return nil, ErrDeviceCodeEmpty 145 } 146 147 var code DeviceCode 148 err = json.Unmarshal(rb, &code) 149 if err != nil { 150 return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error()) 151 } 152 153 code.ClientID = clientID 154 code.Resource = resource 155 code.OAuthConfig = oauthConfig 156 157 return &code, nil 158} 159 160// CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint 161// to see if the device flow has: been completed, timed out, or otherwise failed 162// Deprecated: use CheckForUserCompletionWithContext() instead. 163func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { 164 return CheckForUserCompletionWithContext(context.Background(), sender, code) 165} 166 167// CheckForUserCompletionWithContext takes a DeviceCode and checks with the Azure AD OAuth endpoint 168// to see if the device flow has: been completed, timed out, or otherwise failed 169func CheckForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) { 170 v := url.Values{ 171 "client_id": []string{code.ClientID}, 172 "code": []string{*code.DeviceCode}, 173 "grant_type": []string{OAuthGrantTypeDeviceCode}, 174 "resource": []string{code.Resource}, 175 } 176 177 s := v.Encode() 178 body := ioutil.NopCloser(strings.NewReader(s)) 179 180 req, err := http.NewRequest(http.MethodPost, code.OAuthConfig.TokenEndpoint.String(), body) 181 if err != nil { 182 return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error()) 183 } 184 185 req.ContentLength = int64(len(s)) 186 req.Header.Set(contentType, mimeTypeFormPost) 187 resp, err := sender.Do(req.WithContext(ctx)) 188 if err != nil { 189 return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error()) 190 } 191 defer resp.Body.Close() 192 193 rb, err := ioutil.ReadAll(resp.Body) 194 if err != nil { 195 return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error()) 196 } 197 198 if resp.StatusCode != http.StatusOK && len(strings.Trim(string(rb), " ")) == 0 { 199 return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, errStatusNotOK) 200 } 201 if len(strings.Trim(string(rb), " ")) == 0 { 202 return nil, ErrOAuthTokenEmpty 203 } 204 205 var token deviceToken 206 err = json.Unmarshal(rb, &token) 207 if err != nil { 208 return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error()) 209 } 210 211 if token.Error == nil { 212 return &token.Token, nil 213 } 214 215 switch *token.Error { 216 case "authorization_pending": 217 return nil, ErrDeviceAuthorizationPending 218 case "slow_down": 219 return nil, ErrDeviceSlowDown 220 case "access_denied": 221 return nil, ErrDeviceAccessDenied 222 case "code_expired": 223 return nil, ErrDeviceCodeExpired 224 default: 225 return nil, ErrDeviceGeneric 226 } 227} 228 229// WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs. 230// This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'. 231// Deprecated: use WaitForUserCompletionWithContext() instead. 232func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { 233 return WaitForUserCompletionWithContext(context.Background(), sender, code) 234} 235 236// WaitForUserCompletionWithContext calls CheckForUserCompletion repeatedly until a token is granted or an error 237// state occurs. This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'. 238func WaitForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) { 239 intervalDuration := time.Duration(*code.Interval) * time.Second 240 waitDuration := intervalDuration 241 242 for { 243 token, err := CheckForUserCompletionWithContext(ctx, sender, code) 244 245 if err == nil { 246 return token, nil 247 } 248 249 switch err { 250 case ErrDeviceSlowDown: 251 waitDuration += waitDuration 252 case ErrDeviceAuthorizationPending: 253 // noop 254 default: // everything else is "fatal" to us 255 return nil, err 256 } 257 258 if waitDuration > (intervalDuration * 3) { 259 return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix) 260 } 261 262 select { 263 case <-time.After(waitDuration): 264 // noop 265 case <-ctx.Done(): 266 return nil, ctx.Err() 267 } 268 } 269} 270