1# Derived from libmux, available in Plan 9 under /sys/src/libmux
2# under the following terms:
3#
4# Copyright (C) 2003-2006 Russ Cox, Massachusetts Institute of Technology
5#
6# Permission is hereby granted, free of charge, to any person obtaining
7# a copy of this software and associated documentation files (the
8# "Software"), to deal in the Software without restriction, including
9# without limitation the rights to use, copy, modify, merge, publish,
10# distribute, sublicense, and/or sell copies of the Software, and to
11# permit persons to whom the Software is furnished to do so, subject to
12# the following conditions:
13#
14# The above copyright notice and this permission notice shall be
15# included in all copies or substantial portions of the Software.
16
17import sys
18import traceback
19
20from pyxp import fields
21from pyxp.dial import dial
22from threading import *
23Condition = Condition().__class__
24
25__all__ = 'Mux',
26
27class Mux(object):
28    def __init__(self, con, process, flush=None, mintag=0, maxtag=1<<16 - 1):
29        self.queue = set()
30        self.lock = RLock()
31        self.rendez = Condition(self.lock)
32        self.outlock = RLock()
33        self.inlock = RLock()
34        self.process = process
35        self.flush = flush
36        self.wait = {}
37        self.free = set(range(mintag, maxtag))
38        self.mintag = mintag
39        self.maxtag = maxtag
40        self.muxer = None
41
42        if isinstance(con, basestring):
43            con = dial(con)
44        self.fd = con
45
46        if self.fd is None:
47            raise Exception("No connection")
48
49    def mux(self, rpc):
50        try:
51            rpc.waiting = True
52            self.lock.acquire()
53            while self.muxer and self.muxer != rpc and rpc.data is None:
54                rpc.wait()
55
56            if rpc.data is None:
57                assert not self.muxer or self.muxer is rpc
58                self.muxer = rpc
59                self.lock.release()
60                try:
61                    while rpc.data is None:
62                        data = self.recv()
63                        if data is None:
64                            self.lock.acquire()
65                            self.queue.remove(rpc)
66                            raise Exception("unexpected eof")
67                        self.dispatch(data)
68                finally:
69                    self.lock.acquire()
70                    self.electmuxer()
71        except Exception, e:
72            traceback.print_exc(sys.stdout)
73            if self.flush:
74                self.flush(self, rpc.data)
75            raise e
76        finally:
77            if self.lock._is_owned():
78                self.lock.release()
79
80        if rpc.async:
81            if callable(rpc.async):
82                rpc.async(self, rpc.data)
83        else:
84            return rpc.data
85
86    def rpc(self, dat, async=None):
87        rpc = self.newrpc(dat, async)
88        if async:
89            with self.lock:
90                if self.muxer is None:
91                    self.electmuxer()
92        else:
93            return self.mux(rpc)
94
95    def electmuxer(self):
96        async = None
97        for rpc in self.queue:
98            if self.muxer != rpc:
99                if rpc.async:
100                    async = rpc
101                else:
102                    self.muxer = rpc
103                    rpc.notify()
104                    return
105        self.muxer = None
106        if async:
107            self.muxer = async
108            t = Thread(target=self.mux, args=(async,))
109            t.daemon = True
110            t.start()
111
112    def dispatch(self, dat):
113        tag = dat.tag
114        rpc = None
115        with self.lock:
116            rpc = self.wait.get(tag, None)
117            if rpc is None or rpc not in self.queue:
118                #print "bad rpc tag: %u (no one waiting on it)" % dat.tag
119                return
120            self.puttag(rpc)
121            self.queue.remove(rpc)
122            rpc.dispatch(dat)
123
124    def gettag(self, r):
125        tag = 0
126
127        while not self.free:
128            self.rendez.wait()
129
130        tag = self.free.pop()
131
132        if tag in self.wait:
133            raise Exception("nwait botch")
134
135        self.wait[tag] = r
136
137        r.tag = tag
138        r.orig.tag = r.tag
139        return r.tag
140
141    def puttag(self, rpc):
142        if rpc.tag in self.wait:
143            del self.wait[rpc.tag]
144        self.free.add(rpc.tag)
145        self.rendez.notify()
146
147    def send(self, dat):
148        data = ''.join(dat.marshall())
149        n = self.fd.send(data)
150        return n == len(data)
151    def recv(self):
152        try:
153            with self.inlock:
154                data = self.fd.recv(4)
155                if data:
156                    len = fields.Int.decoders[4](data, 0)
157                    data += self.fd.recv(len - 4)
158                    return self.process(data)
159        except Exception, e:
160            traceback.print_exc(sys.stdout)
161            print repr(data)
162            return None
163
164    def newrpc(self, dat, async=None):
165        rpc = Rpc(self, dat, async)
166        tag = None
167
168        with self.lock:
169            self.gettag(rpc)
170            self.queue.add(rpc)
171
172        if rpc.tag >= 0 and self.send(dat):
173            return rpc
174
175        with self.lock:
176            self.queue.remove(rpc)
177            self.puttag(rpc)
178
179class Rpc(Condition):
180    def __init__(self, mux, data, async=None):
181        super(Rpc, self).__init__(mux.lock)
182        self.mux = mux
183        self.orig = data
184        self.data = None
185        self.waiting = False
186        self.async = async
187
188    def dispatch(self, data=None):
189        self.data = data
190        if not self.async or self.waiting:
191            self.notify()
192        elif callable(self.async):
193            Thread(target=self.async, args=(self.mux, data)).start()
194
195# vim:se sts=4 sw=4 et:
196