1package azblob
2
3import (
4	"context"
5	"errors"
6	"sync/atomic"
7
8	"runtime"
9	"sync"
10	"time"
11
12	"github.com/Azure/azure-pipeline-go/pipeline"
13)
14
15// TokenRefresher represents a callback method that you write; this method is called periodically
16// so you can refresh the token credential's value.
17type TokenRefresher func(credential TokenCredential) time.Duration
18
19// TokenCredential represents a token credential (which is also a pipeline.Factory).
20type TokenCredential interface {
21	Credential
22	Token() string
23	SetToken(newToken string)
24}
25
26// NewTokenCredential creates a token credential for use with role-based access control (RBAC) access to Azure Storage
27// resources. You initialize the TokenCredential with an initial token value. If you pass a non-nil value for
28// tokenRefresher, then the function you pass will be called immediately so it can refresh and change the
29// TokenCredential's token value by calling SetToken. Your tokenRefresher function must return a time.Duration
30// indicating how long the TokenCredential object should wait before calling your tokenRefresher function again.
31// If your tokenRefresher callback fails to refresh the token, you can return a duration of 0 to stop your
32// TokenCredential object from ever invoking tokenRefresher again. Also, oen way to deal with failing to refresh a
33// token is to cancel a context.Context object used by requests that have the TokenCredential object in their pipeline.
34func NewTokenCredential(initialToken string, tokenRefresher TokenRefresher) TokenCredential {
35	tc := &tokenCredential{}
36	tc.SetToken(initialToken) // We don't set it above to guarantee atomicity
37	if tokenRefresher == nil {
38		return tc // If no callback specified, return the simple tokenCredential
39	}
40
41	tcwr := &tokenCredentialWithRefresh{token: tc}
42	tcwr.token.startRefresh(tokenRefresher)
43	runtime.SetFinalizer(tcwr, func(deadTC *tokenCredentialWithRefresh) {
44		deadTC.token.stopRefresh()
45		deadTC.token = nil //  Sanity (not really required)
46	})
47	return tcwr
48}
49
50// tokenCredentialWithRefresh is a wrapper over a token credential.
51// When this wrapper object gets GC'd, it stops the tokenCredential's timer
52// which allows the tokenCredential object to also be GC'd.
53type tokenCredentialWithRefresh struct {
54	token *tokenCredential
55}
56
57// credentialMarker is a package-internal method that exists just to satisfy the Credential interface.
58func (*tokenCredentialWithRefresh) credentialMarker() {}
59
60// Token returns the current token value
61func (f *tokenCredentialWithRefresh) Token() string { return f.token.Token() }
62
63// SetToken changes the current token value
64func (f *tokenCredentialWithRefresh) SetToken(token string) { f.token.SetToken(token) }
65
66// New satisfies pipeline.Factory's New method creating a pipeline policy object.
67func (f *tokenCredentialWithRefresh) New(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.Policy {
68	return f.token.New(next, po)
69}
70
71///////////////////////////////////////////////////////////////////////////////
72
73// tokenCredential is a pipeline.Factory is the credential's policy factory.
74type tokenCredential struct {
75	token atomic.Value
76
77	// The members below are only used if the user specified a tokenRefresher callback function.
78	timer          *time.Timer
79	tokenRefresher TokenRefresher
80	lock           sync.Mutex
81	stopped        bool
82}
83
84// credentialMarker is a package-internal method that exists just to satisfy the Credential interface.
85func (*tokenCredential) credentialMarker() {}
86
87// Token returns the current token value
88func (f *tokenCredential) Token() string { return f.token.Load().(string) }
89
90// SetToken changes the current token value
91func (f *tokenCredential) SetToken(token string) { f.token.Store(token) }
92
93// startRefresh calls refresh which immediately calls tokenRefresher
94// and then starts a timer to call tokenRefresher in the future.
95func (f *tokenCredential) startRefresh(tokenRefresher TokenRefresher) {
96	f.tokenRefresher = tokenRefresher
97	f.stopped = false // In case user calls StartRefresh, StopRefresh, & then StartRefresh again
98	f.refresh()
99}
100
101// refresh calls the user's tokenRefresher so they can refresh the token (by
102// calling SetToken) and then starts another time (based on the returned duration)
103// in order to refresh the token again in the future.
104func (f *tokenCredential) refresh() {
105	d := f.tokenRefresher(f) // Invoke the user's refresh callback outside of the lock
106	if d > 0 {               // If duration is 0 or negative, refresher wants to not be called again
107		f.lock.Lock()
108		if !f.stopped {
109			f.timer = time.AfterFunc(d, f.refresh)
110		}
111		f.lock.Unlock()
112	}
113}
114
115// stopRefresh stops any pending timer and sets stopped field to true to prevent
116// any new timer from starting.
117// NOTE: Stopping the timer allows the GC to destroy the tokenCredential object.
118func (f *tokenCredential) stopRefresh() {
119	f.lock.Lock()
120	f.stopped = true
121	if f.timer != nil {
122		f.timer.Stop()
123	}
124	f.lock.Unlock()
125}
126
127// New satisfies pipeline.Factory's New method creating a pipeline policy object.
128func (f *tokenCredential) New(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.Policy {
129	return pipeline.PolicyFunc(func(ctx context.Context, request pipeline.Request) (pipeline.Response, error) {
130		if request.URL.Scheme != "https" {
131			// HTTPS must be used, otherwise the tokens are at the risk of being exposed
132			return nil, errors.New("token credentials require a URL using the https protocol scheme")
133		}
134		request.Header[headerAuthorization] = []string{"Bearer " + f.Token()}
135		return next.Do(ctx, request)
136	})
137}
138