1 /*
2  * AArch64 specific aes acceleration.
3  * SPDX-License-Identifier: GPL-2.0-or-later
4  */
5 
6 #ifndef AARCH64_HOST_CRYPTO_AES_ROUND_H
7 #define AARCH64_HOST_CRYPTO_AES_ROUND_H
8 
9 #include "host/cpuinfo.h"
10 #include <arm_neon.h>
11 
12 #ifdef __ARM_FEATURE_AES
13 # define HAVE_AES_ACCEL  true
14 #else
15 # define HAVE_AES_ACCEL  likely(cpuinfo & CPUINFO_AES)
16 #endif
17 #if !defined(__ARM_FEATURE_AES) && defined(CONFIG_ARM_AES_BUILTIN)
18 # define ATTR_AES_ACCEL  __attribute__((target("+crypto")))
19 #else
20 # define ATTR_AES_ACCEL
21 #endif
22 
23 static inline uint8x16_t aes_accel_bswap(uint8x16_t x)
24 {
25     return vqtbl1q_u8(x, (uint8x16_t){ 15, 14, 13, 12, 11, 10, 9, 8,
26                                         7,  6,  5,  4,  3,  2, 1, 0, });
27 }
28 
29 #ifdef CONFIG_ARM_AES_BUILTIN
30 # define aes_accel_aesd            vaesdq_u8
31 # define aes_accel_aese            vaeseq_u8
32 # define aes_accel_aesmc           vaesmcq_u8
33 # define aes_accel_aesimc          vaesimcq_u8
34 # define aes_accel_aesd_imc(S, K)  vaesimcq_u8(vaesdq_u8(S, K))
35 # define aes_accel_aese_mc(S, K)   vaesmcq_u8(vaeseq_u8(S, K))
36 #else
37 static inline uint8x16_t aes_accel_aesd(uint8x16_t d, uint8x16_t k)
38 {
39     asm(".arch_extension aes\n\t"
40         "aesd %0.16b, %1.16b" : "+w"(d) : "w"(k));
41     return d;
42 }
43 
44 static inline uint8x16_t aes_accel_aese(uint8x16_t d, uint8x16_t k)
45 {
46     asm(".arch_extension aes\n\t"
47         "aese %0.16b, %1.16b" : "+w"(d) : "w"(k));
48     return d;
49 }
50 
51 static inline uint8x16_t aes_accel_aesmc(uint8x16_t d)
52 {
53     asm(".arch_extension aes\n\t"
54         "aesmc %0.16b, %1.16b" : "=w"(d) : "w"(d));
55     return d;
56 }
57 
58 static inline uint8x16_t aes_accel_aesimc(uint8x16_t d)
59 {
60     asm(".arch_extension aes\n\t"
61         "aesimc %0.16b, %1.16b" : "=w"(d) : "w"(d));
62     return d;
63 }
64 
65 /* Most CPUs fuse AESD+AESIMC in the execution pipeline. */
66 static inline uint8x16_t aes_accel_aesd_imc(uint8x16_t d, uint8x16_t k)
67 {
68     asm(".arch_extension aes\n\t"
69         "aesd %0.16b, %1.16b\n\t"
70         "aesimc %0.16b, %0.16b" : "+w"(d) : "w"(k));
71     return d;
72 }
73 
74 /* Most CPUs fuse AESE+AESMC in the execution pipeline. */
75 static inline uint8x16_t aes_accel_aese_mc(uint8x16_t d, uint8x16_t k)
76 {
77     asm(".arch_extension aes\n\t"
78         "aese %0.16b, %1.16b\n\t"
79         "aesmc %0.16b, %0.16b" : "+w"(d) : "w"(k));
80     return d;
81 }
82 #endif /* CONFIG_ARM_AES_BUILTIN */
83 
84 static inline void ATTR_AES_ACCEL
85 aesenc_MC_accel(AESState *ret, const AESState *st, bool be)
86 {
87     uint8x16_t t = (uint8x16_t)st->v;
88 
89     if (be) {
90         t = aes_accel_bswap(t);
91         t = aes_accel_aesmc(t);
92         t = aes_accel_bswap(t);
93     } else {
94         t = aes_accel_aesmc(t);
95     }
96     ret->v = (AESStateVec)t;
97 }
98 
99 static inline void ATTR_AES_ACCEL
100 aesenc_SB_SR_AK_accel(AESState *ret, const AESState *st,
101                       const AESState *rk, bool be)
102 {
103     uint8x16_t t = (uint8x16_t)st->v;
104     uint8x16_t z = { };
105 
106     if (be) {
107         t = aes_accel_bswap(t);
108         t = aes_accel_aese(t, z);
109         t = aes_accel_bswap(t);
110     } else {
111         t = aes_accel_aese(t, z);
112     }
113     ret->v = (AESStateVec)t ^ rk->v;
114 }
115 
116 static inline void ATTR_AES_ACCEL
117 aesenc_SB_SR_MC_AK_accel(AESState *ret, const AESState *st,
118                          const AESState *rk, bool be)
119 {
120     uint8x16_t t = (uint8x16_t)st->v;
121     uint8x16_t z = { };
122 
123     if (be) {
124         t = aes_accel_bswap(t);
125         t = aes_accel_aese_mc(t, z);
126         t = aes_accel_bswap(t);
127     } else {
128         t = aes_accel_aese_mc(t, z);
129     }
130     ret->v = (AESStateVec)t ^ rk->v;
131 }
132 
133 static inline void ATTR_AES_ACCEL
134 aesdec_IMC_accel(AESState *ret, const AESState *st, bool be)
135 {
136     uint8x16_t t = (uint8x16_t)st->v;
137 
138     if (be) {
139         t = aes_accel_bswap(t);
140         t = aes_accel_aesimc(t);
141         t = aes_accel_bswap(t);
142     } else {
143         t = aes_accel_aesimc(t);
144     }
145     ret->v = (AESStateVec)t;
146 }
147 
148 static inline void ATTR_AES_ACCEL
149 aesdec_ISB_ISR_AK_accel(AESState *ret, const AESState *st,
150                         const AESState *rk, bool be)
151 {
152     uint8x16_t t = (uint8x16_t)st->v;
153     uint8x16_t z = { };
154 
155     if (be) {
156         t = aes_accel_bswap(t);
157         t = aes_accel_aesd(t, z);
158         t = aes_accel_bswap(t);
159     } else {
160         t = aes_accel_aesd(t, z);
161     }
162     ret->v = (AESStateVec)t ^ rk->v;
163 }
164 
165 static inline void ATTR_AES_ACCEL
166 aesdec_ISB_ISR_AK_IMC_accel(AESState *ret, const AESState *st,
167                             const AESState *rk, bool be)
168 {
169     uint8x16_t t = (uint8x16_t)st->v;
170     uint8x16_t k = (uint8x16_t)rk->v;
171     uint8x16_t z = { };
172 
173     if (be) {
174         t = aes_accel_bswap(t);
175         k = aes_accel_bswap(k);
176         t = aes_accel_aesd(t, z);
177         t ^= k;
178         t = aes_accel_aesimc(t);
179         t = aes_accel_bswap(t);
180     } else {
181         t = aes_accel_aesd(t, z);
182         t ^= k;
183         t = aes_accel_aesimc(t);
184     }
185     ret->v = (AESStateVec)t;
186 }
187 
188 static inline void ATTR_AES_ACCEL
189 aesdec_ISB_ISR_IMC_AK_accel(AESState *ret, const AESState *st,
190                             const AESState *rk, bool be)
191 {
192     uint8x16_t t = (uint8x16_t)st->v;
193     uint8x16_t z = { };
194 
195     if (be) {
196         t = aes_accel_bswap(t);
197         t = aes_accel_aesd_imc(t, z);
198         t = aes_accel_bswap(t);
199     } else {
200         t = aes_accel_aesd_imc(t, z);
201     }
202     ret->v = (AESStateVec)t ^ rk->v;
203 }
204 
205 #endif /* AARCH64_HOST_CRYPTO_AES_ROUND_H */
206