1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 // -*- c++ -*-
9
10 /* Function for soft heap */
11
12 #include <faiss/utils/Heap.h>
13
14 namespace faiss {
15
16 template <typename C>
heapify()17 void HeapArray<C>::heapify() {
18 #pragma omp parallel for
19 for (int64_t j = 0; j < nh; j++)
20 heap_heapify<C>(k, val + j * k, ids + j * k);
21 }
22
23 template <typename C>
reorder()24 void HeapArray<C>::reorder() {
25 #pragma omp parallel for
26 for (int64_t j = 0; j < nh; j++)
27 heap_reorder<C>(k, val + j * k, ids + j * k);
28 }
29
30 template <typename C>
addn(size_t nj,const T * vin,TI j0,size_t i0,int64_t ni)31 void HeapArray<C>::addn(size_t nj, const T* vin, TI j0, size_t i0, int64_t ni) {
32 if (ni == -1)
33 ni = nh;
34 assert(i0 >= 0 && i0 + ni <= nh);
35 #pragma omp parallel for
36 for (int64_t i = i0; i < i0 + ni; i++) {
37 T* __restrict simi = get_val(i);
38 TI* __restrict idxi = get_ids(i);
39 const T* ip_line = vin + (i - i0) * nj;
40
41 for (size_t j = 0; j < nj; j++) {
42 T ip = ip_line[j];
43 if (C::cmp(simi[0], ip)) {
44 heap_replace_top<C>(k, simi, idxi, ip, j + j0);
45 }
46 }
47 }
48 }
49
50 template <typename C>
addn_with_ids(size_t nj,const T * vin,const TI * id_in,int64_t id_stride,size_t i0,int64_t ni)51 void HeapArray<C>::addn_with_ids(
52 size_t nj,
53 const T* vin,
54 const TI* id_in,
55 int64_t id_stride,
56 size_t i0,
57 int64_t ni) {
58 if (id_in == nullptr) {
59 addn(nj, vin, 0, i0, ni);
60 return;
61 }
62 if (ni == -1)
63 ni = nh;
64 assert(i0 >= 0 && i0 + ni <= nh);
65 #pragma omp parallel for
66 for (int64_t i = i0; i < i0 + ni; i++) {
67 T* __restrict simi = get_val(i);
68 TI* __restrict idxi = get_ids(i);
69 const T* ip_line = vin + (i - i0) * nj;
70 const TI* id_line = id_in + (i - i0) * id_stride;
71
72 for (size_t j = 0; j < nj; j++) {
73 T ip = ip_line[j];
74 if (C::cmp(simi[0], ip)) {
75 heap_replace_top<C>(k, simi, idxi, ip, id_line[j]);
76 }
77 }
78 }
79 }
80
81 template <typename C>
per_line_extrema(T * out_val,TI * out_ids) const82 void HeapArray<C>::per_line_extrema(T* out_val, TI* out_ids) const {
83 #pragma omp parallel for
84 for (int64_t j = 0; j < nh; j++) {
85 int64_t imin = -1;
86 typename C::T xval = C::Crev::neutral();
87 const typename C::T* x_ = val + j * k;
88 for (size_t i = 0; i < k; i++)
89 if (C::cmp(x_[i], xval)) {
90 xval = x_[i];
91 imin = i;
92 }
93 if (out_val)
94 out_val[j] = xval;
95
96 if (out_ids) {
97 if (ids && imin != -1)
98 out_ids[j] = ids[j * k + imin];
99 else
100 out_ids[j] = imin;
101 }
102 }
103 }
104
105 // explicit instanciations
106
107 template struct HeapArray<CMin<float, int64_t>>;
108 template struct HeapArray<CMax<float, int64_t>>;
109 template struct HeapArray<CMin<int, int64_t>>;
110 template struct HeapArray<CMax<int, int64_t>>;
111
112 } // namespace faiss
113