1 /*
2 * Copyright 2019 The libgav1 Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_
18 #define LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_
19
20 #include "src/utils/compiler_attributes.h"
21 #include "src/utils/cpu.h"
22
23 #if LIBGAV1_ENABLE_SSE4_1
24
25 #include <emmintrin.h>
26 #include <smmintrin.h>
27
28 #include <cassert>
29 #include <cstddef>
30 #include <cstdint>
31 #include <cstdlib>
32 #include <cstring>
33
34 #if 0
35 #include <cinttypes>
36 #include <cstdio>
37
38 // Quite useful macro for debugging. Left here for convenience.
39 inline void PrintReg(const __m128i r, const char* const name, int size) {
40 int n;
41 union {
42 __m128i r;
43 uint8_t i8[16];
44 uint16_t i16[8];
45 uint32_t i32[4];
46 uint64_t i64[2];
47 } tmp;
48 tmp.r = r;
49 fprintf(stderr, "%s\t: ", name);
50 if (size == 8) {
51 for (n = 0; n < 16; ++n) fprintf(stderr, "%.2x ", tmp.i8[n]);
52 } else if (size == 16) {
53 for (n = 0; n < 8; ++n) fprintf(stderr, "%.4x ", tmp.i16[n]);
54 } else if (size == 32) {
55 for (n = 0; n < 4; ++n) fprintf(stderr, "%.8x ", tmp.i32[n]);
56 } else {
57 for (n = 0; n < 2; ++n)
58 fprintf(stderr, "%.16" PRIx64 " ", static_cast<uint64_t>(tmp.i64[n]));
59 }
60 fprintf(stderr, "\n");
61 }
62
63 inline void PrintReg(const int r, const char* const name) {
64 fprintf(stderr, "%s: %d\n", name, r);
65 }
66
67 inline void PrintRegX(const int r, const char* const name) {
68 fprintf(stderr, "%s: %.8x\n", name, r);
69 }
70
71 #define PR(var, N) PrintReg(var, #var, N)
72 #define PD(var) PrintReg(var, #var);
73 #define PX(var) PrintRegX(var, #var);
74 #endif // 0
75
76 namespace libgav1 {
77 namespace dsp {
78
79 //------------------------------------------------------------------------------
80 // Load functions.
81
Load2(const void * src)82 inline __m128i Load2(const void* src) {
83 int16_t val;
84 memcpy(&val, src, sizeof(val));
85 return _mm_cvtsi32_si128(val);
86 }
87
Load2x2(const void * src1,const void * src2)88 inline __m128i Load2x2(const void* src1, const void* src2) {
89 uint16_t val1;
90 uint16_t val2;
91 memcpy(&val1, src1, sizeof(val1));
92 memcpy(&val2, src2, sizeof(val2));
93 return _mm_cvtsi32_si128(val1 | (val2 << 16));
94 }
95
96 // Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1.
97 template <int lane>
Load2(const void * const buf,__m128i val)98 inline __m128i Load2(const void* const buf, __m128i val) {
99 uint16_t temp;
100 memcpy(&temp, buf, 2);
101 return _mm_insert_epi16(val, temp, lane);
102 }
103
Load4(const void * src)104 inline __m128i Load4(const void* src) {
105 // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32
106 // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a
107 // movss instruction.
108 //
109 // Until compiler support of _mm_loadu_si32 is widespread, use of
110 // _mm_loadu_si32 is banned.
111 int val;
112 memcpy(&val, src, sizeof(val));
113 return _mm_cvtsi32_si128(val);
114 }
115
Load4x2(const void * src1,const void * src2)116 inline __m128i Load4x2(const void* src1, const void* src2) {
117 // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32
118 // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a
119 // movss instruction.
120 //
121 // Until compiler support of _mm_loadu_si32 is widespread, use of
122 // _mm_loadu_si32 is banned.
123 int val1, val2;
124 memcpy(&val1, src1, sizeof(val1));
125 memcpy(&val2, src2, sizeof(val2));
126 return _mm_insert_epi32(_mm_cvtsi32_si128(val1), val2, 1);
127 }
128
LoadLo8(const void * a)129 inline __m128i LoadLo8(const void* a) {
130 return _mm_loadl_epi64(static_cast<const __m128i*>(a));
131 }
132
LoadHi8(const __m128i v,const void * a)133 inline __m128i LoadHi8(const __m128i v, const void* a) {
134 const __m128 x =
135 _mm_loadh_pi(_mm_castsi128_ps(v), static_cast<const __m64*>(a));
136 return _mm_castps_si128(x);
137 }
138
LoadUnaligned16(const void * a)139 inline __m128i LoadUnaligned16(const void* a) {
140 return _mm_loadu_si128(static_cast<const __m128i*>(a));
141 }
142
LoadAligned16(const void * a)143 inline __m128i LoadAligned16(const void* a) {
144 assert((reinterpret_cast<uintptr_t>(a) & 0xf) == 0);
145 return _mm_load_si128(static_cast<const __m128i*>(a));
146 }
147
148 //------------------------------------------------------------------------------
149 // Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning.
150
MaskOverreads(const __m128i source,const ptrdiff_t over_read_in_bytes)151 inline __m128i MaskOverreads(const __m128i source,
152 const ptrdiff_t over_read_in_bytes) {
153 __m128i dst = source;
154 #if LIBGAV1_MSAN
155 if (over_read_in_bytes > 0) {
156 __m128i mask = _mm_set1_epi8(-1);
157 for (ptrdiff_t i = 0; i < over_read_in_bytes; ++i) {
158 mask = _mm_srli_si128(mask, 1);
159 }
160 dst = _mm_and_si128(dst, mask);
161 }
162 #else
163 static_cast<void>(over_read_in_bytes);
164 #endif
165 return dst;
166 }
167
LoadLo8Msan(const void * const source,const ptrdiff_t over_read_in_bytes)168 inline __m128i LoadLo8Msan(const void* const source,
169 const ptrdiff_t over_read_in_bytes) {
170 return MaskOverreads(LoadLo8(source), over_read_in_bytes + 8);
171 }
172
LoadHi8Msan(const __m128i v,const void * source,const ptrdiff_t over_read_in_bytes)173 inline __m128i LoadHi8Msan(const __m128i v, const void* source,
174 const ptrdiff_t over_read_in_bytes) {
175 return MaskOverreads(LoadHi8(v, source), over_read_in_bytes);
176 }
177
LoadAligned16Msan(const void * const source,const ptrdiff_t over_read_in_bytes)178 inline __m128i LoadAligned16Msan(const void* const source,
179 const ptrdiff_t over_read_in_bytes) {
180 return MaskOverreads(LoadAligned16(source), over_read_in_bytes);
181 }
182
LoadUnaligned16Msan(const void * const source,const ptrdiff_t over_read_in_bytes)183 inline __m128i LoadUnaligned16Msan(const void* const source,
184 const ptrdiff_t over_read_in_bytes) {
185 return MaskOverreads(LoadUnaligned16(source), over_read_in_bytes);
186 }
187
188 //------------------------------------------------------------------------------
189 // Store functions.
190
Store2(void * dst,const __m128i x)191 inline void Store2(void* dst, const __m128i x) {
192 const int val = _mm_cvtsi128_si32(x);
193 memcpy(dst, &val, 2);
194 }
195
Store4(void * dst,const __m128i x)196 inline void Store4(void* dst, const __m128i x) {
197 const int val = _mm_cvtsi128_si32(x);
198 memcpy(dst, &val, sizeof(val));
199 }
200
StoreLo8(void * a,const __m128i v)201 inline void StoreLo8(void* a, const __m128i v) {
202 _mm_storel_epi64(static_cast<__m128i*>(a), v);
203 }
204
StoreHi8(void * a,const __m128i v)205 inline void StoreHi8(void* a, const __m128i v) {
206 _mm_storeh_pi(static_cast<__m64*>(a), _mm_castsi128_ps(v));
207 }
208
StoreAligned16(void * a,const __m128i v)209 inline void StoreAligned16(void* a, const __m128i v) {
210 _mm_store_si128(static_cast<__m128i*>(a), v);
211 }
212
StoreUnaligned16(void * a,const __m128i v)213 inline void StoreUnaligned16(void* a, const __m128i v) {
214 _mm_storeu_si128(static_cast<__m128i*>(a), v);
215 }
216
217 //------------------------------------------------------------------------------
218 // Arithmetic utilities.
219
RightShiftWithRounding_U16(const __m128i v_val_d,int bits)220 inline __m128i RightShiftWithRounding_U16(const __m128i v_val_d, int bits) {
221 assert(bits <= 16);
222 const __m128i v_bias_d =
223 _mm_set1_epi16(static_cast<int16_t>((1 << bits) >> 1));
224 const __m128i v_tmp_d = _mm_add_epi16(v_val_d, v_bias_d);
225 return _mm_srli_epi16(v_tmp_d, bits);
226 }
227
RightShiftWithRounding_S16(const __m128i v_val_d,int bits)228 inline __m128i RightShiftWithRounding_S16(const __m128i v_val_d, int bits) {
229 assert(bits <= 16);
230 const __m128i v_bias_d =
231 _mm_set1_epi16(static_cast<int16_t>((1 << bits) >> 1));
232 const __m128i v_tmp_d = _mm_add_epi16(v_val_d, v_bias_d);
233 return _mm_srai_epi16(v_tmp_d, bits);
234 }
235
RightShiftWithRounding_U32(const __m128i v_val_d,int bits)236 inline __m128i RightShiftWithRounding_U32(const __m128i v_val_d, int bits) {
237 const __m128i v_bias_d = _mm_set1_epi32((1 << bits) >> 1);
238 const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d);
239 return _mm_srli_epi32(v_tmp_d, bits);
240 }
241
RightShiftWithRounding_S32(const __m128i v_val_d,int bits)242 inline __m128i RightShiftWithRounding_S32(const __m128i v_val_d, int bits) {
243 const __m128i v_bias_d = _mm_set1_epi32((1 << bits) >> 1);
244 const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d);
245 return _mm_srai_epi32(v_tmp_d, bits);
246 }
247
248 //------------------------------------------------------------------------------
249 // Masking utilities
MaskHighNBytes(int n)250 inline __m128i MaskHighNBytes(int n) {
251 static constexpr uint8_t kMask[32] = {
252 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
253 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255,
254 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255 };
256
257 return LoadUnaligned16(kMask + n);
258 }
259
260 } // namespace dsp
261 } // namespace libgav1
262
263 #endif // LIBGAV1_ENABLE_SSE4_1
264 #endif // LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_
265