1 /* Copyright (c) 2008, 2015, Oracle and/or its affiliates. All rights reserved.
2
3 This program is free software; you can redistribute it and/or modify
4 it under the terms of the GNU General Public License, version 2.0,
5 as published by the Free Software Foundation.
6
7 This program is also distributed with certain software (including
8 but not limited to OpenSSL) that is licensed under separate terms,
9 as designated in a particular file or component or in included license
10 documentation. The authors of MySQL hereby grant you an additional
11 permission to link the program and your derivative works with the
12 separately licensed software that they have included with MySQL.
13
14 This program is distributed in the hope that it will be useful,
15 but WITHOUT ANY WARRANTY; without even the implied warranty of
16 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 GNU General Public License, version 2.0, for more details.
18
19 You should have received a copy of the GNU General Public License
20 along with this program; if not, write to the Free Software
21 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA */
22
23 #include "SafeMutex.hpp"
24
25 #define JAM_FILE_ID 265
26
27
28 int
create()29 SafeMutex::create()
30 {
31 int ret;
32 if (m_initdone)
33 return err(ErrState, __LINE__);
34 ret = native_mutex_init(&m_mutex, 0);
35 if (ret != 0)
36 return err(ret, __LINE__);
37 ret = native_cond_init(&m_cond);
38 if (ret != 0)
39 return err(ret, __LINE__);
40 m_initdone = true;
41 return 0;
42 }
43
44 int
destroy()45 SafeMutex::destroy()
46 {
47 int ret;
48 if (!m_initdone)
49 return err(ErrState, __LINE__);
50 ret = native_cond_destroy(&m_cond);
51 if (ret != 0)
52 return err(ret, __LINE__);
53 ret = native_mutex_destroy(&m_mutex);
54 if (ret != 0)
55 return err(ret, __LINE__);
56 m_initdone = false;
57 return 0;
58 }
59
60 int
lock()61 SafeMutex::lock()
62 {
63 int ret;
64 if (m_simple) {
65 ret = native_mutex_lock(&m_mutex);
66 if (ret != 0)
67 return err(ret, __LINE__);
68 return 0;
69 }
70 ret = native_mutex_lock(&m_mutex);
71 if (ret != 0)
72 return err(ret, __LINE__);
73 return lock_impl();
74 }
75
76 int
lock_impl()77 SafeMutex::lock_impl()
78 {
79 int ret;
80 my_thread_t self = my_thread_self();
81 assert(self != 0);
82 while (1) {
83 if (m_level == 0) {
84 assert(m_owner == 0);
85 m_owner = self;
86 } else if (m_owner != self) {
87 ret = native_cond_wait(&m_cond, &m_mutex);
88 if (ret != 0)
89 return err(ret, __LINE__);
90 continue;
91 }
92 if (!(m_level < m_limit))
93 return err(ErrLevel, __LINE__);
94 m_level++;
95 if (m_usage < m_level)
96 m_usage = m_level;
97 ret = native_cond_signal(&m_cond);
98 if (ret != 0)
99 return err(ret, __LINE__);
100 ret = native_mutex_unlock(&m_mutex);
101 if (ret != 0)
102 return err(ret, __LINE__);
103 break;
104 }
105 return 0;
106 }
107
108 int
unlock()109 SafeMutex::unlock()
110 {
111 int ret;
112 if (m_simple) {
113 ret = native_mutex_unlock(&m_mutex);
114 if (ret != 0)
115 return err(ret, __LINE__);
116 return 0;
117 }
118 ret = native_mutex_lock(&m_mutex);
119 if (ret != 0)
120 return err(ret, __LINE__);
121 return unlock_impl();
122 }
123
124 int
unlock_impl()125 SafeMutex::unlock_impl()
126 {
127 int ret;
128 my_thread_t self = my_thread_self();
129 assert(self != 0);
130 if (m_owner != self)
131 return err(ErrOwner, __LINE__);
132 if (m_level == 0)
133 return err(ErrNolock, __LINE__);
134 m_level--;
135 if (m_level == 0) {
136 m_owner = 0;
137 ret = native_cond_signal(&m_cond);
138 if (ret != 0)
139 return err(ret, __LINE__);
140 }
141 ret = native_mutex_unlock(&m_mutex);
142 if (ret != 0)
143 return err(ret, __LINE__);
144 return 0;
145 }
146
147 int
err(int errcode,int errline)148 SafeMutex::err(int errcode, int errline)
149 {
150 assert(errcode != 0);
151 m_errcode = errcode;
152 m_errline = errline;
153 ndbout << *this << endl;
154 #ifdef UNIT_TEST
155 abort();
156 #endif
157 return errcode;
158 }
159
160 NdbOut&
operator <<(NdbOut & out,const SafeMutex & sm)161 operator<<(NdbOut& out, const SafeMutex& sm)
162 {
163 out << sm.m_name << ":";
164 out << " level=" << sm.m_level;
165 out << " usage=" << sm.m_usage;
166 if (sm.m_errcode != 0) {
167 out << " errcode=" << sm.m_errcode;
168 out << " errline=" << sm.m_errline;
169 }
170 return out;
171 }
172
173 #ifdef UNIT_TEST
174
175 struct sm_thr {
176 SafeMutex* sm_ptr;
177 uint index;
178 uint loops;
179 uint limit;
180 pthread_t id;
sm_thrsm_thr181 sm_thr() : sm_ptr(0), index(0), loops(0), limit(0), id(0) {}
~sm_thrsm_thr182 ~sm_thr() {}
183 };
184
185 extern "C" { static void* sm_run(void* arg); }
186
187 static void*
sm_run(void * arg)188 sm_run(void* arg)
189 {
190 sm_thr& thr = *(sm_thr*)arg;
191 assert(thr.sm_ptr != 0);
192 SafeMutex& sm = *thr.sm_ptr;
193 uint level = 0;
194 int dir = 0;
195 uint i;
196 for (i = 0; i < thr.loops; i++) {
197 int op = 0;
198 uint sel = uint(random()) % 10;
199 if (level == 0) {
200 dir = +1;
201 op = +1;
202 } else if (level == thr.limit) {
203 dir = -1;
204 op = -1;
205 } else if (dir == +1) {
206 op = sel != 0 ? +1 : -1;
207 } else if (dir == -1) {
208 op = sel != 0 ? -1 : +1;
209 } else {
210 assert(false);
211 }
212 if (op == +1) {
213 assert(level < thr.limit);
214 //ndbout << thr.index << ": lock" << endl;
215 int ret = sm.lock();
216 assert(ret == 0);
217 level++;
218 } else if (op == -1) {
219 //ndbout << thr.index << ": unlock" << endl;
220 int ret = sm.unlock();
221 assert(ret == 0);
222 assert(level != 0);
223 level--;
224 } else {
225 assert(false);
226 }
227 }
228 while (level > 0) {
229 int ret = sm.unlock();
230 assert(ret == 0);
231 level--;
232 }
233 return 0;
234 }
235
236 int
main(int argc,char ** argv)237 main(int argc, char** argv)
238 {
239 const uint max_thr = 128;
240 struct sm_thr thr[max_thr];
241
242 // threads - loops - max level - debug
243 uint num_thr = argc > 1 ? atoi(argv[1]) : 4;
244 assert(num_thr != 0 && num_thr <= max_thr);
245 uint loops = argc > 2 ? atoi(argv[2]) : 1000000;
246 uint limit = argc > 3 ? atoi(argv[3]) : 10;
247 assert(limit != 0);
248 bool debug = argc > 4 ? atoi(argv[4]) : true;
249
250 ndbout << "threads=" << num_thr;
251 ndbout << " loops=" << loops;
252 ndbout << " max level=" << limit << endl;
253
254 SafeMutex sm("test-mutex", limit, debug);
255 int ret;
256 ret = sm.create();
257 assert(ret == 0);
258
259 uint i;
260 for (i = 0; i < num_thr; i++) {
261 thr[i].sm_ptr = &sm;
262 thr[i].index = i;
263 thr[i].loops = loops;
264 thr[i].limit = limit;
265 pthread_create(&thr[i].id, 0, &sm_run, &thr[i]);
266 ndbout << "create " << i << " id=" << thr[i].id << endl;
267 }
268 for (i = 0; i < num_thr; i++) {
269 void* value;
270 pthread_join(thr[i].id, &value);
271 ndbout << "join " << i << " id=" << thr[i].id << endl;
272 }
273
274 ret = sm.destroy();
275 assert(ret == 0);
276 return 0;
277 }
278
279 #endif
280