1package adal
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17/*
18  This file is largely based on rjw57/oauth2device's code, with the follow differences:
19   * scope -> resource, and only allow a single one
20   * receive "Message" in the DeviceCode struct and show it to users as the prompt
21   * azure-xplat-cli has the following behavior that this emulates:
22     - does not send client_secret during the token exchange
23     - sends resource again in the token exchange request
24*/
25
26import (
27	"encoding/json"
28	"fmt"
29	"io/ioutil"
30	"net/http"
31	"net/url"
32	"strings"
33	"time"
34)
35
36const (
37	logPrefix = "autorest/adal/devicetoken:"
38)
39
40var (
41	// ErrDeviceGeneric represents an unknown error from the token endpoint when using device flow
42	ErrDeviceGeneric = fmt.Errorf("%s Error while retrieving OAuth token: Unknown Error", logPrefix)
43
44	// ErrDeviceAccessDenied represents an access denied error from the token endpoint when using device flow
45	ErrDeviceAccessDenied = fmt.Errorf("%s Error while retrieving OAuth token: Access Denied", logPrefix)
46
47	// ErrDeviceAuthorizationPending represents the server waiting on the user to complete the device flow
48	ErrDeviceAuthorizationPending = fmt.Errorf("%s Error while retrieving OAuth token: Authorization Pending", logPrefix)
49
50	// ErrDeviceCodeExpired represents the server timing out and expiring the code during device flow
51	ErrDeviceCodeExpired = fmt.Errorf("%s Error while retrieving OAuth token: Code Expired", logPrefix)
52
53	// ErrDeviceSlowDown represents the service telling us we're polling too often during device flow
54	ErrDeviceSlowDown = fmt.Errorf("%s Error while retrieving OAuth token: Slow Down", logPrefix)
55
56	// ErrDeviceCodeEmpty represents an empty device code from the device endpoint while using device flow
57	ErrDeviceCodeEmpty = fmt.Errorf("%s Error while retrieving device code: Device Code Empty", logPrefix)
58
59	// ErrOAuthTokenEmpty represents an empty OAuth token from the token endpoint when using device flow
60	ErrOAuthTokenEmpty = fmt.Errorf("%s Error while retrieving OAuth token: Token Empty", logPrefix)
61
62	errCodeSendingFails   = "Error occurred while sending request for Device Authorization Code"
63	errCodeHandlingFails  = "Error occurred while handling response from the Device Endpoint"
64	errTokenSendingFails  = "Error occurred while sending request with device code for a token"
65	errTokenHandlingFails = "Error occurred while handling response from the Token Endpoint (during device flow)"
66	errStatusNotOK        = "Error HTTP status != 200"
67)
68
69// DeviceCode is the object returned by the device auth endpoint
70// It contains information to instruct the user to complete the auth flow
71type DeviceCode struct {
72	DeviceCode      *string `json:"device_code,omitempty"`
73	UserCode        *string `json:"user_code,omitempty"`
74	VerificationURL *string `json:"verification_url,omitempty"`
75	ExpiresIn       *int64  `json:"expires_in,string,omitempty"`
76	Interval        *int64  `json:"interval,string,omitempty"`
77
78	Message     *string `json:"message"` // Azure specific
79	Resource    string  // store the following, stored when initiating, used when exchanging
80	OAuthConfig OAuthConfig
81	ClientID    string
82}
83
84// TokenError is the object returned by the token exchange endpoint
85// when something is amiss
86type TokenError struct {
87	Error            *string `json:"error,omitempty"`
88	ErrorCodes       []int   `json:"error_codes,omitempty"`
89	ErrorDescription *string `json:"error_description,omitempty"`
90	Timestamp        *string `json:"timestamp,omitempty"`
91	TraceID          *string `json:"trace_id,omitempty"`
92}
93
94// DeviceToken is the object return by the token exchange endpoint
95// It can either look like a Token or an ErrorToken, so put both here
96// and check for presence of "Error" to know if we are in error state
97type deviceToken struct {
98	Token
99	TokenError
100}
101
102// InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode
103// that can be used with CheckForUserCompletion or WaitForUserCompletion.
104func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
105	v := url.Values{
106		"client_id": []string{clientID},
107		"resource":  []string{resource},
108	}
109
110	s := v.Encode()
111	body := ioutil.NopCloser(strings.NewReader(s))
112
113	req, err := http.NewRequest(http.MethodPost, oauthConfig.DeviceCodeEndpoint.String(), body)
114	if err != nil {
115		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
116	}
117
118	req.ContentLength = int64(len(s))
119	req.Header.Set(contentType, mimeTypeFormPost)
120	resp, err := sender.Do(req)
121	if err != nil {
122		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
123	}
124	defer resp.Body.Close()
125
126	rb, err := ioutil.ReadAll(resp.Body)
127	if err != nil {
128		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
129	}
130
131	if resp.StatusCode != http.StatusOK {
132		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, errStatusNotOK)
133	}
134
135	if len(strings.Trim(string(rb), " ")) == 0 {
136		return nil, ErrDeviceCodeEmpty
137	}
138
139	var code DeviceCode
140	err = json.Unmarshal(rb, &code)
141	if err != nil {
142		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
143	}
144
145	code.ClientID = clientID
146	code.Resource = resource
147	code.OAuthConfig = oauthConfig
148
149	return &code, nil
150}
151
152// CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint
153// to see if the device flow has: been completed, timed out, or otherwise failed
154func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
155	v := url.Values{
156		"client_id":  []string{code.ClientID},
157		"code":       []string{*code.DeviceCode},
158		"grant_type": []string{OAuthGrantTypeDeviceCode},
159		"resource":   []string{code.Resource},
160	}
161
162	s := v.Encode()
163	body := ioutil.NopCloser(strings.NewReader(s))
164
165	req, err := http.NewRequest(http.MethodPost, code.OAuthConfig.TokenEndpoint.String(), body)
166	if err != nil {
167		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
168	}
169
170	req.ContentLength = int64(len(s))
171	req.Header.Set(contentType, mimeTypeFormPost)
172	resp, err := sender.Do(req)
173	if err != nil {
174		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
175	}
176	defer resp.Body.Close()
177
178	rb, err := ioutil.ReadAll(resp.Body)
179	if err != nil {
180		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
181	}
182
183	if resp.StatusCode != http.StatusOK && len(strings.Trim(string(rb), " ")) == 0 {
184		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, errStatusNotOK)
185	}
186	if len(strings.Trim(string(rb), " ")) == 0 {
187		return nil, ErrOAuthTokenEmpty
188	}
189
190	var token deviceToken
191	err = json.Unmarshal(rb, &token)
192	if err != nil {
193		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
194	}
195
196	if token.Error == nil {
197		return &token.Token, nil
198	}
199
200	switch *token.Error {
201	case "authorization_pending":
202		return nil, ErrDeviceAuthorizationPending
203	case "slow_down":
204		return nil, ErrDeviceSlowDown
205	case "access_denied":
206		return nil, ErrDeviceAccessDenied
207	case "code_expired":
208		return nil, ErrDeviceCodeExpired
209	default:
210		return nil, ErrDeviceGeneric
211	}
212}
213
214// WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs.
215// This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
216func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
217	intervalDuration := time.Duration(*code.Interval) * time.Second
218	waitDuration := intervalDuration
219
220	for {
221		token, err := CheckForUserCompletion(sender, code)
222
223		if err == nil {
224			return token, nil
225		}
226
227		switch err {
228		case ErrDeviceSlowDown:
229			waitDuration += waitDuration
230		case ErrDeviceAuthorizationPending:
231			// noop
232		default: // everything else is "fatal" to us
233			return nil, err
234		}
235
236		if waitDuration > (intervalDuration * 3) {
237			return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix)
238		}
239
240		time.Sleep(waitDuration)
241	}
242}
243