1//go:build go1.16 2// +build go1.16 3 4// Copyright (c) Microsoft Corporation. All rights reserved. 5// Licensed under the MIT License. 6 7package arm 8 9import ( 10 "context" 11 "net/http" 12 "strings" 13 "testing" 14 "time" 15 16 "github.com/Azure/azure-sdk-for-go/sdk/azcore" 17 armruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime" 18 "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" 19 "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" 20 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" 21 azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" 22 "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" 23) 24 25type mockTokenCred struct{} 26 27func (mockTokenCred) NewAuthenticationPolicy(azruntime.AuthenticationOptions) policy.Policy { 28 return pipeline.PolicyFunc(func(req *policy.Request) (*http.Response, error) { 29 return req.Next() 30 }) 31} 32 33func (mockTokenCred) GetToken(context.Context, policy.TokenRequestOptions) (*azcore.AccessToken, error) { 34 return &azcore.AccessToken{ 35 Token: "abc123", 36 ExpiresOn: time.Now().Add(1 * time.Hour), 37 }, nil 38} 39 40const rpUnregisteredResp = `{ 41 "error":{ 42 "code":"MissingSubscriptionRegistration", 43 "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions.", 44 "details":[{ 45 "code":"MissingSubscriptionRegistration", 46 "target":"Microsoft.Storage", 47 "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions." 48 } 49 ] 50 } 51}` 52 53func TestNewDefaultConnection(t *testing.T) { 54 opt := ConnectionOptions{} 55 con := NewDefaultConnection(mockTokenCred{}, &opt) 56 if ep := con.Endpoint(); ep != AzurePublicCloud { 57 t.Fatalf("unexpected endpoint %s", ep) 58 } 59} 60 61func TestNewConnection(t *testing.T) { 62 const customEndpoint = "https://contoso.com/fake/endpoint" 63 con := NewConnection(customEndpoint, mockTokenCred{}, nil) 64 if ep := con.Endpoint(); ep != customEndpoint { 65 t.Fatalf("unexpected endpoint %s", ep) 66 } 67} 68 69func TestNewConnectionWithOptions(t *testing.T) { 70 srv, close := mock.NewServer() 71 defer close() 72 srv.AppendResponse() 73 opt := ConnectionOptions{} 74 opt.HTTPClient = srv 75 con := NewConnection(srv.URL(), mockTokenCred{}, &opt) 76 if ep := con.Endpoint(); ep != srv.URL() { 77 t.Fatalf("unexpected endpoint %s", ep) 78 } 79 req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) 80 if err != nil { 81 t.Fatalf("Unexpected error: %v", err) 82 } 83 resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) 84 if err != nil { 85 t.Fatalf("Unexpected error: %v", err) 86 } 87 if resp.StatusCode != http.StatusOK { 88 t.Fatalf("unexpected status code: %d", resp.StatusCode) 89 } 90 if ua := resp.Request.Header.Get("User-Agent"); !strings.HasPrefix(ua, "azsdk-go-armtest/v1.2.3") { 91 t.Fatalf("unexpected User-Agent %s", ua) 92 } 93} 94 95func TestNewConnectionWithCustomTelemetry(t *testing.T) { 96 const myTelemetry = "something" 97 srv, close := mock.NewServer() 98 defer close() 99 srv.AppendResponse() 100 opt := ConnectionOptions{} 101 opt.HTTPClient = srv 102 opt.Telemetry.ApplicationID = myTelemetry 103 con := NewConnection(srv.URL(), mockTokenCred{}, &opt) 104 if ep := con.Endpoint(); ep != srv.URL() { 105 t.Fatalf("unexpected endpoint %s", ep) 106 } 107 if opt.Telemetry.ApplicationID != myTelemetry { 108 t.Fatalf("telemetry was modified: %s", opt.Telemetry.ApplicationID) 109 } 110 req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) 111 if err != nil { 112 t.Fatalf("Unexpected error: %v", err) 113 } 114 resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) 115 if err != nil { 116 t.Fatalf("Unexpected error: %v", err) 117 } 118 if resp.StatusCode != http.StatusOK { 119 t.Fatalf("unexpected status code: %d", resp.StatusCode) 120 } 121 if ua := resp.Request.Header.Get("User-Agent"); !strings.HasPrefix(ua, myTelemetry+" "+"azsdk-go-armtest/v1.2.3") { 122 t.Fatalf("unexpected User-Agent %s", ua) 123 } 124} 125 126func TestDisableAutoRPRegistration(t *testing.T) { 127 srv, close := mock.NewServer() 128 defer close() 129 // initial response that RP is unregistered 130 srv.SetResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) 131 con := NewConnection(srv.URL(), mockTokenCred{}, &ConnectionOptions{DisableRPRegistration: true}) 132 if ep := con.Endpoint(); ep != srv.URL() { 133 t.Fatalf("unexpected endpoint %s", ep) 134 } 135 req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) 136 if err != nil { 137 t.Fatalf("Unexpected error: %v", err) 138 } 139 // log only RP registration 140 log.SetClassifications(armruntime.LogRPRegistration) 141 defer func() { 142 // reset logging 143 log.SetClassifications() 144 }() 145 logEntries := 0 146 log.SetListener(func(cls log.Classification, msg string) { 147 logEntries++ 148 }) 149 resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) 150 if err != nil { 151 t.Fatal(err) 152 } 153 if resp.StatusCode != http.StatusConflict { 154 t.Fatalf("unexpected status code %d:", resp.StatusCode) 155 } 156 // shouldn't be any log entries 157 if logEntries != 0 { 158 t.Fatalf("expected 0 log entries, got %d", logEntries) 159 } 160} 161 162// policy that tracks the number of times it was invoked 163type countingPolicy struct { 164 count int 165} 166 167func (p *countingPolicy) Do(req *policy.Request) (*http.Response, error) { 168 p.count++ 169 return req.Next() 170} 171 172func TestConnectionWithCustomPolicies(t *testing.T) { 173 srv, close := mock.NewServer() 174 defer close() 175 // initial response is a failure to trigger retry 176 srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) 177 srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) 178 perCallPolicy := countingPolicy{} 179 perRetryPolicy := countingPolicy{} 180 con := NewConnection(srv.URL(), mockTokenCred{}, &ConnectionOptions{ 181 DisableRPRegistration: true, 182 PerCallPolicies: []policy.Policy{&perCallPolicy}, 183 PerRetryPolicies: []policy.Policy{&perRetryPolicy}, 184 }) 185 req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) 186 if err != nil { 187 t.Fatal(err) 188 } 189 resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) 190 if err != nil { 191 t.Fatal(err) 192 } 193 if resp.StatusCode != http.StatusOK { 194 t.Fatalf("unexpected status code %d", resp.StatusCode) 195 } 196 if perCallPolicy.count != 1 { 197 t.Fatalf("unexpected per call policy count %d", perCallPolicy.count) 198 } 199 if perRetryPolicy.count != 2 { 200 t.Fatalf("unexpected per retry policy count %d", perRetryPolicy.count) 201 } 202} 203