1"""
2LFU cache Written by Shane Wang
3https://medium.com/@epicshane/a-python-implementation-of-lfu-least-frequently-used-cache-with-o-1-time-complexity-e16b34a3c49b
4https://github.com/luxigner/lfu_cache
5Modified by Sep Dehpour
6"""
7from collections import defaultdict
8from ordered_set import OrderedSet
9from threading import Lock
10from statistics import mean
11from deepdiff.helper import not_found, dict_
12
13
14class CacheNode:
15    def __init__(self, key, report_type, value, freq_node, pre, nxt):
16        self.key = key
17        if report_type:
18            self.content = defaultdict(OrderedSet)
19            self.content[report_type].add(value)
20        else:
21            self.content = value
22        self.freq_node = freq_node
23        self.pre = pre  # previous CacheNode
24        self.nxt = nxt  # next CacheNode
25
26    def free_myself(self):
27        if self.freq_node.cache_head == self.freq_node.cache_tail:
28            self.freq_node.cache_head = self.freq_node.cache_tail = None
29        elif self.freq_node.cache_head == self:
30            self.nxt.pre = None
31            self.freq_node.cache_head = self.nxt
32        elif self.freq_node.cache_tail == self:
33            self.pre.nxt = None
34            self.freq_node.cache_tail = self.pre
35        else:
36            self.pre.nxt = self.nxt
37            self.nxt.pre = self.pre
38
39        self.pre = None
40        self.nxt = None
41        self.freq_node = None
42
43
44class FreqNode:
45    def __init__(self, freq, pre, nxt):
46        self.freq = freq
47        self.pre = pre  # previous FreqNode
48        self.nxt = nxt  # next FreqNode
49        self.cache_head = None  # CacheNode head under this FreqNode
50        self.cache_tail = None  # CacheNode tail under this FreqNode
51
52    def count_caches(self):
53        if self.cache_head is None and self.cache_tail is None:
54            return 0
55        elif self.cache_head == self.cache_tail:
56            return 1
57        else:
58            return '2+'
59
60    def remove(self):
61        if self.pre is not None:
62            self.pre.nxt = self.nxt
63        if self.nxt is not None:
64            self.nxt.pre = self.pre
65
66        pre = self.pre
67        nxt = self.nxt
68        self.pre = self.nxt = self.cache_head = self.cache_tail = None
69
70        return (pre, nxt)
71
72    def pop_head_cache(self):
73        if self.cache_head is None and self.cache_tail is None:
74            return None
75        elif self.cache_head == self.cache_tail:
76            cache_head = self.cache_head
77            self.cache_head = self.cache_tail = None
78            return cache_head
79        else:
80            cache_head = self.cache_head
81            self.cache_head.nxt.pre = None
82            self.cache_head = self.cache_head.nxt
83            return cache_head
84
85    def append_cache_to_tail(self, cache_node):
86        cache_node.freq_node = self
87
88        if self.cache_head is None and self.cache_tail is None:
89            self.cache_head = self.cache_tail = cache_node
90        else:
91            cache_node.pre = self.cache_tail
92            cache_node.nxt = None
93            self.cache_tail.nxt = cache_node
94            self.cache_tail = cache_node
95
96    def insert_after_me(self, freq_node):
97        freq_node.pre = self
98        freq_node.nxt = self.nxt
99
100        if self.nxt is not None:
101            self.nxt.pre = freq_node
102
103        self.nxt = freq_node
104
105    def insert_before_me(self, freq_node):
106        if self.pre is not None:
107            self.pre.nxt = freq_node
108
109        freq_node.pre = self.pre
110        freq_node.nxt = self
111        self.pre = freq_node
112
113
114class LFUCache:
115
116    def __init__(self, capacity):
117        self.cache = dict_()  # {key: cache_node}
118        if capacity <= 0:
119            raise ValueError('Capacity of LFUCache needs to be positive.')  # pragma: no cover.
120        self.capacity = capacity
121        self.freq_link_head = None
122        self.lock = Lock()
123
124    def get(self, key):
125        with self.lock:
126            if key in self.cache:
127                cache_node = self.cache[key]
128                freq_node = cache_node.freq_node
129                content = cache_node.content
130
131                self.move_forward(cache_node, freq_node)
132
133                return content
134            else:
135                return not_found
136
137    def set(self, key, report_type=None, value=None):
138        with self.lock:
139            if key in self.cache:
140                cache_node = self.cache[key]
141                if report_type:
142                    cache_node.content[report_type].add(value)
143                else:
144                    cache_node.content = value
145            else:
146                if len(self.cache) >= self.capacity:
147                    self.dump_cache()
148
149                self.create_cache_node(key, report_type, value)
150
151    def __contains__(self, key):
152        return key in self.cache
153
154    def move_forward(self, cache_node, freq_node):
155        if freq_node.nxt is None or freq_node.nxt.freq != freq_node.freq + 1:
156            target_freq_node = FreqNode(freq_node.freq + 1, None, None)
157            target_empty = True
158        else:
159            target_freq_node = freq_node.nxt
160            target_empty = False
161
162        cache_node.free_myself()
163        target_freq_node.append_cache_to_tail(cache_node)
164
165        if target_empty:
166            freq_node.insert_after_me(target_freq_node)
167
168        if freq_node.count_caches() == 0:
169            if self.freq_link_head == freq_node:
170                self.freq_link_head = target_freq_node
171
172            freq_node.remove()
173
174    def dump_cache(self):
175        head_freq_node = self.freq_link_head
176        self.cache.pop(head_freq_node.cache_head.key)
177        head_freq_node.pop_head_cache()
178
179        if head_freq_node.count_caches() == 0:
180            self.freq_link_head = head_freq_node.nxt
181            head_freq_node.remove()
182
183    def create_cache_node(self, key, report_type, value):
184        cache_node = CacheNode(
185            key=key, report_type=report_type,
186            value=value, freq_node=None, pre=None, nxt=None)
187        self.cache[key] = cache_node
188
189        if self.freq_link_head is None or self.freq_link_head.freq != 0:
190            new_freq_node = FreqNode(0, None, None)
191            new_freq_node.append_cache_to_tail(cache_node)
192
193            if self.freq_link_head is not None:
194                self.freq_link_head.insert_before_me(new_freq_node)
195
196            self.freq_link_head = new_freq_node
197        else:
198            self.freq_link_head.append_cache_to_tail(cache_node)
199
200    def get_sorted_cache_keys(self):
201        result = [(i, freq.freq_node.freq) for i, freq in self.cache.items()]
202        result.sort(key=lambda x: -x[1])
203        return result
204
205    def get_average_frequency(self):
206        return mean(freq.freq_node.freq for freq in self.cache.values())
207
208
209class DummyLFU:
210
211    def __init__(self, *args, **kwargs):
212        pass
213
214    set = __init__
215    get = __init__
216
217    def __contains__(self, key):
218        return False
219