1 // This library is part of PLINK 2.00, copyright (C) 2005-2020 Shaun Purcell,
2 // Christopher Chang.
3 //
4 // This library is free software: you can redistribute it and/or modify it
5 // under the terms of the GNU Lesser General Public License as published by the
6 // Free Software Foundation, either version 3 of the License, or (at your
7 // option) any later version.
8 //
9 // This library is distributed in the hope that it will be useful, but WITHOUT
10 // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 // FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License
12 // for more details.
13 //
14 // You should have received a copy of the GNU Lesser General Public License
15 // along with this library.  If not, see <http://www.gnu.org/licenses/>.
16 
17 #include <errno.h>
18 #include "plink2_zstfile.h"
19 
20 #ifdef __cplusplus
21 namespace plink2 {
22 #endif
23 
GetZrfp(zstRFILE * zrf_ptr)24 static inline zstRFILEMain* GetZrfp(zstRFILE* zrf_ptr) {
25   return &GET_PRIVATE(*zrf_ptr, m);
26 }
27 
PreinitZstRfile(zstRFILE * zrf_ptr)28 void PreinitZstRfile(zstRFILE* zrf_ptr) {
29   zstRFILEMain* zrfp = GetZrfp(zrf_ptr);
30   zrfp->ff = nullptr;
31   zrfp->zds = nullptr;
32   zrfp->zib.src = nullptr;
33   zrfp->errmsg = nullptr;
34   zrfp->reterr = kPglRetEof;
35 }
36 
37 // zstd prefix_unknown error string
38 const char kShortErrZstdPrefixUnknown[] = "Unknown frame descriptor";
39 const char kShortErrZstdCorruptionDetected[] = "Corrupted block detected";
40 
41 const char kShortErrZstAlreadyOpen[] = "ZstRfileOpen can't be called on an already-open file";
42 
ZstRfileOpen(const char * fname,zstRFILE * zrf_ptr)43 PglErr ZstRfileOpen(const char* fname, zstRFILE* zrf_ptr) {
44   zstRFILEMain* zrfp = GetZrfp(zrf_ptr);
45   PglErr reterr = kPglRetSuccess;
46   {
47     if (unlikely(zrfp->ff)) {
48       reterr = kPglRetImproperFunctionCall;
49       zrfp->errmsg = kShortErrZstAlreadyOpen;
50       goto ZstRfileOpen_ret_1;
51     }
52     zrfp->ff = fopen(fname, FOPEN_RB);
53     if (unlikely(!zrfp->ff)) {
54       goto ZstRfileOpen_ret_OPEN_FAIL;
55     }
56     zrfp->zds = ZSTD_createDStream();
57     if (unlikely(!zrfp->zds)) {
58       goto ZstRfileOpen_ret_NOMEM;
59     }
60     zrfp->zib.src = malloc(ZSTD_DStreamInSize());
61     if (!zrfp->zib.src) {
62       goto ZstRfileOpen_ret_NOMEM;
63     }
64     uint32_t nbytes = fread_unlocked(K_CAST(void*, zrfp->zib.src), 1, 4, zrfp->ff);
65     if (nbytes < 4) {
66       if (!feof_unlocked(zrfp->ff)) {
67         goto ZstRfileOpen_ret_READ_FAIL;
68       }
69       if (!nbytes) {
70         // May as well accept this.
71         // Don't jump to ret_1 since we're setting zrfp->reterr to a
72         // different value than what we're returning.
73         zrfp->reterr = kPglRetEof;
74         return kPglRetSuccess;
75       }
76       reterr = kPglRetDecompressFail;
77       zrfp->errmsg = kShortErrZstdPrefixUnknown;
78       goto ZstRfileOpen_ret_1;
79     }
80     zrfp->zib.size = 4;
81     zrfp->zib.pos = 0;
82   }
83   while (0) {
84   ZstRfileOpen_ret_NOMEM:
85     reterr = kPglRetNomem;
86     break;
87   ZstRfileOpen_ret_OPEN_FAIL:
88     reterr = kPglRetOpenFail;
89     zrfp->errmsg = strerror(errno);
90     break;
91   ZstRfileOpen_ret_READ_FAIL:
92     reterr = kPglRetReadFail;
93     zrfp->errmsg = strerror(errno);
94     break;
95   }
96  ZstRfileOpen_ret_1:
97   zrfp->reterr = reterr;
98   return reterr;
99 }
100 
zstread(zstRFILE * zrf_ptr,void * dst,uint32_t len)101 int32_t zstread(zstRFILE* zrf_ptr, void* dst, uint32_t len) {
102   zstRFILEMain* zrfp = GetZrfp(zrf_ptr);
103   if (zrfp->reterr) {
104     return 0;
105   }
106   assert(len < 0x80000000U);
107   uint32_t wpos = 0;
108   while (wpos != len) {
109     ZSTD_outBuffer zob = {&(S_CAST(unsigned char*, dst)[wpos]), len - wpos, 0};
110     const uintptr_t read_size_hint = ZSTD_decompressStream(zrfp->zds, &zob, &zrfp->zib);
111     if (ZSTD_isError(read_size_hint)) {
112       zrfp->reterr = kPglRetDecompressFail;
113       zrfp->errmsg = ZSTD_getErrorName(read_size_hint);
114       // zlib Z_STREAM_ERROR.  may want to return Z_DATA_ERROR when appropriate
115       return -2;
116     }
117     wpos += zob.pos;
118     if (!read_size_hint) {
119       // We've finished decompressing the most recent input frame.  Finish
120       // loading the next frame if necessary, or exit on eof.
121       assert(zrfp->zib.size == zrfp->zib.pos);
122       const uint32_t nbytes = fread_unlocked(K_CAST(void*, zrfp->zib.src), 1, 4, zrfp->ff);
123       zrfp->zib.size = nbytes;
124       zrfp->zib.pos = 0;
125       if (nbytes < 4) {
126         if (unlikely(!feof_unlocked(zrfp->ff))) {
127           // zlib Z_ERRNO
128           zrfp->reterr = kPglRetReadFail;
129           zrfp->errmsg = strerror(errno);
130           return -1;
131         }
132         if (unlikely(nbytes)) {
133           zrfp->reterr = kPglRetDecompressFail;
134           zrfp->errmsg = kShortErrZstdPrefixUnknown;
135           return -2;
136         }
137         zrfp->reterr = kPglRetEof;
138         break;
139       }
140       if (unlikely(!IsZstdFrame(*R_CAST(const uint32_t*, zrfp->zib.src)))) {
141         zrfp->reterr = kPglRetDecompressFail;
142         zrfp->errmsg = kShortErrZstdPrefixUnknown;
143         return -2;
144       }
145       // impossible for this to fail
146       ZSTD_DCtx_reset(zrfp->zds, ZSTD_reset_session_only);
147       continue;
148     }
149     if (zrfp->zib.pos != zrfp->zib.size) {
150       assert(wpos == len);
151       break;
152     }
153     const uint32_t to_decode = MINV(read_size_hint, ZSTD_DStreamInSize());
154     unsigned char* buf = S_CAST(unsigned char*, K_CAST(void*, zrfp->zib.src));
155     if (unlikely(!fread_unlocked(buf, to_decode, 1, zrfp->ff))) {
156       if (feof_unlocked(zrfp->ff)) {
157         zrfp->reterr = kPglRetDecompressFail;
158         zrfp->errmsg = kShortErrZstdCorruptionDetected;
159         return -2;
160       }
161       zrfp->reterr = kPglRetReadFail;
162       zrfp->errmsg = strerror(errno);
163       return -1;
164     }
165     zrfp->zib.size = to_decode;
166     zrfp->zib.pos = 0;
167   }
168   return wpos;
169 }
170 
zstrewind(zstRFILE * zrf_ptr)171 void zstrewind(zstRFILE* zrf_ptr) {
172   zstRFILEMain* zrfp = GetZrfp(zrf_ptr);
173   if ((zrfp->reterr == kPglRetSuccess) || (zrfp->reterr == kPglRetEof)) {
174     rewind(zrfp->ff);
175     ZSTD_DCtx_reset(zrfp->zds, ZSTD_reset_session_only);
176     zrfp->zib.size = 0;
177     zrfp->zib.pos = 0;
178     zrfp->reterr = kPglRetSuccess;
179   }
180 }
181 
CleanupZstRfile(zstRFILE * zrf_ptr,PglErr * reterrp)182 BoolErr CleanupZstRfile(zstRFILE* zrf_ptr, PglErr* reterrp) {
183   zstRFILEMain* zrfp = GetZrfp(zrf_ptr);
184   zrfp->reterr = kPglRetEof;
185   zrfp->errmsg = nullptr;
186   if (zrfp->zib.src) {
187     free_const(zrfp->zib.src);
188     zrfp->zib.src = nullptr;
189   }
190   if (zrfp->zds) {
191     ZSTD_freeDStream(zrfp->zds); // this should never fail in practice
192     zrfp->zds = nullptr;
193   }
194   if (zrfp->ff) {
195     if (unlikely(fclose_null(&zrfp->ff))) {
196       if (!reterrp) {
197         return 1;
198       }
199       if (*reterrp == kPglRetSuccess) {
200         *reterrp = kPglRetReadFail;
201         return 1;
202       }
203     }
204   }
205   return 0;
206 }
207 
208 #ifdef __cplusplus
209 }
210 #endif
211