1package cidranger
2
3import (
4	"fmt"
5	"net"
6	"strings"
7
8	rnet "github.com/yl2chen/cidranger/net"
9)
10
11// prefixTrie is a path-compressed (PC) trie implementation of the
12// ranger interface inspired by this blog post:
13// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
14//
15// CIDR blocks are stored using a prefix tree structure where each node has its
16// parent as prefix, and the path from the root node represents current CIDR
17// block.
18//
19// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
20// 32 bits long and each bit represents a prefix tree starting at that bit. This
21// property also guarantees constant lookup time in Big-O notation.
22//
23// Path compression compresses a string of node with only 1 child into a single
24// node, decrease the amount of lookups necessary during containment tests.
25//
26// Level compression dictates the amount of direct children of a node by
27// allowing it to handle multiple bits in the path.  The heuristic (based on
28// children population) to decide when the compression and decompression happens
29// is outlined in the prior linked blog, and will be experimented with in more
30// depth in this project in the future.
31//
32// Note: Can not insert both IPv4 and IPv6 network addresses into the same
33// prefix trie, use versionedRanger wrapper instead.
34//
35// TODO: Implement level-compressed component of the LPC trie.
36type prefixTrie struct {
37	parent   *prefixTrie
38	children []*prefixTrie
39
40	numBitsSkipped uint
41	numBitsHandled uint
42
43	network rnet.Network
44	entry   RangerEntry
45
46	size int // This is only maintained in the root trie.
47}
48
49// newPrefixTree creates a new prefixTrie.
50func newPrefixTree(version rnet.IPVersion) Ranger {
51	_, rootNet, _ := net.ParseCIDR("0.0.0.0/0")
52	if version == rnet.IPv6 {
53		_, rootNet, _ = net.ParseCIDR("0::0/0")
54	}
55	return &prefixTrie{
56		children:       make([]*prefixTrie, 2, 2),
57		numBitsSkipped: 0,
58		numBitsHandled: 1,
59		network:        rnet.NewNetwork(*rootNet),
60	}
61}
62
63func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
64	version := rnet.IPv4
65	if len(network.Number) == rnet.IPv6Uint32Count {
66		version = rnet.IPv6
67	}
68	path := newPrefixTree(version).(*prefixTrie)
69	path.numBitsSkipped = numBitsSkipped
70	path.network = network.Masked(int(numBitsSkipped))
71	return path
72}
73
74func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
75	ones, _ := network.IPNet.Mask.Size()
76	leaf := newPathprefixTrie(network, uint(ones))
77	leaf.entry = entry
78	return leaf
79}
80
81// Insert inserts a RangerEntry into prefix trie.
82func (p *prefixTrie) Insert(entry RangerEntry) error {
83	network := entry.Network()
84	sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry)
85	if sizeIncreased {
86		p.size++
87	}
88	return err
89}
90
91// Remove removes RangerEntry identified by given network from trie.
92func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
93	entry, err := p.remove(rnet.NewNetwork(network))
94	if entry != nil {
95		p.size--
96	}
97	return entry, err
98}
99
100// Contains returns boolean indicating whether given ip is contained in any
101// of the inserted networks.
102func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
103	nn := rnet.NewNetworkNumber(ip)
104	if nn == nil {
105		return false, ErrInvalidNetworkNumberInput
106	}
107	return p.contains(nn)
108}
109
110// ContainingNetworks returns the list of RangerEntry(s) the given ip is
111// contained in in ascending prefix order.
112func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
113	nn := rnet.NewNetworkNumber(ip)
114	if nn == nil {
115		return nil, ErrInvalidNetworkNumberInput
116	}
117	return p.containingNetworks(nn)
118}
119
120// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
121// covers.  That is, the networks that are completely subsumed by the
122// specified network.
123func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
124	net := rnet.NewNetwork(network)
125	return p.coveredNetworks(net)
126}
127
128// Len returns number of networks in ranger.
129func (p *prefixTrie) Len() int {
130	return p.size
131}
132
133// String returns string representation of trie, mainly for visualization and
134// debugging.
135func (p *prefixTrie) String() string {
136	children := []string{}
137	padding := strings.Repeat("| ", p.level()+1)
138	for bits, child := range p.children {
139		if child == nil {
140			continue
141		}
142		childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
143		children = append(children, childStr)
144	}
145	return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
146		p.targetBitPosition(), p.hasEntry(), strings.Join(children, ""))
147}
148
149func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
150	if !p.network.Contains(number) {
151		return false, nil
152	}
153	if p.hasEntry() {
154		return true, nil
155	}
156	if p.targetBitPosition() < 0 {
157		return false, nil
158	}
159	bit, err := p.targetBitFromIP(number)
160	if err != nil {
161		return false, err
162	}
163	child := p.children[bit]
164	if child != nil {
165		return child.contains(number)
166	}
167	return false, nil
168}
169
170func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) {
171	results := []RangerEntry{}
172	if !p.network.Contains(number) {
173		return results, nil
174	}
175	if p.hasEntry() {
176		results = []RangerEntry{p.entry}
177	}
178	if p.targetBitPosition() < 0 {
179		return results, nil
180	}
181	bit, err := p.targetBitFromIP(number)
182	if err != nil {
183		return nil, err
184	}
185	child := p.children[bit]
186	if child != nil {
187		ranges, err := child.containingNetworks(number)
188		if err != nil {
189			return nil, err
190		}
191		if len(ranges) > 0 {
192			if len(results) > 0 {
193				results = append(results, ranges...)
194			} else {
195				results = ranges
196			}
197		}
198	}
199	return results, nil
200}
201
202func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
203	var results []RangerEntry
204	if network.Covers(p.network) {
205		for entry := range p.walkDepth() {
206			results = append(results, entry)
207		}
208	} else if p.targetBitPosition() >= 0 {
209		bit, err := p.targetBitFromIP(network.Number)
210		if err != nil {
211			return results, err
212		}
213		child := p.children[bit]
214		if child != nil {
215			return child.coveredNetworks(network)
216		}
217	}
218	return results, nil
219}
220
221func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, error) {
222	if p.network.Equal(network) {
223		sizeIncreased := p.entry == nil
224		p.entry = entry
225		return sizeIncreased, nil
226	}
227
228	bit, err := p.targetBitFromIP(network.Number)
229	if err != nil {
230		return false, err
231	}
232	existingChild := p.children[bit]
233
234	// No existing child, insert new leaf trie.
235	if existingChild == nil {
236		p.appendTrie(bit, newEntryTrie(network, entry))
237		return true, nil
238	}
239
240	// Check whether it is necessary to insert additional path prefix between current trie and existing child,
241	// in the case that inserted network diverges on its path to existing child.
242	lcb, err := network.LeastCommonBitPosition(existingChild.network)
243	divergingBitPos := int(lcb) - 1
244	if divergingBitPos > existingChild.targetBitPosition() {
245		pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
246		err := p.insertPrefix(bit, pathPrefix, existingChild)
247		if err != nil {
248			return false, err
249		}
250		// Update new child
251		existingChild = pathPrefix
252	}
253	return existingChild.insert(network, entry)
254}
255
256func (p *prefixTrie) appendTrie(bit uint32, prefix *prefixTrie) {
257	p.children[bit] = prefix
258	prefix.parent = p
259}
260
261func (p *prefixTrie) insertPrefix(bit uint32, pathPrefix, child *prefixTrie) error {
262	// Set parent/child relationship between current trie and inserted pathPrefix
263	p.children[bit] = pathPrefix
264	pathPrefix.parent = p
265
266	// Set parent/child relationship between inserted pathPrefix and original child
267	pathPrefixBit, err := pathPrefix.targetBitFromIP(child.network.Number)
268	if err != nil {
269		return err
270	}
271	pathPrefix.children[pathPrefixBit] = child
272	child.parent = pathPrefix
273	return nil
274}
275
276func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) {
277	if p.hasEntry() && p.network.Equal(network) {
278		entry := p.entry
279		p.entry = nil
280
281		err := p.compressPathIfPossible()
282		if err != nil {
283			return nil, err
284		}
285		return entry, nil
286	}
287	if p.targetBitPosition() < 0 {
288		return nil, nil
289	}
290	bit, err := p.targetBitFromIP(network.Number)
291	if err != nil {
292		return nil, err
293	}
294	child := p.children[bit]
295	if child != nil {
296		return child.remove(network)
297	}
298	return nil, nil
299}
300
301func (p *prefixTrie) qualifiesForPathCompression() bool {
302	// Current prefix trie can be path compressed if it meets all following.
303	//		1. records no CIDR entry
304	//		2. has single or no child
305	//		3. is not root trie
306	return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil
307}
308
309func (p *prefixTrie) compressPathIfPossible() error {
310	if !p.qualifiesForPathCompression() {
311		// Does not qualify to be compressed
312		return nil
313	}
314
315	// Find lone child.
316	var loneChild *prefixTrie
317	for _, child := range p.children {
318		if child != nil {
319			loneChild = child
320			break
321		}
322	}
323
324	// Find root of currnt single child lineage.
325	parent := p.parent
326	for ; parent.qualifiesForPathCompression(); parent = parent.parent {
327	}
328	parentBit, err := parent.targetBitFromIP(p.network.Number)
329	if err != nil {
330		return err
331	}
332	parent.children[parentBit] = loneChild
333
334	// Attempts to furthur apply path compression at current lineage parent, in case current lineage
335	// compressed into parent.
336	return parent.compressPathIfPossible()
337}
338
339func (p *prefixTrie) childrenCount() int {
340	count := 0
341	for _, child := range p.children {
342		if child != nil {
343			count++
344		}
345	}
346	return count
347}
348
349func (p *prefixTrie) totalNumberOfBits() uint {
350	return rnet.BitsPerUint32 * uint(len(p.network.Number))
351}
352
353func (p *prefixTrie) targetBitPosition() int {
354	return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1
355}
356
357func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
358	// This is a safe uint boxing of int since we should never attempt to get
359	// target bit at a negative position.
360	return n.Bit(uint(p.targetBitPosition()))
361}
362
363func (p *prefixTrie) hasEntry() bool {
364	return p.entry != nil
365}
366
367func (p *prefixTrie) level() int {
368	if p.parent == nil {
369		return 0
370	}
371	return p.parent.level() + 1
372}
373
374// walkDepth walks the trie in depth order, for unit testing.
375func (p *prefixTrie) walkDepth() <-chan RangerEntry {
376	entries := make(chan RangerEntry)
377	go func() {
378		if p.hasEntry() {
379			entries <- p.entry
380		}
381		childEntriesList := []<-chan RangerEntry{}
382		for _, trie := range p.children {
383			if trie == nil {
384				continue
385			}
386			childEntriesList = append(childEntriesList, trie.walkDepth())
387		}
388		for _, childEntries := range childEntriesList {
389			for entry := range childEntries {
390				entries <- entry
391			}
392		}
393		close(entries)
394	}()
395	return entries
396}
397