1package jwtauth 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io/ioutil" 9 "net/http" 10 "net/url" 11 "strings" 12 13 log "github.com/hashicorp/go-hclog" 14 "golang.org/x/oauth2" 15) 16 17const ( 18 // Deprecated: The host of the Azure Active Directory (AAD) graph API 19 azureADGraphHost = "graph.windows.net" 20 21 // The host and version of the Microsoft Graph API 22 microsoftGraphHost = "graph.microsoft.com" 23 microsoftGraphAPIVersion = "/v1.0" 24 25 // Distributed claim fields 26 claimNamesField = "_claim_names" 27 claimSourcesField = "_claim_sources" 28) 29 30// AzureProvider is used for Azure-specific configuration 31type AzureProvider struct { 32 // Context for azure calls 33 ctx context.Context 34} 35 36// Initialize anything in the AzureProvider struct - satisfying the CustomProvider interface 37func (a *AzureProvider) Initialize(_ context.Context, _ *jwtConfig) error { 38 return nil 39} 40 41// SensitiveKeys - satisfying the CustomProvider interface 42func (a *AzureProvider) SensitiveKeys() []string { 43 return []string{} 44} 45 46// FetchGroups - custom groups fetching for azure - satisfying GroupsFetcher interface 47func (a *AzureProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource) (interface{}, error) { 48 groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim) 49 50 if groupsClaimRaw == nil { 51 // If the "groups" claim is missing, it might be because the user is a 52 // member of more than 200 groups, which means the token contains 53 // distributed claim information. Attempt to look that up here. 54 azureClaimSourcesURL, err := a.getClaimSource(b.Logger(), allClaims, role) 55 if err != nil { 56 return nil, fmt.Errorf("unable to get claim sources: %s", err) 57 } 58 59 a.ctx, err = b.createCAContext(b.providerCtx, b.cachedConfig.OIDCDiscoveryCAPEM) 60 if err != nil { 61 return nil, fmt.Errorf("unable to create CA Context: %s", err) 62 } 63 64 azureGroups, err := a.getAzureGroups(azureClaimSourcesURL, tokenSource) 65 if err != nil { 66 return nil, fmt.Errorf("%q claim not found in token: %v", role.GroupsClaim, err) 67 } 68 groupsClaimRaw = azureGroups 69 } 70 b.Logger().Debug(fmt.Sprintf("groups claim raw is %v", groupsClaimRaw)) 71 return groupsClaimRaw, nil 72} 73 74// In Azure, if you are indirectly member of more than 200 groups, they will 75// send _claim_names and _claim_sources instead of the groups, per OIDC Core 76// 1.0, section 5.6.2: 77// https://openid.net/specs/openid-connect-core-1_0.html#AggregatedDistributedClaims 78// In the future this could be used with other providers as well. Example: 79// 80// { 81// "_claim_names": { 82// "groups": "src1" 83// }, 84// "_claim_sources": { 85// "src1": { 86// "endpoint": "https://graph.windows.net...." 87// } 88// } 89// } 90// 91// For this to work, "profile" should be set in "oidc_scopes" in the vault oidc role. 92// 93func (a *AzureProvider) getClaimSource(logger log.Logger, allClaims map[string]interface{}, role *jwtRole) (string, error) { 94 // Get the source key for the groups claim 95 name := fmt.Sprintf("/%s/%s", claimNamesField, role.GroupsClaim) 96 groupsClaimSource := getClaim(logger, allClaims, name) 97 if groupsClaimSource == nil { 98 return "", fmt.Errorf("unable to locate groups claim %q in %s", role.GroupsClaim, claimNamesField) 99 } 100 // Get the endpoint source for the groups claim 101 endpoint := fmt.Sprintf("/%s/%s/endpoint", claimSourcesField, groupsClaimSource.(string)) 102 val := getClaim(logger, allClaims, endpoint) 103 if val == nil { 104 return "", fmt.Errorf("unable to locate %s in claims", endpoint) 105 } 106 107 urlParsed, err := url.Parse(fmt.Sprintf("%v", val)) 108 if err != nil { 109 return "", fmt.Errorf("unable to parse claim source URL: %w", err) 110 } 111 112 // If the endpoint source for the groups claim has a host of the deprecated AAD graph API, 113 // then replace it to instead use the Microsoft graph API. The AAD graph API is deprecated 114 // and will eventually stop servicing requests. See details at: 115 // - https://developer.microsoft.com/en-us/office/blogs/microsoft-graph-or-azure-ad-graph/ 116 // - https://docs.microsoft.com/en-us/graph/api/overview?view=graph-rest-1.0 117 if urlParsed.Host == azureADGraphHost { 118 urlParsed.Host = microsoftGraphHost 119 urlParsed.Path = microsoftGraphAPIVersion + urlParsed.Path 120 } 121 122 logger.Debug(fmt.Sprintf("found Azure Graph API endpoint for group membership: %v", urlParsed.String())) 123 return urlParsed.String(), nil 124} 125 126// Fetch user groups from the Microsoft Graph API 127func (a *AzureProvider) getAzureGroups(groupsURL string, tokenSource oauth2.TokenSource) (interface{}, error) { 128 urlParsed, err := url.Parse(groupsURL) 129 if err != nil { 130 return nil, fmt.Errorf("failed to parse distributed groups source url %s: %s", groupsURL, err) 131 } 132 133 // Use the Access Token that was pre-negotiated between the Claims Provider and RP 134 // via https://openid.net/specs/openid-connect-core-1_0.html#AggregatedDistributedClaims. 135 if tokenSource == nil { 136 return nil, errors.New("token unavailable to call Microsoft Graph API") 137 } 138 token, err := tokenSource.Token() 139 if err != nil { 140 return nil, fmt.Errorf("unable to get token: %s", err) 141 } 142 payload := strings.NewReader("{\"securityEnabledOnly\": false}") 143 req, err := http.NewRequest("POST", urlParsed.String(), payload) 144 if err != nil { 145 return nil, fmt.Errorf("error constructing groups endpoint request: %s", err) 146 } 147 req.Header.Add("content-type", "application/json") 148 token.SetAuthHeader(req) 149 150 client := http.DefaultClient 151 if c, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client); ok { 152 client = c 153 } 154 res, err := client.Do(req) 155 if err != nil { 156 return nil, fmt.Errorf("unable to call Microsoft Graph API: %s", err) 157 } 158 defer res.Body.Close() 159 body, err := ioutil.ReadAll(res.Body) 160 if err != nil { 161 return nil, fmt.Errorf("failed to read Microsoft Graph API response: %s", err) 162 } 163 if res.StatusCode != http.StatusOK { 164 return nil, fmt.Errorf("failed to get groups: %s", string(body)) 165 } 166 167 var target azureGroups 168 if err := json.Unmarshal(body, &target); err != nil { 169 return nil, fmt.Errorf("unabled to decode response: %s", err) 170 } 171 return target.Value, nil 172} 173 174type azureGroups struct { 175 Value []interface{} `json:"value"` 176} 177