1 /* Copyright (c) 2020 Dovecot authors, see the included COPYING file */
2 
3 #include "lib.h"
4 
5 #ifdef HAVE_ZSTD
6 
7 #include "ostream.h"
8 #include "ostream-private.h"
9 #include "ostream-zlib.h"
10 
11 #include "zstd.h"
12 #include "zstd_errors.h"
13 #include "iostream-zstd-private.h"
14 
15 struct zstd_ostream {
16 	struct ostream_private ostream;
17 
18 	ZSTD_CStream *cstream;
19 	ZSTD_outBuffer output;
20 
21 	unsigned char *outbuf;
22 
23 	bool flushed:1;
24 	bool closed:1;
25 	bool finished:1;
26 };
27 
compression_get_min_level_zstd(void)28 int compression_get_min_level_zstd(void)
29 {
30 #if HAVE_DECL_ZSTD_MINCLEVEL == 1
31 	return ZSTD_minCLevel();
32 #else
33 	return 1;
34 #endif
35 }
36 
compression_get_default_level_zstd(void)37 int compression_get_default_level_zstd(void)
38 {
39 #ifdef ZSTD_CLEVEL_DEFAULT
40 	return ZSTD_CLEVEL_DEFAULT;
41 #else
42 	/* Documentation says 3 is default */
43 	return 3;
44 #endif
45 }
46 
compression_get_max_level_zstd(void)47 int compression_get_max_level_zstd(void)
48 {
49 	return ZSTD_maxCLevel();
50 }
51 
o_stream_zstd_write_error(struct zstd_ostream * zstream,size_t err)52 static void o_stream_zstd_write_error(struct zstd_ostream *zstream, size_t err)
53 {
54 	ZSTD_ErrorCode errcode = zstd_version_errcode(ZSTD_getErrorCode(err));
55 	const char *error = ZSTD_getErrorName(err);
56 	if (errcode == ZSTD_error_memory_allocation)
57 		i_fatal_status(FATAL_OUTOFMEM, "zstd.write(%s): Out of memory",
58 			       o_stream_get_name(&zstream->ostream.ostream));
59 	else if (errcode == ZSTD_error_prefix_unknown ||
60 #if HAVE_DECL_ZSTD_ERROR_PARAMETER_UNSUPPORTED == 1
61 		 errcode == ZSTD_error_parameter_unsupported ||
62 #endif
63 		 errcode == ZSTD_error_dictionary_wrong ||
64 		 errcode == ZSTD_error_init_missing)
65 		zstream->ostream.ostream.stream_errno = EINVAL;
66 	else
67 		zstream->ostream.ostream.stream_errno = EIO;
68 
69 	io_stream_set_error(&zstream->ostream.iostream,
70 			    "zstd.write(%s): %s at %"PRIuUOFF_T,
71 			    o_stream_get_name(&zstream->ostream.ostream), error,
72 			    zstream->ostream.ostream.offset);
73 }
74 
o_stream_zstd_send_outbuf(struct zstd_ostream * zstream)75 static ssize_t o_stream_zstd_send_outbuf(struct zstd_ostream *zstream)
76 {
77 	ssize_t ret;
78 	/* nothing to send */
79 	if (zstream->output.pos == 0)
80 		return 1;
81 	ret = o_stream_send(zstream->ostream.parent, zstream->output.dst,
82 			    zstream->output.pos);
83 	if (ret < 0) {
84 		o_stream_copy_error_from_parent(&zstream->ostream);
85 		return -1;
86 	} else {
87 		memmove(zstream->outbuf, zstream->outbuf+ret, zstream->output.pos-ret);
88 		zstream->output.pos -= ret;
89 	}
90 	if (zstream->output.pos > 0)
91 		return 0;
92 	return 1;
93 }
94 
95 static ssize_t
o_stream_zstd_sendv(struct ostream_private * stream,const struct const_iovec * iov,unsigned int iov_count)96 o_stream_zstd_sendv(struct ostream_private *stream,
97 		    const struct const_iovec *iov, unsigned int iov_count)
98 {
99 	struct zstd_ostream *zstream =
100 		container_of(stream, struct zstd_ostream, ostream);
101 	ssize_t total = 0;
102 	size_t ret;
103 
104 	for (unsigned int i = 0; i < iov_count; i++) {
105 		/* does it actually fit there */
106 		ZSTD_inBuffer input = {
107 			.src = iov[i].iov_base,
108 			.pos = 0,
109 			.size = iov[i].iov_len
110 		};
111 		bool flush_attempted = FALSE;
112 		for (;;) {
113 			size_t prev_pos = input.pos;
114 			ret = ZSTD_compressStream(zstream->cstream, &zstream->output,
115 						  &input);
116 			if (ZSTD_isError(ret) != 0) {
117 				o_stream_zstd_write_error(zstream, ret);
118 				return -1;
119 			}
120 			size_t new_input_size = input.pos - prev_pos;
121 			if (new_input_size == 0 && flush_attempted) {
122 				/* non-blocking output buffer full */
123 				return total;
124 			}
125 			stream->ostream.offset += new_input_size;
126 			total += new_input_size;
127 			if (input.pos == input.size)
128 				break;
129 			/* output buffer full. try to flush it. */
130 			if (o_stream_zstd_send_outbuf(zstream) < 0)
131 				return -1;
132 			flush_attempted = TRUE;
133 		}
134 	}
135 	if (o_stream_zstd_send_outbuf(zstream) < 0)
136 		return -1;
137 	return total;
138 }
139 
o_stream_zstd_send_flush(struct zstd_ostream * zstream,bool final)140 static int o_stream_zstd_send_flush(struct zstd_ostream *zstream, bool final)
141 {
142 	int ret;
143 
144 	if (zstream->flushed) {
145 		i_assert(zstream->output.pos == 0);
146 		return 1;
147 	}
148 
149 	if ((ret = o_stream_flush_parent_if_needed(&zstream->ostream)) <= 0)
150 		return ret;
151 
152 	if (zstream->output.pos == 0)
153 		ZSTD_flushStream(zstream->cstream, &zstream->output);
154 
155 	if ((ret = o_stream_zstd_send_outbuf(zstream)) <= 0)
156 		return ret;
157 
158 	if (!final)
159 		return 1;
160 
161 	if (!zstream->finished) {
162 		ret = ZSTD_endStream(zstream->cstream, &zstream->output);
163 		if (ZSTD_isError(ret) != 0) {
164 			o_stream_zstd_write_error(zstream, ret);
165 			return -1;
166 		}
167 		zstream->finished = TRUE;
168 	}
169 
170 	if ((ret = o_stream_zstd_send_outbuf(zstream)) <= 0)
171 		return ret;
172 
173 	if (final)
174 		zstream->flushed = TRUE;
175 	i_assert(zstream->output.pos == 0);
176 	return 1;
177 }
178 
o_stream_zstd_flush(struct ostream_private * stream)179 static int o_stream_zstd_flush(struct ostream_private *stream)
180 {
181 	struct zstd_ostream *zstream =
182 		container_of(stream, struct zstd_ostream, ostream);
183 
184 	int ret;
185 	if ((ret = o_stream_zstd_send_flush(zstream, stream->finished)) < 0)
186 		return -1;
187 	else if (ret > 0)
188 		return o_stream_flush_parent(stream);
189 	return ret;
190 }
191 
o_stream_zstd_close(struct iostream_private * stream,bool close_parent)192 static void o_stream_zstd_close(struct iostream_private *stream,
193 				bool close_parent)
194 {
195 	struct ostream_private *_ostream =
196 		container_of(stream, struct ostream_private, iostream);
197 	struct zstd_ostream *zstream =
198 		container_of(_ostream, struct zstd_ostream, ostream);
199 
200 	i_assert(zstream->ostream.finished ||
201 		 zstream->ostream.ostream.stream_errno != 0 ||
202 		 zstream->ostream.error_handling_disabled);
203 	if (zstream->cstream != NULL) {
204 		ZSTD_freeCStream(zstream->cstream);
205 		zstream->cstream = NULL;
206 	}
207 	i_free(zstream->outbuf);
208 	i_zero(&zstream->output);
209 	if (close_parent)
210 		o_stream_close(zstream->ostream.parent);
211 }
212 
213 struct ostream *
o_stream_create_zstd(struct ostream * output,int level)214 o_stream_create_zstd(struct ostream *output, int level)
215 {
216 	struct zstd_ostream *zstream;
217 	size_t ret;
218 
219 	i_assert(level >= compression_get_min_level_zstd() &&
220 		 level <= compression_get_max_level_zstd());
221 
222 	zstd_version_check();
223 
224 	zstream = i_new(struct zstd_ostream, 1);
225 	zstream->ostream.sendv = o_stream_zstd_sendv;
226 	zstream->ostream.flush = o_stream_zstd_flush;
227 	zstream->ostream.iostream.close = o_stream_zstd_close;
228 	zstream->cstream = ZSTD_createCStream();
229 	if (zstream->cstream == NULL)
230 		i_fatal_status(FATAL_OUTOFMEM, "zstd: Out of memory");
231 	ret = ZSTD_initCStream(zstream->cstream, level);
232 	if (ZSTD_isError(ret) != 0)
233 		o_stream_zstd_write_error(zstream, ret);
234 	else {
235 		zstream->outbuf = i_malloc(ZSTD_CStreamOutSize());
236 		zstream->output.dst = zstream->outbuf;
237 		zstream->output.size = ZSTD_CStreamOutSize();
238 	}
239 	return o_stream_create(&zstream->ostream, output,
240 			       o_stream_get_fd(output));
241 }
242 
243 #endif
244