1 /*-
2  * Copyright (c) 2012 Alexander Nasonov.
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in
13  *    the documentation and/or other materials provided with the
14  *    distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
19  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
20  * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
21  * INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING,
22  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24  * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
25  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
26  * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29 
30 #include <bpfjit.h>
31 
32 #include <stdint.h>
33 #include <string.h>
34 
35 #include "util.h"
36 #include "tests.h"
37 
38 static void
test_st1(void)39 test_st1(void)
40 {
41 	static struct bpf_insn insns[] = {
42 		BPF_STMT(BPF_LD+BPF_W+BPF_LEN, 0),
43 		BPF_STMT(BPF_ST, 0),
44 		BPF_STMT(BPF_LD+BPF_MEM, 0),
45 		BPF_STMT(BPF_RET+BPF_A, 0)
46 	};
47 
48 	size_t i;
49 	bpfjit_func_t code;
50 	uint8_t pkt[16]; /* the program doesn't read any data */
51 
52 	size_t insn_count = sizeof(insns) / sizeof(insns[0]);
53 
54 	CHECK(bpf_validate(insns, insn_count));
55 
56 	code = bpfjit_generate_code(NULL, insns, insn_count);
57 	REQUIRE(code != NULL);
58 
59 	for (i = 1; i <= sizeof(pkt); i++)
60 		CHECK(jitcall(code, pkt, i, sizeof(pkt)) == i);
61 
62 	bpfjit_free_code(code);
63 }
64 
65 static void
test_st2(void)66 test_st2(void)
67 {
68 	static struct bpf_insn insns[] = {
69 		BPF_STMT(BPF_LD+BPF_W+BPF_LEN, 0),
70 		BPF_STMT(BPF_ST, BPF_MEMWORDS-1),
71 		BPF_STMT(BPF_LD+BPF_MEM, 0),
72 		BPF_STMT(BPF_RET+BPF_A, 0)
73 	};
74 
75 	bpfjit_func_t code;
76 	uint8_t pkt[1]; /* the program doesn't read any data */
77 
78 	size_t insn_count = sizeof(insns) / sizeof(insns[0]);
79 
80 	CHECK(bpf_validate(insns, insn_count));
81 
82 	code = bpfjit_generate_code(NULL, insns, insn_count);
83 	REQUIRE(code != NULL);
84 
85 	CHECK(jitcall(code, pkt, 1, 1) == 0);
86 
87 	bpfjit_free_code(code);
88 }
89 
90 static void
test_st3(void)91 test_st3(void)
92 {
93 	static struct bpf_insn insns[] = {
94 		BPF_STMT(BPF_LD+BPF_W+BPF_LEN, 0),
95 		BPF_STMT(BPF_ST, 0),
96 		BPF_STMT(BPF_ALU+BPF_ADD+BPF_K, 100),
97 		BPF_STMT(BPF_ST, BPF_MEMWORDS-1),
98 		BPF_STMT(BPF_ALU+BPF_ADD+BPF_K, 200),
99 		BPF_JUMP(BPF_JMP+BPF_JEQ+BPF_K, 301, 2, 0),
100 		BPF_STMT(BPF_LD+BPF_MEM, BPF_MEMWORDS-1),
101 		BPF_STMT(BPF_RET+BPF_A, 0),
102 		BPF_STMT(BPF_LD+BPF_MEM, 0),
103 		BPF_STMT(BPF_RET+BPF_A, 0)
104 	};
105 
106 	bpfjit_func_t code;
107 	uint8_t pkt[2]; /* the program doesn't read any data */
108 
109 	size_t insn_count = sizeof(insns) / sizeof(insns[0]);
110 
111 	REQUIRE(BPF_MEMWORDS > 1);
112 
113 	CHECK(bpf_validate(insns, insn_count));
114 
115 	code = bpfjit_generate_code(NULL, insns, insn_count);
116 	REQUIRE(code != NULL);
117 
118 	CHECK(jitcall(code, pkt, 1, 1) == 1);
119 	CHECK(jitcall(code, pkt, 2, 2) == 102);
120 
121 	bpfjit_free_code(code);
122 }
123 
124 static void
test_st4(void)125 test_st4(void)
126 {
127 	static struct bpf_insn insns[] = {
128 		BPF_STMT(BPF_LD+BPF_W+BPF_LEN, 0),
129 		BPF_STMT(BPF_ST, 5),
130 		BPF_STMT(BPF_ALU+BPF_ADD+BPF_K, 100),
131 		BPF_STMT(BPF_ST, BPF_MEMWORDS-1),
132 		BPF_STMT(BPF_ALU+BPF_ADD+BPF_K, 200),
133 		BPF_JUMP(BPF_JMP+BPF_JEQ+BPF_K, 301, 2, 0),
134 		BPF_STMT(BPF_LD+BPF_MEM, BPF_MEMWORDS-1),
135 		BPF_STMT(BPF_RET+BPF_A, 0),
136 		BPF_STMT(BPF_LD+BPF_MEM, 5),
137 		BPF_STMT(BPF_RET+BPF_A, 0)
138 	};
139 
140 	bpfjit_func_t code;
141 	uint8_t pkt[2]; /* the program doesn't read any data */
142 
143 	size_t insn_count = sizeof(insns) / sizeof(insns[0]);
144 
145 	REQUIRE(BPF_MEMWORDS > 6);
146 
147 	CHECK(bpf_validate(insns, insn_count));
148 
149 	code = bpfjit_generate_code(NULL, insns, insn_count);
150 	REQUIRE(code != NULL);
151 
152 	CHECK(jitcall(code, pkt, 1, 1) == 1);
153 	CHECK(jitcall(code, pkt, 2, 2) == 102);
154 
155 	bpfjit_free_code(code);
156 }
157 
158 
159 static void
test_st5(void)160 test_st5(void)
161 {
162 	struct bpf_insn insns[5*BPF_MEMWORDS+2];
163 	size_t insn_count = sizeof(insns) / sizeof(insns[0]);
164 
165 	size_t k;
166 	bpfjit_func_t code;
167 	uint8_t pkt[BPF_MEMWORDS]; /* the program doesn't read any data */
168 
169 	memset(insns, 0, sizeof(insns));
170 
171 	/* for each k do M[k] = k */
172 	for (k = 0; k < BPF_MEMWORDS; k++) {
173 		insns[2*k].code   = BPF_LD+BPF_IMM;
174 		insns[2*k].k      = 3*k;
175 		insns[2*k+1].code = BPF_ST;
176 		insns[2*k+1].k    = k;
177 	}
178 
179 	/* load wirelen into A */
180 	insns[2*BPF_MEMWORDS].code = BPF_LD+BPF_W+BPF_LEN;
181 
182 	/* for each k, if (A == k + 1) return M[k] */
183 	for (k = 0; k < BPF_MEMWORDS; k++) {
184 		insns[2*BPF_MEMWORDS+3*k+1].code = BPF_JMP+BPF_JEQ+BPF_K;
185 		insns[2*BPF_MEMWORDS+3*k+1].k    = k+1;
186 		insns[2*BPF_MEMWORDS+3*k+1].jt   = 0;
187 		insns[2*BPF_MEMWORDS+3*k+1].jf   = 2;
188 		insns[2*BPF_MEMWORDS+3*k+2].code = BPF_LD+BPF_MEM;
189 		insns[2*BPF_MEMWORDS+3*k+2].k    = k;
190 		insns[2*BPF_MEMWORDS+3*k+3].code = BPF_RET+BPF_A;
191 		insns[2*BPF_MEMWORDS+3*k+3].k    = 0;
192 	}
193 
194 	insns[5*BPF_MEMWORDS+1].code = BPF_RET+BPF_K;
195 	insns[5*BPF_MEMWORDS+1].k    = UINT32_MAX;
196 
197 	CHECK(bpf_validate(insns, insn_count));
198 
199 	code = bpfjit_generate_code(NULL, insns, insn_count);
200 	REQUIRE(code != NULL);
201 
202 	for (k = 1; k <= sizeof(pkt); k++)
203 		CHECK(jitcall(code, pkt, k, k) == 3*(k-1));
204 
205 	bpfjit_free_code(code);
206 }
207 
208 void
test_st(void)209 test_st(void)
210 {
211 
212 	test_st1();
213 	test_st2();
214 	test_st3();
215 	test_st4();
216 	test_st5();
217 }
218