1package jwtauth
2
3import (
4	"bytes"
5	"context"
6	"encoding/json"
7	"encoding/pem"
8	"net/http"
9	"net/http/httptest"
10	"strings"
11	"testing"
12
13	"github.com/hashicorp/go-hclog"
14	"github.com/hashicorp/vault/sdk/logical"
15	"github.com/stretchr/testify/assert"
16	"github.com/stretchr/testify/require"
17	"golang.org/x/oauth2"
18)
19
20type azureServer struct {
21	t      *testing.T
22	server *httptest.Server
23}
24
25func newAzureServer(t *testing.T) *azureServer {
26	a := new(azureServer)
27	a.t = t
28	a.server = httptest.NewTLSServer(a)
29
30	return a
31}
32
33func (a *azureServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
34	w.Header().Set("Content-Type", "application/json")
35
36	switch r.URL.Path {
37	case "/.well-known/openid-configuration":
38		w.Write([]byte(strings.Replace(`
39			{
40				"issuer": "%s",
41				"authorization_endpoint": "%s/auth",
42				"token_endpoint": "%s/oauth2/v2.0/token",
43				"jwks_uri": "%s/certs",
44				"userinfo_endpoint": "%s/userinfo"
45			}`, "%s", a.server.URL, -1)))
46	case "/getMemberObjects":
47		groups := azureGroups{
48			Value: []interface{}{"group1", "group2"},
49		}
50		gBytes, _ := json.Marshal(groups)
51		w.Write(gBytes)
52	default:
53		a.t.Fatalf("unexpected path: %q", r.URL.Path)
54	}
55}
56
57// getTLSCert returns the certificate for this provider in PEM format
58func (a *azureServer) getTLSCert() (string, error) {
59	cert := a.server.Certificate()
60	block := &pem.Block{
61		Type:  "CERTIFICATE",
62		Bytes: cert.Raw,
63	}
64
65	pemBuf := new(bytes.Buffer)
66	if err := pem.Encode(pemBuf, block); err != nil {
67		return "", err
68	}
69
70	return pemBuf.String(), nil
71}
72
73func TestLogin_fetchGroups(t *testing.T) {
74
75	aServer := newAzureServer(t)
76	aCert, err := aServer.getTLSCert()
77	require.NoError(t, err)
78
79	b, storage := getBackend(t)
80	ctx := context.Background()
81
82	data := map[string]interface{}{
83		"oidc_discovery_url":    aServer.server.URL,
84		"oidc_discovery_ca_pem": aCert,
85		"oidc_client_id":        "abc",
86		"oidc_client_secret":    "def",
87		"default_role":          "test",
88		"bound_issuer":          "http://vault.example.com/",
89		"provider_config": map[string]interface{}{
90			"provider": "azure",
91		},
92	}
93
94	// basic configuration
95	req := &logical.Request{
96		Operation: logical.UpdateOperation,
97		Path:      configPath,
98		Storage:   storage,
99		Data:      data,
100	}
101
102	resp, err := b.HandleRequest(context.Background(), req)
103	if err != nil || (resp != nil && resp.IsError()) {
104		t.Fatalf("err:%v resp:%#v\n", err, resp)
105	}
106
107	// set up test role
108	data = map[string]interface{}{
109		"user_claim":            "email",
110		"groups_claim":          "groups",
111		"allowed_redirect_uris": []string{"https://example.com"},
112	}
113
114	req = &logical.Request{
115		Operation: logical.CreateOperation,
116		Path:      "role/test",
117		Storage:   storage,
118		Data:      data,
119	}
120
121	resp, err = b.HandleRequest(context.Background(), req)
122	if err != nil || (resp != nil && resp.IsError()) {
123		t.Fatalf("err:%v resp:%#v\n", err, resp)
124	}
125
126	role := &jwtRole{
127		GroupsClaim: "groups",
128	}
129	allClaims := map[string]interface{}{
130		"_claim_names": H{
131			"groups": "src1",
132		},
133		"_claim_sources": H{
134			"src1": H{
135				"endpoint": aServer.server.URL + "/getMemberObjects",
136			},
137		},
138	}
139
140	// Ensure b.cachedConfig is populated
141	config, err := b.(*jwtAuthBackend).config(ctx, storage)
142	if err != nil {
143		t.Fatal(err)
144	}
145
146	// Initialize the azure provider
147	provider, err := NewProviderConfig(ctx, config, ProviderMap())
148	if err != nil {
149		t.Fatal(err)
150	}
151
152	// Ensure groups are as expected
153	tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "test.access.token"})
154	groupsResp, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role, tokenSource)
155	assert.NoError(t, err)
156	assert.Equal(t, []interface{}{"group1", "group2"}, groupsResp)
157}
158
159func Test_getClaimSources(t *testing.T) {
160	t.Run("normal case", func(t *testing.T) {
161		a := &AzureProvider{}
162		role := &jwtRole{
163			GroupsClaim: "groups",
164		}
165		allClaims := H{
166			claimNamesField: H{
167				role.GroupsClaim: "src1",
168			},
169			claimSourcesField: H{
170				"src1": H{
171					"endpoint": "/test/endpoint",
172				},
173			},
174		}
175		source, err := a.getClaimSource(hclog.Default(), allClaims, role)
176		assert.NoError(t, err)
177		assert.Equal(t, "/test/endpoint", source)
178	})
179
180	t.Run("no _claim_names", func(t *testing.T) {
181		a := AzureProvider{}
182		role := &jwtRole{
183			GroupsClaim: "groups",
184		}
185		allClaims := H{
186			"not_claim_names": "blank",
187		}
188		source, err := a.getClaimSource(hclog.Default(), allClaims, role)
189		assert.Error(t, err)
190		assert.Empty(t, source)
191	})
192
193	t.Run("no _claim_sources", func(t *testing.T) {
194		a := AzureProvider{}
195		role := &jwtRole{
196			GroupsClaim: "groups",
197		}
198		allClaims := H{
199			claimNamesField: H{
200				role.GroupsClaim: "src1",
201			},
202		}
203		source, err := a.getClaimSource(hclog.Default(), allClaims, role)
204		assert.Error(t, err)
205		assert.Empty(t, source)
206	})
207}
208