1// Copyright 2017 Microsoft Corporation 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package azure 16 17import ( 18 "errors" 19 "fmt" 20 "net/http" 21 "net/url" 22 "strings" 23 "time" 24 25 "github.com/Azure/go-autorest/autorest" 26) 27 28// DoRetryWithRegistration tries to register the resource provider in case it is unregistered. 29// It also handles request retries 30func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator { 31 return func(s autorest.Sender) autorest.Sender { 32 return autorest.SenderFunc(func(r *http.Request) (resp *http.Response, err error) { 33 rr := autorest.NewRetriableRequest(r) 34 for currentAttempt := 0; currentAttempt < client.RetryAttempts; currentAttempt++ { 35 err = rr.Prepare() 36 if err != nil { 37 return resp, err 38 } 39 40 resp, err = autorest.SendWithSender(s, rr.Request(), 41 autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...), 42 ) 43 if err != nil { 44 return resp, err 45 } 46 47 if resp.StatusCode != http.StatusConflict || client.SkipResourceProviderRegistration { 48 return resp, err 49 } 50 51 var re RequestError 52 if strings.Contains(r.Header.Get("Content-Type"), "xml") { 53 // XML errors (e.g. Storage Data Plane) only return the inner object 54 err = autorest.Respond(resp, autorest.ByUnmarshallingXML(&re.ServiceError)) 55 } else { 56 err = autorest.Respond(resp, autorest.ByUnmarshallingJSON(&re)) 57 } 58 59 if err != nil { 60 return resp, err 61 } 62 err = re 63 64 if re.ServiceError != nil && re.ServiceError.Code == "MissingSubscriptionRegistration" { 65 regErr := register(client, r, re) 66 if regErr != nil { 67 return resp, fmt.Errorf("failed auto registering Resource Provider: %s. Original error: %s", regErr, err) 68 } 69 } 70 } 71 return resp, err 72 }) 73 } 74} 75 76func getProvider(re RequestError) (string, error) { 77 if re.ServiceError != nil && len(re.ServiceError.Details) > 0 { 78 return re.ServiceError.Details[0]["target"].(string), nil 79 } 80 return "", errors.New("provider was not found in the response") 81} 82 83func register(client autorest.Client, originalReq *http.Request, re RequestError) error { 84 subID := getSubscription(originalReq.URL.Path) 85 if subID == "" { 86 return errors.New("missing parameter subscriptionID to register resource provider") 87 } 88 providerName, err := getProvider(re) 89 if err != nil { 90 return fmt.Errorf("missing parameter provider to register resource provider: %s", err) 91 } 92 newURL := url.URL{ 93 Scheme: originalReq.URL.Scheme, 94 Host: originalReq.URL.Host, 95 } 96 97 // taken from the resources SDK 98 // with almost identical code, this sections are easier to mantain 99 // It is also not a good idea to import the SDK here 100 // https://github.com/Azure/azure-sdk-for-go/blob/9f366792afa3e0ddaecdc860e793ba9d75e76c27/arm/resources/resources/providers.go#L252 101 pathParameters := map[string]interface{}{ 102 "resourceProviderNamespace": autorest.Encode("path", providerName), 103 "subscriptionId": autorest.Encode("path", subID), 104 } 105 106 const APIVersion = "2016-09-01" 107 queryParameters := map[string]interface{}{ 108 "api-version": APIVersion, 109 } 110 111 preparer := autorest.CreatePreparer( 112 autorest.AsPost(), 113 autorest.WithBaseURL(newURL.String()), 114 autorest.WithPathParameters("/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}/register", pathParameters), 115 autorest.WithQueryParameters(queryParameters), 116 ) 117 118 req, err := preparer.Prepare(&http.Request{}) 119 if err != nil { 120 return err 121 } 122 req = req.WithContext(originalReq.Context()) 123 124 resp, err := autorest.SendWithSender(client, req, 125 autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...), 126 ) 127 if err != nil { 128 return err 129 } 130 131 type Provider struct { 132 RegistrationState *string `json:"registrationState,omitempty"` 133 } 134 var provider Provider 135 136 err = autorest.Respond( 137 resp, 138 WithErrorUnlessStatusCode(http.StatusOK), 139 autorest.ByUnmarshallingJSON(&provider), 140 autorest.ByClosing(), 141 ) 142 if err != nil { 143 return err 144 } 145 146 // poll for registered provisioning state 147 registrationStartTime := time.Now() 148 for err == nil && (client.PollingDuration == 0 || (client.PollingDuration != 0 && time.Since(registrationStartTime) < client.PollingDuration)) { 149 // taken from the resources SDK 150 // https://github.com/Azure/azure-sdk-for-go/blob/9f366792afa3e0ddaecdc860e793ba9d75e76c27/arm/resources/resources/providers.go#L45 151 preparer := autorest.CreatePreparer( 152 autorest.AsGet(), 153 autorest.WithBaseURL(newURL.String()), 154 autorest.WithPathParameters("/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}", pathParameters), 155 autorest.WithQueryParameters(queryParameters), 156 ) 157 req, err = preparer.Prepare(&http.Request{}) 158 if err != nil { 159 return err 160 } 161 req = req.WithContext(originalReq.Context()) 162 163 resp, err := autorest.SendWithSender(client, req, 164 autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...), 165 ) 166 if err != nil { 167 return err 168 } 169 170 err = autorest.Respond( 171 resp, 172 WithErrorUnlessStatusCode(http.StatusOK), 173 autorest.ByUnmarshallingJSON(&provider), 174 autorest.ByClosing(), 175 ) 176 if err != nil { 177 return err 178 } 179 180 if provider.RegistrationState != nil && 181 *provider.RegistrationState == "Registered" { 182 break 183 } 184 185 delayed := autorest.DelayWithRetryAfter(resp, originalReq.Context().Done()) 186 if !delayed && !autorest.DelayForBackoff(client.PollingDelay, 0, originalReq.Context().Done()) { 187 return originalReq.Context().Err() 188 } 189 } 190 if client.PollingDuration != 0 && !(time.Since(registrationStartTime) < client.PollingDuration) { 191 return errors.New("polling for resource provider registration has exceeded the polling duration") 192 } 193 return err 194} 195 196func getSubscription(path string) string { 197 parts := strings.Split(path, "/") 198 for i, v := range parts { 199 if v == "subscriptions" && (i+1) < len(parts) { 200 return parts[i+1] 201 } 202 } 203 return "" 204} 205