1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sync_test
6
7import (
8	"sync"
9	"sync/atomic"
10)
11
12// This file contains reference map implementations for unit-tests.
13
14// mapInterface is the interface Map implements.
15type mapInterface interface {
16	Load(interface{}) (interface{}, bool)
17	Store(key, value interface{})
18	LoadOrStore(key, value interface{}) (actual interface{}, loaded bool)
19	LoadAndDelete(key interface{}) (value interface{}, loaded bool)
20	Delete(interface{})
21	Range(func(key, value interface{}) (shouldContinue bool))
22}
23
24// RWMutexMap is an implementation of mapInterface using a sync.RWMutex.
25type RWMutexMap struct {
26	mu    sync.RWMutex
27	dirty map[interface{}]interface{}
28}
29
30func (m *RWMutexMap) Load(key interface{}) (value interface{}, ok bool) {
31	m.mu.RLock()
32	value, ok = m.dirty[key]
33	m.mu.RUnlock()
34	return
35}
36
37func (m *RWMutexMap) Store(key, value interface{}) {
38	m.mu.Lock()
39	if m.dirty == nil {
40		m.dirty = make(map[interface{}]interface{})
41	}
42	m.dirty[key] = value
43	m.mu.Unlock()
44}
45
46func (m *RWMutexMap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
47	m.mu.Lock()
48	actual, loaded = m.dirty[key]
49	if !loaded {
50		actual = value
51		if m.dirty == nil {
52			m.dirty = make(map[interface{}]interface{})
53		}
54		m.dirty[key] = value
55	}
56	m.mu.Unlock()
57	return actual, loaded
58}
59
60func (m *RWMutexMap) LoadAndDelete(key interface{}) (value interface{}, loaded bool) {
61	m.mu.Lock()
62	value, loaded = m.dirty[key]
63	if !loaded {
64		m.mu.Unlock()
65		return nil, false
66	}
67	delete(m.dirty, key)
68	m.mu.Unlock()
69	return value, loaded
70}
71
72func (m *RWMutexMap) Delete(key interface{}) {
73	m.mu.Lock()
74	delete(m.dirty, key)
75	m.mu.Unlock()
76}
77
78func (m *RWMutexMap) Range(f func(key, value interface{}) (shouldContinue bool)) {
79	m.mu.RLock()
80	keys := make([]interface{}, 0, len(m.dirty))
81	for k := range m.dirty {
82		keys = append(keys, k)
83	}
84	m.mu.RUnlock()
85
86	for _, k := range keys {
87		v, ok := m.Load(k)
88		if !ok {
89			continue
90		}
91		if !f(k, v) {
92			break
93		}
94	}
95}
96
97// DeepCopyMap is an implementation of mapInterface using a Mutex and
98// atomic.Value.  It makes deep copies of the map on every write to avoid
99// acquiring the Mutex in Load.
100type DeepCopyMap struct {
101	mu    sync.Mutex
102	clean atomic.Value
103}
104
105func (m *DeepCopyMap) Load(key interface{}) (value interface{}, ok bool) {
106	clean, _ := m.clean.Load().(map[interface{}]interface{})
107	value, ok = clean[key]
108	return value, ok
109}
110
111func (m *DeepCopyMap) Store(key, value interface{}) {
112	m.mu.Lock()
113	dirty := m.dirty()
114	dirty[key] = value
115	m.clean.Store(dirty)
116	m.mu.Unlock()
117}
118
119func (m *DeepCopyMap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
120	clean, _ := m.clean.Load().(map[interface{}]interface{})
121	actual, loaded = clean[key]
122	if loaded {
123		return actual, loaded
124	}
125
126	m.mu.Lock()
127	// Reload clean in case it changed while we were waiting on m.mu.
128	clean, _ = m.clean.Load().(map[interface{}]interface{})
129	actual, loaded = clean[key]
130	if !loaded {
131		dirty := m.dirty()
132		dirty[key] = value
133		actual = value
134		m.clean.Store(dirty)
135	}
136	m.mu.Unlock()
137	return actual, loaded
138}
139
140func (m *DeepCopyMap) LoadAndDelete(key interface{}) (value interface{}, loaded bool) {
141	m.mu.Lock()
142	dirty := m.dirty()
143	value, loaded = dirty[key]
144	delete(dirty, key)
145	m.clean.Store(dirty)
146	m.mu.Unlock()
147	return
148}
149
150func (m *DeepCopyMap) Delete(key interface{}) {
151	m.mu.Lock()
152	dirty := m.dirty()
153	delete(dirty, key)
154	m.clean.Store(dirty)
155	m.mu.Unlock()
156}
157
158func (m *DeepCopyMap) Range(f func(key, value interface{}) (shouldContinue bool)) {
159	clean, _ := m.clean.Load().(map[interface{}]interface{})
160	for k, v := range clean {
161		if !f(k, v) {
162			break
163		}
164	}
165}
166
167func (m *DeepCopyMap) dirty() map[interface{}]interface{} {
168	clean, _ := m.clean.Load().(map[interface{}]interface{})
169	dirty := make(map[interface{}]interface{}, len(clean)+1)
170	for k, v := range clean {
171		dirty[k] = v
172	}
173	return dirty
174}
175