1 /*
2  * Authored by Alex Hultman, 2018-2020.
3  * Intellectual property of third-party.
4 
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8 
9  *     http://www.apache.org/licenses/LICENSE-2.0
10 
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 /* This Server Name Indication hostname tree is written in C++ but could be ported to C.
19  * Overall it looks like crap, but has no memory allocations in fast path and is O(log n). */
20 
21 #ifndef SNI_TREE_H
22 #define SNI_TREE_H
23 
24 #ifndef LIBUS_NO_SSL
25 
26 #include <map>
27 #include <memory>
28 #include <string_view>
29 #include <cstring>
30 #include <cstdlib>
31 #include <algorithm>
32 
33 /* We only handle a maximum of 10 labels per hostname */
34 #define MAX_LABELS 10
35 
36 /* This cannot be shared */
37 thread_local void (*sni_free_cb)(void *);
38 
39 struct sni_node {
40     /* Empty nodes must always hold null */
41     void *user = nullptr;
42     std::map<std::string_view, std::unique_ptr<sni_node>> children;
43 
~sni_nodesni_node44     ~sni_node() {
45         for (auto &p : children) {
46             /* The data of our string_views are managed by malloc */
47             free((void *) p.first.data());
48 
49             /* Call destructor passed to sni_free only if we hold data.
50              * This is important since sni_remove does not have sni_free_cb set */
51             if (p.second.get()->user) {
52                 sni_free_cb(p.second.get()->user);
53             }
54         }
55     }
56 };
57 
58 // this can only delete ONE single node, but may cull "empty nodes with null as data"
removeUser(struct sni_node * root,unsigned int label,std::string_view * labels,unsigned int numLabels)59 void *removeUser(struct sni_node *root, unsigned int label, std::string_view *labels, unsigned int numLabels) {
60 
61     /* If we are in the bottom (past bottom by one), there is nothing to remove */
62     if (label == numLabels) {
63         void *user = root->user;
64         /* Mark us for culling on the way up */
65         root->user = nullptr;
66         return user;
67     }
68 
69     /* Is this label a child of root? */
70     auto it = root->children.find(labels[label]);
71     if (it == root->children.end()) {
72         /* We cannot continue */
73         return nullptr;
74     }
75 
76     void *removedUser = removeUser(it->second.get(), label + 1, labels, numLabels);
77 
78     /* On the way back up, we may cull empty nodes with no children.
79      * This ends up being where we remove all nodes */
80     if (it->second.get()->children.empty() && it->second.get()->user == nullptr) {
81 
82         /* The data of our string_views are managed by malloc */
83         free((void *) it->first.data());
84 
85         /* This can only happen with user set to null, otherwise we use sni_free_cb which is unset by sni_remove */
86         root->children.erase(it);
87     }
88 
89     return removedUser;
90 }
91 
getUser(struct sni_node * root,unsigned int label,std::string_view * labels,unsigned int numLabels)92 void *getUser(struct sni_node *root, unsigned int label, std::string_view *labels, unsigned int numLabels) {
93 
94     /* Do we have labels to match? Otherwise, return where we stand */
95     if (label == numLabels) {
96         return root->user;
97     }
98 
99     /* Try and match by our label */
100     auto it = root->children.find(labels[label]);
101     if (it != root->children.end()) {
102         void *user = getUser(it->second.get(), label + 1, labels, numLabels);
103         if (user) {
104             return user;
105         }
106     }
107 
108     /* Try and match by wildcard */
109     it = root->children.find("*");
110     if (it == root->children.end()) {
111         /* Matching has failed for both label and wildcard */
112         return nullptr;
113     }
114 
115     /* We matched by wildcard */
116     return getUser(it->second.get(), label + 1, labels, numLabels);
117 }
118 
119 extern "C" {
120 
sni_new()121     void *sni_new() {
122         return new sni_node;
123     }
124 
sni_free(void * sni,void (* cb)(void *))125     void sni_free(void *sni, void (*cb)(void *)) {
126         /* We want to run this callback for every remaining name */
127         sni_free_cb = cb;
128 
129         delete (sni_node *) sni;
130     }
131 
132     /* Returns non-null if this name already exists */
sni_add(void * sni,const char * hostname,void * user)133     int sni_add(void *sni, const char *hostname, void *user) {
134         struct sni_node *root = (struct sni_node *) sni;
135 
136         /* Traverse all labels in hostname */
137         for (std::string_view view(hostname, strlen(hostname)), label;
138             view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) {
139             /* Label is the token separated by dot */
140             label = view.substr(0, view.find('.', 0));
141 
142             auto it = root->children.find(label);
143             if (it == root->children.end()) {
144                 /* Duplicate this label for our kept string_view of it */
145                 void *labelString = malloc(label.length());
146                 memcpy(labelString, label.data(), label.length());
147 
148                 it = root->children.emplace(std::string_view((char *) labelString, label.length()),
149                                             std::make_unique<sni_node>()).first;
150             }
151 
152             root = it->second.get();
153         }
154 
155         /* We must never add multiple contexts for the same name, as that would overwrite and leak */
156         if (root->user) {
157             return 1;
158         }
159 
160         root->user = user;
161 
162         return 0;
163     }
164 
165     /* Removes the exact match. Wildcards are treated as the verbatim asterisk char, not as an actual wildcard */
sni_remove(void * sni,const char * hostname)166     void *sni_remove(void *sni, const char *hostname) {
167         struct sni_node *root = (struct sni_node *) sni;
168 
169         /* I guess 10 labels is an okay limit */
170         std::string_view labels[10];
171         unsigned int numLabels = 0;
172 
173         /* We traverse all labels first of all */
174         for (std::string_view view(hostname, strlen(hostname)), label;
175             view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) {
176             /* Label is the token separated by dot */
177             label = view.substr(0, view.find('.', 0));
178 
179             /* Anything longer than 10 labels is forbidden */
180             if (numLabels == 10) {
181                 return nullptr;
182             }
183 
184             labels[numLabels++] = label;
185         }
186 
187         return removeUser(root, 0, labels, numLabels);
188     }
189 
sni_find(void * sni,const char * hostname)190     void *sni_find(void *sni, const char *hostname) {
191         struct sni_node *root = (struct sni_node *) sni;
192 
193         /* I guess 10 labels is an okay limit */
194         std::string_view labels[10];
195         unsigned int numLabels = 0;
196 
197         /* We traverse all labels first of all */
198         for (std::string_view view(hostname, strlen(hostname)), label;
199             view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) {
200             /* Label is the token separated by dot */
201             label = view.substr(0, view.find('.', 0));
202 
203             /* Anything longer than 10 labels is forbidden */
204             if (numLabels == 10) {
205                 return nullptr;
206             }
207 
208             labels[numLabels++] = label;
209         }
210 
211         return getUser(root, 0, labels, numLabels);
212     }
213 
214 }
215 
216 #endif
217 
218 #endif