1package ec2metadata_test 2 3import ( 4 "net/http" 5 "net/http/httptest" 6 "os" 7 "strings" 8 "sync" 9 "testing" 10 "time" 11 12 "github.com/aws/aws-sdk-go/aws" 13 "github.com/aws/aws-sdk-go/aws/awserr" 14 "github.com/aws/aws-sdk-go/aws/ec2metadata" 15 "github.com/aws/aws-sdk-go/aws/request" 16 "github.com/aws/aws-sdk-go/awstesting/unit" 17 "github.com/aws/aws-sdk-go/internal/sdktesting" 18) 19 20func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) { 21 svc := ec2metadata.New(unit.Session) 22 23 if e, a := http.DefaultClient, svc.Config.HTTPClient; e == a { 24 t.Errorf("expect %v, not to equal %v", e, a) 25 } 26 27 if e, a := 5*time.Second, svc.Config.HTTPClient.Timeout; e != a { 28 t.Errorf("expect %v to be %v", e, a) 29 } 30} 31 32func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) { 33 http.DefaultClient.Transport = &http.Transport{} 34 defer func() { 35 http.DefaultClient.Transport = nil 36 }() 37 38 svc := ec2metadata.New(unit.Session) 39 40 if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a { 41 t.Errorf("expect %v, got %v", e, a) 42 } 43 44 tr := svc.Config.HTTPClient.Transport.(*http.Transport) 45 if tr == nil { 46 t.Fatalf("expect transport not to be nil") 47 } 48 if tr.Dial != nil { 49 t.Errorf("expect dial to be nil, was not") 50 } 51} 52 53func TestClientDisableOverrideDefaultHTTPClientTimeout(t *testing.T) { 54 svc := ec2metadata.New(unit.Session, aws.NewConfig().WithEC2MetadataDisableTimeoutOverride(true)) 55 56 if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a { 57 t.Errorf("expect %v, got %v", e, a) 58 } 59} 60 61func TestClientOverrideDefaultHTTPClientTimeoutRace(t *testing.T) { 62 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 63 w.Write([]byte("us-east-1a")) 64 })) 65 defer server.Close() 66 67 cfg := aws.NewConfig().WithEndpoint(server.URL) 68 runEC2MetadataClients(t, cfg, 50) 69} 70 71func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) { 72 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 73 w.Write([]byte("us-east-1a")) 74 })) 75 defer server.Close() 76 77 cfg := aws.NewConfig().WithEndpoint(server.URL).WithHTTPClient(&http.Client{ 78 Transport: &http.Transport{ 79 DisableKeepAlives: true, 80 }, 81 }) 82 83 runEC2MetadataClients(t, cfg, 50) 84} 85 86func TestClientDisableIMDS(t *testing.T) { 87 restoreEnvFn := sdktesting.StashEnv() 88 defer restoreEnvFn() 89 90 os.Setenv("AWS_EC2_METADATA_DISABLED", "true") 91 92 svc := ec2metadata.New(unit.Session, &aws.Config{ 93 LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody), 94 }) 95 resp, err := svc.GetUserData() 96 if err == nil { 97 t.Fatalf("expect error, got none") 98 } 99 if len(resp) != 0 { 100 t.Errorf("expect no response, got %v", resp) 101 } 102 103 aerr := err.(awserr.Error) 104 if e, a := request.CanceledErrorCode, aerr.Code(); e != a { 105 t.Errorf("expect %v error code, got %v", e, a) 106 } 107 if e, a := "AWS_EC2_METADATA_DISABLED", aerr.Message(); !strings.Contains(a, e) { 108 t.Errorf("expect %v in error message, got %v", e, a) 109 } 110} 111 112func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) { 113 var wg sync.WaitGroup 114 wg.Add(atOnce) 115 svc := ec2metadata.New(unit.Session, cfg) 116 for i := 0; i < atOnce; i++ { 117 go func() { 118 defer wg.Done() 119 _, err := svc.GetUserData() 120 if err != nil { 121 t.Errorf("expect no error, got %v", err) 122 } 123 }() 124 } 125 wg.Wait() 126} 127