1package jwtauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "encoding/pem" 8 "net/http" 9 "net/http/httptest" 10 "strings" 11 "testing" 12 13 "github.com/hashicorp/go-hclog" 14 "github.com/hashicorp/vault/sdk/logical" 15 "github.com/stretchr/testify/assert" 16 "github.com/stretchr/testify/require" 17 "golang.org/x/oauth2" 18) 19 20type azureServer struct { 21 t *testing.T 22 server *httptest.Server 23} 24 25func newAzureServer(t *testing.T) *azureServer { 26 a := new(azureServer) 27 a.t = t 28 a.server = httptest.NewTLSServer(a) 29 30 return a 31} 32 33func (a *azureServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 34 w.Header().Set("Content-Type", "application/json") 35 36 switch r.URL.Path { 37 case "/.well-known/openid-configuration": 38 w.Write([]byte(strings.Replace(` 39 { 40 "issuer": "%s", 41 "authorization_endpoint": "%s/auth", 42 "token_endpoint": "%s/oauth2/v2.0/token", 43 "jwks_uri": "%s/certs", 44 "userinfo_endpoint": "%s/userinfo" 45 }`, "%s", a.server.URL, -1))) 46 case "/getMemberObjects": 47 groups := azureGroups{ 48 Value: []interface{}{"group1", "group2"}, 49 } 50 gBytes, _ := json.Marshal(groups) 51 w.Write(gBytes) 52 default: 53 a.t.Fatalf("unexpected path: %q", r.URL.Path) 54 } 55} 56 57// getTLSCert returns the certificate for this provider in PEM format 58func (a *azureServer) getTLSCert() (string, error) { 59 cert := a.server.Certificate() 60 block := &pem.Block{ 61 Type: "CERTIFICATE", 62 Bytes: cert.Raw, 63 } 64 65 pemBuf := new(bytes.Buffer) 66 if err := pem.Encode(pemBuf, block); err != nil { 67 return "", err 68 } 69 70 return pemBuf.String(), nil 71} 72 73func TestLogin_fetchGroups(t *testing.T) { 74 75 aServer := newAzureServer(t) 76 aCert, err := aServer.getTLSCert() 77 require.NoError(t, err) 78 79 b, storage := getBackend(t) 80 ctx := context.Background() 81 82 data := map[string]interface{}{ 83 "oidc_discovery_url": aServer.server.URL, 84 "oidc_discovery_ca_pem": aCert, 85 "oidc_client_id": "abc", 86 "oidc_client_secret": "def", 87 "default_role": "test", 88 "bound_issuer": "http://vault.example.com/", 89 "provider_config": map[string]interface{}{ 90 "provider": "azure", 91 }, 92 } 93 94 // basic configuration 95 req := &logical.Request{ 96 Operation: logical.UpdateOperation, 97 Path: configPath, 98 Storage: storage, 99 Data: data, 100 } 101 102 resp, err := b.HandleRequest(context.Background(), req) 103 if err != nil || (resp != nil && resp.IsError()) { 104 t.Fatalf("err:%v resp:%#v\n", err, resp) 105 } 106 107 // set up test role 108 data = map[string]interface{}{ 109 "user_claim": "email", 110 "groups_claim": "groups", 111 "allowed_redirect_uris": []string{"https://example.com"}, 112 } 113 114 req = &logical.Request{ 115 Operation: logical.CreateOperation, 116 Path: "role/test", 117 Storage: storage, 118 Data: data, 119 } 120 121 resp, err = b.HandleRequest(context.Background(), req) 122 if err != nil || (resp != nil && resp.IsError()) { 123 t.Fatalf("err:%v resp:%#v\n", err, resp) 124 } 125 126 role := &jwtRole{ 127 GroupsClaim: "groups", 128 } 129 allClaims := map[string]interface{}{ 130 "_claim_names": H{ 131 "groups": "src1", 132 }, 133 "_claim_sources": H{ 134 "src1": H{ 135 "endpoint": aServer.server.URL + "/getMemberObjects", 136 }, 137 }, 138 } 139 140 // Ensure b.cachedConfig is populated 141 config, err := b.(*jwtAuthBackend).config(ctx, storage) 142 if err != nil { 143 t.Fatal(err) 144 } 145 146 // Initialize the azure provider 147 provider, err := NewProviderConfig(ctx, config, ProviderMap()) 148 if err != nil { 149 t.Fatal(err) 150 } 151 152 // Ensure groups are as expected 153 tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "test.access.token"}) 154 groupsResp, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role, tokenSource) 155 assert.NoError(t, err) 156 assert.Equal(t, []interface{}{"group1", "group2"}, groupsResp) 157} 158 159func Test_getClaimSources(t *testing.T) { 160 t.Run("normal case", func(t *testing.T) { 161 a := &AzureProvider{} 162 role := &jwtRole{ 163 GroupsClaim: "groups", 164 } 165 allClaims := H{ 166 claimNamesField: H{ 167 role.GroupsClaim: "src1", 168 }, 169 claimSourcesField: H{ 170 "src1": H{ 171 "endpoint": "/test/endpoint", 172 }, 173 }, 174 } 175 source, err := a.getClaimSource(hclog.Default(), allClaims, role) 176 assert.NoError(t, err) 177 assert.Equal(t, "/test/endpoint", source) 178 }) 179 180 t.Run("no _claim_names", func(t *testing.T) { 181 a := AzureProvider{} 182 role := &jwtRole{ 183 GroupsClaim: "groups", 184 } 185 allClaims := H{ 186 "not_claim_names": "blank", 187 } 188 source, err := a.getClaimSource(hclog.Default(), allClaims, role) 189 assert.Error(t, err) 190 assert.Empty(t, source) 191 }) 192 193 t.Run("no _claim_sources", func(t *testing.T) { 194 a := AzureProvider{} 195 role := &jwtRole{ 196 GroupsClaim: "groups", 197 } 198 allClaims := H{ 199 claimNamesField: H{ 200 role.GroupsClaim: "src1", 201 }, 202 } 203 source, err := a.getClaimSource(hclog.Default(), allClaims, role) 204 assert.Error(t, err) 205 assert.Empty(t, source) 206 }) 207} 208