1package main 2 3// Copyright 2017 Microsoft Corporation 4// 5// Licensed under the Apache License, Version 2.0 (the "License"); 6// you may not use this file except in compliance with the License. 7// You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, software 12// distributed under the License is distributed on an "AS IS" BASIS, 13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14// See the License for the specific language governing permissions and 15// limitations under the License. 16 17import ( 18 "crypto/rsa" 19 "crypto/x509" 20 "encoding/json" 21 "flag" 22 "fmt" 23 "io/ioutil" 24 "log" 25 "net/http" 26 "strings" 27 28 "github.com/Azure/go-autorest/autorest" 29 "github.com/Azure/go-autorest/autorest/adal" 30 "github.com/Azure/go-autorest/autorest/azure" 31 "golang.org/x/crypto/pkcs12" 32) 33 34const ( 35 resourceGroupURLTemplate = "https://management.azure.com" 36 apiVersion = "2015-01-01" 37 nativeAppClientID = "a87032a7-203c-4bf7-913c-44c50d23409a" 38 resource = "https://management.core.windows.net/" 39) 40 41var ( 42 mode string 43 tenantID string 44 subscriptionID string 45 applicationID string 46 47 tokenCachePath string 48 forceRefresh bool 49 impatient bool 50 51 certificatePath string 52) 53 54func init() { 55 flag.StringVar(&mode, "mode", "device", "mode of operation for SPT creation") 56 flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/pfx certificate") 57 flag.StringVar(&applicationID, "applicationId", "", "application id") 58 flag.StringVar(&tenantID, "tenantId", "", "tenant id") 59 flag.StringVar(&subscriptionID, "subscriptionId", "", "subscription id") 60 flag.StringVar(&tokenCachePath, "tokenCachePath", "", "location of oauth token cache") 61 flag.BoolVar(&forceRefresh, "forceRefresh", false, "pass true to force a token refresh") 62 63 flag.Parse() 64 65 log.Printf("mode(%s) certPath(%s) appID(%s) tenantID(%s), subID(%s)\n", 66 mode, certificatePath, applicationID, tenantID, subscriptionID) 67 68 if mode == "certificate" && 69 (strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "") { 70 log.Fatalln("Bad usage. Using certificate mode. Please specify tenantID, subscriptionID") 71 } 72 73 if mode != "certificate" && mode != "device" { 74 log.Fatalln("Bad usage. Mode must be one of 'certificate' or 'device'.") 75 } 76 77 if mode == "device" && strings.TrimSpace(applicationID) == "" { 78 log.Println("Using device mode auth. Will use `azkube` clientID since none was specified on the comand line.") 79 applicationID = nativeAppClientID 80 } 81 82 if mode == "certificate" && strings.TrimSpace(certificatePath) == "" { 83 log.Fatalln("Bad usage. Mode 'certificate' requires the 'certificatePath' argument.") 84 } 85 86 if strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "" || strings.TrimSpace(applicationID) == "" { 87 log.Fatalln("Bad usage. Must specify the 'tenantId' and 'subscriptionId'") 88 } 89} 90 91func getSptFromCachedToken(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) { 92 token, err := adal.LoadToken(tokenCachePath) 93 if err != nil { 94 return nil, fmt.Errorf("failed to load token from cache: %v", err) 95 } 96 97 spt, _ := adal.NewServicePrincipalTokenFromManualToken( 98 oauthConfig, 99 clientID, 100 resource, 101 *token, 102 callbacks...) 103 104 return spt, nil 105} 106 107func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) { 108 privateKey, certificate, err := pkcs12.Decode(pkcs, password) 109 if err != nil { 110 return nil, nil, err 111 } 112 113 rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey) 114 if !isRsaKey { 115 return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key") 116 } 117 118 return certificate, rsaPrivateKey, nil 119} 120 121func getSptFromCertificate(oauthConfig adal.OAuthConfig, clientID, resource, certicatePath string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) { 122 certData, err := ioutil.ReadFile(certificatePath) 123 if err != nil { 124 return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err) 125 } 126 127 certificate, rsaPrivateKey, err := decodePkcs12(certData, "") 128 if err != nil { 129 return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err) 130 } 131 132 spt, _ := adal.NewServicePrincipalTokenFromCertificate( 133 oauthConfig, 134 clientID, 135 certificate, 136 rsaPrivateKey, 137 resource, 138 callbacks...) 139 140 return spt, nil 141} 142 143func getSptFromDeviceFlow(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) { 144 oauthClient := &autorest.Client{} 145 deviceCode, err := adal.InitiateDeviceAuth(oauthClient, oauthConfig, clientID, resource) 146 if err != nil { 147 return nil, fmt.Errorf("failed to start device auth flow: %s", err) 148 } 149 150 fmt.Println(*deviceCode.Message) 151 152 token, err := adal.WaitForUserCompletion(oauthClient, deviceCode) 153 if err != nil { 154 return nil, fmt.Errorf("failed to finish device auth flow: %s", err) 155 } 156 157 spt, err := adal.NewServicePrincipalTokenFromManualToken( 158 oauthConfig, 159 clientID, 160 resource, 161 *token, 162 callbacks...) 163 if err != nil { 164 return nil, fmt.Errorf("failed to get oauth token from device flow: %v", err) 165 } 166 167 return spt, nil 168} 169 170func printResourceGroups(client *autorest.Client) error { 171 p := map[string]interface{}{"subscription-id": subscriptionID} 172 q := map[string]interface{}{"api-version": apiVersion} 173 174 req, _ := autorest.Prepare(&http.Request{}, 175 autorest.AsGet(), 176 autorest.WithBaseURL(resourceGroupURLTemplate), 177 autorest.WithPathParameters("/subscriptions/{subscription-id}/resourcegroups", p), 178 autorest.WithQueryParameters(q)) 179 180 resp, err := autorest.SendWithSender(client, req) 181 if err != nil { 182 return err 183 } 184 185 value := struct { 186 ResourceGroups []struct { 187 Name string `json:"name"` 188 } `json:"value"` 189 }{} 190 191 defer resp.Body.Close() 192 dec := json.NewDecoder(resp.Body) 193 err = dec.Decode(&value) 194 if err != nil { 195 return err 196 } 197 198 var groupNames = make([]string, len(value.ResourceGroups)) 199 for i, name := range value.ResourceGroups { 200 groupNames[i] = name.Name 201 } 202 203 log.Println("Groups:", strings.Join(groupNames, ", ")) 204 return err 205} 206 207func saveToken(spt adal.Token) { 208 if tokenCachePath != "" { 209 err := adal.SaveToken(tokenCachePath, 0600, spt) 210 if err != nil { 211 log.Println("error saving token", err) 212 } else { 213 log.Println("saved token to", tokenCachePath) 214 } 215 } 216} 217 218func main() { 219 var spt *adal.ServicePrincipalToken 220 var err error 221 222 callback := func(t adal.Token) error { 223 log.Println("refresh callback was called") 224 saveToken(t) 225 return nil 226 } 227 228 oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, tenantID) 229 if err != nil { 230 panic(err) 231 } 232 233 if tokenCachePath != "" { 234 log.Println("tokenCachePath specified; attempting to load from", tokenCachePath) 235 spt, err = getSptFromCachedToken(*oauthConfig, applicationID, resource, callback) 236 if err != nil { 237 spt = nil // just in case, this is the condition below 238 log.Println("loading from cache failed:", err) 239 } 240 } 241 242 if spt == nil { 243 log.Println("authenticating via 'mode'", mode) 244 switch mode { 245 case "device": 246 spt, err = getSptFromDeviceFlow(*oauthConfig, applicationID, resource, callback) 247 case "certificate": 248 spt, err = getSptFromCertificate(*oauthConfig, applicationID, resource, certificatePath, callback) 249 } 250 if err != nil { 251 log.Fatalln("failed to retrieve token:", err) 252 } 253 254 // should save it as soon as you get it since Refresh won't be called for some time 255 if tokenCachePath != "" { 256 saveToken(spt.Token()) 257 } 258 } 259 260 client := &autorest.Client{} 261 client.Authorizer = autorest.NewBearerAuthorizer(spt) 262 263 printResourceGroups(client) 264 265 if forceRefresh { 266 err = spt.Refresh() 267 if err != nil { 268 panic(err) 269 } 270 printResourceGroups(client) 271 } 272} 273