1 /*
2  * Copyright (c) 2012 by Farsight Security, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /* Import. */
18 
19 #include "private.h"
20 
21 /* Private declarations. */
22 
23 struct nmsg_container {
24 	Nmsg__Nmsg	*nmsg;
25 	size_t		bufsz;
26 	size_t		estsz;
27 	bool		do_sequence;
28 };
29 
30 /* Export. */
31 
32 struct nmsg_container *
nmsg_container_init(size_t bufsz)33 nmsg_container_init(size_t bufsz) {
34 	struct nmsg_container *c;
35 
36 	c = calloc(1, sizeof(*c));
37 	if (c == NULL)
38 		return (NULL);
39 
40 	c->nmsg = calloc(1, sizeof(Nmsg__Nmsg));
41 	if (c->nmsg == NULL) {
42 		free(c);
43 		return (NULL);
44 	}
45 	nmsg__nmsg__init(c->nmsg);
46 
47 	c->bufsz = bufsz;
48 	if (c->bufsz < NMSG_WBUFSZ_MIN) {
49 		nmsg_container_destroy(&c);
50 		return (NULL);
51 	}
52 	c->estsz = NMSG_HDRLSZ_V2;
53 
54 	return (c);
55 }
56 
57 void
nmsg_container_destroy(struct nmsg_container ** c)58 nmsg_container_destroy(struct nmsg_container **c) {
59 	if (*c != NULL) {
60 		nmsg__nmsg__free_unpacked((*c)->nmsg, NULL);
61 		free(*c);
62 		*c = NULL;
63 	}
64 }
65 
66 void
nmsg_container_set_sequence(struct nmsg_container * c,bool do_sequence)67 nmsg_container_set_sequence(struct nmsg_container *c, bool do_sequence) {
68 	c->do_sequence = do_sequence;
69 }
70 
71 nmsg_res
nmsg_container_add(struct nmsg_container * c,nmsg_message_t msg)72 nmsg_container_add(struct nmsg_container *c, nmsg_message_t msg) {
73 	Nmsg__NmsgPayload *np;
74 	nmsg_res res;
75 	size_t np_len;
76 	void *tmp;
77 
78 	/* ensure that msg->np is up-to-date */
79 	res = _nmsg_message_serialize(msg);
80 	if (res != nmsg_res_success)
81 		return (res);
82 	assert(msg->np != NULL);
83 
84 	/* calculate size of serialized payload */
85 	np_len = _nmsg_payload_size(msg->np);
86 
87 	/* check for overflow */
88 	if (c->estsz != NMSG_HDRLSZ_V2 && c->estsz + np_len + 32 >= c->bufsz)
89 		return (nmsg_res_container_full);
90 
91 	/* allocate payload pointer */
92 	tmp = c->nmsg->payloads;
93 	c->nmsg->payloads = realloc(c->nmsg->payloads,
94 				    ++(c->nmsg->n_payloads) * sizeof(void *));
95 	if (c->nmsg->payloads == NULL) {
96 		c->nmsg->payloads = tmp;
97 		return (nmsg_res_memfail);
98 	}
99 
100 	/* detach payload from msg object */
101 	np = msg->np;
102 	msg->np = NULL;
103 
104 	/* add payload to container */
105 	c->nmsg->payloads[c->nmsg->n_payloads - 1] = np;
106 
107 	/* update estsz */
108 	c->estsz += np_len;
109 	/* payload field tag, length */
110 	c->estsz += 1+1;
111 	c->estsz += ((np_len >= (1 << 7)) ? 1 : 0);
112 	c->estsz += ((np_len >= (1 << 14)) ? 1 : 0);
113 	c->estsz += ((np_len >= (1 << 21)) ? 1 : 0);
114 	/* crc field */
115 	c->estsz += 6;
116 	/* sequence field, sequence_id field */
117 	c->estsz += (c->do_sequence ? (6+12) : 0);
118 
119 	/* check if container may need to be fragmented */
120 	if (c->estsz > c->bufsz)
121 		return (nmsg_res_container_overfull);
122 
123 	return (nmsg_res_success);
124 }
125 
126 size_t
nmsg_container_get_num_payloads(struct nmsg_container * c)127 nmsg_container_get_num_payloads(struct nmsg_container *c) {
128 	return (c->nmsg->n_payloads);
129 }
130 
131 nmsg_res
nmsg_container_serialize(struct nmsg_container * c,uint8_t ** pbuf,size_t * buf_len,bool do_header,bool do_zlib,uint32_t sequence,uint64_t sequence_id)132 nmsg_container_serialize(struct nmsg_container *c,
133 			 uint8_t **pbuf, size_t *buf_len,
134 			 bool do_header, bool do_zlib,
135 			 uint32_t sequence, uint64_t sequence_id)
136 {
137 	static const char magic[] = NMSG_MAGIC;
138 	size_t len = 0;
139 	uint8_t flags;
140 	uint8_t *buf;
141 	uint8_t *len_wire = NULL;
142 	uint16_t version;
143 
144 	*pbuf = buf = malloc((do_zlib) ? (2 * c->estsz) : (c->estsz));
145 	if (buf == NULL)
146 		return (nmsg_res_memfail);
147 
148 	if (do_header) {
149 		/* serialize header */
150 		memcpy(buf, magic, sizeof(magic));
151 		buf += sizeof(magic);
152 		flags = (do_zlib) ? NMSG_FLAG_ZLIB : 0;
153 		version = NMSG_PROTOCOL_VERSION | (flags << 8);
154 		version = htons(version);
155 		memcpy(buf, &version, sizeof(version));
156 		buf += sizeof(version);
157 
158 		/* save location where length of serialized NMSG container will be written */
159 		len_wire = buf;
160 		buf += sizeof(uint32_t);
161 	}
162 
163 	/* calculate payload CRCs */
164 	_nmsg_payload_calc_crcs(c->nmsg);
165 
166 	if (c->do_sequence) {
167 		c->nmsg->sequence = sequence;
168 		c->nmsg->sequence_id = sequence_id;
169 		c->nmsg->has_sequence = true;
170 		c->nmsg->has_sequence_id = true;
171 	}
172 
173 	/* serialize the container */
174 	if (do_zlib == false) {
175 		len = nmsg__nmsg__pack(c->nmsg, buf);
176 	} else {
177 		nmsg_res res;
178 		nmsg_zbuf_t zbuf;
179 		size_t ulen;
180 		u_char *zb_tmp;
181 
182 		zb_tmp = malloc(c->estsz);
183 		if (zb_tmp == NULL) {
184 			free(*pbuf);
185 			return (nmsg_res_memfail);
186 		}
187 
188 		zbuf = nmsg_zbuf_deflate_init();
189 		if (zbuf == NULL) {
190 			free(zb_tmp);
191 			free(*pbuf);
192 			return (nmsg_res_memfail);
193 		}
194 
195 		ulen = nmsg__nmsg__pack(c->nmsg, zb_tmp);
196 		len = 2 * c->estsz;
197 		res = nmsg_zbuf_deflate(zbuf, ulen, zb_tmp, &len, buf);
198 		nmsg_zbuf_destroy(&zbuf);
199 		free(zb_tmp);
200 		if (res != nmsg_res_success)
201 			return (res);
202 	}
203 
204 	if (do_header) {
205 		/* write the length of the container data */
206 		store_net32(len_wire, len);
207 		*buf_len = NMSG_HDRLSZ_V2 + len;
208 	} else {
209 		*buf_len = len;
210 	}
211 
212 	_nmsg_dprintf(6, "%s: buf= %p len= %zd\n", __func__, buf, len);
213 
214 	return (nmsg_res_success);
215 }
216 
217 nmsg_res
nmsg_container_deserialize(const uint8_t * buf,size_t buf_len,nmsg_message_t ** msgarray,size_t * n_msg)218 nmsg_container_deserialize(const uint8_t *buf, size_t buf_len,
219 			   nmsg_message_t **msgarray, size_t *n_msg)
220 {
221 	Nmsg__Nmsg *nmsg;
222 	nmsg_res res;
223 	ssize_t msgsize;
224 	unsigned flags;
225 
226 	/* deserialize the NMSG header */
227 	res = _input_nmsg_deserialize_header(buf, buf_len, &msgsize, &flags);
228 	if (res != nmsg_res_success)
229 		return (res);
230 	buf += NMSG_HDRLSZ_V2;
231 	buf_len -= NMSG_HDRLSZ_V2;
232 
233 	/* the entire NMSG container must be present */
234 	if ((size_t) msgsize != buf_len)
235 		return (nmsg_res_failure);
236 
237 	/* unpack message container */
238 	res = _input_nmsg_unpack_container2(buf, buf_len, flags, &nmsg);
239 	if (res != nmsg_res_success)
240 		return (res);
241 
242 	if (nmsg != NULL) {
243 		*msgarray = malloc(nmsg->n_payloads * sizeof(void *));
244 		if (*msgarray == NULL) {
245 			nmsg__nmsg__free_unpacked(nmsg, NULL);
246 			return (nmsg_res_memfail);
247 		}
248 		*n_msg = nmsg->n_payloads;
249 
250 		for (unsigned i = 0; i < nmsg->n_payloads; i++) {
251 			Nmsg__NmsgPayload *np;
252 			nmsg_message_t msg;
253 
254 			/* detach payload */
255 			np = nmsg->payloads[i];
256 			nmsg->payloads[i] = NULL;
257 
258 			/* convert payload to message object */
259 			msg = _nmsg_message_from_payload(np);
260 			if (msg == NULL) {
261 				free(*msgarray);
262 				*msgarray = NULL;
263 				*n_msg = 0;
264 				nmsg__nmsg__free_unpacked(nmsg, NULL);
265 				return (nmsg_res_memfail);
266 			}
267 			(*msgarray)[i] = msg;
268 		}
269 		nmsg->n_payloads = 0;
270 		free(nmsg->payloads);
271 		nmsg->payloads = NULL;
272 		nmsg__nmsg__free_unpacked(nmsg, NULL);
273 	}
274 
275 	return (nmsg_res_success);
276 }
277