1 /*
2 * AArch64 specific aes acceleration.
3 * SPDX-License-Identifier: GPL-2.0-or-later
4 */
5
6 #ifndef AARCH64_HOST_CRYPTO_AES_ROUND_H
7 #define AARCH64_HOST_CRYPTO_AES_ROUND_H
8
9 #include "host/cpuinfo.h"
10 #include <arm_neon.h>
11
12 #ifdef __ARM_FEATURE_AES
13 # define HAVE_AES_ACCEL true
14 #else
15 # define HAVE_AES_ACCEL likely(cpuinfo & CPUINFO_AES)
16 #endif
17 #if !defined(__ARM_FEATURE_AES) && defined(CONFIG_ARM_AES_BUILTIN)
18 # define ATTR_AES_ACCEL __attribute__((target("+crypto")))
19 #else
20 # define ATTR_AES_ACCEL
21 #endif
22
aes_accel_bswap(uint8x16_t x)23 static inline uint8x16_t aes_accel_bswap(uint8x16_t x)
24 {
25 return vqtbl1q_u8(x, (uint8x16_t){ 15, 14, 13, 12, 11, 10, 9, 8,
26 7, 6, 5, 4, 3, 2, 1, 0, });
27 }
28
29 #ifdef CONFIG_ARM_AES_BUILTIN
30 # define aes_accel_aesd vaesdq_u8
31 # define aes_accel_aese vaeseq_u8
32 # define aes_accel_aesmc vaesmcq_u8
33 # define aes_accel_aesimc vaesimcq_u8
34 # define aes_accel_aesd_imc(S, K) vaesimcq_u8(vaesdq_u8(S, K))
35 # define aes_accel_aese_mc(S, K) vaesmcq_u8(vaeseq_u8(S, K))
36 #else
aes_accel_aesd(uint8x16_t d,uint8x16_t k)37 static inline uint8x16_t aes_accel_aesd(uint8x16_t d, uint8x16_t k)
38 {
39 asm(".arch_extension aes\n\t"
40 "aesd %0.16b, %1.16b" : "+w"(d) : "w"(k));
41 return d;
42 }
43
aes_accel_aese(uint8x16_t d,uint8x16_t k)44 static inline uint8x16_t aes_accel_aese(uint8x16_t d, uint8x16_t k)
45 {
46 asm(".arch_extension aes\n\t"
47 "aese %0.16b, %1.16b" : "+w"(d) : "w"(k));
48 return d;
49 }
50
aes_accel_aesmc(uint8x16_t d)51 static inline uint8x16_t aes_accel_aesmc(uint8x16_t d)
52 {
53 asm(".arch_extension aes\n\t"
54 "aesmc %0.16b, %1.16b" : "=w"(d) : "w"(d));
55 return d;
56 }
57
aes_accel_aesimc(uint8x16_t d)58 static inline uint8x16_t aes_accel_aesimc(uint8x16_t d)
59 {
60 asm(".arch_extension aes\n\t"
61 "aesimc %0.16b, %1.16b" : "=w"(d) : "w"(d));
62 return d;
63 }
64
65 /* Most CPUs fuse AESD+AESIMC in the execution pipeline. */
aes_accel_aesd_imc(uint8x16_t d,uint8x16_t k)66 static inline uint8x16_t aes_accel_aesd_imc(uint8x16_t d, uint8x16_t k)
67 {
68 asm(".arch_extension aes\n\t"
69 "aesd %0.16b, %1.16b\n\t"
70 "aesimc %0.16b, %0.16b" : "+w"(d) : "w"(k));
71 return d;
72 }
73
74 /* Most CPUs fuse AESE+AESMC in the execution pipeline. */
aes_accel_aese_mc(uint8x16_t d,uint8x16_t k)75 static inline uint8x16_t aes_accel_aese_mc(uint8x16_t d, uint8x16_t k)
76 {
77 asm(".arch_extension aes\n\t"
78 "aese %0.16b, %1.16b\n\t"
79 "aesmc %0.16b, %0.16b" : "+w"(d) : "w"(k));
80 return d;
81 }
82 #endif /* CONFIG_ARM_AES_BUILTIN */
83
84 static inline void ATTR_AES_ACCEL
aesenc_MC_accel(AESState * ret,const AESState * st,bool be)85 aesenc_MC_accel(AESState *ret, const AESState *st, bool be)
86 {
87 uint8x16_t t = (uint8x16_t)st->v;
88
89 if (be) {
90 t = aes_accel_bswap(t);
91 t = aes_accel_aesmc(t);
92 t = aes_accel_bswap(t);
93 } else {
94 t = aes_accel_aesmc(t);
95 }
96 ret->v = (AESStateVec)t;
97 }
98
99 static inline void ATTR_AES_ACCEL
aesenc_SB_SR_AK_accel(AESState * ret,const AESState * st,const AESState * rk,bool be)100 aesenc_SB_SR_AK_accel(AESState *ret, const AESState *st,
101 const AESState *rk, bool be)
102 {
103 uint8x16_t t = (uint8x16_t)st->v;
104 uint8x16_t z = { };
105
106 if (be) {
107 t = aes_accel_bswap(t);
108 t = aes_accel_aese(t, z);
109 t = aes_accel_bswap(t);
110 } else {
111 t = aes_accel_aese(t, z);
112 }
113 ret->v = (AESStateVec)t ^ rk->v;
114 }
115
116 static inline void ATTR_AES_ACCEL
aesenc_SB_SR_MC_AK_accel(AESState * ret,const AESState * st,const AESState * rk,bool be)117 aesenc_SB_SR_MC_AK_accel(AESState *ret, const AESState *st,
118 const AESState *rk, bool be)
119 {
120 uint8x16_t t = (uint8x16_t)st->v;
121 uint8x16_t z = { };
122
123 if (be) {
124 t = aes_accel_bswap(t);
125 t = aes_accel_aese_mc(t, z);
126 t = aes_accel_bswap(t);
127 } else {
128 t = aes_accel_aese_mc(t, z);
129 }
130 ret->v = (AESStateVec)t ^ rk->v;
131 }
132
133 static inline void ATTR_AES_ACCEL
aesdec_IMC_accel(AESState * ret,const AESState * st,bool be)134 aesdec_IMC_accel(AESState *ret, const AESState *st, bool be)
135 {
136 uint8x16_t t = (uint8x16_t)st->v;
137
138 if (be) {
139 t = aes_accel_bswap(t);
140 t = aes_accel_aesimc(t);
141 t = aes_accel_bswap(t);
142 } else {
143 t = aes_accel_aesimc(t);
144 }
145 ret->v = (AESStateVec)t;
146 }
147
148 static inline void ATTR_AES_ACCEL
aesdec_ISB_ISR_AK_accel(AESState * ret,const AESState * st,const AESState * rk,bool be)149 aesdec_ISB_ISR_AK_accel(AESState *ret, const AESState *st,
150 const AESState *rk, bool be)
151 {
152 uint8x16_t t = (uint8x16_t)st->v;
153 uint8x16_t z = { };
154
155 if (be) {
156 t = aes_accel_bswap(t);
157 t = aes_accel_aesd(t, z);
158 t = aes_accel_bswap(t);
159 } else {
160 t = aes_accel_aesd(t, z);
161 }
162 ret->v = (AESStateVec)t ^ rk->v;
163 }
164
165 static inline void ATTR_AES_ACCEL
aesdec_ISB_ISR_AK_IMC_accel(AESState * ret,const AESState * st,const AESState * rk,bool be)166 aesdec_ISB_ISR_AK_IMC_accel(AESState *ret, const AESState *st,
167 const AESState *rk, bool be)
168 {
169 uint8x16_t t = (uint8x16_t)st->v;
170 uint8x16_t k = (uint8x16_t)rk->v;
171 uint8x16_t z = { };
172
173 if (be) {
174 t = aes_accel_bswap(t);
175 k = aes_accel_bswap(k);
176 t = aes_accel_aesd(t, z);
177 t ^= k;
178 t = aes_accel_aesimc(t);
179 t = aes_accel_bswap(t);
180 } else {
181 t = aes_accel_aesd(t, z);
182 t ^= k;
183 t = aes_accel_aesimc(t);
184 }
185 ret->v = (AESStateVec)t;
186 }
187
188 static inline void ATTR_AES_ACCEL
aesdec_ISB_ISR_IMC_AK_accel(AESState * ret,const AESState * st,const AESState * rk,bool be)189 aesdec_ISB_ISR_IMC_AK_accel(AESState *ret, const AESState *st,
190 const AESState *rk, bool be)
191 {
192 uint8x16_t t = (uint8x16_t)st->v;
193 uint8x16_t z = { };
194
195 if (be) {
196 t = aes_accel_bswap(t);
197 t = aes_accel_aesd_imc(t, z);
198 t = aes_accel_bswap(t);
199 } else {
200 t = aes_accel_aesd_imc(t, z);
201 }
202 ret->v = (AESStateVec)t ^ rk->v;
203 }
204
205 #endif /* AARCH64_HOST_CRYPTO_AES_ROUND_H */
206