1 /* Copyright (c) Mark Harmstone 2016-17
2  * Copyright (c) Reimar Doeffinger 2006
3  * Copyright (c) Markus Oberhumer 1996
4  *
5  * This file is part of WinBtrfs.
6  *
7  * WinBtrfs is free software: you can redistribute it and/or modify
8  * it under the terms of the GNU Lesser General Public Licence as published by
9  * the Free Software Foundation, either version 3 of the Licence, or
10  * (at your option) any later version.
11  *
12  * WinBtrfs is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU Lesser General Public Licence for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public Licence
18  * along with WinBtrfs.  If not, see <http://www.gnu.org/licenses/>. */
19 
20 // Portions of the LZO decompression code here were cribbed from code in
21 // libavcodec, also under the LGPL. Thank you, Reimar Doeffinger.
22 
23 // The LZO compression code comes from v0.22 of lzo, written way back in
24 // 1996, and available here:
25 // https://www.ibiblio.org/pub/historic-linux/ftp-archives/sunsite.unc.edu/Sep-29-1996/libs/lzo-0.22.tar.gz
26 // Modern versions of lzo are licensed under the GPL, but the very oldest
27 // versions are under the LGPL and hence okay to use here.
28 
29 #include "btrfs_drv.h"
30 
31 #define Z_SOLO
32 #define ZLIB_INTERNAL
33 
34 #ifndef __REACTOS__
35 #include "zlib/zlib.h"
36 #include "zlib/inftrees.h"
37 #include "zlib/inflate.h"
38 #else
39 #include <zlib.h>
40 #endif
41 
42 #define ZSTD_STATIC_LINKING_ONLY
43 
44 #include "zstd/zstd.h"
45 
46 #define LINUX_PAGE_SIZE 4096
47 
48 typedef struct {
49     UINT8* in;
50     UINT32 inlen;
51     UINT32 inpos;
52     UINT8* out;
53     UINT32 outlen;
54     UINT32 outpos;
55     BOOL error;
56     void* wrkmem;
57 } lzo_stream;
58 
59 #define LZO1X_MEM_COMPRESS ((UINT32) (16384L * sizeof(UINT8*)))
60 
61 #define M1_MAX_OFFSET 0x0400
62 #define M2_MAX_OFFSET 0x0800
63 #define M3_MAX_OFFSET 0x4000
64 #define M4_MAX_OFFSET 0xbfff
65 
66 #define MX_MAX_OFFSET (M1_MAX_OFFSET + M2_MAX_OFFSET)
67 
68 #define M1_MARKER 0
69 #define M2_MARKER 64
70 #define M3_MARKER 32
71 #define M4_MARKER 16
72 
73 #define _DV2(p, shift1, shift2) (((( (UINT32)(p[2]) << shift1) ^ p[1]) << shift2) ^ p[0])
74 #define DVAL_NEXT(dv, p) dv ^= p[-1]; dv = (((dv) >> 5) ^ ((UINT32)(p[2]) << (2*5)))
75 #define _DV(p, shift) _DV2(p, shift, shift)
76 #define DVAL_FIRST(dv, p) dv = _DV((p), 5)
77 #define _DINDEX(dv, p) ((40799u * (dv)) >> 5)
78 #define DINDEX(dv, p) (((_DINDEX(dv, p)) & 0x3fff) << 0)
79 #define UPDATE_D(dict, cycle, dv, p) dict[DINDEX(dv, p)] = (p)
80 #define UPDATE_I(dict, cycle, index, p) dict[index] = (p)
81 
82 #define LZO_CHECK_MPOS_NON_DET(m_pos, m_off, in, ip, max_offset) \
83     ((void*) m_pos < (void*) in || \
84     (m_off = (UINT8*) ip - (UINT8*) m_pos) <= 0 || \
85     m_off > max_offset)
86 
87 #define LZO_BYTE(x) ((unsigned char) (x))
88 
89 #define ZSTD_ALLOC_TAG 0x6474737a // "zstd"
90 
91 // needs to be the same as Linux (fs/btrfs/zstd.c)
92 #define ZSTD_BTRFS_MAX_WINDOWLOG 17
93 
94 static void* zstd_malloc(void* opaque, size_t size);
95 static void zstd_free(void* opaque, void* address);
96 
97 #ifndef __REACTOS__
98 ZSTD_customMem zstd_mem = { .customAlloc = zstd_malloc, .customFree = zstd_free, .opaque = NULL };
99 #else
100 ZSTD_customMem zstd_mem = { zstd_malloc, zstd_free, NULL };
101 #endif
102 
103 static UINT8 lzo_nextbyte(lzo_stream* stream) {
104     UINT8 c;
105 
106     if (stream->inpos >= stream->inlen) {
107         stream->error = TRUE;
108         return 0;
109     }
110 
111     c = stream->in[stream->inpos];
112     stream->inpos++;
113 
114     return c;
115 }
116 
117 static int lzo_len(lzo_stream* stream, int byte, int mask) {
118     int len = byte & mask;
119 
120     if (len == 0) {
121         while (!(byte = lzo_nextbyte(stream))) {
122             if (stream->error) return 0;
123 
124             len += 255;
125         }
126 
127         len += mask + byte;
128     }
129 
130     return len;
131 }
132 
133 static void lzo_copy(lzo_stream* stream, int len) {
134     if (stream->inpos + len > stream->inlen) {
135         stream->error = TRUE;
136         return;
137     }
138 
139     if (stream->outpos + len > stream->outlen) {
140         stream->error = TRUE;
141         return;
142     }
143 
144     do {
145         stream->out[stream->outpos] = stream->in[stream->inpos];
146         stream->inpos++;
147         stream->outpos++;
148         len--;
149     } while (len > 0);
150 }
151 
152 static void lzo_copyback(lzo_stream* stream, UINT32 back, int len) {
153     if (stream->outpos < back) {
154         stream->error = TRUE;
155         return;
156     }
157 
158     if (stream->outpos + len > stream->outlen) {
159         stream->error = TRUE;
160         return;
161     }
162 
163     do {
164         stream->out[stream->outpos] = stream->out[stream->outpos - back];
165         stream->outpos++;
166         len--;
167     } while (len > 0);
168 }
169 
170 static NTSTATUS do_lzo_decompress(lzo_stream* stream) {
171     UINT8 byte;
172     UINT32 len, back;
173     BOOL backcopy = FALSE;
174 
175     stream->error = FALSE;
176 
177     byte = lzo_nextbyte(stream);
178     if (stream->error) return STATUS_INTERNAL_ERROR;
179 
180     if (byte > 17) {
181         lzo_copy(stream, min((UINT8)(byte - 17), (UINT32)(stream->outlen - stream->outpos)));
182         if (stream->error) return STATUS_INTERNAL_ERROR;
183 
184         if (stream->outlen == stream->outpos)
185             return STATUS_SUCCESS;
186 
187         byte = lzo_nextbyte(stream);
188         if (stream->error) return STATUS_INTERNAL_ERROR;
189 
190         if (byte < 16) return STATUS_INTERNAL_ERROR;
191     }
192 
193     while (1) {
194         if (byte >> 4) {
195             backcopy = TRUE;
196             if (byte >> 6) {
197                 len = (byte >> 5) - 1;
198                 back = (lzo_nextbyte(stream) << 3) + ((byte >> 2) & 7) + 1;
199                 if (stream->error) return STATUS_INTERNAL_ERROR;
200             } else if (byte >> 5) {
201                 len = lzo_len(stream, byte, 31);
202                 if (stream->error) return STATUS_INTERNAL_ERROR;
203 
204                 byte = lzo_nextbyte(stream);
205                 if (stream->error) return STATUS_INTERNAL_ERROR;
206 
207                 back = (lzo_nextbyte(stream) << 6) + (byte >> 2) + 1;
208                 if (stream->error) return STATUS_INTERNAL_ERROR;
209             } else {
210                 len = lzo_len(stream, byte, 7);
211                 if (stream->error) return STATUS_INTERNAL_ERROR;
212 
213                 back = (1 << 14) + ((byte & 8) << 11);
214 
215                 byte = lzo_nextbyte(stream);
216                 if (stream->error) return STATUS_INTERNAL_ERROR;
217 
218                 back += (lzo_nextbyte(stream) << 6) + (byte >> 2);
219                 if (stream->error) return STATUS_INTERNAL_ERROR;
220 
221                 if (back == (1 << 14)) {
222                     if (len != 1)
223                         return STATUS_INTERNAL_ERROR;
224                     break;
225                 }
226             }
227         } else if (backcopy) {
228             len = 0;
229             back = (lzo_nextbyte(stream) << 2) + (byte >> 2) + 1;
230             if (stream->error) return STATUS_INTERNAL_ERROR;
231         } else {
232             len = lzo_len(stream, byte, 15);
233             if (stream->error) return STATUS_INTERNAL_ERROR;
234 
235             lzo_copy(stream, min(len + 3, stream->outlen - stream->outpos));
236             if (stream->error) return STATUS_INTERNAL_ERROR;
237 
238             if (stream->outlen == stream->outpos)
239                 return STATUS_SUCCESS;
240 
241             byte = lzo_nextbyte(stream);
242             if (stream->error) return STATUS_INTERNAL_ERROR;
243 
244             if (byte >> 4)
245                 continue;
246 
247             len = 1;
248             back = (1 << 11) + (lzo_nextbyte(stream) << 2) + (byte >> 2) + 1;
249             if (stream->error) return STATUS_INTERNAL_ERROR;
250 
251             break;
252         }
253 
254         lzo_copyback(stream, back, min(len + 2, stream->outlen - stream->outpos));
255         if (stream->error) return STATUS_INTERNAL_ERROR;
256 
257         if (stream->outlen == stream->outpos)
258             return STATUS_SUCCESS;
259 
260         len = byte & 3;
261 
262         if (len) {
263             lzo_copy(stream, min(len, stream->outlen - stream->outpos));
264             if (stream->error) return STATUS_INTERNAL_ERROR;
265 
266             if (stream->outlen == stream->outpos)
267                 return STATUS_SUCCESS;
268         } else
269             backcopy = !backcopy;
270 
271         byte = lzo_nextbyte(stream);
272         if (stream->error) return STATUS_INTERNAL_ERROR;
273     }
274 
275     return STATUS_SUCCESS;
276 }
277 
278 NTSTATUS lzo_decompress(UINT8* inbuf, UINT32 inlen, UINT8* outbuf, UINT32 outlen, UINT32 inpageoff) {
279     NTSTATUS Status;
280     UINT32 partlen, inoff, outoff;
281     lzo_stream stream;
282 
283     inoff = 0;
284     outoff = 0;
285 
286     do {
287         partlen = *(UINT32*)&inbuf[inoff];
288 
289         if (partlen + inoff > inlen) {
290             ERR("overflow: %x + %x > %llx\n", partlen, inoff, inlen);
291             return STATUS_INTERNAL_ERROR;
292         }
293 
294         inoff += sizeof(UINT32);
295 
296         stream.in = &inbuf[inoff];
297         stream.inlen = partlen;
298         stream.inpos = 0;
299         stream.out = &outbuf[outoff];
300         stream.outlen = min(outlen, LINUX_PAGE_SIZE);
301         stream.outpos = 0;
302 
303         Status = do_lzo_decompress(&stream);
304         if (!NT_SUCCESS(Status)) {
305             ERR("do_lzo_decompress returned %08x\n", Status);
306             return Status;
307         }
308 
309         if (stream.outpos < stream.outlen)
310             RtlZeroMemory(&stream.out[stream.outpos], stream.outlen - stream.outpos);
311 
312         inoff += partlen;
313         outoff += stream.outlen;
314 
315         if (LINUX_PAGE_SIZE - ((inpageoff + inoff) % LINUX_PAGE_SIZE) < sizeof(UINT32))
316             inoff = ((((inpageoff + inoff) / LINUX_PAGE_SIZE) + 1) * LINUX_PAGE_SIZE) - inpageoff;
317 
318         outlen -= stream.outlen;
319     } while (inoff < inlen && outlen > 0);
320 
321     return STATUS_SUCCESS;
322 }
323 
324 static void* zlib_alloc(void* opaque, unsigned int items, unsigned int size) {
325     UNUSED(opaque);
326 
327     return ExAllocatePoolWithTag(PagedPool, items * size, ALLOC_TAG_ZLIB);
328 }
329 
330 static void zlib_free(void* opaque, void* ptr) {
331     UNUSED(opaque);
332 
333     ExFreePool(ptr);
334 }
335 
336 NTSTATUS zlib_decompress(UINT8* inbuf, UINT32 inlen, UINT8* outbuf, UINT32 outlen) {
337     z_stream c_stream;
338     int ret;
339 
340     c_stream.zalloc = zlib_alloc;
341     c_stream.zfree = zlib_free;
342     c_stream.opaque = (voidpf)0;
343 
344     ret = inflateInit(&c_stream);
345 
346     if (ret != Z_OK) {
347         ERR("inflateInit returned %08x\n", ret);
348         return STATUS_INTERNAL_ERROR;
349     }
350 
351     c_stream.next_in = inbuf;
352     c_stream.avail_in = inlen;
353 
354     c_stream.next_out = outbuf;
355     c_stream.avail_out = outlen;
356 
357     do {
358         ret = inflate(&c_stream, Z_NO_FLUSH);
359 
360         if (ret != Z_OK && ret != Z_STREAM_END) {
361             ERR("inflate returned %08x\n", ret);
362             inflateEnd(&c_stream);
363             return STATUS_INTERNAL_ERROR;
364         }
365 
366         if (c_stream.avail_out == 0)
367             break;
368     } while (ret != Z_STREAM_END);
369 
370     ret = inflateEnd(&c_stream);
371 
372     if (ret != Z_OK) {
373         ERR("inflateEnd returned %08x\n", ret);
374         return STATUS_INTERNAL_ERROR;
375     }
376 
377     // FIXME - if we're short, should we zero the end of outbuf so we don't leak information into userspace?
378 
379     return STATUS_SUCCESS;
380 }
381 
382 static NTSTATUS zlib_write_compressed_bit(fcb* fcb, UINT64 start_data, UINT64 end_data, void* data, BOOL* compressed, PIRP Irp, LIST_ENTRY* rollback) {
383     NTSTATUS Status;
384     UINT8 compression;
385     UINT32 comp_length;
386     UINT8* comp_data;
387     UINT32 out_left;
388     LIST_ENTRY* le;
389     chunk* c;
390     z_stream c_stream;
391     int ret;
392 
393     comp_data = ExAllocatePoolWithTag(PagedPool, (UINT32)(end_data - start_data), ALLOC_TAG);
394     if (!comp_data) {
395         ERR("out of memory\n");
396         return STATUS_INSUFFICIENT_RESOURCES;
397     }
398 
399     Status = excise_extents(fcb->Vcb, fcb, start_data, end_data, Irp, rollback);
400     if (!NT_SUCCESS(Status)) {
401         ERR("excise_extents returned %08x\n", Status);
402         ExFreePool(comp_data);
403         return Status;
404     }
405 
406     c_stream.zalloc = zlib_alloc;
407     c_stream.zfree = zlib_free;
408     c_stream.opaque = (voidpf)0;
409 
410     ret = deflateInit(&c_stream, fcb->Vcb->options.zlib_level);
411 
412     if (ret != Z_OK) {
413         ERR("deflateInit returned %08x\n", ret);
414         ExFreePool(comp_data);
415         return STATUS_INTERNAL_ERROR;
416     }
417 
418     c_stream.avail_in = (UINT32)(end_data - start_data);
419     c_stream.next_in = data;
420     c_stream.avail_out = (UINT32)(end_data - start_data);
421     c_stream.next_out = comp_data;
422 
423     do {
424         ret = deflate(&c_stream, Z_FINISH);
425 
426         if (ret == Z_STREAM_ERROR) {
427             ERR("deflate returned %x\n", ret);
428             ExFreePool(comp_data);
429             return STATUS_INTERNAL_ERROR;
430         }
431     } while (c_stream.avail_in > 0 && c_stream.avail_out > 0);
432 
433     out_left = c_stream.avail_out;
434 
435     ret = deflateEnd(&c_stream);
436 
437     if (ret != Z_OK) {
438         ERR("deflateEnd returned %08x\n", ret);
439         ExFreePool(comp_data);
440         return STATUS_INTERNAL_ERROR;
441     }
442 
443     if (out_left < fcb->Vcb->superblock.sector_size) { // compressed extent would be larger than or same size as uncompressed extent
444         ExFreePool(comp_data);
445 
446         comp_length = (UINT32)(end_data - start_data);
447         comp_data = data;
448         compression = BTRFS_COMPRESSION_NONE;
449 
450         *compressed = FALSE;
451     } else {
452         UINT32 cl;
453 
454         compression = BTRFS_COMPRESSION_ZLIB;
455         cl = (UINT32)(end_data - start_data - out_left);
456         comp_length = (UINT32)sector_align(cl, fcb->Vcb->superblock.sector_size);
457 
458         RtlZeroMemory(comp_data + cl, comp_length - cl);
459 
460         *compressed = TRUE;
461     }
462 
463     ExAcquireResourceSharedLite(&fcb->Vcb->chunk_lock, TRUE);
464 
465     le = fcb->Vcb->chunks.Flink;
466     while (le != &fcb->Vcb->chunks) {
467         c = CONTAINING_RECORD(le, chunk, list_entry);
468 
469         if (!c->readonly && !c->reloc) {
470             acquire_chunk_lock(c, fcb->Vcb);
471 
472             if (c->chunk_item->type == fcb->Vcb->data_flags && (c->chunk_item->size - c->used) >= comp_length) {
473                 if (insert_extent_chunk(fcb->Vcb, fcb, c, start_data, comp_length, FALSE, comp_data, Irp, rollback, compression, end_data - start_data, FALSE, 0)) {
474                     ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
475 
476                     if (compression != BTRFS_COMPRESSION_NONE)
477                         ExFreePool(comp_data);
478 
479                     return STATUS_SUCCESS;
480                 }
481             }
482 
483             release_chunk_lock(c, fcb->Vcb);
484         }
485 
486         le = le->Flink;
487     }
488 
489     ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
490 
491     ExAcquireResourceExclusiveLite(&fcb->Vcb->chunk_lock, TRUE);
492 
493     Status = alloc_chunk(fcb->Vcb, fcb->Vcb->data_flags, &c, FALSE);
494 
495     ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
496 
497     if (!NT_SUCCESS(Status)) {
498         ERR("alloc_chunk returned %08x\n", Status);
499 
500         if (compression != BTRFS_COMPRESSION_NONE)
501             ExFreePool(comp_data);
502 
503         return Status;
504     }
505 
506     if (c) {
507         acquire_chunk_lock(c, fcb->Vcb);
508 
509         if (c->chunk_item->type == fcb->Vcb->data_flags && (c->chunk_item->size - c->used) >= comp_length) {
510             if (insert_extent_chunk(fcb->Vcb, fcb, c, start_data, comp_length, FALSE, comp_data, Irp, rollback, compression, end_data - start_data, FALSE, 0)) {
511                 if (compression != BTRFS_COMPRESSION_NONE)
512                     ExFreePool(comp_data);
513 
514                 return STATUS_SUCCESS;
515             }
516         }
517 
518         release_chunk_lock(c, fcb->Vcb);
519     }
520 
521     WARN("couldn't find any data chunks with %llx bytes free\n", comp_length);
522 
523     if (compression != BTRFS_COMPRESSION_NONE)
524         ExFreePool(comp_data);
525 
526     return STATUS_DISK_FULL;
527 }
528 
529 static NTSTATUS lzo_do_compress(const UINT8* in, UINT32 in_len, UINT8* out, UINT32* out_len, void* wrkmem) {
530     const UINT8* ip;
531     UINT32 dv;
532     UINT8* op;
533     const UINT8* in_end = in + in_len;
534     const UINT8* ip_end = in + in_len - 9 - 4;
535     const UINT8* ii;
536     const UINT8** dict = (const UINT8**)wrkmem;
537 
538     op = out;
539     ip = in;
540     ii = ip;
541 
542     DVAL_FIRST(dv, ip); UPDATE_D(dict, cycle, dv, ip); ip++;
543     DVAL_NEXT(dv, ip);  UPDATE_D(dict, cycle, dv, ip); ip++;
544     DVAL_NEXT(dv, ip);  UPDATE_D(dict, cycle, dv, ip); ip++;
545     DVAL_NEXT(dv, ip);  UPDATE_D(dict, cycle, dv, ip); ip++;
546 
547     while (1) {
548         const UINT8* m_pos;
549         UINT32 m_len;
550         ptrdiff_t m_off;
551         UINT32 lit, dindex;
552 
553         dindex = DINDEX(dv, ip);
554         m_pos = dict[dindex];
555         UPDATE_I(dict, cycle, dindex, ip);
556 
557         if (!LZO_CHECK_MPOS_NON_DET(m_pos, m_off, in, ip, M4_MAX_OFFSET) && m_pos[0] == ip[0] && m_pos[1] == ip[1] && m_pos[2] == ip[2]) {
558             lit = (UINT32)(ip - ii);
559             m_pos += 3;
560             if (m_off <= M2_MAX_OFFSET)
561                 goto match;
562 
563             if (lit == 3) { /* better compression, but slower */
564                 if (op - 2 <= out)
565                     return STATUS_INTERNAL_ERROR;
566 
567                 op[-2] |= LZO_BYTE(3);
568                 *op++ = *ii++; *op++ = *ii++; *op++ = *ii++;
569                 goto code_match;
570             }
571 
572             if (*m_pos == ip[3])
573                 goto match;
574         }
575 
576         /* a literal */
577         ++ip;
578         if (ip >= ip_end)
579             break;
580         DVAL_NEXT(dv, ip);
581         continue;
582 
583         /* a match */
584 match:
585         /* store current literal run */
586         if (lit > 0) {
587             UINT32 t = lit;
588 
589             if (t <= 3) {
590                 if (op - 2 <= out)
591                     return STATUS_INTERNAL_ERROR;
592 
593                 op[-2] |= LZO_BYTE(t);
594             } else if (t <= 18)
595                 *op++ = LZO_BYTE(t - 3);
596             else {
597                 UINT32 tt = t - 18;
598 
599                 *op++ = 0;
600                 while (tt > 255) {
601                     tt -= 255;
602                     *op++ = 0;
603                 }
604 
605                 if (tt <= 0)
606                     return STATUS_INTERNAL_ERROR;
607 
608                 *op++ = LZO_BYTE(tt);
609             }
610 
611             do {
612                 *op++ = *ii++;
613             } while (--t > 0);
614         }
615 
616 
617         /* code the match */
618 code_match:
619         if (ii != ip)
620             return STATUS_INTERNAL_ERROR;
621 
622         ip += 3;
623         if (*m_pos++ != *ip++ || *m_pos++ != *ip++ || *m_pos++ != *ip++ ||
624             *m_pos++ != *ip++ || *m_pos++ != *ip++ || *m_pos++ != *ip++) {
625             --ip;
626             m_len = (UINT32)(ip - ii);
627 
628             if (m_len < 3 || m_len > 8)
629                 return STATUS_INTERNAL_ERROR;
630 
631             if (m_off <= M2_MAX_OFFSET) {
632                 m_off -= 1;
633                 *op++ = LZO_BYTE(((m_len - 1) << 5) | ((m_off & 7) << 2));
634                 *op++ = LZO_BYTE(m_off >> 3);
635             } else if (m_off <= M3_MAX_OFFSET) {
636                 m_off -= 1;
637                 *op++ = LZO_BYTE(M3_MARKER | (m_len - 2));
638                 goto m3_m4_offset;
639             } else {
640                 m_off -= 0x4000;
641 
642                 if (m_off <= 0 || m_off > 0x7fff)
643                     return STATUS_INTERNAL_ERROR;
644 
645                 *op++ = LZO_BYTE(M4_MARKER | ((m_off & 0x4000) >> 11) | (m_len - 2));
646                 goto m3_m4_offset;
647             }
648         } else {
649             const UINT8* end;
650             end = in_end;
651             while (ip < end && *m_pos == *ip)
652                 m_pos++, ip++;
653             m_len = (UINT32)(ip - ii);
654 
655             if (m_len < 3)
656                 return STATUS_INTERNAL_ERROR;
657 
658             if (m_off <= M3_MAX_OFFSET) {
659                 m_off -= 1;
660                 if (m_len <= 33)
661                     *op++ = LZO_BYTE(M3_MARKER | (m_len - 2));
662                 else {
663                     m_len -= 33;
664                     *op++ = M3_MARKER | 0;
665                     goto m3_m4_len;
666                 }
667             } else {
668                 m_off -= 0x4000;
669 
670                 if (m_off <= 0 || m_off > 0x7fff)
671                     return STATUS_INTERNAL_ERROR;
672 
673                 if (m_len <= 9)
674                     *op++ = LZO_BYTE(M4_MARKER | ((m_off & 0x4000) >> 11) | (m_len - 2));
675                 else {
676                     m_len -= 9;
677                     *op++ = LZO_BYTE(M4_MARKER | ((m_off & 0x4000) >> 11));
678 m3_m4_len:
679                     while (m_len > 255) {
680                         m_len -= 255;
681                         *op++ = 0;
682                     }
683 
684                     if (m_len <= 0)
685                         return STATUS_INTERNAL_ERROR;
686 
687                     *op++ = LZO_BYTE(m_len);
688                 }
689             }
690 
691 m3_m4_offset:
692             *op++ = LZO_BYTE((m_off & 63) << 2);
693             *op++ = LZO_BYTE(m_off >> 6);
694         }
695 
696         ii = ip;
697         if (ip >= ip_end)
698             break;
699         DVAL_FIRST(dv, ip);
700     }
701 
702     /* store final literal run */
703     if (in_end - ii > 0) {
704         UINT32 t = (UINT32)(in_end - ii);
705 
706         if (op == out && t <= 238)
707             *op++ = LZO_BYTE(17 + t);
708         else if (t <= 3)
709             op[-2] |= LZO_BYTE(t);
710         else if (t <= 18)
711             *op++ = LZO_BYTE(t - 3);
712         else {
713             UINT32 tt = t - 18;
714 
715             *op++ = 0;
716             while (tt > 255) {
717                 tt -= 255;
718                 *op++ = 0;
719             }
720 
721             if (tt <= 0)
722                 return STATUS_INTERNAL_ERROR;
723 
724             *op++ = LZO_BYTE(tt);
725         }
726 
727         do {
728             *op++ = *ii++;
729         } while (--t > 0);
730     }
731 
732     *out_len = (UINT32)(op - out);
733 
734     return STATUS_SUCCESS;
735 }
736 
737 static NTSTATUS lzo1x_1_compress(lzo_stream* stream) {
738     UINT8 *op = stream->out;
739     NTSTATUS Status = STATUS_SUCCESS;
740 
741     if (stream->inlen <= 0)
742         stream->outlen = 0;
743     else if (stream->inlen <= 9 + 4) {
744         *op++ = LZO_BYTE(17 + stream->inlen);
745 
746         stream->inpos = 0;
747         do {
748             *op++ = stream->in[stream->inpos];
749             stream->inpos++;
750         } while (stream->inlen < stream->inpos);
751         stream->outlen = (UINT32)(op - stream->out);
752     } else
753         Status = lzo_do_compress(stream->in, stream->inlen, stream->out, &stream->outlen, stream->wrkmem);
754 
755     if (Status == STATUS_SUCCESS) {
756         op = stream->out + stream->outlen;
757         *op++ = M4_MARKER | 1;
758         *op++ = 0;
759         *op++ = 0;
760         stream->outlen += 3;
761     }
762 
763     return Status;
764 }
765 
766 static __inline UINT32 lzo_max_outlen(UINT32 inlen) {
767     return inlen + (inlen / 16) + 64 + 3; // formula comes from LZO.FAQ
768 }
769 
770 static NTSTATUS lzo_write_compressed_bit(fcb* fcb, UINT64 start_data, UINT64 end_data, void* data, BOOL* compressed, PIRP Irp, LIST_ENTRY* rollback) {
771     NTSTATUS Status;
772     UINT8 compression;
773     UINT64 comp_length;
774     ULONG comp_data_len, num_pages, i;
775     UINT8* comp_data;
776     BOOL skip_compression = FALSE;
777     lzo_stream stream;
778     UINT32* out_size;
779     LIST_ENTRY* le;
780     chunk* c;
781 
782     num_pages = (ULONG)((sector_align(end_data - start_data, LINUX_PAGE_SIZE)) / LINUX_PAGE_SIZE);
783 
784     // Four-byte overall header
785     // Another four-byte header page
786     // Each page has a maximum size of lzo_max_outlen(LINUX_PAGE_SIZE)
787     // Plus another four bytes for possible padding
788     comp_data_len = sizeof(UINT32) + ((lzo_max_outlen(LINUX_PAGE_SIZE) + (2 * sizeof(UINT32))) * num_pages);
789 
790     comp_data = ExAllocatePoolWithTag(PagedPool, comp_data_len, ALLOC_TAG);
791     if (!comp_data) {
792         ERR("out of memory\n");
793         return STATUS_INSUFFICIENT_RESOURCES;
794     }
795 
796     stream.wrkmem = ExAllocatePoolWithTag(PagedPool, LZO1X_MEM_COMPRESS, ALLOC_TAG);
797     if (!stream.wrkmem) {
798         ERR("out of memory\n");
799         ExFreePool(comp_data);
800         return STATUS_INSUFFICIENT_RESOURCES;
801     }
802 
803     Status = excise_extents(fcb->Vcb, fcb, start_data, end_data, Irp, rollback);
804     if (!NT_SUCCESS(Status)) {
805         ERR("excise_extents returned %08x\n", Status);
806         ExFreePool(comp_data);
807         ExFreePool(stream.wrkmem);
808         return Status;
809     }
810 
811     out_size = (UINT32*)comp_data;
812     *out_size = sizeof(UINT32);
813 
814     stream.in = data;
815     stream.out = comp_data + (2 * sizeof(UINT32));
816 
817     for (i = 0; i < num_pages; i++) {
818         UINT32* pagelen = (UINT32*)(stream.out - sizeof(UINT32));
819 
820         stream.inlen = (UINT32)min(LINUX_PAGE_SIZE, end_data - start_data - (i * LINUX_PAGE_SIZE));
821 
822         Status = lzo1x_1_compress(&stream);
823         if (!NT_SUCCESS(Status)) {
824             ERR("lzo1x_1_compress returned %08x\n", Status);
825             skip_compression = TRUE;
826             break;
827         }
828 
829         *pagelen = stream.outlen;
830         *out_size += stream.outlen + sizeof(UINT32);
831 
832         stream.in += LINUX_PAGE_SIZE;
833         stream.out += stream.outlen + sizeof(UINT32);
834 
835         if (LINUX_PAGE_SIZE - (*out_size % LINUX_PAGE_SIZE) < sizeof(UINT32)) {
836             RtlZeroMemory(stream.out, LINUX_PAGE_SIZE - (*out_size % LINUX_PAGE_SIZE));
837             stream.out += LINUX_PAGE_SIZE - (*out_size % LINUX_PAGE_SIZE);
838             *out_size += LINUX_PAGE_SIZE - (*out_size % LINUX_PAGE_SIZE);
839         }
840     }
841 
842     ExFreePool(stream.wrkmem);
843 
844     if (skip_compression || *out_size >= end_data - start_data - fcb->Vcb->superblock.sector_size) { // compressed extent would be larger than or same size as uncompressed extent
845         ExFreePool(comp_data);
846 
847         comp_length = end_data - start_data;
848         comp_data = data;
849         compression = BTRFS_COMPRESSION_NONE;
850 
851         *compressed = FALSE;
852     } else {
853         compression = BTRFS_COMPRESSION_LZO;
854         comp_length = sector_align(*out_size, fcb->Vcb->superblock.sector_size);
855 
856         RtlZeroMemory(comp_data + *out_size, (ULONG)(comp_length - *out_size));
857 
858         *compressed = TRUE;
859     }
860 
861     ExAcquireResourceSharedLite(&fcb->Vcb->chunk_lock, TRUE);
862 
863     le = fcb->Vcb->chunks.Flink;
864     while (le != &fcb->Vcb->chunks) {
865         c = CONTAINING_RECORD(le, chunk, list_entry);
866 
867         if (!c->readonly && !c->reloc) {
868             acquire_chunk_lock(c, fcb->Vcb);
869 
870             if (c->chunk_item->type == fcb->Vcb->data_flags && (c->chunk_item->size - c->used) >= comp_length) {
871                 if (insert_extent_chunk(fcb->Vcb, fcb, c, start_data, comp_length, FALSE, comp_data, Irp, rollback, compression, end_data - start_data, FALSE, 0)) {
872                     ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
873 
874                     if (compression != BTRFS_COMPRESSION_NONE)
875                         ExFreePool(comp_data);
876 
877                     return STATUS_SUCCESS;
878                 }
879             }
880 
881             release_chunk_lock(c, fcb->Vcb);
882         }
883 
884         le = le->Flink;
885     }
886 
887     ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
888 
889     ExAcquireResourceExclusiveLite(&fcb->Vcb->chunk_lock, TRUE);
890 
891     Status = alloc_chunk(fcb->Vcb, fcb->Vcb->data_flags, &c, FALSE);
892 
893     ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
894 
895     if (!NT_SUCCESS(Status)) {
896         ERR("alloc_chunk returned %08x\n", Status);
897 
898         if (compression != BTRFS_COMPRESSION_NONE)
899             ExFreePool(comp_data);
900 
901         return Status;
902     }
903 
904     if (c) {
905         acquire_chunk_lock(c, fcb->Vcb);
906 
907         if (c->chunk_item->type == fcb->Vcb->data_flags && (c->chunk_item->size - c->used) >= comp_length) {
908             if (insert_extent_chunk(fcb->Vcb, fcb, c, start_data, comp_length, FALSE, comp_data, Irp, rollback, compression, end_data - start_data, FALSE, 0)) {
909                 if (compression != BTRFS_COMPRESSION_NONE)
910                     ExFreePool(comp_data);
911 
912                 return STATUS_SUCCESS;
913             }
914         }
915 
916         release_chunk_lock(c, fcb->Vcb);
917     }
918 
919     WARN("couldn't find any data chunks with %llx bytes free\n", comp_length);
920 
921     if (compression != BTRFS_COMPRESSION_NONE)
922         ExFreePool(comp_data);
923 
924     return STATUS_DISK_FULL;
925 }
926 
927 static NTSTATUS zstd_write_compressed_bit(fcb* fcb, UINT64 start_data, UINT64 end_data, void* data, BOOL* compressed, PIRP Irp, LIST_ENTRY* rollback) {
928     NTSTATUS Status;
929     UINT8 compression;
930     UINT32 comp_length;
931     UINT8* comp_data;
932     UINT32 out_left;
933     LIST_ENTRY* le;
934     chunk* c;
935     ZSTD_CStream* stream;
936     size_t init_res, written;
937     ZSTD_inBuffer input;
938     ZSTD_outBuffer output;
939     ZSTD_parameters params;
940 
941     comp_data = ExAllocatePoolWithTag(PagedPool, (UINT32)(end_data - start_data), ALLOC_TAG);
942     if (!comp_data) {
943         ERR("out of memory\n");
944         return STATUS_INSUFFICIENT_RESOURCES;
945     }
946 
947     Status = excise_extents(fcb->Vcb, fcb, start_data, end_data, Irp, rollback);
948     if (!NT_SUCCESS(Status)) {
949         ERR("excise_extents returned %08x\n", Status);
950         ExFreePool(comp_data);
951         return Status;
952     }
953 
954     stream = ZSTD_createCStream_advanced(zstd_mem);
955 
956     if (!stream) {
957         ERR("ZSTD_createCStream failed.\n");
958         ExFreePool(comp_data);
959         return STATUS_INTERNAL_ERROR;
960     }
961 
962     params = ZSTD_getParams(fcb->Vcb->options.zstd_level, (UINT32)(end_data - start_data), 0);
963 
964     if (params.cParams.windowLog > ZSTD_BTRFS_MAX_WINDOWLOG)
965         params.cParams.windowLog = ZSTD_BTRFS_MAX_WINDOWLOG;
966 
967     init_res = ZSTD_initCStream_advanced(stream, NULL, 0, params, (UINT32)(end_data - start_data));
968 
969     if (ZSTD_isError(init_res)) {
970         ERR("ZSTD_initCStream_advanced failed: %s\n", ZSTD_getErrorName(init_res));
971         ZSTD_freeCStream(stream);
972         ExFreePool(comp_data);
973         return STATUS_INTERNAL_ERROR;
974     }
975 
976     input.src = data;
977     input.size = (UINT32)(end_data - start_data);
978     input.pos = 0;
979 
980     output.dst = comp_data;
981     output.size = (UINT32)(end_data - start_data);
982     output.pos = 0;
983 
984     while (input.pos < input.size && output.pos < output.size) {
985         written = ZSTD_compressStream(stream, &output, &input);
986 
987         if (ZSTD_isError(written)) {
988             ERR("ZSTD_compressStream failed: %s\n", ZSTD_getErrorName(written));
989             ZSTD_freeCStream(stream);
990             ExFreePool(comp_data);
991             return STATUS_INTERNAL_ERROR;
992         }
993     }
994 
995     written = ZSTD_endStream(stream, &output);
996     if (ZSTD_isError(written)) {
997         ERR("ZSTD_endStream failed: %s\n", ZSTD_getErrorName(written));
998         ZSTD_freeCStream(stream);
999         ExFreePool(comp_data);
1000         return STATUS_INTERNAL_ERROR;
1001     }
1002 
1003     ZSTD_freeCStream(stream);
1004 
1005     out_left = output.size - output.pos;
1006 
1007     if (out_left < fcb->Vcb->superblock.sector_size) { // compressed extent would be larger than or same size as uncompressed extent
1008         ExFreePool(comp_data);
1009 
1010         comp_length = (UINT32)(end_data - start_data);
1011         comp_data = data;
1012         compression = BTRFS_COMPRESSION_NONE;
1013 
1014         *compressed = FALSE;
1015     } else {
1016         UINT32 cl;
1017 
1018         compression = BTRFS_COMPRESSION_ZSTD;
1019         cl = (UINT32)(end_data - start_data - out_left);
1020         comp_length = (UINT32)sector_align(cl, fcb->Vcb->superblock.sector_size);
1021 
1022         RtlZeroMemory(comp_data + cl, comp_length - cl);
1023 
1024         *compressed = TRUE;
1025     }
1026 
1027     ExAcquireResourceSharedLite(&fcb->Vcb->chunk_lock, TRUE);
1028 
1029     le = fcb->Vcb->chunks.Flink;
1030     while (le != &fcb->Vcb->chunks) {
1031         c = CONTAINING_RECORD(le, chunk, list_entry);
1032 
1033         if (!c->readonly && !c->reloc) {
1034             acquire_chunk_lock(c, fcb->Vcb);
1035 
1036             if (c->chunk_item->type == fcb->Vcb->data_flags && (c->chunk_item->size - c->used) >= comp_length) {
1037                 if (insert_extent_chunk(fcb->Vcb, fcb, c, start_data, comp_length, FALSE, comp_data, Irp, rollback, compression, end_data - start_data, FALSE, 0)) {
1038                     ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
1039 
1040                     if (compression != BTRFS_COMPRESSION_NONE)
1041                         ExFreePool(comp_data);
1042 
1043                     return STATUS_SUCCESS;
1044                 }
1045             }
1046 
1047             release_chunk_lock(c, fcb->Vcb);
1048         }
1049 
1050         le = le->Flink;
1051     }
1052 
1053     ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
1054 
1055     ExAcquireResourceExclusiveLite(&fcb->Vcb->chunk_lock, TRUE);
1056 
1057     Status = alloc_chunk(fcb->Vcb, fcb->Vcb->data_flags, &c, FALSE);
1058 
1059     ExReleaseResourceLite(&fcb->Vcb->chunk_lock);
1060 
1061     if (!NT_SUCCESS(Status)) {
1062         ERR("alloc_chunk returned %08x\n", Status);
1063 
1064         if (compression != BTRFS_COMPRESSION_NONE)
1065             ExFreePool(comp_data);
1066 
1067         return Status;
1068     }
1069 
1070     if (c) {
1071         acquire_chunk_lock(c, fcb->Vcb);
1072 
1073         if (c->chunk_item->type == fcb->Vcb->data_flags && (c->chunk_item->size - c->used) >= comp_length) {
1074             if (insert_extent_chunk(fcb->Vcb, fcb, c, start_data, comp_length, FALSE, comp_data, Irp, rollback, compression, end_data - start_data, FALSE, 0)) {
1075                 if (compression != BTRFS_COMPRESSION_NONE)
1076                     ExFreePool(comp_data);
1077 
1078                 return STATUS_SUCCESS;
1079             }
1080         }
1081 
1082         release_chunk_lock(c, fcb->Vcb);
1083     }
1084 
1085     WARN("couldn't find any data chunks with %llx bytes free\n", comp_length);
1086 
1087     if (compression != BTRFS_COMPRESSION_NONE)
1088         ExFreePool(comp_data);
1089 
1090     return STATUS_DISK_FULL;
1091 }
1092 
1093 NTSTATUS write_compressed_bit(fcb* fcb, UINT64 start_data, UINT64 end_data, void* data, BOOL* compressed, PIRP Irp, LIST_ENTRY* rollback) {
1094     UINT8 type;
1095 
1096     if (fcb->Vcb->options.compress_type != 0 && fcb->prop_compression == PropCompression_None)
1097         type = fcb->Vcb->options.compress_type;
1098     else {
1099         if (!(fcb->Vcb->superblock.incompat_flags & BTRFS_INCOMPAT_FLAGS_COMPRESS_ZSTD) && fcb->prop_compression == PropCompression_ZSTD)
1100             type = BTRFS_COMPRESSION_ZSTD;
1101         else if (fcb->Vcb->superblock.incompat_flags & BTRFS_INCOMPAT_FLAGS_COMPRESS_ZSTD && fcb->prop_compression != PropCompression_Zlib && fcb->prop_compression != PropCompression_LZO)
1102             type = BTRFS_COMPRESSION_ZSTD;
1103         else if (!(fcb->Vcb->superblock.incompat_flags & BTRFS_INCOMPAT_FLAGS_COMPRESS_LZO) && fcb->prop_compression == PropCompression_LZO)
1104             type = BTRFS_COMPRESSION_LZO;
1105         else if (fcb->Vcb->superblock.incompat_flags & BTRFS_INCOMPAT_FLAGS_COMPRESS_LZO && fcb->prop_compression != PropCompression_Zlib)
1106             type = BTRFS_COMPRESSION_LZO;
1107         else
1108             type = BTRFS_COMPRESSION_ZLIB;
1109     }
1110 
1111     if (type == BTRFS_COMPRESSION_ZSTD) {
1112         fcb->Vcb->superblock.incompat_flags |= BTRFS_INCOMPAT_FLAGS_COMPRESS_ZSTD;
1113         return zstd_write_compressed_bit(fcb, start_data, end_data, data, compressed, Irp, rollback);
1114     } else if (type == BTRFS_COMPRESSION_LZO) {
1115         fcb->Vcb->superblock.incompat_flags |= BTRFS_INCOMPAT_FLAGS_COMPRESS_LZO;
1116         return lzo_write_compressed_bit(fcb, start_data, end_data, data, compressed, Irp, rollback);
1117     } else
1118         return zlib_write_compressed_bit(fcb, start_data, end_data, data, compressed, Irp, rollback);
1119 }
1120 
1121 static void* zstd_malloc(void* opaque, size_t size) {
1122     UNUSED(opaque);
1123 
1124     return ExAllocatePoolWithTag(PagedPool, size, ZSTD_ALLOC_TAG);
1125 }
1126 
1127 static void zstd_free(void* opaque, void* address) {
1128     UNUSED(opaque);
1129 
1130     ExFreePool(address);
1131 }
1132 
1133 NTSTATUS zstd_decompress(UINT8* inbuf, UINT32 inlen, UINT8* outbuf, UINT32 outlen) {
1134     NTSTATUS Status;
1135     ZSTD_DStream* stream;
1136     size_t init_res, read;
1137     ZSTD_inBuffer input;
1138     ZSTD_outBuffer output;
1139 
1140     stream = ZSTD_createDStream_advanced(zstd_mem);
1141 
1142     if (!stream) {
1143         ERR("ZSTD_createDStream failed.\n");
1144         return STATUS_INTERNAL_ERROR;
1145     }
1146 
1147     init_res = ZSTD_initDStream(stream);
1148 
1149     if (ZSTD_isError(init_res)) {
1150         ERR("ZSTD_initDStream failed: %s\n", ZSTD_getErrorName(init_res));
1151         Status = STATUS_INTERNAL_ERROR;
1152         goto end;
1153     }
1154 
1155     input.src = inbuf;
1156     input.size = inlen;
1157     input.pos = 0;
1158 
1159     output.dst = outbuf;
1160     output.size = outlen;
1161     output.pos = 0;
1162 
1163     read = ZSTD_decompressStream(stream, &output, &input);
1164 
1165     if (ZSTD_isError(read)) {
1166         ERR("ZSTD_decompressStream failed: %s\n", ZSTD_getErrorName(read));
1167         Status = STATUS_INTERNAL_ERROR;
1168         goto end;
1169     }
1170 
1171     Status = STATUS_SUCCESS;
1172 
1173 end:
1174     ZSTD_freeDStream(stream);
1175 
1176     return Status;
1177 }
1178