1// 2// Adapted from BearSSL's ctmul64 implementation originally written by Thomas Pornin <pornin@bolet.org> 3 4const std = @import("../std.zig"); 5const builtin = @import("builtin"); 6const assert = std.debug.assert; 7const math = std.math; 8const mem = std.mem; 9const utils = std.crypto.utils; 10 11/// GHASH is a universal hash function that features multiplication 12/// by a fixed parameter within a Galois field. 13/// 14/// It is not a general purpose hash function - The key must be secret, unpredictable and never reused. 15/// 16/// GHASH is typically used to compute the authentication tag in the AES-GCM construction. 17pub const Ghash = struct { 18 pub const block_length: usize = 16; 19 pub const mac_length = 16; 20 pub const key_length = 16; 21 22 y0: u64 = 0, 23 y1: u64 = 0, 24 h0: u64, 25 h1: u64, 26 h2: u64, 27 h0r: u64, 28 h1r: u64, 29 h2r: u64, 30 31 hh0: u64 = undefined, 32 hh1: u64 = undefined, 33 hh2: u64 = undefined, 34 hh0r: u64 = undefined, 35 hh1r: u64 = undefined, 36 hh2r: u64 = undefined, 37 38 leftover: usize = 0, 39 buf: [block_length]u8 align(16) = undefined, 40 41 pub fn init(key: *const [key_length]u8) Ghash { 42 const h1 = mem.readIntBig(u64, key[0..8]); 43 const h0 = mem.readIntBig(u64, key[8..16]); 44 const h1r = @bitReverse(u64, h1); 45 const h0r = @bitReverse(u64, h0); 46 const h2 = h0 ^ h1; 47 const h2r = h0r ^ h1r; 48 49 if (builtin.mode == .ReleaseSmall) { 50 return Ghash{ 51 .h0 = h0, 52 .h1 = h1, 53 .h2 = h2, 54 .h0r = h0r, 55 .h1r = h1r, 56 .h2r = h2r, 57 }; 58 } else { 59 // Precompute H^2 60 var hh = Ghash{ 61 .h0 = h0, 62 .h1 = h1, 63 .h2 = h2, 64 .h0r = h0r, 65 .h1r = h1r, 66 .h2r = h2r, 67 }; 68 hh.update(key); 69 const hh1 = hh.y1; 70 const hh0 = hh.y0; 71 const hh1r = @bitReverse(u64, hh1); 72 const hh0r = @bitReverse(u64, hh0); 73 const hh2 = hh0 ^ hh1; 74 const hh2r = hh0r ^ hh1r; 75 76 return Ghash{ 77 .h0 = h0, 78 .h1 = h1, 79 .h2 = h2, 80 .h0r = h0r, 81 .h1r = h1r, 82 .h2r = h2r, 83 84 .hh0 = hh0, 85 .hh1 = hh1, 86 .hh2 = hh2, 87 .hh0r = hh0r, 88 .hh1r = hh1r, 89 .hh2r = hh2r, 90 }; 91 } 92 } 93 94 inline fn clmul_pclmul(x: u64, y: u64) u64 { 95 const Vector = std.meta.Vector; 96 const product = asm ( 97 \\ vpclmulqdq $0x00, %[x], %[y], %[out] 98 : [out] "=x" (-> Vector(2, u64)), 99 : [x] "x" (@bitCast(Vector(2, u64), @as(u128, x))), 100 [y] "x" (@bitCast(Vector(2, u64), @as(u128, y))), 101 ); 102 return product[0]; 103 } 104 105 inline fn clmul_pmull(x: u64, y: u64) u64 { 106 const Vector = std.meta.Vector; 107 const product = asm ( 108 \\ pmull %[out].1q, %[x].1d, %[y].1d 109 : [out] "=w" (-> Vector(2, u64)), 110 : [x] "w" (@bitCast(Vector(2, u64), @as(u128, x))), 111 [y] "w" (@bitCast(Vector(2, u64), @as(u128, y))), 112 ); 113 return product[0]; 114 } 115 116 fn clmul_soft(x: u64, y: u64) u64 { 117 const x0 = x & 0x1111111111111111; 118 const x1 = x & 0x2222222222222222; 119 const x2 = x & 0x4444444444444444; 120 const x3 = x & 0x8888888888888888; 121 const y0 = y & 0x1111111111111111; 122 const y1 = y & 0x2222222222222222; 123 const y2 = y & 0x4444444444444444; 124 const y3 = y & 0x8888888888888888; 125 var z0 = (x0 *% y0) ^ (x1 *% y3) ^ (x2 *% y2) ^ (x3 *% y1); 126 var z1 = (x0 *% y1) ^ (x1 *% y0) ^ (x2 *% y3) ^ (x3 *% y2); 127 var z2 = (x0 *% y2) ^ (x1 *% y1) ^ (x2 *% y0) ^ (x3 *% y3); 128 var z3 = (x0 *% y3) ^ (x1 *% y2) ^ (x2 *% y1) ^ (x3 *% y0); 129 z0 &= 0x1111111111111111; 130 z1 &= 0x2222222222222222; 131 z2 &= 0x4444444444444444; 132 z3 &= 0x8888888888888888; 133 return z0 | z1 | z2 | z3; 134 } 135 136 const has_pclmul = std.Target.x86.featureSetHas(builtin.cpu.features, .pclmul); 137 const has_avx = std.Target.x86.featureSetHas(builtin.cpu.features, .avx); 138 const has_armaes = std.Target.aarch64.featureSetHas(builtin.cpu.features, .aes); 139 const clmul = if (builtin.cpu.arch == .x86_64 and has_pclmul and has_avx) impl: { 140 break :impl clmul_pclmul; 141 } else if (builtin.cpu.arch == .aarch64 and has_armaes) impl: { 142 break :impl clmul_pmull; 143 } else impl: { 144 break :impl clmul_soft; 145 }; 146 147 fn blocks(st: *Ghash, msg: []const u8) void { 148 assert(msg.len % 16 == 0); // GHASH blocks() expects full blocks 149 var y1 = st.y1; 150 var y0 = st.y0; 151 152 var i: usize = 0; 153 154 // 2-blocks aggregated reduction 155 if (builtin.mode != .ReleaseSmall) { 156 while (i + 32 <= msg.len) : (i += 32) { 157 // B0 * H^2 unreduced 158 y1 ^= mem.readIntBig(u64, msg[i..][0..8]); 159 y0 ^= mem.readIntBig(u64, msg[i..][8..16]); 160 161 const y1r = @bitReverse(u64, y1); 162 const y0r = @bitReverse(u64, y0); 163 const y2 = y0 ^ y1; 164 const y2r = y0r ^ y1r; 165 166 var z0 = clmul(y0, st.hh0); 167 var z1 = clmul(y1, st.hh1); 168 var z2 = clmul(y2, st.hh2) ^ z0 ^ z1; 169 var z0h = clmul(y0r, st.hh0r); 170 var z1h = clmul(y1r, st.hh1r); 171 var z2h = clmul(y2r, st.hh2r) ^ z0h ^ z1h; 172 173 // B1 * H unreduced 174 const sy1 = mem.readIntBig(u64, msg[i..][16..24]); 175 const sy0 = mem.readIntBig(u64, msg[i..][24..32]); 176 177 const sy1r = @bitReverse(u64, sy1); 178 const sy0r = @bitReverse(u64, sy0); 179 const sy2 = sy0 ^ sy1; 180 const sy2r = sy0r ^ sy1r; 181 182 const sz0 = clmul(sy0, st.h0); 183 const sz1 = clmul(sy1, st.h1); 184 const sz2 = clmul(sy2, st.h2) ^ sz0 ^ sz1; 185 const sz0h = clmul(sy0r, st.h0r); 186 const sz1h = clmul(sy1r, st.h1r); 187 const sz2h = clmul(sy2r, st.h2r) ^ sz0h ^ sz1h; 188 189 // ((B0 * H^2) + B1 * H) (mod M) 190 z0 ^= sz0; 191 z1 ^= sz1; 192 z2 ^= sz2; 193 z0h ^= sz0h; 194 z1h ^= sz1h; 195 z2h ^= sz2h; 196 z0h = @bitReverse(u64, z0h) >> 1; 197 z1h = @bitReverse(u64, z1h) >> 1; 198 z2h = @bitReverse(u64, z2h) >> 1; 199 200 var v3 = z1h; 201 var v2 = z1 ^ z2h; 202 var v1 = z0h ^ z2; 203 var v0 = z0; 204 205 v3 = (v3 << 1) | (v2 >> 63); 206 v2 = (v2 << 1) | (v1 >> 63); 207 v1 = (v1 << 1) | (v0 >> 63); 208 v0 = (v0 << 1); 209 210 v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7); 211 v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57); 212 y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7); 213 y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57); 214 } 215 } 216 217 // single block 218 while (i + 16 <= msg.len) : (i += 16) { 219 y1 ^= mem.readIntBig(u64, msg[i..][0..8]); 220 y0 ^= mem.readIntBig(u64, msg[i..][8..16]); 221 222 const y1r = @bitReverse(u64, y1); 223 const y0r = @bitReverse(u64, y0); 224 const y2 = y0 ^ y1; 225 const y2r = y0r ^ y1r; 226 227 const z0 = clmul(y0, st.h0); 228 const z1 = clmul(y1, st.h1); 229 var z2 = clmul(y2, st.h2) ^ z0 ^ z1; 230 var z0h = clmul(y0r, st.h0r); 231 var z1h = clmul(y1r, st.h1r); 232 var z2h = clmul(y2r, st.h2r) ^ z0h ^ z1h; 233 z0h = @bitReverse(u64, z0h) >> 1; 234 z1h = @bitReverse(u64, z1h) >> 1; 235 z2h = @bitReverse(u64, z2h) >> 1; 236 237 // shift & reduce 238 var v3 = z1h; 239 var v2 = z1 ^ z2h; 240 var v1 = z0h ^ z2; 241 var v0 = z0; 242 243 v3 = (v3 << 1) | (v2 >> 63); 244 v2 = (v2 << 1) | (v1 >> 63); 245 v1 = (v1 << 1) | (v0 >> 63); 246 v0 = (v0 << 1); 247 248 v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7); 249 v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57); 250 y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7); 251 y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57); 252 } 253 st.y1 = y1; 254 st.y0 = y0; 255 } 256 257 pub fn update(st: *Ghash, m: []const u8) void { 258 var mb = m; 259 260 if (st.leftover > 0) { 261 const want = math.min(block_length - st.leftover, mb.len); 262 const mc = mb[0..want]; 263 for (mc) |x, i| { 264 st.buf[st.leftover + i] = x; 265 } 266 mb = mb[want..]; 267 st.leftover += want; 268 if (st.leftover < block_length) { 269 return; 270 } 271 st.blocks(&st.buf); 272 st.leftover = 0; 273 } 274 if (mb.len >= block_length) { 275 const want = mb.len & ~(block_length - 1); 276 st.blocks(mb[0..want]); 277 mb = mb[want..]; 278 } 279 if (mb.len > 0) { 280 for (mb) |x, i| { 281 st.buf[st.leftover + i] = x; 282 } 283 st.leftover += mb.len; 284 } 285 } 286 287 /// Zero-pad to align the next input to the first byte of a block 288 pub fn pad(st: *Ghash) void { 289 if (st.leftover == 0) { 290 return; 291 } 292 var i = st.leftover; 293 while (i < block_length) : (i += 1) { 294 st.buf[i] = 0; 295 } 296 st.blocks(&st.buf); 297 st.leftover = 0; 298 } 299 300 pub fn final(st: *Ghash, out: *[mac_length]u8) void { 301 st.pad(); 302 mem.writeIntBig(u64, out[0..8], st.y1); 303 mem.writeIntBig(u64, out[8..16], st.y0); 304 305 utils.secureZero(u8, @ptrCast([*]u8, st)[0..@sizeOf(Ghash)]); 306 } 307 308 pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [key_length]u8) void { 309 var st = Ghash.init(key); 310 st.update(msg); 311 st.final(out); 312 } 313}; 314 315const htest = @import("test.zig"); 316 317test "ghash" { 318 const key = [_]u8{0x42} ** 16; 319 const m = [_]u8{0x69} ** 256; 320 321 var st = Ghash.init(&key); 322 st.update(&m); 323 var out: [16]u8 = undefined; 324 st.final(&out); 325 try htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out); 326 327 st = Ghash.init(&key); 328 st.update(m[0..100]); 329 st.update(m[100..]); 330 st.final(&out); 331 try htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out); 332} 333