1 /* Copyright (c) 2008, 2015, Oracle and/or its affiliates. All rights reserved.
2 
3    This program is free software; you can redistribute it and/or modify
4    it under the terms of the GNU General Public License, version 2.0,
5    as published by the Free Software Foundation.
6 
7    This program is also distributed with certain software (including
8    but not limited to OpenSSL) that is licensed under separate terms,
9    as designated in a particular file or component or in included license
10    documentation.  The authors of MySQL hereby grant you an additional
11    permission to link the program and your derivative works with the
12    separately licensed software that they have included with MySQL.
13 
14    This program is distributed in the hope that it will be useful,
15    but WITHOUT ANY WARRANTY; without even the implied warranty of
16    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17    GNU General Public License, version 2.0, for more details.
18 
19    You should have received a copy of the GNU General Public License
20    along with this program; if not, write to the Free Software
21    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA */
22 
23 #include "SafeMutex.hpp"
24 
25 #define JAM_FILE_ID 265
26 
27 
28 int
create()29 SafeMutex::create()
30 {
31   int ret;
32   if (m_initdone)
33     return err(ErrState, __LINE__);
34   ret = native_mutex_init(&m_mutex, 0);
35   if (ret != 0)
36     return err(ret, __LINE__);
37   ret = native_cond_init(&m_cond);
38   if (ret != 0)
39     return err(ret, __LINE__);
40   m_initdone = true;
41   return 0;
42 }
43 
44 int
destroy()45 SafeMutex::destroy()
46 {
47   int ret;
48   if (!m_initdone)
49     return err(ErrState, __LINE__);
50   ret = native_cond_destroy(&m_cond);
51   if (ret != 0)
52     return err(ret, __LINE__);
53   ret = native_mutex_destroy(&m_mutex);
54   if (ret != 0)
55     return err(ret, __LINE__);
56   m_initdone = false;
57   return 0;
58 }
59 
60 int
lock()61 SafeMutex::lock()
62 {
63   int ret;
64   if (m_simple) {
65     ret = native_mutex_lock(&m_mutex);
66     if (ret != 0)
67       return err(ret, __LINE__);
68     return 0;
69   }
70   ret = native_mutex_lock(&m_mutex);
71   if (ret != 0)
72     return err(ret, __LINE__);
73   return lock_impl();
74 }
75 
76 int
lock_impl()77 SafeMutex::lock_impl()
78 {
79   int ret;
80   my_thread_t self = my_thread_self();
81   assert(self != 0);
82   while (1) {
83     if (m_level == 0) {
84       assert(m_owner == 0);
85       m_owner = self;
86     } else if (m_owner != self) {
87       ret = native_cond_wait(&m_cond, &m_mutex);
88       if (ret != 0)
89         return err(ret, __LINE__);
90       continue;
91     }
92     if (!(m_level < m_limit))
93       return err(ErrLevel, __LINE__);
94     m_level++;
95     if (m_usage < m_level)
96       m_usage = m_level;
97     ret = native_cond_signal(&m_cond);
98     if (ret != 0)
99       return err(ret, __LINE__);
100     ret = native_mutex_unlock(&m_mutex);
101     if (ret != 0)
102       return err(ret, __LINE__);
103     break;
104   }
105   return 0;
106 }
107 
108 int
unlock()109 SafeMutex::unlock()
110 {
111   int ret;
112   if (m_simple) {
113     ret = native_mutex_unlock(&m_mutex);
114     if (ret != 0)
115       return err(ret, __LINE__);
116     return 0;
117   }
118   ret = native_mutex_lock(&m_mutex);
119   if (ret != 0)
120     return err(ret, __LINE__);
121   return unlock_impl();
122 }
123 
124 int
unlock_impl()125 SafeMutex::unlock_impl()
126 {
127   int ret;
128   my_thread_t self = my_thread_self();
129   assert(self != 0);
130   if (m_owner != self)
131     return err(ErrOwner, __LINE__);
132   if (m_level == 0)
133     return err(ErrNolock, __LINE__);
134   m_level--;
135   if (m_level == 0) {
136     m_owner = 0;
137     ret = native_cond_signal(&m_cond);
138     if (ret != 0)
139       return err(ret, __LINE__);
140   }
141   ret = native_mutex_unlock(&m_mutex);
142   if (ret != 0)
143     return err(ret, __LINE__);
144   return 0;
145 }
146 
147 int
err(int errcode,int errline)148 SafeMutex::err(int errcode, int errline)
149 {
150   assert(errcode != 0);
151   m_errcode = errcode;
152   m_errline = errline;
153   ndbout << *this << endl;
154 #ifdef UNIT_TEST
155   abort();
156 #endif
157   return errcode;
158 }
159 
160 NdbOut&
operator <<(NdbOut & out,const SafeMutex & sm)161 operator<<(NdbOut& out, const SafeMutex& sm)
162 {
163   out << sm.m_name << ":";
164   out << " level=" << sm.m_level;
165   out << " usage=" << sm.m_usage;
166   if (sm.m_errcode != 0) {
167     out << " errcode=" << sm.m_errcode;
168     out << " errline=" << sm.m_errline;
169   }
170   return out;
171 }
172 
173 #ifdef UNIT_TEST
174 
175 struct sm_thr {
176   SafeMutex* sm_ptr;
177   uint index;
178   uint loops;
179   uint limit;
180   pthread_t id;
sm_thrsm_thr181   sm_thr() : sm_ptr(0), index(0), loops(0), limit(0), id(0) {}
~sm_thrsm_thr182   ~sm_thr() {}
183 };
184 
185 extern "C" { static void* sm_run(void* arg); }
186 
187 static void*
sm_run(void * arg)188 sm_run(void* arg)
189 {
190   sm_thr& thr = *(sm_thr*)arg;
191   assert(thr.sm_ptr != 0);
192   SafeMutex& sm = *thr.sm_ptr;
193   uint level = 0;
194   int dir = 0;
195   uint i;
196   for (i = 0; i < thr.loops; i++) {
197     int op = 0;
198     uint sel = uint(random()) % 10;
199     if (level == 0) {
200       dir = +1;
201       op = +1;
202     } else if (level == thr.limit) {
203       dir = -1;
204       op = -1;
205     } else if (dir == +1) {
206       op = sel != 0 ? +1 : -1;
207     } else if (dir == -1) {
208       op = sel != 0 ? -1 : +1;
209     } else {
210       assert(false);
211     }
212     if (op == +1) {
213       assert(level < thr.limit);
214       //ndbout << thr.index << ": lock" << endl;
215       int ret = sm.lock();
216       assert(ret == 0);
217       level++;
218     } else if (op == -1) {
219       //ndbout << thr.index << ": unlock" << endl;
220       int ret = sm.unlock();
221       assert(ret == 0);
222       assert(level != 0);
223       level--;
224     } else {
225       assert(false);
226     }
227   }
228   while (level > 0) {
229     int ret = sm.unlock();
230     assert(ret == 0);
231     level--;
232   }
233   return 0;
234 }
235 
236 int
main(int argc,char ** argv)237 main(int argc, char** argv)
238 {
239   const uint max_thr = 128;
240   struct sm_thr thr[max_thr];
241 
242   // threads - loops - max level - debug
243   uint num_thr = argc > 1 ? atoi(argv[1]) : 4;
244   assert(num_thr != 0 && num_thr <= max_thr);
245   uint loops = argc > 2 ? atoi(argv[2]) : 1000000;
246   uint limit = argc > 3 ? atoi(argv[3]) : 10;
247   assert(limit != 0);
248   bool debug = argc > 4 ? atoi(argv[4]) : true;
249 
250   ndbout << "threads=" << num_thr;
251   ndbout << " loops=" << loops;
252   ndbout << " max level=" << limit << endl;
253 
254   SafeMutex sm("test-mutex", limit, debug);
255   int ret;
256   ret = sm.create();
257   assert(ret == 0);
258 
259   uint i;
260   for (i = 0; i < num_thr; i++) {
261     thr[i].sm_ptr = &sm;
262     thr[i].index = i;
263     thr[i].loops = loops;
264     thr[i].limit = limit;
265     pthread_create(&thr[i].id, 0, &sm_run, &thr[i]);
266     ndbout << "create " << i << " id=" << thr[i].id << endl;
267   }
268   for (i = 0; i < num_thr; i++) {
269     void* value;
270     pthread_join(thr[i].id, &value);
271     ndbout << "join " << i << " id=" << thr[i].id << endl;
272   }
273 
274   ret = sm.destroy();
275   assert(ret == 0);
276   return 0;
277 }
278 
279 #endif
280