1//go:build go1.16
2// +build go1.16
3
4// Copyright (c) Microsoft Corporation. All rights reserved.
5// Licensed under the MIT License.
6
7package pollers
8
9import (
10	"context"
11	"encoding/json"
12	"errors"
13	"fmt"
14	"io/ioutil"
15	"net/http"
16	"reflect"
17	"time"
18
19	"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline"
20	"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
21	"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
22)
23
24// KindFromToken extracts the poller kind from the provided token.
25// If the pollerID doesn't match what's in the token an error is returned.
26func KindFromToken(pollerID, token string) (string, error) {
27	// unmarshal into JSON object to determine the poller type
28	obj := map[string]interface{}{}
29	err := json.Unmarshal([]byte(token), &obj)
30	if err != nil {
31		return "", err
32	}
33	t, ok := obj["type"]
34	if !ok {
35		return "", errors.New("missing type field")
36	}
37	tt, ok := t.(string)
38	if !ok {
39		return "", fmt.Errorf("invalid type format %T", t)
40	}
41	ttID, ttKind, err := DecodeID(tt)
42	if err != nil {
43		return "", err
44	}
45	// ensure poller types match
46	if ttID != pollerID {
47		return "", fmt.Errorf("cannot resume from this poller token.  expected %s, received %s", pollerID, ttID)
48	}
49	return ttKind, nil
50}
51
52// PollerType returns the concrete type of the poller (FOR TESTING PURPOSES).
53func PollerType(p *Poller) reflect.Type {
54	return reflect.TypeOf(p.lro)
55}
56
57// NewPoller creates a Poller from the specified input.
58func NewPoller(lro Operation, resp *http.Response, pl pipeline.Pipeline, eu func(*http.Response) error) *Poller {
59	return &Poller{lro: lro, pl: pl, eu: eu, resp: resp}
60}
61
62// Poller encapsulates state and logic for polling on long-running operations.
63type Poller struct {
64	lro  Operation
65	pl   pipeline.Pipeline
66	eu   func(*http.Response) error
67	resp *http.Response
68	err  error
69}
70
71// Done returns true if the LRO has reached a terminal state.
72func (l *Poller) Done() bool {
73	if l.err != nil {
74		return true
75	}
76	return l.lro.Done()
77}
78
79// Poll sends a polling request to the polling endpoint and returns the response or error.
80func (l *Poller) Poll(ctx context.Context) (*http.Response, error) {
81	if l.Done() {
82		// the LRO has reached a terminal state, don't poll again
83		if l.resp != nil {
84			return l.resp, nil
85		}
86		return nil, l.err
87	}
88	req, err := pipeline.NewRequest(ctx, http.MethodGet, l.lro.URL())
89	if err != nil {
90		return nil, err
91	}
92	resp, err := l.pl.Do(req)
93	if err != nil {
94		// don't update the poller for failed requests
95		return nil, err
96	}
97	defer resp.Body.Close()
98	if !StatusCodeValid(resp) {
99		// the LRO failed.  unmarshall the error and update state
100		l.err = l.eu(resp)
101		l.resp = nil
102		return nil, l.err
103	}
104	if err = l.lro.Update(resp); err != nil {
105		return nil, err
106	}
107	l.resp = resp
108	log.Writef(log.LongRunningOperation, "Status %s", l.lro.Status())
109	if Failed(l.lro.Status()) {
110		l.err = l.eu(resp)
111		l.resp = nil
112		return nil, l.err
113	}
114	return l.resp, nil
115}
116
117// ResumeToken returns a token string that can be used to resume a poller that has not yet reached a terminal state.
118func (l *Poller) ResumeToken() (string, error) {
119	if l.Done() {
120		return "", errors.New("cannot create a ResumeToken from a poller in a terminal state")
121	}
122	b, err := json.Marshal(l.lro)
123	if err != nil {
124		return "", err
125	}
126	return string(b), nil
127}
128
129// FinalResponse will perform a final GET request and return the final HTTP response for the polling
130// operation and unmarshall the content of the payload into the respType interface that is provided.
131func (l *Poller) FinalResponse(ctx context.Context, respType interface{}) (*http.Response, error) {
132	if !l.Done() {
133		return nil, errors.New("cannot return a final response from a poller in a non-terminal state")
134	}
135	// update l.resp with the content from final GET if applicable
136	if u := l.lro.FinalGetURL(); u != "" {
137		log.Write(log.LongRunningOperation, "Performing final GET.")
138		req, err := pipeline.NewRequest(ctx, http.MethodGet, u)
139		if err != nil {
140			return nil, err
141		}
142		resp, err := l.pl.Do(req)
143		if err != nil {
144			return nil, err
145		}
146		if !StatusCodeValid(resp) {
147			return nil, l.eu(resp)
148		}
149		l.resp = resp
150	}
151	// if there's nothing to unmarshall into or no response body just return the final response
152	if respType == nil {
153		return l.resp, nil
154	} else if l.resp.StatusCode == http.StatusNoContent || l.resp.ContentLength == 0 {
155		log.Write(log.LongRunningOperation, "final response specifies a response type but no payload was received")
156		return l.resp, nil
157	}
158	body, err := ioutil.ReadAll(l.resp.Body)
159	l.resp.Body.Close()
160	if err != nil {
161		return nil, err
162	}
163	if err = json.Unmarshal(body, respType); err != nil {
164		return nil, err
165	}
166	return l.resp, nil
167}
168
169// PollUntilDone will handle the entire span of the polling operation until a terminal state is reached,
170// then return the final HTTP response for the polling operation and unmarshal the content of the payload
171// into the respType interface that is provided.
172// freq - the time to wait between polling intervals if the endpoint doesn't send a Retry-After header.
173//        A good starting value is 30 seconds.  Note that some resources might benefit from a different value.
174func (l *Poller) PollUntilDone(ctx context.Context, freq time.Duration, respType interface{}) (*http.Response, error) {
175	start := time.Now()
176	logPollUntilDoneExit := func(v interface{}) {
177		log.Writef(log.LongRunningOperation, "END PollUntilDone() for %T: %v, total time: %s", l.lro, v, time.Since(start))
178	}
179	log.Writef(log.LongRunningOperation, "BEGIN PollUntilDone() for %T", l.lro)
180	if l.resp != nil {
181		// initial check for a retry-after header existing on the initial response
182		if retryAfter := shared.RetryAfter(l.resp); retryAfter > 0 {
183			log.Writef(log.LongRunningOperation, "initial Retry-After delay for %s", retryAfter.String())
184			if err := shared.Delay(ctx, retryAfter); err != nil {
185				logPollUntilDoneExit(err)
186				return nil, err
187			}
188		}
189	}
190	// begin polling the endpoint until a terminal state is reached
191	for {
192		resp, err := l.Poll(ctx)
193		if err != nil {
194			logPollUntilDoneExit(err)
195			return nil, err
196		}
197		if l.Done() {
198			logPollUntilDoneExit(l.lro.Status())
199			return l.FinalResponse(ctx, respType)
200		}
201		d := freq
202		if retryAfter := shared.RetryAfter(resp); retryAfter > 0 {
203			log.Writef(log.LongRunningOperation, "Retry-After delay for %s", retryAfter.String())
204			d = retryAfter
205		} else {
206			log.Writef(log.LongRunningOperation, "delay for %s", d.String())
207		}
208		if err = shared.Delay(ctx, d); err != nil {
209			logPollUntilDoneExit(err)
210			return nil, err
211		}
212	}
213}
214