1package cidranger 2 3import ( 4 "fmt" 5 "net" 6 "strings" 7 8 rnet "github.com/yl2chen/cidranger/net" 9) 10 11// prefixTrie is a path-compressed (PC) trie implementation of the 12// ranger interface inspired by this blog post: 13// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux 14// 15// CIDR blocks are stored using a prefix tree structure where each node has its 16// parent as prefix, and the path from the root node represents current CIDR 17// block. 18// 19// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are 20// 32 bits long and each bit represents a prefix tree starting at that bit. This 21// property also guarantees constant lookup time in Big-O notation. 22// 23// Path compression compresses a string of node with only 1 child into a single 24// node, decrease the amount of lookups necessary during containment tests. 25// 26// Level compression dictates the amount of direct children of a node by 27// allowing it to handle multiple bits in the path. The heuristic (based on 28// children population) to decide when the compression and decompression happens 29// is outlined in the prior linked blog, and will be experimented with in more 30// depth in this project in the future. 31// 32// Note: Can not insert both IPv4 and IPv6 network addresses into the same 33// prefix trie, use versionedRanger wrapper instead. 34// 35// TODO: Implement level-compressed component of the LPC trie. 36type prefixTrie struct { 37 parent *prefixTrie 38 children []*prefixTrie 39 40 numBitsSkipped uint 41 numBitsHandled uint 42 43 network rnet.Network 44 entry RangerEntry 45 46 size int // This is only maintained in the root trie. 47} 48 49// newPrefixTree creates a new prefixTrie. 50func newPrefixTree(version rnet.IPVersion) Ranger { 51 _, rootNet, _ := net.ParseCIDR("0.0.0.0/0") 52 if version == rnet.IPv6 { 53 _, rootNet, _ = net.ParseCIDR("0::0/0") 54 } 55 return &prefixTrie{ 56 children: make([]*prefixTrie, 2, 2), 57 numBitsSkipped: 0, 58 numBitsHandled: 1, 59 network: rnet.NewNetwork(*rootNet), 60 } 61} 62 63func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie { 64 version := rnet.IPv4 65 if len(network.Number) == rnet.IPv6Uint32Count { 66 version = rnet.IPv6 67 } 68 path := newPrefixTree(version).(*prefixTrie) 69 path.numBitsSkipped = numBitsSkipped 70 path.network = network.Masked(int(numBitsSkipped)) 71 return path 72} 73 74func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie { 75 ones, _ := network.IPNet.Mask.Size() 76 leaf := newPathprefixTrie(network, uint(ones)) 77 leaf.entry = entry 78 return leaf 79} 80 81// Insert inserts a RangerEntry into prefix trie. 82func (p *prefixTrie) Insert(entry RangerEntry) error { 83 network := entry.Network() 84 sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry) 85 if sizeIncreased { 86 p.size++ 87 } 88 return err 89} 90 91// Remove removes RangerEntry identified by given network from trie. 92func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) { 93 entry, err := p.remove(rnet.NewNetwork(network)) 94 if entry != nil { 95 p.size-- 96 } 97 return entry, err 98} 99 100// Contains returns boolean indicating whether given ip is contained in any 101// of the inserted networks. 102func (p *prefixTrie) Contains(ip net.IP) (bool, error) { 103 nn := rnet.NewNetworkNumber(ip) 104 if nn == nil { 105 return false, ErrInvalidNetworkNumberInput 106 } 107 return p.contains(nn) 108} 109 110// ContainingNetworks returns the list of RangerEntry(s) the given ip is 111// contained in in ascending prefix order. 112func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { 113 nn := rnet.NewNetworkNumber(ip) 114 if nn == nil { 115 return nil, ErrInvalidNetworkNumberInput 116 } 117 return p.containingNetworks(nn) 118} 119 120// CoveredNetworks returns the list of RangerEntry(s) the given ipnet 121// covers. That is, the networks that are completely subsumed by the 122// specified network. 123func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { 124 net := rnet.NewNetwork(network) 125 return p.coveredNetworks(net) 126} 127 128// Len returns number of networks in ranger. 129func (p *prefixTrie) Len() int { 130 return p.size 131} 132 133// String returns string representation of trie, mainly for visualization and 134// debugging. 135func (p *prefixTrie) String() string { 136 children := []string{} 137 padding := strings.Repeat("| ", p.level()+1) 138 for bits, child := range p.children { 139 if child == nil { 140 continue 141 } 142 childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String()) 143 children = append(children, childStr) 144 } 145 return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network, 146 p.targetBitPosition(), p.hasEntry(), strings.Join(children, "")) 147} 148 149func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) { 150 if !p.network.Contains(number) { 151 return false, nil 152 } 153 if p.hasEntry() { 154 return true, nil 155 } 156 if p.targetBitPosition() < 0 { 157 return false, nil 158 } 159 bit, err := p.targetBitFromIP(number) 160 if err != nil { 161 return false, err 162 } 163 child := p.children[bit] 164 if child != nil { 165 return child.contains(number) 166 } 167 return false, nil 168} 169 170func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) { 171 results := []RangerEntry{} 172 if !p.network.Contains(number) { 173 return results, nil 174 } 175 if p.hasEntry() { 176 results = []RangerEntry{p.entry} 177 } 178 if p.targetBitPosition() < 0 { 179 return results, nil 180 } 181 bit, err := p.targetBitFromIP(number) 182 if err != nil { 183 return nil, err 184 } 185 child := p.children[bit] 186 if child != nil { 187 ranges, err := child.containingNetworks(number) 188 if err != nil { 189 return nil, err 190 } 191 if len(ranges) > 0 { 192 if len(results) > 0 { 193 results = append(results, ranges...) 194 } else { 195 results = ranges 196 } 197 } 198 } 199 return results, nil 200} 201 202func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) { 203 var results []RangerEntry 204 if network.Covers(p.network) { 205 for entry := range p.walkDepth() { 206 results = append(results, entry) 207 } 208 } else if p.targetBitPosition() >= 0 { 209 bit, err := p.targetBitFromIP(network.Number) 210 if err != nil { 211 return results, err 212 } 213 child := p.children[bit] 214 if child != nil { 215 return child.coveredNetworks(network) 216 } 217 } 218 return results, nil 219} 220 221func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, error) { 222 if p.network.Equal(network) { 223 sizeIncreased := p.entry == nil 224 p.entry = entry 225 return sizeIncreased, nil 226 } 227 228 bit, err := p.targetBitFromIP(network.Number) 229 if err != nil { 230 return false, err 231 } 232 existingChild := p.children[bit] 233 234 // No existing child, insert new leaf trie. 235 if existingChild == nil { 236 p.appendTrie(bit, newEntryTrie(network, entry)) 237 return true, nil 238 } 239 240 // Check whether it is necessary to insert additional path prefix between current trie and existing child, 241 // in the case that inserted network diverges on its path to existing child. 242 lcb, err := network.LeastCommonBitPosition(existingChild.network) 243 divergingBitPos := int(lcb) - 1 244 if divergingBitPos > existingChild.targetBitPosition() { 245 pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb) 246 err := p.insertPrefix(bit, pathPrefix, existingChild) 247 if err != nil { 248 return false, err 249 } 250 // Update new child 251 existingChild = pathPrefix 252 } 253 return existingChild.insert(network, entry) 254} 255 256func (p *prefixTrie) appendTrie(bit uint32, prefix *prefixTrie) { 257 p.children[bit] = prefix 258 prefix.parent = p 259} 260 261func (p *prefixTrie) insertPrefix(bit uint32, pathPrefix, child *prefixTrie) error { 262 // Set parent/child relationship between current trie and inserted pathPrefix 263 p.children[bit] = pathPrefix 264 pathPrefix.parent = p 265 266 // Set parent/child relationship between inserted pathPrefix and original child 267 pathPrefixBit, err := pathPrefix.targetBitFromIP(child.network.Number) 268 if err != nil { 269 return err 270 } 271 pathPrefix.children[pathPrefixBit] = child 272 child.parent = pathPrefix 273 return nil 274} 275 276func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) { 277 if p.hasEntry() && p.network.Equal(network) { 278 entry := p.entry 279 p.entry = nil 280 281 err := p.compressPathIfPossible() 282 if err != nil { 283 return nil, err 284 } 285 return entry, nil 286 } 287 if p.targetBitPosition() < 0 { 288 return nil, nil 289 } 290 bit, err := p.targetBitFromIP(network.Number) 291 if err != nil { 292 return nil, err 293 } 294 child := p.children[bit] 295 if child != nil { 296 return child.remove(network) 297 } 298 return nil, nil 299} 300 301func (p *prefixTrie) qualifiesForPathCompression() bool { 302 // Current prefix trie can be path compressed if it meets all following. 303 // 1. records no CIDR entry 304 // 2. has single or no child 305 // 3. is not root trie 306 return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil 307} 308 309func (p *prefixTrie) compressPathIfPossible() error { 310 if !p.qualifiesForPathCompression() { 311 // Does not qualify to be compressed 312 return nil 313 } 314 315 // Find lone child. 316 var loneChild *prefixTrie 317 for _, child := range p.children { 318 if child != nil { 319 loneChild = child 320 break 321 } 322 } 323 324 // Find root of currnt single child lineage. 325 parent := p.parent 326 for ; parent.qualifiesForPathCompression(); parent = parent.parent { 327 } 328 parentBit, err := parent.targetBitFromIP(p.network.Number) 329 if err != nil { 330 return err 331 } 332 parent.children[parentBit] = loneChild 333 334 // Attempts to furthur apply path compression at current lineage parent, in case current lineage 335 // compressed into parent. 336 return parent.compressPathIfPossible() 337} 338 339func (p *prefixTrie) childrenCount() int { 340 count := 0 341 for _, child := range p.children { 342 if child != nil { 343 count++ 344 } 345 } 346 return count 347} 348 349func (p *prefixTrie) totalNumberOfBits() uint { 350 return rnet.BitsPerUint32 * uint(len(p.network.Number)) 351} 352 353func (p *prefixTrie) targetBitPosition() int { 354 return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1 355} 356 357func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) { 358 // This is a safe uint boxing of int since we should never attempt to get 359 // target bit at a negative position. 360 return n.Bit(uint(p.targetBitPosition())) 361} 362 363func (p *prefixTrie) hasEntry() bool { 364 return p.entry != nil 365} 366 367func (p *prefixTrie) level() int { 368 if p.parent == nil { 369 return 0 370 } 371 return p.parent.level() + 1 372} 373 374// walkDepth walks the trie in depth order, for unit testing. 375func (p *prefixTrie) walkDepth() <-chan RangerEntry { 376 entries := make(chan RangerEntry) 377 go func() { 378 if p.hasEntry() { 379 entries <- p.entry 380 } 381 childEntriesList := []<-chan RangerEntry{} 382 for _, trie := range p.children { 383 if trie == nil { 384 continue 385 } 386 childEntriesList = append(childEntriesList, trie.walkDepth()) 387 } 388 for _, childEntries := range childEntriesList { 389 for entry := range childEntries { 390 entries <- entry 391 } 392 } 393 close(entries) 394 }() 395 return entries 396} 397