1package discover 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "net/http" 8 "net/url" 9 "testing" 10 11 "gitlab.com/gitlab-org/gitlab-shell/client" 12 13 "github.com/stretchr/testify/require" 14 "gitlab.com/gitlab-org/gitlab-shell/client/testserver" 15 "gitlab.com/gitlab-org/gitlab-shell/internal/config" 16) 17 18var ( 19 requests []testserver.TestRequestHandler 20) 21 22func init() { 23 requests = []testserver.TestRequestHandler{ 24 { 25 Path: "/api/v4/internal/discover", 26 Handler: func(w http.ResponseWriter, r *http.Request) { 27 if r.URL.Query().Get("key_id") == "1" { 28 body := &Response{ 29 UserId: 2, 30 Username: "alex-doe", 31 Name: "Alex Doe", 32 } 33 json.NewEncoder(w).Encode(body) 34 } else if r.URL.Query().Get("username") == "jane-doe" { 35 body := &Response{ 36 UserId: 1, 37 Username: "jane-doe", 38 Name: "Jane Doe", 39 } 40 json.NewEncoder(w).Encode(body) 41 } else if r.URL.Query().Get("username") == "broken_message" { 42 w.WriteHeader(http.StatusForbidden) 43 body := &client.ErrorResponse{ 44 Message: "Not allowed!", 45 } 46 json.NewEncoder(w).Encode(body) 47 } else if r.URL.Query().Get("username") == "broken_json" { 48 w.Write([]byte("{ \"message\": \"broken json!\"")) 49 } else if r.URL.Query().Get("username") == "broken_empty" { 50 w.WriteHeader(http.StatusForbidden) 51 } else { 52 fmt.Fprint(w, "null") 53 } 54 }, 55 }, 56 } 57} 58 59func TestGetByKeyId(t *testing.T) { 60 client := setup(t) 61 62 params := url.Values{} 63 params.Add("key_id", "1") 64 result, err := client.getResponse(context.Background(), params) 65 require.NoError(t, err) 66 require.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result) 67} 68 69func TestGetByUsername(t *testing.T) { 70 client := setup(t) 71 72 params := url.Values{} 73 params.Add("username", "jane-doe") 74 result, err := client.getResponse(context.Background(), params) 75 require.NoError(t, err) 76 require.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result) 77} 78 79func TestMissingUser(t *testing.T) { 80 client := setup(t) 81 82 params := url.Values{} 83 params.Add("username", "missing") 84 result, err := client.getResponse(context.Background(), params) 85 require.NoError(t, err) 86 require.True(t, result.IsAnonymous()) 87} 88 89func TestErrorResponses(t *testing.T) { 90 client := setup(t) 91 92 testCases := []struct { 93 desc string 94 fakeUsername string 95 expectedError string 96 }{ 97 { 98 desc: "A response with an error message", 99 fakeUsername: "broken_message", 100 expectedError: "Not allowed!", 101 }, 102 { 103 desc: "A response with bad JSON", 104 fakeUsername: "broken_json", 105 expectedError: "Parsing failed", 106 }, 107 { 108 desc: "An error response without message", 109 fakeUsername: "broken_empty", 110 expectedError: "Internal API error (403)", 111 }, 112 } 113 114 for _, tc := range testCases { 115 t.Run(tc.desc, func(t *testing.T) { 116 params := url.Values{} 117 params.Add("username", tc.fakeUsername) 118 resp, err := client.getResponse(context.Background(), params) 119 120 require.EqualError(t, err, tc.expectedError) 121 require.Nil(t, resp) 122 }) 123 } 124} 125 126func setup(t *testing.T) *Client { 127 url := testserver.StartSocketHttpServer(t, requests) 128 129 client, err := NewClient(&config.Config{GitlabUrl: url}) 130 require.NoError(t, err) 131 132 return client 133} 134