1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 /* Function for soft heap */
11 
12 #include <faiss/utils/Heap.h>
13 
14 namespace faiss {
15 
16 template <typename C>
heapify()17 void HeapArray<C>::heapify() {
18 #pragma omp parallel for
19     for (int64_t j = 0; j < nh; j++)
20         heap_heapify<C>(k, val + j * k, ids + j * k);
21 }
22 
23 template <typename C>
reorder()24 void HeapArray<C>::reorder() {
25 #pragma omp parallel for
26     for (int64_t j = 0; j < nh; j++)
27         heap_reorder<C>(k, val + j * k, ids + j * k);
28 }
29 
30 template <typename C>
addn(size_t nj,const T * vin,TI j0,size_t i0,int64_t ni)31 void HeapArray<C>::addn(size_t nj, const T* vin, TI j0, size_t i0, int64_t ni) {
32     if (ni == -1)
33         ni = nh;
34     assert(i0 >= 0 && i0 + ni <= nh);
35 #pragma omp parallel for
36     for (int64_t i = i0; i < i0 + ni; i++) {
37         T* __restrict simi = get_val(i);
38         TI* __restrict idxi = get_ids(i);
39         const T* ip_line = vin + (i - i0) * nj;
40 
41         for (size_t j = 0; j < nj; j++) {
42             T ip = ip_line[j];
43             if (C::cmp(simi[0], ip)) {
44                 heap_replace_top<C>(k, simi, idxi, ip, j + j0);
45             }
46         }
47     }
48 }
49 
50 template <typename C>
addn_with_ids(size_t nj,const T * vin,const TI * id_in,int64_t id_stride,size_t i0,int64_t ni)51 void HeapArray<C>::addn_with_ids(
52         size_t nj,
53         const T* vin,
54         const TI* id_in,
55         int64_t id_stride,
56         size_t i0,
57         int64_t ni) {
58     if (id_in == nullptr) {
59         addn(nj, vin, 0, i0, ni);
60         return;
61     }
62     if (ni == -1)
63         ni = nh;
64     assert(i0 >= 0 && i0 + ni <= nh);
65 #pragma omp parallel for
66     for (int64_t i = i0; i < i0 + ni; i++) {
67         T* __restrict simi = get_val(i);
68         TI* __restrict idxi = get_ids(i);
69         const T* ip_line = vin + (i - i0) * nj;
70         const TI* id_line = id_in + (i - i0) * id_stride;
71 
72         for (size_t j = 0; j < nj; j++) {
73             T ip = ip_line[j];
74             if (C::cmp(simi[0], ip)) {
75                 heap_replace_top<C>(k, simi, idxi, ip, id_line[j]);
76             }
77         }
78     }
79 }
80 
81 template <typename C>
per_line_extrema(T * out_val,TI * out_ids) const82 void HeapArray<C>::per_line_extrema(T* out_val, TI* out_ids) const {
83 #pragma omp parallel for
84     for (int64_t j = 0; j < nh; j++) {
85         int64_t imin = -1;
86         typename C::T xval = C::Crev::neutral();
87         const typename C::T* x_ = val + j * k;
88         for (size_t i = 0; i < k; i++)
89             if (C::cmp(x_[i], xval)) {
90                 xval = x_[i];
91                 imin = i;
92             }
93         if (out_val)
94             out_val[j] = xval;
95 
96         if (out_ids) {
97             if (ids && imin != -1)
98                 out_ids[j] = ids[j * k + imin];
99             else
100                 out_ids[j] = imin;
101         }
102     }
103 }
104 
105 // explicit instanciations
106 
107 template struct HeapArray<CMin<float, int64_t>>;
108 template struct HeapArray<CMax<float, int64_t>>;
109 template struct HeapArray<CMin<int, int64_t>>;
110 template struct HeapArray<CMax<int, int64_t>>;
111 
112 } // namespace faiss
113