1package adal
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17/*
18  This file is largely based on rjw57/oauth2device's code, with the follow differences:
19   * scope -> resource, and only allow a single one
20   * receive "Message" in the DeviceCode struct and show it to users as the prompt
21   * azure-xplat-cli has the following behavior that this emulates:
22     - does not send client_secret during the token exchange
23     - sends resource again in the token exchange request
24*/
25
26import (
27	"context"
28	"encoding/json"
29	"fmt"
30	"io/ioutil"
31	"net/http"
32	"net/url"
33	"strings"
34	"time"
35)
36
37const (
38	logPrefix = "autorest/adal/devicetoken:"
39)
40
41var (
42	// ErrDeviceGeneric represents an unknown error from the token endpoint when using device flow
43	ErrDeviceGeneric = fmt.Errorf("%s Error while retrieving OAuth token: Unknown Error", logPrefix)
44
45	// ErrDeviceAccessDenied represents an access denied error from the token endpoint when using device flow
46	ErrDeviceAccessDenied = fmt.Errorf("%s Error while retrieving OAuth token: Access Denied", logPrefix)
47
48	// ErrDeviceAuthorizationPending represents the server waiting on the user to complete the device flow
49	ErrDeviceAuthorizationPending = fmt.Errorf("%s Error while retrieving OAuth token: Authorization Pending", logPrefix)
50
51	// ErrDeviceCodeExpired represents the server timing out and expiring the code during device flow
52	ErrDeviceCodeExpired = fmt.Errorf("%s Error while retrieving OAuth token: Code Expired", logPrefix)
53
54	// ErrDeviceSlowDown represents the service telling us we're polling too often during device flow
55	ErrDeviceSlowDown = fmt.Errorf("%s Error while retrieving OAuth token: Slow Down", logPrefix)
56
57	// ErrDeviceCodeEmpty represents an empty device code from the device endpoint while using device flow
58	ErrDeviceCodeEmpty = fmt.Errorf("%s Error while retrieving device code: Device Code Empty", logPrefix)
59
60	// ErrOAuthTokenEmpty represents an empty OAuth token from the token endpoint when using device flow
61	ErrOAuthTokenEmpty = fmt.Errorf("%s Error while retrieving OAuth token: Token Empty", logPrefix)
62
63	errCodeSendingFails   = "Error occurred while sending request for Device Authorization Code"
64	errCodeHandlingFails  = "Error occurred while handling response from the Device Endpoint"
65	errTokenSendingFails  = "Error occurred while sending request with device code for a token"
66	errTokenHandlingFails = "Error occurred while handling response from the Token Endpoint (during device flow)"
67	errStatusNotOK        = "Error HTTP status != 200"
68)
69
70// DeviceCode is the object returned by the device auth endpoint
71// It contains information to instruct the user to complete the auth flow
72type DeviceCode struct {
73	DeviceCode      *string `json:"device_code,omitempty"`
74	UserCode        *string `json:"user_code,omitempty"`
75	VerificationURL *string `json:"verification_url,omitempty"`
76	ExpiresIn       *int64  `json:"expires_in,string,omitempty"`
77	Interval        *int64  `json:"interval,string,omitempty"`
78
79	Message     *string `json:"message"` // Azure specific
80	Resource    string  // store the following, stored when initiating, used when exchanging
81	OAuthConfig OAuthConfig
82	ClientID    string
83}
84
85// TokenError is the object returned by the token exchange endpoint
86// when something is amiss
87type TokenError struct {
88	Error            *string `json:"error,omitempty"`
89	ErrorCodes       []int   `json:"error_codes,omitempty"`
90	ErrorDescription *string `json:"error_description,omitempty"`
91	Timestamp        *string `json:"timestamp,omitempty"`
92	TraceID          *string `json:"trace_id,omitempty"`
93}
94
95// DeviceToken is the object return by the token exchange endpoint
96// It can either look like a Token or an ErrorToken, so put both here
97// and check for presence of "Error" to know if we are in error state
98type deviceToken struct {
99	Token
100	TokenError
101}
102
103// InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode
104// that can be used with CheckForUserCompletion or WaitForUserCompletion.
105// Deprecated: use InitiateDeviceAuthWithContext() instead.
106func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
107	return InitiateDeviceAuthWithContext(context.Background(), sender, oauthConfig, clientID, resource)
108}
109
110// InitiateDeviceAuthWithContext initiates a device auth flow. It returns a DeviceCode
111// that can be used with CheckForUserCompletion or WaitForUserCompletion.
112func InitiateDeviceAuthWithContext(ctx context.Context, sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
113	v := url.Values{
114		"client_id": []string{clientID},
115		"resource":  []string{resource},
116	}
117
118	s := v.Encode()
119	body := ioutil.NopCloser(strings.NewReader(s))
120
121	req, err := http.NewRequest(http.MethodPost, oauthConfig.DeviceCodeEndpoint.String(), body)
122	if err != nil {
123		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
124	}
125
126	req.ContentLength = int64(len(s))
127	req.Header.Set(contentType, mimeTypeFormPost)
128	resp, err := sender.Do(req.WithContext(ctx))
129	if err != nil {
130		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
131	}
132	defer resp.Body.Close()
133
134	rb, err := ioutil.ReadAll(resp.Body)
135	if err != nil {
136		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
137	}
138
139	if resp.StatusCode != http.StatusOK {
140		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, errStatusNotOK)
141	}
142
143	if len(strings.Trim(string(rb), " ")) == 0 {
144		return nil, ErrDeviceCodeEmpty
145	}
146
147	var code DeviceCode
148	err = json.Unmarshal(rb, &code)
149	if err != nil {
150		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
151	}
152
153	code.ClientID = clientID
154	code.Resource = resource
155	code.OAuthConfig = oauthConfig
156
157	return &code, nil
158}
159
160// CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint
161// to see if the device flow has: been completed, timed out, or otherwise failed
162// Deprecated: use CheckForUserCompletionWithContext() instead.
163func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
164	return CheckForUserCompletionWithContext(context.Background(), sender, code)
165}
166
167// CheckForUserCompletionWithContext takes a DeviceCode and checks with the Azure AD OAuth endpoint
168// to see if the device flow has: been completed, timed out, or otherwise failed
169func CheckForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
170	v := url.Values{
171		"client_id":  []string{code.ClientID},
172		"code":       []string{*code.DeviceCode},
173		"grant_type": []string{OAuthGrantTypeDeviceCode},
174		"resource":   []string{code.Resource},
175	}
176
177	s := v.Encode()
178	body := ioutil.NopCloser(strings.NewReader(s))
179
180	req, err := http.NewRequest(http.MethodPost, code.OAuthConfig.TokenEndpoint.String(), body)
181	if err != nil {
182		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
183	}
184
185	req.ContentLength = int64(len(s))
186	req.Header.Set(contentType, mimeTypeFormPost)
187	resp, err := sender.Do(req.WithContext(ctx))
188	if err != nil {
189		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
190	}
191	defer resp.Body.Close()
192
193	rb, err := ioutil.ReadAll(resp.Body)
194	if err != nil {
195		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
196	}
197
198	if resp.StatusCode != http.StatusOK && len(strings.Trim(string(rb), " ")) == 0 {
199		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, errStatusNotOK)
200	}
201	if len(strings.Trim(string(rb), " ")) == 0 {
202		return nil, ErrOAuthTokenEmpty
203	}
204
205	var token deviceToken
206	err = json.Unmarshal(rb, &token)
207	if err != nil {
208		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
209	}
210
211	if token.Error == nil {
212		return &token.Token, nil
213	}
214
215	switch *token.Error {
216	case "authorization_pending":
217		return nil, ErrDeviceAuthorizationPending
218	case "slow_down":
219		return nil, ErrDeviceSlowDown
220	case "access_denied":
221		return nil, ErrDeviceAccessDenied
222	case "code_expired":
223		return nil, ErrDeviceCodeExpired
224	default:
225		return nil, ErrDeviceGeneric
226	}
227}
228
229// WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs.
230// This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
231// Deprecated: use WaitForUserCompletionWithContext() instead.
232func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
233	return WaitForUserCompletionWithContext(context.Background(), sender, code)
234}
235
236// WaitForUserCompletionWithContext calls CheckForUserCompletion repeatedly until a token is granted or an error
237// state occurs.  This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
238func WaitForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
239	intervalDuration := time.Duration(*code.Interval) * time.Second
240	waitDuration := intervalDuration
241
242	for {
243		token, err := CheckForUserCompletionWithContext(ctx, sender, code)
244
245		if err == nil {
246			return token, nil
247		}
248
249		switch err {
250		case ErrDeviceSlowDown:
251			waitDuration += waitDuration
252		case ErrDeviceAuthorizationPending:
253			// noop
254		default: // everything else is "fatal" to us
255			return nil, err
256		}
257
258		if waitDuration > (intervalDuration * 3) {
259			return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix)
260		}
261
262		select {
263		case <-time.After(waitDuration):
264			// noop
265		case <-ctx.Done():
266			return nil, ctx.Err()
267		}
268	}
269}
270