1 //--------------------------------------------------------------------------
2 // Copyright (C) 2014-2021 Cisco and/or its affiliates. All rights reserved.
3 // Copyright (C) 2005-2013 Sourcefire, Inc.
4 //
5 // This program is free software; you can redistribute it and/or modify it
6 // under the terms of the GNU General Public License Version 2 as published
7 // by the Free Software Foundation. You may not use, modify or distribute
8 // this program under any other version of the GNU General Public License.
9 //
10 // This program is distributed in the hope that it will be useful, but
11 // WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13 // General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License along
16 // with this program; if not, write to the Free Software Foundation, Inc.,
17 // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
18 //--------------------------------------------------------------------------
19
20 #ifdef HAVE_CONFIG_H
21 #include "config.h"
22 #endif
23
24 #include "expect_cache.h"
25
26 #include "detection/ips_context.h"
27 #include "hash/zhash.h"
28 #include "packet_io/sfdaq_instance.h"
29 #include "packet_tracer/packet_tracer.h"
30 #include "protocols/packet.h"
31 #include "protocols/vlan.h"
32 #include "pub_sub/expect_events.h"
33 #include "sfip/sf_ip.h"
34 #include "stream/stream.h" // FIXIT-M bad dependency
35 #include "time/packet_time.h"
36
37 using namespace snort;
38
39 /* Reasonably small, and prime */
40 // FIXIT-L size based on max_tcp + max_udp?
41 #define MAX_HASH 1021
42 #define MAX_LIST 8
43 #define MAX_DATA 4
44 #define MAX_WAIT 300
45
46 static THREAD_LOCAL std::vector<ExpectFlow*>* packet_expect_flows = nullptr;
47
~ExpectFlow()48 ExpectFlow::~ExpectFlow()
49 {
50 clear();
51 }
52
clear()53 void ExpectFlow::clear()
54 {
55 while (data)
56 {
57 FlowData* fd = data;
58 data = data->next;
59 delete fd;
60 }
61 data = nullptr;
62 }
63
add_flow_data(FlowData * fd)64 int ExpectFlow::add_flow_data(FlowData* fd)
65 {
66 if (data)
67 {
68 FlowData* prev_fd;
69 for (prev_fd = data; prev_fd && prev_fd->next; prev_fd = prev_fd->next);
70
71 prev_fd->next = fd;
72 }
73 else
74 data = fd;
75 return 0;
76 }
77
get_expect_flows()78 std::vector<ExpectFlow*>* ExpectFlow::get_expect_flows()
79 {
80 return packet_expect_flows;
81 }
82
reset_expect_flows()83 void ExpectFlow::reset_expect_flows()
84 {
85 if(packet_expect_flows)
86 packet_expect_flows->clear();
87 }
88
get_flow_data(unsigned id)89 FlowData* ExpectFlow::get_flow_data(unsigned id)
90 {
91 for (FlowData* p = data; p; p = p->next)
92 {
93 if (p->get_id() == id)
94 return p;
95 }
96 return nullptr;
97 }
98
99 struct ExpectNode
100 {
101 time_t expires = 0;
102 bool reversed_key = false;
103 int direction = 0;
104 bool swap_app_direction = false;
105 unsigned count = 0;
106 SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID;
107
108 ExpectFlow* head = nullptr;
109 ExpectFlow* tail = nullptr;
110
111 void clear(ExpectFlow*&);
112 };
113
clear(ExpectFlow * & list)114 void ExpectNode::clear(ExpectFlow*& list)
115 {
116 while (head)
117 {
118 ExpectFlow* p = head;
119 head = head->next;
120 p->clear();
121 p->next = list;
122 list = p;
123 }
124 tail = nullptr;
125 count = 0;
126 }
127
128 //-------------------------------------------------------------------------
129 // private ExpectCache methods
130 //-------------------------------------------------------------------------
131
prune_lru()132 void ExpectCache::prune_lru()
133 {
134 ExpectNode* node = static_cast<ExpectNode*>( hash_table->lru_first() );
135 assert(node);
136 node->clear(free_list);
137 hash_table->release();
138 ++prunes;
139 }
140
find_node_by_packet(Packet * p,FlowKey & key)141 ExpectNode* ExpectCache::find_node_by_packet(Packet* p, FlowKey &key)
142 {
143 if (!hash_table->get_num_nodes())
144 return nullptr;
145
146 const SfIp* srcIP = p->ptrs.ip_api.get_src();
147 const SfIp* dstIP = p->ptrs.ip_api.get_dst();
148 uint16_t vlanId = (p->proto_bits & PROTO_BIT__VLAN) ? layer::get_vlan_layer(p)->vid() : 0;
149 uint32_t mplsId = (p->proto_bits & PROTO_BIT__MPLS) ? p->ptrs.mplsHdr.label : 0;
150 PktType type = p->type();
151 IpProtocol ip_proto = p->get_ip_proto_next();
152
153 bool reversed_key = key.init(p->context->conf, type, ip_proto, dstIP, p->ptrs.dp,
154 srcIP, p->ptrs.sp, vlanId, mplsId, *p->pkth);
155
156 /*
157 Lookup order:
158 1. Full match.
159 2. Unknown (zeroed) source port.
160 3. Unknown (zeroed) destination port.
161 If the client/server addresses were reversed during key creation, the
162 source port will be in port_l.
163 */
164 // FIXIT-P X This should be optimized to only do full matches when full keys
165 // are present, likewise for partial keys.
166 ExpectNode* node = static_cast<ExpectNode*>( hash_table->get_user_data(&key) );
167 if (!node)
168 {
169 // FIXIT-M X This logic could fail if IPs were equal because the original key
170 // would always have been created with a 0 for src or dst port and put the
171 // known port in port_h.
172 uint16_t port1;
173 uint16_t port2;
174
175 if (reversed_key)
176 {
177 port1 = key.port_l;
178 port2 = 0;
179 key.port_l = 0;
180 }
181 else
182 {
183 port1 = 0;
184 port2 = key.port_h;
185 key.port_h = 0;
186 }
187 node = static_cast<ExpectNode*> ( hash_table->get_user_data(&key) );
188 if (!node)
189 {
190 key.port_l = port1;
191 key.port_h = port2;
192 node = static_cast<ExpectNode*> ( hash_table->get_user_data(&key) );
193 if (!node)
194 return nullptr;
195 }
196 }
197 if (!node->head || (p->pkth->ts.tv_sec > node->expires))
198 {
199 if (node->head)
200 node->clear(free_list);
201 hash_table->release_node(&key);
202 return nullptr;
203 }
204 /* Make sure the packet direction is correct */
205 switch (node->direction)
206 {
207 case SSN_DIR_BOTH:
208 break;
209
210 case SSN_DIR_FROM_CLIENT:
211 case SSN_DIR_FROM_SERVER:
212 if (node->reversed_key != reversed_key)
213 return nullptr;
214 break;
215 }
216
217 return node;
218 }
219
process_expected(ExpectNode * node,FlowKey & key,Packet * p,Flow * lws)220 bool ExpectCache::process_expected(ExpectNode* node, FlowKey& key, Packet* p, Flow* lws)
221 {
222 ExpectFlow* head;
223 FlowData* fd;
224 bool ignoring = false;
225
226 assert(node->count && node->head);
227
228 /* Pull the first set of expected flow data off of the Expect node and apply it
229 in its entirety to the target flow. Discard the set (and potentially the
230 entire node, it empty) after this is done. */
231 node->count--;
232 head = node->head;
233 node->head = head->next;
234
235 while ((fd = head->data))
236 {
237 head->data = fd->next;
238 lws->set_flow_data(fd);
239 ++realized;
240 fd->handle_expected(p);
241 }
242 head->next = free_list;
243 free_list = head;
244
245 /* If this is 0, we're ignoring, otherwise setting id of new session */
246 if (!node->snort_protocol_id)
247 ignoring = node->direction != 0;
248 else
249 {
250 lws->ssn_state.snort_protocol_id = node->snort_protocol_id;
251 if ( node->swap_app_direction)
252 lws->flags.app_direction_swapped = true;
253 }
254
255 if (!node->count)
256 hash_table->release_node(&key);
257
258 return ignoring;
259 }
260
261 //-------------------------------------------------------------------------
262 // public ExpectCache methods
263 //-------------------------------------------------------------------------
264
ExpectCache(uint32_t max)265 ExpectCache::ExpectCache(uint32_t max)
266 {
267 // -size forces use of abs(size) ie w/o bumping up
268 hash_table = new ZHash(-MAX_HASH, sizeof(FlowKey));
269 nodes = new ExpectNode[max];
270 for (unsigned i = 0; i < max; ++i)
271 hash_table->push(nodes + i);
272
273 /* Preallocate a pool of ExpectFlows big enough to handle the worst case
274 requirement (max number of nodes * max flows per node) and add them all
275 to an initial free list. */
276 max *= MAX_LIST;
277 pool = new ExpectFlow[max];
278 free_list = nullptr;
279 for (unsigned i = 0; i < max; ++i)
280 {
281 ExpectFlow* p = pool + i;
282 p->data = nullptr;
283 p->next = free_list;
284 free_list = p;
285 }
286
287 if (packet_expect_flows == nullptr)
288 packet_expect_flows = new std::vector<ExpectFlow*>;
289 }
290
~ExpectCache()291 ExpectCache::~ExpectCache()
292 {
293 delete hash_table;
294 delete[] nodes;
295 delete[] pool;
296 delete packet_expect_flows;
297 packet_expect_flows = nullptr;
298 }
299
300 /**Either expect or expect future session.
301 *
302 * Inspectors may add sessions to be expected altogether or to be associated
303 * with some data. For example, FTP inspector may add data channel that
304 * should be expected. Alternatively, FTP inspector may add session with
305 * snort protocol ID FTP-DATA.
306 *
307 * It is assumed that only one of cliPort or srvPort should be known (!0). This
308 * violation of this assumption will cause hash collision that will cause some
309 * session to be not expected and expected. This will occur only rarely and
310 * therefore acceptable design optimization.
311 *
312 * Also, snort_protocol_id is assumed to be consistent between different
313 * inspectors. Each session can be assigned only one snort protocol ID.
314 * When new snort_protocol_id mismatches existing snort_protocol_id, new
315 * snort_protocol_id and associated data is not stored.
316 *
317 */
add_flow(const Packet * ctrlPkt,PktType type,IpProtocol ip_proto,const SfIp * cliIP,uint16_t cliPort,const SfIp * srvIP,uint16_t srvPort,char direction,FlowData * fd,SnortProtocolId snort_protocol_id,bool swap_app_direction,bool expect_multi,bool bidirectional)318 int ExpectCache::add_flow(const Packet *ctrlPkt, PktType type, IpProtocol ip_proto,
319 const SfIp* cliIP, uint16_t cliPort, const SfIp* srvIP, uint16_t srvPort, char direction,
320 FlowData* fd, SnortProtocolId snort_protocol_id, bool swap_app_direction, bool expect_multi,
321 bool bidirectional)
322 {
323 /* Just pull the VLAN ID, MPLS ID, and Address Space ID from the
324 control packet until we have a use case for not doing so. */
325 uint16_t vlanId = (ctrlPkt->proto_bits & PROTO_BIT__VLAN) ? layer::get_vlan_layer(ctrlPkt)->vid() : 0;
326 uint32_t mplsId = (ctrlPkt->proto_bits & PROTO_BIT__MPLS) ? ctrlPkt->ptrs.mplsHdr.label : 0;
327 FlowKey key;
328
329 bool reversed_key = key.init(ctrlPkt->context->conf, type, ip_proto, cliIP, cliPort,
330 srvIP, srvPort, vlanId, mplsId, *ctrlPkt->pkth);
331
332 bool new_node = false;
333 ExpectNode* node = static_cast<ExpectNode*> ( hash_table->get_user_data(&key) );
334 if ( !node )
335 {
336 if ( hash_table->full() )
337 prune_lru();
338 node = static_cast<ExpectNode*> ( hash_table->get(&key) );
339 assert(node);
340 new_node = true;
341 }
342 else if ( packet_time() > node->expires )
343 {
344 // node is past its expiration date, whack it and reuse it.
345 node->clear(free_list);
346 new_node = true;
347 }
348
349 ExpectFlow* last = nullptr;
350 if ( !new_node )
351 {
352 // reject if the snort_protocol_id doesn't match
353 if ( node->snort_protocol_id != snort_protocol_id )
354 {
355 if ( node->snort_protocol_id && snort_protocol_id )
356 return -1;
357 node->snort_protocol_id = snort_protocol_id;
358 }
359
360 last = node->tail;
361 if ( last )
362 {
363 FlowData* lfd = last->data;
364 while ( lfd )
365 {
366 if ( lfd->get_id() == fd->get_id() )
367 {
368 last = nullptr;
369 break;
370 }
371 lfd = lfd->next;
372 }
373 }
374 }
375 else
376 {
377 node->snort_protocol_id = snort_protocol_id;
378 node->reversed_key = reversed_key;
379 node->direction = direction;
380 node->swap_app_direction = swap_app_direction;
381 node->head = node->tail = nullptr;
382 node->count = 0;
383 last = nullptr;
384 /* Only add TCP and UDP expected flows for now via the DAQ module. */
385 if ((ip_proto == IpProtocol::TCP || ip_proto == IpProtocol::UDP) && ctrlPkt->daq_instance)
386 {
387 if (PacketTracer::is_active())
388 {
389 SfIpString sipstr;
390 SfIpString dipstr;
391 cliIP->ntop(sipstr, sizeof(sipstr));
392 srvIP->ntop(dipstr, sizeof(dipstr));
393 PacketTracer::log("Create expected channel request sent with %s -> %s %hu %hhu\n",
394 dipstr, sipstr, srvPort, static_cast<uint8_t>(ip_proto));
395 }
396 unsigned flag = 0;
397 if (expect_multi)
398 flag |= DAQ_EFLOW_ALLOW_MULTIPLE;
399
400 if (bidirectional)
401 flag |= DAQ_EFLOW_BIDIRECTIONAL;
402
403 ctrlPkt->daq_instance->add_expected(ctrlPkt, cliIP, cliPort, srvIP, srvPort,
404 ip_proto, 1000, flag);
405 }
406 }
407
408 bool new_expect_flow = false;
409 if ( !last )
410 {
411 if ( node->count >= MAX_LIST )
412 {
413 // fail when maxed out
414 ++overflows;
415 return -1;
416 }
417 last = free_list;
418 free_list = free_list->next;
419
420 if ( !node->tail )
421 node->head = last;
422 else
423 node->tail->next = last;
424
425 node->tail = last;
426 last->next = nullptr;
427 node->count++;
428 new_expect_flow = true;
429 }
430 last->add_flow_data(fd);
431 node->expires = packet_time() + MAX_WAIT;
432 ++expects;
433 if ( new_expect_flow )
434 {
435 // chain all expected flows created by this packet
436 packet_expect_flows->emplace_back(last);
437
438 ExpectEvent event(ctrlPkt, last, fd);
439 DataBus::publish(EXPECT_EVENT_TYPE_EARLY_SESSION_CREATE_KEY, event, ctrlPkt->flow);
440 }
441 return 0;
442 }
443
is_expected(Packet * p)444 bool ExpectCache::is_expected(Packet* p)
445 {
446 FlowKey key;
447 return (find_node_by_packet(p, key) != nullptr);
448 }
449
check(Packet * p,Flow * lws)450 bool ExpectCache::check(Packet* p, Flow* lws)
451 {
452 FlowKey key;
453 ExpectNode* node = find_node_by_packet(p, key);
454
455 if (!node)
456 return false;
457
458 return process_expected(node, key, p, lws);
459 }
460
461