1%% This Source Code Form is subject to the terms of the Mozilla Public
2%% License, v. 2.0. If a copy of the MPL was not distributed with this
3%% file, You can obtain one at https://mozilla.org/MPL/2.0/.
4%%
5%% Copyright (c) 2007-2021 VMware, Inc. or its affiliates.  All rights reserved.
6%%
7
8-module(rabbit_auth_backend_oauth2).
9
10-include_lib("rabbit_common/include/rabbit.hrl").
11
12-behaviour(rabbit_authn_backend).
13-behaviour(rabbit_authz_backend).
14
15-export([description/0]).
16-export([user_login_authentication/2, user_login_authorization/2,
17         check_vhost_access/3, check_resource_access/4,
18         check_topic_access/4, check_token/1, state_can_expire/0, update_state/2]).
19
20% for testing
21-export([post_process_payload/1]).
22
23-import(rabbit_data_coercion, [to_map/1]).
24
25-ifdef(TEST).
26-compile(export_all).
27-endif.
28%%--------------------------------------------------------------------
29
30-define(APP, rabbitmq_auth_backend_oauth2).
31-define(RESOURCE_SERVER_ID, resource_server_id).
32%% a term used by the IdentityServer community
33-define(COMPLEX_CLAIM, extra_scopes_source).
34
35description() ->
36    [{name, <<"OAuth 2">>},
37     {description, <<"Performs authentication and authorisation using JWT tokens and OAuth 2 scopes">>}].
38
39%%--------------------------------------------------------------------
40
41user_login_authentication(Username, AuthProps) ->
42    case authenticate(Username, AuthProps) of
43	{refused, Msg, Args} = AuthResult ->
44	    rabbit_log:debug(Msg, Args),
45	    AuthResult;
46	_ = AuthResult ->
47	    AuthResult
48    end.
49
50user_login_authorization(Username, AuthProps) ->
51    case authenticate(Username, AuthProps) of
52        {ok, #auth_user{impl = Impl}} -> {ok, Impl};
53        Else                          -> Else
54    end.
55
56check_vhost_access(#auth_user{impl = DecodedToken},
57                   VHost, _AuthzData) ->
58    with_decoded_token(DecodedToken,
59        fun() ->
60            Scopes      = get_scopes(DecodedToken),
61            ScopeString = rabbit_oauth2_scope:concat_scopes(Scopes, ","),
62            rabbit_log:debug("Matching virtual host '~s' against the following scopes: ~s", [VHost, ScopeString]),
63            rabbit_oauth2_scope:vhost_access(VHost, Scopes)
64        end).
65
66check_resource_access(#auth_user{impl = DecodedToken},
67                      Resource, Permission, _AuthzContext) ->
68    with_decoded_token(DecodedToken,
69        fun() ->
70            Scopes = get_scopes(DecodedToken),
71            rabbit_oauth2_scope:resource_access(Resource, Permission, Scopes)
72        end).
73
74check_topic_access(#auth_user{impl = DecodedToken},
75                   Resource, Permission, Context) ->
76    with_decoded_token(DecodedToken,
77        fun() ->
78            Scopes = get_scopes(DecodedToken),
79            rabbit_oauth2_scope:topic_access(Resource, Permission, Context, Scopes)
80        end).
81
82state_can_expire() -> true.
83
84update_state(AuthUser, NewToken) ->
85  case check_token(NewToken) of
86      %% avoid logging the token
87      {error, _} = E  -> E;
88      {refused, {error, {invalid_token, error, _Err, _Stacktrace}}} ->
89        {refused, "Authentication using an OAuth 2/JWT token failed: provided token is invalid"};
90      {refused, Err} ->
91        {refused, rabbit_misc:format("Authentication using an OAuth 2/JWT token failed: ~p", [Err])};
92      {ok, DecodedToken} ->
93          Tags = tags_from(DecodedToken),
94
95          {ok, AuthUser#auth_user{tags = Tags,
96                                  impl = DecodedToken}}
97  end.
98
99%%--------------------------------------------------------------------
100
101authenticate(Username0, AuthProps0) ->
102    AuthProps = to_map(AuthProps0),
103    Token     = token_from_context(AuthProps),
104    case check_token(Token) of
105        %% avoid logging the token
106        {error, _} = E  -> E;
107        {refused, {error, {invalid_token, error, _Err, _Stacktrace}}} ->
108          {refused, "Authentication using an OAuth 2/JWT token failed: provided token is invalid", []};
109        {refused, Err} ->
110          {refused, "Authentication using an OAuth 2/JWT token failed: ~p", [Err]};
111        {ok, DecodedToken} ->
112            Func = fun() ->
113                        Username = username_from(Username0, DecodedToken),
114                        Tags     = tags_from(DecodedToken),
115
116                        {ok, #auth_user{username = Username,
117                                        tags = Tags,
118                                        impl = DecodedToken}}
119                   end,
120            case with_decoded_token(DecodedToken, Func) of
121                {error, Err} ->
122                    {refused, "Authentication using an OAuth 2/JWT token failed: ~p", [Err]};
123                Else ->
124                    Else
125            end
126    end.
127
128with_decoded_token(DecodedToken, Fun) ->
129    case validate_token_expiry(DecodedToken) of
130        ok               -> Fun();
131        {error, Msg} = Err ->
132            rabbit_log:error(Msg),
133            Err
134    end.
135
136validate_token_expiry(#{<<"exp">> := Exp}) when is_integer(Exp) ->
137    Now = os:system_time(seconds),
138    case Exp =< Now of
139        true  -> {error, rabbit_misc:format("Provided JWT token has expired at timestamp ~p (validated at ~p)", [Exp, Now])};
140        false -> ok
141    end;
142validate_token_expiry(#{}) -> ok.
143
144-spec check_token(binary()) -> {ok, map()} | {error, term()}.
145check_token(Token) ->
146    case uaa_jwt:decode_and_verify(Token) of
147        {error, Reason} -> {refused, {error, Reason}};
148        {true, Payload} -> validate_payload(post_process_payload(Payload));
149        {false, _}      -> {refused, signature_invalid}
150    end.
151
152post_process_payload(Payload) when is_map(Payload) ->
153    Payload0 = maps:map(fun(K, V) ->
154                        case K of
155                            <<"aud">>   when is_binary(V) -> binary:split(V, <<" ">>, [global, trim_all]);
156                            <<"scope">> when is_binary(V) -> binary:split(V, <<" ">>, [global, trim_all]);
157                            _ -> V
158                        end
159        end,
160        Payload
161    ),
162    Payload1 = case does_include_complex_claim_field(Payload0) of
163        true  -> post_process_payload_complex_claim(Payload0);
164        false -> Payload0
165        end,
166
167    Payload2 = case maps:is_key(<<"authorization">>, Payload1) of
168        true -> post_process_payload_keycloak(Payload1);
169        false -> Payload1
170        end,
171
172    Payload2.
173
174does_include_complex_claim_field(Payload) when is_map(Payload) ->
175        maps:is_key(application:get_env(?APP, ?COMPLEX_CLAIM, undefined), Payload).
176
177post_process_payload_complex_claim(Payload) ->
178    ComplexClaim = maps:get(application:get_env(?APP, ?COMPLEX_CLAIM, undefined), Payload),
179    ResourceServerId = rabbit_data_coercion:to_binary(application:get_env(?APP, ?RESOURCE_SERVER_ID, <<>>)),
180
181    AdditionalScopes =
182        case ComplexClaim of
183            L when is_list(L) -> L;
184            M when is_map(M) ->
185                case maps:get(ResourceServerId, M, undefined) of
186                    undefined           -> [];
187                    Ks when is_list(Ks) ->
188                        [erlang:iolist_to_binary([ResourceServerId, <<".">>, K]) || K <- Ks];
189                    ClaimBin when is_binary(ClaimBin) ->
190                        UnprefixedClaims = binary:split(ClaimBin, <<" ">>, [global, trim_all]),
191                        [erlang:iolist_to_binary([ResourceServerId, <<".">>, K]) || K <- UnprefixedClaims];
192                    _ -> []
193                    end;
194            Bin when is_binary(Bin) ->
195                binary:split(Bin, <<" ">>, [global, trim_all]);
196            _ -> []
197            end,
198
199    case AdditionalScopes of
200        [] -> Payload;
201        _  ->
202            ExistingScopes = maps:get(<<"scope">>, Payload, []),
203            maps:put(<<"scope">>, AdditionalScopes ++ ExistingScopes, Payload)
204        end.
205
206%% keycloak token format: https://github.com/rabbitmq/rabbitmq-auth-backend-oauth2/issues/36
207post_process_payload_keycloak(#{<<"authorization">> := Authorization} = Payload) ->
208    AdditionalScopes = case maps:get(<<"permissions">>, Authorization, undefined) of
209        undefined   -> [];
210        Permissions -> extract_scopes_from_keycloak_permissions([], Permissions)
211    end,
212    ExistingScopes = maps:get(<<"scope">>, Payload),
213    maps:put(<<"scope">>, AdditionalScopes ++ ExistingScopes, Payload).
214
215extract_scopes_from_keycloak_permissions(Acc, []) ->
216    Acc;
217extract_scopes_from_keycloak_permissions(Acc, [H | T]) when is_map(H) ->
218    Scopes = case maps:get(<<"scopes">>, H, []) of
219        ScopesAsList when is_list(ScopesAsList) ->
220            ScopesAsList;
221        ScopesAsBinary when is_binary(ScopesAsBinary) ->
222            [ScopesAsBinary]
223    end,
224    extract_scopes_from_keycloak_permissions(Acc ++ Scopes, T);
225extract_scopes_from_keycloak_permissions(Acc, [_ | T]) ->
226    extract_scopes_from_keycloak_permissions(Acc, T).
227
228validate_payload(#{<<"scope">> := _Scope, <<"aud">> := _Aud} = DecodedToken) ->
229    ResourceServerEnv = application:get_env(?APP, ?RESOURCE_SERVER_ID, <<>>),
230    ResourceServerId = rabbit_data_coercion:to_binary(ResourceServerEnv),
231    validate_payload(DecodedToken, ResourceServerId).
232
233validate_payload(#{<<"scope">> := Scope, <<"aud">> := Aud} = DecodedToken, ResourceServerId) ->
234    case check_aud(Aud, ResourceServerId) of
235        ok           -> {ok, DecodedToken#{<<"scope">> => filter_scopes(Scope, ResourceServerId)}};
236        {error, Err} -> {refused, {invalid_aud, Err}}
237    end.
238
239filter_scopes(Scopes, <<"">>) -> Scopes;
240filter_scopes(Scopes, ResourceServerId)  ->
241    PrefixPattern = <<ResourceServerId/binary, ".">>,
242    matching_scopes_without_prefix(Scopes, PrefixPattern).
243
244check_aud(_, <<>>)    -> ok;
245check_aud(Aud, ResourceServerId) ->
246    case Aud of
247        List when is_list(List) ->
248            case lists:member(ResourceServerId, Aud) of
249                true  -> ok;
250                false -> {error, {resource_id_not_found_in_aud, ResourceServerId, Aud}}
251            end;
252        _ -> {error, {badarg, {aud_is_not_a_list, Aud}}}
253    end.
254
255%%--------------------------------------------------------------------
256
257get_scopes(#{<<"scope">> := Scope}) -> Scope.
258
259-spec token_from_context(map()) -> binary() | undefined.
260token_from_context(AuthProps) ->
261    maps:get(password, AuthProps, undefined).
262
263%% Decoded tokens look like this:
264%%
265%% #{<<"aud">>         => [<<"rabbitmq">>, <<"rabbit_client">>],
266%%   <<"authorities">> => [<<"rabbitmq.read:*/*">>, <<"rabbitmq.write:*/*">>, <<"rabbitmq.configure:*/*">>],
267%%   <<"azp">>         => <<"rabbit_client">>,
268%%   <<"cid">>         => <<"rabbit_client">>,
269%%   <<"client_id">>   => <<"rabbit_client">>,
270%%   <<"exp">>         => 1530849387,
271%%   <<"grant_type">>  => <<"client_credentials">>,
272%%   <<"iat">>         => 1530806187,
273%%   <<"iss">>         => <<"http://localhost:8080/uaa/oauth/token">>,
274%%   <<"jti">>         => <<"df5d50a1cdcb4fa6bf32e7e03acfc74d">>,
275%%   <<"rev_sig">>     => <<"2f880d5b">>,
276%%   <<"scope">>       => [<<"rabbitmq.read:*/*">>, <<"rabbitmq.write:*/*">>, <<"rabbitmq.configure:*/*">>],
277%%   <<"sub">>         => <<"rabbit_client">>,
278%%   <<"zid">>         => <<"uaa">>}
279
280-spec username_from(binary(), map()) -> binary() | undefined.
281username_from(ClientProvidedUsername, DecodedToken) ->
282    ClientId = uaa_jwt:client_id(DecodedToken, undefined),
283    Sub      = uaa_jwt:sub(DecodedToken, undefined),
284
285    rabbit_log:debug("Computing username from client's JWT token, client ID: '~s', sub: '~s'",
286                     [ClientId, Sub]),
287
288    case uaa_jwt:client_id(DecodedToken, Sub) of
289        undefined ->
290            case ClientProvidedUsername of
291                undefined -> undefined;
292                <<>>      -> undefined;
293                _Other    -> ClientProvidedUsername
294            end;
295        Value     ->
296            Value
297    end.
298
299-spec tags_from(map()) -> list(atom()).
300tags_from(DecodedToken) ->
301    Scopes    = maps:get(<<"scope">>, DecodedToken, []),
302    TagScopes = matching_scopes_without_prefix(Scopes, <<"tag:">>),
303    lists:usort(lists:map(fun rabbit_data_coercion:to_atom/1, TagScopes)).
304
305matching_scopes_without_prefix(Scopes, PrefixPattern) ->
306    PatternLength = byte_size(PrefixPattern),
307    lists:filtermap(
308        fun(ScopeEl) ->
309            case binary:match(ScopeEl, PrefixPattern) of
310                {0, PatternLength} ->
311                    ElLength = byte_size(ScopeEl),
312                    {true,
313                     binary:part(ScopeEl,
314                                 {PatternLength, ElLength - PatternLength})};
315                _ -> false
316            end
317        end,
318        Scopes).
319