1 //--------------------------------------------------------------------------
2 // Copyright (C) 2017-2021 Cisco and/or its affiliates. All rights reserved.
3 //
4 // This program is free software; you can redistribute it and/or modify it
5 // under the terms of the GNU General Public License Version 2 as published
6 // by the Free Software Foundation.  You may not use, modify or distribute
7 // this program under any other version of the GNU General Public License.
8 //
9 // This program is distributed in the hope that it will be useful, but
10 // WITHOUT ANY WARRANTY; without even the implied warranty of
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 // General Public License for more details.
13 //
14 // You should have received a copy of the GNU General Public License along
15 // with this program; if not, write to the Free Software Foundation, Inc.,
16 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
17 //--------------------------------------------------------------------------
18 
19 // base64_encoder.cc author Russ Combs <rucombs@cisco.com>
20 
21 // this is based on the excellent work by devolve found at
22 // https://sourceforge.net/projects/libb64/.
23 
24 #ifdef HAVE_CONFIG_H
25 #include "config.h"
26 #endif
27 
28 #include "base64_encoder.h"
29 
30 #include <cassert>
31 
32 using namespace snort;
33 
b64(uint8_t idx)34 static inline char b64(uint8_t idx)
35 {
36     static const char* encoding =
37         "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
38     assert(idx < 64);
39     return encoding[idx];
40 }
41 
encode(const uint8_t * plain_text,unsigned length,char * buf)42 unsigned Base64Encoder::encode(
43     const uint8_t* plain_text, unsigned length, char* buf)
44 {
45     const uint8_t* data = plain_text;
46     const uint8_t* const data_end = plain_text + length;
47     char* p = buf;
48 
49     switch (step)
50     {
51         while (true)
52         {
53             uint8_t fragment;
54     case step_A:
55             if (data == data_end)
56             {
57                 step = step_A;
58                 return p - buf;
59             }
60             fragment = *data++;
61             state = (fragment & 0x0fc) >> 2;
62             *p++ = b64(state);
63             state = (fragment & 0x003) << 4;
64             // fallthrough
65     case step_B:
66             if (data == data_end)
67             {
68                 step = step_B;
69                 return p - buf;
70             }
71             fragment = *data++;
72             state |= (fragment & 0x0f0) >> 4;
73             *p++ = b64(state);
74             state = (fragment & 0x00f) << 2;
75             // fallthrough
76     case step_C:
77             if (data == data_end)
78             {
79                 step = step_C;
80                 return p - buf;
81             }
82             fragment = *data++;
83             state |= (fragment & 0x0c0) >> 6;
84             *p++ = b64(state);
85             state  = (fragment & 0x03f) >> 0;
86             *p++ = b64(state);
87         }
88     }
89     /* control should not reach here */
90     assert(false);
91     return p - buf;
92 }
93 
finish(char * buf)94 unsigned Base64Encoder::finish(char* buf)
95 {
96     char* p = buf;
97 
98     switch (step)
99     {
100     case step_B:
101         *p++ = b64(state);
102         *p++ = '=';
103         *p++ = '=';
104         break;
105     case step_C:
106         *p++ = b64(state);
107         *p++ = '=';
108         break;
109     case step_A:
110         break;
111     }
112     return p - buf;
113 }
114 
115 //--------------------------------------------------------------------------
116 // unit tests
117 // code string generated with: echo <text> | base64 -
118 // which adds a \n to the input.
119 //--------------------------------------------------------------------------
120 
121 #ifdef CATCH_TEST_BUILD
122 
123 #include <cstring>
124 
125 #include "catch/catch.hpp"
126 
127 TEST_CASE("b64 decode", "[Base64Encoder]")
128 {
129     Base64Encoder b64e;
130 
131     const char* text = "The quick brown segment jumped over the lazy dogs.\n";
132     const char* code = "VGhlIHF1aWNrIGJyb3duIHNlZ21lbnQganVtcGVkIG92ZXIgdGhlIGxhenkgZG9ncy4K";
133 
134     char buf[256];
135 
136     SECTION("no decode")
137     {
138         CHECK(!b64e.finish(buf));
139     }
140     SECTION("null data")
141     {
142         CHECK(!b64e.encode(nullptr, 0, buf));
143         CHECK(!b64e.finish(buf));
144     }
145     SECTION("zero length data")
146     {
147         CHECK(!b64e.encode((const uint8_t*)"ignore", 0, buf));
148         CHECK(!b64e.finish(buf));
149     }
150     SECTION("finish states")
151     {
152         const char* txt[] = { "test0\n",  "test01\n",     "test012\n" };
153         const char* exp[] = { "dGVzdDAK", "dGVzdDAxCg==", "dGVzdDAxMgo=" };
154 
155         const unsigned to_do = sizeof(txt)/sizeof(txt[0]);
156 
157         for ( unsigned i = 0; i < to_do; ++i )
158         {
159             unsigned n = b64e.encode((const uint8_t*)txt[i], strlen(txt[i]), buf);
160             n += b64e.finish(buf+n);
161 
162             REQUIRE(n < sizeof(buf));
163             buf[n] = 0;
164 
165             CHECK(!strcmp(buf, exp[i]));
166             b64e.reset();
167         }
168     }
169     SECTION("one shot")
170     {
171         unsigned n = b64e.encode((const uint8_t*)text, strlen(text), buf);
172         n += b64e.finish(buf+n);
173 
174         REQUIRE(n < sizeof(buf));
175         buf[n] = 0;
176 
177         CHECK(!strcmp(buf, code));
178     }
179     SECTION("slice and dice")
180     {
181         unsigned len = strlen(text);
182 
183         for ( unsigned chunk = 1; chunk < len; ++chunk )
184         {
185             memset(buf, 0, sizeof(buf));
186             unsigned offset = 0;
187             unsigned n = 0;
188 
189             while ( offset < len )
190             {
191                 unsigned k = (offset + chunk > len) ? len - offset : chunk;
192                 n += b64e.encode((const uint8_t*)text+offset, k, buf+n);
193                 offset += k;
194             }
195             n += b64e.finish(buf+n);
196 
197             REQUIRE(n < sizeof(buf));
198             buf[n] = 0;
199 
200             CHECK(!strcmp(buf, code));
201             b64e.reset();
202         }
203     }
204 }
205 
206 #endif
207 
208