1 // Copyright 2018 Chia Network Inc
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <stdio.h>
16 
17 #include <set>
18 
19 #include "../lib/include/catch.hpp"
20 #include "../lib/include/picosha2.hpp"
21 #include "calculate_bucket.hpp"
22 #include "disk.hpp"
23 #include "plotter_disk.hpp"
24 #include "prover_disk.hpp"
25 #include "sort_manager.hpp"
26 #include "verifier.hpp"
27 
28 using namespace std;
29 
30 uint8_t plot_id_1[] = {35,  2,   52,  4,  51, 55,  23,  84, 91, 10, 111, 12,  13,  222, 151, 16,
31                        228, 211, 254, 45, 92, 198, 204, 10, 9,  10, 11,  129, 139, 171, 15,  23};
32 
33 uint8_t plot_id_3[] = {5,   104, 52,  4,  51, 55,  23,  84, 91, 10, 111, 12,  13,  222, 151, 16,
34                        228, 211, 254, 45, 92, 198, 204, 10, 9,  10, 11,  129, 139, 171, 15,  23};
35 
intToBytes(uint32_t paramInt,uint32_t numBytes)36 vector<unsigned char> intToBytes(uint32_t paramInt, uint32_t numBytes)
37 {
38     vector<unsigned char> arrayOfByte(numBytes, 0);
39     for (uint32_t i = 0; paramInt > 0; i++) {
40         arrayOfByte[numBytes - i - 1] = paramInt & 0xff;
41         paramInt >>= 8;
42     }
43     return arrayOfByte;
44 }
45 
to_uint128(uint64_t hi,uint64_t lo)46 static uint128_t to_uint128(uint64_t hi, uint64_t lo) { return (uint128_t)hi << 64 | lo; }
47 
48 TEST_CASE("SliceInt64FromBytes 1 bit")
49 {
50     const uint8_t bytes[9 + 7] = {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9};
51 
52     // since we interpret the first 64 bits (8 bytes) as big endian, the
53     // first byte is 0x01
54     CHECK(Util::SliceInt64FromBytes(bytes, 0, 1) == 0);
55     CHECK(Util::SliceInt64FromBytes(bytes, 1, 1) == 0);
56     CHECK(Util::SliceInt64FromBytes(bytes, 2, 1) == 0);
57     CHECK(Util::SliceInt64FromBytes(bytes, 3, 1) == 0);
58     CHECK(Util::SliceInt64FromBytes(bytes, 4, 1) == 0);
59     CHECK(Util::SliceInt64FromBytes(bytes, 5, 1) == 0);
60     CHECK(Util::SliceInt64FromBytes(bytes, 6, 1) == 0);
61     CHECK(Util::SliceInt64FromBytes(bytes, 7, 1) == 1);
62 
63     // the second byte is 0x2
64     CHECK(Util::SliceInt64FromBytes(bytes, 8, 1) == 0);
65     CHECK(Util::SliceInt64FromBytes(bytes, 9, 1) == 0);
66     CHECK(Util::SliceInt64FromBytes(bytes, 10, 1) == 0);
67     CHECK(Util::SliceInt64FromBytes(bytes, 11, 1) == 0);
68     CHECK(Util::SliceInt64FromBytes(bytes, 12, 1) == 0);
69     CHECK(Util::SliceInt64FromBytes(bytes, 13, 1) == 0);
70     CHECK(Util::SliceInt64FromBytes(bytes, 14, 1) == 1);
71     CHECK(Util::SliceInt64FromBytes(bytes, 15, 1) == 0);
72 
73     // the third byte is 0x3
74     CHECK(Util::SliceInt64FromBytes(bytes, 16, 1) == 0);
75     CHECK(Util::SliceInt64FromBytes(bytes, 17, 1) == 0);
76     CHECK(Util::SliceInt64FromBytes(bytes, 18, 1) == 0);
77     CHECK(Util::SliceInt64FromBytes(bytes, 19, 1) == 0);
78     CHECK(Util::SliceInt64FromBytes(bytes, 20, 1) == 0);
79     CHECK(Util::SliceInt64FromBytes(bytes, 21, 1) == 0);
80     CHECK(Util::SliceInt64FromBytes(bytes, 22, 1) == 1);
81     CHECK(Util::SliceInt64FromBytes(bytes, 23, 1) == 1);
82 }
83 
84 TEST_CASE("SliceInt64FromBytes 8 bits")
85 {
86     const uint8_t bytes[9 + 7] = {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9};
87 
88     // since we interpret the first 64 bits (8 bytes) as big endian, the
89     // first byte is 0x01
90     CHECK(Util::SliceInt64FromBytes(bytes, 0, 8) == 0b00000001);
91     CHECK(Util::SliceInt64FromBytes(bytes, 1, 8) == 0b00000010);
92     CHECK(Util::SliceInt64FromBytes(bytes, 2, 8) == 0b00000100);
93     CHECK(Util::SliceInt64FromBytes(bytes, 3, 8) == 0b00001000);
94     CHECK(Util::SliceInt64FromBytes(bytes, 4, 8) == 0b00010000);
95     CHECK(Util::SliceInt64FromBytes(bytes, 5, 8) == 0b00100000);
96     CHECK(Util::SliceInt64FromBytes(bytes, 6, 8) == 0b01000000);
97     CHECK(Util::SliceInt64FromBytes(bytes, 7, 8) == 0b10000001);
98 
99     CHECK(Util::SliceInt64FromBytes(bytes,  8, 8) == 0b00000010);
100     CHECK(Util::SliceInt64FromBytes(bytes,  9, 8) == 0b00000100);
101     CHECK(Util::SliceInt64FromBytes(bytes, 10, 8) == 0b00001000);
102     CHECK(Util::SliceInt64FromBytes(bytes, 11, 8) == 0b00010000);
103     CHECK(Util::SliceInt64FromBytes(bytes, 12, 8) == 0b00100000);
104     CHECK(Util::SliceInt64FromBytes(bytes, 13, 8) == 0b01000000);
105     CHECK(Util::SliceInt64FromBytes(bytes, 14, 8) == 0b10000000);
106     CHECK(Util::SliceInt64FromBytes(bytes, 15, 8) == 0b00000001);
107 
108     CHECK(Util::SliceInt64FromBytes(bytes, 16, 8) == 0b00000011);
109     CHECK(Util::SliceInt64FromBytes(bytes, 17, 8) == 0b00000110);
110     CHECK(Util::SliceInt64FromBytes(bytes, 18, 8) == 0b00001100);
111     CHECK(Util::SliceInt64FromBytes(bytes, 19, 8) == 0b00011000);
112 }
113 
114 TEST_CASE("SliceInt64FromBytes 24 bits")
115 {
116     const uint8_t bytes[9 + 7] = {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9};
117 
118     // since we interpret the first 64 bits (8 bytes) as big endian, the
119     // first byte is 0x01
120     CHECK(Util::SliceInt64FromBytes(bytes, 0, 24) == 0b00000001'00000010'00000011);
121     CHECK(Util::SliceInt64FromBytes(bytes, 1, 24) == 0b0000001'00000010'00000011'0);
122     CHECK(Util::SliceInt64FromBytes(bytes, 2, 24) == 0b000001'00000010'00000011'00);
123     CHECK(Util::SliceInt64FromBytes(bytes, 3, 24) == 0b00001'00000010'00000011'000);
124     CHECK(Util::SliceInt64FromBytes(bytes, 4, 24) == 0b0001'00000010'00000011'0000);
125     CHECK(Util::SliceInt64FromBytes(bytes, 5, 24) == 0b001'00000010'00000011'00000);
126     CHECK(Util::SliceInt64FromBytes(bytes, 6, 24) == 0b01'00000010'00000011'000001);
127     CHECK(Util::SliceInt64FromBytes(bytes, 7, 24) == 0b1'00000010'00000011'0000010);
128 }
129 
130 TEST_CASE("SliceInt64FromBytesFull")
131 {
132     const uint8_t bytes[9 + 7] = {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9};
133 
134     // since we interpret the first 64 bits (8 bytes) as big endian, the
135     // first byte is 0x01
136     CHECK(Util::SliceInt64FromBytesFull(bytes, 0, 64) == 0x0102030405060708ull);
137     CHECK(Util::SliceInt64FromBytesFull(bytes, 1, 64) == 0x0102030405060708ull << 1);
138     CHECK(Util::SliceInt64FromBytesFull(bytes, 2, 64) == 0x0102030405060708ull << 2);
139     CHECK(Util::SliceInt64FromBytesFull(bytes, 3, 64) == 0x0102030405060708ull << 3);
140     CHECK(Util::SliceInt64FromBytesFull(bytes, 4, 64) == 0x1020304050607080ull);
141     CHECK(Util::SliceInt64FromBytesFull(bytes, 5, 64) == ((0x1020304050607080ull << 1) | 0b1));
142     CHECK(Util::SliceInt64FromBytesFull(bytes, 6, 64) == ((0x1020304050607080ull << 2) | 0b10));
143     CHECK(Util::SliceInt64FromBytesFull(bytes, 7, 64) == ((0x1020304050607080ull << 3) | 0b100));
144     CHECK(Util::SliceInt64FromBytesFull(bytes, 8, 64) == 0x0203040506070809ull);
145 }
146 
147 TEST_CASE("Util")
148 {
149     SECTION("Increment and decrement")
150     {
151         uint8_t bytes[3 + 7] = {45, 172, 225};
152         REQUIRE(Util::SliceInt64FromBytes(bytes, 2, 19) == 374172);
153         uint8_t bytes2[1 + 7] = {213};
154         REQUIRE(Util::SliceInt64FromBytes(bytes2, 1, 5) == 21);
155         uint8_t bytes3[17 + 7] = {1, 2, 3, 4, 5, 6, 7, 255, 255, 10, 11, 12, 13, 14, 15, 16, 255};
156         uint128_t int3 = to_uint128(0x01020304050607ff, 0xff0a0b0c0d0e0f10);
157         REQUIRE(Util::SliceInt64FromBytes(bytes3, 64, 64) == (uint64_t)int3);
158         REQUIRE(Util::SliceInt64FromBytes(bytes3, 0, 60) == (uint64_t)(int3 >> 68));
159         REQUIRE(Util::SliceInt128FromBytes(bytes3, 0, 60) == int3 >> 68);
160         REQUIRE(Util::SliceInt128FromBytes(bytes3, 7, 64) == int3 >> 57);
161         REQUIRE(Util::SliceInt128FromBytes(bytes3, 7, 72) == int3 >> 49);
162         REQUIRE(Util::SliceInt128FromBytes(bytes3, 0, 128) == int3);
163         REQUIRE(Util::SliceInt128FromBytes(bytes3, 3, 125) == int3);
164         REQUIRE(Util::SliceInt128FromBytes(bytes3, 2, 125) == int3 >> 1);
165         REQUIRE(Util::SliceInt128FromBytes(bytes3, 0, 120) == int3 >> 8);
166         REQUIRE(Util::SliceInt128FromBytes(bytes3, 3, 127) == (int3 << 2 | 3));
167     }
168 }
169 
170 TEST_CASE("Bits")
171 {
172     SECTION("Slicing and manipulating")
173     {
174         Bits g = Bits(13271, 15);
175         cout << "G: " << g << endl;
176         cout << "G Slice: " << g.Slice(4, 9) << endl;
177         cout << "G Slice: " << g.Slice(0, 9) << endl;
178         cout << "G Slice: " << g.Slice(9, 10) << endl;
179         cout << "G Slice: " << g.Slice(9, 15) << endl;
180         cout << "G Slice: " << g.Slice(9, 9) << endl;
181         REQUIRE(g.Slice(9, 9) == Bits());
182 
183         uint8_t bytes[2];
184         g.ToBytes(bytes);
185         cout << "bytes: " << static_cast<int>(bytes[0]) << " " << static_cast<int>(bytes[1])
186              << endl;
187         cout << "Back to Bits: " << Bits(bytes, 2, 16) << endl;
188 
189         Bits(256, 9).ToBytes(bytes);
190         cout << "bytes: " << static_cast<int>(bytes[0]) << " " << static_cast<int>(bytes[1])
191              << endl;
192         cout << "Back to Bits: " << Bits(bytes, 2, 16) << endl;
193 
194         cout << Bits(640, 11) << endl;
195         Bits(640, 11).ToBytes(bytes);
196         cout << "bytes: " << static_cast<int>(bytes[0]) << " " << static_cast<int>(bytes[1])
197              << endl;
198 
199         Bits h = Bits(bytes, 2, 16);
200         Bits i = Bits(bytes, 2, 17);
201         cout << "H: " << h << endl;
202         cout << "I: " << i << endl;
203 
204         cout << "G: " << g << endl;
205         cout << "size: " << g.GetSize() << endl;
206 
207         Bits shifted = (g << 150);
208 
209         REQUIRE(shifted.GetSize() == 15);
210         REQUIRE(shifted.ToString() == "000000000000000");
211 
212         Bits large = Bits(13271, 200);
213         REQUIRE(large == ((large << 160)) >> 160);
214         REQUIRE((large << 160).GetSize() == 200);
215 
216         Bits l = Bits(123287490 & ((1U << 20) - 1), 20);
217         l = l + Bits(0, 5);
218 
219         Bits m = Bits(5, 3);
220         uint8_t buf[1];
221         m.ToBytes(buf);
222         REQUIRE(buf[0] == (5 << 5));
223     }
224     SECTION("Park Bits")
225     {
226         uint32_t const num_bytes = 16000;
227         uint8_t buf[num_bytes];
228         uint8_t buf_2[num_bytes];
229         Util::GetRandomBytes(buf, num_bytes);
230         ParkBits my_bits = ParkBits(buf, num_bytes, num_bytes * 8);
231         my_bits.ToBytes(buf_2);
232         for (uint32_t i = 0; i < num_bytes; i++) {
233             REQUIRE(buf[i] == buf_2[i]);
234         }
235     }
236 
237     SECTION("Large Bits")
238     {
239         uint32_t const num_bytes = 200000;
240         uint8_t buf[num_bytes];
241         uint8_t buf_2[num_bytes];
242         Util::GetRandomBytes(buf, num_bytes);
243         LargeBits my_bits = LargeBits(buf, num_bytes, num_bytes * 8);
244         my_bits.ToBytes(buf_2);
245         for (uint32_t i = 0; i < num_bytes; i++) {
246             REQUIRE(buf[i] == buf_2[i]);
247         }
248     }
249 }
250 
CheckMatch(int64_t yl,int64_t yr)251 bool CheckMatch(int64_t yl, int64_t yr)
252 {
253     int64_t bl = yl / kBC;
254     int64_t br = yr / kBC;
255     if (bl + 1 != br)
256         return false;  // Buckets don't match
257     for (int64_t m = 0; m < kExtraBitsPow; m++) {
258         if ((((yr % kBC) / kC - ((yl % kBC) / kC)) - m) % kB == 0) {
259             int64_t c_diff = 2 * m + bl % 2;
260             c_diff *= c_diff;
261 
262             if ((((yr % kBC) % kC - ((yl % kBC) % kC)) - c_diff) % kC == 0) {
263                 return true;
264             }
265         }
266     }
267     return false;
268 }
269 
270 // Get next set in the Cartesian product of k ranges of [0, n - 1], similar to
271 // k nested 'for' loops from 0 to n - 1
CartProdNext(uint8_t * items,uint8_t n,uint8_t k,bool init)272 static int CartProdNext(uint8_t* items, uint8_t n, uint8_t k, bool init)
273 {
274     uint8_t i;
275 
276     if (init) {
277         memset(items, 0, k);
278         return 0;
279     }
280 
281     items[0]++;
282     for (i = 0; i < k; i++) {
283         if (items[i] == n) {
284             items[i] = 0;
285             if (i == k - 1) {
286                 return -1;
287             }
288             items[i + 1]++;
289         } else {
290             break;
291         }
292     }
293 
294     return 0;
295 }
296 
sq(int n)297 static int sq(int n) { return n * n; }
298 
Have4Cycles(uint32_t extraBits,int B,int C)299 static bool Have4Cycles(uint32_t extraBits, int B, int C)
300 {
301     uint8_t m[4];
302     bool init = true;
303 
304     while (!CartProdNext(m, 1 << extraBits, 4, init)) {
305         uint8_t r1 = m[0], r2 = m[1], s1 = m[2], s2 = m[3];
306 
307         init = false;
308         if (r1 != s1 && (r1 << extraBits) + r2 != (s2 << extraBits) + s1 &&
309             (r1 - s1 + r2 - s2) % B == 0) {
310             uint8_t p[2];
311             bool initp = true;
312 
313             while (!CartProdNext(p, 2, 2, initp)) {
314                 uint8_t p1 = p[0], p2 = p[1];
315                 int lhs = sq(2 * r1 + p1) - sq(2 * s1 + p1) + sq(2 * r2 + p2) - sq(2 * s2 + p2);
316 
317                 initp = false;
318                 if (lhs % C == 0) {
319                     fprintf(stderr, "%d %d %d %d %d %d\n", r1, r2, s1, s2, p1, p2);
320                     return true;
321                 }
322             }
323         }
324     }
325 
326     return false;
327 }
328 
329 TEST_CASE("Matching function")
330 {
331     SECTION("Cycles") { REQUIRE(!Have4Cycles(kExtraBits, kB, kC)); }
332 }
333 
VerifyFC(uint8_t t,uint8_t k,uint64_t L,uint64_t R,uint64_t y1,uint64_t y,uint64_t c)334 void VerifyFC(uint8_t t, uint8_t k, uint64_t L, uint64_t R, uint64_t y1, uint64_t y, uint64_t c)
335 {
336     uint8_t sizes[] = {1, 2, 4, 4, 3, 2};
337     uint8_t size = sizes[t - 2];
338     FxCalculator fcalc(k, t);
339 
340     std::pair<Bits, Bits> res = fcalc.CalculateBucket(
341         Bits(y1, k + kExtraBits), Bits(L, k * size), Bits(R, k * size));
342     REQUIRE(res.first.GetValue() == y);
343     if (c) {
344         REQUIRE(res.second.GetValue() == c);
345     }
346 }
347 
348 TEST_CASE("F functions")
349 {
350     SECTION("F1")
351     {
352         uint8_t test_k = 35;
353         uint8_t test_key[] = {0, 2, 3, 4,  5, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
354                               1, 2, 3, 41, 5, 6, 7, 8, 9, 10, 11, 12, 13, 11, 15, 16};
355         F1Calculator f1(test_k, test_key);
356 
357         Bits L = Bits(525, test_k);
358         pair<Bits, Bits> result1 = f1.CalculateBucket(L);
359         Bits L2 = Bits(526, test_k);
360         pair<Bits, Bits> result2 = f1.CalculateBucket(L2);
361         Bits L3 = Bits(625, test_k);
362         pair<Bits, Bits> result3 = f1.CalculateBucket(L3);
363 
364         uint64_t results[256];
365         f1.CalculateBuckets(L.GetValue(), 101, results);
366         REQUIRE(result1.first.GetValue() == results[0]);
367         REQUIRE(result2.first.GetValue() == results[1]);
368         REQUIRE(result3.first.GetValue() == results[100]);
369 
370         uint32_t max_batch = 1 << kBatchSizes;
371         test_k = 32;
372         F1Calculator f1_2(test_k, test_key);
373         L = Bits(192837491, test_k);
374         result1 = f1_2.CalculateBucket(L);
375         L2 = Bits(192837491 + 1, test_k);
376         result2 = f1_2.CalculateBucket(L2);
377         L3 = Bits(192837491 + 2, test_k);
378         result3 = f1_2.CalculateBucket(L3);
379         Bits L4 = Bits(192837491 + max_batch - 1, test_k);
380         pair<Bits, Bits> result4 = f1_2.CalculateBucket(L4);
381 
382         f1_2.CalculateBuckets(L.GetValue(), max_batch, results);
383         REQUIRE(result1.first.GetValue() == results[0]);
384         REQUIRE(result2.first.GetValue() == results[1]);
385         REQUIRE(result3.first.GetValue() == results[2]);
386         REQUIRE(result4.first.GetValue() == results[max_batch - 1]);
387     }
388 
389     SECTION("F2")
390     {
391         uint8_t test_key_2[] = {20,  2,  5,  4,   51, 52,  23,  84,  91, 10, 111,
392                                 12,  13, 24, 151, 16, 228, 211, 254, 45, 92, 198,
393                                 204, 10, 9,  10,  11, 129, 139, 171, 15, 18};
394         map<uint64_t, vector<pair<Bits, Bits>>> buckets;
395 
396         uint8_t const k = 12;
397         uint64_t num_buckets = (1ULL << (k + kExtraBits)) / kBC + 1;
398         uint64_t x = 0;
399 
400         F1Calculator f1(k, test_key_2);
401         for (uint32_t j = 0; j < (1ULL << (k - 4)) + 1; j++) {
402             uint64_t y[1 << 4];
403 
404             f1.CalculateBuckets(x, 1U << 4, y);
405             for (int i = 0; i < 1 << 4; i++) {
406                 uint64_t bucket = y[i] / kBC;
407                 if (buckets.find(bucket) == buckets.end()) {
408                     buckets[bucket] = vector<std::pair<Bits, Bits>>();
409                 }
410                 buckets[bucket].push_back(std::make_pair(Bits(y[i], k + kExtraBits), Bits(x, k)));
411                 if (x + 1 > (1ULL << k) - 1) {
412                     break;
413                 }
414                 ++x;
415             }
416             if (x + 1 > (1ULL << k) - 1) {
417                 break;
418             }
419         }
420 
421         FxCalculator f2(k, 2);
422         int total_matches = 0;
423 
424         for (auto kv : buckets) {
425             if (kv.first == num_buckets - 1) {
426                 continue;
427             }
428             auto bucket_elements_2 = buckets[kv.first + 1];
429             vector<PlotEntry> left_bucket;
430             vector<PlotEntry> right_bucket;
431             for (auto yx1 : kv.second) {
432                 PlotEntry e;
433                 e.y = get<0>(yx1).GetValue();
434                 left_bucket.push_back(e);
435             }
436             for (auto yx2 : buckets[kv.first + 1]) {
437                 PlotEntry e;
438                 e.y = get<0>(yx2).GetValue();
439                 right_bucket.push_back(e);
440             }
441             sort(
442                 left_bucket.begin(),
443                 left_bucket.end(),
__anon0cce16290102(const PlotEntry& a, const PlotEntry& b) 444                 [](const PlotEntry& a, const PlotEntry& b) -> bool { return a.y > b.y; });
445             sort(
446                 right_bucket.begin(),
447                 right_bucket.end(),
__anon0cce16290202(const PlotEntry& a, const PlotEntry& b) 448                 [](const PlotEntry& a, const PlotEntry& b) -> bool { return a.y > b.y; });
449 
450             uint16_t idx_L[10000];
451             uint16_t idx_R[10000];
452 
453             int32_t idx_count = f2.FindMatches(left_bucket, right_bucket, idx_L, idx_R);
454             for(int32_t i=0; i < idx_count; i++) {
455                 REQUIRE(CheckMatch(left_bucket[idx_L[i]].y, right_bucket[idx_R[i]].y));
456             }
457             total_matches += idx_count;
458         }
459         REQUIRE(total_matches > (1 << k) / 2);
460         REQUIRE(total_matches < (1 << k) * 2);
461     }
462 
463     SECTION("Fx")
464     {
465         VerifyFC(2, 16, 0x44cb, 0x204f, 0x20a61a, 0x2af546, 0x44cb204f);
466         VerifyFC(2, 16, 0x3c5f, 0xfda9, 0x3988ec, 0x15293b, 0x3c5ffda9);
467         VerifyFC(3, 16, 0x35bf992d, 0x7ce42c82, 0x31e541, 0xf73b3, 0x35bf992d7ce42c82);
468         VerifyFC(3, 16, 0x7204e52d, 0xf1fd42a2, 0x28a188, 0x3fb0b5, 0x7204e52df1fd42a2);
469         VerifyFC(
470             4, 16, 0x5b6e6e307d4bedc, 0x8a9a021ea648a7dd, 0x30cb4c, 0x11ad5, 0xd4bd0b144fc26138);
471         VerifyFC(
472             4, 16, 0xb9d179e06c0fd4f5, 0xf06d3fef701966a0, 0x1dd5b6, 0xe69a2, 0xd02115f512009d4d);
473         VerifyFC(5, 16, 0xc2cd789a380208a9, 0x19999e3fa46d6753, 0x25f01e, 0x1f22bd, 0xabe423040a33);
474         VerifyFC(5, 16, 0xbe3edc0a1ef2a4f0, 0x4da98f1d3099fdf5, 0x3feb18, 0x31501e, 0x7300a3a03ac5);
475         VerifyFC(6, 16, 0xc965815a47c5, 0xf5e008d6af57, 0x1f121a, 0x1cabbe, 0xc8cc6947);
476         VerifyFC(6, 16, 0xd420677f6cbd, 0x5894aa2ca1af, 0x2efde9, 0xc2121, 0x421bb8ec);
477         VerifyFC(7, 16, 0x5fec898f, 0x82283d15, 0x14f410, 0x24c3c2, 0x0);
478         VerifyFC(7, 16, 0x64ac5db9, 0x7923986, 0x590fd, 0x1c74a2, 0x0);
479     }
480 }
481 
HexToBytes(const string & hex,uint8_t * result)482 void HexToBytes(const string& hex, uint8_t* result)
483 {
484     for (unsigned int i = 0; i < hex.length(); i += 2) {
485         string byteString = hex.substr(i, 2);
486         uint8_t byte = (uint8_t)strtol(byteString.c_str(), NULL, 16);
487         result[i / 2] = byte;
488     }
489 }
490 
TestProofOfSpace(std::string filename,uint32_t iterations,uint8_t k,uint8_t * plot_id,uint32_t num_proofs)491 void TestProofOfSpace(
492     std::string filename,
493     uint32_t iterations,
494     uint8_t k,
495     uint8_t* plot_id,
496     uint32_t num_proofs)
497 {
498     DiskProver prover(filename);
499     uint8_t* proof_data = new uint8_t[8 * k];
500     uint32_t success = 0;
501     // Tries an edge case challenge with many 1s in the front, and ensures there is no segfault
502     vector<unsigned char> hash(picosha2::k_digest_size);
503     HexToBytes("fffffa2b647d4651c500076d7df4c6f352936cf293bd79c591a7b08e43d6adfb", hash.data());
504     prover.GetQualitiesForChallenge(hash.data());
505 
506     for (uint32_t i = 0; i < iterations; i++) {
507         vector<unsigned char> hash_input = intToBytes(i, 4);
508         vector<unsigned char> hash(picosha2::k_digest_size);
509         picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
510         vector<LargeBits> qualities = prover.GetQualitiesForChallenge(hash.data());
511         Verifier verifier = Verifier();
512         for (uint32_t index = 0; index < qualities.size(); index++) {
513             LargeBits proof = prover.GetFullProof(hash.data(), index);
514             proof.ToBytes(proof_data);
515 
516             LargeBits quality = verifier.ValidateProof(plot_id, k, hash.data(), proof_data, k * 8);
517             REQUIRE(quality.GetSize() == 256);
518             REQUIRE(quality == qualities[index]);
519             success += 1;
520 
521             // Tests invalid proof
522             proof_data[0] = (proof_data[0] + 1) % 256;
523             LargeBits quality_2 =
524                 verifier.ValidateProof(plot_id, k, hash.data(), proof_data, k * 8);
525             REQUIRE(quality_2.GetSize() == 0);
526         }
527     }
528     std::cout << "Success: " << success << "/" << iterations << " "
529               << (100 * ((double)success / (double)iterations)) << "%" << std::endl;
530     REQUIRE(success == num_proofs);
531     REQUIRE(success > 0.5 * iterations);
532     REQUIRE(success < 1.5 * iterations);
533     delete[] proof_data;
534 }
535 
PlotAndTestProofOfSpace(std::string filename,uint32_t iterations,uint8_t k,uint8_t * plot_id,uint32_t buffer,uint32_t num_proofs,uint32_t stripe_size,uint8_t num_threads)536 void PlotAndTestProofOfSpace(
537     std::string filename,
538     uint32_t iterations,
539     uint8_t k,
540     uint8_t* plot_id,
541     uint32_t buffer,
542     uint32_t num_proofs,
543     uint32_t stripe_size,
544     uint8_t num_threads)
545 {
546     DiskPlotter plotter = DiskPlotter();
547     uint8_t memo[5] = {1, 2, 3, 4, 5};
548     plotter.CreatePlotDisk(
549         ".", ".", ".", filename, k, memo, 5, plot_id, 32, buffer, 0, stripe_size, num_threads);
550     TestProofOfSpace(filename, iterations, k, plot_id, num_proofs);
551     REQUIRE(remove(filename.c_str()) == 0);
552 }
553 
554 TEST_CASE("Plotting")
555 {
556     SECTION("Disk plot k18")
557     {
558         PlotAndTestProofOfSpace("cpp-test-plot.dat", 100, 18, plot_id_1, 11, 95, 4000, 2);
559     }
560     SECTION("Disk plot k19")
561     {
562         PlotAndTestProofOfSpace("cpp-test-plot.dat", 100, 19, plot_id_1, 100, 71, 8192, 2);
563     }
564     SECTION("Disk plot k19 single-thread")
565     {
566         PlotAndTestProofOfSpace("cpp-test-plot.dat", 100, 19, plot_id_1, 100, 71, 8192, 1);
567     }
568     SECTION("Disk plot k20")
569     {
570         PlotAndTestProofOfSpace("cpp-test-plot.dat", 500, 20, plot_id_3, 100, 469, 16000, 2);
571     }
572     SECTION("Disk plot k21")
573     {
574         PlotAndTestProofOfSpace("cpp-test-plot.dat", 5000, 21, plot_id_3, 100, 4945, 8192, 4);
575     }
576     // SECTION("Disk plot k24") { PlotAndTestProofOfSpace("cpp-test-plot.dat", 100, 24, plot_id_3,
577     // 100, 107); }
578 }
579 
580 TEST_CASE("Invalid plot")
581 {
582     SECTION("File gets deleted")
583     {
584         string filename = "invalid-plot.dat";
585         {
586             DiskPlotter plotter = DiskPlotter();
587             uint8_t memo[5] = {1, 2, 3, 4, 5};
588             uint8_t k = 20;
589             plotter.CreatePlotDisk(".", ".", ".", filename, k, memo, 5, plot_id_1, 32, 200, 32, 8192, 2);
590             DiskProver prover(filename);
591             uint8_t* proof_data = new uint8_t[8 * k];
592             uint8_t challenge[32];
593             size_t i;
594             memset(challenge, 155, 32);
595             vector<LargeBits> qualities;
596             for (i = 0; i < 50; i++) {
597                 qualities = prover.GetQualitiesForChallenge(challenge);
598                 if (qualities.size())
599                     break;
600                 challenge[0]++;
601             }
602             Verifier verifier = Verifier();
603             REQUIRE(qualities.size() > 0);
604             for (uint32_t index = 0; index < qualities.size(); index++) {
605                 LargeBits proof = prover.GetFullProof(challenge, index);
606                 proof.ToBytes(proof_data);
607                 LargeBits quality =
608                     verifier.ValidateProof(plot_id_1, k, challenge, proof_data, k * 8);
609                 REQUIRE(quality == qualities[index]);
610             }
611             delete[] proof_data;
612         }
613         REQUIRE(remove(filename.c_str()) == 0);
__anon0cce16290302() 614         REQUIRE_THROWS_WITH([&]() { DiskProver p(filename); }(), "Invalid file " + filename);
615     }
616 }
617 
618 TEST_CASE("Sort on disk")
619 {
620     SECTION("ExtractNum")
621     {
622         for (int i = 0; i < 15 * 8 - 5; i++) {
623             uint8_t buf[15 + 7];
624             Bits((uint128_t)27 << i, 15 * 8).ToBytes(buf);
625 
626             REQUIRE(Util::ExtractNum(buf, 15, 15 * 8 - 4 - i, 3) == 5);
627         }
628         uint8_t buf[16 + 7];
629         Bits((uint128_t)27 << 5, 128).ToBytes(buf);
630         REQUIRE(Util::ExtractNum(buf, 16, 100, 200) == 864);
631     }
632 
633     SECTION("MemCmpBits")
634     {
635         uint8_t left[3];
636         left[0] = 12;
637         left[1] = 10;
638         left[2] = 100;
639 
640         uint8_t right[3];
641         right[0] = 12;
642         right[1] = 10;
643         right[2] = 100;
644 
645         REQUIRE(Util::MemCmpBits(left, right, 3, 0) == 0);
646         REQUIRE(Util::MemCmpBits(left, right, 3, 10) == 0);
647 
648         right[1] = 11;
649         REQUIRE(Util::MemCmpBits(left, right, 3, 0) < 0);
650         REQUIRE(Util::MemCmpBits(left, right, 3, 16) == 0);
651 
652         right[1] = 9;
653         REQUIRE(Util::MemCmpBits(left, right, 3, 0) > 0);
654 
655         right[1] = 10;
656 
657         // Last bit differs
658         right[2] = 101;
659         REQUIRE(Util::MemCmpBits(left, right, 3, 0) < 0);
660     }
661 
662     SECTION("Quicksort")
663     {
664         uint32_t const iters = 100;
665         vector<string> hashes;
666         uint8_t* hashes_bytes = new uint8_t[iters * 16];
667         memset(hashes_bytes, 0, iters * 16);
668 
669         srand(0);
670         for (uint32_t i = 0; i < iters; i++) {
671             // reverting to rand()
672             string to_insert = std::to_string(rand());
673             while (to_insert.length() < 16) {
674                 to_insert += "0";
675             }
676             hashes.push_back(to_insert);
677             memcpy(hashes_bytes + i * 16, to_insert.data(), to_insert.length());
678         }
679         sort(hashes.begin(), hashes.end());
680         QuickSort::Sort(hashes_bytes, 16, iters, 0);
681 
682         for (uint32_t i = 0; i < iters; i++) {
683             std::string str(reinterpret_cast<char*>(hashes_bytes) + i * 16, 16);
684             REQUIRE(str.compare(hashes[i]) == 0);
685         }
686         delete[] hashes_bytes;
687     }
688 
689     SECTION("File disk")
690     {
691         FileDisk d = FileDisk("test_file.bin");
692         uint8_t buf[5] = {1, 2, 3, 5, 7};
693         d.Write(250, buf, 5);
694 
695         uint8_t read_buf[5];
696         d.Read(250, read_buf, 5);
697 
698         REQUIRE(memcmp(buf, read_buf, 5) == 0);
699         remove("test_file.bin");
700     }
701 
702     SECTION("Lazy Sort Manager QS")
703     {
704         uint32_t iters = 250000;
705         uint32_t const size = 32;
706         vector<Bits> input;
707         const uint32_t memory_len = 1000000;
708         SortManager manager(memory_len, 16, 4, size, ".", "test-files", 0, 1);
709         int total_written_1 = 0;
710         for (uint32_t i = 0; i < iters; i++) {
711             vector<unsigned char> hash_input = intToBytes(i, 4);
712             vector<unsigned char> hash(picosha2::k_digest_size);
713             picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
714             total_written_1 += size;
715             Bits to_write = Bits(hash.data(), size, size * 8);
716             input.emplace_back(to_write);
717             manager.AddToCache(to_write);
718         }
719         manager.FlushCache();
720         uint8_t buf[size];
721         sort(input.begin(), input.end());
722         uint8_t* buf3;
723         for (uint32_t i = 0; i < iters; i++) {
724             buf3 = manager.ReadEntry(i * size);
725             input[i].ToBytes(buf);
726             REQUIRE(memcmp(buf, buf3, size) == 0);
727         }
728     }
729 
730     SECTION("Lazy Sort Manager uniform sort")
731     {
732         uint32_t iters = 120000;
733         uint32_t const size = 32;
734         vector<Bits> input;
735         const uint32_t memory_len = 1000000;
736         SortManager manager(memory_len, 16, 4, size, ".", "test-files", 0, 1);
737         int total_written_1 = 0;
738         for (uint32_t i = 0; i < iters; i++) {
739             vector<unsigned char> hash_input = intToBytes(i, 4);
740             vector<unsigned char> hash(picosha2::k_digest_size);
741             picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
742             total_written_1 += size;
743             Bits to_write = Bits(hash.data(), size, size * 8);
744             input.emplace_back(to_write);
745             manager.AddToCache(to_write);
746         }
747         manager.FlushCache();
748         uint8_t buf[size];
749         sort(input.begin(), input.end());
750         uint8_t* buf3;
751         for (uint32_t i = 0; i < iters; i++) {
752             buf3 = manager.ReadEntry(i * size);
753             input[i].ToBytes(buf);
754             REQUIRE(memcmp(buf, buf3, size) == 0);
755         }
756     }
757 
758     SECTION("Sort in Memory")
759     {
760         uint32_t iters = 100000;
761         uint32_t const size = 32;
762         vector<Bits> input;
763         uint32_t begin = 1000;
764         FileDisk disk("test_file.bin");
765 
766         for (uint32_t i = 0; i < iters; i++) {
767             vector<unsigned char> hash_input = intToBytes(i, 4);
768             vector<unsigned char> hash(picosha2::k_digest_size);
769             picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
770             hash[0] = hash[1] = 0;
771             disk.Write(begin + i * size, hash.data(), size);
772             input.emplace_back(Bits(hash.data(), size, size * 8));
773         }
774 
775         const uint32_t memory_len = Util::RoundSize(iters) * size;
776         auto memory = std::make_unique<uint8_t[]>(memory_len);
777         UniformSort::SortToMemory(disk, begin, memory.get(), size, iters, 16);
778 
779         sort(input.begin(), input.end());
780         uint8_t buf[size];
781         for (uint32_t i = 0; i < iters; i++) {
782             input[i].ToBytes(buf);
783             REQUIRE(memcmp(buf, memory.get() + i * size, size) == 0);
784         }
785     }
786 }
787 
788 TEST_CASE("bitfield-simple")
789 {
790     bitfield b(4);
791     CHECK(!b.get(0));
792     CHECK(!b.get(1));
793     CHECK(!b.get(2));
794     CHECK(!b.get(3));
795 
796     b.set(0);
797     CHECK(b.get(0));
798     CHECK(!b.get(1));
799     CHECK(!b.get(2));
800     CHECK(!b.get(3));
801 
802     b.set(1);
803     CHECK(b.get(0));
804     CHECK(b.get(1));
805     CHECK(!b.get(2));
806     CHECK(!b.get(3));
807 
808     b.set(3);
809     CHECK(b.get(0));
810     CHECK(b.get(1));
811     CHECK(!b.get(2));
812     CHECK(b.get(3));
813 }
814 
815 TEST_CASE("bitfield-count")
816 {
817     bitfield b(512);
818 
819     for (int i = 0; i < 512; ++i) {
820         CHECK(b.count(0, 512) == i);
821         CHECK(!b.get(i));
822         b.set(i);
823         CHECK(b.get(i));
824     }
825     CHECK(b.count(0, 512) == 512);
826 }
827 
828 TEST_CASE("bitfield-count-unaligned")
829 {
830     bitfield b(512);
831 
832     for (int i = 0; i < 512; ++i) {
833         b.set(i);
834     }
835 
836     for (int i = 0; i < 512; ++i) {
837         CHECK(b.count(0, i) == i);
838     }
839 }
840 
841 TEST_CASE("bitfield_index-simple")
842 {
843     bitfield b(64);
844     b.set(0);
845     b.set(1);
846     b.set(3);
847     bitfield_index const idx(b);
848     CHECK(idx.lookup(0, 0) == std::pair<uint64_t, uint64_t>{0,0});
849     CHECK(idx.lookup(0, 1) == std::pair<uint64_t, uint64_t>{0,1});
850 
851     CHECK(idx.lookup(0, 3) == std::pair<uint64_t, uint64_t>{0,2});
852 
853     CHECK(idx.lookup(1, 0) == std::pair<uint64_t, uint64_t>{1,0});
854     CHECK(idx.lookup(1, 2) == std::pair<uint64_t, uint64_t>{1,1});
855     CHECK(idx.lookup(3, 0) == std::pair<uint64_t, uint64_t>{2,0});
856 }
857 
858 TEST_CASE("bitfield_index-use index")
859 {
860     bitfield b(1048576);
861     CHECK(b.size() == 1048576);
862     b.set(1048576 - 3);
863     b.set(1048576 - 2);
864     b.set(1048576 - 1);
865     bitfield_index const idx(b);
866     CHECK(idx.lookup(1048576 - 3, 1) == std::pair<uint64_t, uint64_t>{0,1});
867     CHECK(idx.lookup(1048576 - 2, 1) == std::pair<uint64_t, uint64_t>{1,1});
868 }
869 
870 TEST_CASE("bitfield_index edge-cases")
871 {
872     bitfield b(1048576);
873     CHECK(b.size() == 1048576);
874     b.set(0);
875     b.set(bitfield_index::kIndexBucket);
876     b.set(bitfield_index::kIndexBucket * 2);
877     b.set(1048576 - 1);
878     bitfield_index const idx(b);
879     CHECK(idx.lookup(0, 0) == std::pair<uint64_t, uint64_t>{0,0});
880     CHECK(idx.lookup(0, bitfield_index::kIndexBucket) == std::pair<uint64_t, uint64_t>{0,1});
881     CHECK(idx.lookup(0, bitfield_index::kIndexBucket * 2) == std::pair<uint64_t, uint64_t>{0,2});
882     CHECK(idx.lookup(0, 1048576 - 1) == std::pair<uint64_t, uint64_t>{0,3});
883 
884     CHECK(idx.lookup(bitfield_index::kIndexBucket, 0) == std::pair<uint64_t, uint64_t>{1,0});
885     CHECK(idx.lookup(bitfield_index::kIndexBucket, bitfield_index::kIndexBucket) == std::pair<uint64_t, uint64_t>{1,1});
886     CHECK(idx.lookup(bitfield_index::kIndexBucket, 1048576 - 1 - bitfield_index::kIndexBucket)
887         == std::pair<uint64_t, uint64_t>{1,2});
888 
889     CHECK(idx.lookup(bitfield_index::kIndexBucket * 2, 1048576 - 1 - bitfield_index::kIndexBucket * 2)
890         == std::pair<uint64_t, uint64_t>{2,1});
891     CHECK(idx.lookup(1048576 - 1, 0) == std::pair<uint64_t, uint64_t>{3,0});
892 }
893 
test_bitfield_size(int const size)894 void test_bitfield_size(int const size)
895 {
896     bitfield b(size);
897     b.set(0);
898     b.set(size - 1);
899     bitfield_index const idx(b);
900     CHECK(idx.lookup(0, 0) == std::pair<uint64_t, uint64_t>{0,0});
901     CHECK(idx.lookup(0, size - 1) == std::pair<uint64_t, uint64_t>{0,1});
902     CHECK(idx.lookup(size - 1, 0) == std::pair<uint64_t, uint64_t>{1,0});
903 }
904 
905 TEST_CASE("bitfield_index edge-sizes")
906 {
907     test_bitfield_size(bitfield_index::kIndexBucket - 1);
908     test_bitfield_size(bitfield_index::kIndexBucket);
909     test_bitfield_size(bitfield_index::kIndexBucket + 1);
910 }
911 
912 namespace {
913 
914 constexpr int num_test_entries = 2000000;
915 
write_disk_file(FileDisk & df)916 void write_disk_file(FileDisk& df)
917 {
918     std::uint32_t val = 0;
919     for (int i = 0; i < num_test_entries; ++i) {
920         df.Write(i * 4, reinterpret_cast<std::uint8_t const*>(&val), 4);
921         ++val;
922     }
923 }
924 
925 }
926 
927 TEST_CASE("FileDisk")
928 {
929     FileDisk d = FileDisk("test_file.bin");
930     write_disk_file(d);
931 
932     std::uint32_t val = 0;
933     for (uint32_t i = 0; i < num_test_entries; ++i) {
934         d.Read(i * 4, reinterpret_cast<std::uint8_t*>(&val), 4);
935         REQUIRE(i == val);
936     }
937 
938     for (uint32_t i = num_test_entries - 1; i > 0; --i) {
939         d.Read(i * 4, reinterpret_cast<std::uint8_t*>(&val), 4);
940         CHECK(i == val);
941     }
942 
943     remove("test_file.bin");
944 }
945 
946 TEST_CASE("BufferedDisk")
947 {
948     FileDisk d = FileDisk("test_file.bin");
949     write_disk_file(d);
950 
951     BufferedDisk bd(&d, num_test_entries * 4);
952 
953     for (uint32_t i = 0; i < num_test_entries; ++i) {
954         auto const val = *reinterpret_cast<std::uint32_t const*>(bd.Read(i * 4, 4));
955         CHECK(i == val);
956     }
957 
958     // don't go all the way down to 0, every backwards read cursor movement will
959     // print a warning
960     for (uint32_t i = num_test_entries - 1; i > num_test_entries / 2 + 200; --i) {
961         auto const val = *reinterpret_cast<std::uint32_t const*>(bd.Read(i * 4, 4));
962         CHECK(i == val);
963     }
964 
965     remove("test_file.bin");
966 }
967 
968 TEST_CASE("FilteredDisk")
969 {
970     FileDisk d = FileDisk("test_file.bin");
971     write_disk_file(d);
972 
973     SECTION("filter even")
974     {
975         BufferedDisk bd(&d, num_test_entries * 4);
976         // filter every other entry (starting with 0)
977         bitfield filter(num_test_entries);
978         for (int i = 0; i < num_test_entries; ++i) {
979             if ((i & 1) == 1) filter.set(i);
980         }
981         FilteredDisk fd(std::move(bd), std::move(filter), 4);
982 
983         for (uint32_t i = 0; i < num_test_entries / 2 - 1; ++i) {
984             auto const val = *reinterpret_cast<std::uint32_t const*>(fd.Read(i * 4, 4));
985             CHECK((i * 2) + 1 == val);
986         }
987 
988         // don't go all the way down to 0, every backwards read cursor movement will
989         // print a warning
990         for (uint32_t i = num_test_entries / 2 - 1; i > num_test_entries / 2 + 200; --i) {
991             auto const val = *reinterpret_cast<std::uint32_t const*>(fd.Read(i * 4, 4));
992             CHECK((i * 2) + 1 == val);
993         }
994     }
995 
996     SECTION("filter odd")
997     {
998         BufferedDisk bd(&d, num_test_entries * 4);
999         // filter every other entry (starting with 0)
1000         bitfield filter(num_test_entries);
1001         for (int i = 0; i < num_test_entries; ++i) {
1002             if ((i & 1) == 0) filter.set(i);
1003         }
1004         FilteredDisk fd(std::move(bd), std::move(filter), 4);
1005 
1006         for (uint32_t i = 0; i < num_test_entries / 2 - 1; ++i) {
1007             auto const val = *reinterpret_cast<std::uint32_t const*>(fd.Read(i * 4, 4));
1008             CHECK((i * 2) == val);
1009         }
1010 
1011         // don't go all the way down to 0, every backwards read cursor movement will
1012         // print a warning
1013         for (uint32_t i = num_test_entries / 2 - 1; i > num_test_entries / 2 + 200; --i) {
1014             auto const val = *reinterpret_cast<std::uint32_t const*>(fd.Read(i * 4, 4));
1015             CHECK((i * 2) == val);
1016         }
1017     }
1018 /*
1019     SECTION("empty bitfield")
1020     {
1021         BufferedDisk bd(&d, num_test_entries * 4);
1022         bitfield filter(num_test_entries);
1023         FilteredDisk fd(std::move(bd), std::move(filter), 4);
1024     }
1025 */
1026     remove("test_file.bin");
1027 }
1028