1 /* Copyright (c) Mark Harmstone 2016-17
2  * Copyright (c) Reimar Doeffinger 2006
3  * Copyright (c) Markus Oberhumer 1996
4  *
5  * This file is part of WinBtrfs.
6  *
7  * WinBtrfs is free software: you can redistribute it and/or modify
8  * it under the terms of the GNU Lesser General Public Licence as published by
9  * the Free Software Foundation, either version 3 of the Licence, or
10  * (at your option) any later version.
11  *
12  * WinBtrfs is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU Lesser General Public Licence for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public Licence
18  * along with WinBtrfs.  If not, see <http://www.gnu.org/licenses/>. */
19 
20 // Portions of the LZO decompression code here were cribbed from code in
21 // libavcodec, also under the LGPL. Thank you, Reimar Doeffinger.
22 
23 // The LZO compression code comes from v0.22 of lzo, written way back in
24 // 1996, and available here:
25 // https://www.ibiblio.org/pub/historic-linux/ftp-archives/sunsite.unc.edu/Sep-29-1996/libs/lzo-0.22.tar.gz
26 // Modern versions of lzo are licensed under the GPL, but the very oldest
27 // versions are under the LGPL and hence okay to use here.
28 
29 #include "btrfs_drv.h"
30 
31 #define Z_SOLO
32 #define ZLIB_INTERNAL
33 
34 #ifndef __REACTOS__
35 #include "zlib/zlib.h"
36 #include "zlib/inftrees.h"
37 #include "zlib/inflate.h"
38 #else
39 #include <zlib.h>
40 #endif // __REACTOS__
41 
42 #define ZSTD_STATIC_LINKING_ONLY
43 
44 #include "zstd/zstd.h"
45 
46 #define LZO_PAGE_SIZE 4096
47 
48 typedef struct {
49     uint8_t* in;
50     uint32_t inlen;
51     uint32_t inpos;
52     uint8_t* out;
53     uint32_t outlen;
54     uint32_t outpos;
55     bool error;
56     void* wrkmem;
57 } lzo_stream;
58 
59 #define LZO1X_MEM_COMPRESS ((uint32_t) (16384L * sizeof(uint8_t*)))
60 
61 #define M1_MAX_OFFSET 0x0400
62 #define M2_MAX_OFFSET 0x0800
63 #define M3_MAX_OFFSET 0x4000
64 #define M4_MAX_OFFSET 0xbfff
65 
66 #define MX_MAX_OFFSET (M1_MAX_OFFSET + M2_MAX_OFFSET)
67 
68 #define M1_MARKER 0
69 #define M2_MARKER 64
70 #define M3_MARKER 32
71 #define M4_MARKER 16
72 
73 #define _DV2(p, shift1, shift2) (((( (uint32_t)(p[2]) << shift1) ^ p[1]) << shift2) ^ p[0])
74 #define DVAL_NEXT(dv, p) dv ^= p[-1]; dv = (((dv) >> 5) ^ ((uint32_t)(p[2]) << (2*5)))
75 #define _DV(p, shift) _DV2(p, shift, shift)
76 #define DVAL_FIRST(dv, p) dv = _DV((p), 5)
77 #define _DINDEX(dv, p) ((40799u * (dv)) >> 5)
78 #define DINDEX(dv, p) (((_DINDEX(dv, p)) & 0x3fff) << 0)
79 #define UPDATE_D(dict, cycle, dv, p) dict[DINDEX(dv, p)] = (p)
80 #define UPDATE_I(dict, cycle, index, p) dict[index] = (p)
81 
82 #define LZO_CHECK_MPOS_NON_DET(m_pos, m_off, in, ip, max_offset) \
83     ((void*) m_pos < (void*) in || \
84     (m_off = (uint8_t*) ip - (uint8_t*) m_pos) <= 0 || \
85     m_off > max_offset)
86 
87 #define LZO_BYTE(x) ((unsigned char) (x))
88 
89 #define ZSTD_ALLOC_TAG 0x6474737a // "zstd"
90 
91 // needs to be the same as Linux (fs/btrfs/zstd.c)
92 #define ZSTD_BTRFS_MAX_WINDOWLOG 17
93 
94 static void* zstd_malloc(void* opaque, size_t size);
95 static void zstd_free(void* opaque, void* address);
96 
97 #ifndef __REACTOS__
98 ZSTD_customMem zstd_mem = { .customAlloc = zstd_malloc, .customFree = zstd_free, .opaque = NULL };
99 #else
100 ZSTD_customMem zstd_mem = { zstd_malloc, zstd_free, NULL };
101 #endif
102 
103 static uint8_t lzo_nextbyte(lzo_stream* stream) {
104     uint8_t c;
105 
106     if (stream->inpos >= stream->inlen) {
107         stream->error = true;
108         return 0;
109     }
110 
111     c = stream->in[stream->inpos];
112     stream->inpos++;
113 
114     return c;
115 }
116 
117 static int lzo_len(lzo_stream* stream, int byte, int mask) {
118     int len = byte & mask;
119 
120     if (len == 0) {
121         while (!(byte = lzo_nextbyte(stream))) {
122             if (stream->error) return 0;
123 
124             len += 255;
125         }
126 
127         len += mask + byte;
128     }
129 
130     return len;
131 }
132 
133 static void lzo_copy(lzo_stream* stream, int len) {
134     if (stream->inpos + len > stream->inlen) {
135         stream->error = true;
136         return;
137     }
138 
139     if (stream->outpos + len > stream->outlen) {
140         stream->error = true;
141         return;
142     }
143 
144     do {
145         stream->out[stream->outpos] = stream->in[stream->inpos];
146         stream->inpos++;
147         stream->outpos++;
148         len--;
149     } while (len > 0);
150 }
151 
152 static void lzo_copyback(lzo_stream* stream, uint32_t back, int len) {
153     if (stream->outpos < back) {
154         stream->error = true;
155         return;
156     }
157 
158     if (stream->outpos + len > stream->outlen) {
159         stream->error = true;
160         return;
161     }
162 
163     do {
164         stream->out[stream->outpos] = stream->out[stream->outpos - back];
165         stream->outpos++;
166         len--;
167     } while (len > 0);
168 }
169 
170 static NTSTATUS do_lzo_decompress(lzo_stream* stream) {
171     uint8_t byte;
172     uint32_t len, back;
173     bool backcopy = false;
174 
175     stream->error = false;
176 
177     byte = lzo_nextbyte(stream);
178     if (stream->error) return STATUS_INTERNAL_ERROR;
179 
180     if (byte > 17) {
181         lzo_copy(stream, min((uint8_t)(byte - 17), (uint32_t)(stream->outlen - stream->outpos)));
182         if (stream->error) return STATUS_INTERNAL_ERROR;
183 
184         if (stream->outlen == stream->outpos)
185             return STATUS_SUCCESS;
186 
187         byte = lzo_nextbyte(stream);
188         if (stream->error) return STATUS_INTERNAL_ERROR;
189 
190         if (byte < 16) return STATUS_INTERNAL_ERROR;
191     }
192 
193     while (1) {
194         if (byte >> 4) {
195             backcopy = true;
196             if (byte >> 6) {
197                 len = (byte >> 5) - 1;
198                 back = (lzo_nextbyte(stream) << 3) + ((byte >> 2) & 7) + 1;
199                 if (stream->error) return STATUS_INTERNAL_ERROR;
200             } else if (byte >> 5) {
201                 len = lzo_len(stream, byte, 31);
202                 if (stream->error) return STATUS_INTERNAL_ERROR;
203 
204                 byte = lzo_nextbyte(stream);
205                 if (stream->error) return STATUS_INTERNAL_ERROR;
206 
207                 back = (lzo_nextbyte(stream) << 6) + (byte >> 2) + 1;
208                 if (stream->error) return STATUS_INTERNAL_ERROR;
209             } else {
210                 len = lzo_len(stream, byte, 7);
211                 if (stream->error) return STATUS_INTERNAL_ERROR;
212 
213                 back = (1 << 14) + ((byte & 8) << 11);
214 
215                 byte = lzo_nextbyte(stream);
216                 if (stream->error) return STATUS_INTERNAL_ERROR;
217 
218                 back += (lzo_nextbyte(stream) << 6) + (byte >> 2);
219                 if (stream->error) return STATUS_INTERNAL_ERROR;
220 
221                 if (back == (1 << 14)) {
222                     if (len != 1)
223                         return STATUS_INTERNAL_ERROR;
224                     break;
225                 }
226             }
227         } else if (backcopy) {
228             len = 0;
229             back = (lzo_nextbyte(stream) << 2) + (byte >> 2) + 1;
230             if (stream->error) return STATUS_INTERNAL_ERROR;
231         } else {
232             len = lzo_len(stream, byte, 15);
233             if (stream->error) return STATUS_INTERNAL_ERROR;
234 
235             lzo_copy(stream, min(len + 3, stream->outlen - stream->outpos));
236             if (stream->error) return STATUS_INTERNAL_ERROR;
237 
238             if (stream->outlen == stream->outpos)
239                 return STATUS_SUCCESS;
240 
241             byte = lzo_nextbyte(stream);
242             if (stream->error) return STATUS_INTERNAL_ERROR;
243 
244             if (byte >> 4)
245                 continue;
246 
247             len = 1;
248             back = (1 << 11) + (lzo_nextbyte(stream) << 2) + (byte >> 2) + 1;
249             if (stream->error) return STATUS_INTERNAL_ERROR;
250 
251             break;
252         }
253 
254         lzo_copyback(stream, back, min(len + 2, stream->outlen - stream->outpos));
255         if (stream->error) return STATUS_INTERNAL_ERROR;
256 
257         if (stream->outlen == stream->outpos)
258             return STATUS_SUCCESS;
259 
260         len = byte & 3;
261 
262         if (len) {
263             lzo_copy(stream, min(len, stream->outlen - stream->outpos));
264             if (stream->error) return STATUS_INTERNAL_ERROR;
265 
266             if (stream->outlen == stream->outpos)
267                 return STATUS_SUCCESS;
268         } else
269             backcopy = !backcopy;
270 
271         byte = lzo_nextbyte(stream);
272         if (stream->error) return STATUS_INTERNAL_ERROR;
273     }
274 
275     return STATUS_SUCCESS;
276 }
277 
278 NTSTATUS lzo_decompress(uint8_t* inbuf, uint32_t inlen, uint8_t* outbuf, uint32_t outlen, uint32_t inpageoff) {
279     NTSTATUS Status;
280     uint32_t partlen, inoff, outoff;
281     lzo_stream stream;
282 
283     inoff = 0;
284     outoff = 0;
285 
286     do {
287         partlen = *(uint32_t*)&inbuf[inoff];
288 
289         if (partlen + inoff > inlen) {
290             ERR("overflow: %x + %x > %x\n", partlen, inoff, inlen);
291             return STATUS_INTERNAL_ERROR;
292         }
293 
294         inoff += sizeof(uint32_t);
295 
296         stream.in = &inbuf[inoff];
297         stream.inlen = partlen;
298         stream.inpos = 0;
299         stream.out = &outbuf[outoff];
300         stream.outlen = min(outlen, LZO_PAGE_SIZE);
301         stream.outpos = 0;
302 
303         Status = do_lzo_decompress(&stream);
304         if (!NT_SUCCESS(Status)) {
305             ERR("do_lzo_decompress returned %08lx\n", Status);
306             return Status;
307         }
308 
309         if (stream.outpos < stream.outlen)
310             RtlZeroMemory(&stream.out[stream.outpos], stream.outlen - stream.outpos);
311 
312         inoff += partlen;
313         outoff += stream.outlen;
314 
315         if (LZO_PAGE_SIZE - ((inpageoff + inoff) % LZO_PAGE_SIZE) < sizeof(uint32_t))
316             inoff = ((((inpageoff + inoff) / LZO_PAGE_SIZE) + 1) * LZO_PAGE_SIZE) - inpageoff;
317 
318         outlen -= stream.outlen;
319     } while (inoff < inlen && outlen > 0);
320 
321     return STATUS_SUCCESS;
322 }
323 
324 static void* zlib_alloc(void* opaque, unsigned int items, unsigned int size) {
325     UNUSED(opaque);
326 
327     return ExAllocatePoolWithTag(PagedPool, items * size, ALLOC_TAG_ZLIB);
328 }
329 
330 static void zlib_free(void* opaque, void* ptr) {
331     UNUSED(opaque);
332 
333     ExFreePool(ptr);
334 }
335 
336 NTSTATUS zlib_compress(uint8_t* inbuf, uint32_t inlen, uint8_t* outbuf, uint32_t outlen, unsigned int level, unsigned int* space_left) {
337     z_stream c_stream;
338     int ret;
339 
340     c_stream.zalloc = zlib_alloc;
341     c_stream.zfree = zlib_free;
342     c_stream.opaque = (voidpf)0;
343 
344     ret = deflateInit(&c_stream, level);
345 
346     if (ret != Z_OK) {
347         ERR("deflateInit returned %i\n", ret);
348         return STATUS_INTERNAL_ERROR;
349     }
350 
351     c_stream.next_in = inbuf;
352     c_stream.avail_in = inlen;
353 
354     c_stream.next_out = outbuf;
355     c_stream.avail_out = outlen;
356 
357     do {
358         ret = deflate(&c_stream, Z_FINISH);
359 
360         if (ret != Z_OK && ret != Z_STREAM_END) {
361             ERR("deflate returned %i\n", ret);
362             deflateEnd(&c_stream);
363             return STATUS_INTERNAL_ERROR;
364         }
365 
366         if (c_stream.avail_in == 0 || c_stream.avail_out == 0)
367             break;
368     } while (ret != Z_STREAM_END);
369 
370     deflateEnd(&c_stream);
371 
372     *space_left = c_stream.avail_in > 0 ? 0 : c_stream.avail_out;
373 
374     return STATUS_SUCCESS;
375 }
376 
377 NTSTATUS zlib_decompress(uint8_t* inbuf, uint32_t inlen, uint8_t* outbuf, uint32_t outlen) {
378     z_stream c_stream;
379     int ret;
380 
381     c_stream.zalloc = zlib_alloc;
382     c_stream.zfree = zlib_free;
383     c_stream.opaque = (voidpf)0;
384 
385     ret = inflateInit(&c_stream);
386 
387     if (ret != Z_OK) {
388         ERR("inflateInit returned %i\n", ret);
389         return STATUS_INTERNAL_ERROR;
390     }
391 
392     c_stream.next_in = inbuf;
393     c_stream.avail_in = inlen;
394 
395     c_stream.next_out = outbuf;
396     c_stream.avail_out = outlen;
397 
398     do {
399         ret = inflate(&c_stream, Z_NO_FLUSH);
400 
401         if (ret != Z_OK && ret != Z_STREAM_END) {
402             ERR("inflate returned %i\n", ret);
403             inflateEnd(&c_stream);
404             return STATUS_INTERNAL_ERROR;
405         }
406 
407         if (c_stream.avail_out == 0)
408             break;
409     } while (ret != Z_STREAM_END);
410 
411     ret = inflateEnd(&c_stream);
412 
413     if (ret != Z_OK) {
414         ERR("inflateEnd returned %i\n", ret);
415         return STATUS_INTERNAL_ERROR;
416     }
417 
418     // FIXME - if we're short, should we zero the end of outbuf so we don't leak information into userspace?
419 
420     return STATUS_SUCCESS;
421 }
422 
423 static NTSTATUS lzo_do_compress(const uint8_t* in, uint32_t in_len, uint8_t* out, uint32_t* out_len, void* wrkmem) {
424     const uint8_t* ip;
425     uint32_t dv;
426     uint8_t* op;
427     const uint8_t* in_end = in + in_len;
428     const uint8_t* ip_end = in + in_len - 9 - 4;
429     const uint8_t* ii;
430     const uint8_t** dict = (const uint8_t**)wrkmem;
431 
432     op = out;
433     ip = in;
434     ii = ip;
435 
436     DVAL_FIRST(dv, ip); UPDATE_D(dict, cycle, dv, ip); ip++;
437     DVAL_NEXT(dv, ip);  UPDATE_D(dict, cycle, dv, ip); ip++;
438     DVAL_NEXT(dv, ip);  UPDATE_D(dict, cycle, dv, ip); ip++;
439     DVAL_NEXT(dv, ip);  UPDATE_D(dict, cycle, dv, ip); ip++;
440 
441     while (1) {
442         const uint8_t* m_pos;
443         uint32_t m_len;
444         ptrdiff_t m_off;
445         uint32_t lit, dindex;
446 
447         dindex = DINDEX(dv, ip);
448         m_pos = dict[dindex];
449         UPDATE_I(dict, cycle, dindex, ip);
450 
451         if (!LZO_CHECK_MPOS_NON_DET(m_pos, m_off, in, ip, M4_MAX_OFFSET) && m_pos[0] == ip[0] && m_pos[1] == ip[1] && m_pos[2] == ip[2]) {
452             lit = (uint32_t)(ip - ii);
453             m_pos += 3;
454             if (m_off <= M2_MAX_OFFSET)
455                 goto match;
456 
457             if (lit == 3) { /* better compression, but slower */
458                 if (op - 2 <= out)
459                     return STATUS_INTERNAL_ERROR;
460 
461                 op[-2] |= LZO_BYTE(3);
462                 *op++ = *ii++; *op++ = *ii++; *op++ = *ii++;
463                 goto code_match;
464             }
465 
466             if (*m_pos == ip[3])
467                 goto match;
468         }
469 
470         /* a literal */
471         ++ip;
472         if (ip >= ip_end)
473             break;
474         DVAL_NEXT(dv, ip);
475         continue;
476 
477         /* a match */
478 match:
479         /* store current literal run */
480         if (lit > 0) {
481             uint32_t t = lit;
482 
483             if (t <= 3) {
484                 if (op - 2 <= out)
485                     return STATUS_INTERNAL_ERROR;
486 
487                 op[-2] |= LZO_BYTE(t);
488             } else if (t <= 18)
489                 *op++ = LZO_BYTE(t - 3);
490             else {
491                 uint32_t tt = t - 18;
492 
493                 *op++ = 0;
494                 while (tt > 255) {
495                     tt -= 255;
496                     *op++ = 0;
497                 }
498 
499                 if (tt <= 0)
500                     return STATUS_INTERNAL_ERROR;
501 
502                 *op++ = LZO_BYTE(tt);
503             }
504 
505             do {
506                 *op++ = *ii++;
507             } while (--t > 0);
508         }
509 
510 
511         /* code the match */
512 code_match:
513         if (ii != ip)
514             return STATUS_INTERNAL_ERROR;
515 
516         ip += 3;
517         if (*m_pos++ != *ip++ || *m_pos++ != *ip++ || *m_pos++ != *ip++ ||
518             *m_pos++ != *ip++ || *m_pos++ != *ip++ || *m_pos++ != *ip++) {
519             --ip;
520             m_len = (uint32_t)(ip - ii);
521 
522             if (m_len < 3 || m_len > 8)
523                 return STATUS_INTERNAL_ERROR;
524 
525             if (m_off <= M2_MAX_OFFSET) {
526                 m_off -= 1;
527                 *op++ = LZO_BYTE(((m_len - 1) << 5) | ((m_off & 7) << 2));
528                 *op++ = LZO_BYTE(m_off >> 3);
529             } else if (m_off <= M3_MAX_OFFSET) {
530                 m_off -= 1;
531                 *op++ = LZO_BYTE(M3_MARKER | (m_len - 2));
532                 goto m3_m4_offset;
533             } else {
534                 m_off -= 0x4000;
535 
536                 if (m_off <= 0 || m_off > 0x7fff)
537                     return STATUS_INTERNAL_ERROR;
538 
539                 *op++ = LZO_BYTE(M4_MARKER | ((m_off & 0x4000) >> 11) | (m_len - 2));
540                 goto m3_m4_offset;
541             }
542         } else {
543             const uint8_t* end;
544             end = in_end;
545             while (ip < end && *m_pos == *ip)
546                 m_pos++, ip++;
547             m_len = (uint32_t)(ip - ii);
548 
549             if (m_len < 3)
550                 return STATUS_INTERNAL_ERROR;
551 
552             if (m_off <= M3_MAX_OFFSET) {
553                 m_off -= 1;
554                 if (m_len <= 33)
555                     *op++ = LZO_BYTE(M3_MARKER | (m_len - 2));
556                 else {
557                     m_len -= 33;
558                     *op++ = M3_MARKER | 0;
559                     goto m3_m4_len;
560                 }
561             } else {
562                 m_off -= 0x4000;
563 
564                 if (m_off <= 0 || m_off > 0x7fff)
565                     return STATUS_INTERNAL_ERROR;
566 
567                 if (m_len <= 9)
568                     *op++ = LZO_BYTE(M4_MARKER | ((m_off & 0x4000) >> 11) | (m_len - 2));
569                 else {
570                     m_len -= 9;
571                     *op++ = LZO_BYTE(M4_MARKER | ((m_off & 0x4000) >> 11));
572 m3_m4_len:
573                     while (m_len > 255) {
574                         m_len -= 255;
575                         *op++ = 0;
576                     }
577 
578                     if (m_len <= 0)
579                         return STATUS_INTERNAL_ERROR;
580 
581                     *op++ = LZO_BYTE(m_len);
582                 }
583             }
584 
585 m3_m4_offset:
586             *op++ = LZO_BYTE((m_off & 63) << 2);
587             *op++ = LZO_BYTE(m_off >> 6);
588         }
589 
590         ii = ip;
591         if (ip >= ip_end)
592             break;
593         DVAL_FIRST(dv, ip);
594     }
595 
596     /* store final literal run */
597     if (in_end - ii > 0) {
598         uint32_t t = (uint32_t)(in_end - ii);
599 
600         if (op == out && t <= 238)
601             *op++ = LZO_BYTE(17 + t);
602         else if (t <= 3)
603             op[-2] |= LZO_BYTE(t);
604         else if (t <= 18)
605             *op++ = LZO_BYTE(t - 3);
606         else {
607             uint32_t tt = t - 18;
608 
609             *op++ = 0;
610             while (tt > 255) {
611                 tt -= 255;
612                 *op++ = 0;
613             }
614 
615             if (tt <= 0)
616                 return STATUS_INTERNAL_ERROR;
617 
618             *op++ = LZO_BYTE(tt);
619         }
620 
621         do {
622             *op++ = *ii++;
623         } while (--t > 0);
624     }
625 
626     *out_len = (uint32_t)(op - out);
627 
628     return STATUS_SUCCESS;
629 }
630 
631 static NTSTATUS lzo1x_1_compress(lzo_stream* stream) {
632     uint8_t *op = stream->out;
633     NTSTATUS Status = STATUS_SUCCESS;
634 
635     if (stream->inlen <= 0)
636         stream->outlen = 0;
637     else if (stream->inlen <= 9 + 4) {
638         *op++ = LZO_BYTE(17 + stream->inlen);
639 
640         stream->inpos = 0;
641         do {
642             *op++ = stream->in[stream->inpos];
643             stream->inpos++;
644         } while (stream->inlen < stream->inpos);
645         stream->outlen = (uint32_t)(op - stream->out);
646     } else
647         Status = lzo_do_compress(stream->in, stream->inlen, stream->out, &stream->outlen, stream->wrkmem);
648 
649     if (Status == STATUS_SUCCESS) {
650         op = stream->out + stream->outlen;
651         *op++ = M4_MARKER | 1;
652         *op++ = 0;
653         *op++ = 0;
654         stream->outlen += 3;
655     }
656 
657     return Status;
658 }
659 
660 static __inline uint32_t lzo_max_outlen(uint32_t inlen) {
661     return inlen + (inlen / 16) + 64 + 3; // formula comes from LZO.FAQ
662 }
663 
664 static void* zstd_malloc(void* opaque, size_t size) {
665     UNUSED(opaque);
666 
667     return ExAllocatePoolWithTag(PagedPool, size, ZSTD_ALLOC_TAG);
668 }
669 
670 static void zstd_free(void* opaque, void* address) {
671     UNUSED(opaque);
672 
673     ExFreePool(address);
674 }
675 
676 NTSTATUS zstd_decompress(uint8_t* inbuf, uint32_t inlen, uint8_t* outbuf, uint32_t outlen) {
677     NTSTATUS Status;
678     ZSTD_DStream* stream;
679     size_t init_res, read;
680     ZSTD_inBuffer input;
681     ZSTD_outBuffer output;
682 
683     stream = ZSTD_createDStream_advanced(zstd_mem);
684 
685     if (!stream) {
686         ERR("ZSTD_createDStream failed.\n");
687         return STATUS_INTERNAL_ERROR;
688     }
689 
690     init_res = ZSTD_initDStream(stream);
691 
692     if (ZSTD_isError(init_res)) {
693         ERR("ZSTD_initDStream failed: %s\n", ZSTD_getErrorName(init_res));
694         Status = STATUS_INTERNAL_ERROR;
695         goto end;
696     }
697 
698     input.src = inbuf;
699     input.size = inlen;
700     input.pos = 0;
701 
702     output.dst = outbuf;
703     output.size = outlen;
704     output.pos = 0;
705 
706     do {
707         read = ZSTD_decompressStream(stream, &output, &input);
708 
709         if (ZSTD_isError(read)) {
710             ERR("ZSTD_decompressStream failed: %s\n", ZSTD_getErrorName(read));
711             Status = STATUS_INTERNAL_ERROR;
712             goto end;
713         }
714 
715         if (output.pos == output.size)
716             break;
717     } while (read != 0);
718 
719     Status = STATUS_SUCCESS;
720 
721 end:
722     ZSTD_freeDStream(stream);
723 
724     return Status;
725 }
726 
727 NTSTATUS lzo_compress(uint8_t* inbuf, uint32_t inlen, uint8_t* outbuf, uint32_t outlen, unsigned int* space_left) {
728     NTSTATUS Status;
729     unsigned int num_pages;
730     unsigned int comp_data_len;
731     uint8_t* comp_data;
732     lzo_stream stream;
733     uint32_t* out_size;
734 #ifdef __REACTOS__
735     unsigned int i;
736 #endif // __REACTOS__
737 
738     num_pages = (unsigned int)sector_align(inlen, LZO_PAGE_SIZE) / LZO_PAGE_SIZE;
739 
740     // Four-byte overall header
741     // Another four-byte header page
742     // Each page has a maximum size of lzo_max_outlen(LZO_PAGE_SIZE)
743     // Plus another four bytes for possible padding
744     comp_data_len = sizeof(uint32_t) + ((lzo_max_outlen(LZO_PAGE_SIZE) + (2 * sizeof(uint32_t))) * num_pages);
745 
746     // FIXME - can we write this so comp_data isn't necessary?
747 
748     comp_data = ExAllocatePoolWithTag(PagedPool, comp_data_len, ALLOC_TAG);
749     if (!comp_data) {
750         ERR("out of memory\n");
751         return STATUS_INSUFFICIENT_RESOURCES;
752     }
753 
754     stream.wrkmem = ExAllocatePoolWithTag(PagedPool, LZO1X_MEM_COMPRESS, ALLOC_TAG);
755     if (!stream.wrkmem) {
756         ERR("out of memory\n");
757         ExFreePool(comp_data);
758         return STATUS_INSUFFICIENT_RESOURCES;
759     }
760 
761     out_size = (uint32_t*)comp_data;
762     *out_size = sizeof(uint32_t);
763 
764     stream.in = inbuf;
765     stream.out = comp_data + (2 * sizeof(uint32_t));
766 
767 #ifndef __REACTOS__
768     for (unsigned int i = 0; i < num_pages; i++) {
769 #else
770     for (i = 0; i < num_pages; i++) {
771 #endif // __REACTOS__
772         uint32_t* pagelen = (uint32_t*)(stream.out - sizeof(uint32_t));
773 
774         stream.inlen = (uint32_t)min(LZO_PAGE_SIZE, outlen - (i * LZO_PAGE_SIZE));
775 
776         Status = lzo1x_1_compress(&stream);
777         if (!NT_SUCCESS(Status)) {
778             ERR("lzo1x_1_compress returned %08lx\n", Status);
779             ExFreePool(comp_data);
780             return Status;
781         }
782 
783         *pagelen = stream.outlen;
784         *out_size += stream.outlen + sizeof(uint32_t);
785 
786         stream.in += LZO_PAGE_SIZE;
787         stream.out += stream.outlen + sizeof(uint32_t);
788 
789         // new page needs to start at a 32-bit boundary
790         if (LZO_PAGE_SIZE - (*out_size % LZO_PAGE_SIZE) < sizeof(uint32_t)) {
791             RtlZeroMemory(stream.out, LZO_PAGE_SIZE - (*out_size % LZO_PAGE_SIZE));
792             stream.out += LZO_PAGE_SIZE - (*out_size % LZO_PAGE_SIZE);
793             *out_size += LZO_PAGE_SIZE - (*out_size % LZO_PAGE_SIZE);
794         }
795     }
796 
797     ExFreePool(stream.wrkmem);
798 
799     if (*out_size >= outlen)
800         *space_left = 0;
801     else {
802         *space_left = outlen - *out_size;
803 
804         RtlCopyMemory(outbuf, comp_data, *out_size);
805     }
806 
807     ExFreePool(comp_data);
808 
809     return STATUS_SUCCESS;
810 }
811 
812 NTSTATUS zstd_compress(uint8_t* inbuf, uint32_t inlen, uint8_t* outbuf, uint32_t outlen, uint32_t level, unsigned int* space_left) {
813     ZSTD_CStream* stream;
814     size_t init_res, written;
815     ZSTD_inBuffer input;
816     ZSTD_outBuffer output;
817     ZSTD_parameters params;
818 
819     stream = ZSTD_createCStream_advanced(zstd_mem);
820 
821     if (!stream) {
822         ERR("ZSTD_createCStream failed.\n");
823         return STATUS_INTERNAL_ERROR;
824     }
825 
826     params = ZSTD_getParams(level, inlen, 0);
827 
828     if (params.cParams.windowLog > ZSTD_BTRFS_MAX_WINDOWLOG)
829         params.cParams.windowLog = ZSTD_BTRFS_MAX_WINDOWLOG;
830 
831     init_res = ZSTD_initCStream_advanced(stream, NULL, 0, params, inlen);
832 
833     if (ZSTD_isError(init_res)) {
834         ERR("ZSTD_initCStream_advanced failed: %s\n", ZSTD_getErrorName(init_res));
835         ZSTD_freeCStream(stream);
836         return STATUS_INTERNAL_ERROR;
837     }
838 
839     input.src = inbuf;
840     input.size = inlen;
841     input.pos = 0;
842 
843     output.dst = outbuf;
844     output.size = outlen;
845     output.pos = 0;
846 
847     while (input.pos < input.size && output.pos < output.size) {
848         written = ZSTD_compressStream(stream, &output, &input);
849 
850         if (ZSTD_isError(written)) {
851             ERR("ZSTD_compressStream failed: %s\n", ZSTD_getErrorName(written));
852             ZSTD_freeCStream(stream);
853             return STATUS_INTERNAL_ERROR;
854         }
855     }
856 
857     written = ZSTD_endStream(stream, &output);
858     if (ZSTD_isError(written)) {
859         ERR("ZSTD_endStream failed: %s\n", ZSTD_getErrorName(written));
860         ZSTD_freeCStream(stream);
861         return STATUS_INTERNAL_ERROR;
862     }
863 
864     ZSTD_freeCStream(stream);
865 
866     if (input.pos < input.size) // output would be larger than input
867         *space_left = 0;
868     else
869         *space_left = output.size - output.pos;
870 
871     return STATUS_SUCCESS;
872 }
873 
874 typedef struct {
875     uint8_t buf[COMPRESSED_EXTENT_SIZE];
876     uint8_t compression_type;
877     unsigned int inlen;
878     unsigned int outlen;
879     calc_job* cj;
880 } comp_part;
881 
882 NTSTATUS write_compressed(fcb* fcb, uint64_t start_data, uint64_t end_data, void* data, PIRP Irp, LIST_ENTRY* rollback) {
883     NTSTATUS Status;
884     uint64_t i;
885     unsigned int num_parts = (unsigned int)sector_align(end_data - start_data, COMPRESSED_EXTENT_SIZE) / COMPRESSED_EXTENT_SIZE;
886     uint8_t type;
887     comp_part* parts;
888     unsigned int buflen = 0;
889     uint8_t* buf;
890     chunk* c = NULL;
891     LIST_ENTRY* le;
892     uint64_t address, extaddr;
893     void* csum = NULL;
894 #ifdef __REACTOS__
895     int32_t i2;
896     uint32_t i3, j;
897 #endif // __REACTOS__
898 
899     if (fcb->Vcb->options.compress_type != 0 && fcb->prop_compression == PropCompression_None)
900         type = fcb->Vcb->options.compress_type;
901     else {
902         if (!(fcb->Vcb->superblock.incompat_flags & BTRFS_INCOMPAT_FLAGS_COMPRESS_ZSTD) && fcb->prop_compression == PropCompression_ZSTD)
903             type = BTRFS_COMPRESSION_ZSTD;
904         else if (fcb->Vcb->superblock.incompat_flags & BTRFS_INCOMPAT_FLAGS_COMPRESS_ZSTD && fcb->prop_compression != PropCompression_Zlib && fcb->prop_compression != PropCompression_LZO)
905             type = BTRFS_COMPRESSION_ZSTD;
906         else if (!(fcb->Vcb->superblock.incompat_flags & BTRFS_INCOMPAT_FLAGS_COMPRESS_LZO) && fcb->prop_compression == PropCompression_LZO)
907             type = BTRFS_COMPRESSION_LZO;
908         else if (fcb->Vcb->superblock.incompat_flags & BTRFS_INCOMPAT_FLAGS_COMPRESS_LZO && fcb->prop_compression != PropCompression_Zlib)
909             type = BTRFS_COMPRESSION_LZO;
910         else
911             type = BTRFS_COMPRESSION_ZLIB;
912     }
913 
914     Status = excise_extents(fcb->Vcb, fcb, start_data, end_data, Irp, rollback);
915     if (!NT_SUCCESS(Status)) {
916         ERR("excise_extents returned %08lx\n", Status);
917         return Status;
918     }
919 
920     parts = ExAllocatePoolWithTag(PagedPool, sizeof(comp_part) * num_parts, ALLOC_TAG);
921     if (!parts) {
922         ERR("out of memory\n");
923         return STATUS_INSUFFICIENT_RESOURCES;
924     }
925 
926     for (i = 0; i < num_parts; i++) {
927         if (i == num_parts - 1)
928             parts[i].inlen = ((unsigned int)(end_data - start_data) - ((num_parts - 1) * COMPRESSED_EXTENT_SIZE));
929         else
930             parts[i].inlen = COMPRESSED_EXTENT_SIZE;
931 
932         Status = add_calc_job_comp(fcb->Vcb, type, (uint8_t*)data + (i * COMPRESSED_EXTENT_SIZE), parts[i].inlen,
933                                    parts[i].buf, parts[i].inlen, &parts[i].cj);
934         if (!NT_SUCCESS(Status)) {
935             ERR("add_calc_job_comp returned %08lx\n", Status);
936 
937 #ifndef __REACTOS__
938             for (unsigned int j = 0; j < i; j++) {
939 #else
940             for (j = 0; j < i; j++) {
941 #endif // __REACTOS__
942                 KeWaitForSingleObject(&parts[j].cj->event, Executive, KernelMode, false, NULL);
943                 ExFreePool(parts[j].cj);
944             }
945 
946             ExFreePool(parts);
947             return Status;
948         }
949     }
950 
951     Status = STATUS_SUCCESS;
952 
953 #ifndef __REACTOS__
954     for (int i = num_parts - 1; i >= 0; i--) {
955         calc_thread_main(fcb->Vcb, parts[i].cj);
956 
957         KeWaitForSingleObject(&parts[i].cj->event, Executive, KernelMode, false, NULL);
958 
959         if (!NT_SUCCESS(parts[i].cj->Status))
960             Status = parts[i].cj->Status;
961     }
962 #else
963     for (i2 = num_parts - 1; i2 >= 0; i2--) {
964         calc_thread_main(fcb->Vcb, parts[i].cj);
965 
966         KeWaitForSingleObject(&parts[i2].cj->event, Executive, KernelMode, false, NULL);
967 
968         if (!NT_SUCCESS(parts[i2].cj->Status))
969             Status = parts[i2].cj->Status;
970     }
971 #endif // __REACTOS__
972 
973     if (!NT_SUCCESS(Status)) {
974         ERR("calc job returned %08lx\n", Status);
975 
976 #ifndef __REACTOS__
977         for (unsigned int i = 0; i < num_parts; i++) {
978             ExFreePool(parts[i].cj);
979         }
980 #else
981         for (i3 = 0; i3 < num_parts; i3++) {
982             ExFreePool(parts[i3].cj);
983         }
984 #endif // __REACTOS__
985 
986         ExFreePool(parts);
987         return Status;
988     }
989 
990 #ifndef __REACTOS__
991     for (unsigned int i = 0; i < num_parts; i++) {
992         if (parts[i].cj->space_left >= fcb->Vcb->superblock.sector_size) {
993             parts[i].compression_type = type;
994             parts[i].outlen = parts[i].inlen - parts[i].cj->space_left;
995 
996             if (type == BTRFS_COMPRESSION_LZO)
997                 fcb->Vcb->superblock.incompat_flags |= BTRFS_INCOMPAT_FLAGS_COMPRESS_LZO;
998             else if (type == BTRFS_COMPRESSION_ZSTD)
999                 fcb->Vcb->superblock.incompat_flags |= BTRFS_INCOMPAT_FLAGS_COMPRESS_ZSTD;
1000 
1001             if ((parts[i].outlen & (fcb->Vcb->superblock.sector_size - 1)) != 0) {
1002                 unsigned int newlen = (unsigned int)sector_align(parts[i].outlen, fcb->Vcb->superblock.sector_size);
1003 
1004                 RtlZeroMemory(parts[i].buf + parts[i].outlen, newlen - parts[i].outlen);
1005 
1006                 parts[i].outlen = newlen;
1007             }
1008         } else {
1009             parts[i].compression_type = BTRFS_COMPRESSION_NONE;
1010             parts[i].outlen = (unsigned int)sector_align(parts[i].inlen, fcb->Vcb->superblock.sector_size);
1011         }
1012 
1013         buflen += parts[i].outlen;
1014         ExFreePool(parts[i].cj);
1015     }
1016 #else
1017     for (i3 = 0; i3 < num_parts; i3++) {
1018         if (parts[i3].cj->space_left >= fcb->Vcb->superblock.sector_size) {
1019             parts[i3].compression_type = type;
1020             parts[i3].outlen = parts[i3].inlen - parts[i3].cj->space_left;
1021 
1022             if (type == BTRFS_COMPRESSION_LZO)
1023                 fcb->Vcb->superblock.incompat_flags |= BTRFS_INCOMPAT_FLAGS_COMPRESS_LZO;
1024             else if (type == BTRFS_COMPRESSION_ZSTD)
1025                 fcb->Vcb->superblock.incompat_flags |= BTRFS_INCOMPAT_FLAGS_COMPRESS_ZSTD;
1026 
1027             if ((parts[i3].outlen % fcb->Vcb->superblock.sector_size) != 0) {
1028                 unsigned int newlen = (unsigned int)sector_align(parts[i3].outlen, fcb->Vcb->superblock.sector_size);
1029 
1030                 RtlZeroMemory(parts[i3].buf + parts[i3].outlen, newlen - parts[i3].outlen);
1031 
1032                 parts[i3].outlen = newlen;
1033             }
1034         } else {
1035             parts[i3].compression_type = BTRFS_COMPRESSION_NONE;
1036             parts[i3].outlen = (unsigned int)sector_align(parts[i3].inlen, fcb->Vcb->superblock.sector_size);
1037         }
1038 
1039         buflen += parts[i3].outlen;
1040         ExFreePool(parts[i3].cj);
1041     }
1042 #endif // __REACTOS__
1043 
1044     // check if first 128 KB of file is incompressible
1045 
1046     if (start_data == 0 && parts[0].compression_type == BTRFS_COMPRESSION_NONE && !fcb->Vcb->options.compress_force) {
1047         TRACE("adding nocompress flag to subvol %I64x, inode %I64x\n", fcb->subvol->id, fcb->inode);
1048 
1049         fcb->inode_item.flags |= BTRFS_INODE_NOCOMPRESS;
1050         fcb->inode_item_changed = true;
1051         mark_fcb_dirty(fcb);
1052     }
1053 
1054     // join together into continuous buffer
1055 
1056     buf = ExAllocatePoolWithTag(PagedPool, buflen, ALLOC_TAG);
1057     if (!buf) {
1058         ERR("out of memory\n");
1059         ExFreePool(parts);
1060         return STATUS_INSUFFICIENT_RESOURCES;
1061     }
1062 
1063     {
1064         uint8_t* buf2 = buf;
1065 
1066         for (i = 0; i < num_parts; i++) {
1067             if (parts[i].compression_type == BTRFS_COMPRESSION_NONE)
1068                 RtlCopyMemory(buf2, (uint8_t*)data + (i * COMPRESSED_EXTENT_SIZE), parts[i].outlen);
1069             else
1070                 RtlCopyMemory(buf2, parts[i].buf, parts[i].outlen);
1071 
1072             buf2 += parts[i].outlen;
1073         }
1074     }
1075 
1076     // find an address
1077 
1078     ExAcquireResourceSharedLite(&fcb->Vcb->chunk_lock, true);
1079 
1080     le = fcb->Vcb->chunks.Flink;
1081     while (le != &fcb->Vcb->chunks) {
1082         chunk* c2 = CONTAINING_RECORD(le, chunk, list_entry);
1083 
1084         if (!c2->readonly && !c2->reloc) {
1085             acquire_chunk_lock(c2, fcb->Vcb);
1086 
1087             if (c2->chunk_item->type == fcb->Vcb->data_flags && (c2->chunk_item->size - c2->used) >= buflen) {
1088                 if (find_data_address_in_chunk(fcb->Vcb, c2, buflen, &address)) {
1089                     c = c2;
1090                     c->used += buflen;
1091                     space_list_subtract(c, address, buflen, rollback);
1092                     release_chunk_lock(c2, fcb->Vcb);
1093                     break;
1094                 }
1095             }
1096 
1097             release_chunk_lock(c2, fcb->Vcb);
1098         }
1099 
1100         le = le->Flink;
1101     }
1102 
1103     ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
1104 
1105     if (!c) {
1106         chunk* c2;
1107 
1108         ExAcquireResourceExclusiveLite(&fcb->Vcb->chunk_lock, true);
1109 
1110         Status = alloc_chunk(fcb->Vcb, fcb->Vcb->data_flags, &c2, false);
1111 
1112         ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
1113 
1114         if (!NT_SUCCESS(Status)) {
1115             ERR("alloc_chunk returned %08lx\n", Status);
1116             ExFreePool(buf);
1117             ExFreePool(parts);
1118             return Status;
1119         }
1120 
1121         acquire_chunk_lock(c2, fcb->Vcb);
1122 
1123         if (find_data_address_in_chunk(fcb->Vcb, c2, buflen, &address)) {
1124             c = c2;
1125             c->used += buflen;
1126             space_list_subtract(c, address, buflen, rollback);
1127         }
1128 
1129         release_chunk_lock(c2, fcb->Vcb);
1130     }
1131 
1132     if (!c) {
1133         WARN("couldn't find any data chunks with %x bytes free\n", buflen);
1134         ExFreePool(buf);
1135         ExFreePool(parts);
1136         return STATUS_DISK_FULL;
1137     }
1138 
1139     // write to disk
1140 
1141     TRACE("writing %x bytes to %I64x\n", buflen, address);
1142 
1143     Status = write_data_complete(fcb->Vcb, address, buf, buflen, Irp, NULL, false, 0,
1144                                  fcb->Header.Flags2 & FSRTL_FLAG2_IS_PAGING_FILE ? HighPagePriority : NormalPagePriority);
1145     if (!NT_SUCCESS(Status)) {
1146         ERR("write_data_complete returned %08lx\n", Status);
1147         ExFreePool(buf);
1148         ExFreePool(parts);
1149         return Status;
1150     }
1151 
1152     // FIXME - do rest of the function while we're waiting for I/O to finish?
1153 
1154     // calculate csums if necessary
1155 
1156     if (!(fcb->inode_item.flags & BTRFS_INODE_NODATASUM)) {
1157         unsigned int sl = buflen >> fcb->Vcb->sector_shift;
1158 
1159         csum = ExAllocatePoolWithTag(PagedPool, sl * fcb->Vcb->csum_size, ALLOC_TAG);
1160         if (!csum) {
1161             ERR("out of memory\n");
1162             ExFreePool(buf);
1163             ExFreePool(parts);
1164             return STATUS_INSUFFICIENT_RESOURCES;
1165         }
1166 
1167         do_calc_job(fcb->Vcb, buf, sl, csum);
1168     }
1169 
1170     ExFreePool(buf);
1171 
1172     // add extents to fcb
1173 
1174     extaddr = address;
1175 
1176     for (i = 0; i < num_parts; i++) {
1177         EXTENT_DATA* ed;
1178         EXTENT_DATA2* ed2;
1179         void* csum2;
1180 
1181         ed = ExAllocatePoolWithTag(PagedPool, offsetof(EXTENT_DATA, data[0]) + sizeof(EXTENT_DATA2), ALLOC_TAG);
1182         if (!ed) {
1183             ERR("out of memory\n");
1184             ExFreePool(parts);
1185 
1186             if (csum)
1187                 ExFreePool(csum);
1188 
1189             return STATUS_INSUFFICIENT_RESOURCES;
1190         }
1191 
1192         ed->generation = fcb->Vcb->superblock.generation;
1193         ed->decoded_size = parts[i].inlen;
1194         ed->compression = parts[i].compression_type;
1195         ed->encryption = BTRFS_ENCRYPTION_NONE;
1196         ed->encoding = BTRFS_ENCODING_NONE;
1197         ed->type = EXTENT_TYPE_REGULAR;
1198 
1199         ed2 = (EXTENT_DATA2*)ed->data;
1200         ed2->address = extaddr;
1201         ed2->size = parts[i].outlen;
1202         ed2->offset = 0;
1203         ed2->num_bytes = parts[i].inlen;
1204 
1205         if (csum) {
1206             csum2 = ExAllocatePoolWithTag(PagedPool, (parts[i].outlen * fcb->Vcb->csum_size) >> fcb->Vcb->sector_shift, ALLOC_TAG);
1207             if (!csum2) {
1208                 ERR("out of memory\n");
1209                 ExFreePool(ed);
1210                 ExFreePool(parts);
1211                 ExFreePool(csum);
1212                 return STATUS_INSUFFICIENT_RESOURCES;
1213             }
1214 
1215             RtlCopyMemory(csum2, (uint8_t*)csum + (((extaddr - address) * fcb->Vcb->csum_size) >> fcb->Vcb->sector_shift),
1216                           (parts[i].outlen * fcb->Vcb->csum_size) >> fcb->Vcb->sector_shift);
1217         } else
1218             csum2 = NULL;
1219 
1220         Status = add_extent_to_fcb(fcb, start_data + (i * COMPRESSED_EXTENT_SIZE), ed, offsetof(EXTENT_DATA, data[0]) + sizeof(EXTENT_DATA2),
1221                                    true, csum2, rollback);
1222         if (!NT_SUCCESS(Status)) {
1223             ERR("add_extent_to_fcb returned %08lx\n", Status);
1224             ExFreePool(ed);
1225             ExFreePool(parts);
1226 
1227             if (csum)
1228                 ExFreePool(csum);
1229 
1230             return Status;
1231         }
1232 
1233         ExFreePool(ed);
1234 
1235         fcb->inode_item.st_blocks += parts[i].inlen;
1236 
1237         extaddr += parts[i].outlen;
1238     }
1239 
1240     if (csum)
1241         ExFreePool(csum);
1242 
1243     // update extent refcounts
1244 
1245     ExAcquireResourceExclusiveLite(&c->changed_extents_lock, true);
1246 
1247     extaddr = address;
1248 
1249     for (i = 0; i < num_parts; i++) {
1250         add_changed_extent_ref(c, extaddr, parts[i].outlen, fcb->subvol->id, fcb->inode,
1251                                start_data + (i * COMPRESSED_EXTENT_SIZE), 1, fcb->inode_item.flags & BTRFS_INODE_NODATASUM);
1252 
1253         extaddr += parts[i].outlen;
1254     }
1255 
1256     ExReleaseResourceLite(&c->changed_extents_lock);
1257 
1258     fcb->extents_changed = true;
1259     fcb->inode_item_changed = true;
1260     mark_fcb_dirty(fcb);
1261 
1262     ExFreePool(parts);
1263 
1264     return STATUS_SUCCESS;
1265 }
1266