1 // This library is part of PLINK 2, copyright (C) 2005-2020 Shaun Purcell,
2 // Christopher Chang.
3 //
4 // This library is free software: you can redistribute it and/or modify it
5 // under the terms of the GNU Lesser General Public License as published by the
6 // Free Software Foundation; either version 3 of the License, or (at your
7 // option) any later version.
8 //
9 // This library is distributed in the hope that it will be useful, but WITHOUT
10 // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 // FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License
12 // for more details.
13 //
14 // You should have received a copy of the GNU Lesser General Public License
15 // along with this library.  If not, see <http://www.gnu.org/licenses/>.
16 
17 
18 #include "plink2_bits.h"
19 
20 #ifdef __cplusplus
21 namespace plink2 {
22 #endif
23 
24 #if defined(__LP64__) && !defined(USE_AVX2)
25 // No alignment assumptions.
Pack32bTo16bMask(const void * words,uintptr_t ct_32b,void * dest)26 void Pack32bTo16bMask(const void* words, uintptr_t ct_32b, void* dest) {
27   // This is also competitive in the AVX2 case, but never quite beats the
28   // simple loop.  (We'd want to enable a similar function for Ryzen,
29   // processing one 32-byte vector instead of two 16-byte vectors at a time in
30   // the main loop since _mm256_packus_epi16() doesn't do what we want.)
31   const VecW m1 = VCONST_W(kMask5555);
32 #  ifdef USE_SSE42
33   const VecW swap12 = vecw_setr8(
34       0, 1, 4, 5, 2, 3, 6, 7,
35       8, 9, 12, 13, 10, 11, 14, 15);
36 #  else
37   const VecW m2 = VCONST_W(kMask3333);
38 #  endif
39   const VecW m4 = VCONST_W(kMask0F0F);
40   const VecW m8 = VCONST_W(kMask00FF);
41   const VecW* words_valias = R_CAST(const VecW*, words);
42   __m128i* dest_alias = R_CAST(__m128i*, dest);
43   for (uintptr_t vidx = 0; vidx != ct_32b; ++vidx) {
44     VecW vec_lo = vecw_loadu(&(words_valias[2 * vidx])) & m1;
45     VecW vec_hi = vecw_loadu(&(words_valias[2 * vidx + 1])) & m1;
46 #  ifdef USE_SSE42
47     // this right-shift-3 + shuffle shortcut saves two operations.
48     vec_lo = (vec_lo | vecw_srli(vec_lo, 3)) & m4;
49     vec_hi = (vec_hi | vecw_srli(vec_hi, 3)) & m4;
50     vec_lo = vecw_shuffle8(swap12, vec_lo);
51     vec_hi = vecw_shuffle8(swap12, vec_hi);
52 #  else
53     vec_lo = (vec_lo | vecw_srli(vec_lo, 1)) & m2;
54     vec_hi = (vec_hi | vecw_srli(vec_hi, 1)) & m2;
55     vec_lo = (vec_lo | vecw_srli(vec_lo, 2)) & m4;
56     vec_hi = (vec_hi | vecw_srli(vec_hi, 2)) & m4;
57 #  endif
58     vec_lo = (vec_lo | vecw_srli(vec_lo, 4)) & m8;
59     vec_hi = (vec_hi | vecw_srli(vec_hi, 4)) & m8;
60     const __m128i vec_packed = _mm_packus_epi16(R_CAST(__m128i, vec_lo), R_CAST(__m128i, vec_hi));
61     _mm_storeu_si128(&(dest_alias[vidx]), vec_packed);
62   }
63 }
64 #endif
65 
SetAllBits(uintptr_t ct,uintptr_t * bitarr)66 void SetAllBits(uintptr_t ct, uintptr_t* bitarr) {
67   // leaves bits beyond the end unset
68   // ok for ct == 0
69   uintptr_t quotient = ct / kBitsPerWord;
70   uintptr_t remainder = ct % kBitsPerWord;
71   SetAllWArr(quotient, bitarr);
72   if (remainder) {
73     bitarr[quotient] = (k1LU << remainder) - k1LU;
74   }
75 }
76 
BitvecAnd(const uintptr_t * __restrict arg_bitvec,uintptr_t word_ct,uintptr_t * __restrict main_bitvec)77 void BitvecAnd(const uintptr_t* __restrict arg_bitvec, uintptr_t word_ct, uintptr_t* __restrict main_bitvec) {
78   // main_bitvec := main_bitvec AND arg_bitvec
79 #ifdef __LP64__
80   VecW* main_bitvvec_iter = R_CAST(VecW*, main_bitvec);
81   const VecW* arg_bitvvec_iter = R_CAST(const VecW*, arg_bitvec);
82   const uintptr_t full_vec_ct = word_ct / kWordsPerVec;
83   // ok, retested this explicit unroll (Jun 2018) and it's still noticeably
84   // faster for small cases than the simple loop.  sigh.
85   if (full_vec_ct & 1) {
86     *main_bitvvec_iter++ &= *arg_bitvvec_iter++;
87   }
88   if (full_vec_ct & 2) {
89     *main_bitvvec_iter++ &= *arg_bitvvec_iter++;
90     *main_bitvvec_iter++ &= *arg_bitvvec_iter++;
91   }
92   for (uintptr_t ulii = 3; ulii < full_vec_ct; ulii += 4) {
93     *main_bitvvec_iter++ &= *arg_bitvvec_iter++;
94     *main_bitvvec_iter++ &= *arg_bitvvec_iter++;
95     *main_bitvvec_iter++ &= *arg_bitvvec_iter++;
96     *main_bitvvec_iter++ &= *arg_bitvvec_iter++;
97   }
98 #  ifdef USE_AVX2
99   if (word_ct & 2) {
100     const uintptr_t base_idx = full_vec_ct * kWordsPerVec;
101     main_bitvec[base_idx] &= arg_bitvec[base_idx];
102     main_bitvec[base_idx + 1] &= arg_bitvec[base_idx + 1];
103   }
104 #  endif
105   if (word_ct & 1) {
106     main_bitvec[word_ct - 1] &= arg_bitvec[word_ct - 1];
107   }
108 #else
109   for (uintptr_t widx = 0; widx != word_ct; ++widx) {
110     main_bitvec[widx] &= arg_bitvec[widx];
111   }
112 #endif
113 }
114 
BitvecInvmask(const uintptr_t * __restrict exclude_bitvec,uintptr_t word_ct,uintptr_t * __restrict main_bitvec)115 void BitvecInvmask(const uintptr_t* __restrict exclude_bitvec, uintptr_t word_ct, uintptr_t* __restrict main_bitvec) {
116   // main_bitvec := main_bitvec ANDNOT exclude_bitvec
117   // note that this is the reverse of the _mm_andnot() operand order
118 #ifdef __LP64__
119   VecW* main_bitvvec_iter = R_CAST(VecW*, main_bitvec);
120   const VecW* exclude_bitvvec_iter = R_CAST(const VecW*, exclude_bitvec);
121   const uintptr_t full_vec_ct = word_ct / kWordsPerVec;
122   if (full_vec_ct & 1) {
123     *main_bitvvec_iter = vecw_and_notfirst(*exclude_bitvvec_iter++, *main_bitvvec_iter);
124     ++main_bitvvec_iter;
125   }
126   if (full_vec_ct & 2) {
127     *main_bitvvec_iter = vecw_and_notfirst(*exclude_bitvvec_iter++, *main_bitvvec_iter);
128     ++main_bitvvec_iter;
129     *main_bitvvec_iter = vecw_and_notfirst(*exclude_bitvvec_iter++, *main_bitvvec_iter);
130     ++main_bitvvec_iter;
131   }
132   for (uintptr_t ulii = 3; ulii < full_vec_ct; ulii += 4) {
133     *main_bitvvec_iter = vecw_and_notfirst(*exclude_bitvvec_iter++, *main_bitvvec_iter);
134     ++main_bitvvec_iter;
135     *main_bitvvec_iter = vecw_and_notfirst(*exclude_bitvvec_iter++, *main_bitvvec_iter);
136     ++main_bitvvec_iter;
137     *main_bitvvec_iter = vecw_and_notfirst(*exclude_bitvvec_iter++, *main_bitvvec_iter);
138     ++main_bitvvec_iter;
139     *main_bitvvec_iter = vecw_and_notfirst(*exclude_bitvvec_iter++, *main_bitvvec_iter);
140     ++main_bitvvec_iter;
141   }
142 #  ifdef USE_AVX2
143   if (word_ct & 2) {
144     const uintptr_t base_idx = full_vec_ct * kWordsPerVec;
145     main_bitvec[base_idx] &= ~exclude_bitvec[base_idx];
146     main_bitvec[base_idx + 1] &= ~exclude_bitvec[base_idx + 1];
147   }
148 #  endif
149   if (word_ct & 1) {
150     main_bitvec[word_ct - 1] &= ~exclude_bitvec[word_ct - 1];
151   }
152 #else
153   for (uintptr_t widx = 0; widx != word_ct; ++widx) {
154     main_bitvec[widx] &= ~exclude_bitvec[widx];
155   }
156 #endif
157 }
158 
BitvecOr(const uintptr_t * __restrict arg_bitvec,uintptr_t word_ct,uintptr_t * main_bitvec)159 void BitvecOr(const uintptr_t* __restrict arg_bitvec, uintptr_t word_ct, uintptr_t* main_bitvec) {
160   // main_bitvec := main_bitvec OR arg_bitvec
161 #ifdef __LP64__
162   VecW* main_bitvvec_iter = R_CAST(VecW*, main_bitvec);
163   const VecW* arg_bitvvec_iter = R_CAST(const VecW*, arg_bitvec);
164   const uintptr_t full_vec_ct = word_ct / kWordsPerVec;
165   if (full_vec_ct & 1) {
166     *main_bitvvec_iter++ |= (*arg_bitvvec_iter++);
167   }
168   if (full_vec_ct & 2) {
169     *main_bitvvec_iter++ |= (*arg_bitvvec_iter++);
170     *main_bitvvec_iter++ |= (*arg_bitvvec_iter++);
171   }
172   for (uintptr_t ulii = 3; ulii < full_vec_ct; ulii += 4) {
173     *main_bitvvec_iter++ |= (*arg_bitvvec_iter++);
174     *main_bitvvec_iter++ |= (*arg_bitvvec_iter++);
175     *main_bitvvec_iter++ |= (*arg_bitvvec_iter++);
176     *main_bitvvec_iter++ |= (*arg_bitvvec_iter++);
177   }
178 #  ifdef USE_AVX2
179   if (word_ct & 2) {
180     const uintptr_t base_idx = full_vec_ct * kWordsPerVec;
181     main_bitvec[base_idx] |= arg_bitvec[base_idx];
182     main_bitvec[base_idx + 1] |= arg_bitvec[base_idx + 1];
183   }
184 #  endif
185   if (word_ct & 1) {
186     main_bitvec[word_ct - 1] |= arg_bitvec[word_ct - 1];
187   }
188 #else
189   for (uintptr_t widx = 0; widx != word_ct; ++widx) {
190     main_bitvec[widx] |= arg_bitvec[widx];
191   }
192 #endif
193 }
194 
BitvecInvert(uintptr_t word_ct,uintptr_t * main_bitvec)195 void BitvecInvert(uintptr_t word_ct, uintptr_t* main_bitvec) {
196 #ifdef __LP64__
197   VecW* main_bitvvec_iter = R_CAST(VecW*, main_bitvec);
198   const uintptr_t full_vec_ct = word_ct / kWordsPerVec;
199   const VecW all1 = VCONST_W(~k0LU);
200   if (full_vec_ct & 1) {
201     *main_bitvvec_iter++ ^= all1;
202   }
203   if (full_vec_ct & 2) {
204     *main_bitvvec_iter++ ^= all1;
205     *main_bitvvec_iter++ ^= all1;
206   }
207   for (uintptr_t ulii = 3; ulii < full_vec_ct; ulii += 4) {
208     *main_bitvvec_iter++ ^= all1;
209     *main_bitvvec_iter++ ^= all1;
210     *main_bitvvec_iter++ ^= all1;
211     *main_bitvvec_iter++ ^= all1;
212   }
213 #  ifdef USE_AVX2
214   if (word_ct & 2) {
215     const uintptr_t base_idx = full_vec_ct * kWordsPerVec;
216     main_bitvec[base_idx] ^= ~k0LU;
217     main_bitvec[base_idx + 1] ^= ~k0LU;
218   }
219 #  endif
220   if (word_ct & 1) {
221     main_bitvec[word_ct - 1] ^= ~k0LU;
222   }
223 #else
224   for (uintptr_t widx = 0; widx != word_ct; ++widx) {
225     main_bitvec[widx] ^= ~k0LU;
226   }
227 #endif
228 }
229 
AdvTo1Bit(const uintptr_t * bitarr,uintptr_t loc)230 uintptr_t AdvTo1Bit(const uintptr_t* bitarr, uintptr_t loc) {
231   const uintptr_t* bitarr_iter = &(bitarr[loc / kBitsPerWord]);
232   uintptr_t ulii = (*bitarr_iter) >> (loc % kBitsPerWord);
233   if (ulii) {
234     return loc + ctzw(ulii);
235   }
236   do {
237     ulii = *(++bitarr_iter);
238   } while (!ulii);
239   return S_CAST(uintptr_t, bitarr_iter - bitarr) * kBitsPerWord + ctzw(ulii);
240 }
241 
AdvTo0Bit(const uintptr_t * bitarr,uintptr_t loc)242 uintptr_t AdvTo0Bit(const uintptr_t* bitarr, uintptr_t loc) {
243   const uintptr_t* bitarr_iter = &(bitarr[loc / kBitsPerWord]);
244   uintptr_t ulii = (~(*bitarr_iter)) >> (loc % kBitsPerWord);
245   if (ulii) {
246     return loc + ctzw(ulii);
247   }
248   do {
249     ulii = *(++bitarr_iter);
250   } while (ulii == ~k0LU);
251   return S_CAST(uintptr_t, bitarr_iter - bitarr) * kBitsPerWord + ctzw(~ulii);
252 }
253 
254 /*
255 uintptr_t NextNonmissingUnsafe(const uintptr_t* genoarr, uintptr_t loc) {
256   const uintptr_t* genoarr_iter = &(genoarr[loc / kBitsPerWordD2]);
257   uintptr_t ulii = (~(*genoarr_iter)) >> (2 * (loc % kBitsPerWordD2));
258   if (ulii) {
259     return loc + (ctzw(ulii) / 2);
260   }
261   do {
262     ulii = *(++genoarr_iter);
263   } while (ulii == ~k0LU);
264   return S_CAST(uintptr_t, genoarr_iter - genoarr) * kBitsPerWordD2 + (ctzw(~ulii) / 2);
265 }
266 */
267 
AdvBoundedTo1Bit(const uintptr_t * bitarr,uint32_t loc,uint32_t ceil)268 uint32_t AdvBoundedTo1Bit(const uintptr_t* bitarr, uint32_t loc, uint32_t ceil) {
269   // safe version.
270   const uintptr_t* bitarr_iter = &(bitarr[loc / kBitsPerWord]);
271   uintptr_t ulii = (*bitarr_iter) >> (loc % kBitsPerWord);
272   if (ulii) {
273     const uint32_t rval = loc + ctzw(ulii);
274     return MINV(rval, ceil);
275   }
276   const uintptr_t* bitarr_last = &(bitarr[(ceil - 1) / kBitsPerWord]);
277   do {
278     if (bitarr_iter >= bitarr_last) {
279       return ceil;
280     }
281     ulii = *(++bitarr_iter);
282   } while (!ulii);
283   const uint32_t rval = S_CAST(uintptr_t, bitarr_iter - bitarr) * kBitsPerWord + ctzw(ulii);
284   return MINV(rval, ceil);
285 }
286 
AdvBoundedTo0Bit(const uintptr_t * bitarr,uintptr_t loc,uintptr_t ceil)287 uintptr_t AdvBoundedTo0Bit(const uintptr_t* bitarr, uintptr_t loc, uintptr_t ceil) {
288   assert(ceil >= 1);
289   const uintptr_t* bitarr_ptr = &(bitarr[loc / kBitsPerWord]);
290   uintptr_t ulii = (~(*bitarr_ptr)) >> (loc % kBitsPerWord);
291   if (ulii) {
292     loc += ctzw(ulii);
293     return MINV(loc, ceil);
294   }
295   const uintptr_t* bitarr_last = &(bitarr[(ceil - 1) / kBitsPerWord]);
296   do {
297     if (bitarr_ptr >= bitarr_last) {
298       return ceil;
299     }
300     ulii = *(++bitarr_ptr);
301   } while (ulii == ~k0LU);
302   loc = S_CAST(uintptr_t, bitarr_ptr - bitarr) * kBitsPerWord + ctzw(~ulii);
303   return MINV(loc, ceil);
304 }
305 
FindLast1BitBefore(const uintptr_t * bitarr,uint32_t loc)306 uint32_t FindLast1BitBefore(const uintptr_t* bitarr, uint32_t loc) {
307   // unlike the next_{un}set family, this always returns a STRICTLY earlier
308   // position
309   const uintptr_t* bitarr_iter = &(bitarr[loc / kBitsPerWord]);
310   const uint32_t remainder = loc % kBitsPerWord;
311   uintptr_t ulii;
312   if (remainder) {
313     ulii = bzhi(*bitarr_iter, remainder);
314     if (ulii) {
315       return loc - remainder + bsrw(ulii);
316     }
317   }
318   do {
319     ulii = *(--bitarr_iter);
320   } while (!ulii);
321   return S_CAST(uintptr_t, bitarr_iter - bitarr) * kBitsPerWord + bsrw(ulii);
322 }
323 
AllBytesAreX(const unsigned char * bytes,unsigned char match,uintptr_t byte_ct)324 uint32_t AllBytesAreX(const unsigned char* bytes, unsigned char match, uintptr_t byte_ct) {
325   if (byte_ct < kBytesPerWord) {
326     for (uint32_t uii = 0; uii != byte_ct; ++uii) {
327       if (bytes[uii] != match) {
328         return 0;
329       }
330     }
331     return 1;
332   }
333   const uintptr_t* bytes_alias = R_CAST(const uintptr_t*, bytes);
334   const uintptr_t word_match = S_CAST(uintptr_t, match) * kMask0101;
335   uintptr_t word_ct_m1 = (byte_ct - 1) / kBytesPerWord;
336   // todo: try movemask in AVX2 case
337   for (uintptr_t widx = 0; widx != word_ct_m1; ++widx) {
338     if (bytes_alias[widx] != word_match) {
339       return 0;
340     }
341   }
342   const uintptr_t last_word = *R_CAST(const uintptr_t*, &(bytes[byte_ct - kBytesPerWord]));
343   if (last_word != word_match) {
344     return 0;
345   }
346   return 1;
347 }
348 
349 #ifdef USE_AVX2
350 // void CopyBitarrSubsetEx(const uintptr_t* __restrict raw_bitarr, const uintptr_t* __restrict subset_mask, uint32_t bit_idx_start, uint32_t output_bit_idx_end, uintptr_t* __restrict output_bitarr) {
CopyBitarrSubset(const uintptr_t * __restrict raw_bitarr,const uintptr_t * __restrict subset_mask,uint32_t output_bit_idx_end,uintptr_t * __restrict output_bitarr)351 void CopyBitarrSubset(const uintptr_t* __restrict raw_bitarr, const uintptr_t* __restrict subset_mask, uint32_t output_bit_idx_end, uintptr_t* __restrict output_bitarr) {
352   const uint32_t output_bit_idx_end_lowbits = output_bit_idx_end % kBitsPerWord;
353   uintptr_t* output_bitarr_iter = output_bitarr;
354   uintptr_t* output_bitarr_last = &(output_bitarr[output_bit_idx_end / kBitsPerWord]);
355   uintptr_t cur_output_word = 0;
356   uint32_t read_widx = UINT32_MAX;  // deliberate overflow
357   uint32_t write_idx_lowbits = 0;
358   while ((output_bitarr_iter != output_bitarr_last) || (write_idx_lowbits != output_bit_idx_end_lowbits)) {
359     uintptr_t cur_mask_word;
360     // sparse subset_mask optimization
361     // guaranteed to terminate since there's at least one more set bit
362     do {
363       cur_mask_word = subset_mask[++read_widx];
364     } while (!cur_mask_word);
365     uintptr_t extracted_bits = raw_bitarr[read_widx];
366     uint32_t set_bit_ct = kBitsPerWord;
367     if (cur_mask_word != ~k0LU) {
368       extracted_bits = _pext_u64(extracted_bits, cur_mask_word);
369       set_bit_ct = PopcountWord(cur_mask_word);
370     }
371     cur_output_word |= extracted_bits << write_idx_lowbits;
372     const uint32_t new_write_idx_lowbits = write_idx_lowbits + set_bit_ct;
373     if (new_write_idx_lowbits >= kBitsPerWord) {
374       *output_bitarr_iter++ = cur_output_word;
375       // ...and these are the bits that fell off
376       // bugfix: unsafe to right-shift 64
377       if (write_idx_lowbits) {
378         cur_output_word = extracted_bits >> (kBitsPerWord - write_idx_lowbits);
379       } else {
380         cur_output_word = 0;
381       }
382     }
383     write_idx_lowbits = new_write_idx_lowbits % kBitsPerWord;
384   }
385   if (write_idx_lowbits) {
386     *output_bitarr_iter = cur_output_word;
387   }
388 }
389 
PopcountVecsAvx2(const VecW * bit_vvec,uintptr_t vec_ct)390 uintptr_t PopcountVecsAvx2(const VecW* bit_vvec, uintptr_t vec_ct) {
391   // See popcnt_avx2() in libpopcnt.
392   VecW cnt = vecw_setzero();
393   VecW ones = vecw_setzero();
394   VecW twos = vecw_setzero();
395   VecW fours = vecw_setzero();
396   VecW eights = vecw_setzero();
397   VecW prev_sad_result = vecw_setzero();
398   const uintptr_t vec_ct_a16 = RoundDownPow2(vec_ct, 16);
399   for (uintptr_t vec_idx = 0; vec_idx != vec_ct_a16; vec_idx += 16) {
400     VecW twos_a = Csa256(bit_vvec[vec_idx + 0], bit_vvec[vec_idx + 1], &ones);
401     VecW twos_b = Csa256(bit_vvec[vec_idx + 2], bit_vvec[vec_idx + 3], &ones);
402     VecW fours_a = Csa256(twos_a, twos_b, &twos);
403 
404     twos_a = Csa256(bit_vvec[vec_idx + 4], bit_vvec[vec_idx + 5], &ones);
405     twos_b = Csa256(bit_vvec[vec_idx + 6], bit_vvec[vec_idx + 7], &ones);
406     VecW fours_b = Csa256(twos_a, twos_b, &twos);
407     const VecW eights_a = Csa256(fours_a, fours_b, &fours);
408 
409     twos_a = Csa256(bit_vvec[vec_idx + 8], bit_vvec[vec_idx + 9], &ones);
410     twos_b = Csa256(bit_vvec[vec_idx + 10], bit_vvec[vec_idx + 11], &ones);
411     fours_a = Csa256(twos_a, twos_b, &twos);
412 
413     twos_a = Csa256(bit_vvec[vec_idx + 12], bit_vvec[vec_idx + 13], &ones);
414     twos_b = Csa256(bit_vvec[vec_idx + 14], bit_vvec[vec_idx + 15], &ones);
415     fours_b = Csa256(twos_a, twos_b, &twos);
416     const VecW eights_b = Csa256(fours_a, fours_b, &fours);
417     const VecW sixteens = Csa256(eights_a, eights_b, &eights);
418     cnt = cnt + prev_sad_result;
419     // work around high SAD latency
420     prev_sad_result = PopcountVecAvx2(sixteens);
421   }
422   bit_vvec = &(bit_vvec[vec_ct_a16]);
423   const uintptr_t remainder = vec_ct % 16;
424   cnt = cnt + prev_sad_result;
425   if (remainder < 12) {
426     cnt = vecw_slli(cnt, 4);
427     if (remainder) {
428       VecW popcnt1_acc = vecw_setzero();
429       VecW popcnt2_acc = vecw_setzero();
430       const VecW lookup1 = vecw_setr8(4, 5, 5, 6, 5, 6, 6, 7,
431                                       5, 6, 6, 7, 6, 7, 7, 8);
432       const VecW lookup2 = vecw_setr8(4, 3, 3, 2, 3, 2, 2, 1,
433                                       3, 2, 2, 1, 2, 1, 1, 0);
434 
435       const VecW m4 = VCONST_W(kMask0F0F);
436       for (uintptr_t vec_idx = 0; vec_idx != remainder; ++vec_idx) {
437         const VecW vv = bit_vvec[vec_idx];
438         const VecW lo = vv & m4;
439         const VecW hi = vecw_srli(vv, 4) & m4;
440         popcnt1_acc = popcnt1_acc + vecw_shuffle8(lookup1, lo);
441         popcnt2_acc = popcnt2_acc + vecw_shuffle8(lookup2, hi);
442       }
443       cnt = cnt + vecw_sad(popcnt1_acc, popcnt2_acc);
444     }
445   } else {
446     VecW twos_a = Csa256(bit_vvec[0], bit_vvec[1], &ones);
447     VecW twos_b = Csa256(bit_vvec[2], bit_vvec[3], &ones);
448     VecW fours_a = Csa256(twos_a, twos_b, &twos);
449     twos_a = Csa256(bit_vvec[4], bit_vvec[5], &ones);
450     twos_b = Csa256(bit_vvec[6], bit_vvec[7], &ones);
451     VecW fours_b = Csa256(twos_a, twos_b, &twos);
452     const VecW eights_a = Csa256(fours_a, fours_b, &fours);
453     twos_a = Csa256(bit_vvec[8], bit_vvec[9], &ones);
454     twos_b = Csa256(bit_vvec[10], bit_vvec[11], &ones);
455     fours_a = Csa256(twos_a, twos_b, &twos);
456     twos_a = vecw_setzero();
457     if (remainder & 2) {
458       twos_a = Csa256(bit_vvec[12], bit_vvec[13], &ones);
459     }
460     twos_b = vecw_setzero();
461     if (remainder & 1) {
462       twos_b = CsaOne256(bit_vvec[remainder - 1], &ones);
463     }
464     fours_b = Csa256(twos_a, twos_b, &twos);
465     const VecW eights_b = Csa256(fours_a, fours_b, &fours);
466     const VecW sixteens = Csa256(eights_a, eights_b, &eights);
467     cnt = cnt + PopcountVecAvx2(sixteens);
468     cnt = vecw_slli(cnt, 4);
469   }
470   // Appears to be counterproductive to put multiple SAD instructions in
471   // flight.
472   // Compiler is smart enough that it's pointless to manually inline
473   // PopcountVecAvx2.  (Tried combining the 4 SAD calls into one, didn't help.)
474   cnt = cnt + vecw_slli(PopcountVecAvx2(eights), 3);
475   cnt = cnt + vecw_slli(PopcountVecAvx2(fours), 2);
476   cnt = cnt + vecw_slli(PopcountVecAvx2(twos), 1);
477   cnt = cnt + PopcountVecAvx2(ones);
478   return HsumW(cnt);
479 }
480 
PopcountVecsAvx2Intersect(const VecW * __restrict vvec1_iter,const VecW * __restrict vvec2_iter,uintptr_t vec_ct)481 uintptr_t PopcountVecsAvx2Intersect(const VecW* __restrict vvec1_iter, const VecW* __restrict vvec2_iter, uintptr_t vec_ct) {
482   // See popcnt_avx2() in libpopcnt.  vec_ct must be a multiple of 16.
483   VecW cnt = vecw_setzero();
484   VecW ones = vecw_setzero();
485   VecW twos = vecw_setzero();
486   VecW fours = vecw_setzero();
487   VecW eights = vecw_setzero();
488   for (uintptr_t vec_idx = 0; vec_idx < vec_ct; vec_idx += 16) {
489     VecW twos_a = Csa256(vvec1_iter[vec_idx + 0] & vvec2_iter[vec_idx + 0], vvec1_iter[vec_idx + 1] & vvec2_iter[vec_idx + 1], &ones);
490     VecW twos_b = Csa256(vvec1_iter[vec_idx + 2] & vvec2_iter[vec_idx + 2], vvec1_iter[vec_idx + 3] & vvec2_iter[vec_idx + 3], &ones);
491     VecW fours_a = Csa256(twos_a, twos_b, &twos);
492 
493     twos_a = Csa256(vvec1_iter[vec_idx + 4] & vvec2_iter[vec_idx + 4], vvec1_iter[vec_idx + 5] & vvec2_iter[vec_idx + 5], &ones);
494     twos_b = Csa256(vvec1_iter[vec_idx + 6] & vvec2_iter[vec_idx + 6], vvec1_iter[vec_idx + 7] & vvec2_iter[vec_idx + 7], &ones);
495     VecW fours_b = Csa256(twos_a, twos_b, &twos);
496     const VecW eights_a = Csa256(fours_a, fours_b, &fours);
497 
498     twos_a = Csa256(vvec1_iter[vec_idx + 8] & vvec2_iter[vec_idx + 8], vvec1_iter[vec_idx + 9] & vvec2_iter[vec_idx + 9], &ones);
499     twos_b = Csa256(vvec1_iter[vec_idx + 10] & vvec2_iter[vec_idx + 10], vvec1_iter[vec_idx + 11] & vvec2_iter[vec_idx + 11], &ones);
500     fours_a = Csa256(twos_a, twos_b, &twos);
501 
502     twos_a = Csa256(vvec1_iter[vec_idx + 12] & vvec2_iter[vec_idx + 12], vvec1_iter[vec_idx + 13] & vvec2_iter[vec_idx + 13], &ones);
503     twos_b = Csa256(vvec1_iter[vec_idx + 14] & vvec2_iter[vec_idx + 14], vvec1_iter[vec_idx + 15] & vvec2_iter[vec_idx + 15], &ones);
504     fours_b = Csa256(twos_a, twos_b, &twos);
505     const VecW eights_b = Csa256(fours_a, fours_b, &fours);
506     const VecW sixteens = Csa256(eights_a, eights_b, &eights);
507     cnt = cnt + PopcountVecAvx2(sixteens);
508   }
509   cnt = vecw_slli(cnt, 4);
510   cnt = cnt + vecw_slli(PopcountVecAvx2(eights), 3);
511   cnt = cnt + vecw_slli(PopcountVecAvx2(fours), 2);
512   cnt = cnt + vecw_slli(PopcountVecAvx2(twos), 1);
513   cnt = cnt + PopcountVecAvx2(ones);
514   return HsumW(cnt);
515 }
516 
PopcountWordsIntersect(const uintptr_t * __restrict bitvec1_iter,const uintptr_t * __restrict bitvec2_iter,uintptr_t word_ct)517 uintptr_t PopcountWordsIntersect(const uintptr_t* __restrict bitvec1_iter, const uintptr_t* __restrict bitvec2_iter, uintptr_t word_ct) {
518   const uintptr_t* bitvec1_end = &(bitvec1_iter[word_ct]);
519   const uintptr_t block_ct = word_ct / (16 * kWordsPerVec);
520   uintptr_t tot = 0;
521   if (block_ct) {
522     tot = PopcountVecsAvx2Intersect(R_CAST(const VecW*, bitvec1_iter), R_CAST(const VecW*, bitvec2_iter), block_ct * 16);
523     bitvec1_iter = &(bitvec1_iter[block_ct * (16 * kWordsPerVec)]);
524     bitvec2_iter = &(bitvec2_iter[block_ct * (16 * kWordsPerVec)]);
525   }
526   while (bitvec1_iter < bitvec1_end) {
527     tot += PopcountWord((*bitvec1_iter++) & (*bitvec2_iter++));
528   }
529   return tot;
530 }
531 
ExpandBytearr(const void * __restrict compact_bitarr,const uintptr_t * __restrict expand_mask,uint32_t word_ct,uint32_t expand_size,uint32_t read_start_bit,uintptr_t * __restrict target)532 void ExpandBytearr(const void* __restrict compact_bitarr, const uintptr_t* __restrict expand_mask, uint32_t word_ct, uint32_t expand_size, uint32_t read_start_bit, uintptr_t* __restrict target) {
533   const uint32_t expand_sizex_m1 = expand_size + read_start_bit - 1;
534   const uint32_t leading_byte_ct = 1 + (expand_sizex_m1 % kBitsPerWord) / CHAR_BIT;
535   uintptr_t compact_word = SubwordLoad(compact_bitarr, leading_byte_ct) >> read_start_bit;
536   const uintptr_t* compact_bitarr_iter = R_CAST(const uintptr_t*, &(S_CAST(const unsigned char*, compact_bitarr)[leading_byte_ct]));
537   uint32_t compact_idx_lowbits = read_start_bit + CHAR_BIT * (sizeof(intptr_t) - leading_byte_ct);
538   for (uint32_t widx = 0; widx != word_ct; ++widx) {
539     const uintptr_t mask_word = expand_mask[widx];
540     uintptr_t write_word = 0;
541     if (mask_word) {
542       const uint32_t mask_set_ct = PopcountWord(mask_word);
543       uint32_t next_compact_idx_lowbits = compact_idx_lowbits + mask_set_ct;
544       if (next_compact_idx_lowbits <= kBitsPerWord) {
545         write_word = _pdep_u64(compact_word, mask_word);
546         if (mask_set_ct != kBitsPerWord) {
547           compact_word = compact_word >> mask_set_ct;
548         } else {
549           // avoid nasal demons
550           compact_word = 0;
551         }
552       } else {
553 #  ifdef __arm__
554 #    error "Unaligned accesses in ExpandBytearr()."
555 #  endif
556         uintptr_t next_compact_word = *compact_bitarr_iter++;
557         next_compact_idx_lowbits -= kBitsPerWord;
558         compact_word |= next_compact_word << (kBitsPerWord - compact_idx_lowbits);
559         write_word = _pdep_u64(compact_word, mask_word);
560         if (next_compact_idx_lowbits != kBitsPerWord) {
561           compact_word = next_compact_word >> next_compact_idx_lowbits;
562         } else {
563           compact_word = 0;
564         }
565       }
566       compact_idx_lowbits = next_compact_idx_lowbits;
567     }
568     target[widx] = write_word;
569   }
570 }
571 
ExpandThenSubsetBytearr(const void * __restrict compact_bitarr,const uintptr_t * __restrict expand_mask,const uintptr_t * __restrict subset_mask,uint32_t expand_size,uint32_t subset_size,uint32_t read_start_bit,uintptr_t * __restrict target)572 void ExpandThenSubsetBytearr(const void* __restrict compact_bitarr, const uintptr_t* __restrict expand_mask, const uintptr_t* __restrict subset_mask, uint32_t expand_size, uint32_t subset_size, uint32_t read_start_bit, uintptr_t* __restrict target) {
573   const uint32_t expand_sizex_m1 = expand_size + read_start_bit - 1;
574   const uint32_t leading_byte_ct = 1 + (expand_sizex_m1 % kBitsPerWord) / CHAR_BIT;
575   uintptr_t compact_word = SubwordLoad(compact_bitarr, leading_byte_ct) >> read_start_bit;
576   const uintptr_t* compact_bitarr_alias = R_CAST(const uintptr_t*, &(S_CAST(const unsigned char*, compact_bitarr)[leading_byte_ct]));
577   uint32_t compact_widx = UINT32_MAX;  // deliberate overflow
578   uint32_t compact_idx_lowbits = read_start_bit + CHAR_BIT * (sizeof(uintptr_t) - leading_byte_ct);
579   const uint32_t subset_size_lowbits = subset_size % kBitsPerWord;
580   uintptr_t* target_iter = target;
581   uintptr_t* target_last = &(target[subset_size / kBitsPerWord]);
582   uintptr_t cur_output_word = 0;
583   uint32_t read_widx = UINT32_MAX;  // deliberate overflow
584   uint32_t write_idx_lowbits = 0;
585 
586   // bugfix (5 Feb 2018): missed a case in sparse subset_mask optimization
587   uint32_t expand_bit_ct_skip = 0;
588   while ((target_iter != target_last) || (write_idx_lowbits != subset_size_lowbits)) {
589     uintptr_t expand_word;
590     uintptr_t subset_word;
591     uint32_t expand_bit_ct;
592     while (1) {
593       ++read_widx;
594       expand_word = expand_mask[read_widx];
595       subset_word = subset_mask[read_widx];
596       expand_bit_ct = PopcountWord(expand_word);
597       if (subset_word) {
598         break;
599       }
600       expand_bit_ct_skip += expand_bit_ct;
601     }
602     uintptr_t extracted_bits = 0;
603     const uint32_t set_bit_ct = PopcountWord(subset_word);
604     if (expand_word & subset_word) {
605       // lazy load
606       compact_idx_lowbits += expand_bit_ct_skip;
607       if (compact_idx_lowbits >= kBitsPerWord) {
608         compact_widx += compact_idx_lowbits / kBitsPerWord;
609         compact_idx_lowbits = compact_idx_lowbits % kBitsPerWord;
610 #  ifdef __arm__
611 #    error "Unaligned accesses in ExpandThenSubsetBytearr()."
612 #  endif
613         compact_word = compact_bitarr_alias[compact_widx] >> compact_idx_lowbits;
614       } else {
615         compact_word = compact_word >> expand_bit_ct_skip;
616       }
617       uint32_t next_compact_idx_lowbits = compact_idx_lowbits + expand_bit_ct;
618       uintptr_t expanded_bits;
619       if (next_compact_idx_lowbits <= kBitsPerWord) {
620         expanded_bits = _pdep_u64(compact_word, expand_word);
621         if (expand_bit_ct != kBitsPerWord) {
622           compact_word = compact_word >> expand_bit_ct;
623         }
624       } else {
625         uintptr_t next_compact_word = compact_bitarr_alias[++compact_widx];
626         next_compact_idx_lowbits -= kBitsPerWord;
627         compact_word |= next_compact_word << (kBitsPerWord - compact_idx_lowbits);
628         expanded_bits = _pdep_u64(compact_word, expand_word);
629         if (next_compact_idx_lowbits != kBitsPerWord) {
630           compact_word = next_compact_word >> next_compact_idx_lowbits;
631         }
632       }
633       extracted_bits = _pext_u64(expanded_bits, subset_word);
634       compact_idx_lowbits = next_compact_idx_lowbits;
635       cur_output_word |= extracted_bits << write_idx_lowbits;
636       expand_bit_ct_skip = 0;
637     } else {
638       expand_bit_ct_skip += expand_bit_ct;
639     }
640     const uint32_t new_write_idx_lowbits = write_idx_lowbits + set_bit_ct;
641     if (new_write_idx_lowbits >= kBitsPerWord) {
642       *target_iter++ = cur_output_word;
643       // ...and these are the bits that fell off
644       if (write_idx_lowbits) {
645         cur_output_word = extracted_bits >> (kBitsPerWord - write_idx_lowbits);
646       } else {
647         cur_output_word = 0;
648       }
649     }
650     write_idx_lowbits = new_write_idx_lowbits % kBitsPerWord;
651   }
652   if (write_idx_lowbits) {
653     *target_iter = cur_output_word;
654   }
655 }
656 
ExpandBytearrNested(const void * __restrict compact_bitarr,const uintptr_t * __restrict mid_bitarr,const uintptr_t * __restrict top_expand_mask,uint32_t word_ct,uint32_t mid_popcount,uint32_t mid_start_bit,uintptr_t * __restrict mid_target,uintptr_t * __restrict compact_target)657 void ExpandBytearrNested(const void* __restrict compact_bitarr, const uintptr_t* __restrict mid_bitarr, const uintptr_t* __restrict top_expand_mask, uint32_t word_ct, uint32_t mid_popcount, uint32_t mid_start_bit, uintptr_t* __restrict mid_target, uintptr_t* __restrict compact_target) {
658   assert(mid_popcount);
659   const uint32_t leading_byte_ct = 1 + ((mid_popcount - 1) % kBitsPerWord) / CHAR_BIT;
660   uintptr_t compact_read_word = SubwordLoad(compact_bitarr, leading_byte_ct);
661   uint32_t compact_idx_lowbits = CHAR_BIT * (sizeof(intptr_t) - leading_byte_ct);
662   const uintptr_t* compact_bitarr_iter = R_CAST(const uintptr_t*, &(S_CAST(const unsigned char*, compact_bitarr)[leading_byte_ct]));
663   const uintptr_t* mid_bitarr_iter = mid_bitarr;
664   uint32_t mid_idx_lowbits = mid_start_bit;
665   uintptr_t mid_read_word = (*mid_bitarr_iter) >> mid_start_bit;
666   for (uint32_t widx = 0; widx != word_ct; ++widx) {
667     const uintptr_t top_word = top_expand_mask[widx];
668     uintptr_t mid_write_word = 0;
669     uintptr_t compact_write_word = 0;
670     if (top_word) {
671       const uint32_t top_set_ct = PopcountWord(top_word);
672       uint32_t next_mid_idx_lowbits = mid_idx_lowbits + top_set_ct;
673       if (next_mid_idx_lowbits <= kBitsPerWord) {
674         mid_write_word = _pdep_u64(mid_read_word, top_word);
675         if (top_set_ct != kBitsPerWord) {
676           mid_read_word = mid_read_word >> top_set_ct;
677         } else {
678           // avoid nasal demons
679           mid_read_word = 0;
680         }
681       } else {
682         uintptr_t next_mid_read_word = *(++mid_bitarr_iter);
683         next_mid_idx_lowbits -= kBitsPerWord;
684         mid_read_word |= next_mid_read_word << (kBitsPerWord - mid_idx_lowbits);
685         mid_write_word = _pdep_u64(mid_read_word, top_word);
686         if (next_mid_idx_lowbits != kBitsPerWord) {
687           mid_read_word = next_mid_read_word >> next_mid_idx_lowbits;
688         } else {
689           mid_read_word = 0;
690         }
691       }
692       mid_idx_lowbits = next_mid_idx_lowbits;
693       if (mid_write_word) {
694         const uint32_t mid_set_ct = PopcountWord(mid_write_word);
695         uint32_t next_compact_idx_lowbits = compact_idx_lowbits + mid_set_ct;
696         if (next_compact_idx_lowbits <= kBitsPerWord) {
697           compact_write_word = _pdep_u64(compact_read_word, mid_write_word);
698           if (mid_set_ct != kBitsPerWord) {
699             compact_read_word = compact_read_word >> mid_set_ct;
700           } else {
701             compact_read_word = 0;
702           }
703         } else {
704 #  ifdef __arm__
705 #    error "Unaligned accesses in ExpandBytearrNested()."
706 #  endif
707           uintptr_t next_compact_word = *compact_bitarr_iter++;
708           next_compact_idx_lowbits -= kBitsPerWord;
709           compact_read_word |= next_compact_word << (kBitsPerWord - compact_idx_lowbits);
710           compact_write_word = _pdep_u64(compact_read_word, mid_write_word);
711           if (next_compact_idx_lowbits != kBitsPerWord) {
712             compact_read_word = next_compact_word >> next_compact_idx_lowbits;
713           } else {
714             compact_read_word = 0;
715           }
716         }
717         compact_idx_lowbits = next_compact_idx_lowbits;
718       }
719     }
720     mid_target[widx] = mid_write_word;
721     compact_target[widx] = compact_write_word;
722   }
723 }
724 
ExpandThenSubsetBytearrNested(const void * __restrict compact_bitarr,const uintptr_t * __restrict mid_bitarr,const uintptr_t * __restrict top_expand_mask,const uintptr_t * __restrict subset_mask,uint32_t subset_size,uint32_t mid_popcount,uint32_t mid_start_bit,uintptr_t * __restrict mid_target,uintptr_t * __restrict compact_target)725 void ExpandThenSubsetBytearrNested(const void* __restrict compact_bitarr, const uintptr_t* __restrict mid_bitarr, const uintptr_t* __restrict top_expand_mask, const uintptr_t* __restrict subset_mask, uint32_t subset_size, uint32_t mid_popcount, uint32_t mid_start_bit, uintptr_t* __restrict mid_target, uintptr_t* __restrict compact_target) {
726   assert(mid_popcount);
727   const uint32_t leading_byte_ct = 1 + ((mid_popcount - 1) % kBitsPerWord) / CHAR_BIT;
728   uintptr_t compact_read_word = SubwordLoad(compact_bitarr, leading_byte_ct);
729   uint32_t compact_idx_lowbits = CHAR_BIT * (sizeof(intptr_t) - leading_byte_ct);
730   const uintptr_t* compact_bitarr_alias = R_CAST(const uintptr_t*, &(S_CAST(const unsigned char*, compact_bitarr)[leading_byte_ct]));
731   const uintptr_t* mid_bitarr_iter = mid_bitarr;
732   const uint32_t subset_size_lowbits = subset_size % kBitsPerWord;
733   const uint32_t write_widx_last = subset_size / kBitsPerWord;
734   uintptr_t mid_read_word = (*mid_bitarr_iter) >> mid_start_bit;
735   uintptr_t mid_output_word = 0;
736   uintptr_t compact_output_word = 0;
737   uint32_t mid_idx_lowbits = mid_start_bit;
738   uint32_t compact_widx = UINT32_MAX;  // deliberate overflow
739   uint32_t read_widx = UINT32_MAX;  // deliberate overflow
740   uint32_t write_idx_lowbits = 0;
741   uint32_t write_widx = 0;
742 
743   // bugfix (5 Feb 2018): missed a case in sparse subset_mask optimization
744   uint32_t mid_set_skip = 0;
745   while ((write_widx != write_widx_last) || (write_idx_lowbits != subset_size_lowbits)) {
746     uintptr_t subset_word;
747     uintptr_t mid_expanded_bits;
748     uint32_t mid_set_ct;
749     while (1) {
750       ++read_widx;
751       uintptr_t top_word = top_expand_mask[read_widx];
752       subset_word = subset_mask[read_widx];
753       mid_expanded_bits = 0;
754       if (top_word) {
755         uint32_t top_set_ct = PopcountWord(top_word);
756         uint32_t next_mid_idx_lowbits = mid_idx_lowbits + top_set_ct;
757         if (next_mid_idx_lowbits <= kBitsPerWord) {
758           mid_expanded_bits = _pdep_u64(mid_read_word, top_word);
759           if (top_set_ct != kBitsPerWord) {
760             mid_read_word = mid_read_word >> top_set_ct;
761           } else {
762             // avoid nasal demons
763             mid_read_word = 0;
764           }
765         } else {
766           uintptr_t next_mid_read_word = *(++mid_bitarr_iter);
767           next_mid_idx_lowbits -= kBitsPerWord;
768           mid_read_word |= next_mid_read_word << (kBitsPerWord - mid_idx_lowbits);
769           mid_expanded_bits = _pdep_u64(mid_read_word, top_word);
770           if (next_mid_idx_lowbits != kBitsPerWord) {
771             mid_read_word = next_mid_read_word >> next_mid_idx_lowbits;
772           } else {
773             mid_read_word = 0;
774           }
775         }
776         mid_idx_lowbits = next_mid_idx_lowbits;
777       }
778       mid_set_ct = PopcountWord(mid_expanded_bits);
779       if (subset_word) {
780         break;
781       }
782       mid_set_skip += mid_set_ct;
783     }
784 
785     uintptr_t mid_extracted_bits = 0;
786     uintptr_t compact_extracted_bits = 0;
787     uint32_t set_bit_ct = PopcountWord(subset_word);
788     if (mid_expanded_bits & subset_word) {
789       // lazy load
790       compact_idx_lowbits += mid_set_skip;
791       if (compact_idx_lowbits >= kBitsPerWord) {
792         compact_widx += compact_idx_lowbits / kBitsPerWord;
793         compact_idx_lowbits = compact_idx_lowbits % kBitsPerWord;
794 #  ifdef __arm__
795 #    error "Unaligned accesses in ExpandThenSubsetBytearrNested()."
796 #  endif
797         compact_read_word = compact_bitarr_alias[compact_widx] >> compact_idx_lowbits;
798       } else {
799         compact_read_word = compact_read_word >> mid_set_skip;
800       }
801       uint32_t next_compact_idx_lowbits = compact_idx_lowbits + mid_set_ct;
802       uintptr_t compact_expanded_bits;
803       if (next_compact_idx_lowbits <= kBitsPerWord) {
804         compact_expanded_bits = _pdep_u64(compact_read_word, mid_expanded_bits);
805         if (mid_set_ct != kBitsPerWord) {
806           compact_read_word = compact_read_word >> mid_set_ct;
807         }
808       } else {
809         uintptr_t next_compact_word = compact_bitarr_alias[++compact_widx];
810         next_compact_idx_lowbits -= kBitsPerWord;
811         compact_read_word |= next_compact_word << (kBitsPerWord - compact_idx_lowbits);
812         compact_expanded_bits = _pdep_u64(compact_read_word, mid_expanded_bits);
813         if (next_compact_idx_lowbits != kBitsPerWord) {
814           compact_read_word = next_compact_word >> next_compact_idx_lowbits;
815         }
816       }
817       compact_extracted_bits = _pext_u64(compact_expanded_bits, subset_word);
818       mid_extracted_bits = _pext_u64(mid_expanded_bits, subset_word);
819       compact_idx_lowbits = next_compact_idx_lowbits;
820       compact_output_word |= compact_extracted_bits << write_idx_lowbits;
821       mid_output_word |= mid_extracted_bits << write_idx_lowbits;
822       mid_set_skip = 0;
823     } else {
824       mid_set_skip += mid_set_ct;
825     }
826     const uint32_t new_write_idx_lowbits = write_idx_lowbits + set_bit_ct;
827     if (new_write_idx_lowbits >= kBitsPerWord) {
828       mid_target[write_widx] = mid_output_word;
829       compact_target[write_widx] = compact_output_word;
830       ++write_widx;
831       if (write_idx_lowbits) {
832         mid_output_word = mid_extracted_bits >> (kBitsPerWord - write_idx_lowbits);
833         compact_output_word = compact_extracted_bits >> (kBitsPerWord - write_idx_lowbits);
834       } else {
835         mid_output_word = 0;
836         compact_output_word = 0;
837       }
838     }
839     write_idx_lowbits = new_write_idx_lowbits % kBitsPerWord;
840   }
841   if (write_idx_lowbits) {
842     mid_target[write_widx] = mid_output_word;
843     compact_target[write_widx] = compact_output_word;
844   }
845 }
846 #else  // !USE_AVX2
CopyBitarrSubset(const uintptr_t * __restrict raw_bitarr,const uintptr_t * __restrict subset_mask,uint32_t output_bit_idx_end,uintptr_t * __restrict output_bitarr)847 void CopyBitarrSubset(const uintptr_t* __restrict raw_bitarr, const uintptr_t* __restrict subset_mask, uint32_t output_bit_idx_end, uintptr_t* __restrict output_bitarr) {
848   const uint32_t output_bit_idx_end_lowbits = output_bit_idx_end % kBitsPerWord;
849   uintptr_t* output_bitarr_iter = output_bitarr;
850   uintptr_t* output_bitarr_last = &(output_bitarr[output_bit_idx_end / kBitsPerWord]);
851   uintptr_t cur_output_word = 0;
852   uint32_t read_widx = UINT32_MAX;  // deliberate overflow
853   uint32_t write_idx_lowbits = 0;
854   while ((output_bitarr_iter != output_bitarr_last) || (write_idx_lowbits != output_bit_idx_end_lowbits)) {
855     uintptr_t cur_mask_word;
856     // sparse subset_mask optimization
857     // guaranteed to terminate since there's at least one more set bit
858     do {
859       cur_mask_word = subset_mask[++read_widx];
860     } while (!cur_mask_word);
861     uintptr_t cur_masked_input_word = raw_bitarr[read_widx] & cur_mask_word;
862     const uint32_t cur_mask_popcount = PopcountWord(cur_mask_word);
863     uintptr_t subsetted_input_word = 0;
864     while (cur_masked_input_word) {
865       const uintptr_t mask_word_high = (cur_mask_word | (cur_masked_input_word ^ (cur_masked_input_word - 1))) + 1;
866       if (!mask_word_high) {
867         subsetted_input_word |= cur_masked_input_word >> (kBitsPerWord - cur_mask_popcount);
868         break;
869       }
870       const uint32_t cur_read_end = ctzw(mask_word_high);
871       const uintptr_t bits_to_copy = cur_masked_input_word & (~mask_word_high);
872       cur_masked_input_word ^= bits_to_copy;
873       const uint32_t cur_write_end = PopcountWord(cur_mask_word & (~mask_word_high));
874       subsetted_input_word |= bits_to_copy >> (cur_read_end - cur_write_end);
875     }
876     cur_output_word |= subsetted_input_word << write_idx_lowbits;
877     const uint32_t new_write_idx_lowbits = write_idx_lowbits + cur_mask_popcount;
878     if (new_write_idx_lowbits >= kBitsPerWord) {
879       *output_bitarr_iter++ = cur_output_word;
880       // ...and these are the bits that fell off
881       // bugfix: unsafe to right-shift 64
882       if (write_idx_lowbits) {
883         cur_output_word = subsetted_input_word >> (kBitsPerWord - write_idx_lowbits);
884       } else {
885         cur_output_word = 0;
886       }
887     }
888     write_idx_lowbits = new_write_idx_lowbits % kBitsPerWord;
889   }
890   if (write_idx_lowbits) {
891     *output_bitarr_iter = cur_output_word;
892   }
893 }
894 
895 // Basic SSE2 implementation of Lauradoux/Walisch popcount.
PopcountVecsNoAvx2(const VecW * bit_vvec,uintptr_t vec_ct)896 uintptr_t PopcountVecsNoAvx2(const VecW* bit_vvec, uintptr_t vec_ct) {
897   // popcounts vptr[0..(vec_ct-1)].  Assumes vec_ct is a multiple of 3 (0 ok).
898   assert(!(vec_ct % 3));
899   const VecW m0 = vecw_setzero();
900   const VecW m1 = VCONST_W(kMask5555);
901   const VecW m2 = VCONST_W(kMask3333);
902   const VecW m4 = VCONST_W(kMask0F0F);
903   const VecW* bit_vvec_iter = bit_vvec;
904   VecW prev_sad_result = vecw_setzero();
905   VecW acc = vecw_setzero();
906   uintptr_t cur_incr = 30;
907   for (; ; vec_ct -= cur_incr) {
908     if (vec_ct < 30) {
909       if (!vec_ct) {
910         acc = acc + prev_sad_result;
911         return HsumW(acc);
912       }
913       cur_incr = vec_ct;
914     }
915     VecW inner_acc = vecw_setzero();
916     const VecW* bit_vvec_stop = &(bit_vvec_iter[cur_incr]);
917     do {
918       VecW count1 = *bit_vvec_iter++;
919       VecW count2 = *bit_vvec_iter++;
920       VecW half1 = *bit_vvec_iter++;
921       VecW half2 = vecw_srli(half1, 1) & m1;
922       half1 = half1 & m1;
923       // Two bits can represent values from 0-3, so make each pair in count1
924       // count2 store a partial bitcount covering themselves AND another bit
925       // from elsewhere.
926       count1 = count1 - (vecw_srli(count1, 1) & m1);
927       count2 = count2 - (vecw_srli(count2, 1) & m1);
928       count1 = count1 + half1;
929       count2 = count2 + half2;
930       // Four bits represent 0-15, so we can safely add four 0-3 partial
931       // bitcounts together.
932       count1 = (count1 & m2) + (vecw_srli(count1, 2) & m2);
933       count1 = count1 + (count2 & m2) + (vecw_srli(count2, 2) & m2);
934       // Accumulator stores sixteen 0-255 counts in parallel.
935       // (32 in AVX2 case, 4 in 32-bit case)
936       inner_acc = inner_acc + (count1 & m4) + (vecw_srli(count1, 4) & m4);
937     } while (bit_vvec_iter < bit_vvec_stop);
938     // _mm_sad_epu8() has better throughput than the previous method of
939     // horizontal-summing the bytes in inner_acc, by enough to compensate for
940     // the loop length being reduced from 30 to 15 vectors, but it has high
941     // latency.  We work around that by waiting till the end of the next full
942     // loop iteration to actually use the SAD result.
943     acc = acc + prev_sad_result;
944     prev_sad_result = vecw_bytesum(inner_acc, m0);
945   }
946 }
947 
PopcountVecsNoAvx2Intersect(const VecW * __restrict vvec1_iter,const VecW * __restrict vvec2_iter,uintptr_t vec_ct)948 static inline uintptr_t PopcountVecsNoAvx2Intersect(const VecW* __restrict vvec1_iter, const VecW* __restrict vvec2_iter, uintptr_t vec_ct) {
949   // popcounts vvec1 AND vvec2[0..(ct-1)].  ct is a multiple of 3.
950   assert(!(vec_ct % 3));
951   const VecW m0 = vecw_setzero();
952   const VecW m1 = VCONST_W(kMask5555);
953   const VecW m2 = VCONST_W(kMask3333);
954   const VecW m4 = VCONST_W(kMask0F0F);
955   VecW prev_sad_result = vecw_setzero();
956   VecW acc = vecw_setzero();
957   uintptr_t cur_incr = 30;
958   for (; ; vec_ct -= cur_incr) {
959     if (vec_ct < 30) {
960       if (!vec_ct) {
961         acc = acc + prev_sad_result;
962         return HsumW(acc);
963       }
964       cur_incr = vec_ct;
965     }
966     VecW inner_acc = vecw_setzero();
967     const VecW* vvec1_stop = &(vvec1_iter[cur_incr]);
968     do {
969       VecW count1 = (*vvec1_iter++) & (*vvec2_iter++);
970       VecW count2 = (*vvec1_iter++) & (*vvec2_iter++);
971       VecW half1 = (*vvec1_iter++) & (*vvec2_iter++);
972       const VecW half2 = vecw_srli(half1, 1) & m1;
973       half1 = half1 & m1;
974       count1 = count1 - (vecw_srli(count1, 1) & m1);
975       count2 = count2 - (vecw_srli(count2, 1) & m1);
976       count1 = count1 + half1;
977       count2 = count2 + half2;
978       count1 = (count1 & m2) + (vecw_srli(count1, 2) & m2);
979       count1 = count1 + (count2 & m2) + (vecw_srli(count2, 2) & m2);
980       inner_acc = inner_acc + (count1 & m4) + (vecw_srli(count1, 4) & m4);
981     } while (vvec1_iter < vvec1_stop);
982     acc = acc + prev_sad_result;
983     prev_sad_result = vecw_bytesum(inner_acc, m0);
984   }
985 }
986 
PopcountWordsIntersect(const uintptr_t * __restrict bitvec1_iter,const uintptr_t * __restrict bitvec2_iter,uintptr_t word_ct)987 uintptr_t PopcountWordsIntersect(const uintptr_t* __restrict bitvec1_iter, const uintptr_t* __restrict bitvec2_iter, uintptr_t word_ct) {
988   uintptr_t tot = 0;
989   const uintptr_t* bitvec1_end = &(bitvec1_iter[word_ct]);
990   const uintptr_t trivec_ct = word_ct / (3 * kWordsPerVec);
991   tot += PopcountVecsNoAvx2Intersect(R_CAST(const VecW*, bitvec1_iter), R_CAST(const VecW*, bitvec2_iter), trivec_ct * 3);
992   bitvec1_iter = &(bitvec1_iter[trivec_ct * (3 * kWordsPerVec)]);
993   bitvec2_iter = &(bitvec2_iter[trivec_ct * (3 * kWordsPerVec)]);
994   while (bitvec1_iter < bitvec1_end) {
995     tot += PopcountWord((*bitvec1_iter++) & (*bitvec2_iter++));
996   }
997   return tot;
998 }
999 
ExpandBytearr(const void * __restrict compact_bitarr,const uintptr_t * __restrict expand_mask,uint32_t word_ct,uint32_t expand_size,uint32_t read_start_bit,uintptr_t * __restrict target)1000 void ExpandBytearr(const void* __restrict compact_bitarr, const uintptr_t* __restrict expand_mask, uint32_t word_ct, uint32_t expand_size, uint32_t read_start_bit, uintptr_t* __restrict target) {
1001   ZeroWArr(word_ct, target);
1002   const uintptr_t* compact_bitarr_alias = S_CAST(const uintptr_t*, compact_bitarr);
1003   const uint32_t expand_sizex_m1 = expand_size + read_start_bit - 1;
1004   const uint32_t compact_widx_last = expand_sizex_m1 / kBitsPerWord;
1005   uint32_t compact_idx_lowbits = read_start_bit;
1006   uint32_t loop_len = kBitsPerWord;
1007   uintptr_t write_widx = 0;
1008   uintptr_t expand_mask_bits = expand_mask[0];
1009   for (uint32_t compact_widx = 0; ; ++compact_widx) {
1010     uintptr_t compact_word;
1011     if (compact_widx >= compact_widx_last) {
1012       if (compact_widx > compact_widx_last) {
1013         return;
1014       }
1015       loop_len = 1 + (expand_sizex_m1 % kBitsPerWord);
1016       // avoid possible segfault
1017       compact_word = SubwordLoad(&(compact_bitarr_alias[compact_widx]), DivUp(loop_len, CHAR_BIT));
1018     } else {
1019 #  ifdef __arm__
1020 #    error "Unaligned accesses in ExpandBytearr()."
1021 #  endif
1022       compact_word = compact_bitarr_alias[compact_widx];
1023     }
1024     for (; compact_idx_lowbits != loop_len; ++compact_idx_lowbits) {
1025       const uintptr_t lowbit = BitIter1y(expand_mask, &write_widx, &expand_mask_bits);
1026       // bugfix: can't just use (compact_word & 1) and compact_word >>= 1,
1027       // since we may skip the first bit on the first loop iteration
1028       if ((compact_word >> compact_idx_lowbits) & 1) {
1029         target[write_widx] |= lowbit;
1030       }
1031     }
1032     compact_idx_lowbits = 0;
1033   }
1034 }
1035 
ExpandThenSubsetBytearr(const void * __restrict compact_bitarr,const uintptr_t * __restrict expand_mask,const uintptr_t * __restrict subset_mask,uint32_t expand_size,uint32_t subset_size,uint32_t read_start_bit,uintptr_t * __restrict target)1036 void ExpandThenSubsetBytearr(const void* __restrict compact_bitarr, const uintptr_t* __restrict expand_mask, const uintptr_t* __restrict subset_mask, uint32_t expand_size, uint32_t subset_size, uint32_t read_start_bit, uintptr_t* __restrict target) {
1037   const uint32_t expand_sizex_m1 = expand_size + read_start_bit - 1;
1038   const uint32_t leading_byte_ct = 1 + (expand_sizex_m1 % kBitsPerWord) / CHAR_BIT;
1039   uint32_t read_idx_lowbits = CHAR_BIT * (sizeof(intptr_t) - leading_byte_ct);
1040   uintptr_t compact_read_word = SubwordLoad(compact_bitarr, leading_byte_ct) << read_idx_lowbits;
1041   read_idx_lowbits += read_start_bit;
1042   const uintptr_t* compact_bitarr_iter = R_CAST(const uintptr_t*, &(S_CAST(const unsigned char*, compact_bitarr)[leading_byte_ct]));
1043   const uint32_t subset_size_lowbits = subset_size % kBitsPerWord;
1044   uintptr_t* target_iter = target;
1045   uintptr_t* target_last = &(target[subset_size / kBitsPerWord]);
1046   uintptr_t compact_write_word = 0;
1047   uint32_t read_widx = 0;
1048   // further improvement is probably possible (e.g. use AVX2 lazy-load), but
1049   // I'll postpone for now
1050   uint32_t write_idx_lowbits = 0;
1051   while ((target_iter != target_last) || (write_idx_lowbits != subset_size_lowbits)) {
1052     const uintptr_t subset_word = subset_mask[read_widx];
1053     const uintptr_t expand_word = expand_mask[read_widx];
1054     ++read_widx;
1055     uintptr_t tmp_compact_write_word = 0;
1056     if (expand_word) {
1057       const uint32_t expand_bit_ct = PopcountWord(expand_word);
1058       uint32_t read_idx_lowbits_end = read_idx_lowbits + expand_bit_ct;
1059       uintptr_t tmp_compact_read_word = 0;
1060       if (read_idx_lowbits != kBitsPerWord) {
1061         tmp_compact_read_word = compact_read_word >> read_idx_lowbits;
1062       }
1063       if (read_idx_lowbits_end > kBitsPerWord) {
1064 #  ifdef __arm__
1065 #    error "Unaligned accesses in ExpandThenSubsetBytearr()."
1066 #  endif
1067         compact_read_word = *compact_bitarr_iter++;
1068         tmp_compact_read_word |= compact_read_word << (kBitsPerWord - read_idx_lowbits);
1069         read_idx_lowbits_end -= kBitsPerWord;
1070       }
1071       tmp_compact_read_word = bzhi_max(tmp_compact_read_word, expand_bit_ct);
1072       read_idx_lowbits = read_idx_lowbits_end;
1073       if (tmp_compact_read_word) {
1074         uintptr_t cur_intersect = subset_word & expand_word;
1075         while (cur_intersect) {
1076           const uintptr_t cur_intersect_and_arg = cur_intersect - k1LU;
1077           const uintptr_t lowmask = (cur_intersect ^ cur_intersect_and_arg) >> 1;
1078           const uint32_t read_idx_offset = PopcountWord(expand_word & lowmask);
1079           uintptr_t shifted_compact_read_word = tmp_compact_read_word >> read_idx_offset;
1080           if (shifted_compact_read_word & 1) {
1081             tmp_compact_write_word |= (k1LU << PopcountWord(subset_word & lowmask));
1082             if (shifted_compact_read_word == 1) {
1083               break;
1084             }
1085           }
1086           cur_intersect &= cur_intersect_and_arg;
1087         }
1088       }
1089       compact_write_word |= tmp_compact_write_word << write_idx_lowbits;
1090     }
1091     const uint32_t write_idx_lowbits_end = write_idx_lowbits + PopcountWord(subset_word);
1092     if (write_idx_lowbits_end >= kBitsPerWord) {
1093       *target_iter++ = compact_write_word;
1094       if (write_idx_lowbits) {
1095         compact_write_word = tmp_compact_write_word >> (kBitsPerWord - write_idx_lowbits);
1096       } else {
1097         compact_write_word = 0;
1098       }
1099     }
1100     write_idx_lowbits = write_idx_lowbits_end % kBitsPerWord;
1101   }
1102   if (write_idx_lowbits) {
1103     *target_iter = compact_write_word;
1104   }
1105 }
1106 
1107 // compact_bitarr := phaseinfo
1108 // mid_bitarr := phasepresent, [1 + het_ct]
1109 // top_expand_mask := all_hets, [raw_sample_ct]
ExpandBytearrNested(const void * __restrict compact_bitarr,const uintptr_t * __restrict mid_bitarr,const uintptr_t * __restrict top_expand_mask,uint32_t word_ct,uint32_t mid_popcount,uint32_t mid_start_bit,uintptr_t * __restrict mid_target,uintptr_t * __restrict compact_target)1110 void ExpandBytearrNested(const void* __restrict compact_bitarr, const uintptr_t* __restrict mid_bitarr, const uintptr_t* __restrict top_expand_mask, uint32_t word_ct, uint32_t mid_popcount, uint32_t mid_start_bit, uintptr_t* __restrict mid_target, uintptr_t* __restrict compact_target) {
1111   ZeroWArr(word_ct, mid_target);
1112   ZeroWArr(word_ct, compact_target);
1113   const uintptr_t* compact_bitarr_alias = S_CAST(const uintptr_t*, compact_bitarr);
1114   const uint32_t mid_popcount_m1 = mid_popcount - 1;
1115   const uint32_t compact_widx_last = mid_popcount_m1 / kBitsPerWord;
1116   uint32_t mid_idx = mid_start_bit;
1117   // can allow compact_idx_lowbits to be initialized to nonzero
1118   uint32_t loop_len = kBitsPerWord;
1119   uintptr_t write_widx = 0;
1120   uintptr_t top_expand_mask_bits = top_expand_mask[0];
1121   for (uint32_t compact_widx = 0; ; ++compact_widx) {
1122     uintptr_t compact_word;
1123     if (compact_widx >= compact_widx_last) {
1124       if (compact_widx > compact_widx_last) {
1125         return;
1126       }
1127       loop_len = 1 + (mid_popcount_m1 % kBitsPerWord);
1128       // avoid possible segfault
1129       compact_word = SubwordLoad(&(compact_bitarr_alias[compact_widx]), DivUp(loop_len, CHAR_BIT));
1130     } else {
1131 #ifdef __arm__
1132 #  error "Unaligned accesses in ExpandBytearrNested()."
1133 #endif
1134       compact_word = compact_bitarr_alias[compact_widx];
1135     }
1136     for (uint32_t compact_idx_lowbits = 0; compact_idx_lowbits != loop_len; ++mid_idx) {
1137       const uintptr_t lowbit = BitIter1y(top_expand_mask, &write_widx, &top_expand_mask_bits);
1138       if (IsSet(mid_bitarr, mid_idx)) {
1139         mid_target[write_widx] |= lowbit;
1140         compact_target[write_widx] |= lowbit * (compact_word & 1);
1141         compact_word >>= 1;
1142         ++compact_idx_lowbits;
1143       }
1144     }
1145   }
1146 }
1147 
ExpandThenSubsetBytearrNested(const void * __restrict compact_bitarr,const uintptr_t * __restrict mid_bitarr,const uintptr_t * __restrict top_expand_mask,const uintptr_t * __restrict subset_mask,uint32_t subset_size,uint32_t mid_popcount,uint32_t mid_start_bit,uintptr_t * __restrict mid_target,uintptr_t * __restrict compact_target)1148 void ExpandThenSubsetBytearrNested(const void* __restrict compact_bitarr, const uintptr_t* __restrict mid_bitarr, const uintptr_t* __restrict top_expand_mask, const uintptr_t* __restrict subset_mask, uint32_t subset_size, uint32_t mid_popcount, uint32_t mid_start_bit, uintptr_t* __restrict mid_target, uintptr_t* __restrict compact_target) {
1149   assert(mid_popcount);
1150   const uint32_t leading_byte_ct = 1 + ((mid_popcount - 1) % kBitsPerWord) / CHAR_BIT;
1151   uint32_t compact_idx_lowbits = CHAR_BIT * (sizeof(intptr_t) - leading_byte_ct);
1152   uintptr_t compact_read_word = SubwordLoad(compact_bitarr, leading_byte_ct) << compact_idx_lowbits;
1153   const uintptr_t* compact_bitarr_iter = R_CAST(const uintptr_t*, &(S_CAST(const unsigned char*, compact_bitarr)[leading_byte_ct]));
1154   // bugfix (12 Apr 2018): need to round down here
1155   const uint32_t subset_size_dl = subset_size / kBitsPerWord;
1156   const uint32_t subset_size_lowbits = subset_size % kBitsPerWord;
1157   const uintptr_t* mid_read_iter = mid_bitarr;
1158   uintptr_t mid_read_word = *mid_read_iter++;
1159   uintptr_t mid_write_word = 0;
1160   uintptr_t compact_write_word = 0;
1161   uint32_t mid_idx_lowbits = mid_start_bit;
1162   uint32_t write_idx_lowbits = 0;
1163   uint32_t write_widx = 0;
1164   uint32_t read_widx = 0;
1165   while ((write_widx != subset_size_dl) || (write_idx_lowbits != subset_size_lowbits)) {
1166     const uintptr_t subset_word = subset_mask[read_widx];
1167     const uintptr_t top_word = top_expand_mask[read_widx];
1168     ++read_widx;
1169     uintptr_t tmp_mid_write_word = 0;
1170     uintptr_t tmp_compact_write_word = 0;
1171     if (top_word) {
1172       const uint32_t top_set_ct = PopcountWord(top_word);
1173       uint32_t mid_idx_lowbits_end = mid_idx_lowbits + top_set_ct;
1174       uintptr_t tmp_mid_read_word = 0;
1175       if (mid_idx_lowbits != kBitsPerWord) {
1176         tmp_mid_read_word = mid_read_word >> mid_idx_lowbits;
1177       }
1178       if (mid_idx_lowbits_end > kBitsPerWord) {
1179         // be paranoid for now re: reading an extra word off the end of
1180         // mid_bitarr
1181         mid_read_word = *mid_read_iter++;
1182         tmp_mid_read_word |= mid_read_word << (kBitsPerWord - mid_idx_lowbits);
1183         mid_idx_lowbits_end -= kBitsPerWord;
1184       }
1185       tmp_mid_read_word = bzhi_max(tmp_mid_read_word, top_set_ct);
1186       mid_idx_lowbits = mid_idx_lowbits_end;
1187       if (tmp_mid_read_word) {
1188         const uint32_t mid_set_ct = PopcountWord(tmp_mid_read_word);
1189         uintptr_t tmp_compact_read_word;
1190         if (compact_idx_lowbits != kBitsPerWord) {
1191           const uint32_t compact_idx_lowbits_end = compact_idx_lowbits + mid_set_ct;
1192           tmp_compact_read_word = compact_read_word >> compact_idx_lowbits;
1193           // avoid reading off end of compact_bitarr here
1194           if (compact_idx_lowbits_end <= kBitsPerWord) {
1195             compact_idx_lowbits = compact_idx_lowbits_end;
1196           } else {
1197 #ifdef __arm__
1198 #  error "Unaligned accesses in ExpandThenSubsetBytearrNested()."
1199 #endif
1200             compact_read_word = *compact_bitarr_iter++;
1201             tmp_compact_read_word |= compact_read_word << (kBitsPerWord - compact_idx_lowbits);
1202             compact_idx_lowbits = compact_idx_lowbits_end - kBitsPerWord;
1203           }
1204         } else {
1205           // special case, can't right-shift 64
1206           compact_read_word = *compact_bitarr_iter++;
1207           compact_idx_lowbits = mid_set_ct;
1208           tmp_compact_read_word = compact_read_word;
1209         }
1210         tmp_compact_read_word = bzhi_max(tmp_compact_read_word, mid_set_ct);
1211 
1212         uintptr_t cur_masked_top = subset_word & top_word;
1213         while (cur_masked_top) {
1214           const uintptr_t cur_masked_top_and_arg = cur_masked_top - k1LU;
1215           const uintptr_t lowmask = (cur_masked_top ^ cur_masked_top_and_arg) >> 1;
1216           const uint32_t read_idx_offset = PopcountWord(top_word & lowmask);
1217           uintptr_t shifted_mid_read_word = tmp_mid_read_word >> read_idx_offset;
1218           if (shifted_mid_read_word & 1) {
1219             // bugfix (7 Sep 2017): forgot the "k1LU << " part of this
1220             const uintptr_t cur_bit = k1LU << PopcountWord(subset_word & lowmask);
1221             tmp_mid_write_word |= cur_bit;
1222             tmp_compact_write_word += cur_bit * ((tmp_compact_read_word >> (mid_set_ct - PopcountWord(shifted_mid_read_word))) & 1);
1223             if (shifted_mid_read_word == 1) {
1224               break;
1225             }
1226           }
1227           cur_masked_top &= cur_masked_top_and_arg;
1228         }
1229       }
1230       mid_write_word |= tmp_mid_write_word << write_idx_lowbits;
1231       compact_write_word |= tmp_compact_write_word << write_idx_lowbits;
1232     }
1233     const uint32_t write_idx_lowbits_end = write_idx_lowbits + PopcountWord(subset_word);
1234     if (write_idx_lowbits_end >= kBitsPerWord) {
1235       mid_target[write_widx] = mid_write_word;
1236       compact_target[write_widx] = compact_write_word;
1237       ++write_widx;
1238       if (write_idx_lowbits) {
1239         const uint32_t rshift = kBitsPerWord - write_idx_lowbits;
1240         mid_write_word = tmp_mid_write_word >> rshift;
1241         compact_write_word = tmp_compact_write_word >> rshift;
1242       } else {
1243         mid_write_word = 0;
1244         compact_write_word = 0;
1245       }
1246     }
1247     write_idx_lowbits = write_idx_lowbits_end % kBitsPerWord;
1248   }
1249   if (write_idx_lowbits) {
1250     mid_target[write_widx] = mid_write_word;
1251     compact_target[write_widx] = compact_write_word;
1252   }
1253 }
1254 #endif
PopcountBytes(const void * bitarr,uintptr_t byte_ct)1255 uintptr_t PopcountBytes(const void* bitarr, uintptr_t byte_ct) {
1256   const unsigned char* bitarr_uc = S_CAST(const unsigned char*, bitarr);
1257   const uint32_t lead_byte_ct = (-R_CAST(uintptr_t, bitarr_uc)) % kBytesPerVec;
1258   uintptr_t tot = 0;
1259   const uintptr_t* bitarr_iter;
1260   uint32_t trail_byte_ct;
1261   // bugfix: had wrong condition here
1262   if (byte_ct >= lead_byte_ct) {
1263 #ifdef __LP64__
1264     const uint32_t word_rem = lead_byte_ct % kBytesPerWord;
1265     if (word_rem) {
1266       tot = PopcountWord(ProperSubwordLoad(bitarr_uc, word_rem));
1267     }
1268     bitarr_iter = R_CAST(const uintptr_t*, &(bitarr_uc[word_rem]));
1269     if (lead_byte_ct >= kBytesPerWord) {
1270       tot += PopcountWord(*bitarr_iter++);
1271 #  ifdef USE_AVX2
1272       if (lead_byte_ct >= 2 * kBytesPerWord) {
1273         tot += PopcountWord(*bitarr_iter++);
1274         if (lead_byte_ct >= 3 * kBytesPerWord) {
1275           tot += PopcountWord(*bitarr_iter++);
1276         }
1277       }
1278 #  endif
1279     }
1280 #else
1281     if (lead_byte_ct) {
1282       tot = PopcountWord(ProperSubwordLoad(bitarr_uc, lead_byte_ct));
1283     }
1284     bitarr_iter = R_CAST(const uintptr_t*, &(bitarr_uc[lead_byte_ct]));
1285 #endif
1286     byte_ct -= lead_byte_ct;
1287     const uintptr_t word_ct = byte_ct / kBytesPerWord;
1288     // vec-alignment required here
1289     tot += PopcountWords(bitarr_iter, word_ct);
1290     bitarr_iter = &(bitarr_iter[word_ct]);
1291     trail_byte_ct = byte_ct % kBytesPerWord;
1292   } else {
1293     bitarr_iter = R_CAST(const uintptr_t*, bitarr_uc);
1294     // this may still be >= kBytesPerWord, so can't remove loop
1295     trail_byte_ct = byte_ct;
1296   }
1297   for (uint32_t bytes_remaining = trail_byte_ct; ; ) {
1298     uintptr_t cur_word;
1299     if (bytes_remaining < kBytesPerWord) {
1300       if (!bytes_remaining) {
1301         return tot;
1302       }
1303       cur_word = ProperSubwordLoad(bitarr_iter, bytes_remaining);
1304       bytes_remaining = 0;
1305     } else {
1306       cur_word = *bitarr_iter++;
1307       bytes_remaining -= kBytesPerWord;
1308     }
1309     tot += PopcountWord(cur_word);
1310   }
1311 }
1312 
PopcountBytesMasked(const void * bitarr,const uintptr_t * mask_arr,uintptr_t byte_ct)1313 uintptr_t PopcountBytesMasked(const void* bitarr, const uintptr_t* mask_arr, uintptr_t byte_ct) {
1314   // todo: try modifying PopcountWordsIntersect() to use unaligned load
1315   // instructions; then, if there is no performance penalty, try modifying this
1316   // main loop to call it.
1317   const uintptr_t word_ct = byte_ct / kBytesPerWord;
1318 #ifdef USE_SSE42
1319   const uintptr_t* bitarr_w = S_CAST(const uintptr_t*, bitarr);
1320   uintptr_t tot = 0;
1321   for (uintptr_t widx = 0; widx != word_ct; ++widx) {
1322     tot += PopcountWord(bitarr_w[widx] & mask_arr[widx]);
1323   }
1324   const uint32_t trail_byte_ct = byte_ct % kBytesPerWord;
1325   if (trail_byte_ct) {
1326     uintptr_t cur_word = ProperSubwordLoad(&(bitarr_w[word_ct]), trail_byte_ct);
1327     tot += PopcountWord(cur_word & mask_arr[word_ct]);
1328   }
1329   return tot;
1330 #else
1331   const uintptr_t* bitarr_iter = S_CAST(const uintptr_t*, bitarr);
1332   const uintptr_t mainblock_word_ct = word_ct - (word_ct % (24 / kBytesPerWord));
1333   const uintptr_t* bitarr_24b_end = &(bitarr_iter[mainblock_word_ct]);
1334   const uintptr_t* mask_arr_iter = mask_arr;
1335   uintptr_t tot = 0;
1336   while (bitarr_iter < bitarr_24b_end) {
1337     uintptr_t loader = (*bitarr_iter++) & (*mask_arr_iter++);
1338     uintptr_t ulii = loader - ((loader >> 1) & kMask5555);
1339     loader = (*bitarr_iter++) & (*mask_arr_iter++);
1340     uintptr_t uljj = loader - ((loader >> 1) & kMask5555);
1341     loader = (*bitarr_iter++) & (*mask_arr_iter++);
1342     ulii += (loader >> 1) & kMask5555;
1343     uljj += loader & kMask5555;
1344     ulii = (ulii & kMask3333) + ((ulii >> 2) & kMask3333);
1345     ulii += (uljj & kMask3333) + ((uljj >> 2) & kMask3333);
1346     uintptr_t tmp_stor = (ulii & kMask0F0F) + ((ulii >> 4) & kMask0F0F);
1347 
1348 #  ifndef __LP64__
1349     loader = (*bitarr_iter++) & (*mask_arr_iter++);
1350     ulii = loader - ((loader >> 1) & kMask5555);
1351     loader = (*bitarr_iter++) & (*mask_arr_iter++);
1352     uljj = loader - ((loader >> 1) & kMask5555);
1353     loader = (*bitarr_iter++) & (*mask_arr_iter++);
1354     ulii += (loader >> 1) & kMask5555;
1355     uljj += loader & kMask5555;
1356     ulii = (ulii & kMask3333) + ((ulii >> 2) & kMask3333);
1357     ulii += (uljj & kMask3333) + ((uljj >> 2) & kMask3333);
1358     tmp_stor += (ulii & kMask0F0F) + ((ulii >> 4) & kMask0F0F);
1359 #  endif
1360 
1361     // 32-bit case: each 8-bit slot stores a number in 0..48.  Multiplying by
1362     // 0x01010101 is equivalent to the left-shifts and adds we need to sum
1363     // those four 8-bit numbers in the high-order slot.
1364     // 64-bit case: each 8-bit slot stores a number in 0..24.
1365     tot += (tmp_stor * kMask0101) >> (kBitsPerWord - 8);
1366   }
1367   for (uint32_t trail_byte_ct = byte_ct - (mainblock_word_ct * kBytesPerWord); ; ) {
1368     uintptr_t cur_word;
1369     if (trail_byte_ct < kBytesPerWord) {
1370       if (!trail_byte_ct) {
1371         return tot;
1372       }
1373       cur_word = ProperSubwordLoad(bitarr_iter, trail_byte_ct);
1374       trail_byte_ct = 0;
1375     } else {
1376       cur_word = *bitarr_iter++;
1377       trail_byte_ct -= kBytesPerWord;
1378     }
1379     tot += PopcountWord(cur_word & (*mask_arr_iter++));
1380   }
1381 #endif
1382 }
1383 
FillCumulativePopcounts(const uintptr_t * subset_mask,uint32_t word_ct,uint32_t * cumulative_popcounts)1384 void FillCumulativePopcounts(const uintptr_t* subset_mask, uint32_t word_ct, uint32_t* cumulative_popcounts) {
1385   assert(word_ct);
1386   const uint32_t word_ct_m1 = word_ct - 1;
1387   uint32_t cur_sum = 0;
1388   for (uint32_t widx = 0; widx != word_ct_m1; ++widx) {
1389     cumulative_popcounts[widx] = cur_sum;
1390     cur_sum += PopcountWord(subset_mask[widx]);
1391   }
1392   cumulative_popcounts[word_ct_m1] = cur_sum;
1393 }
1394 
UidxsToIdxs(const uintptr_t * subset_mask,const uint32_t * subset_cumulative_popcounts,const uintptr_t idx_list_len,uint32_t * idx_list)1395 void UidxsToIdxs(const uintptr_t* subset_mask, const uint32_t* subset_cumulative_popcounts, const uintptr_t idx_list_len, uint32_t* idx_list) {
1396   uint32_t* idx_list_end = &(idx_list[idx_list_len]);
1397   for (uint32_t* idx_list_iter = idx_list; idx_list_iter != idx_list_end; ++idx_list_iter) {
1398     *idx_list_iter = RawToSubsettedPos(subset_mask, subset_cumulative_popcounts, *idx_list_iter);
1399   }
1400 }
1401 
Expand1bitTo8(const void * __restrict bytearr,uint32_t input_bit_ct,uint32_t incr,uintptr_t * __restrict dst)1402 void Expand1bitTo8(const void* __restrict bytearr, uint32_t input_bit_ct, uint32_t incr, uintptr_t* __restrict dst) {
1403   const unsigned char* bytearr_uc = S_CAST(const unsigned char*, bytearr);
1404   const uint32_t input_bit_ct_plus = input_bit_ct + kBytesPerWord - 1;
1405 #ifdef USE_SSE42
1406   const uint32_t input_byte_ct = input_bit_ct_plus / 8;
1407   const uint32_t fullvec_ct = input_byte_ct / (kBytesPerVec / 8);
1408   uint32_t byte_idx = 0;
1409   if (fullvec_ct) {
1410     const Vec8thUint* bytearr_alias = R_CAST(const Vec8thUint*, bytearr);
1411 #  ifdef USE_AVX2
1412     const VecUc byte_gather = R_CAST(VecUc, _mm256_setr_epi64x(0, kMask0101, 2 * kMask0101, 3 * kMask0101));
1413     const VecUc bit_mask = R_CAST(VecUc, _mm256_set1_epi64x(0x7fbfdfeff7fbfdfeLL));
1414 #  else
1415     const VecUc byte_gather = R_CAST(VecUc, _mm_setr_epi32(0, 0, 0x01010101, 0x01010101));
1416     const VecUc bit_mask = R_CAST(VecUc, _mm_set1_epi64x(0x7fbfdfeff7fbfdfeLL));
1417 #  endif
1418     const VecUc all1 = vecuc_set1(255);
1419     const VecUc subfrom = vecuc_set1(incr);
1420     VecUc* dst_alias = R_CAST(VecUc*, dst);
1421     for (uint32_t vec_idx = 0; vec_idx != fullvec_ct; ++vec_idx) {
1422 #  ifdef USE_AVX2
1423       VecUc vmask = R_CAST(VecUc, _mm256_set1_epi32(bytearr_alias[vec_idx]));
1424 #  else
1425       VecUc vmask = R_CAST(VecUc, _mm_set1_epi16(bytearr_alias[vec_idx]));
1426 #  endif
1427       vmask = vecuc_shuffle8(vmask, byte_gather);
1428       vmask = vmask | bit_mask;
1429       vmask = (vmask == all1);
1430       const VecUc result = subfrom - vmask;
1431       vecuc_storeu(&(dst_alias[vec_idx]), result);
1432     }
1433     byte_idx = fullvec_ct * (kBytesPerVec / 8);
1434   }
1435   const uintptr_t incr_word = incr * kMask0101;
1436   for (; byte_idx != input_byte_ct; ++byte_idx) {
1437     const uintptr_t input_byte = bytearr_uc[byte_idx];
1438 #  ifdef USE_AVX2
1439     const uintptr_t input_byte_scatter = _pdep_u64(input_byte, kMask0101);
1440 #  else
1441     const uintptr_t input_byte_scatter = (((input_byte & 0xfe) * 0x2040810204080LLU) & kMask0101) | (input_byte & 1);
1442 #  endif
1443     dst[byte_idx] = incr_word + input_byte_scatter;
1444   }
1445 #else
1446   const uintptr_t incr_word = incr * kMask0101;
1447 #  ifdef __LP64__
1448   const uint32_t input_byte_ct = input_bit_ct_plus / 8;
1449   for (uint32_t uii = 0; uii != input_byte_ct; ++uii) {
1450     // this operation maps binary hgfedcba to h0000000g0000000f...
1451     //                                        ^       ^       ^
1452     //                                        |       |       |
1453     //                                       56      48      40
1454     // 1. (cur_variant_include_word & 0xfe) gives us hgfedcb0; necessary to
1455     //    avoid carryover.
1456     // 2. multiply by the number with bits 7, 14, 21, ..., 49 set, to get
1457     //    hgfedcbhgfedcbhgf...
1458     //    ^       ^       ^
1459     //    |       |       |
1460     //   56      48      40
1461     // 3. mask out all but bits 8, 16, 24, ..., 56
1462     // todo: test if this actually beats the per-character loop...
1463     const uintptr_t input_byte = bytearr_uc[uii];
1464     const uintptr_t input_byte_scatter = (((input_byte & 0xfe) * 0x2040810204080LLU) & kMask0101) | (input_byte & 1);
1465     dst[uii] = incr_word + input_byte_scatter;
1466   }
1467 #  else
1468   const uint32_t fullbyte_ct = input_bit_ct_plus / 8;
1469   for (uint32_t uii = 0; uii != fullbyte_ct; ++uii) {
1470     // dcba -> d0000000c0000000b0000000a
1471     const uintptr_t input_byte = bytearr_uc[uii];
1472     uintptr_t input_byte_scatter = ((input_byte & 0xf) * 0x204081) & kMask0101;
1473     dst[2 * uii] = incr_word + input_byte_scatter;
1474     input_byte_scatter = ((input_byte >> 4) * 0x204081) & kMask0101;
1475     dst[2 * uii + 1] = incr_word + input_byte_scatter;
1476   }
1477   if (input_bit_ct_plus & 4) {
1478     uintptr_t input_byte = bytearr_uc[fullbyte_ct];
1479     // input_bit_ct mod 8 in 1..4, so high bits zeroed out
1480     uintptr_t input_byte_scatter = (input_byte * 0x204081) & kMask0101;
1481     dst[2 * fullbyte_ct] = incr_word + input_byte_scatter;
1482   }
1483 #  endif
1484 #endif
1485 }
1486 
Expand1bitTo16(const void * __restrict bytearr,uint32_t input_bit_ct,uint32_t incr,uintptr_t * __restrict dst)1487 void Expand1bitTo16(const void* __restrict bytearr, uint32_t input_bit_ct, uint32_t incr, uintptr_t* __restrict dst) {
1488   const unsigned char* bytearr_uc = S_CAST(const unsigned char*, bytearr);
1489 #ifdef USE_SSE42
1490   const uint32_t input_nybble_ct = DivUp(input_bit_ct, 4);
1491   const uint32_t fullvec_ct = input_nybble_ct / (kBytesPerVec / 8);
1492   uint32_t byte_idx = 0;
1493   if (fullvec_ct) {
1494     const Vec16thUint* bytearr_alias = R_CAST(const Vec16thUint*, bytearr);
1495 #  ifdef USE_AVX2
1496     const VecU16 byte_gather = R_CAST(VecU16, _mm256_setr_epi64x(0, 0, kMask0101, kMask0101));
1497     const VecU16 bit_mask = R_CAST(VecU16, _mm256_set_epi32(0xff7fffbfU, 0xffdfffefU, 0xfff7fffbU, 0xfffdfffeU, 0xff7fffbfU, 0xffdfffefU, 0xfff7fffbU, 0xfffdfffeU));
1498 #  else
1499     const VecU16 byte_gather = VCONST_S(0);
1500     const VecU16 bit_mask = R_CAST(VecU16, _mm_set_epi32(0xff7fffbfU, 0xffdfffefU, 0xfff7fffbU, 0xfffdfffeU));
1501 #  endif
1502     const VecU16 all1 = VCONST_S(0xffff);
1503     const VecU16 subfrom = vecu16_set1(incr);
1504     VecU16* dst_alias = R_CAST(VecU16*, dst);
1505     // todo: check whether this is actually any better than the non-vectorized
1506     // loop
1507     for (uint32_t vec_idx = 0; vec_idx != fullvec_ct; ++vec_idx) {
1508 #  ifdef USE_AVX2
1509       VecU16 vmask = R_CAST(VecU16, _mm256_set1_epi16(bytearr_alias[vec_idx]));
1510 #  else
1511       VecU16 vmask = R_CAST(VecU16, _mm_set1_epi8(bytearr_alias[vec_idx]));
1512 #  endif
1513       vmask = vecu16_shuffle8(vmask, byte_gather);
1514       vmask = vmask | bit_mask;
1515       vmask = (vmask == all1);
1516       const VecU16 result = subfrom - vmask;
1517       vecu16_storeu(&(dst_alias[vec_idx]), result);
1518     }
1519     byte_idx = fullvec_ct * (kBytesPerVec / 16);
1520   }
1521   const uintptr_t incr_word = incr * kMask0001;
1522   const uint32_t fullbyte_ct = input_nybble_ct / 2;
1523   for (; byte_idx != fullbyte_ct; ++byte_idx) {
1524     const uintptr_t input_byte = bytearr_uc[byte_idx];
1525     const uintptr_t input_byte_scatter = input_byte * 0x200040008001LLU;
1526     const uintptr_t write0 = input_byte_scatter & kMask0001;
1527     const uintptr_t write1 = (input_byte_scatter >> 4) & kMask0001;
1528     dst[2 * byte_idx] = incr_word + write0;
1529     dst[2 * byte_idx + 1] = incr_word + write1;
1530   }
1531   if (input_nybble_ct % 2) {
1532     const uintptr_t input_byte = bytearr_uc[byte_idx];
1533     const uintptr_t write0 = (input_byte * 0x200040008001LLU) & kMask0001;
1534     dst[input_nybble_ct - 1] = incr_word + write0;
1535   }
1536 #else
1537   const uintptr_t incr_word = incr * kMask0001;
1538 #  ifdef __LP64__
1539   const uint32_t input_nybble_ct = DivUp(input_bit_ct, 4);
1540   const uint32_t fullbyte_ct = input_nybble_ct / 2;
1541   for (uint32_t uii = 0; uii != fullbyte_ct; ++uii) {
1542     const uintptr_t input_byte = bytearr_uc[uii];
1543     const uintptr_t input_byte_scatter = input_byte * 0x200040008001LLU;
1544     const uintptr_t write0 = input_byte_scatter & kMask0001;
1545     const uintptr_t write1 = (input_byte_scatter >> 4) & kMask0001;
1546     dst[2 * uii] = incr_word + write0;
1547     dst[2 * uii + 1] = incr_word + write1;
1548   }
1549   if (input_nybble_ct % 2) {
1550     const uintptr_t input_byte = bytearr_uc[fullbyte_ct];
1551     const uintptr_t write0 = (input_byte * 0x200040008001LLU) & kMask0001;
1552     dst[input_nybble_ct - 1] = incr_word + write0;
1553   }
1554 #  else
1555   const uint32_t fullbyte_ct = input_bit_ct / 8;
1556   for (uint32_t uii = 0; uii != fullbyte_ct; ++uii) {
1557     uintptr_t input_byte = bytearr_uc[uii];
1558     const uintptr_t input_byte_scatter = input_byte * 0x8001;
1559     dst[4 * uii] = (input_byte_scatter & kMask0001) + incr_word;
1560     dst[4 * uii + 1] = ((input_byte_scatter >> 2) & kMask0001) + incr_word;
1561     dst[4 * uii + 2] = ((input_byte_scatter >> 4) & kMask0001) + incr_word;
1562     dst[4 * uii + 3] = ((input_byte_scatter >> 6) & kMask0001) + incr_word;
1563   }
1564   const uint32_t remainder = input_bit_ct % 8;
1565   if (remainder) {
1566     uintptr_t input_byte = bytearr_uc[fullbyte_ct];
1567     uint16_t* dst_alias = R_CAST(uint16_t*, &(dst[4 * fullbyte_ct]));
1568     for (uint32_t uii = 0; uii < remainder; ++uii) {
1569       dst_alias[uii] = (input_byte & 1) + incr;
1570       input_byte = input_byte >> 1;
1571     }
1572   }
1573 #  endif
1574 #endif
1575 }
1576 
1577 static_assert(kPglBitTransposeBatch == S_CAST(uint32_t, kBitsPerCacheline), "TransposeBitblock64() needs to be updated.");
1578 #ifdef __LP64__
TransposeBitblock64(const uintptr_t * read_iter,uintptr_t read_ul_stride,uintptr_t write_ul_stride,uint32_t read_row_ct,uint32_t write_row_ct,uintptr_t * write_iter,VecW * __restrict buf0,VecW * __restrict buf1)1579 void TransposeBitblock64(const uintptr_t* read_iter, uintptr_t read_ul_stride, uintptr_t write_ul_stride, uint32_t read_row_ct, uint32_t write_row_ct, uintptr_t* write_iter, VecW* __restrict buf0, VecW* __restrict buf1) {
1580   // We need to perform the equivalent of 9 shuffles (assuming a full-size
1581   // 512x512 bitblock).
1582   // The first shuffles are performed by the ingestion loop: we write the first
1583   // word from every row to buf0, then the second word from every row, etc.,
1584   // yielding
1585   //   (0,0) ...   (0,63)  (1,0) ...   (1,63)  (2,0) ...   (511,63)
1586   //   (0,64) ...  (0,127) (1,64) ...  (1,127) (2,64) ...  (511,127)
1587   //   ...
1588   //   (0,448) ... (0,511) (1,448) ... (1,511) (2,448) ... (511,511)
1589   // in terms of the original bit positions.
1590   // Since each input row has 8 words, this amounts to 3 shuffles.
1591   //
1592   // The second step writes
1593   //   (0,0) (0,1) ... (0,7)   (1,0) (1,1) ... (1,7) ...   (511,7)
1594   //   (0,8) (0,9) ... (0,15)  (1,8) (1,9) ... (1,15) ...  (511,15)
1595   //   ...
1596   //   (0,504) ...     (0,511) (1,504) ...     (1,511) ... (511,511)
1597   // to buf1, performing the equivalent of 3 shuffles, and the third step
1598   // finishes the transpose using movemask.
1599   //
1600   // buf0 and buf1 must both be 32KiB vector-aligned buffers.
1601 
1602   const uint32_t buf0_row_ct = DivUp(write_row_ct, 64);
1603   {
1604     uintptr_t* buf0_ul = R_CAST(uintptr_t*, buf0);
1605     const uint32_t zfill_ct = (-read_row_ct) & 63;
1606     for (uint32_t bidx = 0; bidx != buf0_row_ct; ++bidx) {
1607       const uintptr_t* read_iter_tmp = &(read_iter[bidx]);
1608       uintptr_t* buf0_row_start = &(buf0_ul[512 * bidx]);
1609       for (uint32_t uii = 0; uii != read_row_ct; ++uii) {
1610         buf0_row_start[uii] = *read_iter_tmp;
1611         read_iter_tmp = &(read_iter_tmp[read_ul_stride]);
1612       }
1613       // This is a simple way of fulfilling the trailing-zero part of the
1614       // function contract.
1615       // (   buf0 rows zeroed out to 512 bytes
1616       //  -> buf1 rows zeroed out to 64 bytes
1617       //  -> output rows zeroed out to 8 bytes)
1618       ZeroWArr(zfill_ct, &(buf0_row_start[read_row_ct]));
1619     }
1620   }
1621   // Each width-unit corresponds to 64 input rows.
1622   const uint32_t buf_row_xwidth = DivUp(read_row_ct, 64);
1623   {
1624     const VecW* buf0_read_iter = buf0;
1625     uintptr_t* write_iter0 = R_CAST(uintptr_t*, buf1);
1626 #  ifdef USE_SSE42
1627     const VecW gather_u16s = vecw_setr8(0, 8, 1, 9, 2, 10, 3, 11,
1628                                         4, 12, 5, 13, 6, 14, 7, 15);
1629 #    ifdef USE_AVX2
1630     const VecW gather_u32s = vecw_setr8(0, 1, 8, 9, 2, 3, 10, 11,
1631                                         4, 5, 12, 13, 6, 7, 14, 15);
1632 #    endif
1633 #  else
1634     const VecW m8 = VCONST_W(kMask00FF);
1635 #  endif
1636     const uint32_t buf0_row_clwidth = buf_row_xwidth * 8;
1637     for (uint32_t bidx = 0; bidx != buf0_row_ct; ++bidx) {
1638       uintptr_t* write_iter1 = &(write_iter0[64]);
1639       uintptr_t* write_iter2 = &(write_iter1[64]);
1640       uintptr_t* write_iter3 = &(write_iter2[64]);
1641       uintptr_t* write_iter4 = &(write_iter3[64]);
1642       uintptr_t* write_iter5 = &(write_iter4[64]);
1643       uintptr_t* write_iter6 = &(write_iter5[64]);
1644       uintptr_t* write_iter7 = &(write_iter6[64]);
1645       for (uint32_t clidx = 0; clidx != buf0_row_clwidth; ++clidx) {
1646 #  ifdef USE_AVX2
1647         VecW loader0 = buf0_read_iter[clidx * 2];
1648         VecW loader1 = buf0_read_iter[clidx * 2 + 1];
1649         //    (0,0) (0,1) ... (0,7) (1,0) (1,1) ... (1,7) (2,0) ... (3,7)
1650         // -> (0,0) (1,0) (0,1) (1,1) (0,2) .... (1,7) (2,0) (3,0) (2,1) ...
1651         loader0 = vecw_shuffle8(loader0, gather_u16s);
1652         loader1 = vecw_shuffle8(loader1, gather_u16s);
1653         // -> (0,0) (1,0) (0,1) (1,1) (0,2) (1,2) (0,3) (1,3) (2,0) (3,0) ...
1654         VecW vec_lo = vecw_permute0xd8_if_avx2(loader0);
1655         VecW vec_hi = vecw_permute0xd8_if_avx2(loader1);
1656         // -> (0,0) (1,0) (2,0) (3,0) (0,1) (1,1) (2,1) (3,1) (0,2) ...
1657         vec_lo = vecw_shuffle8(vec_lo, gather_u32s);
1658         // -> (4,0) (5,0) (6,0) (7,0) (4,1) (5,1) (6,1) (7,1) (4,2) ...
1659         vec_hi = vecw_shuffle8(vec_hi, gather_u32s);
1660         const VecW final0145 = vecw_unpacklo32(vec_lo, vec_hi);
1661         const VecW final2367 = vecw_unpackhi32(vec_lo, vec_hi);
1662         write_iter0[clidx] = vecw_extract64_0(final0145);
1663         write_iter1[clidx] = vecw_extract64_1(final0145);
1664         write_iter2[clidx] = vecw_extract64_0(final2367);
1665         write_iter3[clidx] = vecw_extract64_1(final2367);
1666         write_iter4[clidx] = vecw_extract64_2(final0145);
1667         write_iter5[clidx] = vecw_extract64_3(final0145);
1668         write_iter6[clidx] = vecw_extract64_2(final2367);
1669         write_iter7[clidx] = vecw_extract64_3(final2367);
1670 #  else  // !USE_AVX2
1671         VecW loader0 = buf0_read_iter[clidx * 4];
1672         VecW loader1 = buf0_read_iter[clidx * 4 + 1];
1673         VecW loader2 = buf0_read_iter[clidx * 4 + 2];
1674         VecW loader3 = buf0_read_iter[clidx * 4 + 3];
1675         //    (0,0) (0,1) ... (0,7) (1,0) (1,1) ... (1,7)
1676         // -> (0,0) (1,0) (0,1) (1,1) (0,2) ... (1,7)
1677 #    ifdef USE_SSE42
1678         loader0 = vecw_shuffle8(loader0, gather_u16s);
1679         loader1 = vecw_shuffle8(loader1, gather_u16s);
1680         loader2 = vecw_shuffle8(loader2, gather_u16s);
1681         loader3 = vecw_shuffle8(loader3, gather_u16s);
1682 #    else
1683         VecW tmp_lo = vecw_unpacklo8(loader0, loader1);
1684         VecW tmp_hi = vecw_unpackhi8(loader0, loader1);
1685         loader0 = vecw_blendv(vecw_slli(tmp_hi, 8), tmp_lo, m8);
1686         loader1 = vecw_blendv(tmp_hi, vecw_srli(tmp_lo, 8), m8);
1687         tmp_lo = vecw_unpacklo8(loader2, loader3);
1688         tmp_hi = vecw_unpackhi8(loader2, loader3);
1689         loader2 = vecw_blendv(vecw_slli(tmp_hi, 8), tmp_lo, m8);
1690         loader3 = vecw_blendv(tmp_hi, vecw_srli(tmp_lo, 8), m8);
1691 #    endif
1692         // -> (0,0) (1,0) (2,0) (3,0) (0,1) (1,1) (2,1) (3,1) (0,2) ...
1693         const VecW lo_0123 = vecw_unpacklo16(loader0, loader1);
1694         // -> (0,4) (1,4) (2,4) (3,4) (0,5) (1,5) (2,5) (3,5) (0,6) ...
1695         const VecW lo_4567 = vecw_unpackhi16(loader0, loader1);
1696         const VecW hi_0123 = vecw_unpacklo16(loader2, loader3);
1697         const VecW hi_4567 = vecw_unpackhi16(loader2, loader3);
1698 
1699         VecW final01 = vecw_unpacklo32(lo_0123, hi_0123);
1700         VecW final23 = vecw_unpackhi32(lo_0123, hi_0123);
1701         VecW final45 = vecw_unpacklo32(lo_4567, hi_4567);
1702         VecW final67 = vecw_unpackhi32(lo_4567, hi_4567);
1703         write_iter0[clidx] = vecw_extract64_0(final01);
1704         write_iter1[clidx] = vecw_extract64_1(final01);
1705         write_iter2[clidx] = vecw_extract64_0(final23);
1706         write_iter3[clidx] = vecw_extract64_1(final23);
1707         write_iter4[clidx] = vecw_extract64_0(final45);
1708         write_iter5[clidx] = vecw_extract64_1(final45);
1709         write_iter6[clidx] = vecw_extract64_0(final67);
1710         write_iter7[clidx] = vecw_extract64_1(final67);
1711 #  endif  // !USE_AVX2
1712       }
1713       buf0_read_iter = &(buf0_read_iter[512 / kWordsPerVec]);
1714       write_iter0 = &(write_iter7[64]);
1715     }
1716   }
1717   const VecW* buf1_read_iter = buf1;
1718   const uint32_t write_v8ui_stride = kVec8thUintPerWord * write_ul_stride;
1719   const uint32_t buf1_fullrow_ct = write_row_ct / 8;
1720   const uint32_t buf1_row_vecwidth = buf_row_xwidth * (8 / kWordsPerVec);
1721   Vec8thUint* write_iter0 = R_CAST(Vec8thUint*, write_iter);
1722   for (uint32_t bidx = 0; bidx != buf1_fullrow_ct; ++bidx) {
1723     Vec8thUint* write_iter1 = &(write_iter0[write_v8ui_stride]);
1724     Vec8thUint* write_iter2 = &(write_iter1[write_v8ui_stride]);
1725     Vec8thUint* write_iter3 = &(write_iter2[write_v8ui_stride]);
1726     Vec8thUint* write_iter4 = &(write_iter3[write_v8ui_stride]);
1727     Vec8thUint* write_iter5 = &(write_iter4[write_v8ui_stride]);
1728     Vec8thUint* write_iter6 = &(write_iter5[write_v8ui_stride]);
1729     Vec8thUint* write_iter7 = &(write_iter6[write_v8ui_stride]);
1730     for (uint32_t vidx = 0; vidx != buf1_row_vecwidth; ++vidx) {
1731       VecW loader = buf1_read_iter[vidx];
1732       write_iter7[vidx] = vecw_movemask(loader);
1733       loader = vecw_slli(loader, 1);
1734       write_iter6[vidx] = vecw_movemask(loader);
1735       loader = vecw_slli(loader, 1);
1736       write_iter5[vidx] = vecw_movemask(loader);
1737       loader = vecw_slli(loader, 1);
1738       write_iter4[vidx] = vecw_movemask(loader);
1739       loader = vecw_slli(loader, 1);
1740       write_iter3[vidx] = vecw_movemask(loader);
1741       loader = vecw_slli(loader, 1);
1742       write_iter2[vidx] = vecw_movemask(loader);
1743       loader = vecw_slli(loader, 1);
1744       write_iter1[vidx] = vecw_movemask(loader);
1745       loader = vecw_slli(loader, 1);
1746       write_iter0[vidx] = vecw_movemask(loader);
1747     }
1748     buf1_read_iter = &(buf1_read_iter[64 / kWordsPerVec]);
1749     write_iter0 = &(write_iter7[write_v8ui_stride]);
1750   }
1751   const uint32_t row_ct_rem = write_row_ct % 8;
1752   if (!row_ct_rem) {
1753     return;
1754   }
1755   const uint32_t lshift = 8 - row_ct_rem;
1756   Vec8thUint* write_iter_last = &(write_iter0[write_v8ui_stride * (row_ct_rem - 1)]);
1757   for (uint32_t vidx = 0; vidx != buf1_row_vecwidth; ++vidx) {
1758     VecW loader = buf1_read_iter[vidx];
1759     loader = vecw_slli(loader, lshift);
1760     Vec8thUint* inner_write_iter = &(write_iter_last[vidx]);
1761     for (uint32_t uii = 0; uii != row_ct_rem; ++uii) {
1762       *inner_write_iter = vecw_movemask(loader);
1763       loader = vecw_slli(loader, 1);
1764       inner_write_iter -= write_v8ui_stride;
1765     }
1766   }
1767 }
1768 #else  // !__LP64__
1769 static_assert(kWordsPerVec == 1, "TransposeBitblock32() needs to be updated.");
TransposeBitblock32(const uintptr_t * read_iter,uintptr_t read_ul_stride,uintptr_t write_ul_stride,uint32_t read_batch_size,uint32_t write_batch_size,uintptr_t * write_iter,VecW * __restrict buf0,VecW * __restrict buf1)1770 void TransposeBitblock32(const uintptr_t* read_iter, uintptr_t read_ul_stride, uintptr_t write_ul_stride, uint32_t read_batch_size, uint32_t write_batch_size, uintptr_t* write_iter, VecW* __restrict buf0, VecW* __restrict buf1) {
1771   // buf must be vector-aligned and have size 64k
1772   const uint32_t initial_read_byte_ct = DivUp(write_batch_size, CHAR_BIT);
1773   // fold the first 6 shuffles into the initial ingestion loop
1774   const unsigned char* initial_read_iter = R_CAST(const unsigned char*, read_iter);
1775   const unsigned char* initial_read_end = &(initial_read_iter[initial_read_byte_ct]);
1776   unsigned char* initial_target_iter = R_CAST(unsigned char*, buf0);
1777   const uint32_t read_byte_stride = read_ul_stride * kBytesPerWord;
1778   const uint32_t read_batch_rem = kBitsPerCacheline - read_batch_size;
1779   for (; initial_read_iter != initial_read_end; ++initial_read_iter) {
1780     const unsigned char* read_iter_tmp = initial_read_iter;
1781     for (uint32_t ujj = 0; ujj != read_batch_size; ++ujj) {
1782       *initial_target_iter++ = *read_iter_tmp;
1783       read_iter_tmp = &(read_iter_tmp[read_byte_stride]);
1784     }
1785     initial_target_iter = memsetua(initial_target_iter, 0, read_batch_rem);
1786   }
1787 
1788   // third-to-last shuffle, 8 bit spacing -> 4
1789   const VecW* source_iter = buf0;
1790   uintptr_t* target_iter0 = buf1;
1791   const uint32_t write_word_ct = BitCtToWordCt(read_batch_size);
1792   const uint32_t first_inner_loop_iter_ct = 4 * write_word_ct;
1793   uint32_t cur_write_skip = 4 * kWordsPerCacheline - first_inner_loop_iter_ct;
1794   // coincidentally, this also needs to run DivUp(write_batch_size, CHAR_BIT)
1795   // times
1796   for (uint32_t uii = 0; uii != initial_read_byte_ct; ++uii) {
1797     uintptr_t* target_iter1 = &(target_iter0[kWordsPerCacheline * 4]);
1798     for (uint32_t ujj = 0; ujj != first_inner_loop_iter_ct; ++ujj) {
1799       const uintptr_t source_word_lo = *source_iter++;
1800       const uintptr_t source_word_hi = *source_iter++;
1801       uintptr_t target_word0_lo = source_word_lo & kMask0F0F;
1802       uintptr_t target_word1_lo = (source_word_lo >> 4) & kMask0F0F;
1803       uintptr_t target_word0_hi = source_word_hi & kMask0F0F;
1804       uintptr_t target_word1_hi = (source_word_hi >> 4) & kMask0F0F;
1805       target_word0_lo = (target_word0_lo | (target_word0_lo >> 4)) & kMask00FF;
1806       target_word1_lo = (target_word1_lo | (target_word1_lo >> 4)) & kMask00FF;
1807       target_word0_hi = (target_word0_hi | (target_word0_hi >> 4)) & kMask00FF;
1808       target_word1_hi = (target_word1_hi | (target_word1_hi >> 4)) & kMask00FF;
1809       target_word0_lo = target_word0_lo | (target_word0_lo >> kBitsPerWordD4);
1810       target_word1_lo = target_word1_lo | (target_word1_lo >> kBitsPerWordD4);
1811       target_word0_hi = target_word0_hi | (target_word0_hi >> kBitsPerWordD4);
1812       target_word1_hi = target_word1_hi | (target_word1_hi >> kBitsPerWordD4);
1813       *target_iter0++ = S_CAST(Halfword, target_word0_lo) | (target_word0_hi << kBitsPerWordD2);
1814       *target_iter1++ = S_CAST(Halfword, target_word1_lo) | (target_word1_hi << kBitsPerWordD2);
1815     }
1816     source_iter = &(source_iter[2 * cur_write_skip]);
1817     target_iter0 = &(target_iter1[cur_write_skip]);
1818   }
1819 
1820   // second-to-last shuffle, 4 bit spacing -> 2
1821   source_iter = buf1;
1822   target_iter0 = buf0;
1823   const uint32_t second_outer_loop_iter_ct = DivUp(write_batch_size, 4);
1824   const uint32_t second_inner_loop_iter_ct = 2 * write_word_ct;
1825   cur_write_skip = 2 * kWordsPerCacheline - second_inner_loop_iter_ct;
1826   for (uint32_t uii = 0; uii != second_outer_loop_iter_ct; ++uii) {
1827     uintptr_t* target_iter1 = &(target_iter0[kWordsPerCacheline * 2]);
1828     for (uint32_t ujj = 0; ujj != second_inner_loop_iter_ct; ++ujj) {
1829       const uintptr_t source_word_lo = *source_iter++;
1830       const uintptr_t source_word_hi = *source_iter++;
1831       uintptr_t target_word0_lo = source_word_lo & kMask3333;
1832       uintptr_t target_word1_lo = (source_word_lo >> 2) & kMask3333;
1833       uintptr_t target_word0_hi = source_word_hi & kMask3333;
1834       uintptr_t target_word1_hi = (source_word_hi >> 2) & kMask3333;
1835       target_word0_lo = (target_word0_lo | (target_word0_lo >> 2)) & kMask0F0F;
1836       target_word1_lo = (target_word1_lo | (target_word1_lo >> 2)) & kMask0F0F;
1837       target_word0_hi = (target_word0_hi | (target_word0_hi >> 2)) & kMask0F0F;
1838       target_word1_hi = (target_word1_hi | (target_word1_hi >> 2)) & kMask0F0F;
1839       target_word0_lo = (target_word0_lo | (target_word0_lo >> 4)) & kMask00FF;
1840       target_word1_lo = (target_word1_lo | (target_word1_lo >> 4)) & kMask00FF;
1841       target_word0_hi = (target_word0_hi | (target_word0_hi >> 4)) & kMask00FF;
1842       target_word1_hi = (target_word1_hi | (target_word1_hi >> 4)) & kMask00FF;
1843       target_word0_lo = target_word0_lo | (target_word0_lo >> kBitsPerWordD4);
1844       target_word1_lo = target_word1_lo | (target_word1_lo >> kBitsPerWordD4);
1845       target_word0_hi = target_word0_hi | (target_word0_hi >> kBitsPerWordD4);
1846       target_word1_hi = target_word1_hi | (target_word1_hi >> kBitsPerWordD4);
1847       *target_iter0++ = S_CAST(Halfword, target_word0_lo) | (target_word0_hi << kBitsPerWordD2);
1848       *target_iter1++ = S_CAST(Halfword, target_word1_lo) | (target_word1_hi << kBitsPerWordD2);
1849     }
1850     source_iter = &(source_iter[2 * cur_write_skip]);
1851     target_iter0 = &(target_iter1[cur_write_skip]);
1852   }
1853   // last shuffle, 2 bit spacing -> 1
1854   source_iter = buf0;
1855   target_iter0 = write_iter;
1856   const uint32_t last_loop_iter_ct = DivUp(write_batch_size, 2);
1857   for (uint32_t uii = 0; uii != last_loop_iter_ct; ++uii) {
1858     uintptr_t* target_iter1 = &(target_iter0[write_ul_stride]);
1859     for (uint32_t ujj = 0; ujj != write_word_ct; ++ujj) {
1860       const uintptr_t source_word_lo = S_CAST(uintptr_t, *source_iter++);
1861       const uintptr_t source_word_hi = S_CAST(uintptr_t, *source_iter++);
1862       uintptr_t target_word0_lo = source_word_lo & kMask5555;
1863       uintptr_t target_word1_lo = (source_word_lo >> 1) & kMask5555;
1864       uintptr_t target_word0_hi = source_word_hi & kMask5555;
1865       uintptr_t target_word1_hi = (source_word_hi >> 1) & kMask5555;
1866       target_word0_lo = (target_word0_lo | (target_word0_lo >> 1)) & kMask3333;
1867       target_word1_lo = (target_word1_lo | (target_word1_lo >> 1)) & kMask3333;
1868       target_word0_hi = (target_word0_hi | (target_word0_hi >> 1)) & kMask3333;
1869       target_word1_hi = (target_word1_hi | (target_word1_hi >> 1)) & kMask3333;
1870       target_word0_lo = (target_word0_lo | (target_word0_lo >> 2)) & kMask0F0F;
1871       target_word1_lo = (target_word1_lo | (target_word1_lo >> 2)) & kMask0F0F;
1872       target_word0_hi = (target_word0_hi | (target_word0_hi >> 2)) & kMask0F0F;
1873       target_word1_hi = (target_word1_hi | (target_word1_hi >> 2)) & kMask0F0F;
1874       target_word0_lo = (target_word0_lo | (target_word0_lo >> 4)) & kMask00FF;
1875       target_word1_lo = (target_word1_lo | (target_word1_lo >> 4)) & kMask00FF;
1876       target_word0_hi = (target_word0_hi | (target_word0_hi >> 4)) & kMask00FF;
1877       target_word1_hi = (target_word1_hi | (target_word1_hi >> 4)) & kMask00FF;
1878       target_word0_lo = target_word0_lo | (target_word0_lo >> kBitsPerWordD4);
1879       target_word1_lo = target_word1_lo | (target_word1_lo >> kBitsPerWordD4);
1880       target_word0_hi = target_word0_hi | (target_word0_hi >> kBitsPerWordD4);
1881       target_word1_hi = target_word1_hi | (target_word1_hi >> kBitsPerWordD4);
1882       target_iter0[ujj] = S_CAST(Halfword, target_word0_lo) | (target_word0_hi << kBitsPerWordD2);
1883       target_iter1[ujj] = S_CAST(Halfword, target_word1_lo) | (target_word1_hi << kBitsPerWordD2);
1884     }
1885     source_iter = &(source_iter[2 * (kWordsPerCacheline - write_word_ct)]);
1886     target_iter0 = &(target_iter1[write_ul_stride]);
1887   }
1888 }
1889 #endif  // !__LP64__
1890 
1891 #ifdef __LP64__
TransposeNybbleblock(const uintptr_t * read_iter,uint32_t read_ul_stride,uint32_t write_ul_stride,uint32_t read_batch_size,uint32_t write_batch_size,uintptr_t * __restrict write_iter,VecW * vecaligned_buf)1892 void TransposeNybbleblock(const uintptr_t* read_iter, uint32_t read_ul_stride, uint32_t write_ul_stride, uint32_t read_batch_size, uint32_t write_batch_size, uintptr_t* __restrict write_iter, VecW* vecaligned_buf) {
1893   // Very similar to TransposeNypblock64() in pgenlib_internal.
1894   // vecaligned_buf must be vector-aligned and have size 8k
1895   const uint32_t buf_row_ct = DivUp(write_batch_size, 8);
1896   // fold the first 4 shuffles into the initial ingestion loop
1897   const uint32_t* initial_read_iter = R_CAST(const uint32_t*, read_iter);
1898   const uint32_t* initial_read_end = &(initial_read_iter[buf_row_ct]);
1899   uint32_t* initial_target_iter = R_CAST(uint32_t*, vecaligned_buf);
1900   const uint32_t read_u32_stride = read_ul_stride * (kBytesPerWord / 4);
1901   const uint32_t read_batch_rem = kNybblesPerCacheline - read_batch_size;
1902   for (; initial_read_iter != initial_read_end; ++initial_read_iter) {
1903     const uint32_t* read_iter_tmp = initial_read_iter;
1904     for (uint32_t ujj = 0; ujj != read_batch_size; ++ujj) {
1905       *initial_target_iter++ = *read_iter_tmp;
1906       read_iter_tmp = &(read_iter_tmp[read_u32_stride]);
1907     }
1908     if (!read_batch_rem) {
1909       continue;
1910     }
1911     memset(initial_target_iter, 0, read_batch_rem * 4);
1912     initial_target_iter = &(initial_target_iter[read_batch_rem]);
1913   }
1914 
1915   // 32 bit spacing -> 4
1916   const VecW* source_iter = vecaligned_buf;
1917   const VecW m4 = VCONST_W(kMask0F0F);
1918   const uint32_t buf_fullrow_ct = write_batch_size / 8;
1919   const uint32_t eightword_ct = DivUp(read_batch_size, 16);
1920   uintptr_t* target_iter0 = write_iter;
1921   uint32_t cur_dst_row_ct = 8;
1922 #  ifdef USE_SSE42
1923   const VecW gather_u16s = vecw_setr8(0, 8, 1, 9, 2, 10, 3, 11,
1924                                       4, 12, 5, 13, 6, 14, 7, 15);
1925 #  else
1926   const VecW m8 = VCONST_W(kMask00FF);
1927 #  endif
1928 #  ifdef USE_AVX2
1929   // movemask is slower even in AVX2 case
1930   const VecW gather_u32s = vecw_setr8(0, 1, 8, 9, 2, 3, 10, 11,
1931                                       4, 5, 12, 13, 6, 7, 14, 15);
1932   for (uint32_t buf_row_idx = 0; ; ++buf_row_idx) {
1933     if (buf_row_idx >= buf_fullrow_ct) {
1934       if (buf_row_idx == buf_row_ct) {
1935         return;
1936       }
1937       cur_dst_row_ct = write_batch_size % 8;
1938     }
1939     uintptr_t* target_iter1 = &(target_iter0[write_ul_stride]);
1940     uintptr_t* target_iter2 = &(target_iter1[write_ul_stride]);
1941     uintptr_t* target_iter3 = &(target_iter2[write_ul_stride]);
1942     uintptr_t* target_iter4 = &(target_iter3[write_ul_stride]);
1943     uintptr_t* target_iter5 = &(target_iter4[write_ul_stride]);
1944     uintptr_t* target_iter6 = &(target_iter5[write_ul_stride]);
1945     uintptr_t* target_iter7 = &(target_iter6[write_ul_stride]);
1946     for (uint32_t dvidx = 0; dvidx != eightword_ct; ++dvidx) {
1947       const VecW loader0 = source_iter[dvidx * 2];
1948       const VecW loader1 = source_iter[dvidx * 2 + 1];
1949       VecW even_nybbles0 = loader0 & m4;
1950       VecW odd_nybbles0 = vecw_and_notfirst(m4, loader0);
1951       VecW even_nybbles1 = loader1 & m4;
1952       VecW odd_nybbles1 = vecw_and_notfirst(m4, loader1);
1953       even_nybbles0 = even_nybbles0 | vecw_srli(even_nybbles0, 28);
1954       odd_nybbles0 = vecw_slli(odd_nybbles0, 28) | odd_nybbles0;
1955       even_nybbles1 = even_nybbles1 | vecw_srli(even_nybbles1, 28);
1956       odd_nybbles1 = vecw_slli(odd_nybbles1, 28) | odd_nybbles1;
1957       // Label the bytes in even_nybbles0 (0, 1, 2, ..., 31), and the bytes in
1958       // even_nybbles1 (32, 33, ..., 63).  We wish to generate the following
1959       // lane-and-vector-crossing permutation:
1960       //   (0, 8, 16, 24, 32, 40, 48, 56, 1, 9, 17, 25, 33, 41, 49, 57)
1961       //   (2, 10, 18, 26, 34, 42, 50, 58, 3, 11, 19, 27, 35, 43, 51, 59)
1962 
1963       // first shuffle:
1964       //   (0, 8, 1, 9, 2, 10, 3, 11, _, _, _, _, _, _, _, _,
1965       //    16, 24, 17, 25, 18, 26, 19, 27, _, _, _, _, _, _, _, _)
1966       //
1967       //   (32, 40, 33, 41, 34, 42, 35, 43, _, _, _, _, _, _, _, _,
1968       //    48, 56, 49, 57, 50, 58, 51, 59, _, _, _, _, _, _, _, _)
1969       //
1970       // _mm256_unpacklo_epi16:
1971       //   (0, 8, 32, 40, 1, 9, 33, 41, 2, 10, 34, 42, 3, 11, 35, 43,
1972       //    16, 24, 48, 56, 17, 25, 49, 57, 18, 26, 50, 58, 19, 27, 51, 59)
1973       //
1974       // {0, 2, 1, 3} permute:
1975       //   (0, 8, 32, 40, 1, 9, 33, 41, 16, 24, 48, 56, 17, 25, 49, 57,
1976       //    2, 10, 34, 42, 3, 11, 35, 43, 18, 26, 50, 58, 19, 27, 51, 59)
1977       //
1978       // final shuffle gives us what we want.
1979       even_nybbles0 = vecw_shuffle8(even_nybbles0, gather_u16s);
1980       odd_nybbles0 = vecw_shuffle8(odd_nybbles0, gather_u16s);
1981       even_nybbles1 = vecw_shuffle8(even_nybbles1, gather_u16s);
1982       odd_nybbles1 = vecw_shuffle8(odd_nybbles1, gather_u16s);
1983 
1984       VecW target_even = vecw_unpacklo16(even_nybbles0, even_nybbles1);
1985       VecW target_odd = vecw_unpackhi16(odd_nybbles0, odd_nybbles1);
1986 
1987       target_even = vecw_permute0xd8_if_avx2(target_even);
1988       target_odd = vecw_permute0xd8_if_avx2(target_odd);
1989 
1990       target_even = vecw_shuffle8(target_even, gather_u32s);
1991       target_odd = vecw_shuffle8(target_odd, gather_u32s);
1992 
1993       // tried using _mm_stream_si64 here, that totally sucked
1994       switch (cur_dst_row_ct) {
1995         case 8:
1996           target_iter7[dvidx] = vecw_extract64_3(target_odd);
1997           // fall through
1998         case 7:
1999           target_iter6[dvidx] = vecw_extract64_3(target_even);
2000           // fall through
2001         case 6:
2002           target_iter5[dvidx] = vecw_extract64_2(target_odd);
2003           // fall through
2004         case 5:
2005           target_iter4[dvidx] = vecw_extract64_2(target_even);
2006           // fall through
2007         case 4:
2008           target_iter3[dvidx] = vecw_extract64_1(target_odd);
2009           // fall through
2010         case 3:
2011           target_iter2[dvidx] = vecw_extract64_1(target_even);
2012           // fall through
2013         case 2:
2014           target_iter1[dvidx] = vecw_extract64_0(target_odd);
2015           // fall through
2016         default:
2017           target_iter0[dvidx] = vecw_extract64_0(target_even);
2018       }
2019     }
2020     source_iter = &(source_iter[(4 * kPglNybbleTransposeBatch) / kBytesPerVec]);
2021     target_iter0 = &(target_iter7[write_ul_stride]);
2022   }
2023 #  else  // !USE_AVX2
2024   for (uint32_t buf_row_idx = 0; ; ++buf_row_idx) {
2025     if (buf_row_idx >= buf_fullrow_ct) {
2026       if (buf_row_idx == buf_row_ct) {
2027         return;
2028       }
2029       cur_dst_row_ct = write_batch_size % 8;
2030     }
2031     uintptr_t* target_iter1 = &(target_iter0[write_ul_stride]);
2032     uintptr_t* target_iter2 = &(target_iter1[write_ul_stride]);
2033     uintptr_t* target_iter3 = &(target_iter2[write_ul_stride]);
2034     uintptr_t* target_iter4 = &(target_iter3[write_ul_stride]);
2035     uintptr_t* target_iter5 = &(target_iter4[write_ul_stride]);
2036     uintptr_t* target_iter6 = &(target_iter5[write_ul_stride]);
2037     uintptr_t* target_iter7 = &(target_iter6[write_ul_stride]);
2038     for (uint32_t qvidx = 0; qvidx != eightword_ct; ++qvidx) {
2039       const VecW loader0 = source_iter[qvidx * 4];
2040       const VecW loader1 = source_iter[qvidx * 4 + 1];
2041       const VecW loader2 = source_iter[qvidx * 4 + 2];
2042       const VecW loader3 = source_iter[qvidx * 4 + 3];
2043       VecW even_nybbles0 = loader0 & m4;
2044       VecW odd_nybbles0 = vecw_and_notfirst(m4, loader0);
2045       VecW even_nybbles1 = loader1 & m4;
2046       VecW odd_nybbles1 = vecw_and_notfirst(m4, loader1);
2047       VecW even_nybbles2 = loader2 & m4;
2048       VecW odd_nybbles2 = vecw_and_notfirst(m4, loader2);
2049       VecW even_nybbles3 = loader3 & m4;
2050       VecW odd_nybbles3 = vecw_and_notfirst(m4, loader3);
2051       even_nybbles0 = even_nybbles0 | vecw_srli(even_nybbles0, 28);
2052       odd_nybbles0 = vecw_slli(odd_nybbles0, 28) | odd_nybbles0;
2053       even_nybbles1 = even_nybbles1 | vecw_srli(even_nybbles1, 28);
2054       odd_nybbles1 = vecw_slli(odd_nybbles1, 28) | odd_nybbles1;
2055       even_nybbles2 = even_nybbles2 | vecw_srli(even_nybbles2, 28);
2056       odd_nybbles2 = vecw_slli(odd_nybbles2, 28) | odd_nybbles2;
2057       even_nybbles3 = even_nybbles3 | vecw_srli(even_nybbles3, 28);
2058       odd_nybbles3 = vecw_slli(odd_nybbles3, 28) | odd_nybbles3;
2059       // Label the bytes in even_nybbles0 (0, 1, 2, ..., 15), the bytes in
2060       // even_nybbles1 (16, 17, ..., 31), ..., up to even_nybbles3 being (48,
2061       // 49, ..., 63).  We wish to generate the following vector-crossing
2062       // permutation:
2063       //   (0, 8, 16, 24, 32, 40, 48, 56, 1, 9, 17, 25, 33, 41, 49, 57)
2064       //   (2, 10, 18, 26, 34, 42, 50, 58, 3, 11, 19, 27, 35, 43, 51, 59)
2065 
2066       // first shuffle:
2067       //   (0, 8, 1, 9, 2, 10, 3, 11, _, _, _, _, _, _, _, _)
2068       //   (16, 24, 17, 25, 18, 26, 19, 27, _, _, _, _, _, _, _, _)
2069       //   (32, 40, 33, 41, 34, 42, 35, 43, _, _, _, _, _, _, _, _)
2070       //   (48, 56, 49, 57, 50, 58, 51, 59, _, _, _, _, _, _, _, _)
2071 
2072       // _mm_unpacklo_epi16:
2073       //   (0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27)
2074       //   (32, 40, 48, 56, 33, 41, 49, 57, 34, 42, 50, 58, 35, 43, 51, 59)
2075       //
2076       // finish with _mm_unpack{lo,hi}_epi32
2077 #    ifdef USE_SSE42
2078       even_nybbles0 = vecw_shuffle8(even_nybbles0, gather_u16s);
2079       odd_nybbles0 = vecw_shuffle8(odd_nybbles0, gather_u16s);
2080       even_nybbles1 = vecw_shuffle8(even_nybbles1, gather_u16s);
2081       odd_nybbles1 = vecw_shuffle8(odd_nybbles1, gather_u16s);
2082       even_nybbles2 = vecw_shuffle8(even_nybbles2, gather_u16s);
2083       odd_nybbles2 = vecw_shuffle8(odd_nybbles2, gather_u16s);
2084       even_nybbles3 = vecw_shuffle8(even_nybbles3, gather_u16s);
2085       odd_nybbles3 = vecw_shuffle8(odd_nybbles3, gather_u16s);
2086 #    else
2087       VecW tmp_lo = vecw_unpacklo8(even_nybbles0, odd_nybbles0);
2088       VecW tmp_hi = vecw_unpackhi8(even_nybbles0, odd_nybbles0);
2089       even_nybbles0 = vecw_blendv(vecw_slli(tmp_hi, 8), tmp_lo, m8);
2090       odd_nybbles0 = vecw_blendv(tmp_hi, vecw_srli(tmp_lo, 8), m8);
2091       tmp_lo = vecw_unpacklo8(even_nybbles1, odd_nybbles1);
2092       tmp_hi = vecw_unpackhi8(even_nybbles1, odd_nybbles1);
2093       even_nybbles1 = vecw_blendv(vecw_slli(tmp_hi, 8), tmp_lo, m8);
2094       odd_nybbles1 = vecw_blendv(tmp_hi, vecw_srli(tmp_lo, 8), m8);
2095       tmp_lo = vecw_unpacklo8(even_nybbles2, odd_nybbles2);
2096       tmp_hi = vecw_unpackhi8(even_nybbles2, odd_nybbles2);
2097       even_nybbles2 = vecw_blendv(vecw_slli(tmp_hi, 8), tmp_lo, m8);
2098       odd_nybbles2 = vecw_blendv(tmp_hi, vecw_srli(tmp_lo, 8), m8);
2099       tmp_lo = vecw_unpacklo8(even_nybbles3, odd_nybbles3);
2100       tmp_hi = vecw_unpackhi8(even_nybbles3, odd_nybbles3);
2101       even_nybbles3 = vecw_blendv(vecw_slli(tmp_hi, 8), tmp_lo, m8);
2102       odd_nybbles3 = vecw_blendv(tmp_hi, vecw_srli(tmp_lo, 8), m8);
2103 #    endif
2104 
2105       const VecW even_lo = vecw_unpacklo16(even_nybbles0, even_nybbles1);
2106       const VecW odd_lo = vecw_unpackhi16(odd_nybbles0, odd_nybbles1);
2107       const VecW even_hi = vecw_unpacklo16(even_nybbles2, even_nybbles3);
2108       const VecW odd_hi = vecw_unpackhi16(odd_nybbles2, odd_nybbles3);
2109 
2110       const VecW final02 = vecw_unpacklo32(even_lo, even_hi);
2111       const VecW final13 = vecw_unpacklo32(odd_lo, odd_hi);
2112       const VecW final46 = vecw_unpackhi32(even_lo, even_hi);
2113       const VecW final57 = vecw_unpackhi32(odd_lo, odd_hi);
2114       switch (cur_dst_row_ct) {
2115         case 8:
2116           target_iter7[qvidx] = vecw_extract64_1(final57);
2117           // fall through
2118         case 7:
2119           target_iter6[qvidx] = vecw_extract64_1(final46);
2120           // fall through
2121         case 6:
2122           target_iter5[qvidx] = vecw_extract64_0(final57);
2123           // fall through
2124         case 5:
2125           target_iter4[qvidx] = vecw_extract64_0(final46);
2126           // fall through
2127         case 4:
2128           target_iter3[qvidx] = vecw_extract64_1(final13);
2129           // fall through
2130         case 3:
2131           target_iter2[qvidx] = vecw_extract64_1(final02);
2132           // fall through
2133         case 2:
2134           target_iter1[qvidx] = vecw_extract64_0(final13);
2135           // fall through
2136         default:
2137           target_iter0[qvidx] = vecw_extract64_0(final02);
2138       }
2139     }
2140     source_iter = &(source_iter[(4 * kPglNybbleTransposeBatch) / kBytesPerVec]);
2141     target_iter0 = &(target_iter7[write_ul_stride]);
2142   }
2143 #  endif  // !USE_AVX2
2144 }
2145 #else  // !__LP64__
2146 static_assert(kWordsPerVec == 1, "TransposeNybbleblock() needs to be updated.");
TransposeNybbleblock(const uintptr_t * read_iter,uint32_t read_ul_stride,uint32_t write_ul_stride,uint32_t read_batch_size,uint32_t write_batch_size,uintptr_t * __restrict write_iter,VecW * vecaligned_buf)2147 void TransposeNybbleblock(const uintptr_t* read_iter, uint32_t read_ul_stride, uint32_t write_ul_stride, uint32_t read_batch_size, uint32_t write_batch_size, uintptr_t* __restrict write_iter, VecW* vecaligned_buf) {
2148   // Very similar to TransposeNypblock32() in pgenlib_internal.
2149   // vecaligned_buf must be vector-aligned and have size 8k
2150   const uint32_t buf_row_ct = NybbleCtToByteCt(write_batch_size);
2151   // fold the first 6 shuffles into the initial ingestion loop
2152   const unsigned char* initial_read_iter = R_CAST(const unsigned char*, read_iter);
2153   const unsigned char* initial_read_end = &(initial_read_iter[buf_row_ct]);
2154   unsigned char* initial_target_iter = R_CAST(unsigned char*, vecaligned_buf);
2155   const uint32_t read_byte_stride = read_ul_stride * kBytesPerWord;
2156   const uint32_t read_batch_rem = kNybblesPerCacheline - read_batch_size;
2157   for (; initial_read_iter != initial_read_end; ++initial_read_iter) {
2158     const unsigned char* read_iter_tmp = initial_read_iter;
2159     for (uint32_t ujj = 0; ujj != read_batch_size; ++ujj) {
2160       *initial_target_iter++ = *read_iter_tmp;
2161       read_iter_tmp = &(read_iter_tmp[read_byte_stride]);
2162     }
2163     initial_target_iter = memsetua(initial_target_iter, 0, read_batch_rem);
2164   }
2165 
2166   // 8 bit spacing -> 4
2167   const VecW* source_iter = vecaligned_buf;
2168   uintptr_t* target_iter0 = write_iter;
2169   const uint32_t buf_fullrow_ct = write_batch_size / 2;
2170   const uint32_t write_word_ct = NybbleCtToWordCt(read_batch_size);
2171   for (uint32_t uii = 0; uii != buf_fullrow_ct; ++uii) {
2172     uintptr_t* target_iter1 = &(target_iter0[write_ul_stride]);
2173     for (uint32_t ujj = 0; ujj != write_word_ct; ++ujj) {
2174       const uintptr_t source_word_lo = *source_iter++;
2175       const uintptr_t source_word_hi = *source_iter++;
2176       uintptr_t target_word0_lo = source_word_lo & kMask0F0F;
2177       uintptr_t target_word1_lo = (source_word_lo >> 4) & kMask0F0F;
2178       uintptr_t target_word0_hi = source_word_hi & kMask0F0F;
2179       uintptr_t target_word1_hi = (source_word_hi >> 4) & kMask0F0F;
2180       target_word0_lo = (target_word0_lo | (target_word0_lo >> 4)) & kMask00FF;
2181       target_word1_lo = (target_word1_lo | (target_word1_lo >> 4)) & kMask00FF;
2182       target_word0_hi = (target_word0_hi | (target_word0_hi >> 4)) & kMask00FF;
2183       target_word1_hi = (target_word1_hi | (target_word1_hi >> 4)) & kMask00FF;
2184       target_word0_lo = target_word0_lo | (target_word0_lo >> kBitsPerWordD4);
2185       target_word1_lo = target_word1_lo | (target_word1_lo >> kBitsPerWordD4);
2186       target_word0_hi = target_word0_hi | (target_word0_hi >> kBitsPerWordD4);
2187       target_word1_hi = target_word1_hi | (target_word1_hi >> kBitsPerWordD4);
2188       target_iter0[ujj] = S_CAST(Halfword, target_word0_lo) | (target_word0_hi << kBitsPerWordD2);
2189       target_iter1[ujj] = S_CAST(Halfword, target_word1_lo) | (target_word1_hi << kBitsPerWordD2);
2190     }
2191     source_iter = &(source_iter[2 * (kWordsPerCacheline - write_word_ct)]);
2192     target_iter0 = &(target_iter1[write_ul_stride]);
2193   }
2194   const uint32_t remainder = write_batch_size % 2;
2195   if (!remainder) {
2196     return;
2197   }
2198   for (uint32_t ujj = 0; ujj != write_word_ct; ++ujj) {
2199     const uintptr_t source_word_lo = *source_iter++;
2200     const uintptr_t source_word_hi = *source_iter++;
2201     uintptr_t target_word0_lo = source_word_lo & kMask0F0F;
2202     uintptr_t target_word0_hi = source_word_hi & kMask0F0F;
2203     target_word0_lo = (target_word0_lo | (target_word0_lo >> 4)) & kMask00FF;
2204     target_word0_hi = (target_word0_hi | (target_word0_hi >> 4)) & kMask00FF;
2205     target_word0_lo = target_word0_lo | (target_word0_lo >> kBitsPerWordD4);
2206     target_word0_hi = target_word0_hi | (target_word0_hi >> kBitsPerWordD4);
2207     target_iter0[ujj] = S_CAST(Halfword, target_word0_lo) | (target_word0_hi << kBitsPerWordD2);
2208   }
2209 }
2210 #endif  // !__LP64__
2211 
2212 #ifdef __LP64__
2213 #  ifdef USE_AVX2
2214 const unsigned char kLeadMask[2 * kBytesPerVec] __attribute__ ((aligned (64))) =
2215   {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2216    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2217    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
2218    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255};
2219 #  else
2220 const unsigned char kLeadMask[2 * kBytesPerVec] __attribute__ ((aligned (32))) =
2221   {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2222    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255};
2223 #  endif
2224 
BytesumArr(const void * bytearr,uintptr_t byte_ct)2225 uintptr_t BytesumArr(const void* bytearr, uintptr_t byte_ct) {
2226   uintptr_t tot = 0;
2227   if (byte_ct < kBytesPerVec) {
2228     const unsigned char* bytearr_uc = S_CAST(const unsigned char*, bytearr);
2229     for (uintptr_t ulii = 0; ulii != byte_ct; ++ulii) {
2230       tot += bytearr_uc[ulii];
2231     }
2232     return tot;
2233   }
2234   const unsigned char* bytearr_uc_iter = S_CAST(const unsigned char*, bytearr);
2235   const unsigned char* bytearr_uc_final = &(bytearr_uc_iter[byte_ct - kBytesPerVec]);
2236   const VecW m0 = vecw_setzero();
2237   VecW acc = vecw_setzero();
2238   while (bytearr_uc_iter < bytearr_uc_final) {
2239     const VecW cur_vec = vecw_loadu(bytearr_uc_iter);
2240     acc = acc + vecw_sad(cur_vec, m0);
2241     bytearr_uc_iter = &(bytearr_uc_iter[kBytesPerVec]);
2242   }
2243   VecW cur_vec = vecw_loadu(bytearr_uc_final);
2244   const uintptr_t overlap_byte_ct = bytearr_uc_iter - bytearr_uc_final;
2245   const VecW mask_vec = vecw_loadu(&(kLeadMask[kBytesPerVec - overlap_byte_ct]));
2246   cur_vec = cur_vec & mask_vec;
2247   acc = acc + vecw_sad(cur_vec, m0);
2248   return HsumW(acc);
2249 }
2250 
2251 #else  // !__LP64__
BytesumArr(const void * bytearr,uintptr_t byte_ct)2252 uintptr_t BytesumArr(const void* bytearr, uintptr_t byte_ct) {
2253   // Assumes sum < 2^32.
2254 #  ifdef __arm__
2255 #    error "Unaligned accesses in BytesumArr()."
2256 #  endif
2257   const uint32_t word_ct = byte_ct / kBytesPerWord;
2258   const uintptr_t* bytearr_alias_iter = S_CAST(const uintptr_t*, bytearr);
2259   const uint32_t wordblock_idx_trail = word_ct / 256;
2260   const uint32_t wordblock_idx_end = DivUp(word_ct, 256);
2261   uint32_t wordblock_len = 256;
2262   uintptr_t tot = 0;
2263   for (uint32_t wordblock_idx = 0; ; ++wordblock_idx) {
2264     if (wordblock_idx >= wordblock_idx_trail) {
2265       if (wordblock_idx == wordblock_idx_end) {
2266         byte_ct = byte_ct % kBytesPerWord;
2267         const unsigned char* bytearr_alias_iter2 = R_CAST(const unsigned char*, bytearr_alias_iter);
2268         for (uint32_t uii = 0; uii != byte_ct; ++uii) {
2269           tot += bytearr_alias_iter2[uii];
2270         }
2271         return tot;
2272       }
2273       wordblock_len = word_ct % 256;
2274     }
2275     const uintptr_t* bytearr_alias_stop = &(bytearr_alias_iter[wordblock_len]);
2276     uintptr_t acc_even = 0;
2277     uintptr_t acc_odd = 0;
2278     do {
2279       uintptr_t cur_word = *bytearr_alias_iter++;
2280       acc_even += cur_word & kMask00FF;
2281       acc_odd += (cur_word >> 8) & kMask00FF;
2282     } while (bytearr_alias_iter < bytearr_alias_stop);
2283     acc_even = S_CAST(Halfword, acc_even) + (acc_even >> kBitsPerWordD2);
2284     acc_odd = S_CAST(Halfword, acc_odd) + (acc_odd >> kBitsPerWordD2);
2285     tot += acc_even + acc_odd;
2286   }
2287 }
2288 #endif  // !__LP64__
2289 
CountByte(const void * bytearr,unsigned char ucc,uintptr_t byte_ct)2290 uintptr_t CountByte(const void* bytearr, unsigned char ucc, uintptr_t byte_ct) {
2291 #ifdef __LP64__
2292   if (byte_ct < kBytesPerVec) {
2293 #endif
2294     const unsigned char* bytearr_uc = S_CAST(const unsigned char*, bytearr);
2295     uintptr_t tot = 0;
2296     for (uintptr_t ulii = 0; ulii != byte_ct; ++ulii) {
2297       tot += (bytearr_uc[ulii] == ucc);
2298     }
2299     return tot;
2300 #ifdef __LP64__
2301   }
2302   const unsigned char* bytearr_uc_iter = S_CAST(const unsigned char*, bytearr);
2303   const VecW m0 = vecw_setzero();
2304   const VecUc match_vvec = vecuc_set1(ucc);
2305   VecW acc = vecw_setzero();
2306   while (byte_ct > 255 * kBytesPerVec) {
2307     VecUc inner_acc = vecuc_setzero();
2308     for (uint32_t uii = 0; uii != 255; ++uii) {
2309       const VecUc cur_vvec = vecuc_loadu(bytearr_uc_iter);
2310       bytearr_uc_iter = &(bytearr_uc_iter[kBytesPerVec]);
2311       inner_acc = inner_acc - (cur_vvec == match_vvec);
2312     }
2313     acc = acc + vecw_sad(R_CAST(VecW, inner_acc), m0);
2314     byte_ct -= 255 * kBytesPerVec;
2315   }
2316   const unsigned char* bytearr_uc_final = &(bytearr_uc_iter[byte_ct - kBytesPerVec]);
2317   VecUc inner_acc = vecuc_setzero();
2318   while (bytearr_uc_iter < bytearr_uc_final) {
2319     const VecUc cur_vvec = vecuc_loadu(bytearr_uc_iter);
2320     bytearr_uc_iter = &(bytearr_uc_iter[kBytesPerVec]);
2321     inner_acc = inner_acc - (cur_vvec == match_vvec);
2322   }
2323   VecUc cur_vvec = vecuc_loadu(bytearr_uc_final);
2324   const uintptr_t overlap_byte_ct = bytearr_uc_iter - bytearr_uc_final;
2325   const VecUc mask_vvec = vecuc_loadu(&(kLeadMask[kBytesPerVec - overlap_byte_ct]));
2326   cur_vvec = (cur_vvec == match_vvec) & mask_vvec;
2327   inner_acc = inner_acc - cur_vvec;
2328   acc = acc + vecw_sad(R_CAST(VecW, inner_acc), m0);
2329   return HsumW(acc);
2330 #endif  // __LP64__
2331 }
2332 
CountU16(const void * u16arr,uint16_t usii,uintptr_t u16_ct)2333 uintptr_t CountU16(const void* u16arr, uint16_t usii, uintptr_t u16_ct) {
2334 #ifdef __LP64__
2335   if (u16_ct < (kBytesPerVec / 2)) {
2336 #endif
2337     const uint16_t* u16arr_alias = S_CAST(const uint16_t*, u16arr);
2338     uintptr_t tot = 0;
2339     for (uintptr_t ulii = 0; ulii != u16_ct; ++ulii) {
2340       tot += (u16arr_alias[ulii] == usii);
2341     }
2342     return tot;
2343 #ifdef __LP64__
2344   }
2345   const uint16_t* u16arr_iter = S_CAST(const uint16_t*, u16arr);
2346   const VecW m0 = vecw_setzero();
2347   const VecU16 match_vvec = vecu16_set1(usii);
2348   VecW acc = vecw_setzero();
2349   // can also use larger loop and a slightly different accumulation algorithm,
2350   // but it should make practically no difference; lets keep these loops as
2351   // similar as possible for now.
2352   while (u16_ct > 255 * (kBytesPerVec / 2)) {
2353     VecU16 inner_acc = vecu16_setzero();
2354     for (uint32_t uii = 0; uii != 255; ++uii) {
2355       const VecU16 cur_vvec = vecu16_loadu(u16arr_iter);
2356       u16arr_iter = &(u16arr_iter[kBytesPerVec / 2]);
2357       inner_acc = inner_acc - (cur_vvec == match_vvec);
2358     }
2359     acc = acc + vecw_sad(R_CAST(VecW, inner_acc), m0);
2360     u16_ct -= 255 * (kBytesPerVec / 2);
2361   }
2362   const uint16_t* u16arr_final = &(u16arr_iter[u16_ct - (kBytesPerVec / 2)]);
2363   VecU16 inner_acc = vecu16_setzero();
2364   while (u16arr_iter < u16arr_final) {
2365     const VecU16 cur_vvec = vecu16_loadu(u16arr_iter);
2366     u16arr_iter = &(u16arr_iter[kBytesPerVec / 2]);
2367     inner_acc = inner_acc - (cur_vvec == match_vvec);
2368   }
2369   VecU16 cur_vvec = vecu16_loadu(u16arr_final);
2370   const uintptr_t overlap_u16_ct = u16arr_iter - u16arr_final;
2371   const VecU16 mask_vvec = vecu16_loadu(&(kLeadMask[kBytesPerVec - 2 * overlap_u16_ct]));
2372   cur_vvec = (cur_vvec == match_vvec) & mask_vvec;
2373   inner_acc = inner_acc - cur_vvec;
2374   acc = acc + vecw_sad(R_CAST(VecW, inner_acc), m0);
2375   return HsumW(acc);
2376 #endif  // __LP64__
2377 }
2378 
Copy1bit8Subset(const uintptr_t * __restrict src_subset,const void * __restrict src_vals,const uintptr_t * __restrict sample_include,uint32_t src_subset_size,uint32_t sample_ct,uintptr_t * __restrict dst_subset,void * __restrict dst_vals)2379 uint32_t Copy1bit8Subset(const uintptr_t* __restrict src_subset, const void* __restrict src_vals, const uintptr_t* __restrict sample_include, uint32_t src_subset_size, uint32_t sample_ct, uintptr_t* __restrict dst_subset, void* __restrict dst_vals) {
2380   if (!src_subset_size) {
2381     return 0;
2382   }
2383   CopyBitarrSubset(src_subset, sample_include, sample_ct, dst_subset);
2384   const unsigned char* src_vals_uc = S_CAST(const unsigned char*, src_vals);
2385   unsigned char* dst_vals_uc = S_CAST(unsigned char*, dst_vals);
2386   unsigned char* dst_vals_iter = dst_vals_uc;
2387   uintptr_t sample_widx = 0;
2388   uintptr_t src_subset_bits = src_subset[0];
2389   for (uint32_t src_idx = 0; src_idx != src_subset_size; ++src_idx) {
2390     const uintptr_t lowbit = BitIter1y(src_subset, &sample_widx, &src_subset_bits);
2391     if (sample_include[sample_widx] & lowbit) {
2392       *dst_vals_iter++ = src_vals_uc[src_idx];
2393     }
2394   }
2395   return dst_vals_iter - dst_vals_uc;
2396 }
2397 
Copy1bit16Subset(const uintptr_t * __restrict src_subset,const void * __restrict src_vals,const uintptr_t * __restrict sample_include,uint32_t src_subset_size,uint32_t sample_ct,uintptr_t * __restrict dst_subset,void * __restrict dst_vals)2398 uint32_t Copy1bit16Subset(const uintptr_t* __restrict src_subset, const void* __restrict src_vals, const uintptr_t* __restrict sample_include, uint32_t src_subset_size, uint32_t sample_ct, uintptr_t* __restrict dst_subset, void* __restrict dst_vals) {
2399   if (!src_subset_size) {
2400     return 0;
2401   }
2402   CopyBitarrSubset(src_subset, sample_include, sample_ct, dst_subset);
2403   const uint16_t* src_vals_u16 = S_CAST(const uint16_t*, src_vals);
2404   uint16_t* dst_vals_u16 = S_CAST(uint16_t*, dst_vals);
2405   uint16_t* dst_vals_iter = dst_vals_u16;
2406   uintptr_t sample_widx = 0;
2407   uintptr_t src_subset_bits = src_subset[0];
2408   for (uint32_t src_idx = 0; src_idx != src_subset_size; ++src_idx) {
2409     const uintptr_t lowbit = BitIter1y(src_subset, &sample_widx, &src_subset_bits);
2410     if (sample_include[sample_widx] & lowbit) {
2411       *dst_vals_iter++ = src_vals_u16[src_idx];
2412     }
2413   }
2414   return dst_vals_iter - dst_vals_u16;
2415 }
2416 
2417 #ifdef __cplusplus
2418 }  // namespace plink2
2419 #endif
2420