1//go:build go1.16
2// +build go1.16
3
4// Copyright (c) Microsoft Corporation. All rights reserved.
5// Licensed under the MIT License.
6
7package arm
8
9import (
10	"context"
11	"net/http"
12	"strings"
13	"testing"
14	"time"
15
16	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
17	armruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime"
18	"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline"
19	"github.com/Azure/azure-sdk-for-go/sdk/azcore/log"
20	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
21	azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
22	"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
23)
24
25type mockTokenCred struct{}
26
27func (mockTokenCred) NewAuthenticationPolicy(azruntime.AuthenticationOptions) policy.Policy {
28	return pipeline.PolicyFunc(func(req *policy.Request) (*http.Response, error) {
29		return req.Next()
30	})
31}
32
33func (mockTokenCred) GetToken(context.Context, policy.TokenRequestOptions) (*azcore.AccessToken, error) {
34	return &azcore.AccessToken{
35		Token:     "abc123",
36		ExpiresOn: time.Now().Add(1 * time.Hour),
37	}, nil
38}
39
40const rpUnregisteredResp = `{
41	"error":{
42		"code":"MissingSubscriptionRegistration",
43		"message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions.",
44		"details":[{
45				"code":"MissingSubscriptionRegistration",
46				"target":"Microsoft.Storage",
47				"message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions."
48			}
49		]
50	}
51}`
52
53func TestNewDefaultConnection(t *testing.T) {
54	opt := ConnectionOptions{}
55	con := NewDefaultConnection(mockTokenCred{}, &opt)
56	if ep := con.Endpoint(); ep != AzurePublicCloud {
57		t.Fatalf("unexpected endpoint %s", ep)
58	}
59}
60
61func TestNewConnection(t *testing.T) {
62	const customEndpoint = "https://contoso.com/fake/endpoint"
63	con := NewConnection(customEndpoint, mockTokenCred{}, nil)
64	if ep := con.Endpoint(); ep != customEndpoint {
65		t.Fatalf("unexpected endpoint %s", ep)
66	}
67}
68
69func TestNewConnectionWithOptions(t *testing.T) {
70	srv, close := mock.NewServer()
71	defer close()
72	srv.AppendResponse()
73	opt := ConnectionOptions{}
74	opt.HTTPClient = srv
75	con := NewConnection(srv.URL(), mockTokenCred{}, &opt)
76	if ep := con.Endpoint(); ep != srv.URL() {
77		t.Fatalf("unexpected endpoint %s", ep)
78	}
79	req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL())
80	if err != nil {
81		t.Fatalf("Unexpected error: %v", err)
82	}
83	resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req)
84	if err != nil {
85		t.Fatalf("Unexpected error: %v", err)
86	}
87	if resp.StatusCode != http.StatusOK {
88		t.Fatalf("unexpected status code: %d", resp.StatusCode)
89	}
90	if ua := resp.Request.Header.Get("User-Agent"); !strings.HasPrefix(ua, "azsdk-go-armtest/v1.2.3") {
91		t.Fatalf("unexpected User-Agent %s", ua)
92	}
93}
94
95func TestNewConnectionWithCustomTelemetry(t *testing.T) {
96	const myTelemetry = "something"
97	srv, close := mock.NewServer()
98	defer close()
99	srv.AppendResponse()
100	opt := ConnectionOptions{}
101	opt.HTTPClient = srv
102	opt.Telemetry.ApplicationID = myTelemetry
103	con := NewConnection(srv.URL(), mockTokenCred{}, &opt)
104	if ep := con.Endpoint(); ep != srv.URL() {
105		t.Fatalf("unexpected endpoint %s", ep)
106	}
107	if opt.Telemetry.ApplicationID != myTelemetry {
108		t.Fatalf("telemetry was modified: %s", opt.Telemetry.ApplicationID)
109	}
110	req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL())
111	if err != nil {
112		t.Fatalf("Unexpected error: %v", err)
113	}
114	resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req)
115	if err != nil {
116		t.Fatalf("Unexpected error: %v", err)
117	}
118	if resp.StatusCode != http.StatusOK {
119		t.Fatalf("unexpected status code: %d", resp.StatusCode)
120	}
121	if ua := resp.Request.Header.Get("User-Agent"); !strings.HasPrefix(ua, myTelemetry+" "+"azsdk-go-armtest/v1.2.3") {
122		t.Fatalf("unexpected User-Agent %s", ua)
123	}
124}
125
126func TestDisableAutoRPRegistration(t *testing.T) {
127	srv, close := mock.NewServer()
128	defer close()
129	// initial response that RP is unregistered
130	srv.SetResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp)))
131	con := NewConnection(srv.URL(), mockTokenCred{}, &ConnectionOptions{DisableRPRegistration: true})
132	if ep := con.Endpoint(); ep != srv.URL() {
133		t.Fatalf("unexpected endpoint %s", ep)
134	}
135	req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL())
136	if err != nil {
137		t.Fatalf("Unexpected error: %v", err)
138	}
139	// log only RP registration
140	log.SetClassifications(armruntime.LogRPRegistration)
141	defer func() {
142		// reset logging
143		log.SetClassifications()
144	}()
145	logEntries := 0
146	log.SetListener(func(cls log.Classification, msg string) {
147		logEntries++
148	})
149	resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req)
150	if err != nil {
151		t.Fatal(err)
152	}
153	if resp.StatusCode != http.StatusConflict {
154		t.Fatalf("unexpected status code %d:", resp.StatusCode)
155	}
156	// shouldn't be any log entries
157	if logEntries != 0 {
158		t.Fatalf("expected 0 log entries, got %d", logEntries)
159	}
160}
161
162// policy that tracks the number of times it was invoked
163type countingPolicy struct {
164	count int
165}
166
167func (p *countingPolicy) Do(req *policy.Request) (*http.Response, error) {
168	p.count++
169	return req.Next()
170}
171
172func TestConnectionWithCustomPolicies(t *testing.T) {
173	srv, close := mock.NewServer()
174	defer close()
175	// initial response is a failure to trigger retry
176	srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError))
177	srv.AppendResponse(mock.WithStatusCode(http.StatusOK))
178	perCallPolicy := countingPolicy{}
179	perRetryPolicy := countingPolicy{}
180	con := NewConnection(srv.URL(), mockTokenCred{}, &ConnectionOptions{
181		DisableRPRegistration: true,
182		PerCallPolicies:       []policy.Policy{&perCallPolicy},
183		PerRetryPolicies:      []policy.Policy{&perRetryPolicy},
184	})
185	req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL())
186	if err != nil {
187		t.Fatal(err)
188	}
189	resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req)
190	if err != nil {
191		t.Fatal(err)
192	}
193	if resp.StatusCode != http.StatusOK {
194		t.Fatalf("unexpected status code %d", resp.StatusCode)
195	}
196	if perCallPolicy.count != 1 {
197		t.Fatalf("unexpected per call policy count %d", perCallPolicy.count)
198	}
199	if perRetryPolicy.count != 2 {
200		t.Fatalf("unexpected per retry policy count %d", perRetryPolicy.count)
201	}
202}
203