1 /*
2  * nghttp2 - HTTP/2 C Library
3  *
4  * Copyright (c) 2015 Tatsuhiro Tsujikawa
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining
7  * a copy of this software and associated documentation files (the
8  * "Software"), to deal in the Software without restriction, including
9  * without limitation the rights to use, copy, modify, merge, publish,
10  * distribute, sublicense, and/or sell copies of the Software, and to
11  * permit persons to whom the Software is furnished to do so, subject to
12  * the following conditions:
13  *
14  * The above copyright notice and this permission notice shall be
15  * included in all copies or substantial portions of the Software.
16  *
17  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24  */
25 #include "shrpx_router.h"
26 
27 #include <algorithm>
28 
29 #include "shrpx_config.h"
30 #include "shrpx_log.h"
31 
32 namespace shrpx {
33 
RNode()34 RNode::RNode() : s(nullptr), len(0), index(-1), wildcard_index(-1) {}
35 
RNode(const char * s,size_t len,ssize_t index,ssize_t wildcard_index)36 RNode::RNode(const char *s, size_t len, ssize_t index, ssize_t wildcard_index)
37     : s(s), len(len), index(index), wildcard_index(wildcard_index) {}
38 
Router()39 Router::Router() : balloc_(1024, 1024), root_{} {}
40 
~Router()41 Router::~Router() {}
42 
43 namespace {
find_next_node(const RNode * node,char c)44 RNode *find_next_node(const RNode *node, char c) {
45   auto itr = std::lower_bound(std::begin(node->next), std::end(node->next), c,
46                               [](const std::unique_ptr<RNode> &lhs,
47                                  const char c) { return lhs->s[0] < c; });
48   if (itr == std::end(node->next) || (*itr)->s[0] != c) {
49     return nullptr;
50   }
51 
52   return (*itr).get();
53 }
54 } // namespace
55 
56 namespace {
add_next_node(RNode * node,std::unique_ptr<RNode> new_node)57 void add_next_node(RNode *node, std::unique_ptr<RNode> new_node) {
58   auto itr = std::lower_bound(std::begin(node->next), std::end(node->next),
59                               new_node->s[0],
60                               [](const std::unique_ptr<RNode> &lhs,
61                                  const char c) { return lhs->s[0] < c; });
62   node->next.insert(itr, std::move(new_node));
63 }
64 } // namespace
65 
add_node(RNode * node,const char * pattern,size_t patlen,ssize_t index,ssize_t wildcard_index)66 void Router::add_node(RNode *node, const char *pattern, size_t patlen,
67                       ssize_t index, ssize_t wildcard_index) {
68   auto pat = make_string_ref(balloc_, StringRef{pattern, patlen});
69   auto new_node =
70       std::make_unique<RNode>(pat.c_str(), pat.size(), index, wildcard_index);
71   add_next_node(node, std::move(new_node));
72 }
73 
add_route(const StringRef & pattern,size_t idx,bool wildcard)74 size_t Router::add_route(const StringRef &pattern, size_t idx, bool wildcard) {
75   ssize_t index = -1, wildcard_index = -1;
76   if (wildcard) {
77     wildcard_index = idx;
78   } else {
79     index = idx;
80   }
81 
82   auto node = &root_;
83   size_t i = 0;
84 
85   for (;;) {
86     auto next_node = find_next_node(node, pattern[i]);
87     if (next_node == nullptr) {
88       add_node(node, pattern.c_str() + i, pattern.size() - i, index,
89                wildcard_index);
90       return idx;
91     }
92 
93     node = next_node;
94 
95     auto slen = pattern.size() - i;
96     auto s = pattern.c_str() + i;
97     auto n = std::min(node->len, slen);
98     size_t j;
99     for (j = 0; j < n && node->s[j] == s[j]; ++j)
100       ;
101     if (j == n) {
102       // The common prefix was matched
103       if (slen == node->len) {
104         // Complete match
105         if (index != -1) {
106           if (node->index != -1) {
107             // Return the existing index for duplicates.
108             return node->index;
109           }
110           node->index = index;
111           return idx;
112         }
113 
114         assert(wildcard_index != -1);
115 
116         if (node->wildcard_index != -1) {
117           return node->wildcard_index;
118         }
119         node->wildcard_index = wildcard_index;
120         return idx;
121       }
122 
123       if (slen > node->len) {
124         // We still have pattern to add
125         i += j;
126 
127         continue;
128       }
129     }
130 
131     if (node->len > j) {
132       // node must be split into 2 nodes.  new_node is now the child
133       // of node.
134       auto new_node = std::make_unique<RNode>(
135           &node->s[j], node->len - j, node->index, node->wildcard_index);
136       std::swap(node->next, new_node->next);
137 
138       node->len = j;
139       node->index = -1;
140       node->wildcard_index = -1;
141 
142       add_next_node(node, std::move(new_node));
143 
144       if (slen == j) {
145         node->index = index;
146         node->wildcard_index = wildcard_index;
147         return idx;
148       }
149     }
150 
151     i += j;
152 
153     assert(pattern.size() > i);
154     add_node(node, pattern.c_str() + i, pattern.size() - i, index,
155              wildcard_index);
156 
157     return idx;
158   }
159 }
160 
161 namespace {
match_complete(size_t * offset,const RNode * node,const char * first,const char * last)162 const RNode *match_complete(size_t *offset, const RNode *node,
163                             const char *first, const char *last) {
164   *offset = 0;
165 
166   if (first == last) {
167     return node;
168   }
169 
170   auto p = first;
171 
172   for (;;) {
173     auto next_node = find_next_node(node, *p);
174     if (next_node == nullptr) {
175       return nullptr;
176     }
177 
178     node = next_node;
179 
180     auto n = std::min(node->len, static_cast<size_t>(last - p));
181     if (memcmp(node->s, p, n) != 0) {
182       return nullptr;
183     }
184     p += n;
185     if (p == last) {
186       *offset = n;
187       return node;
188     }
189   }
190 }
191 } // namespace
192 
193 namespace {
match_partial(bool * pattern_is_wildcard,const RNode * node,size_t offset,const char * first,const char * last)194 const RNode *match_partial(bool *pattern_is_wildcard, const RNode *node,
195                            size_t offset, const char *first, const char *last) {
196   *pattern_is_wildcard = false;
197 
198   if (first == last) {
199     if (node->len == offset) {
200       return node;
201     }
202     return nullptr;
203   }
204 
205   auto p = first;
206 
207   const RNode *found_node = nullptr;
208 
209   if (offset > 0) {
210     auto n = std::min(node->len - offset, static_cast<size_t>(last - first));
211     if (memcmp(node->s + offset, first, n) != 0) {
212       return nullptr;
213     }
214 
215     p += n;
216 
217     if (p == last) {
218       if (node->len == offset + n) {
219         if (node->index != -1) {
220           return node;
221         }
222 
223         // The last '/' handling, see below.
224         node = find_next_node(node, '/');
225         if (node != nullptr && node->index != -1 && node->len == 1) {
226           return node;
227         }
228 
229         return nullptr;
230       }
231 
232       // The last '/' handling, see below.
233       if (node->index != -1 && offset + n + 1 == node->len &&
234           node->s[node->len - 1] == '/') {
235         return node;
236       }
237 
238       return nullptr;
239     }
240 
241     if (node->wildcard_index != -1) {
242       found_node = node;
243       *pattern_is_wildcard = true;
244     } else if (node->index != -1 && node->s[node->len - 1] == '/') {
245       found_node = node;
246       *pattern_is_wildcard = false;
247     }
248 
249     assert(node->len == offset + n);
250   }
251 
252   for (;;) {
253     auto next_node = find_next_node(node, *p);
254     if (next_node == nullptr) {
255       return found_node;
256     }
257 
258     node = next_node;
259 
260     auto n = std::min(node->len, static_cast<size_t>(last - p));
261     if (memcmp(node->s, p, n) != 0) {
262       return found_node;
263     }
264 
265     p += n;
266 
267     if (p == last) {
268       if (node->len == n) {
269         // Complete match with this node
270         if (node->index != -1) {
271           *pattern_is_wildcard = false;
272           return node;
273         }
274 
275         // The last '/' handling, see below.
276         node = find_next_node(node, '/');
277         if (node != nullptr && node->index != -1 && node->len == 1) {
278           *pattern_is_wildcard = false;
279           return node;
280         }
281 
282         return found_node;
283       }
284 
285       // We allow match without trailing "/" at the end of pattern.
286       // So, if pattern ends with '/', and pattern and path matches
287       // without that slash, we consider they match to deal with
288       // request to the directory without trailing slash.  That is if
289       // pattern is "/foo/" and path is "/foo", we consider they
290       // match.
291       if (node->index != -1 && n + 1 == node->len && node->s[n] == '/') {
292         *pattern_is_wildcard = false;
293         return node;
294       }
295 
296       return found_node;
297     }
298 
299     if (node->wildcard_index != -1) {
300       found_node = node;
301       *pattern_is_wildcard = true;
302     } else if (node->index != -1 && node->s[node->len - 1] == '/') {
303       // This is the case when pattern which ends with "/" is included
304       // in query.
305       found_node = node;
306       *pattern_is_wildcard = false;
307     }
308 
309     assert(node->len == n);
310   }
311 }
312 } // namespace
313 
match(const StringRef & host,const StringRef & path) const314 ssize_t Router::match(const StringRef &host, const StringRef &path) const {
315   const RNode *node;
316   size_t offset;
317 
318   node = match_complete(&offset, &root_, std::begin(host), std::end(host));
319   if (node == nullptr) {
320     return -1;
321   }
322 
323   bool pattern_is_wildcard;
324   node = match_partial(&pattern_is_wildcard, node, offset, std::begin(path),
325                        std::end(path));
326   if (node == nullptr || node == &root_) {
327     return -1;
328   }
329 
330   return pattern_is_wildcard ? node->wildcard_index : node->index;
331 }
332 
match(const StringRef & s) const333 ssize_t Router::match(const StringRef &s) const {
334   const RNode *node;
335   size_t offset;
336 
337   node = match_complete(&offset, &root_, std::begin(s), std::end(s));
338   if (node == nullptr) {
339     return -1;
340   }
341 
342   if (node->len != offset) {
343     return -1;
344   }
345 
346   return node->index;
347 }
348 
349 namespace {
match_prefix(size_t * nread,const RNode * node,const char * first,const char * last)350 const RNode *match_prefix(size_t *nread, const RNode *node, const char *first,
351                           const char *last) {
352   if (first == last) {
353     return nullptr;
354   }
355 
356   auto p = first;
357 
358   for (;;) {
359     auto next_node = find_next_node(node, *p);
360     if (next_node == nullptr) {
361       return nullptr;
362     }
363 
364     node = next_node;
365 
366     auto n = std::min(node->len, static_cast<size_t>(last - p));
367     if (memcmp(node->s, p, n) != 0) {
368       return nullptr;
369     }
370 
371     p += n;
372 
373     if (p != last) {
374       if (node->index != -1) {
375         *nread = p - first;
376         return node;
377       }
378       continue;
379     }
380 
381     if (node->len == n) {
382       *nread = p - first;
383       return node;
384     }
385 
386     return nullptr;
387   }
388 }
389 } // namespace
390 
match_prefix(size_t * nread,const RNode ** last_node,const StringRef & s) const391 ssize_t Router::match_prefix(size_t *nread, const RNode **last_node,
392                              const StringRef &s) const {
393   if (*last_node == nullptr) {
394     *last_node = &root_;
395   }
396 
397   auto node =
398       ::shrpx::match_prefix(nread, *last_node, std::begin(s), std::end(s));
399   if (node == nullptr) {
400     return -1;
401   }
402 
403   *last_node = node;
404 
405   return node->index;
406 }
407 
408 namespace {
dump_node(const RNode * node,int depth)409 void dump_node(const RNode *node, int depth) {
410   fprintf(stderr, "%*ss='%.*s', len=%zu, index=%zd\n", depth, "",
411           (int)node->len, node->s, node->len, node->index);
412   for (auto &nd : node->next) {
413     dump_node(nd.get(), depth + 4);
414   }
415 }
416 } // namespace
417 
dump() const418 void Router::dump() const { dump_node(&root_, 0); }
419 
420 } // namespace shrpx
421