1 // Copyright 2018 Chia Network Inc
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6
7 // http://www.apache.org/licenses/LICENSE-2.0
8
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <stdio.h>
16
17 #include <set>
18
19 #include "../lib/include/catch.hpp"
20 #include "../lib/include/picosha2.hpp"
21 #include "calculate_bucket.hpp"
22 #include "disk.hpp"
23 #include "plotter_disk.hpp"
24 #include "prover_disk.hpp"
25 #include "sort_manager.hpp"
26 #include "verifier.hpp"
27
28 using namespace std;
29
30 uint8_t plot_id_1[] = {35, 2, 52, 4, 51, 55, 23, 84, 91, 10, 111, 12, 13, 222, 151, 16,
31 228, 211, 254, 45, 92, 198, 204, 10, 9, 10, 11, 129, 139, 171, 15, 23};
32
33 uint8_t plot_id_3[] = {5, 104, 52, 4, 51, 55, 23, 84, 91, 10, 111, 12, 13, 222, 151, 16,
34 228, 211, 254, 45, 92, 198, 204, 10, 9, 10, 11, 129, 139, 171, 15, 23};
35
intToBytes(uint32_t paramInt,uint32_t numBytes)36 vector<unsigned char> intToBytes(uint32_t paramInt, uint32_t numBytes)
37 {
38 vector<unsigned char> arrayOfByte(numBytes, 0);
39 for (uint32_t i = 0; paramInt > 0; i++) {
40 arrayOfByte[numBytes - i - 1] = paramInt & 0xff;
41 paramInt >>= 8;
42 }
43 return arrayOfByte;
44 }
45
to_uint128(uint64_t hi,uint64_t lo)46 static uint128_t to_uint128(uint64_t hi, uint64_t lo) { return (uint128_t)hi << 64 | lo; }
47
48 TEST_CASE("SliceInt64FromBytes 1 bit")
49 {
50 const uint8_t bytes[9 + 7] = {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9};
51
52 // since we interpret the first 64 bits (8 bytes) as big endian, the
53 // first byte is 0x01
54 CHECK(Util::SliceInt64FromBytes(bytes, 0, 1) == 0);
55 CHECK(Util::SliceInt64FromBytes(bytes, 1, 1) == 0);
56 CHECK(Util::SliceInt64FromBytes(bytes, 2, 1) == 0);
57 CHECK(Util::SliceInt64FromBytes(bytes, 3, 1) == 0);
58 CHECK(Util::SliceInt64FromBytes(bytes, 4, 1) == 0);
59 CHECK(Util::SliceInt64FromBytes(bytes, 5, 1) == 0);
60 CHECK(Util::SliceInt64FromBytes(bytes, 6, 1) == 0);
61 CHECK(Util::SliceInt64FromBytes(bytes, 7, 1) == 1);
62
63 // the second byte is 0x2
64 CHECK(Util::SliceInt64FromBytes(bytes, 8, 1) == 0);
65 CHECK(Util::SliceInt64FromBytes(bytes, 9, 1) == 0);
66 CHECK(Util::SliceInt64FromBytes(bytes, 10, 1) == 0);
67 CHECK(Util::SliceInt64FromBytes(bytes, 11, 1) == 0);
68 CHECK(Util::SliceInt64FromBytes(bytes, 12, 1) == 0);
69 CHECK(Util::SliceInt64FromBytes(bytes, 13, 1) == 0);
70 CHECK(Util::SliceInt64FromBytes(bytes, 14, 1) == 1);
71 CHECK(Util::SliceInt64FromBytes(bytes, 15, 1) == 0);
72
73 // the third byte is 0x3
74 CHECK(Util::SliceInt64FromBytes(bytes, 16, 1) == 0);
75 CHECK(Util::SliceInt64FromBytes(bytes, 17, 1) == 0);
76 CHECK(Util::SliceInt64FromBytes(bytes, 18, 1) == 0);
77 CHECK(Util::SliceInt64FromBytes(bytes, 19, 1) == 0);
78 CHECK(Util::SliceInt64FromBytes(bytes, 20, 1) == 0);
79 CHECK(Util::SliceInt64FromBytes(bytes, 21, 1) == 0);
80 CHECK(Util::SliceInt64FromBytes(bytes, 22, 1) == 1);
81 CHECK(Util::SliceInt64FromBytes(bytes, 23, 1) == 1);
82 }
83
84 TEST_CASE("SliceInt64FromBytes 8 bits")
85 {
86 const uint8_t bytes[9 + 7] = {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9};
87
88 // since we interpret the first 64 bits (8 bytes) as big endian, the
89 // first byte is 0x01
90 CHECK(Util::SliceInt64FromBytes(bytes, 0, 8) == 0b00000001);
91 CHECK(Util::SliceInt64FromBytes(bytes, 1, 8) == 0b00000010);
92 CHECK(Util::SliceInt64FromBytes(bytes, 2, 8) == 0b00000100);
93 CHECK(Util::SliceInt64FromBytes(bytes, 3, 8) == 0b00001000);
94 CHECK(Util::SliceInt64FromBytes(bytes, 4, 8) == 0b00010000);
95 CHECK(Util::SliceInt64FromBytes(bytes, 5, 8) == 0b00100000);
96 CHECK(Util::SliceInt64FromBytes(bytes, 6, 8) == 0b01000000);
97 CHECK(Util::SliceInt64FromBytes(bytes, 7, 8) == 0b10000001);
98
99 CHECK(Util::SliceInt64FromBytes(bytes, 8, 8) == 0b00000010);
100 CHECK(Util::SliceInt64FromBytes(bytes, 9, 8) == 0b00000100);
101 CHECK(Util::SliceInt64FromBytes(bytes, 10, 8) == 0b00001000);
102 CHECK(Util::SliceInt64FromBytes(bytes, 11, 8) == 0b00010000);
103 CHECK(Util::SliceInt64FromBytes(bytes, 12, 8) == 0b00100000);
104 CHECK(Util::SliceInt64FromBytes(bytes, 13, 8) == 0b01000000);
105 CHECK(Util::SliceInt64FromBytes(bytes, 14, 8) == 0b10000000);
106 CHECK(Util::SliceInt64FromBytes(bytes, 15, 8) == 0b00000001);
107
108 CHECK(Util::SliceInt64FromBytes(bytes, 16, 8) == 0b00000011);
109 CHECK(Util::SliceInt64FromBytes(bytes, 17, 8) == 0b00000110);
110 CHECK(Util::SliceInt64FromBytes(bytes, 18, 8) == 0b00001100);
111 CHECK(Util::SliceInt64FromBytes(bytes, 19, 8) == 0b00011000);
112 }
113
114 TEST_CASE("SliceInt64FromBytes 24 bits")
115 {
116 const uint8_t bytes[9 + 7] = {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9};
117
118 // since we interpret the first 64 bits (8 bytes) as big endian, the
119 // first byte is 0x01
120 CHECK(Util::SliceInt64FromBytes(bytes, 0, 24) == 0b00000001'00000010'00000011);
121 CHECK(Util::SliceInt64FromBytes(bytes, 1, 24) == 0b0000001'00000010'00000011'0);
122 CHECK(Util::SliceInt64FromBytes(bytes, 2, 24) == 0b000001'00000010'00000011'00);
123 CHECK(Util::SliceInt64FromBytes(bytes, 3, 24) == 0b00001'00000010'00000011'000);
124 CHECK(Util::SliceInt64FromBytes(bytes, 4, 24) == 0b0001'00000010'00000011'0000);
125 CHECK(Util::SliceInt64FromBytes(bytes, 5, 24) == 0b001'00000010'00000011'00000);
126 CHECK(Util::SliceInt64FromBytes(bytes, 6, 24) == 0b01'00000010'00000011'000001);
127 CHECK(Util::SliceInt64FromBytes(bytes, 7, 24) == 0b1'00000010'00000011'0000010);
128 }
129
130 TEST_CASE("SliceInt64FromBytesFull")
131 {
132 const uint8_t bytes[9 + 7] = {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9};
133
134 // since we interpret the first 64 bits (8 bytes) as big endian, the
135 // first byte is 0x01
136 CHECK(Util::SliceInt64FromBytesFull(bytes, 0, 64) == 0x0102030405060708ull);
137 CHECK(Util::SliceInt64FromBytesFull(bytes, 1, 64) == 0x0102030405060708ull << 1);
138 CHECK(Util::SliceInt64FromBytesFull(bytes, 2, 64) == 0x0102030405060708ull << 2);
139 CHECK(Util::SliceInt64FromBytesFull(bytes, 3, 64) == 0x0102030405060708ull << 3);
140 CHECK(Util::SliceInt64FromBytesFull(bytes, 4, 64) == 0x1020304050607080ull);
141 CHECK(Util::SliceInt64FromBytesFull(bytes, 5, 64) == ((0x1020304050607080ull << 1) | 0b1));
142 CHECK(Util::SliceInt64FromBytesFull(bytes, 6, 64) == ((0x1020304050607080ull << 2) | 0b10));
143 CHECK(Util::SliceInt64FromBytesFull(bytes, 7, 64) == ((0x1020304050607080ull << 3) | 0b100));
144 CHECK(Util::SliceInt64FromBytesFull(bytes, 8, 64) == 0x0203040506070809ull);
145 }
146
147 TEST_CASE("Util")
148 {
149 SECTION("Increment and decrement")
150 {
151 uint8_t bytes[3 + 7] = {45, 172, 225};
152 REQUIRE(Util::SliceInt64FromBytes(bytes, 2, 19) == 374172);
153 uint8_t bytes2[1 + 7] = {213};
154 REQUIRE(Util::SliceInt64FromBytes(bytes2, 1, 5) == 21);
155 uint8_t bytes3[17 + 7] = {1, 2, 3, 4, 5, 6, 7, 255, 255, 10, 11, 12, 13, 14, 15, 16, 255};
156 uint128_t int3 = to_uint128(0x01020304050607ff, 0xff0a0b0c0d0e0f10);
157 REQUIRE(Util::SliceInt64FromBytes(bytes3, 64, 64) == (uint64_t)int3);
158 REQUIRE(Util::SliceInt64FromBytes(bytes3, 0, 60) == (uint64_t)(int3 >> 68));
159 REQUIRE(Util::SliceInt128FromBytes(bytes3, 0, 60) == int3 >> 68);
160 REQUIRE(Util::SliceInt128FromBytes(bytes3, 7, 64) == int3 >> 57);
161 REQUIRE(Util::SliceInt128FromBytes(bytes3, 7, 72) == int3 >> 49);
162 REQUIRE(Util::SliceInt128FromBytes(bytes3, 0, 128) == int3);
163 REQUIRE(Util::SliceInt128FromBytes(bytes3, 3, 125) == int3);
164 REQUIRE(Util::SliceInt128FromBytes(bytes3, 2, 125) == int3 >> 1);
165 REQUIRE(Util::SliceInt128FromBytes(bytes3, 0, 120) == int3 >> 8);
166 REQUIRE(Util::SliceInt128FromBytes(bytes3, 3, 127) == (int3 << 2 | 3));
167 }
168 }
169
170 TEST_CASE("Bits")
171 {
172 SECTION("Slicing and manipulating")
173 {
174 Bits g = Bits(13271, 15);
175 cout << "G: " << g << endl;
176 cout << "G Slice: " << g.Slice(4, 9) << endl;
177 cout << "G Slice: " << g.Slice(0, 9) << endl;
178 cout << "G Slice: " << g.Slice(9, 10) << endl;
179 cout << "G Slice: " << g.Slice(9, 15) << endl;
180 cout << "G Slice: " << g.Slice(9, 9) << endl;
181 REQUIRE(g.Slice(9, 9) == Bits());
182
183 uint8_t bytes[2];
184 g.ToBytes(bytes);
185 cout << "bytes: " << static_cast<int>(bytes[0]) << " " << static_cast<int>(bytes[1])
186 << endl;
187 cout << "Back to Bits: " << Bits(bytes, 2, 16) << endl;
188
189 Bits(256, 9).ToBytes(bytes);
190 cout << "bytes: " << static_cast<int>(bytes[0]) << " " << static_cast<int>(bytes[1])
191 << endl;
192 cout << "Back to Bits: " << Bits(bytes, 2, 16) << endl;
193
194 cout << Bits(640, 11) << endl;
195 Bits(640, 11).ToBytes(bytes);
196 cout << "bytes: " << static_cast<int>(bytes[0]) << " " << static_cast<int>(bytes[1])
197 << endl;
198
199 Bits h = Bits(bytes, 2, 16);
200 Bits i = Bits(bytes, 2, 17);
201 cout << "H: " << h << endl;
202 cout << "I: " << i << endl;
203
204 cout << "G: " << g << endl;
205 cout << "size: " << g.GetSize() << endl;
206
207 Bits shifted = (g << 150);
208
209 REQUIRE(shifted.GetSize() == 15);
210 REQUIRE(shifted.ToString() == "000000000000000");
211
212 Bits large = Bits(13271, 200);
213 REQUIRE(large == ((large << 160)) >> 160);
214 REQUIRE((large << 160).GetSize() == 200);
215
216 Bits l = Bits(123287490 & ((1U << 20) - 1), 20);
217 l = l + Bits(0, 5);
218
219 Bits m = Bits(5, 3);
220 uint8_t buf[1];
221 m.ToBytes(buf);
222 REQUIRE(buf[0] == (5 << 5));
223 }
224 SECTION("Park Bits")
225 {
226 uint32_t const num_bytes = 16000;
227 uint8_t buf[num_bytes];
228 uint8_t buf_2[num_bytes];
229 Util::GetRandomBytes(buf, num_bytes);
230 ParkBits my_bits = ParkBits(buf, num_bytes, num_bytes * 8);
231 my_bits.ToBytes(buf_2);
232 for (uint32_t i = 0; i < num_bytes; i++) {
233 REQUIRE(buf[i] == buf_2[i]);
234 }
235 }
236
237 SECTION("Large Bits")
238 {
239 uint32_t const num_bytes = 200000;
240 uint8_t buf[num_bytes];
241 uint8_t buf_2[num_bytes];
242 Util::GetRandomBytes(buf, num_bytes);
243 LargeBits my_bits = LargeBits(buf, num_bytes, num_bytes * 8);
244 my_bits.ToBytes(buf_2);
245 for (uint32_t i = 0; i < num_bytes; i++) {
246 REQUIRE(buf[i] == buf_2[i]);
247 }
248 }
249 }
250
CheckMatch(int64_t yl,int64_t yr)251 bool CheckMatch(int64_t yl, int64_t yr)
252 {
253 int64_t bl = yl / kBC;
254 int64_t br = yr / kBC;
255 if (bl + 1 != br)
256 return false; // Buckets don't match
257 for (int64_t m = 0; m < kExtraBitsPow; m++) {
258 if ((((yr % kBC) / kC - ((yl % kBC) / kC)) - m) % kB == 0) {
259 int64_t c_diff = 2 * m + bl % 2;
260 c_diff *= c_diff;
261
262 if ((((yr % kBC) % kC - ((yl % kBC) % kC)) - c_diff) % kC == 0) {
263 return true;
264 }
265 }
266 }
267 return false;
268 }
269
270 // Get next set in the Cartesian product of k ranges of [0, n - 1], similar to
271 // k nested 'for' loops from 0 to n - 1
CartProdNext(uint8_t * items,uint8_t n,uint8_t k,bool init)272 static int CartProdNext(uint8_t* items, uint8_t n, uint8_t k, bool init)
273 {
274 uint8_t i;
275
276 if (init) {
277 memset(items, 0, k);
278 return 0;
279 }
280
281 items[0]++;
282 for (i = 0; i < k; i++) {
283 if (items[i] == n) {
284 items[i] = 0;
285 if (i == k - 1) {
286 return -1;
287 }
288 items[i + 1]++;
289 } else {
290 break;
291 }
292 }
293
294 return 0;
295 }
296
sq(int n)297 static int sq(int n) { return n * n; }
298
Have4Cycles(uint32_t extraBits,int B,int C)299 static bool Have4Cycles(uint32_t extraBits, int B, int C)
300 {
301 uint8_t m[4];
302 bool init = true;
303
304 while (!CartProdNext(m, 1 << extraBits, 4, init)) {
305 uint8_t r1 = m[0], r2 = m[1], s1 = m[2], s2 = m[3];
306
307 init = false;
308 if (r1 != s1 && (r1 << extraBits) + r2 != (s2 << extraBits) + s1 &&
309 (r1 - s1 + r2 - s2) % B == 0) {
310 uint8_t p[2];
311 bool initp = true;
312
313 while (!CartProdNext(p, 2, 2, initp)) {
314 uint8_t p1 = p[0], p2 = p[1];
315 int lhs = sq(2 * r1 + p1) - sq(2 * s1 + p1) + sq(2 * r2 + p2) - sq(2 * s2 + p2);
316
317 initp = false;
318 if (lhs % C == 0) {
319 fprintf(stderr, "%d %d %d %d %d %d\n", r1, r2, s1, s2, p1, p2);
320 return true;
321 }
322 }
323 }
324 }
325
326 return false;
327 }
328
329 TEST_CASE("Matching function")
330 {
331 SECTION("Cycles") { REQUIRE(!Have4Cycles(kExtraBits, kB, kC)); }
332 }
333
VerifyFC(uint8_t t,uint8_t k,uint64_t L,uint64_t R,uint64_t y1,uint64_t y,uint64_t c)334 void VerifyFC(uint8_t t, uint8_t k, uint64_t L, uint64_t R, uint64_t y1, uint64_t y, uint64_t c)
335 {
336 uint8_t sizes[] = {1, 2, 4, 4, 3, 2};
337 uint8_t size = sizes[t - 2];
338 FxCalculator fcalc(k, t);
339
340 std::pair<Bits, Bits> res = fcalc.CalculateBucket(
341 Bits(y1, k + kExtraBits), Bits(L, k * size), Bits(R, k * size));
342 REQUIRE(res.first.GetValue() == y);
343 if (c) {
344 REQUIRE(res.second.GetValue() == c);
345 }
346 }
347
348 TEST_CASE("F functions")
349 {
350 SECTION("F1")
351 {
352 uint8_t test_k = 35;
353 uint8_t test_key[] = {0, 2, 3, 4, 5, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
354 1, 2, 3, 41, 5, 6, 7, 8, 9, 10, 11, 12, 13, 11, 15, 16};
355 F1Calculator f1(test_k, test_key);
356
357 Bits L = Bits(525, test_k);
358 pair<Bits, Bits> result1 = f1.CalculateBucket(L);
359 Bits L2 = Bits(526, test_k);
360 pair<Bits, Bits> result2 = f1.CalculateBucket(L2);
361 Bits L3 = Bits(625, test_k);
362 pair<Bits, Bits> result3 = f1.CalculateBucket(L3);
363
364 uint64_t results[256];
365 f1.CalculateBuckets(L.GetValue(), 101, results);
366 REQUIRE(result1.first.GetValue() == results[0]);
367 REQUIRE(result2.first.GetValue() == results[1]);
368 REQUIRE(result3.first.GetValue() == results[100]);
369
370 uint32_t max_batch = 1 << kBatchSizes;
371 test_k = 32;
372 F1Calculator f1_2(test_k, test_key);
373 L = Bits(192837491, test_k);
374 result1 = f1_2.CalculateBucket(L);
375 L2 = Bits(192837491 + 1, test_k);
376 result2 = f1_2.CalculateBucket(L2);
377 L3 = Bits(192837491 + 2, test_k);
378 result3 = f1_2.CalculateBucket(L3);
379 Bits L4 = Bits(192837491 + max_batch - 1, test_k);
380 pair<Bits, Bits> result4 = f1_2.CalculateBucket(L4);
381
382 f1_2.CalculateBuckets(L.GetValue(), max_batch, results);
383 REQUIRE(result1.first.GetValue() == results[0]);
384 REQUIRE(result2.first.GetValue() == results[1]);
385 REQUIRE(result3.first.GetValue() == results[2]);
386 REQUIRE(result4.first.GetValue() == results[max_batch - 1]);
387 }
388
389 SECTION("F2")
390 {
391 uint8_t test_key_2[] = {20, 2, 5, 4, 51, 52, 23, 84, 91, 10, 111,
392 12, 13, 24, 151, 16, 228, 211, 254, 45, 92, 198,
393 204, 10, 9, 10, 11, 129, 139, 171, 15, 18};
394 map<uint64_t, vector<pair<Bits, Bits>>> buckets;
395
396 uint8_t const k = 12;
397 uint64_t num_buckets = (1ULL << (k + kExtraBits)) / kBC + 1;
398 uint64_t x = 0;
399
400 F1Calculator f1(k, test_key_2);
401 for (uint32_t j = 0; j < (1ULL << (k - 4)) + 1; j++) {
402 uint64_t y[1 << 4];
403
404 f1.CalculateBuckets(x, 1U << 4, y);
405 for (int i = 0; i < 1 << 4; i++) {
406 uint64_t bucket = y[i] / kBC;
407 if (buckets.find(bucket) == buckets.end()) {
408 buckets[bucket] = vector<std::pair<Bits, Bits>>();
409 }
410 buckets[bucket].push_back(std::make_pair(Bits(y[i], k + kExtraBits), Bits(x, k)));
411 if (x + 1 > (1ULL << k) - 1) {
412 break;
413 }
414 ++x;
415 }
416 if (x + 1 > (1ULL << k) - 1) {
417 break;
418 }
419 }
420
421 FxCalculator f2(k, 2);
422 int total_matches = 0;
423
424 for (auto kv : buckets) {
425 if (kv.first == num_buckets - 1) {
426 continue;
427 }
428 auto bucket_elements_2 = buckets[kv.first + 1];
429 vector<PlotEntry> left_bucket;
430 vector<PlotEntry> right_bucket;
431 for (auto yx1 : kv.second) {
432 PlotEntry e;
433 e.y = get<0>(yx1).GetValue();
434 left_bucket.push_back(e);
435 }
436 for (auto yx2 : buckets[kv.first + 1]) {
437 PlotEntry e;
438 e.y = get<0>(yx2).GetValue();
439 right_bucket.push_back(e);
440 }
441 sort(
442 left_bucket.begin(),
443 left_bucket.end(),
__anon0cce16290102(const PlotEntry& a, const PlotEntry& b) 444 [](const PlotEntry& a, const PlotEntry& b) -> bool { return a.y > b.y; });
445 sort(
446 right_bucket.begin(),
447 right_bucket.end(),
__anon0cce16290202(const PlotEntry& a, const PlotEntry& b) 448 [](const PlotEntry& a, const PlotEntry& b) -> bool { return a.y > b.y; });
449
450 uint16_t idx_L[10000];
451 uint16_t idx_R[10000];
452
453 int32_t idx_count = f2.FindMatches(left_bucket, right_bucket, idx_L, idx_R);
454 for(int32_t i=0; i < idx_count; i++) {
455 REQUIRE(CheckMatch(left_bucket[idx_L[i]].y, right_bucket[idx_R[i]].y));
456 }
457 total_matches += idx_count;
458 }
459 REQUIRE(total_matches > (1 << k) / 2);
460 REQUIRE(total_matches < (1 << k) * 2);
461 }
462
463 SECTION("Fx")
464 {
465 VerifyFC(2, 16, 0x44cb, 0x204f, 0x20a61a, 0x2af546, 0x44cb204f);
466 VerifyFC(2, 16, 0x3c5f, 0xfda9, 0x3988ec, 0x15293b, 0x3c5ffda9);
467 VerifyFC(3, 16, 0x35bf992d, 0x7ce42c82, 0x31e541, 0xf73b3, 0x35bf992d7ce42c82);
468 VerifyFC(3, 16, 0x7204e52d, 0xf1fd42a2, 0x28a188, 0x3fb0b5, 0x7204e52df1fd42a2);
469 VerifyFC(
470 4, 16, 0x5b6e6e307d4bedc, 0x8a9a021ea648a7dd, 0x30cb4c, 0x11ad5, 0xd4bd0b144fc26138);
471 VerifyFC(
472 4, 16, 0xb9d179e06c0fd4f5, 0xf06d3fef701966a0, 0x1dd5b6, 0xe69a2, 0xd02115f512009d4d);
473 VerifyFC(5, 16, 0xc2cd789a380208a9, 0x19999e3fa46d6753, 0x25f01e, 0x1f22bd, 0xabe423040a33);
474 VerifyFC(5, 16, 0xbe3edc0a1ef2a4f0, 0x4da98f1d3099fdf5, 0x3feb18, 0x31501e, 0x7300a3a03ac5);
475 VerifyFC(6, 16, 0xc965815a47c5, 0xf5e008d6af57, 0x1f121a, 0x1cabbe, 0xc8cc6947);
476 VerifyFC(6, 16, 0xd420677f6cbd, 0x5894aa2ca1af, 0x2efde9, 0xc2121, 0x421bb8ec);
477 VerifyFC(7, 16, 0x5fec898f, 0x82283d15, 0x14f410, 0x24c3c2, 0x0);
478 VerifyFC(7, 16, 0x64ac5db9, 0x7923986, 0x590fd, 0x1c74a2, 0x0);
479 }
480 }
481
HexToBytes(const string & hex,uint8_t * result)482 void HexToBytes(const string& hex, uint8_t* result)
483 {
484 for (unsigned int i = 0; i < hex.length(); i += 2) {
485 string byteString = hex.substr(i, 2);
486 uint8_t byte = (uint8_t)strtol(byteString.c_str(), NULL, 16);
487 result[i / 2] = byte;
488 }
489 }
490
TestProofOfSpace(std::string filename,uint32_t iterations,uint8_t k,uint8_t * plot_id,uint32_t num_proofs)491 void TestProofOfSpace(
492 std::string filename,
493 uint32_t iterations,
494 uint8_t k,
495 uint8_t* plot_id,
496 uint32_t num_proofs)
497 {
498 DiskProver prover(filename);
499 uint8_t* proof_data = new uint8_t[8 * k];
500 uint32_t success = 0;
501 // Tries an edge case challenge with many 1s in the front, and ensures there is no segfault
502 vector<unsigned char> hash(picosha2::k_digest_size);
503 HexToBytes("fffffa2b647d4651c500076d7df4c6f352936cf293bd79c591a7b08e43d6adfb", hash.data());
504 prover.GetQualitiesForChallenge(hash.data());
505
506 for (uint32_t i = 0; i < iterations; i++) {
507 vector<unsigned char> hash_input = intToBytes(i, 4);
508 vector<unsigned char> hash(picosha2::k_digest_size);
509 picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
510 vector<LargeBits> qualities = prover.GetQualitiesForChallenge(hash.data());
511 Verifier verifier = Verifier();
512 for (uint32_t index = 0; index < qualities.size(); index++) {
513 LargeBits proof = prover.GetFullProof(hash.data(), index);
514 proof.ToBytes(proof_data);
515
516 LargeBits quality = verifier.ValidateProof(plot_id, k, hash.data(), proof_data, k * 8);
517 REQUIRE(quality.GetSize() == 256);
518 REQUIRE(quality == qualities[index]);
519 success += 1;
520
521 // Tests invalid proof
522 proof_data[0] = (proof_data[0] + 1) % 256;
523 LargeBits quality_2 =
524 verifier.ValidateProof(plot_id, k, hash.data(), proof_data, k * 8);
525 REQUIRE(quality_2.GetSize() == 0);
526 }
527 }
528 std::cout << "Success: " << success << "/" << iterations << " "
529 << (100 * ((double)success / (double)iterations)) << "%" << std::endl;
530 REQUIRE(success == num_proofs);
531 REQUIRE(success > 0.5 * iterations);
532 REQUIRE(success < 1.5 * iterations);
533 delete[] proof_data;
534 }
535
PlotAndTestProofOfSpace(std::string filename,uint32_t iterations,uint8_t k,uint8_t * plot_id,uint32_t buffer,uint32_t num_proofs,uint32_t stripe_size,uint8_t num_threads)536 void PlotAndTestProofOfSpace(
537 std::string filename,
538 uint32_t iterations,
539 uint8_t k,
540 uint8_t* plot_id,
541 uint32_t buffer,
542 uint32_t num_proofs,
543 uint32_t stripe_size,
544 uint8_t num_threads)
545 {
546 DiskPlotter plotter = DiskPlotter();
547 uint8_t memo[5] = {1, 2, 3, 4, 5};
548 plotter.CreatePlotDisk(
549 ".", ".", ".", filename, k, memo, 5, plot_id, 32, buffer, 0, stripe_size, num_threads);
550 TestProofOfSpace(filename, iterations, k, plot_id, num_proofs);
551 REQUIRE(remove(filename.c_str()) == 0);
552 }
553
554 TEST_CASE("Plotting")
555 {
556 SECTION("Disk plot k18")
557 {
558 PlotAndTestProofOfSpace("cpp-test-plot.dat", 100, 18, plot_id_1, 11, 95, 4000, 2);
559 }
560 SECTION("Disk plot k19")
561 {
562 PlotAndTestProofOfSpace("cpp-test-plot.dat", 100, 19, plot_id_1, 100, 71, 8192, 2);
563 }
564 SECTION("Disk plot k19 single-thread")
565 {
566 PlotAndTestProofOfSpace("cpp-test-plot.dat", 100, 19, plot_id_1, 100, 71, 8192, 1);
567 }
568 SECTION("Disk plot k20")
569 {
570 PlotAndTestProofOfSpace("cpp-test-plot.dat", 500, 20, plot_id_3, 100, 469, 16000, 2);
571 }
572 SECTION("Disk plot k21")
573 {
574 PlotAndTestProofOfSpace("cpp-test-plot.dat", 5000, 21, plot_id_3, 100, 4945, 8192, 4);
575 }
576 // SECTION("Disk plot k24") { PlotAndTestProofOfSpace("cpp-test-plot.dat", 100, 24, plot_id_3,
577 // 100, 107); }
578 }
579
580 TEST_CASE("Invalid plot")
581 {
582 SECTION("File gets deleted")
583 {
584 string filename = "invalid-plot.dat";
585 {
586 DiskPlotter plotter = DiskPlotter();
587 uint8_t memo[5] = {1, 2, 3, 4, 5};
588 uint8_t k = 20;
589 plotter.CreatePlotDisk(".", ".", ".", filename, k, memo, 5, plot_id_1, 32, 200, 32, 8192, 2);
590 DiskProver prover(filename);
591 uint8_t* proof_data = new uint8_t[8 * k];
592 uint8_t challenge[32];
593 size_t i;
594 memset(challenge, 155, 32);
595 vector<LargeBits> qualities;
596 for (i = 0; i < 50; i++) {
597 qualities = prover.GetQualitiesForChallenge(challenge);
598 if (qualities.size())
599 break;
600 challenge[0]++;
601 }
602 Verifier verifier = Verifier();
603 REQUIRE(qualities.size() > 0);
604 for (uint32_t index = 0; index < qualities.size(); index++) {
605 LargeBits proof = prover.GetFullProof(challenge, index);
606 proof.ToBytes(proof_data);
607 LargeBits quality =
608 verifier.ValidateProof(plot_id_1, k, challenge, proof_data, k * 8);
609 REQUIRE(quality == qualities[index]);
610 }
611 delete[] proof_data;
612 }
613 REQUIRE(remove(filename.c_str()) == 0);
__anon0cce16290302() 614 REQUIRE_THROWS_WITH([&]() { DiskProver p(filename); }(), "Invalid file " + filename);
615 }
616 }
617
618 TEST_CASE("Sort on disk")
619 {
620 SECTION("ExtractNum")
621 {
622 for (int i = 0; i < 15 * 8 - 5; i++) {
623 uint8_t buf[15 + 7];
624 Bits((uint128_t)27 << i, 15 * 8).ToBytes(buf);
625
626 REQUIRE(Util::ExtractNum(buf, 15, 15 * 8 - 4 - i, 3) == 5);
627 }
628 uint8_t buf[16 + 7];
629 Bits((uint128_t)27 << 5, 128).ToBytes(buf);
630 REQUIRE(Util::ExtractNum(buf, 16, 100, 200) == 864);
631 }
632
633 SECTION("MemCmpBits")
634 {
635 uint8_t left[3];
636 left[0] = 12;
637 left[1] = 10;
638 left[2] = 100;
639
640 uint8_t right[3];
641 right[0] = 12;
642 right[1] = 10;
643 right[2] = 100;
644
645 REQUIRE(Util::MemCmpBits(left, right, 3, 0) == 0);
646 REQUIRE(Util::MemCmpBits(left, right, 3, 10) == 0);
647
648 right[1] = 11;
649 REQUIRE(Util::MemCmpBits(left, right, 3, 0) < 0);
650 REQUIRE(Util::MemCmpBits(left, right, 3, 16) == 0);
651
652 right[1] = 9;
653 REQUIRE(Util::MemCmpBits(left, right, 3, 0) > 0);
654
655 right[1] = 10;
656
657 // Last bit differs
658 right[2] = 101;
659 REQUIRE(Util::MemCmpBits(left, right, 3, 0) < 0);
660 }
661
662 SECTION("Quicksort")
663 {
664 uint32_t const iters = 100;
665 vector<string> hashes;
666 uint8_t* hashes_bytes = new uint8_t[iters * 16];
667 memset(hashes_bytes, 0, iters * 16);
668
669 srand(0);
670 for (uint32_t i = 0; i < iters; i++) {
671 // reverting to rand()
672 string to_insert = std::to_string(rand());
673 while (to_insert.length() < 16) {
674 to_insert += "0";
675 }
676 hashes.push_back(to_insert);
677 memcpy(hashes_bytes + i * 16, to_insert.data(), to_insert.length());
678 }
679 sort(hashes.begin(), hashes.end());
680 QuickSort::Sort(hashes_bytes, 16, iters, 0);
681
682 for (uint32_t i = 0; i < iters; i++) {
683 std::string str(reinterpret_cast<char*>(hashes_bytes) + i * 16, 16);
684 REQUIRE(str.compare(hashes[i]) == 0);
685 }
686 delete[] hashes_bytes;
687 }
688
689 SECTION("File disk")
690 {
691 FileDisk d = FileDisk("test_file.bin");
692 uint8_t buf[5] = {1, 2, 3, 5, 7};
693 d.Write(250, buf, 5);
694
695 uint8_t read_buf[5];
696 d.Read(250, read_buf, 5);
697
698 REQUIRE(memcmp(buf, read_buf, 5) == 0);
699 remove("test_file.bin");
700 }
701
702 SECTION("Lazy Sort Manager QS")
703 {
704 uint32_t iters = 250000;
705 uint32_t const size = 32;
706 vector<Bits> input;
707 const uint32_t memory_len = 1000000;
708 SortManager manager(memory_len, 16, 4, size, ".", "test-files", 0, 1);
709 int total_written_1 = 0;
710 for (uint32_t i = 0; i < iters; i++) {
711 vector<unsigned char> hash_input = intToBytes(i, 4);
712 vector<unsigned char> hash(picosha2::k_digest_size);
713 picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
714 total_written_1 += size;
715 Bits to_write = Bits(hash.data(), size, size * 8);
716 input.emplace_back(to_write);
717 manager.AddToCache(to_write);
718 }
719 manager.FlushCache();
720 uint8_t buf[size];
721 sort(input.begin(), input.end());
722 uint8_t* buf3;
723 for (uint32_t i = 0; i < iters; i++) {
724 buf3 = manager.ReadEntry(i * size);
725 input[i].ToBytes(buf);
726 REQUIRE(memcmp(buf, buf3, size) == 0);
727 }
728 }
729
730 SECTION("Lazy Sort Manager uniform sort")
731 {
732 uint32_t iters = 120000;
733 uint32_t const size = 32;
734 vector<Bits> input;
735 const uint32_t memory_len = 1000000;
736 SortManager manager(memory_len, 16, 4, size, ".", "test-files", 0, 1);
737 int total_written_1 = 0;
738 for (uint32_t i = 0; i < iters; i++) {
739 vector<unsigned char> hash_input = intToBytes(i, 4);
740 vector<unsigned char> hash(picosha2::k_digest_size);
741 picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
742 total_written_1 += size;
743 Bits to_write = Bits(hash.data(), size, size * 8);
744 input.emplace_back(to_write);
745 manager.AddToCache(to_write);
746 }
747 manager.FlushCache();
748 uint8_t buf[size];
749 sort(input.begin(), input.end());
750 uint8_t* buf3;
751 for (uint32_t i = 0; i < iters; i++) {
752 buf3 = manager.ReadEntry(i * size);
753 input[i].ToBytes(buf);
754 REQUIRE(memcmp(buf, buf3, size) == 0);
755 }
756 }
757
758 SECTION("Sort in Memory")
759 {
760 uint32_t iters = 100000;
761 uint32_t const size = 32;
762 vector<Bits> input;
763 uint32_t begin = 1000;
764 FileDisk disk("test_file.bin");
765
766 for (uint32_t i = 0; i < iters; i++) {
767 vector<unsigned char> hash_input = intToBytes(i, 4);
768 vector<unsigned char> hash(picosha2::k_digest_size);
769 picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
770 hash[0] = hash[1] = 0;
771 disk.Write(begin + i * size, hash.data(), size);
772 input.emplace_back(Bits(hash.data(), size, size * 8));
773 }
774
775 const uint32_t memory_len = Util::RoundSize(iters) * size;
776 auto memory = std::make_unique<uint8_t[]>(memory_len);
777 UniformSort::SortToMemory(disk, begin, memory.get(), size, iters, 16);
778
779 sort(input.begin(), input.end());
780 uint8_t buf[size];
781 for (uint32_t i = 0; i < iters; i++) {
782 input[i].ToBytes(buf);
783 REQUIRE(memcmp(buf, memory.get() + i * size, size) == 0);
784 }
785 }
786 }
787
788 TEST_CASE("bitfield-simple")
789 {
790 bitfield b(4);
791 CHECK(!b.get(0));
792 CHECK(!b.get(1));
793 CHECK(!b.get(2));
794 CHECK(!b.get(3));
795
796 b.set(0);
797 CHECK(b.get(0));
798 CHECK(!b.get(1));
799 CHECK(!b.get(2));
800 CHECK(!b.get(3));
801
802 b.set(1);
803 CHECK(b.get(0));
804 CHECK(b.get(1));
805 CHECK(!b.get(2));
806 CHECK(!b.get(3));
807
808 b.set(3);
809 CHECK(b.get(0));
810 CHECK(b.get(1));
811 CHECK(!b.get(2));
812 CHECK(b.get(3));
813 }
814
815 TEST_CASE("bitfield-count")
816 {
817 bitfield b(512);
818
819 for (int i = 0; i < 512; ++i) {
820 CHECK(b.count(0, 512) == i);
821 CHECK(!b.get(i));
822 b.set(i);
823 CHECK(b.get(i));
824 }
825 CHECK(b.count(0, 512) == 512);
826 }
827
828 TEST_CASE("bitfield-count-unaligned")
829 {
830 bitfield b(512);
831
832 for (int i = 0; i < 512; ++i) {
833 b.set(i);
834 }
835
836 for (int i = 0; i < 512; ++i) {
837 CHECK(b.count(0, i) == i);
838 }
839 }
840
841 TEST_CASE("bitfield_index-simple")
842 {
843 bitfield b(64);
844 b.set(0);
845 b.set(1);
846 b.set(3);
847 bitfield_index const idx(b);
848 CHECK(idx.lookup(0, 0) == std::pair<uint64_t, uint64_t>{0,0});
849 CHECK(idx.lookup(0, 1) == std::pair<uint64_t, uint64_t>{0,1});
850
851 CHECK(idx.lookup(0, 3) == std::pair<uint64_t, uint64_t>{0,2});
852
853 CHECK(idx.lookup(1, 0) == std::pair<uint64_t, uint64_t>{1,0});
854 CHECK(idx.lookup(1, 2) == std::pair<uint64_t, uint64_t>{1,1});
855 CHECK(idx.lookup(3, 0) == std::pair<uint64_t, uint64_t>{2,0});
856 }
857
858 TEST_CASE("bitfield_index-use index")
859 {
860 bitfield b(1048576);
861 CHECK(b.size() == 1048576);
862 b.set(1048576 - 3);
863 b.set(1048576 - 2);
864 b.set(1048576 - 1);
865 bitfield_index const idx(b);
866 CHECK(idx.lookup(1048576 - 3, 1) == std::pair<uint64_t, uint64_t>{0,1});
867 CHECK(idx.lookup(1048576 - 2, 1) == std::pair<uint64_t, uint64_t>{1,1});
868 }
869
870 TEST_CASE("bitfield_index edge-cases")
871 {
872 bitfield b(1048576);
873 CHECK(b.size() == 1048576);
874 b.set(0);
875 b.set(bitfield_index::kIndexBucket);
876 b.set(bitfield_index::kIndexBucket * 2);
877 b.set(1048576 - 1);
878 bitfield_index const idx(b);
879 CHECK(idx.lookup(0, 0) == std::pair<uint64_t, uint64_t>{0,0});
880 CHECK(idx.lookup(0, bitfield_index::kIndexBucket) == std::pair<uint64_t, uint64_t>{0,1});
881 CHECK(idx.lookup(0, bitfield_index::kIndexBucket * 2) == std::pair<uint64_t, uint64_t>{0,2});
882 CHECK(idx.lookup(0, 1048576 - 1) == std::pair<uint64_t, uint64_t>{0,3});
883
884 CHECK(idx.lookup(bitfield_index::kIndexBucket, 0) == std::pair<uint64_t, uint64_t>{1,0});
885 CHECK(idx.lookup(bitfield_index::kIndexBucket, bitfield_index::kIndexBucket) == std::pair<uint64_t, uint64_t>{1,1});
886 CHECK(idx.lookup(bitfield_index::kIndexBucket, 1048576 - 1 - bitfield_index::kIndexBucket)
887 == std::pair<uint64_t, uint64_t>{1,2});
888
889 CHECK(idx.lookup(bitfield_index::kIndexBucket * 2, 1048576 - 1 - bitfield_index::kIndexBucket * 2)
890 == std::pair<uint64_t, uint64_t>{2,1});
891 CHECK(idx.lookup(1048576 - 1, 0) == std::pair<uint64_t, uint64_t>{3,0});
892 }
893
test_bitfield_size(int const size)894 void test_bitfield_size(int const size)
895 {
896 bitfield b(size);
897 b.set(0);
898 b.set(size - 1);
899 bitfield_index const idx(b);
900 CHECK(idx.lookup(0, 0) == std::pair<uint64_t, uint64_t>{0,0});
901 CHECK(idx.lookup(0, size - 1) == std::pair<uint64_t, uint64_t>{0,1});
902 CHECK(idx.lookup(size - 1, 0) == std::pair<uint64_t, uint64_t>{1,0});
903 }
904
905 TEST_CASE("bitfield_index edge-sizes")
906 {
907 test_bitfield_size(bitfield_index::kIndexBucket - 1);
908 test_bitfield_size(bitfield_index::kIndexBucket);
909 test_bitfield_size(bitfield_index::kIndexBucket + 1);
910 }
911
912 namespace {
913
914 constexpr int num_test_entries = 2000000;
915
write_disk_file(FileDisk & df)916 void write_disk_file(FileDisk& df)
917 {
918 std::uint32_t val = 0;
919 for (int i = 0; i < num_test_entries; ++i) {
920 df.Write(i * 4, reinterpret_cast<std::uint8_t const*>(&val), 4);
921 ++val;
922 }
923 }
924
925 }
926
927 TEST_CASE("FileDisk")
928 {
929 FileDisk d = FileDisk("test_file.bin");
930 write_disk_file(d);
931
932 std::uint32_t val = 0;
933 for (uint32_t i = 0; i < num_test_entries; ++i) {
934 d.Read(i * 4, reinterpret_cast<std::uint8_t*>(&val), 4);
935 REQUIRE(i == val);
936 }
937
938 for (uint32_t i = num_test_entries - 1; i > 0; --i) {
939 d.Read(i * 4, reinterpret_cast<std::uint8_t*>(&val), 4);
940 CHECK(i == val);
941 }
942
943 remove("test_file.bin");
944 }
945
946 TEST_CASE("BufferedDisk")
947 {
948 FileDisk d = FileDisk("test_file.bin");
949 write_disk_file(d);
950
951 BufferedDisk bd(&d, num_test_entries * 4);
952
953 for (uint32_t i = 0; i < num_test_entries; ++i) {
954 auto const val = *reinterpret_cast<std::uint32_t const*>(bd.Read(i * 4, 4));
955 CHECK(i == val);
956 }
957
958 // don't go all the way down to 0, every backwards read cursor movement will
959 // print a warning
960 for (uint32_t i = num_test_entries - 1; i > num_test_entries / 2 + 200; --i) {
961 auto const val = *reinterpret_cast<std::uint32_t const*>(bd.Read(i * 4, 4));
962 CHECK(i == val);
963 }
964
965 remove("test_file.bin");
966 }
967
968 TEST_CASE("FilteredDisk")
969 {
970 FileDisk d = FileDisk("test_file.bin");
971 write_disk_file(d);
972
973 SECTION("filter even")
974 {
975 BufferedDisk bd(&d, num_test_entries * 4);
976 // filter every other entry (starting with 0)
977 bitfield filter(num_test_entries);
978 for (int i = 0; i < num_test_entries; ++i) {
979 if ((i & 1) == 1) filter.set(i);
980 }
981 FilteredDisk fd(std::move(bd), std::move(filter), 4);
982
983 for (uint32_t i = 0; i < num_test_entries / 2 - 1; ++i) {
984 auto const val = *reinterpret_cast<std::uint32_t const*>(fd.Read(i * 4, 4));
985 CHECK((i * 2) + 1 == val);
986 }
987
988 // don't go all the way down to 0, every backwards read cursor movement will
989 // print a warning
990 for (uint32_t i = num_test_entries / 2 - 1; i > num_test_entries / 2 + 200; --i) {
991 auto const val = *reinterpret_cast<std::uint32_t const*>(fd.Read(i * 4, 4));
992 CHECK((i * 2) + 1 == val);
993 }
994 }
995
996 SECTION("filter odd")
997 {
998 BufferedDisk bd(&d, num_test_entries * 4);
999 // filter every other entry (starting with 0)
1000 bitfield filter(num_test_entries);
1001 for (int i = 0; i < num_test_entries; ++i) {
1002 if ((i & 1) == 0) filter.set(i);
1003 }
1004 FilteredDisk fd(std::move(bd), std::move(filter), 4);
1005
1006 for (uint32_t i = 0; i < num_test_entries / 2 - 1; ++i) {
1007 auto const val = *reinterpret_cast<std::uint32_t const*>(fd.Read(i * 4, 4));
1008 CHECK((i * 2) == val);
1009 }
1010
1011 // don't go all the way down to 0, every backwards read cursor movement will
1012 // print a warning
1013 for (uint32_t i = num_test_entries / 2 - 1; i > num_test_entries / 2 + 200; --i) {
1014 auto const val = *reinterpret_cast<std::uint32_t const*>(fd.Read(i * 4, 4));
1015 CHECK((i * 2) == val);
1016 }
1017 }
1018 /*
1019 SECTION("empty bitfield")
1020 {
1021 BufferedDisk bd(&d, num_test_entries * 4);
1022 bitfield filter(num_test_entries);
1023 FilteredDisk fd(std::move(bd), std::move(filter), 4);
1024 }
1025 */
1026 remove("test_file.bin");
1027 }
1028