1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include <immintrin.h>
19 
20 #include "arrow/compute/exec/key_compare.h"
21 #include "arrow/util/bit_util.h"
22 
23 namespace arrow {
24 namespace compute {
25 
26 #if defined(ARROW_HAVE_AVX2)
27 
set_first_n_bytes_avx2(int n)28 inline __m256i set_first_n_bytes_avx2(int n) {
29   constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL;
30   constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL;
31   constexpr uint64_t kByteSequence16To23 = 0x1716151413121110ULL;
32   constexpr uint64_t kByteSequence24To31 = 0x1f1e1d1c1b1a1918ULL;
33 
34   return _mm256_cmpgt_epi8(_mm256_set1_epi8(n),
35                            _mm256_setr_epi64x(kByteSequence0To7, kByteSequence8To15,
36                                               kByteSequence16To23, kByteSequence24To31));
37 }
38 
39 template <bool use_selection>
NullUpdateColumnToRowImp_avx2(uint32_t id_col,uint32_t num_rows_to_compare,const uint16_t * sel_left_maybe_null,const uint32_t * left_to_right_map,KeyEncoder::KeyEncoderContext * ctx,const KeyEncoder::KeyColumnArray & col,const KeyEncoder::KeyRowArray & rows,uint8_t * match_bytevector)40 uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2(
41     uint32_t id_col, uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null,
42     const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx,
43     const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows,
44     uint8_t* match_bytevector) {
45   if (!rows.has_any_nulls(ctx) && !col.data(0)) {
46     return num_rows_to_compare;
47   }
48   if (!col.data(0)) {
49     // Remove rows from the result for which the column value is a null
50     const uint8_t* null_masks = rows.null_masks();
51     uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
52 
53     uint32_t num_processed = 0;
54     constexpr uint32_t unroll = 8;
55     for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
56       __m256i irow_right;
57       if (use_selection) {
58         __m256i irow_left = _mm256_cvtepu16_epi32(
59             _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i));
60         irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4);
61       } else {
62         irow_right =
63             _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i);
64       }
65       __m256i bitid =
66           _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8));
67       bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(id_col));
68       __m256i right =
69           _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1);
70       right = _mm256_and_si256(
71           _mm256_set1_epi32(1),
72           _mm256_srlv_epi32(right, _mm256_and_si256(bitid, _mm256_set1_epi32(7))));
73       __m256i cmp = _mm256_cmpeq_epi32(right, _mm256_setzero_si256());
74       uint32_t result_lo =
75           _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
76       uint32_t result_hi =
77           _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1)));
78       reinterpret_cast<uint64_t*>(match_bytevector)[i] &=
79           result_lo | (static_cast<uint64_t>(result_hi) << 32);
80     }
81     num_processed = num_rows_to_compare / unroll * unroll;
82     return num_processed;
83   } else if (!rows.has_any_nulls(ctx)) {
84     // Remove rows from the result for which the column value on left side is null
85     const uint8_t* non_nulls = col.data(0);
86     ARROW_DCHECK(non_nulls);
87     uint32_t num_processed = 0;
88     constexpr uint32_t unroll = 8;
89     for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
90       __m256i cmp;
91       if (use_selection) {
92         __m256i irow_left = _mm256_cvtepu16_epi32(
93             _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i));
94         irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(col.bit_offset(0)));
95         __m256i left = _mm256_i32gather_epi32((const int*)non_nulls,
96                                               _mm256_srli_epi32(irow_left, 3), 1);
97         left = _mm256_and_si256(
98             _mm256_set1_epi32(1),
99             _mm256_srlv_epi32(left, _mm256_and_si256(irow_left, _mm256_set1_epi32(7))));
100         cmp = _mm256_cmpeq_epi32(left, _mm256_set1_epi32(1));
101       } else {
102         __m256i left = _mm256_cvtepu8_epi32(_mm_set1_epi8(static_cast<uint8_t>(
103             reinterpret_cast<const uint16_t*>(non_nulls + i)[0] >> col.bit_offset(0))));
104         __m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128);
105         cmp = _mm256_cmpeq_epi32(_mm256_and_si256(left, bits), bits);
106       }
107       uint32_t result_lo =
108           _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
109       uint32_t result_hi =
110           _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1)));
111       reinterpret_cast<uint64_t*>(match_bytevector)[i] &=
112           result_lo | (static_cast<uint64_t>(result_hi) << 32);
113       num_processed = num_rows_to_compare / unroll * unroll;
114     }
115     return num_processed;
116   } else {
117     const uint8_t* null_masks = rows.null_masks();
118     uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
119     const uint8_t* non_nulls = col.data(0);
120     ARROW_DCHECK(non_nulls);
121 
122     uint32_t num_processed = 0;
123     constexpr uint32_t unroll = 8;
124     for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
125       __m256i left_null;
126       __m256i irow_right;
127       if (use_selection) {
128         __m256i irow_left = _mm256_cvtepu16_epi32(
129             _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i));
130         irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4);
131         irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(col.bit_offset(0)));
132         __m256i left = _mm256_i32gather_epi32((const int*)non_nulls,
133                                               _mm256_srli_epi32(irow_left, 3), 1);
134         left = _mm256_and_si256(
135             _mm256_set1_epi32(1),
136             _mm256_srlv_epi32(left, _mm256_and_si256(irow_left, _mm256_set1_epi32(7))));
137         left_null = _mm256_cmpeq_epi32(left, _mm256_setzero_si256());
138       } else {
139         irow_right =
140             _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i);
141         __m256i left = _mm256_cvtepu8_epi32(_mm_set1_epi8(static_cast<uint8_t>(
142             reinterpret_cast<const uint16_t*>(non_nulls + i)[0] >> col.bit_offset(0))));
143         __m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128);
144         left_null =
145             _mm256_cmpeq_epi32(_mm256_and_si256(left, bits), _mm256_setzero_si256());
146       }
147       __m256i bitid =
148           _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8));
149       bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(id_col));
150       __m256i right =
151           _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1);
152       right = _mm256_and_si256(
153           _mm256_set1_epi32(1),
154           _mm256_srlv_epi32(right, _mm256_and_si256(bitid, _mm256_set1_epi32(7))));
155       __m256i right_null = _mm256_cmpeq_epi32(right, _mm256_set1_epi32(1));
156 
157       uint64_t left_null_64 =
158           static_cast<uint32_t>(_mm256_movemask_epi8(
159               _mm256_cvtepi32_epi64(_mm256_castsi256_si128(left_null)))) |
160           (static_cast<uint64_t>(static_cast<uint32_t>(_mm256_movemask_epi8(
161                _mm256_cvtepi32_epi64(_mm256_extracti128_si256(left_null, 1)))))
162            << 32);
163 
164       uint64_t right_null_64 =
165           static_cast<uint32_t>(_mm256_movemask_epi8(
166               _mm256_cvtepi32_epi64(_mm256_castsi256_si128(right_null)))) |
167           (static_cast<uint64_t>(static_cast<uint32_t>(_mm256_movemask_epi8(
168                _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_null, 1)))))
169            << 32);
170 
171       reinterpret_cast<uint64_t*>(match_bytevector)[i] |= left_null_64 & right_null_64;
172       reinterpret_cast<uint64_t*>(match_bytevector)[i] &= ~(left_null_64 ^ right_null_64);
173     }
174     num_processed = num_rows_to_compare / unroll * unroll;
175     return num_processed;
176   }
177 }
178 
179 template <bool use_selection, class COMPARE8_FN>
CompareBinaryColumnToRowHelper_avx2(uint32_t offset_within_row,uint32_t num_rows_to_compare,const uint16_t * sel_left_maybe_null,const uint32_t * left_to_right_map,KeyEncoder::KeyEncoderContext * ctx,const KeyEncoder::KeyColumnArray & col,const KeyEncoder::KeyRowArray & rows,uint8_t * match_bytevector,COMPARE8_FN compare8_fn)180 uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2(
181     uint32_t offset_within_row, uint32_t num_rows_to_compare,
182     const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
183     KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
184     const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector,
185     COMPARE8_FN compare8_fn) {
186   bool is_fixed_length = rows.metadata().is_fixed_length;
187   if (is_fixed_length) {
188     uint32_t fixed_length = rows.metadata().fixed_length;
189     const uint8_t* rows_left = col.data(1);
190     const uint8_t* rows_right = rows.data(1);
191     constexpr uint32_t unroll = 8;
192     __m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
193     for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
194       if (use_selection) {
195         irow_left = _mm256_cvtepu16_epi32(
196             _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i));
197       }
198       __m256i irow_right;
199       if (use_selection) {
200         irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4);
201       } else {
202         irow_right =
203             _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i);
204       }
205 
206       __m256i offset_right =
207           _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(fixed_length));
208       offset_right = _mm256_add_epi32(offset_right, _mm256_set1_epi32(offset_within_row));
209 
210       reinterpret_cast<uint64_t*>(match_bytevector)[i] =
211           compare8_fn(rows_left, rows_right, i * unroll, irow_left, offset_right);
212 
213       if (!use_selection) {
214         irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(8));
215       }
216     }
217     return num_rows_to_compare - (num_rows_to_compare % unroll);
218   } else {
219     const uint8_t* rows_left = col.data(1);
220     const uint32_t* offsets_right = rows.offsets();
221     const uint8_t* rows_right = rows.data(2);
222     constexpr uint32_t unroll = 8;
223     __m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
224     for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
225       if (use_selection) {
226         irow_left = _mm256_cvtepu16_epi32(
227             _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i));
228       }
229       __m256i irow_right;
230       if (use_selection) {
231         irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4);
232       } else {
233         irow_right =
234             _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i);
235       }
236       __m256i offset_right =
237           _mm256_i32gather_epi32((const int*)offsets_right, irow_right, 4);
238       offset_right = _mm256_add_epi32(offset_right, _mm256_set1_epi32(offset_within_row));
239 
240       reinterpret_cast<uint64_t*>(match_bytevector)[i] =
241           compare8_fn(rows_left, rows_right, i * unroll, irow_left, offset_right);
242 
243       if (!use_selection) {
244         irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(8));
245       }
246     }
247     return num_rows_to_compare - (num_rows_to_compare % unroll);
248   }
249 }
250 
251 template <int column_width>
CompareSelected8_avx2(const uint8_t * left_base,const uint8_t * right_base,__m256i irow_left,__m256i offset_right,int bit_offset=0)252 inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* right_base,
253                                       __m256i irow_left, __m256i offset_right,
254                                       int bit_offset = 0) {
255   __m256i left;
256   switch (column_width) {
257     case 0:
258       irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(bit_offset));
259       left = _mm256_i32gather_epi32((const int*)left_base,
260                                     _mm256_srli_epi32(irow_left, 3), 1);
261       left = _mm256_and_si256(
262           _mm256_set1_epi32(1),
263           _mm256_srlv_epi32(left, _mm256_and_si256(irow_left, _mm256_set1_epi32(7))));
264       left = _mm256_mullo_epi32(left, _mm256_set1_epi32(0xff));
265       break;
266     case 1:
267       left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 1);
268       left = _mm256_and_si256(left, _mm256_set1_epi32(0xff));
269       break;
270     case 2:
271       left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 2);
272       left = _mm256_and_si256(left, _mm256_set1_epi32(0xff));
273       break;
274     case 4:
275       left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 4);
276       break;
277     default:
278       ARROW_DCHECK(false);
279   }
280 
281   __m256i right = _mm256_i32gather_epi32((const int*)right_base, offset_right, 1);
282   if (column_width != sizeof(uint32_t)) {
283     constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff;
284     right = _mm256_and_si256(right, _mm256_set1_epi32(mask));
285   }
286 
287   __m256i cmp = _mm256_cmpeq_epi32(left, right);
288 
289   uint32_t result_lo =
290       _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
291   uint32_t result_hi =
292       _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1)));
293 
294   return result_lo | (static_cast<uint64_t>(result_hi) << 32);
295 }
296 
297 template <int column_width>
Compare8_avx2(const uint8_t * left_base,const uint8_t * right_base,uint32_t irow_left_first,__m256i offset_right,int bit_offset=0)298 inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_base,
299                               uint32_t irow_left_first, __m256i offset_right,
300                               int bit_offset = 0) {
301   __m256i left;
302   switch (column_width) {
303     case 0: {
304       __m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128);
305       uint32_t start_bit_index = irow_left_first + bit_offset;
306       uint8_t left_bits_8 =
307           (reinterpret_cast<const uint16_t*>(left_base + start_bit_index / 8)[0] >>
308            (start_bit_index % 8)) &
309           0xff;
310       left =
311           _mm256_cmpeq_epi32(_mm256_and_si256(bits, _mm256_set1_epi8(left_bits_8)), bits);
312       left = _mm256_and_si256(left, _mm256_set1_epi32(0xff));
313     } break;
314     case 1:
315       left = _mm256_cvtepu8_epi32(_mm_set1_epi64x(
316           reinterpret_cast<const uint64_t*>(left_base)[irow_left_first / 8]));
317       break;
318     case 2:
319       left = _mm256_cvtepu16_epi32(_mm_loadu_si128(
320           reinterpret_cast<const __m128i*>(left_base) + irow_left_first / 8));
321       break;
322     case 4:
323       left = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_base) +
324                                 irow_left_first / 8);
325       break;
326     default:
327       ARROW_DCHECK(false);
328   }
329 
330   __m256i right = _mm256_i32gather_epi32((const int*)right_base, offset_right, 1);
331   if (column_width != sizeof(uint32_t)) {
332     constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff;
333     right = _mm256_and_si256(right, _mm256_set1_epi32(mask));
334   }
335 
336   __m256i cmp = _mm256_cmpeq_epi32(left, right);
337 
338   uint32_t result_lo =
339       _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
340   uint32_t result_hi =
341       _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1)));
342 
343   return result_lo | (static_cast<uint64_t>(result_hi) << 32);
344 }
345 
346 template <bool use_selection>
Compare8_64bit_avx2(const uint8_t * left_base,const uint8_t * right_base,__m256i irow_left,uint32_t irow_left_first,__m256i offset_right)347 inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* right_base,
348                                     __m256i irow_left, uint32_t irow_left_first,
349                                     __m256i offset_right) {
350   auto left_base_i64 =
351       reinterpret_cast<const arrow::util::int64_for_gather_t*>(left_base);
352   __m256i left_lo =
353       _mm256_i32gather_epi64(left_base_i64, _mm256_castsi256_si128(irow_left), 8);
354   __m256i left_hi =
355       _mm256_i32gather_epi64(left_base_i64, _mm256_extracti128_si256(irow_left, 1), 8);
356   if (use_selection) {
357     left_lo = _mm256_i32gather_epi64(left_base_i64, _mm256_castsi256_si128(irow_left), 8);
358     left_hi =
359         _mm256_i32gather_epi64(left_base_i64, _mm256_extracti128_si256(irow_left, 1), 8);
360   } else {
361     left_lo = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_base) +
362                                  irow_left_first / 4);
363     left_hi = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_base) +
364                                  irow_left_first / 4 + 1);
365   }
366   auto right_base_i64 =
367       reinterpret_cast<const arrow::util::int64_for_gather_t*>(right_base);
368   __m256i right_lo =
369       _mm256_i32gather_epi64(right_base_i64, _mm256_castsi256_si128(offset_right), 1);
370   __m256i right_hi = _mm256_i32gather_epi64(right_base_i64,
371                                             _mm256_extracti128_si256(offset_right, 1), 1);
372   uint32_t result_lo = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_lo, right_lo));
373   uint32_t result_hi = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_hi, right_hi));
374   return result_lo | (static_cast<uint64_t>(result_hi) << 32);
375 }
376 
377 template <bool use_selection>
Compare8_Binary_avx2(uint32_t length,const uint8_t * left_base,const uint8_t * right_base,__m256i irow_left,uint32_t irow_left_first,__m256i offset_right)378 inline uint64_t Compare8_Binary_avx2(uint32_t length, const uint8_t* left_base,
379                                      const uint8_t* right_base, __m256i irow_left,
380                                      uint32_t irow_left_first, __m256i offset_right) {
381   uint32_t irow_left_array[8];
382   uint32_t offset_right_array[8];
383   if (use_selection) {
384     _mm256_storeu_si256(reinterpret_cast<__m256i*>(irow_left_array), irow_left);
385   }
386   _mm256_storeu_si256(reinterpret_cast<__m256i*>(offset_right_array), offset_right);
387 
388   // Non-zero length guarantees no underflow
389   int32_t num_loops_less_one = (static_cast<int32_t>(length) + 31) / 32 - 1;
390 
391   __m256i tail_mask = set_first_n_bytes_avx2(length - num_loops_less_one * 32);
392 
393   uint64_t result = 0;
394   for (uint32_t irow = 0; irow < 8; ++irow) {
395     const __m256i* key_left_ptr = reinterpret_cast<const __m256i*>(
396         left_base +
397         (use_selection ? irow_left_array[irow] : irow_left_first + irow) * length);
398     const __m256i* key_right_ptr =
399         reinterpret_cast<const __m256i*>(right_base + offset_right_array[irow]);
400     __m256i result_or = _mm256_setzero_si256();
401     int32_t i;
402     // length cannot be zero
403     for (i = 0; i < num_loops_less_one; ++i) {
404       __m256i key_left = _mm256_loadu_si256(key_left_ptr + i);
405       __m256i key_right = _mm256_loadu_si256(key_right_ptr + i);
406       result_or = _mm256_or_si256(result_or, _mm256_xor_si256(key_left, key_right));
407     }
408     __m256i key_left = _mm256_loadu_si256(key_left_ptr + i);
409     __m256i key_right = _mm256_loadu_si256(key_right_ptr + i);
410     result_or = _mm256_or_si256(
411         result_or, _mm256_and_si256(tail_mask, _mm256_xor_si256(key_left, key_right)));
412     uint64_t result_single = _mm256_testz_si256(result_or, result_or) * 0xff;
413     result |= result_single << (8 * irow);
414   }
415   return result;
416 }
417 
418 template <bool use_selection>
CompareBinaryColumnToRowImp_avx2(uint32_t offset_within_row,uint32_t num_rows_to_compare,const uint16_t * sel_left_maybe_null,const uint32_t * left_to_right_map,KeyEncoder::KeyEncoderContext * ctx,const KeyEncoder::KeyColumnArray & col,const KeyEncoder::KeyRowArray & rows,uint8_t * match_bytevector)419 uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2(
420     uint32_t offset_within_row, uint32_t num_rows_to_compare,
421     const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
422     KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
423     const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
424   uint32_t col_width = col.metadata().fixed_length;
425   if (col_width == 0) {
426     int bit_offset = col.bit_offset(1);
427     return CompareBinaryColumnToRowHelper_avx2<use_selection>(
428         offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
429         ctx, col, rows, match_bytevector,
430         [bit_offset](const uint8_t* left_base, const uint8_t* right_base,
431                      uint32_t irow_left_base, __m256i irow_left, __m256i offset_right) {
432           if (use_selection) {
433             return CompareSelected8_avx2<0>(left_base, right_base, irow_left,
434                                             offset_right, bit_offset);
435           } else {
436             return Compare8_avx2<0>(left_base, right_base, irow_left_base, offset_right,
437                                     bit_offset);
438           }
439         });
440   } else if (col_width == 1) {
441     return CompareBinaryColumnToRowHelper_avx2<use_selection>(
442         offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
443         ctx, col, rows, match_bytevector,
444         [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base,
445            __m256i irow_left, __m256i offset_right) {
446           if (use_selection) {
447             return CompareSelected8_avx2<1>(left_base, right_base, irow_left,
448                                             offset_right);
449           } else {
450             return Compare8_avx2<1>(left_base, right_base, irow_left_base, offset_right);
451           }
452         });
453   } else if (col_width == 2) {
454     return CompareBinaryColumnToRowHelper_avx2<use_selection>(
455         offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
456         ctx, col, rows, match_bytevector,
457         [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base,
458            __m256i irow_left, __m256i offset_right) {
459           if (use_selection) {
460             return CompareSelected8_avx2<2>(left_base, right_base, irow_left,
461                                             offset_right);
462           } else {
463             return Compare8_avx2<2>(left_base, right_base, irow_left_base, offset_right);
464           }
465         });
466   } else if (col_width == 4) {
467     return CompareBinaryColumnToRowHelper_avx2<use_selection>(
468         offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
469         ctx, col, rows, match_bytevector,
470         [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base,
471            __m256i irow_left, __m256i offset_right) {
472           if (use_selection) {
473             return CompareSelected8_avx2<4>(left_base, right_base, irow_left,
474                                             offset_right);
475           } else {
476             return Compare8_avx2<4>(left_base, right_base, irow_left_base, offset_right);
477           }
478         });
479   } else if (col_width == 8) {
480     return CompareBinaryColumnToRowHelper_avx2<use_selection>(
481         offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
482         ctx, col, rows, match_bytevector,
483         [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base,
484            __m256i irow_left, __m256i offset_right) {
485           return Compare8_64bit_avx2<use_selection>(left_base, right_base, irow_left,
486                                                     irow_left_base, offset_right);
487         });
488   } else {
489     return CompareBinaryColumnToRowHelper_avx2<use_selection>(
490         offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
491         ctx, col, rows, match_bytevector,
492         [&col](const uint8_t* left_base, const uint8_t* right_base,
493                uint32_t irow_left_base, __m256i irow_left, __m256i offset_right) {
494           uint32_t length = col.metadata().fixed_length;
495           return Compare8_Binary_avx2<use_selection>(
496               length, left_base, right_base, irow_left, irow_left_base, offset_right);
497         });
498   }
499 }
500 
501 // Overwrites the match_bytevector instead of updating it
502 template <bool use_selection, bool is_first_varbinary_col>
CompareVarBinaryColumnToRowImp_avx2(uint32_t id_varbinary_col,uint32_t num_rows_to_compare,const uint16_t * sel_left_maybe_null,const uint32_t * left_to_right_map,KeyEncoder::KeyEncoderContext * ctx,const KeyEncoder::KeyColumnArray & col,const KeyEncoder::KeyRowArray & rows,uint8_t * match_bytevector)503 void KeyCompare::CompareVarBinaryColumnToRowImp_avx2(
504     uint32_t id_varbinary_col, uint32_t num_rows_to_compare,
505     const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
506     KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
507     const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
508   const uint32_t* offsets_left = col.offsets();
509   const uint32_t* offsets_right = rows.offsets();
510   const uint8_t* rows_left = col.data(2);
511   const uint8_t* rows_right = rows.data(2);
512   for (uint32_t i = 0; i < num_rows_to_compare; ++i) {
513     uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
514     uint32_t irow_right = left_to_right_map[irow_left];
515     uint32_t begin_left = offsets_left[irow_left];
516     uint32_t length_left = offsets_left[irow_left + 1] - begin_left;
517     uint32_t begin_right = offsets_right[irow_right];
518     uint32_t length_right;
519     uint32_t offset_within_row;
520     if (!is_first_varbinary_col) {
521       rows.metadata().nth_varbinary_offset_and_length(
522           rows_right + begin_right, id_varbinary_col, &offset_within_row, &length_right);
523     } else {
524       rows.metadata().first_varbinary_offset_and_length(
525           rows_right + begin_right, &offset_within_row, &length_right);
526     }
527     begin_right += offset_within_row;
528 
529     __m256i result_or = _mm256_setzero_si256();
530     uint32_t length = std::min(length_left, length_right);
531     if (length > 0) {
532       const __m256i* key_left_ptr =
533           reinterpret_cast<const __m256i*>(rows_left + begin_left);
534       const __m256i* key_right_ptr =
535           reinterpret_cast<const __m256i*>(rows_right + begin_right);
536       int32_t j;
537       // length can be zero
538       for (j = 0; j < (static_cast<int32_t>(length) + 31) / 32 - 1; ++j) {
539         __m256i key_left = _mm256_loadu_si256(key_left_ptr + j);
540         __m256i key_right = _mm256_loadu_si256(key_right_ptr + j);
541         result_or = _mm256_or_si256(result_or, _mm256_xor_si256(key_left, key_right));
542       }
543 
544       __m256i tail_mask = set_first_n_bytes_avx2(length - j * 32);
545 
546       __m256i key_left = _mm256_loadu_si256(key_left_ptr + j);
547       __m256i key_right = _mm256_loadu_si256(key_right_ptr + j);
548       result_or = _mm256_or_si256(
549           result_or, _mm256_and_si256(tail_mask, _mm256_xor_si256(key_left, key_right)));
550     }
551     int result = _mm256_testz_si256(result_or, result_or) * 0xff;
552     result *= (length_left == length_right ? 1 : 0);
553     match_bytevector[i] = result;
554   }
555 }
556 
AndByteVectors_avx2(uint32_t num_elements,uint8_t * bytevector_A,const uint8_t * bytevector_B)557 uint32_t KeyCompare::AndByteVectors_avx2(uint32_t num_elements, uint8_t* bytevector_A,
558                                          const uint8_t* bytevector_B) {
559   constexpr int unroll = 32;
560   for (uint32_t i = 0; i < num_elements / unroll; ++i) {
561     __m256i result = _mm256_and_si256(
562         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bytevector_A) + i),
563         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bytevector_B) + i));
564     _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytevector_A) + i, result);
565   }
566   return (num_elements - (num_elements % unroll));
567 }
568 
NullUpdateColumnToRow_avx2(bool use_selection,uint32_t id_col,uint32_t num_rows_to_compare,const uint16_t * sel_left_maybe_null,const uint32_t * left_to_right_map,KeyEncoder::KeyEncoderContext * ctx,const KeyEncoder::KeyColumnArray & col,const KeyEncoder::KeyRowArray & rows,uint8_t * match_bytevector)569 uint32_t KeyCompare::NullUpdateColumnToRow_avx2(
570     bool use_selection, uint32_t id_col, uint32_t num_rows_to_compare,
571     const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
572     KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
573     const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
574   if (use_selection) {
575     return NullUpdateColumnToRowImp_avx2<true>(id_col, num_rows_to_compare,
576                                                sel_left_maybe_null, left_to_right_map,
577                                                ctx, col, rows, match_bytevector);
578   } else {
579     return NullUpdateColumnToRowImp_avx2<false>(id_col, num_rows_to_compare,
580                                                 sel_left_maybe_null, left_to_right_map,
581                                                 ctx, col, rows, match_bytevector);
582   }
583 }
584 
CompareBinaryColumnToRow_avx2(bool use_selection,uint32_t offset_within_row,uint32_t num_rows_to_compare,const uint16_t * sel_left_maybe_null,const uint32_t * left_to_right_map,KeyEncoder::KeyEncoderContext * ctx,const KeyEncoder::KeyColumnArray & col,const KeyEncoder::KeyRowArray & rows,uint8_t * match_bytevector)585 uint32_t KeyCompare::CompareBinaryColumnToRow_avx2(
586     bool use_selection, uint32_t offset_within_row, uint32_t num_rows_to_compare,
587     const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
588     KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
589     const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
590   if (use_selection) {
591     return CompareBinaryColumnToRowImp_avx2<true>(offset_within_row, num_rows_to_compare,
592                                                   sel_left_maybe_null, left_to_right_map,
593                                                   ctx, col, rows, match_bytevector);
594   } else {
595     return CompareBinaryColumnToRowImp_avx2<false>(offset_within_row, num_rows_to_compare,
596                                                    sel_left_maybe_null, left_to_right_map,
597                                                    ctx, col, rows, match_bytevector);
598   }
599 }
600 
CompareVarBinaryColumnToRow_avx2(bool use_selection,bool is_first_varbinary_col,uint32_t id_varlen_col,uint32_t num_rows_to_compare,const uint16_t * sel_left_maybe_null,const uint32_t * left_to_right_map,KeyEncoder::KeyEncoderContext * ctx,const KeyEncoder::KeyColumnArray & col,const KeyEncoder::KeyRowArray & rows,uint8_t * match_bytevector)601 void KeyCompare::CompareVarBinaryColumnToRow_avx2(
602     bool use_selection, bool is_first_varbinary_col, uint32_t id_varlen_col,
603     uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null,
604     const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx,
605     const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows,
606     uint8_t* match_bytevector) {
607   if (use_selection) {
608     if (is_first_varbinary_col) {
609       CompareVarBinaryColumnToRowImp_avx2<true, true>(
610           id_varlen_col, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx,
611           col, rows, match_bytevector);
612     } else {
613       CompareVarBinaryColumnToRowImp_avx2<true, false>(
614           id_varlen_col, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx,
615           col, rows, match_bytevector);
616     }
617   } else {
618     if (is_first_varbinary_col) {
619       CompareVarBinaryColumnToRowImp_avx2<false, true>(
620           id_varlen_col, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx,
621           col, rows, match_bytevector);
622     } else {
623       CompareVarBinaryColumnToRowImp_avx2<false, false>(
624           id_varlen_col, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx,
625           col, rows, match_bytevector);
626     }
627   }
628 }
629 
630 #endif
631 
632 }  // namespace compute
633 }  // namespace arrow
634