1 /*
2  * Copyright (c) 2017-2020, Yann Collet, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10 
11 #include <limits.h>
12 #include <math.h>
13 #include <stddef.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <string.h>
17 
18 #include "util.h"
19 #include "timefn.h"   /* UTIL_clockSpanMicro, SEC_TO_MICRO, UTIL_TIME_INITIALIZER */
20 #include "zstd.h"
21 #include "zstd_internal.h"
22 #include "mem.h"
23 #define ZDICT_STATIC_LINKING_ONLY
24 #include "zdict.h"
25 
26 /* Direct access to internal compression functions is required */
27 #include "zstd_compress.c"
28 
29 #define XXH_STATIC_LINKING_ONLY
30 #include "xxhash.h"     /* XXH64 */
31 
32 #ifndef MIN
33     #define MIN(a, b) ((a) < (b) ? (a) : (b))
34 #endif
35 
36 #ifndef MAX_PATH
37     #ifdef PATH_MAX
38         #define MAX_PATH PATH_MAX
39     #else
40         #define MAX_PATH 256
41     #endif
42 #endif
43 
44 /*-************************************
45 *  DISPLAY Macros
46 **************************************/
47 #define DISPLAY(...)          fprintf(stderr, __VA_ARGS__)
48 #define DISPLAYLEVEL(l, ...)  if (g_displayLevel>=l) { DISPLAY(__VA_ARGS__); }
49 static U32 g_displayLevel = 2;
50 
51 #define DISPLAYUPDATE(...)                                                     \
52     do {                                                                       \
53         if ((UTIL_clockSpanMicro(g_displayClock) > g_refreshRate) ||           \
54             (g_displayLevel >= 4)) {                                           \
55             g_displayClock = UTIL_getTime();                                   \
56             DISPLAY(__VA_ARGS__);                                              \
57             if (g_displayLevel >= 4) fflush(stderr);                           \
58         }                                                                      \
59     } while (0)
60 
61 static const U64 g_refreshRate = SEC_TO_MICRO / 6;
62 static UTIL_time_t g_displayClock = UTIL_TIME_INITIALIZER;
63 
64 #define CHECKERR(code)                                                         \
65     do {                                                                       \
66         if (ZSTD_isError(code)) {                                              \
67             DISPLAY("Error occurred while generating data: %s\n",              \
68                     ZSTD_getErrorName(code));                                  \
69             exit(1);                                                           \
70         }                                                                      \
71     } while (0)
72 
73 /*-*******************************************************
74 *  Random function
75 *********************************************************/
RAND(U32 * src)76 static U32 RAND(U32* src)
77 {
78 #define RAND_rotl32(x,r) ((x << r) | (x >> (32 - r)))
79     static const U32 prime1 = 2654435761U;
80     static const U32 prime2 = 2246822519U;
81     U32 rand32 = *src;
82     rand32 *= prime1;
83     rand32 += prime2;
84     rand32  = RAND_rotl32(rand32, 13);
85     *src = rand32;
86     return RAND_rotl32(rand32, 27);
87 #undef RAND_rotl32
88 }
89 
90 #define DISTSIZE (8192)
91 
92 /* Write `size` bytes into `ptr`, all of which are less than or equal to `maxSymb` */
RAND_bufferMaxSymb(U32 * seed,void * ptr,size_t size,int maxSymb)93 static void RAND_bufferMaxSymb(U32* seed, void* ptr, size_t size, int maxSymb)
94 {
95     size_t i;
96     BYTE* op = ptr;
97 
98     for (i = 0; i < size; i++) {
99         op[i] = (BYTE) (RAND(seed) % (maxSymb + 1));
100     }
101 }
102 
103 /* Write `size` random bytes into `ptr` */
RAND_buffer(U32 * seed,void * ptr,size_t size)104 static void RAND_buffer(U32* seed, void* ptr, size_t size)
105 {
106     size_t i;
107     BYTE* op = ptr;
108 
109     for (i = 0; i + 4 <= size; i += 4) {
110         MEM_writeLE32(op + i, RAND(seed));
111     }
112     for (; i < size; i++) {
113         op[i] = RAND(seed) & 0xff;
114     }
115 }
116 
117 /* Write `size` bytes into `ptr` following the distribution `dist` */
RAND_bufferDist(U32 * seed,BYTE * dist,void * ptr,size_t size)118 static void RAND_bufferDist(U32* seed, BYTE* dist, void* ptr, size_t size)
119 {
120     size_t i;
121     BYTE* op = ptr;
122 
123     for (i = 0; i < size; i++) {
124         op[i] = dist[RAND(seed) % DISTSIZE];
125     }
126 }
127 
128 /* Generate a random distribution where the frequency of each symbol follows a
129  * geometric distribution defined by `weight`
130  * `dist` should have size at least `DISTSIZE` */
RAND_genDist(U32 * seed,BYTE * dist,double weight)131 static void RAND_genDist(U32* seed, BYTE* dist, double weight)
132 {
133     size_t i = 0;
134     size_t statesLeft = DISTSIZE;
135     BYTE symb = (BYTE) (RAND(seed) % 256);
136     BYTE step = (BYTE) ((RAND(seed) % 256) | 1); /* force it to be odd so it's relatively prime to 256 */
137 
138     while (i < DISTSIZE) {
139         size_t states = ((size_t)(weight * statesLeft)) + 1;
140         size_t j;
141         for (j = 0; j < states && i < DISTSIZE; j++, i++) {
142             dist[i] = symb;
143         }
144 
145         symb += step;
146         statesLeft -= states;
147     }
148 }
149 
150 /* Generates a random number in the range [min, max) */
RAND_range(U32 * seed,U32 min,U32 max)151 static inline U32 RAND_range(U32* seed, U32 min, U32 max)
152 {
153     return (RAND(seed) % (max-min)) + min;
154 }
155 
156 #define ROUND(x) ((U32)(x + 0.5))
157 
158 /* Generates a random number in an exponential distribution with mean `mean` */
RAND_exp(U32 * seed,double mean)159 static double RAND_exp(U32* seed, double mean)
160 {
161     double const u = RAND(seed) / (double) UINT_MAX;
162     return log(1-u) * (-mean);
163 }
164 
165 /*-*******************************************************
166 *  Constants and Structs
167 *********************************************************/
168 const char *BLOCK_TYPES[] = {"raw", "rle", "compressed"};
169 
170 #define MAX_DECOMPRESSED_SIZE_LOG 20
171 #define MAX_DECOMPRESSED_SIZE (1ULL << MAX_DECOMPRESSED_SIZE_LOG)
172 
173 #define MAX_WINDOW_LOG 22 /* Recommended support is 8MB, so limit to 4MB + mantissa */
174 
175 #define MIN_SEQ_LEN (3)
176 #define MAX_NB_SEQ ((ZSTD_BLOCKSIZE_MAX + MIN_SEQ_LEN - 1) / MIN_SEQ_LEN)
177 
178 BYTE CONTENT_BUFFER[MAX_DECOMPRESSED_SIZE];
179 BYTE FRAME_BUFFER[MAX_DECOMPRESSED_SIZE * 2];
180 BYTE LITERAL_BUFFER[ZSTD_BLOCKSIZE_MAX];
181 
182 seqDef SEQUENCE_BUFFER[MAX_NB_SEQ];
183 BYTE SEQUENCE_LITERAL_BUFFER[ZSTD_BLOCKSIZE_MAX]; /* storeSeq expects a place to copy literals to */
184 BYTE SEQUENCE_LLCODE[ZSTD_BLOCKSIZE_MAX];
185 BYTE SEQUENCE_MLCODE[ZSTD_BLOCKSIZE_MAX];
186 BYTE SEQUENCE_OFCODE[ZSTD_BLOCKSIZE_MAX];
187 
188 unsigned WKSP[HUF_WORKSPACE_SIZE_U32];
189 
190 typedef struct {
191     size_t contentSize; /* 0 means unknown (unless contentSize == windowSize == 0) */
192     unsigned windowSize; /* contentSize >= windowSize means single segment */
193 } frameHeader_t;
194 
195 /* For repeat modes */
196 typedef struct {
197     U32 rep[ZSTD_REP_NUM];
198 
199     int hufInit;
200     /* the distribution used in the previous block for repeat mode */
201     BYTE hufDist[DISTSIZE];
202     HUF_CElt hufTable [256];
203 
204     int fseInit;
205     FSE_CTable offcodeCTable  [FSE_CTABLE_SIZE_U32(OffFSELog, MaxOff)];
206     FSE_CTable matchlengthCTable[FSE_CTABLE_SIZE_U32(MLFSELog, MaxML)];
207     FSE_CTable litlengthCTable  [FSE_CTABLE_SIZE_U32(LLFSELog, MaxLL)];
208 
209     /* Symbols that were present in the previous distribution, for use with
210      * set_repeat */
211     BYTE litlengthSymbolSet[36];
212     BYTE offsetSymbolSet[29];
213     BYTE matchlengthSymbolSet[53];
214 } cblockStats_t;
215 
216 typedef struct {
217     void* data;
218     void* dataStart;
219     void* dataEnd;
220 
221     void* src;
222     void* srcStart;
223     void* srcEnd;
224 
225     frameHeader_t header;
226 
227     cblockStats_t stats;
228     cblockStats_t oldStats; /* so they can be rolled back if uncompressible */
229 } frame_t;
230 
231 typedef struct {
232     int useDict;
233     U32 dictID;
234     size_t dictContentSize;
235     BYTE* dictContent;
236 } dictInfo;
237 
238 typedef enum {
239   gt_frame = 0,  /* generate frames */
240   gt_block,      /* generate compressed blocks without block/frame headers */
241 } genType_e;
242 
243 /*-*******************************************************
244 *  Global variables (set from command line)
245 *********************************************************/
246 U32 g_maxDecompressedSizeLog = MAX_DECOMPRESSED_SIZE_LOG;  /* <= 20 */
247 U32 g_maxBlockSize = ZSTD_BLOCKSIZE_MAX;                       /* <= 128 KB */
248 
249 /*-*******************************************************
250 *  Generator Functions
251 *********************************************************/
252 
253 struct {
254     int contentSize; /* force the content size to be present */
255 } opts; /* advanced options on generation */
256 
257 /* Generate and write a random frame header */
writeFrameHeader(U32 * seed,frame_t * frame,dictInfo info)258 static void writeFrameHeader(U32* seed, frame_t* frame, dictInfo info)
259 {
260     BYTE* const op = frame->data;
261     size_t pos = 0;
262     frameHeader_t fh;
263 
264     BYTE windowByte = 0;
265 
266     int singleSegment = 0;
267     int contentSizeFlag = 0;
268     int fcsCode = 0;
269 
270     memset(&fh, 0, sizeof(fh));
271 
272     /* generate window size */
273     {
274         /* Follow window algorithm from specification */
275         int const exponent = RAND(seed) % (MAX_WINDOW_LOG - 10);
276         int const mantissa = RAND(seed) % 8;
277         windowByte = (BYTE) ((exponent << 3) | mantissa);
278         fh.windowSize = (1U << (exponent + 10));
279         fh.windowSize += fh.windowSize / 8 * mantissa;
280     }
281 
282     {
283         /* Generate random content size */
284         size_t highBit;
285         if (RAND(seed) & 7 && g_maxDecompressedSizeLog > 7) {
286             /* do content of at least 128 bytes */
287             highBit = 1ULL << RAND_range(seed, 7, g_maxDecompressedSizeLog);
288         } else if (RAND(seed) & 3) {
289             /* do small content */
290             highBit = 1ULL << RAND_range(seed, 0, MIN(7, 1U << g_maxDecompressedSizeLog));
291         } else {
292             /* 0 size frame */
293             highBit = 0;
294         }
295         fh.contentSize = highBit ? highBit + (RAND(seed) % highBit) : 0;
296 
297         /* provide size sometimes */
298         contentSizeFlag = opts.contentSize | (RAND(seed) & 1);
299 
300         if (contentSizeFlag && (fh.contentSize == 0 || !(RAND(seed) & 7))) {
301             /* do single segment sometimes */
302             fh.windowSize = (U32) fh.contentSize;
303             singleSegment = 1;
304         }
305     }
306 
307     if (contentSizeFlag) {
308         /* Determine how large fcs field has to be */
309         int minFcsCode = (fh.contentSize >= 256) +
310                                (fh.contentSize >= 65536 + 256) +
311                                (fh.contentSize > 0xFFFFFFFFU);
312         if (!singleSegment && !minFcsCode) {
313             minFcsCode = 1;
314         }
315         fcsCode = minFcsCode + (RAND(seed) % (4 - minFcsCode));
316         if (fcsCode == 1 && fh.contentSize < 256) fcsCode++;
317     }
318 
319     /* write out the header */
320     MEM_writeLE32(op + pos, ZSTD_MAGICNUMBER);
321     pos += 4;
322 
323     {
324         /*
325          * fcsCode: 2-bit flag specifying how many bytes used to represent Frame_Content_Size (bits 7-6)
326          * singleSegment: 1-bit flag describing if data must be regenerated within a single continuous memory segment. (bit 5)
327          * contentChecksumFlag: 1-bit flag that is set if frame includes checksum at the end -- set to 1 below (bit 2)
328          * dictBits: 2-bit flag describing how many bytes Dictionary_ID uses -- set to 3 (bits 1-0)
329          * For more information: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_header
330          */
331         int const dictBits = info.useDict ? 3 : 0;
332         BYTE const frameHeaderDescriptor =
333                 (BYTE) ((fcsCode << 6) | (singleSegment << 5) | (1 << 2) | dictBits);
334         op[pos++] = frameHeaderDescriptor;
335     }
336 
337     if (!singleSegment) {
338         op[pos++] = windowByte;
339     }
340     if (info.useDict) {
341         MEM_writeLE32(op + pos, (U32) info.dictID);
342         pos += 4;
343     }
344     if (contentSizeFlag) {
345         switch (fcsCode) {
346         default: /* Impossible */
347         case 0: op[pos++] = (BYTE) fh.contentSize; break;
348         case 1: MEM_writeLE16(op + pos, (U16) (fh.contentSize - 256)); pos += 2; break;
349         case 2: MEM_writeLE32(op + pos, (U32) fh.contentSize); pos += 4; break;
350         case 3: MEM_writeLE64(op + pos, (U64) fh.contentSize); pos += 8; break;
351         }
352     }
353 
354     DISPLAYLEVEL(3, " frame content size:\t%u\n", (unsigned)fh.contentSize);
355     DISPLAYLEVEL(3, " frame window size:\t%u\n", fh.windowSize);
356     DISPLAYLEVEL(3, " content size flag:\t%d\n", contentSizeFlag);
357     DISPLAYLEVEL(3, " single segment flag:\t%d\n", singleSegment);
358 
359     frame->data = op + pos;
360     frame->header = fh;
361 }
362 
363 /* Write a literal block in either raw or RLE form, return the literals size */
writeLiteralsBlockSimple(U32 * seed,frame_t * frame,size_t contentSize)364 static size_t writeLiteralsBlockSimple(U32* seed, frame_t* frame, size_t contentSize)
365 {
366     BYTE* op = (BYTE*)frame->data;
367     int const type = RAND(seed) % 2;
368     int const sizeFormatDesc = RAND(seed) % 8;
369     size_t litSize;
370     size_t maxLitSize = MIN(contentSize, g_maxBlockSize);
371 
372     if (sizeFormatDesc == 0) {
373         /* Size_FormatDesc = ?0 */
374         maxLitSize = MIN(maxLitSize, 31);
375     } else if (sizeFormatDesc <= 4) {
376         /* Size_FormatDesc = 01 */
377         maxLitSize = MIN(maxLitSize, 4095);
378     } else {
379         /* Size_Format = 11 */
380         maxLitSize = MIN(maxLitSize, 1048575);
381     }
382 
383     litSize = RAND(seed) % (maxLitSize + 1);
384     if (frame->src == frame->srcStart && litSize == 0) {
385         litSize = 1; /* no empty literals if there's nothing preceding this block */
386     }
387     if (litSize + 3 > contentSize) {
388         litSize = contentSize; /* no matches shorter than 3 are allowed */
389     }
390     /* use smallest size format that fits */
391     if (litSize < 32) {
392         op[0] = (type | (0 << 2) | (litSize << 3)) & 0xff;
393         op += 1;
394     } else if (litSize < 4096) {
395         op[0] = (type | (1 << 2) | (litSize << 4)) & 0xff;
396         op[1] = (litSize >> 4) & 0xff;
397         op += 2;
398     } else {
399         op[0] = (type | (3 << 2) | (litSize << 4)) & 0xff;
400         op[1] = (litSize >> 4) & 0xff;
401         op[2] = (litSize >> 12) & 0xff;
402         op += 3;
403     }
404 
405     if (type == 0) {
406         /* Raw literals */
407         DISPLAYLEVEL(4, "   raw literals\n");
408 
409         RAND_buffer(seed, LITERAL_BUFFER, litSize);
410         memcpy(op, LITERAL_BUFFER, litSize);
411         op += litSize;
412     } else {
413         /* RLE literals */
414         BYTE const symb = (BYTE) (RAND(seed) % 256);
415 
416         DISPLAYLEVEL(4, "   rle literals: 0x%02x\n", (unsigned)symb);
417 
418         memset(LITERAL_BUFFER, symb, litSize);
419         op[0] = symb;
420         op++;
421     }
422 
423     frame->data = op;
424 
425     return litSize;
426 }
427 
428 /* Generate a Huffman header for the given source */
writeHufHeader(U32 * seed,HUF_CElt * hufTable,void * dst,size_t dstSize,const void * src,size_t srcSize)429 static size_t writeHufHeader(U32* seed, HUF_CElt* hufTable, void* dst, size_t dstSize,
430                                  const void* src, size_t srcSize)
431 {
432     BYTE* const ostart = (BYTE*)dst;
433     BYTE* op = ostart;
434 
435     unsigned huffLog = 11;
436     unsigned maxSymbolValue = 255;
437 
438     unsigned count[HUF_SYMBOLVALUE_MAX+1];
439 
440     /* Scan input and build symbol stats */
441     {   size_t const largest = HIST_count_wksp (count, &maxSymbolValue, (const BYTE*)src, srcSize, WKSP, sizeof(WKSP));
442         assert(!HIST_isError(largest));
443         if (largest == srcSize) { *ostart = ((const BYTE*)src)[0]; return 0; }   /* single symbol, rle */
444         if (largest <= (srcSize >> 7)+1) return 0;   /* Fast heuristic : not compressible enough */
445     }
446 
447     /* Build Huffman Tree */
448     /* Max Huffman log is 11, min is highbit(maxSymbolValue)+1 */
449     huffLog = RAND_range(seed, ZSTD_highbit32(maxSymbolValue)+1, huffLog+1);
450     DISPLAYLEVEL(6, "     huffman log: %u\n", huffLog);
451     {   size_t const maxBits = HUF_buildCTable_wksp (hufTable, count, maxSymbolValue, huffLog, WKSP, sizeof(WKSP));
452         CHECKERR(maxBits);
453         huffLog = (U32)maxBits;
454     }
455 
456     /* Write table description header */
457     {   size_t const hSize = HUF_writeCTable (op, dstSize, hufTable, maxSymbolValue, huffLog);
458         if (hSize + 12 >= srcSize) return 0;   /* not useful to try compression */
459         op += hSize;
460     }
461 
462     return op - ostart;
463 }
464 
465 /* Write a Huffman coded literals block and return the literals size */
writeLiteralsBlockCompressed(U32 * seed,frame_t * frame,size_t contentSize)466 static size_t writeLiteralsBlockCompressed(U32* seed, frame_t* frame, size_t contentSize)
467 {
468     BYTE* origop = (BYTE*)frame->data;
469     BYTE* opend = (BYTE*)frame->dataEnd;
470     BYTE* op;
471     BYTE* const ostart = origop;
472     int const sizeFormat = RAND(seed) % 4;
473     size_t litSize;
474     size_t hufHeaderSize = 0;
475     size_t compressedSize = 0;
476     size_t maxLitSize = MIN(contentSize-3, g_maxBlockSize);
477 
478     symbolEncodingType_e hType;
479 
480     if (contentSize < 64) {
481         /* make sure we get reasonably-sized literals for compression */
482         return ERROR(GENERIC);
483     }
484 
485     DISPLAYLEVEL(4, "   compressed literals\n");
486 
487     switch (sizeFormat) {
488     case 0: /* fall through, size is the same as case 1 */
489     case 1:
490         maxLitSize = MIN(maxLitSize, 1023);
491         origop += 3;
492         break;
493     case 2:
494         maxLitSize = MIN(maxLitSize, 16383);
495         origop += 4;
496         break;
497     case 3:
498         maxLitSize = MIN(maxLitSize, 262143);
499         origop += 5;
500         break;
501     default:; /* impossible */
502     }
503 
504     do {
505         op = origop;
506         do {
507             litSize = RAND(seed) % (maxLitSize + 1);
508         } while (litSize < 32); /* avoid small literal sizes */
509         if (litSize + 3 > contentSize) {
510             litSize = contentSize; /* no matches shorter than 3 are allowed */
511         }
512 
513         /* most of the time generate a new distribution */
514         if ((RAND(seed) & 3) || !frame->stats.hufInit) {
515             do {
516                 if (RAND(seed) & 3) {
517                     /* add 10 to ensure some compressibility */
518                     double const weight = ((RAND(seed) % 90) + 10) / 100.0;
519 
520                     DISPLAYLEVEL(5, "    distribution weight: %d%%\n",
521                                  (int)(weight * 100));
522 
523                     RAND_genDist(seed, frame->stats.hufDist, weight);
524                 } else {
525                     /* sometimes do restricted range literals to force
526                      * non-huffman headers */
527                     DISPLAYLEVEL(5, "    small range literals\n");
528                     RAND_bufferMaxSymb(seed, frame->stats.hufDist, DISTSIZE,
529                                        15);
530                 }
531                 RAND_bufferDist(seed, frame->stats.hufDist, LITERAL_BUFFER,
532                                 litSize);
533 
534                 /* generate the header from the distribution instead of the
535                  * actual data to avoid bugs with symbols that were in the
536                  * distribution but never showed up in the output */
537                 hufHeaderSize = writeHufHeader(
538                         seed, frame->stats.hufTable, op, opend - op,
539                         frame->stats.hufDist, DISTSIZE);
540                 CHECKERR(hufHeaderSize);
541                 /* repeat until a valid header is written */
542             } while (hufHeaderSize == 0);
543             op += hufHeaderSize;
544             hType = set_compressed;
545 
546             frame->stats.hufInit = 1;
547         } else {
548             /* repeat the distribution/table from last time */
549             DISPLAYLEVEL(5, "    huffman repeat stats\n");
550             RAND_bufferDist(seed, frame->stats.hufDist, LITERAL_BUFFER,
551                             litSize);
552             hufHeaderSize = 0;
553             hType = set_repeat;
554         }
555 
556         do {
557             compressedSize =
558                     sizeFormat == 0
559                             ? HUF_compress1X_usingCTable(
560                                       op, opend - op, LITERAL_BUFFER, litSize,
561                                       frame->stats.hufTable)
562                             : HUF_compress4X_usingCTable(
563                                       op, opend - op, LITERAL_BUFFER, litSize,
564                                       frame->stats.hufTable);
565             CHECKERR(compressedSize);
566             /* this only occurs when it could not compress or similar */
567         } while (compressedSize <= 0);
568 
569         op += compressedSize;
570 
571         compressedSize += hufHeaderSize;
572         DISPLAYLEVEL(5, "    regenerated size: %u\n", (unsigned)litSize);
573         DISPLAYLEVEL(5, "    compressed size: %u\n", (unsigned)compressedSize);
574         if (compressedSize >= litSize) {
575             DISPLAYLEVEL(5, "     trying again\n");
576             /* if we have to try again, reset the stats so we don't accidentally
577              * try to repeat a distribution we just made */
578             frame->stats = frame->oldStats;
579         } else {
580             break;
581         }
582     } while (1);
583 
584     /* write header */
585     switch (sizeFormat) {
586     case 0: /* fall through, size is the same as case 1 */
587     case 1: {
588         U32 const header = hType | (sizeFormat << 2) | ((U32)litSize << 4) |
589                            ((U32)compressedSize << 14);
590         MEM_writeLE24(ostart, header);
591         break;
592     }
593     case 2: {
594         U32 const header = hType | (sizeFormat << 2) | ((U32)litSize << 4) |
595                            ((U32)compressedSize << 18);
596         MEM_writeLE32(ostart, header);
597         break;
598     }
599     case 3: {
600         U32 const header = hType | (sizeFormat << 2) | ((U32)litSize << 4) |
601                            ((U32)compressedSize << 22);
602         MEM_writeLE32(ostart, header);
603         ostart[4] = (BYTE)(compressedSize >> 10);
604         break;
605     }
606     default:; /* impossible */
607     }
608 
609     frame->data = op;
610     return litSize;
611 }
612 
writeLiteralsBlock(U32 * seed,frame_t * frame,size_t contentSize)613 static size_t writeLiteralsBlock(U32* seed, frame_t* frame, size_t contentSize)
614 {
615     /* only do compressed for larger segments to avoid compressibility issues */
616     if (RAND(seed) & 7 && contentSize >= 64) {
617         return writeLiteralsBlockCompressed(seed, frame, contentSize);
618     } else {
619         return writeLiteralsBlockSimple(seed, frame, contentSize);
620     }
621 }
622 
initSeqStore(seqStore_t * seqStore)623 static inline void initSeqStore(seqStore_t *seqStore) {
624     seqStore->maxNbSeq = MAX_NB_SEQ;
625     seqStore->maxNbLit = ZSTD_BLOCKSIZE_MAX;
626     seqStore->sequencesStart = SEQUENCE_BUFFER;
627     seqStore->litStart = SEQUENCE_LITERAL_BUFFER;
628     seqStore->llCode = SEQUENCE_LLCODE;
629     seqStore->mlCode = SEQUENCE_MLCODE;
630     seqStore->ofCode = SEQUENCE_OFCODE;
631 
632     ZSTD_resetSeqStore(seqStore);
633 }
634 
635 /* Randomly generate sequence commands */
generateSequences(U32 * seed,frame_t * frame,seqStore_t * seqStore,size_t contentSize,size_t literalsSize,dictInfo info)636 static U32 generateSequences(U32* seed, frame_t* frame, seqStore_t* seqStore,
637                                 size_t contentSize, size_t literalsSize, dictInfo info)
638 {
639     /* The total length of all the matches */
640     size_t const remainingMatch = contentSize - literalsSize;
641     size_t excessMatch = 0;
642     U32 numSequences = 0;
643 
644     U32 i;
645 
646 
647     const BYTE* literals = LITERAL_BUFFER;
648     BYTE* srcPtr = frame->src;
649 
650     if (literalsSize != contentSize) {
651         /* each match must be at least MIN_SEQ_LEN, so this is the maximum
652          * number of sequences we can have */
653         U32 const maxSequences = (U32)remainingMatch / MIN_SEQ_LEN;
654         numSequences = (RAND(seed) % maxSequences) + 1;
655 
656         /* the extra match lengths we have to allocate to each sequence */
657         excessMatch = remainingMatch - numSequences * MIN_SEQ_LEN;
658     }
659 
660     DISPLAYLEVEL(5, "    total match lengths: %u\n", (unsigned)remainingMatch);
661     for (i = 0; i < numSequences; i++) {
662         /* Generate match and literal lengths by exponential distribution to
663          * ensure nice numbers */
664         U32 matchLen =
665                 MIN_SEQ_LEN +
666                 ROUND(RAND_exp(seed, excessMatch / (double)(numSequences - i)));
667         U32 literalLen =
668                 (RAND(seed) & 7)
669                         ? ROUND(RAND_exp(seed,
670                                          literalsSize /
671                                                  (double)(numSequences - i)))
672                         : 0;
673         /* actual offset, code to send, and point to copy up to when shifting
674          * codes in the repeat offsets history */
675         U32 offset, offsetCode, repIndex;
676 
677         /* bounds checks */
678         matchLen = (U32) MIN(matchLen, excessMatch + MIN_SEQ_LEN);
679         literalLen = MIN(literalLen, (U32) literalsSize);
680         if (i == 0 && srcPtr == frame->srcStart && literalLen == 0) literalLen = 1;
681         if (i + 1 == numSequences) matchLen = MIN_SEQ_LEN + (U32) excessMatch;
682 
683         memcpy(srcPtr, literals, literalLen);
684         srcPtr += literalLen;
685         do {
686             if (RAND(seed) & 7) {
687                 /* do a normal offset */
688                 U32 const dataDecompressed = (U32)((BYTE*)srcPtr-(BYTE*)frame->srcStart);
689                 offset = (RAND(seed) %
690                           MIN(frame->header.windowSize,
691                               (size_t)((BYTE*)srcPtr - (BYTE*)frame->srcStart))) +
692                          1;
693                 if (info.useDict && (RAND(seed) & 1) && i + 1 != numSequences && dataDecompressed < frame->header.windowSize) {
694                     /* need to occasionally generate offsets that go past the start */
695                     /* including i+1 != numSequences because the last sequences has to adhere to predetermined contentSize */
696                     U32 lenPastStart = (RAND(seed) % info.dictContentSize) + 1;
697                     offset = (U32)((BYTE*)srcPtr - (BYTE*)frame->srcStart)+lenPastStart;
698                     if (offset > frame->header.windowSize) {
699                         if (lenPastStart < MIN_SEQ_LEN) {
700                             /* when offset > windowSize, matchLen bound by end of dictionary (lenPastStart) */
701                             /* this also means that lenPastStart must be greater than MIN_SEQ_LEN */
702                             /* make sure lenPastStart does not go past dictionary start though */
703                             lenPastStart = MIN(lenPastStart+MIN_SEQ_LEN, (U32)info.dictContentSize);
704                             offset = (U32)((BYTE*)srcPtr - (BYTE*)frame->srcStart) + lenPastStart;
705                         }
706                         {
707                             U32 const matchLenBound = MIN(frame->header.windowSize, lenPastStart);
708                             matchLen = MIN(matchLen, matchLenBound);
709                         }
710                     }
711                 }
712                 offsetCode = offset + ZSTD_REP_MOVE;
713                 repIndex = 2;
714             } else {
715                 /* do a repeat offset */
716                 offsetCode = RAND(seed) % 3;
717                 if (literalLen > 0) {
718                     offset = frame->stats.rep[offsetCode];
719                     repIndex = offsetCode;
720                 } else {
721                     /* special case */
722                     offset = offsetCode == 2 ? frame->stats.rep[0] - 1
723                                            : frame->stats.rep[offsetCode + 1];
724                     repIndex = MIN(2, offsetCode + 1);
725                 }
726             }
727         } while (((!info.useDict) && (offset > (size_t)((BYTE*)srcPtr - (BYTE*)frame->srcStart))) || offset == 0);
728 
729         {
730             size_t j;
731             BYTE* const dictEnd = info.dictContent + info.dictContentSize;
732             for (j = 0; j < matchLen; j++) {
733                 if ((U32)((BYTE*)srcPtr - (BYTE*)frame->srcStart) < offset) {
734                     /* copy from dictionary instead of literals */
735                     size_t const dictOffset = offset - (srcPtr - (BYTE*)frame->srcStart);
736                     *srcPtr = *(dictEnd - dictOffset);
737                 }
738                 else {
739                     *srcPtr = *(srcPtr-offset);
740                 }
741                 srcPtr++;
742             }
743         }
744 
745         {   int r;
746             for (r = repIndex; r > 0; r--) {
747                 frame->stats.rep[r] = frame->stats.rep[r - 1];
748             }
749             frame->stats.rep[0] = offset;
750         }
751 
752         DISPLAYLEVEL(6, "      LL: %5u OF: %5u ML: %5u",
753                     (unsigned)literalLen, (unsigned)offset, (unsigned)matchLen);
754         DISPLAYLEVEL(7, " srcPos: %8u seqNb: %3u",
755                      (unsigned)((BYTE*)srcPtr - (BYTE*)frame->srcStart), (unsigned)i);
756         DISPLAYLEVEL(6, "\n");
757         if (offsetCode < 3) {
758             DISPLAYLEVEL(7, "        repeat offset: %d\n", (int)repIndex);
759         }
760         /* use libzstd sequence handling */
761         ZSTD_storeSeq(seqStore, literalLen, literals, literals + literalLen,
762                       offsetCode, matchLen - MINMATCH);
763 
764         literalsSize -= literalLen;
765         excessMatch -= (matchLen - MIN_SEQ_LEN);
766         literals += literalLen;
767     }
768 
769     memcpy(srcPtr, literals, literalsSize);
770     srcPtr += literalsSize;
771     DISPLAYLEVEL(6, "      excess literals: %5u", (unsigned)literalsSize);
772     DISPLAYLEVEL(7, " srcPos: %8u", (unsigned)((BYTE*)srcPtr - (BYTE*)frame->srcStart));
773     DISPLAYLEVEL(6, "\n");
774 
775     return numSequences;
776 }
777 
initSymbolSet(const BYTE * symbols,size_t len,BYTE * set,BYTE maxSymbolValue)778 static void initSymbolSet(const BYTE* symbols, size_t len, BYTE* set, BYTE maxSymbolValue)
779 {
780     size_t i;
781 
782     memset(set, 0, (size_t)maxSymbolValue+1);
783 
784     for (i = 0; i < len; i++) {
785         set[symbols[i]] = 1;
786     }
787 }
788 
isSymbolSubset(const BYTE * symbols,size_t len,const BYTE * set,BYTE maxSymbolValue)789 static int isSymbolSubset(const BYTE* symbols, size_t len, const BYTE* set, BYTE maxSymbolValue)
790 {
791     size_t i;
792 
793     for (i = 0; i < len; i++) {
794         if (symbols[i] > maxSymbolValue || !set[symbols[i]]) {
795             return 0;
796         }
797     }
798     return 1;
799 }
800 
writeSequences(U32 * seed,frame_t * frame,seqStore_t * seqStorePtr,size_t nbSeq)801 static size_t writeSequences(U32* seed, frame_t* frame, seqStore_t* seqStorePtr,
802                              size_t nbSeq)
803 {
804     /* This code is mostly copied from ZSTD_compressSequences in zstd_compress.c */
805     unsigned count[MaxSeq+1];
806     S16 norm[MaxSeq+1];
807     FSE_CTable* CTable_LitLength = frame->stats.litlengthCTable;
808     FSE_CTable* CTable_OffsetBits = frame->stats.offcodeCTable;
809     FSE_CTable* CTable_MatchLength = frame->stats.matchlengthCTable;
810     U32 LLtype, Offtype, MLtype;   /* compressed, raw or rle */
811     const seqDef* const sequences = seqStorePtr->sequencesStart;
812     const BYTE* const ofCodeTable = seqStorePtr->ofCode;
813     const BYTE* const llCodeTable = seqStorePtr->llCode;
814     const BYTE* const mlCodeTable = seqStorePtr->mlCode;
815     BYTE* const oend = (BYTE*)frame->dataEnd;
816     BYTE* op = (BYTE*)frame->data;
817     BYTE* seqHead;
818     BYTE scratchBuffer[FSE_BUILD_CTABLE_WORKSPACE_SIZE(MaxSeq, MaxFSELog)];
819 
820     /* literals compressing block removed so that can be done separately */
821 
822     /* Sequences Header */
823     if ((oend-op) < 3 /*max nbSeq Size*/ + 1 /*seqHead */) return ERROR(dstSize_tooSmall);
824     if (nbSeq < 0x7F) *op++ = (BYTE)nbSeq;
825     else if (nbSeq < LONGNBSEQ) op[0] = (BYTE)((nbSeq>>8) + 0x80), op[1] = (BYTE)nbSeq, op+=2;
826     else op[0]=0xFF, MEM_writeLE16(op+1, (U16)(nbSeq - LONGNBSEQ)), op+=3;
827 
828     if (nbSeq==0) {
829         frame->data = op;
830         return 0;
831     }
832 
833     /* seqHead : flags for FSE encoding type */
834     seqHead = op++;
835 
836     /* convert length/distances into codes */
837     ZSTD_seqToCodes(seqStorePtr);
838 
839     /* CTable for Literal Lengths */
840     {   unsigned max = MaxLL;
841         size_t const mostFrequent = HIST_countFast_wksp(count, &max, llCodeTable, nbSeq, WKSP, sizeof(WKSP));   /* cannot fail */
842         assert(!HIST_isError(mostFrequent));
843         if (frame->stats.fseInit && !(RAND(seed) & 3) &&
844                    isSymbolSubset(llCodeTable, nbSeq,
845                                   frame->stats.litlengthSymbolSet, 35)) {
846             /* maybe do repeat mode if we're allowed to */
847             LLtype = set_repeat;
848         } else if (mostFrequent == nbSeq) {
849             /* do RLE if we have the chance */
850             *op++ = llCodeTable[0];
851             FSE_buildCTable_rle(CTable_LitLength, (BYTE)max);
852             LLtype = set_rle;
853         } else if (!(RAND(seed) & 3)) {
854             /* maybe use the default distribution */
855             CHECKERR(FSE_buildCTable_wksp(CTable_LitLength, LL_defaultNorm, MaxLL, LL_defaultNormLog, scratchBuffer, sizeof(scratchBuffer)));
856             LLtype = set_basic;
857         } else {
858             /* fall back on a full table */
859             size_t nbSeq_1 = nbSeq;
860             const U32 tableLog = FSE_optimalTableLog(LLFSELog, nbSeq, max);
861             if (count[llCodeTable[nbSeq-1]]>1) { count[llCodeTable[nbSeq-1]]--; nbSeq_1--; }
862             FSE_normalizeCount(norm, tableLog, count, nbSeq_1, max, nbSeq >= 2048);
863             { size_t const NCountSize = FSE_writeNCount(op, oend-op, norm, max, tableLog);   /* overflow protected */
864               if (FSE_isError(NCountSize)) return ERROR(GENERIC);
865               op += NCountSize; }
866             CHECKERR(FSE_buildCTable_wksp(CTable_LitLength, norm, max, tableLog, scratchBuffer, sizeof(scratchBuffer)));
867             LLtype = set_compressed;
868     }   }
869 
870     /* CTable for Offsets */
871     /* see Literal Lengths for descriptions of mode choices */
872     {   unsigned max = MaxOff;
873         size_t const mostFrequent = HIST_countFast_wksp(count, &max, ofCodeTable, nbSeq, WKSP, sizeof(WKSP));   /* cannot fail */
874         assert(!HIST_isError(mostFrequent));
875         if (frame->stats.fseInit && !(RAND(seed) & 3) &&
876                    isSymbolSubset(ofCodeTable, nbSeq,
877                                   frame->stats.offsetSymbolSet, 28)) {
878             Offtype = set_repeat;
879         } else if (mostFrequent == nbSeq) {
880             *op++ = ofCodeTable[0];
881             FSE_buildCTable_rle(CTable_OffsetBits, (BYTE)max);
882             Offtype = set_rle;
883         } else if (!(RAND(seed) & 3)) {
884             FSE_buildCTable_wksp(CTable_OffsetBits, OF_defaultNorm, DefaultMaxOff, OF_defaultNormLog, scratchBuffer, sizeof(scratchBuffer));
885             Offtype = set_basic;
886         } else {
887             size_t nbSeq_1 = nbSeq;
888             const U32 tableLog = FSE_optimalTableLog(OffFSELog, nbSeq, max);
889             if (count[ofCodeTable[nbSeq-1]]>1) { count[ofCodeTable[nbSeq-1]]--; nbSeq_1--; }
890             FSE_normalizeCount(norm, tableLog, count, nbSeq_1, max, nbSeq >= 2048);
891             { size_t const NCountSize = FSE_writeNCount(op, oend-op, norm, max, tableLog);   /* overflow protected */
892               if (FSE_isError(NCountSize)) return ERROR(GENERIC);
893               op += NCountSize; }
894             FSE_buildCTable_wksp(CTable_OffsetBits, norm, max, tableLog, scratchBuffer, sizeof(scratchBuffer));
895             Offtype = set_compressed;
896     }   }
897 
898     /* CTable for MatchLengths */
899     /* see Literal Lengths for descriptions of mode choices */
900     {   unsigned max = MaxML;
901         size_t const mostFrequent = HIST_countFast_wksp(count, &max, mlCodeTable, nbSeq, WKSP, sizeof(WKSP));   /* cannot fail */
902         assert(!HIST_isError(mostFrequent));
903         if (frame->stats.fseInit && !(RAND(seed) & 3) &&
904                    isSymbolSubset(mlCodeTable, nbSeq,
905                                   frame->stats.matchlengthSymbolSet, 52)) {
906             MLtype = set_repeat;
907         } else if (mostFrequent == nbSeq) {
908             *op++ = *mlCodeTable;
909             FSE_buildCTable_rle(CTable_MatchLength, (BYTE)max);
910             MLtype = set_rle;
911         } else if (!(RAND(seed) & 3)) {
912             /* sometimes do default distribution */
913             FSE_buildCTable_wksp(CTable_MatchLength, ML_defaultNorm, MaxML, ML_defaultNormLog, scratchBuffer, sizeof(scratchBuffer));
914             MLtype = set_basic;
915         } else {
916             /* fall back on table */
917             size_t nbSeq_1 = nbSeq;
918             const U32 tableLog = FSE_optimalTableLog(MLFSELog, nbSeq, max);
919             if (count[mlCodeTable[nbSeq-1]]>1) { count[mlCodeTable[nbSeq-1]]--; nbSeq_1--; }
920             FSE_normalizeCount(norm, tableLog, count, nbSeq_1, max, nbSeq >= 2048);
921             { size_t const NCountSize = FSE_writeNCount(op, oend-op, norm, max, tableLog);   /* overflow protected */
922               if (FSE_isError(NCountSize)) return ERROR(GENERIC);
923               op += NCountSize; }
924             FSE_buildCTable_wksp(CTable_MatchLength, norm, max, tableLog, scratchBuffer, sizeof(scratchBuffer));
925             MLtype = set_compressed;
926     }   }
927     frame->stats.fseInit = 1;
928     initSymbolSet(llCodeTable, nbSeq, frame->stats.litlengthSymbolSet, 35);
929     initSymbolSet(ofCodeTable, nbSeq, frame->stats.offsetSymbolSet, 28);
930     initSymbolSet(mlCodeTable, nbSeq, frame->stats.matchlengthSymbolSet, 52);
931 
932     DISPLAYLEVEL(5, "    LL type: %d OF type: %d ML type: %d\n", (unsigned)LLtype, (unsigned)Offtype, (unsigned)MLtype);
933 
934     *seqHead = (BYTE)((LLtype<<6) + (Offtype<<4) + (MLtype<<2));
935 
936     /* Encoding Sequences */
937     {   BIT_CStream_t blockStream;
938         FSE_CState_t  stateMatchLength;
939         FSE_CState_t  stateOffsetBits;
940         FSE_CState_t  stateLitLength;
941 
942         RETURN_ERROR_IF(
943             ERR_isError(BIT_initCStream(&blockStream, op, oend-op)),
944             dstSize_tooSmall, "not enough space remaining");
945 
946         /* first symbols */
947         FSE_initCState2(&stateMatchLength, CTable_MatchLength, mlCodeTable[nbSeq-1]);
948         FSE_initCState2(&stateOffsetBits,  CTable_OffsetBits,  ofCodeTable[nbSeq-1]);
949         FSE_initCState2(&stateLitLength,   CTable_LitLength,   llCodeTable[nbSeq-1]);
950         BIT_addBits(&blockStream, sequences[nbSeq-1].litLength, LL_bits[llCodeTable[nbSeq-1]]);
951         if (MEM_32bits()) BIT_flushBits(&blockStream);
952         BIT_addBits(&blockStream, sequences[nbSeq-1].matchLength, ML_bits[mlCodeTable[nbSeq-1]]);
953         if (MEM_32bits()) BIT_flushBits(&blockStream);
954         BIT_addBits(&blockStream, sequences[nbSeq-1].offset, ofCodeTable[nbSeq-1]);
955         BIT_flushBits(&blockStream);
956 
957         {   size_t n;
958             for (n=nbSeq-2 ; n<nbSeq ; n--) {      /* intentional underflow */
959                 BYTE const llCode = llCodeTable[n];
960                 BYTE const ofCode = ofCodeTable[n];
961                 BYTE const mlCode = mlCodeTable[n];
962                 U32  const llBits = LL_bits[llCode];
963                 U32  const ofBits = ofCode;                                     /* 32b*/  /* 64b*/
964                 U32  const mlBits = ML_bits[mlCode];
965                                                                                 /* (7)*/  /* (7)*/
966                 FSE_encodeSymbol(&blockStream, &stateOffsetBits, ofCode);       /* 15 */  /* 15 */
967                 FSE_encodeSymbol(&blockStream, &stateMatchLength, mlCode);      /* 24 */  /* 24 */
968                 if (MEM_32bits()) BIT_flushBits(&blockStream);                  /* (7)*/
969                 FSE_encodeSymbol(&blockStream, &stateLitLength, llCode);        /* 16 */  /* 33 */
970                 if (MEM_32bits() || (ofBits+mlBits+llBits >= 64-7-(LLFSELog+MLFSELog+OffFSELog)))
971                     BIT_flushBits(&blockStream);                                /* (7)*/
972                 BIT_addBits(&blockStream, sequences[n].litLength, llBits);
973                 if (MEM_32bits() && ((llBits+mlBits)>24)) BIT_flushBits(&blockStream);
974                 BIT_addBits(&blockStream, sequences[n].matchLength, mlBits);
975                 if (MEM_32bits()) BIT_flushBits(&blockStream);                  /* (7)*/
976                 BIT_addBits(&blockStream, sequences[n].offset, ofBits);         /* 31 */
977                 BIT_flushBits(&blockStream);                                    /* (7)*/
978         }   }
979 
980         FSE_flushCState(&blockStream, &stateMatchLength);
981         FSE_flushCState(&blockStream, &stateOffsetBits);
982         FSE_flushCState(&blockStream, &stateLitLength);
983 
984         {   size_t const streamSize = BIT_closeCStream(&blockStream);
985             if (streamSize==0) return ERROR(dstSize_tooSmall);   /* not enough space */
986             op += streamSize;
987     }   }
988 
989     frame->data = op;
990 
991     return 0;
992 }
993 
writeSequencesBlock(U32 * seed,frame_t * frame,size_t contentSize,size_t literalsSize,dictInfo info)994 static size_t writeSequencesBlock(U32* seed, frame_t* frame, size_t contentSize,
995                                   size_t literalsSize, dictInfo info)
996 {
997     seqStore_t seqStore;
998     size_t numSequences;
999 
1000 
1001     initSeqStore(&seqStore);
1002 
1003     /* randomly generate sequences */
1004     numSequences = generateSequences(seed, frame, &seqStore, contentSize, literalsSize, info);
1005     /* write them out to the frame data */
1006     CHECKERR(writeSequences(seed, frame, &seqStore, numSequences));
1007 
1008     return numSequences;
1009 }
1010 
writeCompressedBlock(U32 * seed,frame_t * frame,size_t contentSize,dictInfo info)1011 static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize, dictInfo info)
1012 {
1013     BYTE* const blockStart = (BYTE*)frame->data;
1014     size_t literalsSize;
1015     size_t nbSeq;
1016 
1017     DISPLAYLEVEL(4, "  compressed block:\n");
1018 
1019     literalsSize = writeLiteralsBlock(seed, frame, contentSize);
1020 
1021     DISPLAYLEVEL(4, "   literals size: %u\n", (unsigned)literalsSize);
1022 
1023     nbSeq = writeSequencesBlock(seed, frame, contentSize, literalsSize, info);
1024 
1025     DISPLAYLEVEL(4, "   number of sequences: %u\n", (unsigned)nbSeq);
1026 
1027     return (BYTE*)frame->data - blockStart;
1028 }
1029 
writeBlock(U32 * seed,frame_t * frame,size_t contentSize,int lastBlock,dictInfo info)1030 static void writeBlock(U32* seed, frame_t* frame, size_t contentSize,
1031                        int lastBlock, dictInfo info)
1032 {
1033     int const blockTypeDesc = RAND(seed) % 8;
1034     size_t blockSize;
1035     int blockType;
1036 
1037     BYTE *const header = (BYTE*)frame->data;
1038     BYTE *op = header + 3;
1039 
1040     DISPLAYLEVEL(4, " block:\n");
1041     DISPLAYLEVEL(4, "  block content size: %u\n", (unsigned)contentSize);
1042     DISPLAYLEVEL(4, "  last block: %s\n", lastBlock ? "yes" : "no");
1043 
1044     if (blockTypeDesc == 0) {
1045         /* Raw data frame */
1046 
1047         RAND_buffer(seed, frame->src, contentSize);
1048         memcpy(op, frame->src, contentSize);
1049 
1050         op += contentSize;
1051         blockType = 0;
1052         blockSize = contentSize;
1053     } else if (blockTypeDesc == 1 && frame->header.contentSize > 0) {
1054         /* RLE (Don't create RLE block if frame content is 0 since block size of 1 may exceed max block size)*/
1055         BYTE const symbol = RAND(seed) & 0xff;
1056 
1057         op[0] = symbol;
1058         memset(frame->src, symbol, contentSize);
1059 
1060         op++;
1061         blockType = 1;
1062         blockSize = contentSize;
1063     } else {
1064         /* compressed, most common */
1065         size_t compressedSize;
1066         blockType = 2;
1067 
1068         frame->oldStats = frame->stats;
1069 
1070         frame->data = op;
1071         compressedSize = writeCompressedBlock(seed, frame, contentSize, info);
1072         if (compressedSize >= contentSize) {   /* compressed block must be strictly smaller than uncompressed one */
1073             blockType = 0;
1074             memcpy(op, frame->src, contentSize);
1075 
1076             op += contentSize;
1077             blockSize = contentSize; /* fall back on raw block if data doesn't
1078                                         compress */
1079 
1080             frame->stats = frame->oldStats; /* don't update the stats */
1081         } else {
1082             op += compressedSize;
1083             blockSize = compressedSize;
1084         }
1085     }
1086     frame->src = (BYTE*)frame->src + contentSize;
1087 
1088     DISPLAYLEVEL(4, "  block type: %s\n", BLOCK_TYPES[blockType]);
1089     DISPLAYLEVEL(4, "  block size field: %u\n", (unsigned)blockSize);
1090 
1091     header[0] = (BYTE) ((lastBlock | (blockType << 1) | (blockSize << 3)) & 0xff);
1092     MEM_writeLE16(header + 1, (U16) (blockSize >> 5));
1093 
1094     frame->data = op;
1095 }
1096 
writeBlocks(U32 * seed,frame_t * frame,dictInfo info)1097 static void writeBlocks(U32* seed, frame_t* frame, dictInfo info)
1098 {
1099     size_t contentLeft = frame->header.contentSize;
1100     size_t const maxBlockSize = MIN(g_maxBlockSize, frame->header.windowSize);
1101     while (1) {
1102         /* 1 in 4 chance of ending frame */
1103         int const lastBlock = contentLeft > maxBlockSize ? 0 : !(RAND(seed) & 3);
1104         size_t blockContentSize;
1105         if (lastBlock) {
1106             blockContentSize = contentLeft;
1107         } else {
1108             if (contentLeft > 0 && (RAND(seed) & 7)) {
1109                 /* some variable size block */
1110                 blockContentSize = RAND(seed) % (MIN(maxBlockSize, contentLeft)+1);
1111             } else if (contentLeft > maxBlockSize && (RAND(seed) & 1)) {
1112                 /* some full size block */
1113                 blockContentSize = maxBlockSize;
1114             } else {
1115                 /* some empty block */
1116                 blockContentSize = 0;
1117             }
1118         }
1119 
1120         writeBlock(seed, frame, blockContentSize, lastBlock, info);
1121 
1122         contentLeft -= blockContentSize;
1123         if (lastBlock) break;
1124     }
1125 }
1126 
writeChecksum(frame_t * frame)1127 static void writeChecksum(frame_t* frame)
1128 {
1129     /* write checksum so implementations can verify their output */
1130     U64 digest = XXH64(frame->srcStart, (BYTE*)frame->src-(BYTE*)frame->srcStart, 0);
1131     DISPLAYLEVEL(3, "  checksum: %08x\n", (unsigned)digest);
1132     MEM_writeLE32(frame->data, (U32)digest);
1133     frame->data = (BYTE*)frame->data + 4;
1134 }
1135 
outputBuffer(const void * buf,size_t size,const char * const path)1136 static void outputBuffer(const void* buf, size_t size, const char* const path)
1137 {
1138     /* write data out to file */
1139     const BYTE* ip = (const BYTE*)buf;
1140     FILE* out;
1141     if (path) {
1142         out = fopen(path, "wb");
1143     } else {
1144         out = stdout;
1145     }
1146     if (!out) {
1147         fprintf(stderr, "Failed to open file at %s: ", path);
1148         perror(NULL);
1149         exit(1);
1150     }
1151 
1152     {   size_t fsize = size;
1153         size_t written = 0;
1154         while (written < fsize) {
1155             written += fwrite(ip + written, 1, fsize - written, out);
1156             if (ferror(out)) {
1157                 fprintf(stderr, "Failed to write to file at %s: ", path);
1158                 perror(NULL);
1159                 exit(1);
1160             }
1161         }
1162     }
1163 
1164     if (path) {
1165         fclose(out);
1166     }
1167 }
1168 
initFrame(frame_t * fr)1169 static void initFrame(frame_t* fr)
1170 {
1171     memset(fr, 0, sizeof(*fr));
1172     fr->data = fr->dataStart = FRAME_BUFFER;
1173     fr->dataEnd = FRAME_BUFFER + sizeof(FRAME_BUFFER);
1174     fr->src = fr->srcStart = CONTENT_BUFFER;
1175     fr->srcEnd = CONTENT_BUFFER + sizeof(CONTENT_BUFFER);
1176 
1177     /* init repeat codes */
1178     fr->stats.rep[0] = 1;
1179     fr->stats.rep[1] = 4;
1180     fr->stats.rep[2] = 8;
1181 }
1182 
1183 /**
1184  * Generated a single zstd compressed block with no block/frame header.
1185  * Returns the final seed.
1186  */
generateCompressedBlock(U32 seed,frame_t * frame,dictInfo info)1187 static U32 generateCompressedBlock(U32 seed, frame_t* frame, dictInfo info)
1188 {
1189     size_t blockContentSize;
1190     int blockWritten = 0;
1191     BYTE* op;
1192     DISPLAYLEVEL(4, "block seed: %u\n", (unsigned)seed);
1193     initFrame(frame);
1194     op = (BYTE*)frame->data;
1195 
1196     while (!blockWritten) {
1197         size_t cSize;
1198         /* generate window size */
1199         {   int const exponent = RAND(&seed) % (MAX_WINDOW_LOG - 10);
1200             int const mantissa = RAND(&seed) % 8;
1201             frame->header.windowSize = (1U << (exponent + 10));
1202             frame->header.windowSize += (frame->header.windowSize / 8) * mantissa;
1203         }
1204 
1205         /* generate content size */
1206         {   size_t const maxBlockSize = MIN(g_maxBlockSize, frame->header.windowSize);
1207             if (RAND(&seed) & 15) {
1208                 /* some full size blocks */
1209                 blockContentSize = maxBlockSize;
1210             } else if (RAND(&seed) & 7 && g_maxBlockSize >= (1U << 7)) {
1211                 /* some small blocks <= 128 bytes*/
1212                 blockContentSize = RAND(&seed) % (1U << 7);
1213             } else {
1214                 /* some variable size blocks */
1215                 blockContentSize = RAND(&seed) % maxBlockSize;
1216             }
1217         }
1218 
1219         /* try generating a compressed block */
1220         frame->oldStats = frame->stats;
1221         frame->data = op;
1222         cSize = writeCompressedBlock(&seed, frame, blockContentSize, info);
1223         if (cSize >= blockContentSize) {  /* compressed size must be strictly smaller than decompressed size : https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#blocks */
1224             /* data doesn't compress -- try again */
1225             frame->stats = frame->oldStats; /* don't update the stats */
1226             DISPLAYLEVEL(5, "   can't compress block : try again \n");
1227         } else {
1228             blockWritten = 1;
1229             DISPLAYLEVEL(4, "   block size: %u \n", (unsigned)cSize);
1230             frame->src = (BYTE*)frame->src + blockContentSize;
1231         }
1232     }
1233     return seed;
1234 }
1235 
1236 /* Return the final seed */
generateFrame(U32 seed,frame_t * fr,dictInfo info)1237 static U32 generateFrame(U32 seed, frame_t* fr, dictInfo info)
1238 {
1239     /* generate a complete frame */
1240     DISPLAYLEVEL(3, "frame seed: %u\n", (unsigned)seed);
1241     initFrame(fr);
1242 
1243     writeFrameHeader(&seed, fr, info);
1244     writeBlocks(&seed, fr, info);
1245     writeChecksum(fr);
1246 
1247     return seed;
1248 }
1249 
1250 /*_*******************************************************
1251 *  Dictionary Helper Functions
1252 *********************************************************/
1253 /* returns 0 if successful, otherwise returns 1 upon error */
genRandomDict(U32 dictID,U32 seed,size_t dictSize,BYTE * fullDict)1254 static int genRandomDict(U32 dictID, U32 seed, size_t dictSize, BYTE* fullDict)
1255 {
1256     /* allocate space for samples */
1257     int ret = 0;
1258     unsigned const numSamples = 4;
1259     size_t sampleSizes[4];
1260     BYTE* const samples = malloc(5000*sizeof(BYTE));
1261     if (samples == NULL) {
1262         DISPLAY("Error: could not allocate space for samples\n");
1263         return 1;
1264     }
1265 
1266     /* generate samples */
1267     {   unsigned literalValue = 1;
1268         unsigned samplesPos = 0;
1269         size_t currSize = 1;
1270         while (literalValue <= 4) {
1271             sampleSizes[literalValue - 1] = currSize;
1272             {   size_t k;
1273                 for (k = 0; k < currSize; k++) {
1274                     *(samples + (samplesPos++)) = (BYTE)literalValue;
1275             }   }
1276             literalValue++;
1277             currSize *= 16;
1278     }   }
1279 
1280     {   size_t dictWriteSize = 0;
1281         ZDICT_params_t zdictParams;
1282         size_t const headerSize = MAX(dictSize/4, 256);
1283         size_t const dictContentSize = dictSize - headerSize;
1284         BYTE* const dictContent = fullDict + headerSize;
1285         if (dictContentSize < ZDICT_CONTENTSIZE_MIN || dictSize < ZDICT_DICTSIZE_MIN) {
1286             DISPLAY("Error: dictionary size is too small\n");
1287             ret = 1;
1288             goto exitGenRandomDict;
1289         }
1290 
1291         /* init dictionary params */
1292         memset(&zdictParams, 0, sizeof(zdictParams));
1293         zdictParams.dictID = dictID;
1294         zdictParams.notificationLevel = 1;
1295 
1296         /* fill in dictionary content */
1297         RAND_buffer(&seed, (void*)dictContent, dictContentSize);
1298 
1299         /* finalize dictionary with random samples */
1300         dictWriteSize = ZDICT_finalizeDictionary(fullDict, dictSize,
1301                                     dictContent, dictContentSize,
1302                                     samples, sampleSizes, numSamples,
1303                                     zdictParams);
1304 
1305         if (ZDICT_isError(dictWriteSize)) {
1306             DISPLAY("Could not finalize dictionary: %s\n", ZDICT_getErrorName(dictWriteSize));
1307             ret = 1;
1308         }
1309     }
1310 
1311 exitGenRandomDict:
1312     free(samples);
1313     return ret;
1314 }
1315 
initDictInfo(int useDict,size_t dictContentSize,BYTE * dictContent,U32 dictID)1316 static dictInfo initDictInfo(int useDict, size_t dictContentSize, BYTE* dictContent, U32 dictID){
1317     /* allocate space statically */
1318     dictInfo dictOp;
1319     memset(&dictOp, 0, sizeof(dictOp));
1320     dictOp.useDict = useDict;
1321     dictOp.dictContentSize = dictContentSize;
1322     dictOp.dictContent = dictContent;
1323     dictOp.dictID = dictID;
1324     return dictOp;
1325 }
1326 
1327 /*-*******************************************************
1328 *  Test Mode
1329 *********************************************************/
1330 
1331 BYTE DECOMPRESSED_BUFFER[MAX_DECOMPRESSED_SIZE];
1332 
testDecodeSimple(frame_t * fr)1333 static size_t testDecodeSimple(frame_t* fr)
1334 {
1335     /* test decoding the generated data with the simple API */
1336     size_t const ret = ZSTD_decompress(DECOMPRESSED_BUFFER, MAX_DECOMPRESSED_SIZE,
1337                            fr->dataStart, (BYTE*)fr->data - (BYTE*)fr->dataStart);
1338 
1339     if (ZSTD_isError(ret)) return ret;
1340 
1341     if (memcmp(DECOMPRESSED_BUFFER, fr->srcStart,
1342                (BYTE*)fr->src - (BYTE*)fr->srcStart) != 0) {
1343         return ERROR(corruption_detected);
1344     }
1345 
1346     return ret;
1347 }
1348 
testDecodeStreaming(frame_t * fr)1349 static size_t testDecodeStreaming(frame_t* fr)
1350 {
1351     /* test decoding the generated data with the streaming API */
1352     ZSTD_DStream* zd = ZSTD_createDStream();
1353     ZSTD_inBuffer in;
1354     ZSTD_outBuffer out;
1355     size_t ret;
1356 
1357     if (!zd) return ERROR(memory_allocation);
1358 
1359     in.src = fr->dataStart;
1360     in.pos = 0;
1361     in.size = (BYTE*)fr->data - (BYTE*)fr->dataStart;
1362 
1363     out.dst = DECOMPRESSED_BUFFER;
1364     out.pos = 0;
1365     out.size = ZSTD_DStreamOutSize();
1366 
1367     ZSTD_initDStream(zd);
1368     while (1) {
1369         ret = ZSTD_decompressStream(zd, &out, &in);
1370         if (ZSTD_isError(ret)) goto cleanup; /* error */
1371         if (ret == 0) break; /* frame is done */
1372 
1373         /* force decoding to be done in chunks */
1374         out.size += MIN(ZSTD_DStreamOutSize(), MAX_DECOMPRESSED_SIZE - out.size);
1375     }
1376 
1377     ret = out.pos;
1378 
1379     if (memcmp(out.dst, fr->srcStart, out.pos) != 0) {
1380         return ERROR(corruption_detected);
1381     }
1382 
1383 cleanup:
1384     ZSTD_freeDStream(zd);
1385     return ret;
1386 }
1387 
testDecodeWithDict(U32 seed,genType_e genType)1388 static size_t testDecodeWithDict(U32 seed, genType_e genType)
1389 {
1390     /* create variables */
1391     size_t const dictSize = RAND(&seed) % (10 << 20) + ZDICT_DICTSIZE_MIN + ZDICT_CONTENTSIZE_MIN;
1392     U32 const dictID = RAND(&seed);
1393     size_t errorDetected = 0;
1394     BYTE* const fullDict = malloc(dictSize);
1395     if (fullDict == NULL) {
1396         return ERROR(GENERIC);
1397     }
1398 
1399     /* generate random dictionary */
1400     if (genRandomDict(dictID, seed, dictSize, fullDict)) {  /* return 0 on success */
1401         errorDetected = ERROR(GENERIC);
1402         goto dictTestCleanup;
1403     }
1404 
1405 
1406     {   frame_t fr;
1407         dictInfo info;
1408         ZSTD_DCtx* const dctx = ZSTD_createDCtx();
1409         size_t ret;
1410 
1411         /* get dict info */
1412         {   size_t const headerSize = MAX(dictSize/4, 256);
1413             size_t const dictContentSize = dictSize-headerSize;
1414             BYTE* const dictContent = fullDict+headerSize;
1415             info = initDictInfo(1, dictContentSize, dictContent, dictID);
1416         }
1417 
1418         /* manually decompress and check difference */
1419         if (genType == gt_frame) {
1420             /* Test frame */
1421             generateFrame(seed, &fr, info);
1422             ret = ZSTD_decompress_usingDict(dctx, DECOMPRESSED_BUFFER, MAX_DECOMPRESSED_SIZE,
1423                                             fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart,
1424                                             fullDict, dictSize);
1425         } else {
1426             /* Test block */
1427             generateCompressedBlock(seed, &fr, info);
1428             ret = ZSTD_decompressBegin_usingDict(dctx, fullDict, dictSize);
1429             if (ZSTD_isError(ret)) {
1430                 errorDetected = ret;
1431                 ZSTD_freeDCtx(dctx);
1432                 goto dictTestCleanup;
1433             }
1434             ret = ZSTD_decompressBlock(dctx, DECOMPRESSED_BUFFER, MAX_DECOMPRESSED_SIZE,
1435                                        fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart);
1436         }
1437         ZSTD_freeDCtx(dctx);
1438 
1439         if (ZSTD_isError(ret)) {
1440             errorDetected = ret;
1441             goto dictTestCleanup;
1442         }
1443 
1444         if (memcmp(DECOMPRESSED_BUFFER, fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart) != 0) {
1445             errorDetected = ERROR(corruption_detected);
1446             goto dictTestCleanup;
1447         }
1448     }
1449 
1450 dictTestCleanup:
1451     free(fullDict);
1452     return errorDetected;
1453 }
1454 
testDecodeRawBlock(frame_t * fr)1455 static size_t testDecodeRawBlock(frame_t* fr)
1456 {
1457     ZSTD_DCtx* dctx = ZSTD_createDCtx();
1458     size_t ret = ZSTD_decompressBegin(dctx);
1459     if (ZSTD_isError(ret)) return ret;
1460 
1461     ret = ZSTD_decompressBlock(
1462             dctx,
1463             DECOMPRESSED_BUFFER, MAX_DECOMPRESSED_SIZE,
1464             fr->dataStart, (BYTE*)fr->data - (BYTE*)fr->dataStart);
1465     ZSTD_freeDCtx(dctx);
1466     if (ZSTD_isError(ret)) return ret;
1467 
1468     if (memcmp(DECOMPRESSED_BUFFER, fr->srcStart,
1469                (BYTE*)fr->src - (BYTE*)fr->srcStart) != 0) {
1470         return ERROR(corruption_detected);
1471     }
1472 
1473     return ret;
1474 }
1475 
runBlockTest(U32 * seed)1476 static int runBlockTest(U32* seed)
1477 {
1478     frame_t fr;
1479     U32 const seedCopy = *seed;
1480     {   dictInfo const info = initDictInfo(0, 0, NULL, 0);
1481         *seed = generateCompressedBlock(*seed, &fr, info);
1482     }
1483 
1484     {   size_t const r = testDecodeRawBlock(&fr);
1485         if (ZSTD_isError(r)) {
1486             DISPLAY("Error in block mode on test seed %u: %s\n",
1487                     (unsigned)seedCopy, ZSTD_getErrorName(r));
1488             return 1;
1489         }
1490     }
1491 
1492     {   size_t const r = testDecodeWithDict(*seed, gt_block);
1493         if (ZSTD_isError(r)) {
1494             DISPLAY("Error in block mode with dictionary on test seed %u: %s\n",
1495                     (unsigned)seedCopy, ZSTD_getErrorName(r));
1496             return 1;
1497         }
1498     }
1499     return 0;
1500 }
1501 
runFrameTest(U32 * seed)1502 static int runFrameTest(U32* seed)
1503 {
1504     frame_t fr;
1505     U32 const seedCopy = *seed;
1506     {   dictInfo const info = initDictInfo(0, 0, NULL, 0);
1507         *seed = generateFrame(*seed, &fr, info);
1508     }
1509 
1510     {   size_t const r = testDecodeSimple(&fr);
1511         if (ZSTD_isError(r)) {
1512             DISPLAY("Error in simple mode on test seed %u: %s\n",
1513                     (unsigned)seedCopy, ZSTD_getErrorName(r));
1514             return 1;
1515         }
1516     }
1517     {   size_t const r = testDecodeStreaming(&fr);
1518         if (ZSTD_isError(r)) {
1519             DISPLAY("Error in streaming mode on test seed %u: %s\n",
1520                     (unsigned)seedCopy, ZSTD_getErrorName(r));
1521             return 1;
1522         }
1523     }
1524     {   size_t const r = testDecodeWithDict(*seed, gt_frame);  /* avoid big dictionaries */
1525         if (ZSTD_isError(r)) {
1526             DISPLAY("Error in dictionary mode on test seed %u: %s\n",
1527                     (unsigned)seedCopy, ZSTD_getErrorName(r));
1528             return 1;
1529         }
1530     }
1531     return 0;
1532 }
1533 
runTestMode(U32 seed,unsigned numFiles,unsigned const testDurationS,genType_e genType)1534 static int runTestMode(U32 seed, unsigned numFiles, unsigned const testDurationS,
1535                        genType_e genType)
1536 {
1537     unsigned fnum;
1538 
1539     UTIL_time_t const startClock = UTIL_getTime();
1540     U64 const maxClockSpan = testDurationS * SEC_TO_MICRO;
1541 
1542     if (numFiles == 0 && !testDurationS) numFiles = 1;
1543 
1544     DISPLAY("seed: %u\n", (unsigned)seed);
1545 
1546     for (fnum = 0; fnum < numFiles || UTIL_clockSpanMicro(startClock) < maxClockSpan; fnum++) {
1547         if (fnum < numFiles)
1548             DISPLAYUPDATE("\r%u/%u        ", fnum, numFiles);
1549         else
1550             DISPLAYUPDATE("\r%u           ", fnum);
1551 
1552         {   int const ret = (genType == gt_frame) ?
1553                             runFrameTest(&seed) :
1554                             runBlockTest(&seed);
1555             if (ret) return ret;
1556         }
1557     }
1558 
1559     DISPLAY("\r%u tests completed: ", fnum);
1560     DISPLAY("OK\n");
1561 
1562     return 0;
1563 }
1564 
1565 /*-*******************************************************
1566 *  File I/O
1567 *********************************************************/
1568 
generateFile(U32 seed,const char * const path,const char * const origPath,genType_e genType)1569 static int generateFile(U32 seed, const char* const path,
1570                         const char* const origPath, genType_e genType)
1571 {
1572     frame_t fr;
1573 
1574     DISPLAY("seed: %u\n", (unsigned)seed);
1575 
1576     {   dictInfo const info = initDictInfo(0, 0, NULL, 0);
1577         if (genType == gt_frame) {
1578             generateFrame(seed, &fr, info);
1579         } else {
1580             generateCompressedBlock(seed, &fr, info);
1581         }
1582     }
1583     outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, path);
1584     if (origPath) {
1585         outputBuffer(fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart, origPath);
1586     }
1587     return 0;
1588 }
1589 
generateCorpus(U32 seed,unsigned numFiles,const char * const path,const char * const origPath,genType_e genType)1590 static int generateCorpus(U32 seed, unsigned numFiles, const char* const path,
1591                           const char* const origPath, genType_e genType)
1592 {
1593     char outPath[MAX_PATH];
1594     unsigned fnum;
1595 
1596     DISPLAY("seed: %u\n", (unsigned)seed);
1597 
1598     for (fnum = 0; fnum < numFiles; fnum++) {
1599         frame_t fr;
1600 
1601         DISPLAYUPDATE("\r%u/%u        ", fnum, numFiles);
1602 
1603         {   dictInfo const info = initDictInfo(0, 0, NULL, 0);
1604             if (genType == gt_frame) {
1605                 seed = generateFrame(seed, &fr, info);
1606             } else {
1607                 seed = generateCompressedBlock(seed, &fr, info);
1608             }
1609         }
1610 
1611         if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) {
1612             DISPLAY("Error: path too long\n");
1613             return 1;
1614         }
1615         outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, outPath);
1616 
1617         if (origPath) {
1618             if (snprintf(outPath, MAX_PATH, "%s/z%06u", origPath, fnum) + 1 > MAX_PATH) {
1619                 DISPLAY("Error: path too long\n");
1620                 return 1;
1621             }
1622             outputBuffer(fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart, outPath);
1623         }
1624     }
1625 
1626     DISPLAY("\r%u/%u      \n", fnum, numFiles);
1627 
1628     return 0;
1629 }
1630 
generateCorpusWithDict(U32 seed,unsigned numFiles,const char * const path,const char * const origPath,const size_t dictSize,genType_e genType)1631 static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const path,
1632                                   const char* const origPath, const size_t dictSize,
1633                                   genType_e genType)
1634 {
1635     char outPath[MAX_PATH];
1636     BYTE* fullDict;
1637     U32 const dictID = RAND(&seed);
1638     int errorDetected = 0;
1639 
1640     if (snprintf(outPath, MAX_PATH, "%s/dictionary", path) + 1 > MAX_PATH) {
1641         DISPLAY("Error: path too long\n");
1642         return 1;
1643     }
1644 
1645     /* allocate space for the dictionary */
1646     fullDict = malloc(dictSize);
1647     if (fullDict == NULL) {
1648         DISPLAY("Error: could not allocate space for full dictionary.\n");
1649         return 1;
1650     }
1651 
1652     /* randomly generate the dictionary */
1653     {   int const ret = genRandomDict(dictID, seed, dictSize, fullDict);
1654         if (ret != 0) {
1655             errorDetected = ret;
1656             goto dictCleanup;
1657         }
1658     }
1659 
1660     /* write out dictionary */
1661     if (numFiles != 0) {
1662         if (snprintf(outPath, MAX_PATH, "%s/dictionary", path) + 1 > MAX_PATH) {
1663             DISPLAY("Error: dictionary path too long\n");
1664             errorDetected = 1;
1665             goto dictCleanup;
1666         }
1667         outputBuffer(fullDict, dictSize, outPath);
1668     }
1669     else {
1670         outputBuffer(fullDict, dictSize, "dictionary");
1671     }
1672 
1673     /* generate random compressed/decompressed files */
1674     {   unsigned fnum;
1675         for (fnum = 0; fnum < MAX(numFiles, 1); fnum++) {
1676             frame_t fr;
1677             DISPLAYUPDATE("\r%u/%u        ", fnum, numFiles);
1678             {
1679                 size_t const headerSize = MAX(dictSize/4, 256);
1680                 size_t const dictContentSize = dictSize-headerSize;
1681                 BYTE* const dictContent = fullDict+headerSize;
1682                 dictInfo const info = initDictInfo(1, dictContentSize, dictContent, dictID);
1683                 if (genType == gt_frame) {
1684                     seed = generateFrame(seed, &fr, info);
1685                 } else {
1686                     seed = generateCompressedBlock(seed, &fr, info);
1687                 }
1688             }
1689 
1690             if (numFiles != 0) {
1691                 if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) {
1692                     DISPLAY("Error: path too long\n");
1693                     errorDetected = 1;
1694                     goto dictCleanup;
1695                 }
1696                 outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, outPath);
1697 
1698                 if (origPath) {
1699                     if (snprintf(outPath, MAX_PATH, "%s/z%06u", origPath, fnum) + 1 > MAX_PATH) {
1700                         DISPLAY("Error: path too long\n");
1701                         errorDetected = 1;
1702                         goto dictCleanup;
1703                     }
1704                     outputBuffer(fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart, outPath);
1705                 }
1706             }
1707             else {
1708                 outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, path);
1709                 if (origPath) {
1710                     outputBuffer(fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart, origPath);
1711                 }
1712             }
1713         }
1714     }
1715 
1716 dictCleanup:
1717     free(fullDict);
1718     return errorDetected;
1719 }
1720 
1721 
1722 /*_*******************************************************
1723 *  Command line
1724 *********************************************************/
makeSeed(void)1725 static U32 makeSeed(void)
1726 {
1727     U32 t = (U32) time(NULL);
1728     return XXH32(&t, sizeof(t), 0) % 65536;
1729 }
1730 
readInt(const char ** argument)1731 static unsigned readInt(const char** argument)
1732 {
1733     unsigned val = 0;
1734     while ((**argument>='0') && (**argument<='9')) {
1735         val *= 10;
1736         val += **argument - '0';
1737         (*argument)++;
1738     }
1739     return val;
1740 }
1741 
usage(const char * programName)1742 static void usage(const char* programName)
1743 {
1744     DISPLAY( "Usage :\n");
1745     DISPLAY( "      %s [args]\n", programName);
1746     DISPLAY( "\n");
1747     DISPLAY( "Arguments :\n");
1748     DISPLAY( " -p<path> : select output path (default:stdout)\n");
1749     DISPLAY( "                in multiple files mode this should be a directory\n");
1750     DISPLAY( " -o<path> : select path to output original file (default:no output)\n");
1751     DISPLAY( "                in multiple files mode this should be a directory\n");
1752     DISPLAY( " -s#      : select seed (default:random based on time)\n");
1753     DISPLAY( " -n#      : number of files to generate (default:1)\n");
1754     DISPLAY( " -t       : activate test mode (test files against libzstd instead of outputting them)\n");
1755     DISPLAY( " -T#      : length of time to run tests for\n");
1756     DISPLAY( " -v       : increase verbosity level (default:0, max:7)\n");
1757     DISPLAY( " -h/H     : display help/long help and exit\n");
1758 }
1759 
advancedUsage(const char * programName)1760 static void advancedUsage(const char* programName)
1761 {
1762     usage(programName);
1763     DISPLAY( "\n");
1764     DISPLAY( "Advanced arguments        :\n");
1765     DISPLAY( " --content-size           : always include the content size in the frame header\n");
1766     DISPLAY( " --use-dict=#             : include a dictionary used to decompress the corpus\n");
1767     DISPLAY( " --gen-blocks             : generate raw compressed blocks without block/frame headers\n");
1768     DISPLAY( " --max-block-size-log=#   : max block size log, must be in range [2, 17]\n");
1769     DISPLAY( " --max-content-size-log=# : max content size log, must be <= 20\n");
1770     DISPLAY( "                            (this is ignored with gen-blocks)\n");
1771 }
1772 
1773 /*! readU32FromChar() :
1774     @return : unsigned integer value read from input in `char` format
1775     allows and interprets K, KB, KiB, M, MB and MiB suffix.
1776     Will also modify `*stringPtr`, advancing it to position where it stopped reading.
1777     Note : function result can overflow if digit string > MAX_UINT */
readU32FromChar(const char ** stringPtr)1778 static unsigned readU32FromChar(const char** stringPtr)
1779 {
1780     unsigned result = 0;
1781     while ((**stringPtr >='0') && (**stringPtr <='9'))
1782         result *= 10, result += **stringPtr - '0', (*stringPtr)++ ;
1783     if ((**stringPtr=='K') || (**stringPtr=='M')) {
1784         result <<= 10;
1785         if (**stringPtr=='M') result <<= 10;
1786         (*stringPtr)++ ;
1787         if (**stringPtr=='i') (*stringPtr)++;
1788         if (**stringPtr=='B') (*stringPtr)++;
1789     }
1790     return result;
1791 }
1792 
1793 /** longCommandWArg() :
1794  *  check if *stringPtr is the same as longCommand.
1795  *  If yes, @return 1 and advances *stringPtr to the position which immediately follows longCommand.
1796  *  @return 0 and doesn't modify *stringPtr otherwise.
1797  */
longCommandWArg(const char ** stringPtr,const char * longCommand)1798 static unsigned longCommandWArg(const char** stringPtr, const char* longCommand)
1799 {
1800     size_t const comSize = strlen(longCommand);
1801     int const result = !strncmp(*stringPtr, longCommand, comSize);
1802     if (result) *stringPtr += comSize;
1803     return result;
1804 }
1805 
main(int argc,char ** argv)1806 int main(int argc, char** argv)
1807 {
1808     U32 seed = 0;
1809     int seedset = 0;
1810     unsigned numFiles = 0;
1811     unsigned testDuration = 0;
1812     int testMode = 0;
1813     const char* path = NULL;
1814     const char* origPath = NULL;
1815     int useDict = 0;
1816     unsigned dictSize = (10 << 10); /* 10 kB default */
1817     genType_e genType = gt_frame;
1818 
1819     int argNb;
1820 
1821     /* Check command line */
1822     for (argNb=1; argNb<argc; argNb++) {
1823         const char* argument = argv[argNb];
1824         if(!argument) continue;   /* Protection if argument empty */
1825 
1826         /* Handle commands. Aggregated commands are allowed */
1827         if (argument[0]=='-') {
1828             argument++;
1829             while (*argument!=0) {
1830                 switch(*argument)
1831                 {
1832                 case 'h':
1833                     usage(argv[0]);
1834                     return 0;
1835                 case 'H':
1836                     advancedUsage(argv[0]);
1837                     return 0;
1838                 case 'v':
1839                     argument++;
1840                     g_displayLevel++;
1841                     break;
1842                 case 's':
1843                     argument++;
1844                     seedset=1;
1845                     seed = readInt(&argument);
1846                     break;
1847                 case 'n':
1848                     argument++;
1849                     numFiles = readInt(&argument);
1850                     break;
1851                 case 'T':
1852                     argument++;
1853                     testDuration = readInt(&argument);
1854                     if (*argument == 'm') {
1855                         testDuration *= 60;
1856                         argument++;
1857                         if (*argument == 'n') argument++;
1858                     }
1859                     break;
1860                 case 'o':
1861                     argument++;
1862                     origPath = argument;
1863                     argument += strlen(argument);
1864                     break;
1865                 case 'p':
1866                     argument++;
1867                     path = argument;
1868                     argument += strlen(argument);
1869                     break;
1870                 case 't':
1871                     argument++;
1872                     testMode = 1;
1873                     break;
1874                 case '-':
1875                     argument++;
1876                     if (strcmp(argument, "content-size") == 0) {
1877                         opts.contentSize = 1;
1878                     } else if (longCommandWArg(&argument, "use-dict=")) {
1879                         dictSize = readU32FromChar(&argument);
1880                         useDict = 1;
1881                     } else if (strcmp(argument, "gen-blocks") == 0) {
1882                         genType = gt_block;
1883                     } else if (longCommandWArg(&argument, "max-block-size-log=")) {
1884                         U32 value = readU32FromChar(&argument);
1885                         if (value >= 2 && value <= ZSTD_BLOCKSIZE_MAX) {
1886                             g_maxBlockSize = 1U << value;
1887                         }
1888                     } else if (longCommandWArg(&argument, "max-content-size-log=")) {
1889                         U32 value = readU32FromChar(&argument);
1890                         g_maxDecompressedSizeLog =
1891                                 MIN(MAX_DECOMPRESSED_SIZE_LOG, value);
1892                     } else {
1893                         advancedUsage(argv[0]);
1894                         return 1;
1895                     }
1896                     argument += strlen(argument);
1897                     break;
1898                 default:
1899                     usage(argv[0]);
1900                     return 1;
1901     }   }   }   }   /* for (argNb=1; argNb<argc; argNb++) */
1902 
1903     if (!seedset) {
1904         seed = makeSeed();
1905     }
1906 
1907     if (testMode) {
1908         return runTestMode(seed, numFiles, testDuration, genType);
1909     } else {
1910         if (testDuration) {
1911             DISPLAY("Error: -T requires test mode (-t)\n\n");
1912             usage(argv[0]);
1913             return 1;
1914         }
1915     }
1916 
1917     if (!path) {
1918         DISPLAY("Error: path is required in file generation mode\n");
1919         usage(argv[0]);
1920         return 1;
1921     }
1922 
1923     if (numFiles == 0 && useDict == 0) {
1924         return generateFile(seed, path, origPath, genType);
1925     } else if (useDict == 0){
1926         return generateCorpus(seed, numFiles, path, origPath, genType);
1927     } else {
1928         /* should generate files with a dictionary */
1929         return generateCorpusWithDict(seed, numFiles, path, origPath, dictSize, genType);
1930     }
1931 
1932 }
1933