1// Copyright 2017 Microsoft Corporation
2//
3//  Licensed under the Apache License, Version 2.0 (the "License");
4//  you may not use this file except in compliance with the License.
5//  You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14
15package azure
16
17import (
18	"errors"
19	"fmt"
20	"net/http"
21	"net/url"
22	"strings"
23	"time"
24
25	"github.com/Azure/go-autorest/autorest"
26)
27
28// DoRetryWithRegistration tries to register the resource provider in case it is unregistered.
29// It also handles request retries
30func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator {
31	return func(s autorest.Sender) autorest.Sender {
32		return autorest.SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
33			rr := autorest.NewRetriableRequest(r)
34			for currentAttempt := 0; currentAttempt < client.RetryAttempts; currentAttempt++ {
35				err = rr.Prepare()
36				if err != nil {
37					return resp, err
38				}
39
40				resp, err = autorest.SendWithSender(s, rr.Request(),
41					autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
42				)
43				if err != nil {
44					return resp, err
45				}
46
47				if resp.StatusCode != http.StatusConflict || client.SkipResourceProviderRegistration {
48					return resp, err
49				}
50
51				var re RequestError
52				if strings.Contains(r.Header.Get("Content-Type"), "xml") {
53					// XML errors (e.g. Storage Data Plane) only return the inner object
54					err = autorest.Respond(resp, autorest.ByUnmarshallingXML(&re.ServiceError))
55				} else {
56					err = autorest.Respond(resp, autorest.ByUnmarshallingJSON(&re))
57				}
58
59				if err != nil {
60					return resp, err
61				}
62				err = re
63
64				if re.ServiceError != nil && re.ServiceError.Code == "MissingSubscriptionRegistration" {
65					regErr := register(client, r, re)
66					if regErr != nil {
67						return resp, fmt.Errorf("failed auto registering Resource Provider: %s. Original error: %s", regErr, err)
68					}
69				}
70			}
71			return resp, err
72		})
73	}
74}
75
76func getProvider(re RequestError) (string, error) {
77	if re.ServiceError != nil && len(re.ServiceError.Details) > 0 {
78		return re.ServiceError.Details[0]["target"].(string), nil
79	}
80	return "", errors.New("provider was not found in the response")
81}
82
83func register(client autorest.Client, originalReq *http.Request, re RequestError) error {
84	subID := getSubscription(originalReq.URL.Path)
85	if subID == "" {
86		return errors.New("missing parameter subscriptionID to register resource provider")
87	}
88	providerName, err := getProvider(re)
89	if err != nil {
90		return fmt.Errorf("missing parameter provider to register resource provider: %s", err)
91	}
92	newURL := url.URL{
93		Scheme: originalReq.URL.Scheme,
94		Host:   originalReq.URL.Host,
95	}
96
97	// taken from the resources SDK
98	// with almost identical code, this sections are easier to mantain
99	// It is also not a good idea to import the SDK here
100	// https://github.com/Azure/azure-sdk-for-go/blob/9f366792afa3e0ddaecdc860e793ba9d75e76c27/arm/resources/resources/providers.go#L252
101	pathParameters := map[string]interface{}{
102		"resourceProviderNamespace": autorest.Encode("path", providerName),
103		"subscriptionId":            autorest.Encode("path", subID),
104	}
105
106	const APIVersion = "2016-09-01"
107	queryParameters := map[string]interface{}{
108		"api-version": APIVersion,
109	}
110
111	preparer := autorest.CreatePreparer(
112		autorest.AsPost(),
113		autorest.WithBaseURL(newURL.String()),
114		autorest.WithPathParameters("/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}/register", pathParameters),
115		autorest.WithQueryParameters(queryParameters),
116	)
117
118	req, err := preparer.Prepare(&http.Request{})
119	if err != nil {
120		return err
121	}
122	req = req.WithContext(originalReq.Context())
123
124	resp, err := autorest.SendWithSender(client, req,
125		autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
126	)
127	if err != nil {
128		return err
129	}
130
131	type Provider struct {
132		RegistrationState *string `json:"registrationState,omitempty"`
133	}
134	var provider Provider
135
136	err = autorest.Respond(
137		resp,
138		WithErrorUnlessStatusCode(http.StatusOK),
139		autorest.ByUnmarshallingJSON(&provider),
140		autorest.ByClosing(),
141	)
142	if err != nil {
143		return err
144	}
145
146	// poll for registered provisioning state
147	registrationStartTime := time.Now()
148	for err == nil && (client.PollingDuration == 0 || (client.PollingDuration != 0 && time.Since(registrationStartTime) < client.PollingDuration)) {
149		// taken from the resources SDK
150		// https://github.com/Azure/azure-sdk-for-go/blob/9f366792afa3e0ddaecdc860e793ba9d75e76c27/arm/resources/resources/providers.go#L45
151		preparer := autorest.CreatePreparer(
152			autorest.AsGet(),
153			autorest.WithBaseURL(newURL.String()),
154			autorest.WithPathParameters("/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}", pathParameters),
155			autorest.WithQueryParameters(queryParameters),
156		)
157		req, err = preparer.Prepare(&http.Request{})
158		if err != nil {
159			return err
160		}
161		req = req.WithContext(originalReq.Context())
162
163		resp, err := autorest.SendWithSender(client, req,
164			autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
165		)
166		if err != nil {
167			return err
168		}
169
170		err = autorest.Respond(
171			resp,
172			WithErrorUnlessStatusCode(http.StatusOK),
173			autorest.ByUnmarshallingJSON(&provider),
174			autorest.ByClosing(),
175		)
176		if err != nil {
177			return err
178		}
179
180		if provider.RegistrationState != nil &&
181			*provider.RegistrationState == "Registered" {
182			break
183		}
184
185		delayed := autorest.DelayWithRetryAfter(resp, originalReq.Context().Done())
186		if !delayed && !autorest.DelayForBackoff(client.PollingDelay, 0, originalReq.Context().Done()) {
187			return originalReq.Context().Err()
188		}
189	}
190	if client.PollingDuration != 0 && !(time.Since(registrationStartTime) < client.PollingDuration) {
191		return errors.New("polling for resource provider registration has exceeded the polling duration")
192	}
193	return err
194}
195
196func getSubscription(path string) string {
197	parts := strings.Split(path, "/")
198	for i, v := range parts {
199		if v == "subscriptions" && (i+1) < len(parts) {
200			return parts[i+1]
201		}
202	}
203	return ""
204}
205