1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2022 Alexander V. Chernikov <melifaro@FreeBSD.org>
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18  * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25  * SUCH DAMAGE.
26  */
27 
28 #ifndef _NETLINK_NETLINK_MESSAGE_PARSER_H_
29 #define _NETLINK_NETLINK_MESSAGE_PARSER_H_
30 
31 #ifdef _KERNEL
32 
33 #include <sys/bitset.h>
34 
35 /*
36  * It is not meant to be included directly
37  */
38 
39 /* Parsing state */
40 struct linear_buffer {
41 	char		*base;	/* Base allocated memory pointer */
42 	uint32_t	offset;	/* Currently used offset */
43 	uint32_t	size;	/* Total buffer size */
44 };
45 
46 static inline void *
47 lb_alloc(struct linear_buffer *lb, int len)
48 {
49 	len = roundup2(len, sizeof(uint64_t));
50 	if (lb->offset + len > lb->size)
51 		return (NULL);
52 	void *data = (void *)(lb->base + lb->offset);
53 	lb->offset += len;
54 	return (data);
55 }
56 
57 static inline void
58 lb_clear(struct linear_buffer *lb)
59 {
60 	memset(lb->base, 0, lb->size);
61 	lb->offset = 0;
62 }
63 
64 #define	NL_MAX_ERROR_BUF	128
65 #define	SCRATCH_BUFFER_SIZE	(1024 + NL_MAX_ERROR_BUF)
66 struct nl_pstate {
67         struct linear_buffer    lb;		/* Per-message scratch buffer */
68         struct nlpcb		*nlp;		/* Originator socket */
69 	struct nl_writer	*nw;		/* Message writer to use */
70 	struct nlmsghdr		*hdr;		/* Current parsed message header */
71 	uint32_t		err_off;	/* error offset from hdr start */
72         int			error;		/* last operation error */
73 	char			*err_msg;	/* Description of last error */
74 	struct nlattr		*cookie;	/* NLA to return to the userspace */
75 	bool			strict;		/* Strict parsing required */
76 };
77 
78 static inline void *
79 npt_alloc(struct nl_pstate *npt, int len)
80 {
81 	return (lb_alloc(&npt->lb, len));
82 }
83 #define npt_alloc_sockaddr(_npt, _len)  ((struct sockaddr *)(npt_alloc(_npt, _len)))
84 
85 typedef int parse_field_f(void *hdr, struct nl_pstate *npt,
86     void *target);
87 struct nlfield_parser {
88 	uint16_t	off_in;
89 	uint16_t	off_out;
90 	parse_field_f	*cb;
91 };
92 static const struct nlfield_parser nlf_p_empty[] = {};
93 
94 int nlf_get_ifp(void *src, struct nl_pstate *npt, void *target);
95 int nlf_get_ifpz(void *src, struct nl_pstate *npt, void *target);
96 int nlf_get_u8(void *src, struct nl_pstate *npt, void *target);
97 int nlf_get_u16(void *src, struct nl_pstate *npt, void *target);
98 int nlf_get_u32(void *src, struct nl_pstate *npt, void *target);
99 int nlf_get_u8_u32(void *src, struct nl_pstate *npt, void *target);
100 
101 
102 struct nlattr_parser;
103 typedef int parse_attr_f(struct nlattr *attr, struct nl_pstate *npt,
104     const void *arg, void *target);
105 struct nlattr_parser {
106 	uint16_t			type;	/* Attribute type */
107 	uint16_t			off;	/* field offset in the target structure */
108 	parse_attr_f			*cb;	/* parser function to call */
109 	const void			*arg;
110 };
111 
112 typedef bool strict_parser_f(void *hdr, struct nl_pstate *npt);
113 
114 struct nlhdr_parser {
115 	int				nl_hdr_off; /* aligned netlink header size */
116 	int				out_hdr_off; /* target header size */
117 	int				fp_size;
118 	int				np_size;
119 	const struct nlfield_parser	*fp; /* array of header field parsers */
120 	const struct nlattr_parser	*np; /* array of attribute parsers */
121 	strict_parser_f			*sp; /* Parser function */
122 };
123 
124 #define	NL_DECLARE_PARSER(_name, _t, _fp, _np)		\
125 static const struct nlhdr_parser _name = {		\
126 	.nl_hdr_off = sizeof(_t),			\
127 	.fp = &((_fp)[0]),				\
128 	.np = &((_np)[0]),				\
129 	.fp_size = NL_ARRAY_LEN(_fp),			\
130 	.np_size = NL_ARRAY_LEN(_np),			\
131 }
132 
133 #define	NL_DECLARE_STRICT_PARSER(_name, _t, _sp, _fp, _np)\
134 static const struct nlhdr_parser _name = {		\
135 	.nl_hdr_off = sizeof(_t),			\
136 	.fp = &((_fp)[0]),				\
137 	.np = &((_np)[0]),				\
138 	.fp_size = NL_ARRAY_LEN(_fp),			\
139 	.np_size = NL_ARRAY_LEN(_np),			\
140 	.sp = _sp,					\
141 }
142 
143 #define	NL_DECLARE_ARR_PARSER(_name, _t, _o, _fp, _np)	\
144 static const struct nlhdr_parser _name = {		\
145 	.nl_hdr_off = sizeof(_t),			\
146 	.out_hdr_off = sizeof(_o),			\
147 	.fp = &((_fp)[0]),				\
148 	.np = &((_np)[0]),				\
149 	.fp_size = NL_ARRAY_LEN(_fp),			\
150 	.np_size = NL_ARRAY_LEN(_np),			\
151 }
152 
153 #define	NL_DECLARE_ATTR_PARSER(_name, _np)		\
154 static const struct nlhdr_parser _name = {		\
155 	.np = &((_np)[0]),				\
156 	.np_size = NL_ARRAY_LEN(_np),			\
157 }
158 
159 #define	NL_ATTR_BMASK_SIZE	128
160 BITSET_DEFINE(nlattr_bmask, NL_ATTR_BMASK_SIZE);
161 
162 void nl_get_attrs_bmask_raw(struct nlattr *nla_head, int len, struct nlattr_bmask *bm);
163 bool nl_has_attr(const struct nlattr_bmask *bm, unsigned int nla_type);
164 
165 int nl_parse_attrs_raw(struct nlattr *nla_head, int len, const struct nlattr_parser *ps,
166     int pslen, struct nl_pstate *npt, void *target);
167 
168 int nlattr_get_flag(struct nlattr *nla, struct nl_pstate *npt,
169     const void *arg, void *target);
170 int nlattr_get_ip(struct nlattr *nla, struct nl_pstate *npt,
171     const void *arg, void *target);
172 int nlattr_get_uint16(struct nlattr *nla, struct nl_pstate *npt,
173     const void *arg, void *target);
174 int nlattr_get_uint32(struct nlattr *nla, struct nl_pstate *npt,
175     const void *arg, void *target);
176 int nlattr_get_uint64(struct nlattr *nla, struct nl_pstate *npt,
177     const void *arg, void *target);
178 int nlattr_get_ifp(struct nlattr *nla, struct nl_pstate *npt,
179     const void *arg, void *target);
180 int nlattr_get_ifpz(struct nlattr *nla, struct nl_pstate *npt,
181     const void *arg, void *target);
182 int nlattr_get_ipvia(struct nlattr *nla, struct nl_pstate *npt,
183     const void *arg, void *target);
184 int nlattr_get_string(struct nlattr *nla, struct nl_pstate *npt,
185     const void *arg, void *target);
186 int nlattr_get_stringn(struct nlattr *nla, struct nl_pstate *npt,
187     const void *arg, void *target);
188 int nlattr_get_nla(struct nlattr *nla, struct nl_pstate *npt,
189     const void *arg, void *target);
190 int nlattr_get_nested(struct nlattr *nla, struct nl_pstate *npt,
191     const void *arg, void *target);
192 
193 bool nlmsg_report_err_msg(struct nl_pstate *npt, const char *fmt, ...);
194 
195 #define	NLMSG_REPORT_ERR_MSG(_npt, _fmt, ...) {	\
196 	nlmsg_report_err_msg(_npt, _fmt, ## __VA_ARGS__); \
197 	NLP_LOG(LOG_DEBUG, (_npt)->nlp, _fmt, ## __VA_ARGS__); \
198 }
199 
200 bool nlmsg_report_err_offset(struct nl_pstate *npt, uint32_t off);
201 
202 void nlmsg_report_cookie(struct nl_pstate *npt, struct nlattr *nla);
203 void nlmsg_report_cookie_u32(struct nl_pstate *npt, uint32_t val);
204 
205 /*
206  * Have it inline so compiler can optimize field accesses into
207  * the list of direct function calls without iteration.
208  */
209 static inline int
210 nl_parse_header(void *hdr, int len, const struct nlhdr_parser *parser,
211     struct nl_pstate *npt, void *target)
212 {
213 	int error;
214 
215 	if (__predict_false(len < parser->nl_hdr_off)) {
216 		if (npt->strict) {
217 			nlmsg_report_err_msg(npt, "header too short: expected %d, got %d",
218 			    parser->nl_hdr_off, len);
219 			return (EINVAL);
220 		}
221 
222 		/* Compat with older applications: pretend there's a full header */
223 		void *tmp_hdr = npt_alloc(npt, parser->nl_hdr_off);
224 		if (tmp_hdr == NULL)
225 			return (EINVAL);
226 		memcpy(tmp_hdr, hdr, len);
227 		hdr = tmp_hdr;
228 		len = parser->nl_hdr_off;
229 	}
230 
231 	if (npt->strict && parser->sp != NULL && !parser->sp(hdr, npt))
232 		return (EINVAL);
233 
234 	/* Extract fields first */
235 	for (int i = 0; i < parser->fp_size; i++) {
236 		const struct nlfield_parser *fp = &parser->fp[i];
237 		void *src = (char *)hdr + fp->off_in;
238 		void *dst = (char *)target + fp->off_out;
239 
240 		error = fp->cb(src, npt, dst);
241 		if (error != 0)
242 			return (error);
243 	}
244 
245 	struct nlattr *nla_head = (struct nlattr *)((char *)hdr + parser->nl_hdr_off);
246 	error = nl_parse_attrs_raw(nla_head, len - parser->nl_hdr_off, parser->np,
247 	    parser->np_size, npt, target);
248 
249 	return (error);
250 }
251 
252 static inline int
253 nl_parse_nested(struct nlattr *nla, const struct nlhdr_parser *parser,
254     struct nl_pstate *npt, void *target)
255 {
256 	struct nlattr *nla_head = (struct nlattr *)NLA_DATA(nla);
257 
258 	return (nl_parse_attrs_raw(nla_head, NLA_DATA_LEN(nla), parser->np,
259 	    parser->np_size, npt, target));
260 }
261 
262 /*
263  * Checks that attributes are sorted by attribute type.
264  */
265 static inline void
266 nl_verify_parsers(const struct nlhdr_parser **parser, int count)
267 {
268 #ifdef INVARIANTS
269 	for (int i = 0; i < count; i++) {
270 		const struct nlhdr_parser *p = parser[i];
271 		int attr_type = 0;
272 		for (int j = 0; j < p->np_size; j++) {
273 			MPASS(p->np[j].type > attr_type);
274 			attr_type = p->np[j].type;
275 		}
276 	}
277 #endif
278 }
279 void nl_verify_parsers(const struct nlhdr_parser **parser, int count);
280 #define	NL_VERIFY_PARSERS(_p)	nl_verify_parsers((_p), NL_ARRAY_LEN(_p))
281 
282 static inline int
283 nl_parse_nlmsg(struct nlmsghdr *hdr, const struct nlhdr_parser *parser,
284     struct nl_pstate *npt, void *target)
285 {
286 	return (nl_parse_header(hdr + 1, hdr->nlmsg_len - sizeof(*hdr), parser, npt, target));
287 }
288 
289 static inline void
290 nl_get_attrs_bmask_nlmsg(struct nlmsghdr *hdr, const struct nlhdr_parser *parser,
291     struct nlattr_bmask *bm)
292 {
293 	struct nlattr *nla_head;
294 
295 	nla_head = (struct nlattr *)((char *)(hdr + 1) + parser->nl_hdr_off);
296 	int len = hdr->nlmsg_len - sizeof(*hdr) - parser->nl_hdr_off;
297 
298 	nl_get_attrs_bmask_raw(nla_head, len, bm);
299 }
300 
301 #endif
302 #endif
303