1package zmq4
2
3/*
4#include <zmq.h>
5#include "zmq4.h"
6*/
7import "C"
8
9import (
10	"fmt"
11	"time"
12)
13
14// Return type for (*Poller)Poll
15type Polled struct {
16	Socket *Socket // socket with matched event(s)
17	Events State   // actual matched event(s)
18}
19
20type Poller struct {
21	items []C.zmq_pollitem_t
22	socks []*Socket
23}
24
25// Create a new Poller
26func NewPoller() *Poller {
27	return &Poller{
28		items: make([]C.zmq_pollitem_t, 0),
29		socks: make([]*Socket, 0),
30	}
31}
32
33// Add items to the poller
34//
35// Events is a bitwise OR of zmq.POLLIN and zmq.POLLOUT
36//
37// Returns the id of the item, which can be used as a handle to
38// (*Poller)Update and as an index into the result of (*Poller)PollAll
39func (p *Poller) Add(soc *Socket, events State) int {
40	var item C.zmq_pollitem_t
41	item.socket = soc.soc
42	item.fd = 0
43	item.events = C.short(events)
44	p.items = append(p.items, item)
45	p.socks = append(p.socks, soc)
46	return len(p.items) - 1
47}
48
49// Update the events mask of a socket in the poller
50//
51// Replaces the Poller's bitmask for the specified id with the events parameter passed
52//
53// Returns the previous value, or ErrorNoSocket if the id was out of range
54func (p *Poller) Update(id int, events State) (previous State, err error) {
55	if id >= 0 && id < len(p.items) {
56		previous = State(p.items[id].events)
57		p.items[id].events = C.short(events)
58		return previous, nil
59	}
60	return 0, ErrorNoSocket
61}
62
63// Update the events mask of a socket in the poller
64//
65// Replaces the Poller's bitmask for the specified socket with the events parameter passed
66//
67// Returns the previous value, or ErrorNoSocket if the socket didn't match
68func (p *Poller) UpdateBySocket(soc *Socket, events State) (previous State, err error) {
69	for id, s := range p.socks {
70		if s == soc {
71			previous = State(p.items[id].events)
72			p.items[id].events = C.short(events)
73			return previous, nil
74		}
75	}
76	return 0, ErrorNoSocket
77}
78
79// Remove a socket from the poller
80//
81// Returns ErrorNoSocket if the id was out of range
82func (p *Poller) Remove(id int) error {
83	if id >= 0 && id < len(p.items) {
84		if id == len(p.items)-1 {
85			p.items = p.items[:id]
86			p.socks = p.socks[:id]
87		} else {
88			p.items = append(p.items[:id], p.items[id+1:]...)
89			p.socks = append(p.socks[:id], p.socks[id+1:]...)
90		}
91		return nil
92	}
93	return ErrorNoSocket
94}
95
96// Remove a socket from the poller
97//
98// Returns ErrorNoSocket if the socket didn't match
99func (p *Poller) RemoveBySocket(soc *Socket) error {
100	for id, s := range p.socks {
101		if s == soc {
102			return p.Remove(id)
103		}
104	}
105	return ErrorNoSocket
106}
107
108/*
109Input/output multiplexing
110
111If timeout < 0, wait forever until a matching event is detected
112
113Only sockets with matching socket events are returned in the list.
114
115Example:
116
117    poller := zmq.NewPoller()
118    poller.Add(socket0, zmq.POLLIN)
119    poller.Add(socket1, zmq.POLLIN)
120    //  Process messages from both sockets
121    for {
122        sockets, _ := poller.Poll(-1)
123        for _, socket := range sockets {
124            switch s := socket.Socket; s {
125            case socket0:
126                msg, _ := s.Recv(0)
127                //  Process msg
128            case socket1:
129                msg, _ := s.Recv(0)
130                //  Process msg
131            }
132        }
133    }
134*/
135func (p *Poller) Poll(timeout time.Duration) ([]Polled, error) {
136	return p.poll(timeout, false)
137}
138
139/*
140This is like (*Poller)Poll, but it returns a list of all sockets,
141in the same order as they were added to the poller,
142not just those sockets that had an event.
143
144For each socket in the list, you have to check the Events field
145to see if there was actually an event.
146
147When error is not nil, the return list contains no sockets.
148*/
149func (p *Poller) PollAll(timeout time.Duration) ([]Polled, error) {
150	return p.poll(timeout, true)
151}
152
153func (p *Poller) poll(timeout time.Duration, all bool) ([]Polled, error) {
154	lst := make([]Polled, 0, len(p.items))
155
156	var ctx *Context
157	for _, soc := range p.socks {
158		if !soc.opened {
159			return lst, ErrorSocketClosed
160		}
161		// assume all sockets have the same context
162		ctx = soc.ctx
163	}
164
165	t := timeout
166	if t > 0 {
167		t = t / time.Millisecond
168	}
169	if t < 0 {
170		t = -1
171	}
172	var rv C.int
173	var err error
174	for {
175		rv, err = C.zmq4_poll(&p.items[0], C.int(len(p.items)), C.long(t))
176		if rv >= 0 || ctx == nil || !ctx.retry(err) {
177			break
178		}
179	}
180	if rv < 0 {
181		return lst, errget(err)
182	}
183	for i, it := range p.items {
184		if all || it.events&it.revents != 0 {
185			lst = append(lst, Polled{p.socks[i], State(it.revents)})
186		}
187	}
188	return lst, nil
189}
190
191// Poller as string.
192func (p *Poller) String() string {
193	str := make([]string, 0)
194	for i, poll := range p.items {
195		str = append(str, fmt.Sprintf("%v%v", p.socks[i], State(poll.events)))
196	}
197	return fmt.Sprint("Poller", str)
198}
199