1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  */
9 
10 /*
11  * A simple demo that sums up all the bytes in the file in parallel using
12  * seekable decompression and the zstd thread pool
13  */
14 
15 #include <stdlib.h>    // malloc, exit
16 #include <stdio.h>     // fprintf, perror, feof
17 #include <string.h>    // strerror
18 #include <errno.h>     // errno
19 #define ZSTD_STATIC_LINKING_ONLY
20 #include <zstd.h>      // presumes zstd library is installed
21 #include <zstd_errors.h>
22 #if defined(WIN32) || defined(_WIN32)
23 #  include <windows.h>
24 #  define SLEEP(x) Sleep(x)
25 #else
26 #  include <unistd.h>
27 #  define SLEEP(x) usleep(x * 1000)
28 #endif
29 
30 #include "pool.h"      // use zstd thread pool for demo
31 
32 #include "zstd_seekable.h"
33 
34 #define MIN(a, b) ((a) < (b) ? (a) : (b))
35 
malloc_orDie(size_t size)36 static void* malloc_orDie(size_t size)
37 {
38     void* const buff = malloc(size);
39     if (buff) return buff;
40     /* error */
41     perror("malloc");
42     exit(1);
43 }
44 
realloc_orDie(void * ptr,size_t size)45 static void* realloc_orDie(void* ptr, size_t size)
46 {
47     ptr = realloc(ptr, size);
48     if (ptr) return ptr;
49     /* error */
50     perror("realloc");
51     exit(1);
52 }
53 
fopen_orDie(const char * filename,const char * instruction)54 static FILE* fopen_orDie(const char *filename, const char *instruction)
55 {
56     FILE* const inFile = fopen(filename, instruction);
57     if (inFile) return inFile;
58     /* error */
59     perror(filename);
60     exit(3);
61 }
62 
fread_orDie(void * buffer,size_t sizeToRead,FILE * file)63 static size_t fread_orDie(void* buffer, size_t sizeToRead, FILE* file)
64 {
65     size_t const readSize = fread(buffer, 1, sizeToRead, file);
66     if (readSize == sizeToRead) return readSize;   /* good */
67     if (feof(file)) return readSize;   /* good, reached end of file */
68     /* error */
69     perror("fread");
70     exit(4);
71 }
72 
fwrite_orDie(const void * buffer,size_t sizeToWrite,FILE * file)73 static size_t fwrite_orDie(const void* buffer, size_t sizeToWrite, FILE* file)
74 {
75     size_t const writtenSize = fwrite(buffer, 1, sizeToWrite, file);
76     if (writtenSize == sizeToWrite) return sizeToWrite;   /* good */
77     /* error */
78     perror("fwrite");
79     exit(5);
80 }
81 
fclose_orDie(FILE * file)82 static size_t fclose_orDie(FILE* file)
83 {
84     if (!fclose(file)) return 0;
85     /* error */
86     perror("fclose");
87     exit(6);
88 }
89 
fseek_orDie(FILE * file,long int offset,int origin)90 static void fseek_orDie(FILE* file, long int offset, int origin) {
91     if (!fseek(file, offset, origin)) {
92         if (!fflush(file)) return;
93     }
94     /* error */
95     perror("fseek");
96     exit(7);
97 }
98 
99 struct sum_job {
100     const char* fname;
101     unsigned long long sum;
102     unsigned frameNb;
103     int done;
104 };
105 
sumFrame(void * opaque)106 static void sumFrame(void* opaque)
107 {
108     struct sum_job* job = (struct sum_job*)opaque;
109     job->done = 0;
110 
111     FILE* const fin = fopen_orDie(job->fname, "rb");
112 
113     ZSTD_seekable* const seekable = ZSTD_seekable_create();
114     if (seekable==NULL) { fprintf(stderr, "ZSTD_seekable_create() error \n"); exit(10); }
115 
116     size_t const initResult = ZSTD_seekable_initFile(seekable, fin);
117     if (ZSTD_isError(initResult)) { fprintf(stderr, "ZSTD_seekable_init() error : %s \n", ZSTD_getErrorName(initResult)); exit(11); }
118 
119     size_t const frameSize = ZSTD_seekable_getFrameDecompressedSize(seekable, job->frameNb);
120     unsigned char* data = malloc_orDie(frameSize);
121 
122     size_t result = ZSTD_seekable_decompressFrame(seekable, data, frameSize, job->frameNb);
123     if (ZSTD_isError(result)) { fprintf(stderr, "ZSTD_seekable_decompressFrame() error : %s \n", ZSTD_getErrorName(result)); exit(12); }
124 
125     unsigned long long sum = 0;
126     size_t i;
127     for (i = 0; i < frameSize; i++) {
128         sum += data[i];
129     }
130     job->sum = sum;
131     job->done = 1;
132 
133     fclose(fin);
134     ZSTD_seekable_free(seekable);
135     free(data);
136 }
137 
sumFile_orDie(const char * fname,int nbThreads)138 static void sumFile_orDie(const char* fname, int nbThreads)
139 {
140     POOL_ctx* pool = POOL_create(nbThreads, nbThreads);
141     if (pool == NULL) { fprintf(stderr, "POOL_create() error \n"); exit(9); }
142 
143     FILE* const fin = fopen_orDie(fname, "rb");
144 
145     ZSTD_seekable* const seekable = ZSTD_seekable_create();
146     if (seekable==NULL) { fprintf(stderr, "ZSTD_seekable_create() error \n"); exit(10); }
147 
148     size_t const initResult = ZSTD_seekable_initFile(seekable, fin);
149     if (ZSTD_isError(initResult)) { fprintf(stderr, "ZSTD_seekable_init() error : %s \n", ZSTD_getErrorName(initResult)); exit(11); }
150 
151     unsigned const numFrames = ZSTD_seekable_getNumFrames(seekable);
152     struct sum_job* jobs = (struct sum_job*)malloc(numFrames * sizeof(struct sum_job));
153 
154     unsigned fnb;
155     for (fnb = 0; fnb < numFrames; fnb++) {
156         jobs[fnb] = (struct sum_job){ fname, 0, fnb, 0 };
157         POOL_add(pool, sumFrame, &jobs[fnb]);
158     }
159 
160     unsigned long long total = 0;
161 
162     for (fnb = 0; fnb < numFrames; fnb++) {
163         while (!jobs[fnb].done) SLEEP(5); /* wake up every 5 milliseconds to check */
164         total += jobs[fnb].sum;
165     }
166 
167     printf("Sum: %llu\n", total);
168 
169     POOL_free(pool);
170     ZSTD_seekable_free(seekable);
171     fclose(fin);
172     free(jobs);
173 }
174 
175 
main(int argc,const char ** argv)176 int main(int argc, const char** argv)
177 {
178     const char* const exeName = argv[0];
179 
180     if (argc!=3) {
181         fprintf(stderr, "wrong arguments\n");
182         fprintf(stderr, "usage:\n");
183         fprintf(stderr, "%s FILE NB_THREADS\n", exeName);
184         return 1;
185     }
186 
187     {
188         const char* const inFilename = argv[1];
189         int const nbThreads = atoi(argv[2]);
190         sumFile_orDie(inFilename, nbThreads);
191     }
192 
193     return 0;
194 }
195