1package jwtauth
2
3import (
4	"context"
5	"encoding/json"
6	"errors"
7	"fmt"
8	"io/ioutil"
9	"net/http"
10	"net/url"
11	"strings"
12
13	log "github.com/hashicorp/go-hclog"
14	"golang.org/x/oauth2"
15)
16
17const (
18	// Deprecated: The host of the Azure Active Directory (AAD) graph API
19	azureADGraphHost = "graph.windows.net"
20
21	// The host and version of the Microsoft Graph API
22	microsoftGraphHost       = "graph.microsoft.com"
23	microsoftGraphAPIVersion = "/v1.0"
24
25	// Distributed claim fields
26	claimNamesField   = "_claim_names"
27	claimSourcesField = "_claim_sources"
28)
29
30// AzureProvider is used for Azure-specific configuration
31type AzureProvider struct {
32	// Context for azure calls
33	ctx context.Context
34}
35
36// Initialize anything in the AzureProvider struct - satisfying the CustomProvider interface
37func (a *AzureProvider) Initialize(_ context.Context, _ *jwtConfig) error {
38	return nil
39}
40
41// SensitiveKeys - satisfying the CustomProvider interface
42func (a *AzureProvider) SensitiveKeys() []string {
43	return []string{}
44}
45
46// FetchGroups - custom groups fetching for azure - satisfying GroupsFetcher interface
47func (a *AzureProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource) (interface{}, error) {
48	groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)
49
50	if groupsClaimRaw == nil {
51		// If the "groups" claim is missing, it might be because the user is a
52		// member of more than 200 groups, which means the token contains
53		// distributed claim information. Attempt to look that up here.
54		azureClaimSourcesURL, err := a.getClaimSource(b.Logger(), allClaims, role)
55		if err != nil {
56			return nil, fmt.Errorf("unable to get claim sources: %s", err)
57		}
58
59		a.ctx, err = b.createCAContext(b.providerCtx, b.cachedConfig.OIDCDiscoveryCAPEM)
60		if err != nil {
61			return nil, fmt.Errorf("unable to create CA Context: %s", err)
62		}
63
64		azureGroups, err := a.getAzureGroups(azureClaimSourcesURL, tokenSource)
65		if err != nil {
66			return nil, fmt.Errorf("%q claim not found in token: %v", role.GroupsClaim, err)
67		}
68		groupsClaimRaw = azureGroups
69	}
70	b.Logger().Debug(fmt.Sprintf("groups claim raw is %v", groupsClaimRaw))
71	return groupsClaimRaw, nil
72}
73
74// In Azure, if you are indirectly member of more than 200 groups, they will
75// send _claim_names and _claim_sources instead of the groups, per OIDC Core
76// 1.0, section 5.6.2:
77// https://openid.net/specs/openid-connect-core-1_0.html#AggregatedDistributedClaims
78// In the future this could be used with other providers as well. Example:
79//
80// {
81// 	 "_claim_names": {
82// 	   "groups": "src1"
83// 	 },
84// 	 "_claim_sources": {
85// 	   "src1": {
86// 	     "endpoint": "https://graph.windows.net...."
87// 	   }
88//   }
89// }
90//
91// For this to work, "profile" should be set in "oidc_scopes" in the vault oidc role.
92//
93func (a *AzureProvider) getClaimSource(logger log.Logger, allClaims map[string]interface{}, role *jwtRole) (string, error) {
94	// Get the source key for the groups claim
95	name := fmt.Sprintf("/%s/%s", claimNamesField, role.GroupsClaim)
96	groupsClaimSource := getClaim(logger, allClaims, name)
97	if groupsClaimSource == nil {
98		return "", fmt.Errorf("unable to locate groups claim %q in %s", role.GroupsClaim, claimNamesField)
99	}
100	// Get the endpoint source for the groups claim
101	endpoint := fmt.Sprintf("/%s/%s/endpoint", claimSourcesField, groupsClaimSource.(string))
102	val := getClaim(logger, allClaims, endpoint)
103	if val == nil {
104		return "", fmt.Errorf("unable to locate %s in claims", endpoint)
105	}
106
107	urlParsed, err := url.Parse(fmt.Sprintf("%v", val))
108	if err != nil {
109		return "", fmt.Errorf("unable to parse claim source URL: %w", err)
110	}
111
112	// If the endpoint source for the groups claim has a host of the deprecated AAD graph API,
113	// then replace it to instead use the Microsoft graph API. The AAD graph API is deprecated
114	// and will eventually stop servicing requests. See details at:
115	// - https://developer.microsoft.com/en-us/office/blogs/microsoft-graph-or-azure-ad-graph/
116	// - https://docs.microsoft.com/en-us/graph/api/overview?view=graph-rest-1.0
117	if urlParsed.Host == azureADGraphHost {
118		urlParsed.Host = microsoftGraphHost
119		urlParsed.Path = microsoftGraphAPIVersion + urlParsed.Path
120	}
121
122	logger.Debug(fmt.Sprintf("found Azure Graph API endpoint for group membership: %v", urlParsed.String()))
123	return urlParsed.String(), nil
124}
125
126// Fetch user groups from the Microsoft Graph API
127func (a *AzureProvider) getAzureGroups(groupsURL string, tokenSource oauth2.TokenSource) (interface{}, error) {
128	urlParsed, err := url.Parse(groupsURL)
129	if err != nil {
130		return nil, fmt.Errorf("failed to parse distributed groups source url %s: %s", groupsURL, err)
131	}
132
133	// Use the Access Token that was pre-negotiated between the Claims Provider and RP
134	// via https://openid.net/specs/openid-connect-core-1_0.html#AggregatedDistributedClaims.
135	if tokenSource == nil {
136		return nil, errors.New("token unavailable to call Microsoft Graph API")
137	}
138	token, err := tokenSource.Token()
139	if err != nil {
140		return nil, fmt.Errorf("unable to get token: %s", err)
141	}
142	payload := strings.NewReader("{\"securityEnabledOnly\": false}")
143	req, err := http.NewRequest("POST", urlParsed.String(), payload)
144	if err != nil {
145		return nil, fmt.Errorf("error constructing groups endpoint request: %s", err)
146	}
147	req.Header.Add("content-type", "application/json")
148	token.SetAuthHeader(req)
149
150	client := http.DefaultClient
151	if c, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
152		client = c
153	}
154	res, err := client.Do(req)
155	if err != nil {
156		return nil, fmt.Errorf("unable to call Microsoft Graph API: %s", err)
157	}
158	defer res.Body.Close()
159	body, err := ioutil.ReadAll(res.Body)
160	if err != nil {
161		return nil, fmt.Errorf("failed to read Microsoft Graph API response: %s", err)
162	}
163	if res.StatusCode != http.StatusOK {
164		return nil, fmt.Errorf("failed to get groups: %s", string(body))
165	}
166
167	var target azureGroups
168	if err := json.Unmarshal(body, &target); err != nil {
169		return nil, fmt.Errorf("unabled to decode response: %s", err)
170	}
171	return target.Value, nil
172}
173
174type azureGroups struct {
175	Value []interface{} `json:"value"`
176}
177