1 #include <chrono>
2 #include <dmlc/io.h>
3 #include <dmlc/logging.h>
4 #include <dmlc/threadediter.h>
5 #include <gtest/gtest.h>
6
7 enum ExcType {
8 kDMLCException,
9 kStdException,
10 };
11
12 using namespace dmlc;
13 namespace producer_test {
delay(int sleep)14 inline void delay(int sleep) {
15 if (sleep < 0) {
16 int d = rand() % (-sleep);
17 std::this_thread::sleep_for(std::chrono::milliseconds(d));
18 } else {
19 std::this_thread::sleep_for(std::chrono::milliseconds(sleep));
20 }
21 }
22
23 // int was only used as example, in real life
24 // use big data blob
25 struct IntProducerNextExc : public ThreadedIter<int>::Producer {
26 int counter;
27 int maxcap;
28 int sleep;
29 ExcType exc_type;
30
IntProducerNextExcproducer_test::IntProducerNextExc31 IntProducerNextExc(int maxcap, int sleep, ExcType exc_type = ExcType::kDMLCException)
32 : counter(0), maxcap(maxcap), sleep(sleep), exc_type(exc_type) {}
33 virtual ~IntProducerNextExc() = default;
BeforeFirstproducer_test::IntProducerNextExc34 virtual void BeforeFirst(void) { counter = 0; }
Nextproducer_test::IntProducerNextExc35 virtual bool Next(int **inout_dptr) {
36 if (counter == maxcap)
37 return false;
38 if (counter == (maxcap - 1)) {
39 counter++;
40 if (exc_type == kDMLCException) {
41 LOG(FATAL) << "Test Throw exception";
42 } else {
43 LOG(WARNING) << "Throw std::exception";
44 throw std::exception();
45 }
46 }
47 // allocate space if not exist
48 if (*inout_dptr == NULL) {
49 *inout_dptr = new int();
50 }
51 delay(sleep);
52 **inout_dptr = counter++;
53 return true;
54 }
55 };
56
57 struct IntProducerBeforeFirst : public ThreadedIter<int>::Producer {
58 ExcType exc_type;
IntProducerBeforeFirstproducer_test::IntProducerBeforeFirst59 IntProducerBeforeFirst(ExcType exc_type = ExcType::kDMLCException)
60 : exc_type(exc_type) {}
61 virtual ~IntProducerBeforeFirst() = default;
BeforeFirstproducer_test::IntProducerBeforeFirst62 virtual void BeforeFirst(void) {
63 if (exc_type == ExcType::kDMLCException) {
64 LOG(FATAL) << "Throw exception in before first";
65 } else {
66 throw std::exception();
67 }
68 }
Nextproducer_test::IntProducerBeforeFirst69 virtual bool Next(int **inout_dptr) { return true; }
70 };
71 }
72
TEST(ThreadedIter,dmlc_exception)73 TEST(ThreadedIter, dmlc_exception) {
74 using namespace producer_test;
75 int* value = nullptr;
76 ThreadedIter<int> iter2;
77 iter2.set_max_capacity(7);
78 auto prod = std::make_shared<IntProducerNextExc>(5, 100);
79 bool caught = false;
80 iter2.Init(prod); // t1 is created in here, not passing ownership
81 iter2.BeforeFirst();
82 try {
83 delay(1000);
84 iter2.Recycle(&value);
85 } catch (dmlc::Error &e) {
86 caught = true;
87 LOG(INFO) << "recycle exception caught";
88 }
89 CHECK(caught);
90 iter2.Init(prod);
91 caught = false;
92 iter2.BeforeFirst();
93 try {
94 while (iter2.Next(&value)) {
95 iter2.Recycle(&value);
96 }
97 } catch (dmlc::Error &e) {
98 caught = true;
99 LOG(INFO) << "next exception caught";
100 }
101 CHECK(caught);
102 LOG(INFO) << "finish";
103 ThreadedIter<int> iter3;
104 iter3.set_max_capacity(1);
105 auto prod2 = std::make_shared<IntProducerBeforeFirst>();
106 iter3.Init(prod2);
107 caught = false;
108 try {
109 iter3.BeforeFirst();
110 } catch (dmlc::Error &e) {
111 caught = true;
112 LOG(INFO) << "beforefirst exception caught";
113 }
114 caught = false;
115 try {
116 iter3.BeforeFirst();
117 } catch (dmlc::Error &e) {
118 LOG(INFO) << "beforefirst exception thrown/caught";
119 caught = true;
120 }
121 CHECK(caught);
122 delete(value);
123 }
124
TEST(ThreadedIter,std_exception)125 TEST(ThreadedIter, std_exception) {
126 using namespace producer_test;
127 int *value = nullptr;
128 ThreadedIter<int> iter2;
129 iter2.set_max_capacity(7);
130 auto prod =std::make_shared<IntProducerNextExc>(5, 100, ExcType::kStdException);
131 bool caught = false;
132 iter2.Init(prod);
133 iter2.BeforeFirst();
134 try {
135 delay(1000);
136 iter2.Recycle(&value);
137 } catch (dmlc::Error &e) {
138 caught = true;
139 LOG(INFO) << "recycle exception caught";
140 }
141 CHECK(caught);
142 iter2.Init(prod);
143 caught = false;
144 iter2.BeforeFirst();
145 try {
146 while (iter2.Next(&value)) {
147 iter2.Recycle(&value);
148 }
149 } catch (dmlc::Error &e) {
150 caught = true;
151 LOG(INFO) << "next exception caught";
152 }
153 CHECK(caught);
154 LOG(INFO) << "finish";
155 ThreadedIter<int> iter3;
156 iter3.set_max_capacity(1);
157 auto prod2 = std::make_shared<IntProducerBeforeFirst>(ExcType::kStdException);
158 iter3.Init(prod2);
159 caught = false;
160 try {
161 iter3.BeforeFirst();
162 } catch (dmlc::Error &e) {
163 caught = true;
164 LOG(INFO) << "beforefirst exception caught";
165 }
166 caught = false;
167 try {
168 iter3.BeforeFirst();
169 } catch (dmlc::Error &e) {
170 LOG(INFO) << "beforefirst exception thrown/caught";
171 caught = true;
172 }
173 CHECK(caught);
174 delete(value);
175 }
176