1 #include <chrono>
2 #include <dmlc/io.h>
3 #include <dmlc/logging.h>
4 #include <dmlc/threadediter.h>
5 #include <gtest/gtest.h>
6 
7 enum ExcType {
8   kDMLCException,
9   kStdException,
10 };
11 
12 using namespace dmlc;
13 namespace producer_test {
delay(int sleep)14 inline void delay(int sleep) {
15   if (sleep < 0) {
16     int d = rand() % (-sleep);
17     std::this_thread::sleep_for(std::chrono::milliseconds(d));
18   } else {
19     std::this_thread::sleep_for(std::chrono::milliseconds(sleep));
20   }
21 }
22 
23 // int was only used as example, in real life
24 // use big data blob
25 struct IntProducerNextExc : public ThreadedIter<int>::Producer {
26   int counter;
27   int maxcap;
28   int sleep;
29   ExcType exc_type;
30 
IntProducerNextExcproducer_test::IntProducerNextExc31   IntProducerNextExc(int maxcap, int sleep, ExcType exc_type = ExcType::kDMLCException)
32       : counter(0), maxcap(maxcap), sleep(sleep), exc_type(exc_type) {}
33   virtual ~IntProducerNextExc() = default;
BeforeFirstproducer_test::IntProducerNextExc34   virtual void BeforeFirst(void) { counter = 0; }
Nextproducer_test::IntProducerNextExc35   virtual bool Next(int **inout_dptr) {
36     if (counter == maxcap)
37       return false;
38     if (counter == (maxcap - 1)) {
39       counter++;
40       if (exc_type == kDMLCException) {
41         LOG(FATAL) << "Test Throw exception";
42       } else {
43         LOG(WARNING) << "Throw std::exception";
44         throw std::exception();
45       }
46     }
47     // allocate space if not exist
48     if (*inout_dptr == NULL) {
49       *inout_dptr = new int();
50     }
51     delay(sleep);
52     **inout_dptr = counter++;
53     return true;
54   }
55 };
56 
57 struct IntProducerBeforeFirst : public ThreadedIter<int>::Producer {
58   ExcType exc_type;
IntProducerBeforeFirstproducer_test::IntProducerBeforeFirst59   IntProducerBeforeFirst(ExcType exc_type = ExcType::kDMLCException)
60       : exc_type(exc_type) {}
61   virtual ~IntProducerBeforeFirst() = default;
BeforeFirstproducer_test::IntProducerBeforeFirst62   virtual void BeforeFirst(void) {
63     if (exc_type == ExcType::kDMLCException) {
64       LOG(FATAL) << "Throw exception in before first";
65     } else {
66       throw std::exception();
67     }
68   }
Nextproducer_test::IntProducerBeforeFirst69   virtual bool Next(int **inout_dptr) { return true; }
70 };
71 }
72 
TEST(ThreadedIter,dmlc_exception)73 TEST(ThreadedIter, dmlc_exception) {
74   using namespace producer_test;
75   int* value = nullptr;
76   ThreadedIter<int> iter2;
77   iter2.set_max_capacity(7);
78   auto prod = std::make_shared<IntProducerNextExc>(5, 100);
79   bool caught = false;
80   iter2.Init(prod);  // t1 is created in here, not passing ownership
81   iter2.BeforeFirst();
82   try {
83     delay(1000);
84     iter2.Recycle(&value);
85   } catch (dmlc::Error &e) {
86     caught = true;
87     LOG(INFO) << "recycle exception caught";
88   }
89   CHECK(caught);
90   iter2.Init(prod);
91   caught = false;
92   iter2.BeforeFirst();
93   try {
94     while (iter2.Next(&value)) {
95       iter2.Recycle(&value);
96     }
97   } catch (dmlc::Error &e) {
98     caught = true;
99     LOG(INFO) << "next exception caught";
100   }
101   CHECK(caught);
102   LOG(INFO) << "finish";
103   ThreadedIter<int> iter3;
104   iter3.set_max_capacity(1);
105   auto prod2 = std::make_shared<IntProducerBeforeFirst>();
106   iter3.Init(prod2);
107   caught = false;
108   try {
109     iter3.BeforeFirst();
110   } catch (dmlc::Error &e) {
111     caught = true;
112     LOG(INFO) << "beforefirst exception caught";
113   }
114   caught = false;
115   try {
116   iter3.BeforeFirst();
117   } catch (dmlc::Error &e) {
118     LOG(INFO) << "beforefirst exception thrown/caught";
119     caught = true;
120   }
121   CHECK(caught);
122   delete(value);
123 }
124 
TEST(ThreadedIter,std_exception)125 TEST(ThreadedIter, std_exception) {
126   using namespace producer_test;
127   int *value = nullptr;
128   ThreadedIter<int> iter2;
129   iter2.set_max_capacity(7);
130   auto prod =std::make_shared<IntProducerNextExc>(5, 100, ExcType::kStdException);
131   bool caught = false;
132   iter2.Init(prod);
133   iter2.BeforeFirst();
134   try {
135     delay(1000);
136     iter2.Recycle(&value);
137   } catch (dmlc::Error &e) {
138     caught = true;
139     LOG(INFO) << "recycle exception caught";
140   }
141   CHECK(caught);
142   iter2.Init(prod);
143   caught = false;
144   iter2.BeforeFirst();
145   try {
146     while (iter2.Next(&value)) {
147       iter2.Recycle(&value);
148     }
149   } catch (dmlc::Error &e) {
150     caught = true;
151     LOG(INFO) << "next exception caught";
152   }
153   CHECK(caught);
154   LOG(INFO) << "finish";
155   ThreadedIter<int> iter3;
156   iter3.set_max_capacity(1);
157   auto prod2 = std::make_shared<IntProducerBeforeFirst>(ExcType::kStdException);
158   iter3.Init(prod2);
159   caught = false;
160   try {
161     iter3.BeforeFirst();
162   } catch (dmlc::Error &e) {
163     caught = true;
164     LOG(INFO) << "beforefirst exception caught";
165   }
166   caught = false;
167   try {
168   iter3.BeforeFirst();
169   } catch (dmlc::Error &e) {
170     LOG(INFO) << "beforefirst exception thrown/caught";
171     caught = true;
172   }
173   CHECK(caught);
174   delete(value);
175 }
176