1//
2// Adapted from BearSSL's ctmul64 implementation originally written by Thomas Pornin <pornin@bolet.org>
3
4const std = @import("../std.zig");
5const builtin = @import("builtin");
6const assert = std.debug.assert;
7const math = std.math;
8const mem = std.mem;
9const utils = std.crypto.utils;
10
11/// GHASH is a universal hash function that features multiplication
12/// by a fixed parameter within a Galois field.
13///
14/// It is not a general purpose hash function - The key must be secret, unpredictable and never reused.
15///
16/// GHASH is typically used to compute the authentication tag in the AES-GCM construction.
17pub const Ghash = struct {
18    pub const block_length: usize = 16;
19    pub const mac_length = 16;
20    pub const key_length = 16;
21
22    y0: u64 = 0,
23    y1: u64 = 0,
24    h0: u64,
25    h1: u64,
26    h2: u64,
27    h0r: u64,
28    h1r: u64,
29    h2r: u64,
30
31    hh0: u64 = undefined,
32    hh1: u64 = undefined,
33    hh2: u64 = undefined,
34    hh0r: u64 = undefined,
35    hh1r: u64 = undefined,
36    hh2r: u64 = undefined,
37
38    leftover: usize = 0,
39    buf: [block_length]u8 align(16) = undefined,
40
41    pub fn init(key: *const [key_length]u8) Ghash {
42        const h1 = mem.readIntBig(u64, key[0..8]);
43        const h0 = mem.readIntBig(u64, key[8..16]);
44        const h1r = @bitReverse(u64, h1);
45        const h0r = @bitReverse(u64, h0);
46        const h2 = h0 ^ h1;
47        const h2r = h0r ^ h1r;
48
49        if (builtin.mode == .ReleaseSmall) {
50            return Ghash{
51                .h0 = h0,
52                .h1 = h1,
53                .h2 = h2,
54                .h0r = h0r,
55                .h1r = h1r,
56                .h2r = h2r,
57            };
58        } else {
59            // Precompute H^2
60            var hh = Ghash{
61                .h0 = h0,
62                .h1 = h1,
63                .h2 = h2,
64                .h0r = h0r,
65                .h1r = h1r,
66                .h2r = h2r,
67            };
68            hh.update(key);
69            const hh1 = hh.y1;
70            const hh0 = hh.y0;
71            const hh1r = @bitReverse(u64, hh1);
72            const hh0r = @bitReverse(u64, hh0);
73            const hh2 = hh0 ^ hh1;
74            const hh2r = hh0r ^ hh1r;
75
76            return Ghash{
77                .h0 = h0,
78                .h1 = h1,
79                .h2 = h2,
80                .h0r = h0r,
81                .h1r = h1r,
82                .h2r = h2r,
83
84                .hh0 = hh0,
85                .hh1 = hh1,
86                .hh2 = hh2,
87                .hh0r = hh0r,
88                .hh1r = hh1r,
89                .hh2r = hh2r,
90            };
91        }
92    }
93
94    inline fn clmul_pclmul(x: u64, y: u64) u64 {
95        const Vector = std.meta.Vector;
96        const product = asm (
97            \\ vpclmulqdq $0x00, %[x], %[y], %[out]
98            : [out] "=x" (-> Vector(2, u64)),
99            : [x] "x" (@bitCast(Vector(2, u64), @as(u128, x))),
100              [y] "x" (@bitCast(Vector(2, u64), @as(u128, y))),
101        );
102        return product[0];
103    }
104
105    inline fn clmul_pmull(x: u64, y: u64) u64 {
106        const Vector = std.meta.Vector;
107        const product = asm (
108            \\ pmull %[out].1q, %[x].1d, %[y].1d
109            : [out] "=w" (-> Vector(2, u64)),
110            : [x] "w" (@bitCast(Vector(2, u64), @as(u128, x))),
111              [y] "w" (@bitCast(Vector(2, u64), @as(u128, y))),
112        );
113        return product[0];
114    }
115
116    fn clmul_soft(x: u64, y: u64) u64 {
117        const x0 = x & 0x1111111111111111;
118        const x1 = x & 0x2222222222222222;
119        const x2 = x & 0x4444444444444444;
120        const x3 = x & 0x8888888888888888;
121        const y0 = y & 0x1111111111111111;
122        const y1 = y & 0x2222222222222222;
123        const y2 = y & 0x4444444444444444;
124        const y3 = y & 0x8888888888888888;
125        var z0 = (x0 *% y0) ^ (x1 *% y3) ^ (x2 *% y2) ^ (x3 *% y1);
126        var z1 = (x0 *% y1) ^ (x1 *% y0) ^ (x2 *% y3) ^ (x3 *% y2);
127        var z2 = (x0 *% y2) ^ (x1 *% y1) ^ (x2 *% y0) ^ (x3 *% y3);
128        var z3 = (x0 *% y3) ^ (x1 *% y2) ^ (x2 *% y1) ^ (x3 *% y0);
129        z0 &= 0x1111111111111111;
130        z1 &= 0x2222222222222222;
131        z2 &= 0x4444444444444444;
132        z3 &= 0x8888888888888888;
133        return z0 | z1 | z2 | z3;
134    }
135
136    const has_pclmul = std.Target.x86.featureSetHas(builtin.cpu.features, .pclmul);
137    const has_avx = std.Target.x86.featureSetHas(builtin.cpu.features, .avx);
138    const has_armaes = std.Target.aarch64.featureSetHas(builtin.cpu.features, .aes);
139    const clmul = if (builtin.cpu.arch == .x86_64 and has_pclmul and has_avx) impl: {
140        break :impl clmul_pclmul;
141    } else if (builtin.cpu.arch == .aarch64 and has_armaes) impl: {
142        break :impl clmul_pmull;
143    } else impl: {
144        break :impl clmul_soft;
145    };
146
147    fn blocks(st: *Ghash, msg: []const u8) void {
148        assert(msg.len % 16 == 0); // GHASH blocks() expects full blocks
149        var y1 = st.y1;
150        var y0 = st.y0;
151
152        var i: usize = 0;
153
154        // 2-blocks aggregated reduction
155        if (builtin.mode != .ReleaseSmall) {
156            while (i + 32 <= msg.len) : (i += 32) {
157                // B0 * H^2 unreduced
158                y1 ^= mem.readIntBig(u64, msg[i..][0..8]);
159                y0 ^= mem.readIntBig(u64, msg[i..][8..16]);
160
161                const y1r = @bitReverse(u64, y1);
162                const y0r = @bitReverse(u64, y0);
163                const y2 = y0 ^ y1;
164                const y2r = y0r ^ y1r;
165
166                var z0 = clmul(y0, st.hh0);
167                var z1 = clmul(y1, st.hh1);
168                var z2 = clmul(y2, st.hh2) ^ z0 ^ z1;
169                var z0h = clmul(y0r, st.hh0r);
170                var z1h = clmul(y1r, st.hh1r);
171                var z2h = clmul(y2r, st.hh2r) ^ z0h ^ z1h;
172
173                // B1 * H unreduced
174                const sy1 = mem.readIntBig(u64, msg[i..][16..24]);
175                const sy0 = mem.readIntBig(u64, msg[i..][24..32]);
176
177                const sy1r = @bitReverse(u64, sy1);
178                const sy0r = @bitReverse(u64, sy0);
179                const sy2 = sy0 ^ sy1;
180                const sy2r = sy0r ^ sy1r;
181
182                const sz0 = clmul(sy0, st.h0);
183                const sz1 = clmul(sy1, st.h1);
184                const sz2 = clmul(sy2, st.h2) ^ sz0 ^ sz1;
185                const sz0h = clmul(sy0r, st.h0r);
186                const sz1h = clmul(sy1r, st.h1r);
187                const sz2h = clmul(sy2r, st.h2r) ^ sz0h ^ sz1h;
188
189                // ((B0 * H^2) + B1 * H) (mod M)
190                z0 ^= sz0;
191                z1 ^= sz1;
192                z2 ^= sz2;
193                z0h ^= sz0h;
194                z1h ^= sz1h;
195                z2h ^= sz2h;
196                z0h = @bitReverse(u64, z0h) >> 1;
197                z1h = @bitReverse(u64, z1h) >> 1;
198                z2h = @bitReverse(u64, z2h) >> 1;
199
200                var v3 = z1h;
201                var v2 = z1 ^ z2h;
202                var v1 = z0h ^ z2;
203                var v0 = z0;
204
205                v3 = (v3 << 1) | (v2 >> 63);
206                v2 = (v2 << 1) | (v1 >> 63);
207                v1 = (v1 << 1) | (v0 >> 63);
208                v0 = (v0 << 1);
209
210                v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7);
211                v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57);
212                y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7);
213                y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57);
214            }
215        }
216
217        // single block
218        while (i + 16 <= msg.len) : (i += 16) {
219            y1 ^= mem.readIntBig(u64, msg[i..][0..8]);
220            y0 ^= mem.readIntBig(u64, msg[i..][8..16]);
221
222            const y1r = @bitReverse(u64, y1);
223            const y0r = @bitReverse(u64, y0);
224            const y2 = y0 ^ y1;
225            const y2r = y0r ^ y1r;
226
227            const z0 = clmul(y0, st.h0);
228            const z1 = clmul(y1, st.h1);
229            var z2 = clmul(y2, st.h2) ^ z0 ^ z1;
230            var z0h = clmul(y0r, st.h0r);
231            var z1h = clmul(y1r, st.h1r);
232            var z2h = clmul(y2r, st.h2r) ^ z0h ^ z1h;
233            z0h = @bitReverse(u64, z0h) >> 1;
234            z1h = @bitReverse(u64, z1h) >> 1;
235            z2h = @bitReverse(u64, z2h) >> 1;
236
237            // shift & reduce
238            var v3 = z1h;
239            var v2 = z1 ^ z2h;
240            var v1 = z0h ^ z2;
241            var v0 = z0;
242
243            v3 = (v3 << 1) | (v2 >> 63);
244            v2 = (v2 << 1) | (v1 >> 63);
245            v1 = (v1 << 1) | (v0 >> 63);
246            v0 = (v0 << 1);
247
248            v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7);
249            v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57);
250            y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7);
251            y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57);
252        }
253        st.y1 = y1;
254        st.y0 = y0;
255    }
256
257    pub fn update(st: *Ghash, m: []const u8) void {
258        var mb = m;
259
260        if (st.leftover > 0) {
261            const want = math.min(block_length - st.leftover, mb.len);
262            const mc = mb[0..want];
263            for (mc) |x, i| {
264                st.buf[st.leftover + i] = x;
265            }
266            mb = mb[want..];
267            st.leftover += want;
268            if (st.leftover < block_length) {
269                return;
270            }
271            st.blocks(&st.buf);
272            st.leftover = 0;
273        }
274        if (mb.len >= block_length) {
275            const want = mb.len & ~(block_length - 1);
276            st.blocks(mb[0..want]);
277            mb = mb[want..];
278        }
279        if (mb.len > 0) {
280            for (mb) |x, i| {
281                st.buf[st.leftover + i] = x;
282            }
283            st.leftover += mb.len;
284        }
285    }
286
287    /// Zero-pad to align the next input to the first byte of a block
288    pub fn pad(st: *Ghash) void {
289        if (st.leftover == 0) {
290            return;
291        }
292        var i = st.leftover;
293        while (i < block_length) : (i += 1) {
294            st.buf[i] = 0;
295        }
296        st.blocks(&st.buf);
297        st.leftover = 0;
298    }
299
300    pub fn final(st: *Ghash, out: *[mac_length]u8) void {
301        st.pad();
302        mem.writeIntBig(u64, out[0..8], st.y1);
303        mem.writeIntBig(u64, out[8..16], st.y0);
304
305        utils.secureZero(u8, @ptrCast([*]u8, st)[0..@sizeOf(Ghash)]);
306    }
307
308    pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [key_length]u8) void {
309        var st = Ghash.init(key);
310        st.update(msg);
311        st.final(out);
312    }
313};
314
315const htest = @import("test.zig");
316
317test "ghash" {
318    const key = [_]u8{0x42} ** 16;
319    const m = [_]u8{0x69} ** 256;
320
321    var st = Ghash.init(&key);
322    st.update(&m);
323    var out: [16]u8 = undefined;
324    st.final(&out);
325    try htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out);
326
327    st = Ghash.init(&key);
328    st.update(m[0..100]);
329    st.update(m[100..]);
330    st.final(&out);
331    try htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out);
332}
333