1%% -*- mode: erlang; tab-width: 4; indent-tabs-mode: 1; st-rulers: [70] -*-
2%% vim: ts=4 sw=4 ft=erlang noet
3%%%-------------------------------------------------------------------
4%%% @author Andrew Bennett <potatosaladx@gmail.com>
5%%% @copyright 2014-2016, Andrew Bennett
6%%% @doc
7%%%
8%%% @end
9%%% Created :  08 Jan 2016 by Andrew Bennett <potatosaladx@gmail.com>
10%%%-------------------------------------------------------------------
11-module(jose_jwa_sha3).
12
13%% API
14-export([rol64/2]).
15-export([load64/1]).
16-export([store64/1]).
17-export([keccak_f_1600/1]).
18-export([load_lanes/1]).
19-export([store_lanes/1]).
20-export([keccak/5]).
21-export([keccak_absorb/4]).
22-export([keccak_pad/4]).
23-export([shake128/2]).
24-export([shake256/2]).
25-export([sha3_224/1]).
26-export([sha3_256/1]).
27-export([sha3_384/1]).
28-export([sha3_512/1]).
29
30%%====================================================================
31%% API functions
32%%====================================================================
33
34rol64(A, N) ->
35	((A bsr (64 - (N rem 64))) + (A bsl (N rem 64))) rem (1 bsl 64).
36
37keccak_f_1600_on_lanes(Lanes) ->
38	keccak_f_1600_on_lanes(Lanes, 1, 0).
39
40keccak_f_1600_on_lanes(Lanes, _R, 24) ->
41	Lanes;
42keccak_f_1600_on_lanes(Lanes, R, Round) ->
43	% θ
44	Lanes0 = theta(Lanes),
45	% ρ and π
46	Lanes1 = rho_and_pi(mget(Lanes0, 1, 0), Lanes0, 1, 0, 0),
47	% χ
48	Lanes2 = chi(Lanes1, 0),
49	% ι
50	{Lanes3, Rn} = iota(Lanes2, R, 0),
51	keccak_f_1600_on_lanes(Lanes3, Rn, Round + 1).
52
53%% @private
54theta(Lanes) ->
55	C = list_to_tuple([begin
56		E = mget(Lanes, X),
57		mget(E, 0) bxor mget(E, 1) bxor mget(E, 2) bxor mget(E, 3) bxor mget(E, 4)
58	end || X <- lists:seq(0, 4)]),
59	D = list_to_tuple([begin
60		mget(C, (X + 4) rem 5) bxor rol64(mget(C, (X + 1) rem 5), 1)
61	end || X <- lists:seq(0, 4)]),
62	list_to_tuple([begin
63		list_to_tuple([begin
64			mget(Lanes, X, Y) bxor mget(D, X)
65		end || Y <- lists:seq(0, 4)])
66	end || X <- lists:seq(0, 4)]).
67
68%% @private
69rho_and_pi(_Current, Lanes, _X, _Y, 24) ->
70	Lanes;
71rho_and_pi(Current, Lanes, X, Y, T) ->
72	Xn = Y,
73	Yn = ((2 * X) + (3 * Y)) rem 5,
74	Zn = rol64(Current, ((T + 1) * (T + 2)) div 2),
75	rho_and_pi(mget(Lanes, Xn, Yn), mput(Lanes, Xn, Yn, Zn), Xn, Yn, T + 1).
76
77%% @private
78chi(Lanes, 5) ->
79	Lanes;
80chi(Lanes, Y) ->
81	T = list_to_tuple([mget(Lanes, X, Y) || X <- lists:seq(0, 4)]),
82	chi(Lanes, T, Y, 0).
83
84chi(Lanes, _T, Y, 5) ->
85	chi(Lanes, Y + 1);
86chi(Lanes, T, Y, X) ->
87	V = mget(T, X) bxor ((bnot mget(T, (X + 1) rem 5)) band mget(T, (X + 2) rem 5)),
88	chi(mput(Lanes, X, Y, V), T, Y, X + 1).
89
90%% @private
91iota(Lanes, R, 7) ->
92	{Lanes, R};
93iota(Lanes, R, J) ->
94	Rn = ((R bsl 1) bxor ((R bsr 7) * 16#71)) rem 256,
95	case Rn band 2 of
96		0 ->
97			iota(Lanes, Rn, J + 1);
98		_ ->
99			Right = (1 bsl ((1 bsl J) - 1)),
100			Left = mget(Lanes, 0, 0),
101			Down = Left bxor Right,
102			V = Down,
103			iota(mput(Lanes, 0, 0, V), Rn, J + 1)
104	end.
105
106load64(<< B:64/unsigned-little-integer-unit:1 >>) ->
107	B.
108
109store64(B) when is_integer(B) ->
110	<< B:64/unsigned-little-integer-unit:1 >>.
111
112keccak_f_1600(State) ->
113	Lanes0 = load_lanes(State),
114	Lanes1 = keccak_f_1600_on_lanes(Lanes0),
115	store_lanes(Lanes1).
116
117load_lanes(State) ->
118	load_lanes(State, 0, 0, [], []).
119
120%% @private
121load_lanes(_State, 5, _Y, [], Lanes) ->
122	list_to_tuple(lists:reverse(Lanes));
123load_lanes(State, X, 5, Lane, Lanes) ->
124	load_lanes(State, X + 1, 0, [], [list_to_tuple(lists:reverse(Lane)) | Lanes]);
125load_lanes(State, X, Y, Lane, Lanes) ->
126	Pos = 8 * (X + 5 * Y),
127	Len = 8,
128	load_lanes(State, X, Y + 1, [load64(binary:part(State, Pos, Len)) | Lane], Lanes).
129
130store_lanes(Lanes) ->
131	store_lanes(Lanes, 0, 0, << 0:1600 >>).
132
133store_lanes(_Lanes, 5, _Y, StateBytes) ->
134	StateBytes;
135store_lanes(Lanes, X, 5, StateBytes) ->
136	store_lanes(Lanes, X + 1, 0, StateBytes);
137store_lanes(Lanes, X, Y, StateBytes) ->
138	V = mget(Lanes, X, Y),
139	Pos = 8 * (X + 5 * Y),
140	Len = 8,
141	<< StateHead:Pos/binary, _:Len/binary, StateTail/binary >> = StateBytes,
142	store_lanes(Lanes, X, Y + 1, << StateHead/binary, (store64(V))/binary, StateTail/binary >>).
143
144keccak(Rate, Capacity, InputBytes, DelimitedSuffix, OutputByteLen) ->
145	case (Rate + Capacity) =/= 1600 orelse (Rate rem 8) =/= 0 of
146		true ->
147			erlang:error(badarg);
148		false ->
149			{RateInBytes, StateBytes} = keccak_absorb(Rate div 8, InputBytes, << 0:1600 >>, DelimitedSuffix),
150			keccak_squeeze(RateInBytes, OutputByteLen, StateBytes, <<>>)
151	end.
152
153% Absorb all the input blocks
154keccak_absorb(RateInBytes, InputBytes, StateBytes, DelimitedSuffix)
155		when is_integer(RateInBytes)
156		andalso byte_size(InputBytes) >= RateInBytes ->
157	<< InputHead:RateInBytes/binary, InputTail/binary >> = InputBytes,
158	<< StateHead:RateInBytes/binary, StateTail/binary >> = StateBytes,
159	State = << (crypto:exor(StateHead, InputHead))/binary, StateTail/binary >>,
160	keccak_absorb(RateInBytes, InputTail, keccak_f_1600(State), DelimitedSuffix);
161keccak_absorb(RateInBytes, InputBytes, StateBytes, DelimitedSuffix) ->
162	BlockSize = byte_size(InputBytes),
163	<< StateHead:BlockSize/binary, StateTail/binary >> = StateBytes,
164	State = << (crypto:exor(StateHead, InputBytes))/binary, StateTail/binary >>,
165	keccak_pad(RateInBytes, BlockSize, State, DelimitedSuffix).
166
167% Do the padding and switch to the squeezing phase
168keccak_pad(RateInBytes, BlockSize, StateBytes, DelimitedSuffix) ->
169	<< StateHead:BlockSize/binary, S:8/integer, StateTail/binary >> = StateBytes,
170	State0 = << StateHead/binary, (S bxor DelimitedSuffix):8/integer, StateTail/binary >>,
171	State1 = case (DelimitedSuffix band 16#80) =/= 0 andalso BlockSize =:= (RateInBytes - 1) of
172		false ->
173			State0;
174		true ->
175			keccak_f_1600(State0)
176	end,
177	RateInBytesSubOne = RateInBytes - 1,
178	<< XHead:RateInBytesSubOne/binary, X:8/integer, XTail/binary >> = State1,
179	State2 = << XHead/binary, (X bxor 16#80):8/integer, XTail/binary >>,
180	State3 = keccak_f_1600(State2),
181	{RateInBytes, State3}.
182
183% Squeeze out all the output blocks
184keccak_squeeze(RateInBytes, OutputByteLen, StateBytes, OutputBytes)
185		when OutputByteLen > 0 ->
186	BlockSize = min(OutputByteLen, RateInBytes),
187	<< StateBlock:BlockSize/binary, _/binary >> = StateBytes,
188	NewOutputByteLen = OutputByteLen - BlockSize,
189	State = case NewOutputByteLen > 0 of
190		true ->
191			keccak_f_1600(StateBytes);
192		false ->
193			StateBytes
194	end,
195	keccak_squeeze(RateInBytes, NewOutputByteLen, State, << OutputBytes/binary, StateBlock/binary >>);
196keccak_squeeze(_RateInBytes, _OutputByteLen, _StateBytes, OutputBytes) ->
197	OutputBytes.
198
199shake128(InputBytes, OutputByteLen)
200		when is_binary(InputBytes)
201		andalso is_integer(OutputByteLen)
202		andalso OutputByteLen >= 0 ->
203	keccak(1344, 256, InputBytes, 16#1F, OutputByteLen).
204
205shake256(InputBytes, OutputByteLen)
206		when is_binary(InputBytes)
207		andalso is_integer(OutputByteLen)
208		andalso OutputByteLen >= 0 ->
209	keccak(1088, 512, InputBytes, 16#1F, OutputByteLen).
210
211sha3_224(InputBytes) ->
212	keccak(1152, 448, InputBytes, 16#06, 224 div 8).
213
214sha3_256(InputBytes) ->
215	keccak(1088, 512, InputBytes, 16#06, 256 div 8).
216
217sha3_384(InputBytes) ->
218	keccak(832, 768, InputBytes, 16#06, 384 div 8).
219
220sha3_512(InputBytes) ->
221	keccak(576, 1024, InputBytes, 16#06, 512 div 8).
222
223%%%-------------------------------------------------------------------
224%%% Internal functions
225%%%-------------------------------------------------------------------
226
227%% @private
228mget(M, X) ->
229	element(X + 1, M).
230
231%% @private
232mget(M, X, Y) ->
233	mget(mget(M, X), Y).
234
235%% @private
236mput({E0, E1, E2, E3, E4}, 0, Y, V) ->
237	{mput(E0, Y, V), E1, E2, E3, E4};
238mput({E0, E1, E2, E3, E4}, 1, Y, V) ->
239	{E0, mput(E1, Y, V), E2, E3, E4};
240mput({E0, E1, E2, E3, E4}, 2, Y, V) ->
241	{E0, E1, mput(E2, Y, V), E3, E4};
242mput({E0, E1, E2, E3, E4}, 3, Y, V) ->
243	{E0, E1, E2, mput(E3, Y, V), E4};
244mput({E0, E1, E2, E3, E4}, 4, Y, V) ->
245	{E0, E1, E2, E3, mput(E4, Y, V)}.
246
247%% @private
248mput({_, V1, V2, V3, V4}, 0, V0) ->
249	{V0, V1, V2, V3, V4};
250mput({V0, _, V2, V3, V4}, 1, V1) ->
251	{V0, V1, V2, V3, V4};
252mput({V0, V1, _, V3, V4}, 2, V2) ->
253	{V0, V1, V2, V3, V4};
254mput({V0, V1, V2, _, V4}, 3, V3) ->
255	{V0, V1, V2, V3, V4};
256mput({V0, V1, V2, V3, _}, 4, V4) ->
257	{V0, V1, V2, V3, V4}.
258