1 #include "cado.h" // IWYU pragma: keep
2 #include <stdint.h>     /* AIX wants it first (it's a bug) */
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
6 #include "knapsack.h"
7 #include "macros.h"     // ASSERT
8 
9 
10 
11 struct kns_sum {
12     int64_t x;
13     unsigned long v;
14 };
15 
16 typedef int (*sortfunc_t) (const void *, const void *);
17 
18 /* returns the sign of the difference a-b (which amounts to comparing a
19  * and b */
int64_cmp(int64_t a,int64_t b)20 static inline int int64_cmp(int64_t a, int64_t b)
21 {
22     return (b < a) - (a < b);
23 }
24 
kns_sum_cmp(const struct kns_sum * s,const struct kns_sum * t)25 int kns_sum_cmp(const struct kns_sum *s, const struct kns_sum *t)
26 {
27     int d = int64_cmp(s->x, t->x);
28     if (d)
29 	return d;
30     else
31 	return (s->v > t->v) - (s->v < t->v);
32 }
33 
kns_sum_negcmp(const struct kns_sum * s,const struct kns_sum * t)34 int kns_sum_negcmp(const struct kns_sum *s, const struct kns_sum *t)
35 {
36     int d = int64_cmp(-s->x, -t->x);
37     if (d)
38 	return d;
39     else
40 	return (s->v > t->v) - (s->v < t->v);
41 }
42 
all_sums(const int64_t * t,unsigned int t_stride,unsigned long k,size_t extra_alloc,unsigned long offset)43 struct kns_sum *all_sums(const int64_t * t, unsigned int t_stride,
44 			 unsigned long k, size_t extra_alloc,
45 			 unsigned long offset)
46 {
47     struct kns_sum *r =
48 	malloc((extra_alloc + (1UL << k)) * sizeof(struct kns_sum));
49     int64_t *t2 = malloc(k * sizeof(int64_t));
50     int64_t x = offset;
51     for (unsigned long i = 0; i < k; i++, t += t_stride) {
52 	x -= *t;
53 	t2[i] = *t * 2;
54     }
55     r[0].x = x;
56     unsigned long v = 0;
57     r[0].v = 0;
58     for (unsigned long i = 1; i < (1UL << k); i++) {
59 	unsigned int s = __builtin_ffs(i) - 1;
60 	unsigned long sh = 1UL << s;
61 	unsigned int down = v & sh;
62 	v ^= sh;
63 	x += down ? -t2[s] : t2[s];
64 	r[i].x = x;
65 	r[i].v = v;
66     }
67     free(t2);
68 
69     return r;
70 }
71 
72 #define ALLOC_EXTRA     256
73 
knapsack_solve(knapsack_object_ptr ks)74 int knapsack_solve(knapsack_object_ptr ks)
75 {
76     const int64_t * tab = ks->tab;
77     unsigned int stride = ks->stride;
78     unsigned int offset = ks->offset;
79     unsigned int nelems = ks->nelems;
80     int64_t bound = ks->bound;
81     knapsack_object_callback_t cb = ks->cb;
82     void * cb_arg = ks->cb_arg;
83 
84     /* purpose: find a combination of the 64-bit integers in tab[], with
85      * coefficients -1 or +1, which is within the interval
86      * [-bound..bound].
87      */
88 
89     unsigned int k1 = nelems / 2;
90     unsigned int k2 = nelems - k1;
91     struct kns_sum *s1 = all_sums(tab + offset, stride, k1, 0, bound);
92     struct kns_sum *s2 = all_sums(tab + k1 * stride + offset, stride, k2, 0, 0);
93 
94     unsigned int n1 = (1UL << k1);
95     unsigned int n2 = (1UL << k2);
96 
97     qsort(s1, n1, sizeof(struct kns_sum), (sortfunc_t) & kns_sum_cmp);
98     qsort(s2, n2, sizeof(struct kns_sum), (sortfunc_t) & kns_sum_negcmp);
99 
100     int64_t ebound = 2 * bound;
101     int res = 0;
102 
103     /* elements in u1 are sorted in increasing order, in [-B/2,B/2[ */
104     /* elements in u2 are sorted in DEcreasing order, or more
105      * accurately, in increasing order of their opposite. It is therefore
106      * best to consider u2 as containing the OPPOSITES of the y values,
107      * even though in reality we don't have to take the negation. These
108      * OPPOSITES, for comparison purposes, are thus considered in
109      * [-B/2,B/2[
110      *
111      * We search for x1 - (-x2) in [0,epsilon[ mod B. There are three
112      * possible cases:
113      *
114      *  x1 - (-x2) in [-B, -B + epsilon[.
115      *          this implies x1 in [(-x2)-B, (-x2)-B+epsilon[.
116      *          this implies x1 in [-B/2, -B/2+epsilon-1[, provided -x2
117      *          has its maximal value B/2-1
118      *  x1 - (-x2) in [0, 0 + epsilon[.
119      *          this is the main case.
120      *  x1 - (-x2) in [B, B + epsilon[.
121      *          this implies x1 in [(-x2)+B, (-x2)+B+epsilon[.
122      *          Since -x2 >= B/2, this case clearly impossible.
123      */
124 
125     /* first treat case 1, which is exceptional */
126     for (unsigned int u2 = n2; u2--;) {
127 	/* x2, at the end of the array, is a large negative number. -x2,
128 	 * (which is not the value present in s2), is at its peak, thus a
129 	 * large positive number. It must still be very large for the
130 	 * interval we're going to consider to be non-empty.
131 	 */
132 	/*      do we have (-x2)-B+epsilon >= -B/2 ? */
133 	/* iow, do we have (-x2) >= B/2 - epsilon ? */
134 	/* iow, do we have (-x2) >= -(-B/2 + epsilon) ? */
135 	int64_t against = INT64_MIN + ebound;
136 	if (int64_cmp(-against, -s2[u2].x) > 0)
137 	    break;
138 	for (unsigned int u1 = 0; u1 < n1; u1++) {
139 	    // continue as long as we have:
140 	    // x1 < (-x2)-B+epsilon
141 	    // note that we know that (-x2)-B+epsilon >= -B/2, thus
142 	    // computing (-x2)+epsilon will wrap around to a negative number
143 	    // in the range [-B/2, -B/2+epsilon[
144 	    int64_t cut = -s2[u2].x + ebound;	// wraps around.
145 	    if (int64_cmp(s1[u1].x, cut) >= 0)
146 		break;
147 
148 	    unsigned long v = s2[u2].v << k1 | s1[u1].v;
149 	    int64_t x = s1[u1].x + s2[u2].x;
150 	    res += cb(cb_arg, v, x - bound);
151 	}
152     }
153 
154     unsigned int u1 = 0;
155 
156     for (; u1 < n1; u1++) {
157 	if (s1[u1].x + s2[0].x >= 0) {
158 	    break;
159 	}
160     }
161 
162     for (unsigned int u2 = 0; u2 < n2; u2++) {
163 
164 	// printf("u2=%d\n",u2);
165 	/* strategy: first expand the available interval, later restrict
166 	 * it */
167 
168 	// compared to the previous turn, x2 has decreased.
169 	int64_t last_x = s1[u1].x + s2[u2].x;
170 	for (; u1 < n1; u1++) {
171 	    int64_t x = s1[u1].x + s2[u2].x;
172 	    // printf("x=%" PRId64 "\n", x);
173 	    /* if there is _no_ element in s2 such that x1+x2 >= 0, then
174 	     * we'll notice this since x will drop _below_ last_x. In
175 	     * such a case, we know that we have a gap of size mor than
176 	     * B/2 in s2. This gap will be exposed one again for the next
177 	     * turn.
178 	     */
179 	    if (x >= 0 || x < last_x)
180 		break;
181 	    // printf("u1=%d\n",u1+1);
182 	}
183 
184 	/* Now u1 is the first index such that x1 + x2 >= 0 */
185 	// ASSERT(u1 == n1 || s2[u2].x + s1[u1].x >= 0);
186 	// ASSERT(s2[u2].x + s1[u1+pd].x >= ebound);
187 
188 	for (unsigned int h = 0; u1 + h < n1; h++) {
189 	    int64_t x = s1[u1 + h].x + s2[u2].x;
190 	    if (x >= ebound || x < 0)
191 		break;
192 	    unsigned long v = s2[u2].v << k1 | s1[u1 + h].v;
193 	    // fprintf(stderr, "u1=%u u2=%u h=%u: %" PRId64 "\n", u1,u2,h,x);
194 	    ASSERT(x >= 0);
195 	    ASSERT(x < ebound);
196 	    res += cb(cb_arg, v, x - bound);
197 	}
198     }
199 
200     free(s1);
201     free(s2);
202 
203     return res;
204 }
205 
knapsack_object_init(knapsack_object_ptr ptr)206 void knapsack_object_init(knapsack_object_ptr ptr)
207 {
208     memset(ptr, 0, sizeof(knapsack_object));
209     ptr->stride = 1;
210 }
211 
knapsack_object_clear(knapsack_object_ptr ptr)212 void knapsack_object_clear(knapsack_object_ptr ptr)
213 {
214     memset(ptr, 0, sizeof(knapsack_object));
215 }
216 
217 #ifdef  DEMO
print_solution(knapsack_object_ptr ks,unsigned long v,int64_t x)218 int print_solution(knapsack_object_ptr ks, unsigned long v, int64_t x)
219 {
220     char * signs = malloc(ks->nelems + 1);
221     memset(signs, '\0', ks->nelems + 1);
222     for (unsigned int s = 0; s < ks->nelems; s++)
223         signs[s] = (v & (1UL << s)) ? '+' : '-';
224     printf("%lx (%s) %ld\n", v, signs, x);
225     free(signs);
226     return 1;
227 }
228 
229 #if 1
230 #define NELEMS  24
231 const int64_t tab[NELEMS] = {
232     4162221379340189282L, -6647306186101789600L, -1421709520630629187L,
233     7978579249304052465L, -5210946216003197439L, 1071743655218434855L,
234     2848467872950476511L, -1370801619961069543L, -480116246366186631L,
235     7246359761961066352L, -3820114215392891062L, 1960455329265570848L,
236     2169371464082239491L, 6918027011352649575L, -687610025789514251L,
237     2178270899400006382L, -2751086820472252564L, 326442006929102621L,
238     -7009969887660261263L, -5003156455825490387L, 76450565619814227L,
239     7595450547102556048L, 4069562109599364928L, -510019920501521658L,
240 };
241 
242 const int winning[NELEMS] =
243     { 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0,
244 1 };
245     // 0x9b12e5
246 #elif 1
247 // passes in 25 seconds on a U9400 1.4GHz.
248 #define NELEMS 48
249 const int64_t tab[NELEMS] = {
250     -1538040154316158950L, 8258967369930019632L, -4004688999714422921L,
251     1159889400949648397L, 5249521295396238696L, 3650114626103672295L,
252     -7931626704593073451L, 7674135400418350404L, -2222438450190335011L,
253     7860030247650033749L, 4262500207389102065L, 7160554140656125087L,
254     3467957067211623500L, -2390204226671589277L, 7688679274618040461L,
255     5203982469041323703L, -6554011846537959286L, -1746110400544620451L,
256     -5555145176637642721L, -1550913742863465293L, 8257438977466066771L,
257     -438015496418932509L, -2243959839000145910L, -3640817532474311269L,
258     -1990598364962709281L, 8659906920932625770L, 664803566864969265L,
259     2489632956508442703L, 2742566267756791913L, -4145814854589824541L,
260     -4289798920807923686L, 4733442344418006680L, 429022769516086920L,
261     2715560624548339171L, 6738808048152006654L, 4539026061699952726L,
262     -4489030772365149543L, 8876121350766866461L, 1804278128331619321L,
263     -2057342976905341813L, -6373799295888301404L, -8752958562853688943L,
264     -7543544823992563848L, -6571543819875671449L, 1291491683018881388L,
265     3459301977037432953L, 7219978704923241933L, 7852402516694331902L
266 };
267 #else
268 #define NELEMS  5
269 const int64_t tab[NELEMS] = {
270     3697953231077617412, -9135238050079885539, -187993369997840440,
271     -7852759334484681542, 2603467885480253854
272 };
273 #endif
274 
main()275 int main()
276 {
277     knapsack_object ks;
278     knapsack_object_init(ks);
279     ks->tab = tab;
280     ks->nelems = NELEMS;
281     ks->bound = 6;
282     ks->cb = (knapsack_object_callback_t) print_solution;
283     ks->cb_arg = ks;
284     knapsack_solve(ks);
285     knapsack_object_clear(ks);
286     return 0;
287 }
288 #endif
289