1package main
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"flag"
19	"fmt"
20	"log"
21	"strings"
22
23	"crypto/rsa"
24	"crypto/x509"
25	"io/ioutil"
26	"net/http"
27	"os/user"
28
29	"github.com/Azure/go-autorest/autorest/adal"
30)
31
32const (
33	deviceMode        = "device"
34	clientSecretMode  = "secret"
35	clientCertMode    = "cert"
36	refreshMode       = "refresh"
37	msiDefaultMode    = "msiDefault"
38	msiClientIDMode   = "msiClientID"
39	msiResourceIDMode = "msiResourceID"
40
41	activeDirectoryEndpoint = "https://login.microsoftonline.com/"
42)
43
44type option struct {
45	name  string
46	value string
47}
48
49var (
50	mode     string
51	resource string
52
53	tenantID           string
54	applicationID      string
55	identityResourceID string
56
57	applicationSecret string
58	certificatePath   string
59
60	tokenCachePath string
61)
62
63func checkMandatoryOptions(mode string, options ...option) {
64	for _, option := range options {
65		if strings.TrimSpace(option.value) == "" {
66			log.Fatalf("Authentication mode '%s' requires mandatory option '%s'.", mode, option.name)
67		}
68	}
69}
70
71func defaultTokenCachePath() string {
72	usr, err := user.Current()
73	if err != nil {
74		log.Fatal(err)
75	}
76	defaultTokenPath := usr.HomeDir + "/.adal/accessToken.json"
77	return defaultTokenPath
78}
79
80func init() {
81	flag.StringVar(&mode, "mode", "device", "authentication mode (device, secret, cert, refresh)")
82	flag.StringVar(&resource, "resource", "", "resource for which the token is requested")
83	flag.StringVar(&tenantID, "tenantId", "", "tenant id")
84	flag.StringVar(&applicationID, "applicationId", "", "application id")
85	flag.StringVar(&applicationSecret, "secret", "", "application secret")
86	flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/PFC application certificate")
87	flag.StringVar(&tokenCachePath, "tokenCachePath", defaultTokenCachePath(), "location of oath token cache")
88	flag.StringVar(&identityResourceID, "identityResourceID", "", "managedIdentity azure resource id")
89
90	flag.Parse()
91
92	switch mode = strings.TrimSpace(mode); mode {
93	case msiDefaultMode:
94		checkMandatoryOptions(msiDefaultMode,
95			option{name: "resource", value: resource},
96			option{name: "tenantId", value: tenantID},
97		)
98	case msiClientIDMode:
99		checkMandatoryOptions(msiClientIDMode,
100			option{name: "resource", value: resource},
101			option{name: "tenantId", value: tenantID},
102			option{name: "applicationId", value: applicationID},
103		)
104	case msiResourceIDMode:
105		checkMandatoryOptions(msiResourceIDMode,
106			option{name: "resource", value: resource},
107			option{name: "tenantId", value: tenantID},
108			option{name: "identityResourceID", value: identityResourceID},
109		)
110	case clientSecretMode:
111		checkMandatoryOptions(clientSecretMode,
112			option{name: "resource", value: resource},
113			option{name: "tenantId", value: tenantID},
114			option{name: "applicationId", value: applicationID},
115			option{name: "secret", value: applicationSecret},
116		)
117	case clientCertMode:
118		checkMandatoryOptions(clientCertMode,
119			option{name: "resource", value: resource},
120			option{name: "tenantId", value: tenantID},
121			option{name: "applicationId", value: applicationID},
122			option{name: "certificatePath", value: certificatePath},
123		)
124	case deviceMode:
125		checkMandatoryOptions(deviceMode,
126			option{name: "resource", value: resource},
127			option{name: "tenantId", value: tenantID},
128			option{name: "applicationId", value: applicationID},
129		)
130	case refreshMode:
131		checkMandatoryOptions(refreshMode,
132			option{name: "resource", value: resource},
133			option{name: "tenantId", value: tenantID},
134			option{name: "applicationId", value: applicationID},
135		)
136	default:
137		log.Fatalln("Authentication modes 'secret, 'cert', 'device' or 'refresh' are supported.")
138	}
139}
140
141func acquireTokenClientSecretFlow(oauthConfig adal.OAuthConfig,
142	appliationID string,
143	applicationSecret string,
144	resource string,
145	callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
146
147	spt, err := adal.NewServicePrincipalToken(
148		oauthConfig,
149		appliationID,
150		applicationSecret,
151		resource,
152		callbacks...)
153	if err != nil {
154		return nil, err
155	}
156
157	return spt, spt.Refresh()
158}
159
160func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
161	return adal.DecodePfxCertificateData(pkcs, password)
162}
163
164func acquireTokenMSIFlow(applicationID string,
165	identityResourceID string,
166	resource string,
167	callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
168
169	// only one of them can be present:
170	if applicationID != "" && identityResourceID != "" {
171		return nil, fmt.Errorf("didn't expect applicationID and identityResourceID at same time")
172	}
173
174	msiEndpoint, _ := adal.GetMSIVMEndpoint()
175	var spt *adal.ServicePrincipalToken
176	var err error
177
178	// both can be empty, systemAssignedMSI scenario
179	if applicationID == "" && identityResourceID == "" {
180		spt, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource, callbacks...)
181	}
182
183	// msi login with clientID
184	if applicationID != "" {
185		spt, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, applicationID, callbacks...)
186	}
187
188	// msi login with resourceID
189	if identityResourceID != "" {
190		spt, err = adal.NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource, identityResourceID, callbacks...)
191	}
192
193	if err != nil {
194		return nil, err
195	}
196
197	return spt, spt.Refresh()
198}
199
200func acquireTokenClientCertFlow(oauthConfig adal.OAuthConfig,
201	applicationID string,
202	applicationCertPath string,
203	resource string,
204	callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
205
206	certData, err := ioutil.ReadFile(certificatePath)
207	if err != nil {
208		return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
209	}
210
211	certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
212	if err != nil {
213		return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
214	}
215
216	spt, err := adal.NewServicePrincipalTokenFromCertificate(
217		oauthConfig,
218		applicationID,
219		certificate,
220		rsaPrivateKey,
221		resource,
222		callbacks...)
223	if err != nil {
224		return nil, err
225	}
226
227	return spt, spt.Refresh()
228}
229
230func acquireTokenDeviceCodeFlow(oauthConfig adal.OAuthConfig,
231	applicationID string,
232	resource string,
233	callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
234
235	oauthClient := &http.Client{}
236	deviceCode, err := adal.InitiateDeviceAuth(
237		oauthClient,
238		oauthConfig,
239		applicationID,
240		resource)
241	if err != nil {
242		return nil, fmt.Errorf("Failed to start device auth flow: %s", err)
243	}
244
245	fmt.Println(*deviceCode.Message)
246
247	token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
248	if err != nil {
249		return nil, fmt.Errorf("Failed to finish device auth flow: %s", err)
250	}
251
252	spt, err := adal.NewServicePrincipalTokenFromManualToken(
253		oauthConfig,
254		applicationID,
255		resource,
256		*token,
257		callbacks...)
258	return spt, err
259}
260
261func refreshToken(oauthConfig adal.OAuthConfig,
262	applicationID string,
263	resource string,
264	tokenCachePath string,
265	callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
266
267	token, err := adal.LoadToken(tokenCachePath)
268	if err != nil {
269		return nil, fmt.Errorf("failed to load token from cache: %v", err)
270	}
271
272	spt, err := adal.NewServicePrincipalTokenFromManualToken(
273		oauthConfig,
274		applicationID,
275		resource,
276		*token,
277		callbacks...)
278	if err != nil {
279		return nil, err
280	}
281	return spt, spt.Refresh()
282}
283
284func saveToken(spt adal.Token) error {
285	if tokenCachePath != "" {
286		err := adal.SaveToken(tokenCachePath, 0600, spt)
287		if err != nil {
288			return err
289		}
290		log.Printf("Acquired token was saved in '%s' file\n", tokenCachePath)
291		return nil
292
293	}
294	return fmt.Errorf("empty path for token cache")
295}
296
297func main() {
298	oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID)
299	if err != nil {
300		panic(err)
301	}
302
303	callback := func(token adal.Token) error {
304		return saveToken(token)
305	}
306
307	log.Printf("Authenticating with mode '%s'\n", mode)
308	switch mode {
309	case clientSecretMode:
310		_, err = acquireTokenClientSecretFlow(
311			*oauthConfig,
312			applicationID,
313			applicationSecret,
314			resource,
315			callback)
316	case clientCertMode:
317		_, err = acquireTokenClientCertFlow(
318			*oauthConfig,
319			applicationID,
320			certificatePath,
321			resource,
322			callback)
323	case deviceMode:
324		var spt *adal.ServicePrincipalToken
325		spt, err = acquireTokenDeviceCodeFlow(
326			*oauthConfig,
327			applicationID,
328			resource,
329			callback)
330		if err == nil {
331			err = saveToken(spt.Token())
332		}
333	case msiResourceIDMode:
334		fallthrough
335	case msiClientIDMode:
336		fallthrough
337	case msiDefaultMode:
338		var spt *adal.ServicePrincipalToken
339		spt, err = acquireTokenMSIFlow(
340			applicationID,
341			identityResourceID,
342			resource,
343			callback)
344		if err == nil {
345			err = saveToken(spt.Token())
346		}
347	case refreshMode:
348		_, err = refreshToken(
349			*oauthConfig,
350			applicationID,
351			resource,
352			tokenCachePath,
353			callback)
354	}
355
356	if err != nil {
357		log.Fatalf("Failed to acquire a token for resource %s. Error: %v", resource, err)
358	}
359}
360