1package main
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"crypto/rsa"
19	"crypto/x509"
20	"encoding/json"
21	"flag"
22	"fmt"
23	"io/ioutil"
24	"log"
25	"net/http"
26	"strings"
27
28	"github.com/Azure/go-autorest/autorest"
29	"github.com/Azure/go-autorest/autorest/adal"
30	"github.com/Azure/go-autorest/autorest/azure"
31	"golang.org/x/crypto/pkcs12"
32)
33
34const (
35	resourceGroupURLTemplate = "https://management.azure.com"
36	apiVersion               = "2015-01-01"
37	nativeAppClientID        = "a87032a7-203c-4bf7-913c-44c50d23409a"
38	resource                 = "https://management.core.windows.net/"
39)
40
41var (
42	mode           string
43	tenantID       string
44	subscriptionID string
45	applicationID  string
46
47	tokenCachePath string
48	forceRefresh   bool
49	impatient      bool
50
51	certificatePath string
52)
53
54func init() {
55	flag.StringVar(&mode, "mode", "device", "mode of operation for SPT creation")
56	flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/pfx certificate")
57	flag.StringVar(&applicationID, "applicationId", "", "application id")
58	flag.StringVar(&tenantID, "tenantId", "", "tenant id")
59	flag.StringVar(&subscriptionID, "subscriptionId", "", "subscription id")
60	flag.StringVar(&tokenCachePath, "tokenCachePath", "", "location of oauth token cache")
61	flag.BoolVar(&forceRefresh, "forceRefresh", false, "pass true to force a token refresh")
62
63	flag.Parse()
64
65	log.Printf("mode(%s) certPath(%s) appID(%s) tenantID(%s), subID(%s)\n",
66		mode, certificatePath, applicationID, tenantID, subscriptionID)
67
68	if mode == "certificate" &&
69		(strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "") {
70		log.Fatalln("Bad usage. Using certificate mode. Please specify tenantID, subscriptionID")
71	}
72
73	if mode != "certificate" && mode != "device" {
74		log.Fatalln("Bad usage. Mode must be one of 'certificate' or 'device'.")
75	}
76
77	if mode == "device" && strings.TrimSpace(applicationID) == "" {
78		log.Println("Using device mode auth. Will use `azkube` clientID since none was specified on the comand line.")
79		applicationID = nativeAppClientID
80	}
81
82	if mode == "certificate" && strings.TrimSpace(certificatePath) == "" {
83		log.Fatalln("Bad usage. Mode 'certificate' requires the 'certificatePath' argument.")
84	}
85
86	if strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "" || strings.TrimSpace(applicationID) == "" {
87		log.Fatalln("Bad usage. Must specify the 'tenantId' and 'subscriptionId'")
88	}
89}
90
91func getSptFromCachedToken(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
92	token, err := adal.LoadToken(tokenCachePath)
93	if err != nil {
94		return nil, fmt.Errorf("failed to load token from cache: %v", err)
95	}
96
97	spt, _ := adal.NewServicePrincipalTokenFromManualToken(
98		oauthConfig,
99		clientID,
100		resource,
101		*token,
102		callbacks...)
103
104	return spt, nil
105}
106
107func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
108	privateKey, certificate, err := pkcs12.Decode(pkcs, password)
109	if err != nil {
110		return nil, nil, err
111	}
112
113	rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
114	if !isRsaKey {
115		return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
116	}
117
118	return certificate, rsaPrivateKey, nil
119}
120
121func getSptFromCertificate(oauthConfig adal.OAuthConfig, clientID, resource, certicatePath string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
122	certData, err := ioutil.ReadFile(certificatePath)
123	if err != nil {
124		return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
125	}
126
127	certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
128	if err != nil {
129		return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
130	}
131
132	spt, _ := adal.NewServicePrincipalTokenFromCertificate(
133		oauthConfig,
134		clientID,
135		certificate,
136		rsaPrivateKey,
137		resource,
138		callbacks...)
139
140	return spt, nil
141}
142
143func getSptFromDeviceFlow(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
144	oauthClient := &autorest.Client{}
145	deviceCode, err := adal.InitiateDeviceAuth(oauthClient, oauthConfig, clientID, resource)
146	if err != nil {
147		return nil, fmt.Errorf("failed to start device auth flow: %s", err)
148	}
149
150	fmt.Println(*deviceCode.Message)
151
152	token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
153	if err != nil {
154		return nil, fmt.Errorf("failed to finish device auth flow: %s", err)
155	}
156
157	spt, err := adal.NewServicePrincipalTokenFromManualToken(
158		oauthConfig,
159		clientID,
160		resource,
161		*token,
162		callbacks...)
163	if err != nil {
164		return nil, fmt.Errorf("failed to get oauth token from device flow: %v", err)
165	}
166
167	return spt, nil
168}
169
170func printResourceGroups(client *autorest.Client) error {
171	p := map[string]interface{}{"subscription-id": subscriptionID}
172	q := map[string]interface{}{"api-version": apiVersion}
173
174	req, _ := autorest.Prepare(&http.Request{},
175		autorest.AsGet(),
176		autorest.WithBaseURL(resourceGroupURLTemplate),
177		autorest.WithPathParameters("/subscriptions/{subscription-id}/resourcegroups", p),
178		autorest.WithQueryParameters(q))
179
180	resp, err := autorest.SendWithSender(client, req)
181	if err != nil {
182		return err
183	}
184
185	value := struct {
186		ResourceGroups []struct {
187			Name string `json:"name"`
188		} `json:"value"`
189	}{}
190
191	defer resp.Body.Close()
192	dec := json.NewDecoder(resp.Body)
193	err = dec.Decode(&value)
194	if err != nil {
195		return err
196	}
197
198	var groupNames = make([]string, len(value.ResourceGroups))
199	for i, name := range value.ResourceGroups {
200		groupNames[i] = name.Name
201	}
202
203	log.Println("Groups:", strings.Join(groupNames, ", "))
204	return err
205}
206
207func saveToken(spt adal.Token) {
208	if tokenCachePath != "" {
209		err := adal.SaveToken(tokenCachePath, 0600, spt)
210		if err != nil {
211			log.Println("error saving token", err)
212		} else {
213			log.Println("saved token to", tokenCachePath)
214		}
215	}
216}
217
218func main() {
219	var spt *adal.ServicePrincipalToken
220	var err error
221
222	callback := func(t adal.Token) error {
223		log.Println("refresh callback was called")
224		saveToken(t)
225		return nil
226	}
227
228	oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, tenantID)
229	if err != nil {
230		panic(err)
231	}
232
233	if tokenCachePath != "" {
234		log.Println("tokenCachePath specified; attempting to load from", tokenCachePath)
235		spt, err = getSptFromCachedToken(*oauthConfig, applicationID, resource, callback)
236		if err != nil {
237			spt = nil // just in case, this is the condition below
238			log.Println("loading from cache failed:", err)
239		}
240	}
241
242	if spt == nil {
243		log.Println("authenticating via 'mode'", mode)
244		switch mode {
245		case "device":
246			spt, err = getSptFromDeviceFlow(*oauthConfig, applicationID, resource, callback)
247		case "certificate":
248			spt, err = getSptFromCertificate(*oauthConfig, applicationID, resource, certificatePath, callback)
249		}
250		if err != nil {
251			log.Fatalln("failed to retrieve token:", err)
252		}
253
254		// should save it as soon as you get it since Refresh won't be called for some time
255		if tokenCachePath != "" {
256			saveToken(spt.Token())
257		}
258	}
259
260	client := &autorest.Client{}
261	client.Authorizer = autorest.NewBearerAuthorizer(spt)
262
263	printResourceGroups(client)
264
265	if forceRefresh {
266		err = spt.Refresh()
267		if err != nil {
268			panic(err)
269		}
270		printResourceGroups(client)
271	}
272}
273