1 /*
2  * Copyright 2019 The libgav1 Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_
18 #define LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_
19 
20 #include "src/utils/compiler_attributes.h"
21 #include "src/utils/cpu.h"
22 
23 #if LIBGAV1_ENABLE_SSE4_1
24 
25 #include <emmintrin.h>
26 #include <smmintrin.h>
27 
28 #include <cassert>
29 #include <cstddef>
30 #include <cstdint>
31 #include <cstdlib>
32 #include <cstring>
33 
34 #if 0
35 #include <cinttypes>
36 #include <cstdio>
37 
38 // Quite useful macro for debugging. Left here for convenience.
39 inline void PrintReg(const __m128i r, const char* const name, int size) {
40   int n;
41   union {
42     __m128i r;
43     uint8_t i8[16];
44     uint16_t i16[8];
45     uint32_t i32[4];
46     uint64_t i64[2];
47   } tmp;
48   tmp.r = r;
49   fprintf(stderr, "%s\t: ", name);
50   if (size == 8) {
51     for (n = 0; n < 16; ++n) fprintf(stderr, "%.2x ", tmp.i8[n]);
52   } else if (size == 16) {
53     for (n = 0; n < 8; ++n) fprintf(stderr, "%.4x ", tmp.i16[n]);
54   } else if (size == 32) {
55     for (n = 0; n < 4; ++n) fprintf(stderr, "%.8x ", tmp.i32[n]);
56   } else {
57     for (n = 0; n < 2; ++n)
58       fprintf(stderr, "%.16" PRIx64 " ", static_cast<uint64_t>(tmp.i64[n]));
59   }
60   fprintf(stderr, "\n");
61 }
62 
63 inline void PrintReg(const int r, const char* const name) {
64   fprintf(stderr, "%s: %d\n", name, r);
65 }
66 
67 inline void PrintRegX(const int r, const char* const name) {
68   fprintf(stderr, "%s: %.8x\n", name, r);
69 }
70 
71 #define PR(var, N) PrintReg(var, #var, N)
72 #define PD(var) PrintReg(var, #var);
73 #define PX(var) PrintRegX(var, #var);
74 #endif  // 0
75 
76 namespace libgav1 {
77 namespace dsp {
78 
79 //------------------------------------------------------------------------------
80 // Load functions.
81 
Load2(const void * src)82 inline __m128i Load2(const void* src) {
83   int16_t val;
84   memcpy(&val, src, sizeof(val));
85   return _mm_cvtsi32_si128(val);
86 }
87 
Load2x2(const void * src1,const void * src2)88 inline __m128i Load2x2(const void* src1, const void* src2) {
89   uint16_t val1;
90   uint16_t val2;
91   memcpy(&val1, src1, sizeof(val1));
92   memcpy(&val2, src2, sizeof(val2));
93   return _mm_cvtsi32_si128(val1 | (val2 << 16));
94 }
95 
96 // Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1.
97 template <int lane>
Load2(const void * const buf,__m128i val)98 inline __m128i Load2(const void* const buf, __m128i val) {
99   uint16_t temp;
100   memcpy(&temp, buf, 2);
101   return _mm_insert_epi16(val, temp, lane);
102 }
103 
Load4(const void * src)104 inline __m128i Load4(const void* src) {
105   // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32
106   // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a
107   // movss instruction.
108   //
109   // Until compiler support of _mm_loadu_si32 is widespread, use of
110   // _mm_loadu_si32 is banned.
111   int val;
112   memcpy(&val, src, sizeof(val));
113   return _mm_cvtsi32_si128(val);
114 }
115 
Load4x2(const void * src1,const void * src2)116 inline __m128i Load4x2(const void* src1, const void* src2) {
117   // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32
118   // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a
119   // movss instruction.
120   //
121   // Until compiler support of _mm_loadu_si32 is widespread, use of
122   // _mm_loadu_si32 is banned.
123   int val1, val2;
124   memcpy(&val1, src1, sizeof(val1));
125   memcpy(&val2, src2, sizeof(val2));
126   return _mm_insert_epi32(_mm_cvtsi32_si128(val1), val2, 1);
127 }
128 
LoadLo8(const void * a)129 inline __m128i LoadLo8(const void* a) {
130   return _mm_loadl_epi64(static_cast<const __m128i*>(a));
131 }
132 
LoadHi8(const __m128i v,const void * a)133 inline __m128i LoadHi8(const __m128i v, const void* a) {
134   const __m128 x =
135       _mm_loadh_pi(_mm_castsi128_ps(v), static_cast<const __m64*>(a));
136   return _mm_castps_si128(x);
137 }
138 
LoadUnaligned16(const void * a)139 inline __m128i LoadUnaligned16(const void* a) {
140   return _mm_loadu_si128(static_cast<const __m128i*>(a));
141 }
142 
LoadAligned16(const void * a)143 inline __m128i LoadAligned16(const void* a) {
144   assert((reinterpret_cast<uintptr_t>(a) & 0xf) == 0);
145   return _mm_load_si128(static_cast<const __m128i*>(a));
146 }
147 
148 //------------------------------------------------------------------------------
149 // Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning.
150 
MaskOverreads(const __m128i source,const ptrdiff_t over_read_in_bytes)151 inline __m128i MaskOverreads(const __m128i source,
152                              const ptrdiff_t over_read_in_bytes) {
153   __m128i dst = source;
154 #if LIBGAV1_MSAN
155   if (over_read_in_bytes > 0) {
156     __m128i mask = _mm_set1_epi8(-1);
157     for (ptrdiff_t i = 0; i < over_read_in_bytes; ++i) {
158       mask = _mm_srli_si128(mask, 1);
159     }
160     dst = _mm_and_si128(dst, mask);
161   }
162 #else
163   static_cast<void>(over_read_in_bytes);
164 #endif
165   return dst;
166 }
167 
LoadLo8Msan(const void * const source,const ptrdiff_t over_read_in_bytes)168 inline __m128i LoadLo8Msan(const void* const source,
169                            const ptrdiff_t over_read_in_bytes) {
170   return MaskOverreads(LoadLo8(source), over_read_in_bytes + 8);
171 }
172 
LoadHi8Msan(const __m128i v,const void * source,const ptrdiff_t over_read_in_bytes)173 inline __m128i LoadHi8Msan(const __m128i v, const void* source,
174                            const ptrdiff_t over_read_in_bytes) {
175   return MaskOverreads(LoadHi8(v, source), over_read_in_bytes);
176 }
177 
LoadAligned16Msan(const void * const source,const ptrdiff_t over_read_in_bytes)178 inline __m128i LoadAligned16Msan(const void* const source,
179                                  const ptrdiff_t over_read_in_bytes) {
180   return MaskOverreads(LoadAligned16(source), over_read_in_bytes);
181 }
182 
LoadUnaligned16Msan(const void * const source,const ptrdiff_t over_read_in_bytes)183 inline __m128i LoadUnaligned16Msan(const void* const source,
184                                    const ptrdiff_t over_read_in_bytes) {
185   return MaskOverreads(LoadUnaligned16(source), over_read_in_bytes);
186 }
187 
188 //------------------------------------------------------------------------------
189 // Store functions.
190 
Store2(void * dst,const __m128i x)191 inline void Store2(void* dst, const __m128i x) {
192   const int val = _mm_cvtsi128_si32(x);
193   memcpy(dst, &val, 2);
194 }
195 
Store4(void * dst,const __m128i x)196 inline void Store4(void* dst, const __m128i x) {
197   const int val = _mm_cvtsi128_si32(x);
198   memcpy(dst, &val, sizeof(val));
199 }
200 
StoreLo8(void * a,const __m128i v)201 inline void StoreLo8(void* a, const __m128i v) {
202   _mm_storel_epi64(static_cast<__m128i*>(a), v);
203 }
204 
StoreHi8(void * a,const __m128i v)205 inline void StoreHi8(void* a, const __m128i v) {
206   _mm_storeh_pi(static_cast<__m64*>(a), _mm_castsi128_ps(v));
207 }
208 
StoreAligned16(void * a,const __m128i v)209 inline void StoreAligned16(void* a, const __m128i v) {
210   _mm_store_si128(static_cast<__m128i*>(a), v);
211 }
212 
StoreUnaligned16(void * a,const __m128i v)213 inline void StoreUnaligned16(void* a, const __m128i v) {
214   _mm_storeu_si128(static_cast<__m128i*>(a), v);
215 }
216 
217 //------------------------------------------------------------------------------
218 // Arithmetic utilities.
219 
RightShiftWithRounding_U16(const __m128i v_val_d,int bits)220 inline __m128i RightShiftWithRounding_U16(const __m128i v_val_d, int bits) {
221   assert(bits <= 16);
222   const __m128i v_bias_d =
223       _mm_set1_epi16(static_cast<int16_t>((1 << bits) >> 1));
224   const __m128i v_tmp_d = _mm_add_epi16(v_val_d, v_bias_d);
225   return _mm_srli_epi16(v_tmp_d, bits);
226 }
227 
RightShiftWithRounding_S16(const __m128i v_val_d,int bits)228 inline __m128i RightShiftWithRounding_S16(const __m128i v_val_d, int bits) {
229   assert(bits <= 16);
230   const __m128i v_bias_d =
231       _mm_set1_epi16(static_cast<int16_t>((1 << bits) >> 1));
232   const __m128i v_tmp_d = _mm_add_epi16(v_val_d, v_bias_d);
233   return _mm_srai_epi16(v_tmp_d, bits);
234 }
235 
RightShiftWithRounding_U32(const __m128i v_val_d,int bits)236 inline __m128i RightShiftWithRounding_U32(const __m128i v_val_d, int bits) {
237   const __m128i v_bias_d = _mm_set1_epi32((1 << bits) >> 1);
238   const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d);
239   return _mm_srli_epi32(v_tmp_d, bits);
240 }
241 
RightShiftWithRounding_S32(const __m128i v_val_d,int bits)242 inline __m128i RightShiftWithRounding_S32(const __m128i v_val_d, int bits) {
243   const __m128i v_bias_d = _mm_set1_epi32((1 << bits) >> 1);
244   const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d);
245   return _mm_srai_epi32(v_tmp_d, bits);
246 }
247 
248 //------------------------------------------------------------------------------
249 // Masking utilities
MaskHighNBytes(int n)250 inline __m128i MaskHighNBytes(int n) {
251   static constexpr uint8_t kMask[32] = {
252       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
253       0,   0,   0,   0,   0,   255, 255, 255, 255, 255, 255,
254       255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255   };
256 
257   return LoadUnaligned16(kMask + n);
258 }
259 
260 }  // namespace dsp
261 }  // namespace libgav1
262 
263 #endif  // LIBGAV1_ENABLE_SSE4_1
264 #endif  // LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_
265