1 /*
2     Copyright (c) 2016-2017 ZeroMQ community
3     Copyright (c) 2016 VOCA AS / Harald Nøkland
4 
5     Permission is hereby granted, free of charge, to any person obtaining a copy
6     of this software and associated documentation files (the "Software"), to
7     deal in the Software without restriction, including without limitation the
8     rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
9     sell copies of the Software, and to permit persons to whom the Software is
10     furnished to do so, subject to the following conditions:
11 
12     The above copyright notice and this permission notice shall be included in
13     all copies or substantial portions of the Software.
14 
15     THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16     IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17     FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18     AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19     LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20     FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21     IN THE SOFTWARE.
22 */
23 
24 #ifndef __ZMQ_ADDON_HPP_INCLUDED__
25 #define __ZMQ_ADDON_HPP_INCLUDED__
26 
27 #include "zmq.hpp"
28 
29 #include <deque>
30 #include <iomanip>
31 #include <sstream>
32 #include <stdexcept>
33 #ifdef ZMQ_CPP11
34 #include <limits>
35 #include <functional>
36 #include <unordered_map>
37 #endif
38 
39 namespace zmq
40 {
41 #ifdef ZMQ_CPP11
42 
43 namespace detail
44 {
45 template<bool CheckN, class OutputIt>
46 recv_result_t
recv_multipart_n(socket_ref s,OutputIt out,size_t n,recv_flags flags)47 recv_multipart_n(socket_ref s, OutputIt out, size_t n, recv_flags flags)
48 {
49     size_t msg_count = 0;
50     message_t msg;
51     while (true) {
52         if ZMQ_CONSTEXPR_IF (CheckN) {
53             if (msg_count >= n)
54                 throw std::runtime_error(
55                   "Too many message parts in recv_multipart_n");
56         }
57         if (!s.recv(msg, flags)) {
58             // zmq ensures atomic delivery of messages
59             assert(msg_count == 0);
60             return {};
61         }
62         ++msg_count;
63         const bool more = msg.more();
64         *out++ = std::move(msg);
65         if (!more)
66             break;
67     }
68     return msg_count;
69 }
70 
is_little_endian()71 inline bool is_little_endian()
72 {
73     const uint16_t i = 0x01;
74     return *reinterpret_cast<const uint8_t *>(&i) == 0x01;
75 }
76 
write_network_order(unsigned char * buf,const uint32_t value)77 inline void write_network_order(unsigned char *buf, const uint32_t value)
78 {
79     if (is_little_endian()) {
80         ZMQ_CONSTEXPR_VAR uint32_t mask = std::numeric_limits<std::uint8_t>::max();
81         *buf++ = static_cast<unsigned char>((value >> 24) & mask);
82         *buf++ = static_cast<unsigned char>((value >> 16) & mask);
83         *buf++ = static_cast<unsigned char>((value >> 8) & mask);
84         *buf++ = static_cast<unsigned char>(value & mask);
85     } else {
86         std::memcpy(buf, &value, sizeof(value));
87     }
88 }
89 
read_u32_network_order(const unsigned char * buf)90 inline uint32_t read_u32_network_order(const unsigned char *buf)
91 {
92     if (is_little_endian()) {
93         return (static_cast<uint32_t>(buf[0]) << 24)
94                + (static_cast<uint32_t>(buf[1]) << 16)
95                + (static_cast<uint32_t>(buf[2]) << 8)
96                + static_cast<uint32_t>(buf[3]);
97     } else {
98         uint32_t value;
99         std::memcpy(&value, buf, sizeof(value));
100         return value;
101     }
102 }
103 } // namespace detail
104 
105 /*  Receive a multipart message.
106 
107     Writes the zmq::message_t objects to OutputIterator out.
108     The out iterator must handle an unspecified number of writes,
109     e.g. by using std::back_inserter.
110 
111     Returns: the number of messages received or nullopt (on EAGAIN).
112     Throws: if recv throws. Any exceptions thrown
113     by the out iterator will be propagated and the message
114     may have been only partially received with pending
115     message parts. It is adviced to close this socket in that event.
116 */
117 template<class OutputIt>
recv_multipart(socket_ref s,OutputIt out,recv_flags flags=recv_flags::none)118 ZMQ_NODISCARD recv_result_t recv_multipart(socket_ref s,
119                                            OutputIt out,
120                                            recv_flags flags = recv_flags::none)
121 {
122     return detail::recv_multipart_n<false>(s, std::move(out), 0, flags);
123 }
124 
125 /*  Receive a multipart message.
126 
127     Writes at most n zmq::message_t objects to OutputIterator out.
128     If the number of message parts of the incoming message exceeds n
129     then an exception will be thrown.
130 
131     Returns: the number of messages received or nullopt (on EAGAIN).
132     Throws: if recv throws. Throws std::runtime_error if the number
133     of message parts exceeds n (exactly n messages will have been written
134     to out). Any exceptions thrown
135     by the out iterator will be propagated and the message
136     may have been only partially received with pending
137     message parts. It is adviced to close this socket in that event.
138 */
139 template<class OutputIt>
recv_multipart_n(socket_ref s,OutputIt out,size_t n,recv_flags flags=recv_flags::none)140 ZMQ_NODISCARD recv_result_t recv_multipart_n(socket_ref s,
141                                              OutputIt out,
142                                              size_t n,
143                                              recv_flags flags = recv_flags::none)
144 {
145     return detail::recv_multipart_n<true>(s, std::move(out), n, flags);
146 }
147 
148 /*  Send a multipart message.
149 
150     The range must be a ForwardRange of zmq::message_t,
151     zmq::const_buffer or zmq::mutable_buffer.
152     The flags may be zmq::send_flags::sndmore if there are
153     more message parts to be sent after the call to this function.
154 
155     Returns: the number of messages sent (exactly msgs.size()) or nullopt (on EAGAIN).
156     Throws: if send throws. Any exceptions thrown
157     by the msgs range will be propagated and the message
158     may have been only partially sent. It is adviced to close this socket in that event.
159 */
160 template<class Range
161 #ifndef ZMQ_CPP11_PARTIAL
162          ,
163          typename = typename std::enable_if<
164            detail::is_range<Range>::value
165            && (std::is_same<detail::range_value_t<Range>, message_t>::value
166                || detail::is_buffer<detail::range_value_t<Range>>::value)>::type
167 #endif
168          >
169 send_result_t
send_multipart(socket_ref s,Range && msgs,send_flags flags=send_flags::none)170 send_multipart(socket_ref s, Range &&msgs, send_flags flags = send_flags::none)
171 {
172     using std::begin;
173     using std::end;
174     auto it = begin(msgs);
175     const auto end_it = end(msgs);
176     size_t msg_count = 0;
177     while (it != end_it) {
178         const auto next = std::next(it);
179         const auto msg_flags =
180           flags | (next == end_it ? send_flags::none : send_flags::sndmore);
181         if (!s.send(*it, msg_flags)) {
182             // zmq ensures atomic delivery of messages
183             assert(it == begin(msgs));
184             return {};
185         }
186         ++msg_count;
187         it = next;
188     }
189     return msg_count;
190 }
191 
192 /* Encode a multipart message.
193 
194    The range must be a ForwardRange of zmq::message_t.  A
195    zmq::multipart_t or STL container may be passed for encoding.
196 
197    Returns: a zmq::message_t holding the encoded multipart data.
198 
199    Throws: std::range_error is thrown if the size of any single part
200    can not fit in an unsigned 32 bit integer.
201 
202    The encoding is compatible with that used by the CZMQ function
203    zmsg_encode(), see https://rfc.zeromq.org/spec/50/.
204    Each part consists of a size followed by the data.
205    These are placed contiguously into the output message.  A part of
206    size less than 255 bytes will have a single byte size value.
207    Larger parts will have a five byte size value with the first byte
208    set to 0xFF and the remaining four bytes holding the size of the
209    part's data.
210 */
211 template<class Range
212 #ifndef ZMQ_CPP11_PARTIAL
213          ,
214          typename = typename std::enable_if<
215            detail::is_range<Range>::value
216            && (std::is_same<detail::range_value_t<Range>, message_t>::value
217                || detail::is_buffer<detail::range_value_t<Range>>::value)>::type
218 #endif
219          >
encode(const Range & parts)220 message_t encode(const Range &parts)
221 {
222     size_t mmsg_size = 0;
223 
224     // First pass check sizes
225     for (const auto &part : parts) {
226         const size_t part_size = part.size();
227         if (part_size > std::numeric_limits<std::uint32_t>::max()) {
228             // Size value must fit into uint32_t.
229             throw std::range_error("Invalid size, message part too large");
230         }
231         const size_t count_size =
232           part_size < std::numeric_limits<std::uint8_t>::max() ? 1 : 5;
233         mmsg_size += part_size + count_size;
234     }
235 
236     message_t encoded(mmsg_size);
237     unsigned char *buf = encoded.data<unsigned char>();
238     for (const auto &part : parts) {
239         const uint32_t part_size = static_cast<uint32_t>(part.size());
240         const unsigned char *part_data =
241           static_cast<const unsigned char *>(part.data());
242 
243         if (part_size < std::numeric_limits<std::uint8_t>::max()) {
244             // small part
245             *buf++ = (unsigned char) part_size;
246         } else {
247             // big part
248             *buf++ = std::numeric_limits<uint8_t>::max();
249             detail::write_network_order(buf, part_size);
250             buf += sizeof(part_size);
251         }
252         std::memcpy(buf, part_data, part_size);
253         buf += part_size;
254     }
255 
256     assert(static_cast<size_t>(buf - encoded.data<unsigned char>()) == mmsg_size);
257     return encoded;
258 }
259 
260 /*  Decode an encoded message to multiple parts.
261 
262     The given output iterator must be a ForwardIterator to a container
263     holding zmq::message_t such as a zmq::multipart_t or various STL
264     containers.
265 
266     Returns the ForwardIterator advanced once past the last decoded
267     part.
268 
269     Throws: a std::out_of_range is thrown if the encoded part sizes
270     lead to exceeding the message data bounds.
271 
272     The decoding assumes the message is encoded in the manner
273     performed by zmq::encode(), see https://rfc.zeromq.org/spec/50/.
274  */
decode(const message_t & encoded,OutputIt out)275 template<class OutputIt> OutputIt decode(const message_t &encoded, OutputIt out)
276 {
277     const unsigned char *source = encoded.data<unsigned char>();
278     const unsigned char *const limit = source + encoded.size();
279 
280     while (source < limit) {
281         size_t part_size = *source++;
282         if (part_size == std::numeric_limits<std::uint8_t>::max()) {
283             if (static_cast<size_t>(limit - source) < sizeof(uint32_t)) {
284                 throw std::out_of_range(
285                   "Malformed encoding, overflow in reading size");
286             }
287             part_size = detail::read_u32_network_order(source);
288             // the part size is allowed to be less than 0xFF
289             source += sizeof(uint32_t);
290         }
291 
292         if (static_cast<size_t>(limit - source) < part_size) {
293             throw std::out_of_range("Malformed encoding, overflow in reading part");
294         }
295         *out = message_t(source, part_size);
296         ++out;
297         source += part_size;
298     }
299 
300     assert(source == limit);
301     return out;
302 }
303 
304 #endif
305 
306 
307 #ifdef ZMQ_HAS_RVALUE_REFS
308 
309 /*
310     This class handles multipart messaging. It is the C++ equivalent of zmsg.h,
311     which is part of CZMQ (the high-level C binding). Furthermore, it is a major
312     improvement compared to zmsg.hpp, which is part of the examples in the ØMQ
313     Guide. Unnecessary copying is avoided by using move semantics to efficiently
314     add/remove parts.
315 */
316 class multipart_t
317 {
318   private:
319     std::deque<message_t> m_parts;
320 
321   public:
322     typedef std::deque<message_t>::value_type value_type;
323 
324     typedef std::deque<message_t>::iterator iterator;
325     typedef std::deque<message_t>::const_iterator const_iterator;
326 
327     typedef std::deque<message_t>::reverse_iterator reverse_iterator;
328     typedef std::deque<message_t>::const_reverse_iterator const_reverse_iterator;
329 
330     // Default constructor
multipart_t()331     multipart_t() {}
332 
333     // Construct from socket receive
multipart_t(socket_ref socket)334     multipart_t(socket_ref socket) { recv(socket); }
335 
336     // Construct from memory block
multipart_t(const void * src,size_t size)337     multipart_t(const void *src, size_t size) { addmem(src, size); }
338 
339     // Construct from string
multipart_t(const std::string & string)340     multipart_t(const std::string &string) { addstr(string); }
341 
342     // Construct from message part
multipart_t(message_t && message)343     multipart_t(message_t &&message) { add(std::move(message)); }
344 
345     // Move constructor
multipart_t(multipart_t && other)346     multipart_t(multipart_t &&other) { m_parts = std::move(other.m_parts); }
347 
348     // Move assignment operator
operator =(multipart_t && other)349     multipart_t &operator=(multipart_t &&other)
350     {
351         m_parts = std::move(other.m_parts);
352         return *this;
353     }
354 
355     // Destructor
~multipart_t()356     virtual ~multipart_t() { clear(); }
357 
operator [](size_t n)358     message_t &operator[](size_t n) { return m_parts[n]; }
359 
operator [](size_t n) const360     const message_t &operator[](size_t n) const { return m_parts[n]; }
361 
at(size_t n)362     message_t &at(size_t n) { return m_parts.at(n); }
363 
at(size_t n) const364     const message_t &at(size_t n) const { return m_parts.at(n); }
365 
begin()366     iterator begin() { return m_parts.begin(); }
367 
begin() const368     const_iterator begin() const { return m_parts.begin(); }
369 
cbegin() const370     const_iterator cbegin() const { return m_parts.cbegin(); }
371 
rbegin()372     reverse_iterator rbegin() { return m_parts.rbegin(); }
373 
rbegin() const374     const_reverse_iterator rbegin() const { return m_parts.rbegin(); }
375 
end()376     iterator end() { return m_parts.end(); }
377 
end() const378     const_iterator end() const { return m_parts.end(); }
379 
cend() const380     const_iterator cend() const { return m_parts.cend(); }
381 
rend()382     reverse_iterator rend() { return m_parts.rend(); }
383 
rend() const384     const_reverse_iterator rend() const { return m_parts.rend(); }
385 
386     // Delete all parts
clear()387     void clear() { m_parts.clear(); }
388 
389     // Get number of parts
size() const390     size_t size() const { return m_parts.size(); }
391 
392     // Check if number of parts is zero
empty() const393     bool empty() const { return m_parts.empty(); }
394 
395     // Receive multipart message from socket
recv(socket_ref socket,int flags=0)396     bool recv(socket_ref socket, int flags = 0)
397     {
398         clear();
399         bool more = true;
400         while (more) {
401             message_t message;
402 #ifdef ZMQ_CPP11
403             if (!socket.recv(message, static_cast<recv_flags>(flags)))
404                 return false;
405 #else
406             if (!socket.recv(&message, flags))
407                 return false;
408 #endif
409             more = message.more();
410             add(std::move(message));
411         }
412         return true;
413     }
414 
415     // Send multipart message to socket
send(socket_ref socket,int flags=0)416     bool send(socket_ref socket, int flags = 0)
417     {
418         flags &= ~(ZMQ_SNDMORE);
419         bool more = size() > 0;
420         while (more) {
421             message_t message = pop();
422             more = size() > 0;
423 #ifdef ZMQ_CPP11
424             if (!socket.send(message, static_cast<send_flags>(
425                                         (more ? ZMQ_SNDMORE : 0) | flags)))
426                 return false;
427 #else
428             if (!socket.send(message, (more ? ZMQ_SNDMORE : 0) | flags))
429                 return false;
430 #endif
431         }
432         clear();
433         return true;
434     }
435 
436     // Concatenate other multipart to front
prepend(multipart_t && other)437     void prepend(multipart_t &&other)
438     {
439         while (!other.empty())
440             push(other.remove());
441     }
442 
443     // Concatenate other multipart to back
append(multipart_t && other)444     void append(multipart_t &&other)
445     {
446         while (!other.empty())
447             add(other.pop());
448     }
449 
450     // Push memory block to front
pushmem(const void * src,size_t size)451     void pushmem(const void *src, size_t size)
452     {
453         m_parts.push_front(message_t(src, size));
454     }
455 
456     // Push memory block to back
addmem(const void * src,size_t size)457     void addmem(const void *src, size_t size)
458     {
459         m_parts.push_back(message_t(src, size));
460     }
461 
462     // Push string to front
pushstr(const std::string & string)463     void pushstr(const std::string &string)
464     {
465         m_parts.push_front(message_t(string.data(), string.size()));
466     }
467 
468     // Push string to back
addstr(const std::string & string)469     void addstr(const std::string &string)
470     {
471         m_parts.push_back(message_t(string.data(), string.size()));
472     }
473 
474     // Push type (fixed-size) to front
pushtyp(const T & type)475     template<typename T> void pushtyp(const T &type)
476     {
477         static_assert(!std::is_same<T, std::string>::value,
478                       "Use pushstr() instead of pushtyp<std::string>()");
479         m_parts.push_front(message_t(&type, sizeof(type)));
480     }
481 
482     // Push type (fixed-size) to back
addtyp(const T & type)483     template<typename T> void addtyp(const T &type)
484     {
485         static_assert(!std::is_same<T, std::string>::value,
486                       "Use addstr() instead of addtyp<std::string>()");
487         m_parts.push_back(message_t(&type, sizeof(type)));
488     }
489 
490     // Push message part to front
push(message_t && message)491     void push(message_t &&message) { m_parts.push_front(std::move(message)); }
492 
493     // Push message part to back
add(message_t && message)494     void add(message_t &&message) { m_parts.push_back(std::move(message)); }
495 
496     // Alias to allow std::back_inserter()
push_back(message_t && message)497     void push_back(message_t &&message) { m_parts.push_back(std::move(message)); }
498 
499     // Pop string from front
popstr()500     std::string popstr()
501     {
502         std::string string(m_parts.front().data<char>(), m_parts.front().size());
503         m_parts.pop_front();
504         return string;
505     }
506 
507     // Pop type (fixed-size) from front
poptyp()508     template<typename T> T poptyp()
509     {
510         static_assert(!std::is_same<T, std::string>::value,
511                       "Use popstr() instead of poptyp<std::string>()");
512         if (sizeof(T) != m_parts.front().size())
513             throw std::runtime_error(
514               "Invalid type, size does not match the message size");
515         T type = *m_parts.front().data<T>();
516         m_parts.pop_front();
517         return type;
518     }
519 
520     // Pop message part from front
pop()521     message_t pop()
522     {
523         message_t message = std::move(m_parts.front());
524         m_parts.pop_front();
525         return message;
526     }
527 
528     // Pop message part from back
remove()529     message_t remove()
530     {
531         message_t message = std::move(m_parts.back());
532         m_parts.pop_back();
533         return message;
534     }
535 
536     // get message part from front
front()537     const message_t &front() { return m_parts.front(); }
538 
539     // get message part from back
back()540     const message_t &back() { return m_parts.back(); }
541 
542     // Get pointer to a specific message part
peek(size_t index) const543     const message_t *peek(size_t index) const { return &m_parts[index]; }
544 
545     // Get a string copy of a specific message part
peekstr(size_t index) const546     std::string peekstr(size_t index) const
547     {
548         std::string string(m_parts[index].data<char>(), m_parts[index].size());
549         return string;
550     }
551 
552     // Peek type (fixed-size) from front
peektyp(size_t index) const553     template<typename T> T peektyp(size_t index) const
554     {
555         static_assert(!std::is_same<T, std::string>::value,
556                       "Use peekstr() instead of peektyp<std::string>()");
557         if (sizeof(T) != m_parts[index].size())
558             throw std::runtime_error(
559               "Invalid type, size does not match the message size");
560         T type = *m_parts[index].data<T>();
561         return type;
562     }
563 
564     // Create multipart from type (fixed-size)
create(const T & type)565     template<typename T> static multipart_t create(const T &type)
566     {
567         multipart_t multipart;
568         multipart.addtyp(type);
569         return multipart;
570     }
571 
572     // Copy multipart
clone() const573     multipart_t clone() const
574     {
575         multipart_t multipart;
576         for (size_t i = 0; i < size(); i++)
577             multipart.addmem(m_parts[i].data(), m_parts[i].size());
578         return multipart;
579     }
580 
581     // Dump content to string
str() const582     std::string str() const
583     {
584         std::stringstream ss;
585         for (size_t i = 0; i < m_parts.size(); i++) {
586             const unsigned char *data = m_parts[i].data<unsigned char>();
587             size_t size = m_parts[i].size();
588 
589             // Dump the message as text or binary
590             bool isText = true;
591             for (size_t j = 0; j < size; j++) {
592                 if (data[j] < 32 || data[j] > 127) {
593                     isText = false;
594                     break;
595                 }
596             }
597             ss << "\n[" << std::dec << std::setw(3) << std::setfill('0') << size
598                << "] ";
599             if (size >= 1000) {
600                 ss << "... (too big to print)";
601                 continue;
602             }
603             for (size_t j = 0; j < size; j++) {
604                 if (isText)
605                     ss << static_cast<char>(data[j]);
606                 else
607                     ss << std::hex << std::setw(2) << std::setfill('0')
608                        << static_cast<short>(data[j]);
609             }
610         }
611         return ss.str();
612     }
613 
614     // Check if equal to other multipart
equal(const multipart_t * other) const615     bool equal(const multipart_t *other) const ZMQ_NOTHROW
616     {
617         return *this == *other;
618     }
619 
operator ==(const multipart_t & other) const620     bool operator==(const multipart_t &other) const ZMQ_NOTHROW
621     {
622         if (size() != other.size())
623             return false;
624         for (size_t i = 0; i < size(); i++)
625             if (at(i) != other.at(i))
626                 return false;
627         return true;
628     }
629 
operator !=(const multipart_t & other) const630     bool operator!=(const multipart_t &other) const ZMQ_NOTHROW
631     {
632         return !(*this == other);
633     }
634 
635 #ifdef ZMQ_CPP11
636 
637     // Return single part message_t encoded from this multipart_t.
encode() const638     message_t encode() const { return zmq::encode(*this); }
639 
640     // Decode encoded message into multiple parts and append to self.
decode_append(const message_t & encoded)641     void decode_append(const message_t &encoded)
642     {
643         zmq::decode(encoded, std::back_inserter(*this));
644     }
645 
646     // Return a new multipart_t containing the decoded message_t.
decode(const message_t & encoded)647     static multipart_t decode(const message_t &encoded)
648     {
649         multipart_t tmp;
650         zmq::decode(encoded, std::back_inserter(tmp));
651         return tmp;
652     }
653 
654 #endif
655 
656   private:
657     // Disable implicit copying (moving is more efficient)
658     multipart_t(const multipart_t &other) ZMQ_DELETED_FUNCTION;
659     void operator=(const multipart_t &other) ZMQ_DELETED_FUNCTION;
660 }; // class multipart_t
661 
operator <<(std::ostream & os,const multipart_t & msg)662 inline std::ostream &operator<<(std::ostream &os, const multipart_t &msg)
663 {
664     return os << msg.str();
665 }
666 
667 #endif // ZMQ_HAS_RVALUE_REFS
668 
669 #if defined(ZMQ_BUILD_DRAFT_API) && defined(ZMQ_CPP11) && defined(ZMQ_HAVE_POLLER)
670 class active_poller_t
671 {
672   public:
673     active_poller_t() = default;
674     ~active_poller_t() = default;
675 
676     active_poller_t(const active_poller_t &) = delete;
677     active_poller_t &operator=(const active_poller_t &) = delete;
678 
679     active_poller_t(active_poller_t &&src) = default;
680     active_poller_t &operator=(active_poller_t &&src) = default;
681 
682     using handler_type = std::function<void(event_flags)>;
683 
add(zmq::socket_ref socket,event_flags events,handler_type handler)684     void add(zmq::socket_ref socket, event_flags events, handler_type handler)
685     {
686         if (!handler)
687             throw std::invalid_argument("null handler in active_poller_t::add");
688         auto ret = handlers.emplace(
689           socket, std::make_shared<handler_type>(std::move(handler)));
690         if (!ret.second)
691             throw error_t(EINVAL); // already added
692         try {
693             base_poller.add(socket, events, ret.first->second.get());
694             need_rebuild = true;
695         }
696         catch (...) {
697             // rollback
698             handlers.erase(socket);
699             throw;
700         }
701     }
702 
remove(zmq::socket_ref socket)703     void remove(zmq::socket_ref socket)
704     {
705         base_poller.remove(socket);
706         handlers.erase(socket);
707         need_rebuild = true;
708     }
709 
modify(zmq::socket_ref socket,event_flags events)710     void modify(zmq::socket_ref socket, event_flags events)
711     {
712         base_poller.modify(socket, events);
713     }
714 
wait(std::chrono::milliseconds timeout)715     size_t wait(std::chrono::milliseconds timeout)
716     {
717         if (need_rebuild) {
718             poller_events.resize(handlers.size());
719             poller_handlers.clear();
720             poller_handlers.reserve(handlers.size());
721             for (const auto &handler : handlers) {
722                 poller_handlers.push_back(handler.second);
723             }
724             need_rebuild = false;
725         }
726         const auto count = base_poller.wait_all(poller_events, timeout);
727         std::for_each(poller_events.begin(),
728                       poller_events.begin() + static_cast<ptrdiff_t>(count),
729                       [](decltype(base_poller)::event_type &event) {
730                           assert(event.user_data != nullptr);
731                           (*event.user_data)(event.events);
732                       });
733         return count;
734     }
735 
empty() const736     ZMQ_NODISCARD bool empty() const noexcept { return handlers.empty(); }
737 
size() const738     size_t size() const noexcept { return handlers.size(); }
739 
740   private:
741     bool need_rebuild{false};
742 
743     poller_t<handler_type> base_poller{};
744     std::unordered_map<socket_ref, std::shared_ptr<handler_type>> handlers{};
745     std::vector<decltype(base_poller)::event_type> poller_events{};
746     std::vector<std::shared_ptr<handler_type>> poller_handlers{};
747 };     // class active_poller_t
748 #endif //  defined(ZMQ_BUILD_DRAFT_API) && defined(ZMQ_CPP11) && defined(ZMQ_HAVE_POLLER)
749 
750 
751 } // namespace zmq
752 
753 #endif // __ZMQ_ADDON_HPP_INCLUDED__
754