1// Copyright 2016 VMware, Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package dns
16
17import (
18	"fmt"
19	"sync"
20	"sync/atomic"
21	"time"
22
23	mdns "github.com/miekg/dns"
24)
25
26// Item represents an item in the cache
27type Item struct {
28	Expiration time.Time
29	Msg        *mdns.Msg
30}
31
32// CacheOptions represents the cache options
33type CacheOptions struct {
34	// Max capacity of cache, after this limit cache starts to evict random elements
35	capacity int
36	// Default ttl used by items
37	ttl time.Duration
38}
39
40// Cache stores dns.Msgs and their expiration time
41type Cache struct {
42	CacheOptions
43
44	// Protects following map
45	sync.RWMutex
46	m map[string]*Item
47
48	// atomic cache hits & misses counters
49	// ^ cause we update them while holding the read lock
50	hits   uint64
51	misses uint64
52}
53
54// NewCache returns a new cache
55func NewCache(options CacheOptions) *Cache {
56	return &Cache{
57		CacheOptions: options,
58		m:            make(map[string]*Item, options.capacity),
59	}
60}
61
62// Capacity returns the capacity of the cache
63func (c *Cache) Capacity() int {
64	return c.capacity
65}
66
67// Count returns the element count of the cache
68func (c *Cache) Count() int {
69	c.RLock()
70	defer c.RUnlock()
71	return len(c.m)
72}
73
74func generateKey(q mdns.Question) string {
75	return fmt.Sprintf("%s:%s", q.Name, mdns.TypeToString[q.Qtype])
76}
77
78// Add adds dns.Msg to the cache
79func (c *Cache) Add(msg *mdns.Msg) {
80	c.Lock()
81	defer c.Unlock()
82
83	if len(c.m) >= c.capacity {
84		// pick a random key and remove it
85		for k := range c.m {
86			delete(c.m, k)
87			break
88		}
89	}
90
91	key := generateKey(msg.Question[0])
92	if _, ok := c.m[key]; !ok {
93		c.m[key] = &Item{
94			Expiration: time.Now().UTC().Add(c.ttl),
95			Msg:        msg.Copy(),
96		}
97	}
98}
99
100// Remove removes the dns.Msg from the cache
101func (c *Cache) Remove(msg *mdns.Msg) {
102	c.Lock()
103	defer c.Unlock()
104
105	if len(c.m) <= 0 {
106		return
107	}
108
109	key := generateKey(msg.Question[0])
110	delete(c.m, key)
111}
112
113// Get returns the dns.Msg from the cache
114func (c *Cache) Get(msg *mdns.Msg) *mdns.Msg {
115	key := generateKey(msg.Question[0])
116
117	c.RLock()
118	e, ok := c.m[key]
119	c.RUnlock()
120
121	if ok {
122		atomic.AddUint64(&c.hits, 1)
123
124		if time.Since(e.Expiration) < 0 {
125			return e.Msg.Copy()
126		}
127		// Expired msg, remove it from the cache
128		c.Remove(msg)
129	} else {
130		atomic.AddUint64(&c.misses, 1)
131	}
132	return nil
133}
134
135// Hits returns the number of cache hits
136func (c *Cache) Hits() uint64 {
137	return atomic.LoadUint64(&c.hits)
138}
139
140// Misses returns the number of cache misses
141func (c *Cache) Misses() uint64 {
142	return atomic.LoadUint64(&c.misses)
143}
144
145// Reset resets the cache
146func (c *Cache) Reset() {
147	c.Lock()
148	defer c.Unlock()
149
150	// drop the old map for GC and reset counters
151	c.m = make(map[string]*Item, c.capacity)
152	atomic.StoreUint64(&c.hits, 0)
153	atomic.StoreUint64(&c.misses, 0)
154
155}
156