1%%
2%% %CopyrightBegin%
3%%
4%% Copyright Ericsson AB 2004-2016. All Rights Reserved.
5%%
6%% Licensed under the Apache License, Version 2.0 (the "License");
7%% you may not use this file except in compliance with the License.
8%% You may obtain a copy of the License at
9%%
10%%     http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing, software
13%% distributed under the License is distributed on an "AS IS" BASIS,
14%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15%% See the License for the specific language governing permissions and
16%% limitations under the License.
17%%
18%% %CopyrightEnd%
19%%
20%% AES: RFC 3826
21%%
22
23-module(snmp_usm).
24
25%% Avoid warning for local function error/1 clashing with autoimported BIF.
26-compile({no_auto_import,[error/1]}).
27-export([passwd2localized_key/3, localize_key/3]).
28-export([auth_in/4, auth_out/4, set_msg_auth_params/3]).
29-export([des_encrypt/3, des_decrypt/3]).
30-export([aes_encrypt/5, aes_decrypt/5]).
31
32
33-define(SNMP_USE_V3, true).
34-include("snmp_types.hrl").
35-include("SNMP-USER-BASED-SM-MIB.hrl").
36-include("SNMP-USM-AES-MIB.hrl").
37
38-define(VMODULE,"USM").
39-include("snmp_verbosity.hrl").
40
41
42%%-----------------------------------------------------------------
43
44-define(twelwe_zeros, [0,0,0,0,0,0,0,0,0,0,0,0]).
45
46-define(i32(Int), (Int bsr 24) band 255, (Int bsr 16) band 255, (Int bsr 8) band 255, Int band 255).
47
48-define(BLOCK_CIPHER_AES(Key), case bit_size(iolist_to_binary(Key)) of
49                                   128 -> aes_128_cfb128;
50                                   192 -> aes_192_cfb128;
51                                   256 -> aes_256_cfb128
52                               end).
53
54-define(BLOCK_CIPHER_DES, des_cbc).
55
56
57%%-----------------------------------------------------------------
58%% Func: passwd2localized_key/3
59%% Types: Alg      = md5 | sha
60%%        Passwd   = string()
61%%        EngineID = string()
62%% Purpose: Generates a key that can be used as an authentication
63%%          or privacy key using MD5 och SHA.  The key is
64%%          localized for EngineID.
65%%          The algorithm is described in appendix A.1 2) of
66%%          rfc2274.
67%%-----------------------------------------------------------------
68passwd2localized_key(Alg, Passwd, EngineID) when length(Passwd) > 0 ->
69    Key = mk_digest(Alg, Passwd),
70    localize_key(Alg, Key, EngineID).
71
72
73%%-----------------------------------------------------------------
74%% Func: localize_key/3
75%% Types: Alg      = md5 | sha
76%%        Passwd   = string()
77%%        EngineID = string()
78%% Purpose: Localizes an unlocalized key for EngineID.  See rfc2274
79%%          section 2.6 for a definition of localized keys.
80%%-----------------------------------------------------------------
81localize_key(Alg, Key, EngineID) ->
82    Str = [Key, EngineID, Key],
83    binary_to_list(crypto:hash(Alg, Str)).
84
85
86mk_digest(md5, Passwd) ->
87    mk_md5_digest(Passwd);
88mk_digest(sha, Passwd) ->
89    mk_sha_digest(Passwd).
90
91mk_md5_digest(Passwd) ->
92    Ctx = crypto:hash_init(md5),
93    Ctx2 = md5_loop(0, [], Ctx, Passwd, length(Passwd)),
94    crypto:hash_final(Ctx2).
95
96md5_loop(Count, Buf, Ctx, Passwd, PasswdLen) when Count < 1048576 ->
97    {Buf64, NBuf} = mk_buf64(length(Buf), Buf, Passwd, PasswdLen),
98    NCtx = crypto:hash_update(Ctx, Buf64),
99    md5_loop(Count+64, NBuf, NCtx, Passwd, PasswdLen);
100md5_loop(_Count, _Buf, Ctx, _Passwd, _PasswdLen) ->
101    Ctx.
102
103mk_sha_digest(Passwd) ->
104    Ctx = crypto:hash_init(sha),
105    Ctx2 = sha_loop(0, [], Ctx, Passwd, length(Passwd)),
106    crypto:hash_final(Ctx2).
107
108sha_loop(Count, Buf, Ctx, Passwd, PasswdLen) when Count < 1048576 ->
109    {Buf64, NBuf} = mk_buf64(length(Buf), Buf, Passwd, PasswdLen),
110    NCtx = crypto:hash_update(Ctx, Buf64),
111    sha_loop(Count+64, NBuf, NCtx, Passwd, PasswdLen);
112sha_loop(_Count, _Buf, Ctx, _Passwd, _PasswdLen) ->
113    Ctx.
114
115%% Create a 64 bytes long string, by repeating Passwd as many times
116%% as necessary. Output is the 64 byte string, and the rest of the
117%% last repetition of the Passwd. This is used as input in the next
118%% invocation.
119mk_buf64(BufLen, Buf, Passwd, PasswdLen) ->
120    case BufLen + PasswdLen of
121	TotLen when TotLen > 64 ->
122	    {[Buf, lists:sublist(Passwd, 64-BufLen)],
123	     lists:sublist(Passwd, 65-BufLen, PasswdLen)};
124	TotLen ->
125	    mk_buf64(TotLen, [Buf, Passwd], Passwd, PasswdLen)
126    end.
127
128
129%%-----------------------------------------------------------------
130%% Auth and priv algorithms
131%%-----------------------------------------------------------------
132
133auth_in(usmHMACMD5AuthProtocol, AuthKey, AuthParams, Packet) ->
134    md5_auth_in(AuthKey, AuthParams, Packet);
135auth_in(?usmHMACMD5AuthProtocol, AuthKey, AuthParams, Packet) ->
136    md5_auth_in(AuthKey, AuthParams, Packet);
137auth_in(usmHMACSHAAuthProtocol, AuthKey, AuthParams, Packet) ->
138    sha_auth_in(AuthKey, AuthParams, Packet);
139auth_in(?usmHMACSHAAuthProtocol, AuthKey, AuthParams, Packet) ->
140    sha_auth_in(AuthKey, AuthParams, Packet).
141
142auth_out(usmNoAuthProtocol, _AuthKey, _Message, _UsmSecParams) -> % 3.1.3
143    error(unSupportedSecurityLevel);
144auth_out(?usmNoAuthProtocol, _AuthKey, _Message, _UsmSecParams) -> % 3.1.3
145    error(unSupportedSecurityLevel);
146auth_out(usmHMACMD5AuthProtocol, AuthKey, Message, UsmSecParams) ->
147    md5_auth_out(AuthKey, Message, UsmSecParams);
148auth_out(?usmHMACMD5AuthProtocol, AuthKey, Message, UsmSecParams) ->
149    md5_auth_out(AuthKey, Message, UsmSecParams);
150auth_out(usmHMACSHAAuthProtocol, AuthKey, Message, UsmSecParams) ->
151    sha_auth_out(AuthKey, Message, UsmSecParams);
152auth_out(?usmHMACSHAAuthProtocol, AuthKey, Message, UsmSecParams) ->
153    sha_auth_out(AuthKey, Message, UsmSecParams).
154
155md5_auth_out(AuthKey, Message, UsmSecParams) ->
156    %% ?vtrace("md5_auth_out -> entry with"
157    %%  	    "~n   AuthKey:      ~w"
158    %% 	    "~n   Message:      ~w"
159    %%  	    "~n   UsmSecParams: ~w", [AuthKey, Message, UsmSecParams]),
160    %% 6.3.1.1
161    Message2 = set_msg_auth_params(Message, UsmSecParams, ?twelwe_zeros),
162    Packet   = snmp_pdus:enc_message_only(Message2),
163    %% 6.3.1.2-4 is done by the crypto function
164    %% 6.3.1.4
165    MAC = binary_to_list(crypto:macN(hmac, md5, AuthKey, Packet, 12)),
166    %% ?vtrace("md5_auth_out -> crypto (md5) encoded"
167    %%  	    "~n   MAC: ~w", [MAC]),
168    %% 6.3.1.5
169    set_msg_auth_params(Message, UsmSecParams, MAC).
170
171md5_auth_in(AuthKey, AuthParams, Packet) when length(AuthParams) == 12 ->
172    %% ?vtrace("md5_auth_in -> entry with"
173    %%  	    "~n   AuthKey:    ~w"
174    %%  	    "~n   AuthParams: ~w"
175    %%  	    "~n   Packet:     ~w", [AuthKey, AuthParams, Packet]),
176    %% 6.3.2.3
177    Packet2 = patch_packet(binary_to_list(Packet)),
178    %% 6.3.2.5
179    MAC = binary_to_list(crypto:macN(hmac, md5, AuthKey, Packet2, 12)),
180    %% 6.3.2.6
181    %% ?vtrace("md5_auth_in -> crypto (md5) encoded"
182    %%  	    "~n   MAC: ~w", [MAC]),
183    MAC == AuthParams;
184md5_auth_in(_AuthKey, _AuthParams, _Packet) ->
185    %% 6.3.2.1
186    ?vtrace("md5_auth_in -> entry with"
187	    "~n   _AuthKey:    ~p"
188	    "~n   _AuthParams: ~p", [_AuthKey, _AuthParams]),
189    false.
190
191
192sha_auth_out(AuthKey, Message, UsmSecParams) ->
193    %% 7.3.1.1
194    Message2 = set_msg_auth_params(Message, UsmSecParams, ?twelwe_zeros),
195    Packet = snmp_pdus:enc_message_only(Message2),
196    %% 7.3.1.2-4 is done by the crypto function
197    %% 7.3.1.4
198    MAC = binary_to_list(crypto:macN(hmac, sha, AuthKey, Packet, 12)),
199    %% 7.3.1.5
200    set_msg_auth_params(Message, UsmSecParams, MAC).
201
202sha_auth_in(AuthKey, AuthParams, Packet) when length(AuthParams) =:= 12 ->
203    %% 7.3.2.3
204    Packet2 = patch_packet(binary_to_list(Packet)),
205    %% 7.3.2.5
206    MAC = binary_to_list(crypto:macN(hmac, sha, AuthKey, Packet2, 12)),
207    %% 7.3.2.6
208    MAC == AuthParams;
209sha_auth_in(_AuthKey, _AuthParams, _Packet) ->
210    %% 7.3.2.1
211    ?vtrace("sha_auth_in -> entry with"
212	    "~n   _AuthKey:    ~p"
213	    "~n   _AuthParams: ~p", [_AuthKey, _AuthParams]),
214    false.
215
216
217des_encrypt(PrivKey, Data, SaltFun) ->
218    [A,B,C,D,E,F,G,H | PreIV] = PrivKey,
219    DesKey = [A,B,C,D,E,F,G,H],
220    Salt = SaltFun(),
221    IV = list_to_binary(snmp_misc:str_xor(PreIV, Salt)),
222    TailLen = (8 - (length(Data) rem 8)) rem 8,
223    Tail = mk_tail(TailLen),
224    EncData = crypto:crypto_one_time(?BLOCK_CIPHER_DES,
225                                     DesKey, IV, [Data,Tail], true),
226    {ok, binary_to_list(EncData), Salt}.
227
228des_decrypt(PrivKey, MsgPrivParams, EncData)
229  when length(MsgPrivParams) =:= 8 ->
230    ?vtrace("des_decrypt -> entry with"
231	    "~n   PrivKey:       ~p"
232	    "~n   MsgPrivParams: ~p"
233	    "~n   EncData:       ~p", [PrivKey, MsgPrivParams, EncData]),
234    [A,B,C,D,E,F,G,H | PreIV] = PrivKey,
235    DesKey = [A,B,C,D,E,F,G,H],
236    Salt = MsgPrivParams,
237    IV = list_to_binary(snmp_misc:str_xor(PreIV, Salt)),
238    %% Whatabout errors here???  E.g. not a mulitple of 8!
239    Data = binary_to_list(crypto:crypto_one_time(?BLOCK_CIPHER_DES,
240                                                 DesKey, IV, EncData, false)),
241    Data2 = snmp_pdus:strip_encrypted_scoped_pdu_data(Data),
242    {ok, Data2};
243des_decrypt(PrivKey, BadMsgPrivParams, EncData) ->
244    ?vtrace("des_decrypt -> entry when bad MsgPrivParams"
245	    "~n   PrivKey:          ~p"
246	    "~n   BadMsgPrivParams: ~p"
247	    "~n   EncData:          ~p",
248	    [PrivKey, BadMsgPrivParams, EncData]),
249    throw({error, {bad_msgPrivParams, PrivKey, BadMsgPrivParams, EncData}}).
250
251
252aes_encrypt(PrivKey, Data, SaltFun, EngineBoots, EngineTime) ->
253    AesKey = PrivKey,
254    Salt = SaltFun(),
255    IV = list_to_binary([?i32(EngineBoots), ?i32(EngineTime) | Salt]),
256    EncData = crypto:crypto_one_time(?BLOCK_CIPHER_AES(AesKey),
257                                     AesKey, IV, Data, true),
258    {ok, binary_to_list(EncData), Salt}.
259
260aes_decrypt(PrivKey, MsgPrivParams, EncData, EngineBoots, EngineTime)
261  when length(MsgPrivParams) =:= 8 ->
262    AesKey = PrivKey,
263    Salt = MsgPrivParams,
264    IV = list_to_binary([?i32(EngineBoots), ?i32(EngineTime) | Salt]),
265    %% Whatabout errors here???  E.g. not a mulitple of 8!
266    Data = binary_to_list(crypto:crypto_one_time(?BLOCK_CIPHER_AES(AesKey),
267                                                 AesKey, IV, EncData, false)),
268    Data2 = snmp_pdus:strip_encrypted_scoped_pdu_data(Data),
269    {ok, Data2}.
270
271
272%%-----------------------------------------------------------------
273%% Utility functions
274%%-----------------------------------------------------------------
275mk_tail(N) when N > 0 ->
276    [0 | mk_tail(N-1)];
277mk_tail(0) ->
278    [].
279
280set_msg_auth_params(Message, UsmSecParams, AuthParams) ->
281    NUsmSecParams =
282	UsmSecParams#usmSecurityParameters{msgAuthenticationParameters =
283					   AuthParams},
284    SecBytes = snmp_pdus:enc_usm_security_parameters(NUsmSecParams),
285    VsnHdr   = Message#message.vsn_hdr,
286    NVsnHdr  = VsnHdr#v3_hdr{msgSecurityParameters = SecBytes},
287    Message#message{vsn_hdr = NVsnHdr}.
288
289
290%% Not very nice...
291%% This function patches the asn.1 encoded message. It changes the
292%% AuthenticationParameters to 12 zeros.
293%% NOTE: returns a deep list of bytes
294patch_packet([48 | T]) ->
295    %% Length for whole packet - 2 is tag for version
296    {Len1, [2 | T1]} = split_len(T),
297    %% Length for version - 48 is tag for header data
298    {Len2, [Vsn,48|T2]} = split_len(T1),
299    %% Length for header data
300    {Len3, T3} = split_len(T2),
301    [48,Len1,2,Len2,Vsn,48,Len3|pp2(dec_len(Len3),T3)].
302
303%% Skip HeaderData - 4 is tag for SecurityParameters
304pp2(0,[4|T]) ->
305    %% 48 is tag for UsmSecParams
306    {Len1,[48|T1]} = split_len(T),
307    %% 4 is tag for EngineID
308    {Len2,[4|T2]} = split_len(T1),
309    %% Len 3 is length for EngineID
310    {Len3,T3} = split_len(T2),
311    [4,Len1,48,Len2,4,Len3|pp3(dec_len(Len3),T3)];
312pp2(N,[H|T]) ->
313    [H|pp2(N-1,T)].
314
315%% Skip EngineID - 2 is tag for EngineBoots
316pp3(0,[2|T]) ->
317    {Len1,T1} = split_len(T),
318    [2,Len1|pp4(dec_len(Len1),T1)];
319pp3(N,[H|T]) ->
320    [H|pp3(N-1,T)].
321
322%% Skip EngineBoots - 2 is tag for EngineTime
323pp4(0,[2|T]) ->
324    {Len1,T1} = split_len(T),
325    [2,Len1|pp5(dec_len(Len1),T1)];
326pp4(N,[H|T]) ->
327    [H|pp4(N-1,T)].
328
329%% Skip EngineTime - 4 is tag for UserName
330pp5(0,[4|T]) ->
331    {Len1,T1} = split_len(T),
332    [4,Len1|pp6(dec_len(Len1),T1)];
333pp5(N,[H|T]) ->
334    [H|pp5(N-1,T)].
335
336%% Skip UserName - 4 is tag for AuthenticationParameters
337%% This is what we're looking for!
338pp6(0,[4|T]) ->
339    {Len1,[_,_,_,_,_,_,_,_,_,_,_,_|T1]} = split_len(T),
340    12 = dec_len(Len1),
341    [4,Len1,?twelwe_zeros|T1];
342pp6(N,[H|T]) ->
343    [H|pp6(N-1,T)].
344
345
346%% Returns {LengthOctets, Rest}
347split_len([Hd|Tl]) ->
348    %% definite form
349    case is8set(Hd) of
350	0 -> % Short form
351	    {Hd,Tl};
352	1 -> % Long form - at least one more octet
353	    No = clear(Hd, 8),
354	    {DigList,Rest} = head(No,Tl),
355	    {[Hd | DigList], Rest}
356    end.
357
358dec_len(D) when is_integer(D) ->
359    D;
360dec_len([_LongOctet|T]) ->
361    dl(T).
362dl([D]) ->
363    D;
364dl([A,B]) ->
365    (A bsl 8) bor B;
366dl([A,B,C]) ->
367    (A bsl 16) bor (B bsl 8) bor C;
368dl([0 | T]) ->
369    dl(T).
370
371head(L,List) when length(List) == L -> {List,[]};
372head(L,List) ->
373    head(L,List,[]).
374
375head(0,L,Res) ->
376    {lists:reverse(Res),L};
377
378head(Int,[H|Tail],Res) ->
379    head(Int-1,Tail,[H|Res]).
380
381clear(Byte, 8) ->
382    Byte band 127.
383%% clear(Byte,Pos) when Pos < 9 ->
384%%     Mask = bnot bset(0,Pos),
385%%     Mask band Byte.
386
387%% bset(Byte, 8) ->
388%%     Byte bor 2#10000000;
389%% bset(Byte, Pos) when (Pos < 9) ->
390%%     Mask = 1 bsl (Pos-1),
391%%     Byte bor Mask.
392
393is8set(Byte) ->
394    if
395	Byte > 127 -> 1;
396	true -> 0
397    end.
398
399error(Reason) ->
400    throw({error, Reason}).
401
402