1package main
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"flag"
19	"fmt"
20	"log"
21	"strings"
22
23	"crypto/rsa"
24	"crypto/x509"
25	"io/ioutil"
26	"net/http"
27	"os/user"
28
29	"github.com/Azure/go-autorest/autorest/adal"
30	"golang.org/x/crypto/pkcs12"
31)
32
33const (
34	deviceMode       = "device"
35	clientSecretMode = "secret"
36	clientCertMode   = "cert"
37	refreshMode      = "refresh"
38
39	activeDirectoryEndpoint = "https://login.microsoftonline.com/"
40)
41
42type option struct {
43	name  string
44	value string
45}
46
47var (
48	mode     string
49	resource string
50
51	tenantID      string
52	applicationID string
53
54	applicationSecret string
55	certificatePath   string
56
57	tokenCachePath string
58)
59
60func checkMandatoryOptions(mode string, options ...option) {
61	for _, option := range options {
62		if strings.TrimSpace(option.value) == "" {
63			log.Fatalf("Authentication mode '%s' requires mandatory option '%s'.", mode, option.name)
64		}
65	}
66}
67
68func defaultTokenCachePath() string {
69	usr, err := user.Current()
70	if err != nil {
71		log.Fatal(err)
72	}
73	defaultTokenPath := usr.HomeDir + "/.adal/accessToken.json"
74	return defaultTokenPath
75}
76
77func init() {
78	flag.StringVar(&mode, "mode", "device", "authentication mode (device, secret, cert, refresh)")
79	flag.StringVar(&resource, "resource", "", "resource for which the token is requested")
80	flag.StringVar(&tenantID, "tenantId", "", "tenant id")
81	flag.StringVar(&applicationID, "applicationId", "", "application id")
82	flag.StringVar(&applicationSecret, "secret", "", "application secret")
83	flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/PFC application certificate")
84	flag.StringVar(&tokenCachePath, "tokenCachePath", defaultTokenCachePath(), "location of oath token cache")
85
86	flag.Parse()
87
88	switch mode = strings.TrimSpace(mode); mode {
89	case clientSecretMode:
90		checkMandatoryOptions(clientSecretMode,
91			option{name: "resource", value: resource},
92			option{name: "tenantId", value: tenantID},
93			option{name: "applicationId", value: applicationID},
94			option{name: "secret", value: applicationSecret},
95		)
96	case clientCertMode:
97		checkMandatoryOptions(clientCertMode,
98			option{name: "resource", value: resource},
99			option{name: "tenantId", value: tenantID},
100			option{name: "applicationId", value: applicationID},
101			option{name: "certificatePath", value: certificatePath},
102		)
103	case deviceMode:
104		checkMandatoryOptions(deviceMode,
105			option{name: "resource", value: resource},
106			option{name: "tenantId", value: tenantID},
107			option{name: "applicationId", value: applicationID},
108		)
109	case refreshMode:
110		checkMandatoryOptions(refreshMode,
111			option{name: "resource", value: resource},
112			option{name: "tenantId", value: tenantID},
113			option{name: "applicationId", value: applicationID},
114		)
115	default:
116		log.Fatalln("Authentication modes 'secret, 'cert', 'device' or 'refresh' are supported.")
117	}
118}
119
120func acquireTokenClientSecretFlow(oauthConfig adal.OAuthConfig,
121	appliationID string,
122	applicationSecret string,
123	resource string,
124	callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
125
126	spt, err := adal.NewServicePrincipalToken(
127		oauthConfig,
128		appliationID,
129		applicationSecret,
130		resource,
131		callbacks...)
132	if err != nil {
133		return nil, err
134	}
135
136	return spt, spt.Refresh()
137}
138
139func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
140	privateKey, certificate, err := pkcs12.Decode(pkcs, password)
141	if err != nil {
142		return nil, nil, err
143	}
144
145	rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
146	if !isRsaKey {
147		return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
148	}
149
150	return certificate, rsaPrivateKey, nil
151}
152
153func acquireTokenClientCertFlow(oauthConfig adal.OAuthConfig,
154	applicationID string,
155	applicationCertPath string,
156	resource string,
157	callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
158
159	certData, err := ioutil.ReadFile(certificatePath)
160	if err != nil {
161		return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
162	}
163
164	certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
165	if err != nil {
166		return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
167	}
168
169	spt, err := adal.NewServicePrincipalTokenFromCertificate(
170		oauthConfig,
171		applicationID,
172		certificate,
173		rsaPrivateKey,
174		resource,
175		callbacks...)
176	if err != nil {
177		return nil, err
178	}
179
180	return spt, spt.Refresh()
181}
182
183func acquireTokenDeviceCodeFlow(oauthConfig adal.OAuthConfig,
184	applicationID string,
185	resource string,
186	callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
187
188	oauthClient := &http.Client{}
189	deviceCode, err := adal.InitiateDeviceAuth(
190		oauthClient,
191		oauthConfig,
192		applicationID,
193		resource)
194	if err != nil {
195		return nil, fmt.Errorf("Failed to start device auth flow: %s", err)
196	}
197
198	fmt.Println(*deviceCode.Message)
199
200	token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
201	if err != nil {
202		return nil, fmt.Errorf("Failed to finish device auth flow: %s", err)
203	}
204
205	spt, err := adal.NewServicePrincipalTokenFromManualToken(
206		oauthConfig,
207		applicationID,
208		resource,
209		*token,
210		callbacks...)
211	return spt, err
212}
213
214func refreshToken(oauthConfig adal.OAuthConfig,
215	applicationID string,
216	resource string,
217	tokenCachePath string,
218	callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
219
220	token, err := adal.LoadToken(tokenCachePath)
221	if err != nil {
222		return nil, fmt.Errorf("failed to load token from cache: %v", err)
223	}
224
225	spt, err := adal.NewServicePrincipalTokenFromManualToken(
226		oauthConfig,
227		applicationID,
228		resource,
229		*token,
230		callbacks...)
231	if err != nil {
232		return nil, err
233	}
234	return spt, spt.Refresh()
235}
236
237func saveToken(spt adal.Token) error {
238	if tokenCachePath != "" {
239		err := adal.SaveToken(tokenCachePath, 0600, spt)
240		if err != nil {
241			return err
242		}
243		log.Printf("Acquired token was saved in '%s' file\n", tokenCachePath)
244		return nil
245
246	}
247	return fmt.Errorf("empty path for token cache")
248}
249
250func main() {
251	oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID)
252	if err != nil {
253		panic(err)
254	}
255
256	callback := func(token adal.Token) error {
257		return saveToken(token)
258	}
259
260	log.Printf("Authenticating with mode '%s'\n", mode)
261	switch mode {
262	case clientSecretMode:
263		_, err = acquireTokenClientSecretFlow(
264			*oauthConfig,
265			applicationID,
266			applicationSecret,
267			resource,
268			callback)
269	case clientCertMode:
270		_, err = acquireTokenClientCertFlow(
271			*oauthConfig,
272			applicationID,
273			certificatePath,
274			resource,
275			callback)
276	case deviceMode:
277		var spt *adal.ServicePrincipalToken
278		spt, err = acquireTokenDeviceCodeFlow(
279			*oauthConfig,
280			applicationID,
281			resource,
282			callback)
283		if err == nil {
284			err = saveToken(spt.Token())
285		}
286	case refreshMode:
287		_, err = refreshToken(
288			*oauthConfig,
289			applicationID,
290			resource,
291			tokenCachePath,
292			callback)
293	}
294
295	if err != nil {
296		log.Fatalf("Failed to acquire a token for resource %s. Error: %v", resource, err)
297	}
298}
299