1 //--------------------------------------------------------------------------
2 // Copyright (C) 2014-2021 Cisco and/or its affiliates. All rights reserved.
3 // Copyright (C) 2005-2013 Sourcefire, Inc.
4 //
5 // This program is free software; you can redistribute it and/or modify it
6 // under the terms of the GNU General Public License Version 2 as published
7 // by the Free Software Foundation.  You may not use, modify or distribute
8 // this program under any other version of the GNU General Public License.
9 //
10 // This program is distributed in the hope that it will be useful, but
11 // WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 // General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License along
16 // with this program; if not, write to the Free Software Foundation, Inc.,
17 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
18 //--------------------------------------------------------------------------
19 
20 #ifdef HAVE_CONFIG_H
21 #include "config.h"
22 #endif
23 
24 #include "expect_cache.h"
25 
26 #include "detection/ips_context.h"
27 #include "hash/zhash.h"
28 #include "packet_io/sfdaq_instance.h"
29 #include "packet_tracer/packet_tracer.h"
30 #include "protocols/packet.h"
31 #include "protocols/vlan.h"
32 #include "pub_sub/expect_events.h"
33 #include "sfip/sf_ip.h"
34 #include "stream/stream.h"      // FIXIT-M bad dependency
35 #include "time/packet_time.h"
36 
37 using namespace snort;
38 
39 /* Reasonably small, and prime */
40 // FIXIT-L size based on max_tcp + max_udp?
41 #define MAX_HASH 1021
42 #define MAX_LIST    8
43 #define MAX_DATA    4
44 #define MAX_WAIT  300
45 
46 static THREAD_LOCAL std::vector<ExpectFlow*>* packet_expect_flows = nullptr;
47 
~ExpectFlow()48 ExpectFlow::~ExpectFlow()
49 {
50     clear();
51 }
52 
clear()53 void ExpectFlow::clear()
54 {
55     while (data)
56     {
57         FlowData* fd = data;
58         data = data->next;
59         delete fd;
60     }
61     data = nullptr;
62 }
63 
add_flow_data(FlowData * fd)64 int ExpectFlow::add_flow_data(FlowData* fd)
65 {
66     if (data)
67     {
68         FlowData* prev_fd;
69         for (prev_fd = data; prev_fd && prev_fd->next; prev_fd = prev_fd->next);
70 
71         prev_fd->next = fd;
72     }
73     else
74         data = fd;
75     return 0;
76 }
77 
get_expect_flows()78 std::vector<ExpectFlow*>* ExpectFlow::get_expect_flows()
79 {
80     return packet_expect_flows;
81 }
82 
reset_expect_flows()83 void ExpectFlow::reset_expect_flows()
84 {
85     if(packet_expect_flows)
86         packet_expect_flows->clear();
87 }
88 
get_flow_data(unsigned id)89 FlowData* ExpectFlow::get_flow_data(unsigned id)
90 {
91     for (FlowData* p = data; p; p = p->next)
92     {
93         if (p->get_id() == id)
94             return p;
95     }
96     return nullptr;
97 }
98 
99 struct ExpectNode
100 {
101     time_t expires = 0;
102     bool reversed_key = false;
103     int direction = 0;
104     bool swap_app_direction = false;
105     unsigned count = 0;
106     SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID;
107 
108     ExpectFlow* head = nullptr;
109     ExpectFlow* tail = nullptr;
110 
111     void clear(ExpectFlow*&);
112 };
113 
clear(ExpectFlow * & list)114 void ExpectNode::clear(ExpectFlow*& list)
115 {
116     while (head)
117     {
118         ExpectFlow* p = head;
119         head = head->next;
120         p->clear();
121         p->next = list;
122         list = p;
123     }
124     tail = nullptr;
125     count = 0;
126 }
127 
128 //-------------------------------------------------------------------------
129 // private ExpectCache methods
130 //-------------------------------------------------------------------------
131 
prune_lru()132 void ExpectCache::prune_lru()
133 {
134     ExpectNode* node = static_cast<ExpectNode*>( hash_table->lru_first() );
135     assert(node);
136     node->clear(free_list);
137     hash_table->release();
138     ++prunes;
139 }
140 
find_node_by_packet(Packet * p,FlowKey & key)141 ExpectNode* ExpectCache::find_node_by_packet(Packet* p, FlowKey &key)
142 {
143     if (!hash_table->get_num_nodes())
144         return nullptr;
145 
146     const SfIp* srcIP = p->ptrs.ip_api.get_src();
147     const SfIp* dstIP = p->ptrs.ip_api.get_dst();
148     uint16_t vlanId = (p->proto_bits & PROTO_BIT__VLAN) ? layer::get_vlan_layer(p)->vid() : 0;
149     uint32_t mplsId = (p->proto_bits & PROTO_BIT__MPLS) ? p->ptrs.mplsHdr.label : 0;
150     PktType type = p->type();
151     IpProtocol ip_proto = p->get_ip_proto_next();
152 
153     bool reversed_key = key.init(p->context->conf, type, ip_proto, dstIP, p->ptrs.dp,
154         srcIP, p->ptrs.sp, vlanId, mplsId, *p->pkth);
155 
156     /*
157         Lookup order:
158             1. Full match.
159             2. Unknown (zeroed) source port.
160             3. Unknown (zeroed) destination port.
161         If the client/server addresses were reversed during key creation, the
162         source port will be in port_l.
163     */
164     // FIXIT-P X This should be optimized to only do full matches when full keys
165     //      are present, likewise for partial keys.
166     ExpectNode* node = static_cast<ExpectNode*>( hash_table->get_user_data(&key) );
167     if (!node)
168     {
169         // FIXIT-M X This logic could fail if IPs were equal because the original key
170         // would always have been created with a 0 for src or dst port and put the
171         // known port in port_h.
172         uint16_t port1;
173         uint16_t port2;
174 
175         if (reversed_key)
176         {
177             port1 = key.port_l;
178             port2 = 0;
179             key.port_l = 0;
180         }
181         else
182         {
183             port1 = 0;
184             port2 = key.port_h;
185             key.port_h = 0;
186         }
187         node = static_cast<ExpectNode*> ( hash_table->get_user_data(&key) );
188         if (!node)
189         {
190             key.port_l = port1;
191             key.port_h = port2;
192             node = static_cast<ExpectNode*> ( hash_table->get_user_data(&key) );
193             if (!node)
194                 return nullptr;
195         }
196     }
197     if (!node->head || (p->pkth->ts.tv_sec > node->expires))
198     {
199         if (node->head)
200             node->clear(free_list);
201         hash_table->release_node(&key);
202         return nullptr;
203     }
204     /* Make sure the packet direction is correct */
205     switch (node->direction)
206     {
207         case SSN_DIR_BOTH:
208             break;
209 
210         case SSN_DIR_FROM_CLIENT:
211         case SSN_DIR_FROM_SERVER:
212             if (node->reversed_key != reversed_key)
213                 return nullptr;
214             break;
215     }
216 
217     return node;
218 }
219 
process_expected(ExpectNode * node,FlowKey & key,Packet * p,Flow * lws)220 bool ExpectCache::process_expected(ExpectNode* node, FlowKey& key, Packet* p, Flow* lws)
221 {
222     ExpectFlow* head;
223     FlowData* fd;
224     bool ignoring = false;
225 
226     assert(node->count && node->head);
227 
228     /* Pull the first set of expected flow data off of the Expect node and apply it
229         in its entirety to the target flow.  Discard the set (and potentially the
230         entire node, it empty) after this is done. */
231     node->count--;
232     head = node->head;
233     node->head = head->next;
234 
235     while ((fd = head->data))
236     {
237         head->data = fd->next;
238         lws->set_flow_data(fd);
239         ++realized;
240         fd->handle_expected(p);
241     }
242     head->next = free_list;
243     free_list = head;
244 
245     /* If this is 0, we're ignoring, otherwise setting id of new session */
246     if (!node->snort_protocol_id)
247         ignoring = node->direction != 0;
248     else
249     {
250         lws->ssn_state.snort_protocol_id = node->snort_protocol_id;
251         if ( node->swap_app_direction)
252             lws->flags.app_direction_swapped = true;
253     }
254 
255     if (!node->count)
256         hash_table->release_node(&key);
257 
258     return ignoring;
259 }
260 
261 //-------------------------------------------------------------------------
262 // public ExpectCache methods
263 //-------------------------------------------------------------------------
264 
ExpectCache(uint32_t max)265 ExpectCache::ExpectCache(uint32_t max)
266 {
267     // -size forces use of abs(size) ie w/o bumping up
268     hash_table = new ZHash(-MAX_HASH, sizeof(FlowKey));
269     nodes = new ExpectNode[max];
270     for (unsigned i = 0; i < max; ++i)
271         hash_table->push(nodes + i);
272 
273     /* Preallocate a pool of ExpectFlows big enough to handle the worst case
274         requirement (max number of nodes * max flows per node) and add them all
275         to an initial free list. */
276     max *= MAX_LIST;
277     pool = new ExpectFlow[max];
278     free_list = nullptr;
279     for (unsigned i = 0; i < max; ++i)
280     {
281         ExpectFlow* p = pool + i;
282         p->data = nullptr;
283         p->next = free_list;
284         free_list = p;
285     }
286 
287     if (packet_expect_flows == nullptr)
288         packet_expect_flows = new std::vector<ExpectFlow*>;
289 }
290 
~ExpectCache()291 ExpectCache::~ExpectCache()
292 {
293     delete hash_table;
294     delete[] nodes;
295     delete[] pool;
296     delete packet_expect_flows;
297     packet_expect_flows = nullptr;
298 }
299 
300 /**Either expect or expect future session.
301  *
302  * Inspectors may add sessions to be expected altogether or to be associated
303  * with some data. For example, FTP inspector may add data channel that
304  * should be expected. Alternatively, FTP inspector may add session with
305  * snort protocol ID FTP-DATA.
306  *
307  * It is assumed that only one of cliPort or srvPort should be known (!0). This
308  * violation of this assumption will cause hash collision that will cause some
309  * session to be not expected and expected. This will occur only rarely and
310  * therefore acceptable design optimization.
311  *
312  * Also, snort_protocol_id is assumed to be consistent between different
313  * inspectors.  Each session can be assigned only one snort protocol ID.
314  * When new snort_protocol_id mismatches existing snort_protocol_id, new
315  * snort_protocol_id and associated data is not stored.
316  *
317  */
add_flow(const Packet * ctrlPkt,PktType type,IpProtocol ip_proto,const SfIp * cliIP,uint16_t cliPort,const SfIp * srvIP,uint16_t srvPort,char direction,FlowData * fd,SnortProtocolId snort_protocol_id,bool swap_app_direction,bool expect_multi,bool bidirectional)318 int ExpectCache::add_flow(const Packet *ctrlPkt, PktType type, IpProtocol ip_proto,
319     const SfIp* cliIP, uint16_t cliPort, const SfIp* srvIP, uint16_t srvPort, char direction,
320     FlowData* fd, SnortProtocolId snort_protocol_id, bool swap_app_direction, bool expect_multi,
321     bool bidirectional)
322 {
323     /* Just pull the VLAN ID, MPLS ID, and Address Space ID from the
324         control packet until we have a use case for not doing so. */
325     uint16_t vlanId = (ctrlPkt->proto_bits & PROTO_BIT__VLAN) ? layer::get_vlan_layer(ctrlPkt)->vid() : 0;
326     uint32_t mplsId = (ctrlPkt->proto_bits & PROTO_BIT__MPLS) ? ctrlPkt->ptrs.mplsHdr.label : 0;
327     FlowKey key;
328 
329     bool reversed_key = key.init(ctrlPkt->context->conf, type, ip_proto, cliIP, cliPort,
330         srvIP, srvPort, vlanId, mplsId, *ctrlPkt->pkth);
331 
332     bool new_node = false;
333     ExpectNode* node = static_cast<ExpectNode*> ( hash_table->get_user_data(&key) );
334     if ( !node )
335     {
336         if ( hash_table->full() )
337             prune_lru();
338         node = static_cast<ExpectNode*> ( hash_table->get(&key) );
339         assert(node);
340         new_node = true;
341     }
342     else if ( packet_time() > node->expires )
343     {
344         // node is past its expiration date, whack it and reuse it.
345         node->clear(free_list);
346         new_node = true;
347     }
348 
349     ExpectFlow* last = nullptr;
350     if ( !new_node )
351     {
352         //  reject if the snort_protocol_id doesn't match
353         if ( node->snort_protocol_id != snort_protocol_id )
354         {
355             if ( node->snort_protocol_id && snort_protocol_id )
356                 return -1;
357             node->snort_protocol_id = snort_protocol_id;
358         }
359 
360         last = node->tail;
361         if ( last )
362         {
363             FlowData* lfd = last->data;
364             while ( lfd )
365             {
366                 if ( lfd->get_id() == fd->get_id() )
367                 {
368                     last = nullptr;
369                     break;
370                 }
371                 lfd = lfd->next;
372             }
373         }
374     }
375     else
376     {
377         node->snort_protocol_id = snort_protocol_id;
378         node->reversed_key = reversed_key;
379         node->direction = direction;
380         node->swap_app_direction = swap_app_direction;
381         node->head = node->tail = nullptr;
382         node->count = 0;
383         last = nullptr;
384         /* Only add TCP and UDP expected flows for now via the DAQ module. */
385         if ((ip_proto == IpProtocol::TCP || ip_proto == IpProtocol::UDP) && ctrlPkt->daq_instance)
386         {
387             if (PacketTracer::is_active())
388             {
389                 SfIpString sipstr;
390                 SfIpString dipstr;
391                 cliIP->ntop(sipstr, sizeof(sipstr));
392                 srvIP->ntop(dipstr, sizeof(dipstr));
393                 PacketTracer::log("Create expected channel request sent with %s -> %s %hu %hhu\n",
394                         dipstr, sipstr, srvPort, static_cast<uint8_t>(ip_proto));
395             }
396             unsigned flag = 0;
397             if (expect_multi)
398                 flag |= DAQ_EFLOW_ALLOW_MULTIPLE;
399 
400             if (bidirectional)
401                 flag |= DAQ_EFLOW_BIDIRECTIONAL;
402 
403             ctrlPkt->daq_instance->add_expected(ctrlPkt, cliIP, cliPort, srvIP, srvPort,
404                     ip_proto, 1000, flag);
405         }
406     }
407 
408     bool new_expect_flow = false;
409     if ( !last )
410     {
411         if ( node->count >= MAX_LIST )
412         {
413             // fail when maxed out
414             ++overflows;
415             return -1;
416         }
417         last = free_list;
418         free_list = free_list->next;
419 
420         if ( !node->tail )
421             node->head = last;
422         else
423             node->tail->next = last;
424 
425         node->tail = last;
426         last->next = nullptr;
427         node->count++;
428         new_expect_flow = true;
429     }
430     last->add_flow_data(fd);
431     node->expires = packet_time() + MAX_WAIT;
432     ++expects;
433     if ( new_expect_flow )
434     {
435         // chain all expected flows created by this packet
436         packet_expect_flows->emplace_back(last);
437 
438         ExpectEvent event(ctrlPkt, last, fd);
439         DataBus::publish(EXPECT_EVENT_TYPE_EARLY_SESSION_CREATE_KEY, event, ctrlPkt->flow);
440     }
441     return 0;
442 }
443 
is_expected(Packet * p)444 bool ExpectCache::is_expected(Packet* p)
445 {
446     FlowKey key;
447     return (find_node_by_packet(p, key) != nullptr);
448 }
449 
check(Packet * p,Flow * lws)450 bool ExpectCache::check(Packet* p, Flow* lws)
451 {
452     FlowKey key;
453     ExpectNode* node = find_node_by_packet(p, key);
454 
455     if (!node)
456         return false;
457 
458     return process_expected(node, key, p, lws);
459 }
460 
461