1 /*
2  * Copyright (c) 2005, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  *
26  */
27 
28 import java.net.*;
29 import java.io.*;
30 import java.nio.*;
31 import java.nio.channels.*;
32 import sun.net.www.MessageHeader;
33 import java.util.*;
34 
35 public class TunnelProxy {
36 
37     ServerSocketChannel schan;
38     int threads;
39     int cperthread;
40     Server[] servers;
41 
42     /**
43      * Create a <code>TunnelProxy<code> instance with the specified callback object
44      * for handling requests. One thread is created to handle requests,
45      * and up to ten TCP connections will be handled simultaneously.
46      * @param cb the callback object which is invoked to handle each
47      *  incoming request
48      */
49 
TunnelProxy()50     public TunnelProxy () throws IOException {
51         this (1, 10, 0);
52     }
53 
54     /**
55      * Create a <code>TunnelProxy<code> instance with the specified number of
56      * threads and maximum number of connections per thread. This functions
57      * the same as the 4 arg constructor, where the port argument is set to zero.
58      * @param cb the callback object which is invoked to handle each
59      *     incoming request
60      * @param threads the number of threads to create to handle requests
61      *     in parallel
62      * @param cperthread the number of simultaneous TCP connections to
63      *     handle per thread
64      */
65 
TunnelProxy(int threads, int cperthread)66     public TunnelProxy (int threads, int cperthread)
67         throws IOException {
68         this (threads, cperthread, 0);
69     }
70 
71     /**
72      * Create a <code>TunnelProxy<code> instance with the specified number
73      * of threads and maximum number of connections per thread and running on
74      * the specified port. The specified number of threads are created to
75      * handle incoming requests, and each thread is allowed
76      * to handle a number of simultaneous TCP connections.
77      * @param cb the callback object which is invoked to handle
78      *  each incoming request
79      * @param threads the number of threads to create to handle
80      *  requests in parallel
81      * @param cperthread the number of simultaneous TCP connections
82      *  to handle per thread
83      * @param port the port number to bind the server to. <code>Zero</code>
84      *  means choose any free port.
85      */
86 
TunnelProxy(int threads, int cperthread, int port)87     public TunnelProxy (int threads, int cperthread, int port)
88         throws IOException {
89         schan = ServerSocketChannel.open ();
90         InetSocketAddress addr = new InetSocketAddress (port);
91         schan.socket().bind (addr);
92         this.threads = threads;
93         this.cperthread = cperthread;
94         servers = new Server [threads];
95         for (int i=0; i<threads; i++) {
96             servers[i] = new Server (schan, cperthread);
97             servers[i].start();
98         }
99     }
100 
101     /** Tell all threads in the server to exit within 5 seconds.
102      *  This is an abortive termination. Just prior to the thread exiting
103      *  all channels in that thread waiting to be closed are forceably closed.
104      */
105 
terminate()106     public void terminate () {
107         for (int i=0; i<threads; i++) {
108             servers[i].terminate ();
109         }
110     }
111 
112     /**
113      * return the local port number to which the server is bound.
114      * @return the local port number
115      */
116 
getLocalPort()117     public int getLocalPort () {
118         return schan.socket().getLocalPort ();
119     }
120 
121     static class Server extends Thread {
122 
123         ServerSocketChannel schan;
124         Selector selector;
125         SelectionKey listenerKey;
126         SelectionKey key; /* the current key being processed */
127         ByteBuffer consumeBuffer;
128         int maxconn;
129         int nconn;
130         ClosedChannelList clist;
131         boolean shutdown;
132         Pipeline pipe1 = null;
133         Pipeline pipe2 = null;
134 
Server(ServerSocketChannel schan, int maxconn)135         Server (ServerSocketChannel schan, int maxconn) {
136             this.schan = schan;
137             this.maxconn = maxconn;
138             nconn = 0;
139             consumeBuffer = ByteBuffer.allocate (512);
140             clist = new ClosedChannelList ();
141             try {
142                 selector = Selector.open ();
143                 schan.configureBlocking (false);
144                 listenerKey = schan.register (selector, SelectionKey.OP_ACCEPT);
145             } catch (IOException e) {
146                 System.err.println ("Server could not start: " + e);
147             }
148         }
149 
150         /* Stop the thread as soon as possible */
terminate()151         public synchronized void terminate () {
152             shutdown = true;
153             if (pipe1 != null) pipe1.terminate();
154             if (pipe2 != null) pipe2.terminate();
155         }
156 
run()157         public void run ()  {
158             try {
159                 while (true) {
160                     selector.select (1000);
161                     Set selected = selector.selectedKeys();
162                     Iterator iter = selected.iterator();
163                     while (iter.hasNext()) {
164                         key = (SelectionKey)iter.next();
165                         if (key.equals (listenerKey)) {
166                             SocketChannel sock = schan.accept ();
167                             if (sock == null) {
168                                 /* false notification */
169                                 iter.remove();
170                                 continue;
171                             }
172                             sock.configureBlocking (false);
173                             sock.register (selector, SelectionKey.OP_READ);
174                             nconn ++;
175                             if (nconn == maxconn) {
176                                 /* deregister */
177                                 listenerKey.cancel ();
178                                 listenerKey = null;
179                             }
180                         } else {
181                             if (key.isReadable()) {
182                                 boolean closed;
183                                 SocketChannel chan = (SocketChannel) key.channel();
184                                 if (key.attachment() != null) {
185                                     closed = consume (chan);
186                                 } else {
187                                     closed = read (chan, key);
188                                 }
189                                 if (closed) {
190                                     chan.close ();
191                                     key.cancel ();
192                                     if (nconn == maxconn) {
193                                         listenerKey = schan.register (selector, SelectionKey.OP_ACCEPT);
194                                     }
195                                     nconn --;
196                                 }
197                             }
198                         }
199                         iter.remove();
200                     }
201                     clist.check();
202                     if (shutdown) {
203                         clist.terminate ();
204                         return;
205                     }
206                 }
207             } catch (IOException e) {
208                 System.out.println ("Server exception: " + e);
209                 // TODO finish
210             }
211         }
212 
213         /* read all the data off the channel without looking at it
214              * return true if connection closed
215              */
consume(SocketChannel chan)216         boolean consume (SocketChannel chan) {
217             try {
218                 consumeBuffer.clear ();
219                 int c = chan.read (consumeBuffer);
220                 if (c == -1)
221                     return true;
222             } catch (IOException e) {
223                 return true;
224             }
225             return false;
226         }
227 
228         /* return true if the connection is closed, false otherwise */
229 
read(SocketChannel chan, SelectionKey key)230         private boolean read (SocketChannel chan, SelectionKey key) {
231             HttpTransaction msg;
232             boolean res;
233             try {
234                 InputStream is = new BufferedInputStream (new NioInputStream (chan));
235                 String requestline = readLine (is);
236                 MessageHeader mhead = new MessageHeader (is);
237                 String[] req = requestline.split (" ");
238                 if (req.length < 2) {
239                     /* invalid request line */
240                     return false;
241                 }
242                 String cmd = req[0];
243                 URI uri = null;
244                 if (!("CONNECT".equalsIgnoreCase(cmd))) {
245                     // we expect CONNECT command
246                     return false;
247                 }
248                 try {
249                     uri = new URI("http://" + req[1]);
250                 } catch (URISyntaxException e) {
251                     System.err.println ("Invalid URI: " + e);
252                     res = true;
253                 }
254 
255                 // CONNECT ack
256                 OutputStream os = new BufferedOutputStream(new NioOutputStream(chan));
257                 byte[] ack = "HTTP/1.1 200 Connection established\r\n\r\n".getBytes();
258                 os.write(ack, 0, ack.length);
259                 os.flush();
260 
261                 // tunnel anything else
262                 tunnel(is, os, uri);
263 
264                 res = false;
265             } catch (IOException e) {
266                 res = true;
267             }
268             return res;
269         }
270 
tunnel(InputStream fromClient, OutputStream toClient, URI serverURI)271         private void tunnel(InputStream fromClient, OutputStream toClient, URI serverURI) throws IOException {
272             Socket sockToServer = new Socket(serverURI.getHost(), serverURI.getPort());
273             OutputStream toServer = sockToServer.getOutputStream();
274             InputStream fromServer = sockToServer.getInputStream();
275 
276             pipe1 = new Pipeline(fromClient, toServer);
277             pipe2 = new Pipeline(fromServer, toClient);
278             // start pump
279             pipe1.start();
280             pipe2.start();
281             // wait them to end
282             try {
283                 pipe1.join();
284             } catch (InterruptedException e) {
285                 // No-op
286             } finally {
287                 sockToServer.close();
288             }
289         }
290 
readLine(InputStream is)291         private String readLine (InputStream is) throws IOException {
292             boolean done=false, readCR=false;
293             byte[] b = new byte [512];
294             int c, l = 0;
295 
296             while (!done) {
297                 c = is.read ();
298                 if (c == '\n' && readCR) {
299                     done = true;
300                 } else {
301                     if (c == '\r' && !readCR) {
302                         readCR = true;
303                     } else {
304                         b[l++] = (byte)c;
305                     }
306                 }
307             }
308             return new String (b);
309         }
310 
311         /** close the channel associated with the current key by:
312          * 1. shutdownOutput (send a FIN)
313          * 2. mark the key so that incoming data is to be consumed and discarded
314          * 3. After a period, close the socket
315          */
316 
orderlyCloseChannel(SelectionKey key)317         synchronized void orderlyCloseChannel (SelectionKey key) throws IOException {
318             SocketChannel ch = (SocketChannel)key.channel ();
319             ch.socket().shutdownOutput();
320             key.attach (this);
321             clist.add (key);
322         }
323 
abortiveCloseChannel(SelectionKey key)324         synchronized void abortiveCloseChannel (SelectionKey key) throws IOException {
325             SocketChannel ch = (SocketChannel)key.channel ();
326             Socket s = ch.socket ();
327             s.setSoLinger (true, 0);
328             ch.close();
329         }
330     }
331 
332 
333     /**
334      * Implements blocking reading semantics on top of a non-blocking channel
335      */
336 
337     static class NioInputStream extends InputStream {
338         SocketChannel channel;
339         Selector selector;
340         ByteBuffer chanbuf;
341         SelectionKey key;
342         int available;
343         byte[] one;
344         boolean closed;
345         ByteBuffer markBuf; /* reads may be satisifed from this buffer */
346         boolean marked;
347         boolean reset;
348         int readlimit;
349 
NioInputStream(SocketChannel chan)350         public NioInputStream (SocketChannel chan) throws IOException {
351             this.channel = chan;
352             selector = Selector.open();
353             chanbuf = ByteBuffer.allocate (1024);
354             key = chan.register (selector, SelectionKey.OP_READ);
355             available = 0;
356             one = new byte[1];
357             closed = marked = reset = false;
358         }
359 
read(byte[] b)360         public synchronized int read (byte[] b) throws IOException {
361             return read (b, 0, b.length);
362         }
363 
read()364         public synchronized int read () throws IOException {
365             return read (one, 0, 1);
366         }
367 
read(byte[] b, int off, int srclen)368         public synchronized int read (byte[] b, int off, int srclen) throws IOException {
369 
370             int canreturn, willreturn;
371 
372             if (closed)
373                 return -1;
374 
375             if (reset) { /* satisfy from markBuf */
376                 canreturn = markBuf.remaining ();
377                 willreturn = canreturn>srclen ? srclen : canreturn;
378                 markBuf.get(b, off, willreturn);
379                 if (canreturn == willreturn) {
380                     reset = false;
381                 }
382             } else { /* satisfy from channel */
383                 canreturn = available();
384                 if (canreturn == 0) {
385                     block ();
386                     canreturn = available();
387                 }
388                 willreturn = canreturn>srclen ? srclen : canreturn;
389                 chanbuf.get(b, off, willreturn);
390                 available -= willreturn;
391 
392                 if (marked) { /* copy into markBuf */
393                     try {
394                         markBuf.put (b, off, willreturn);
395                     } catch (BufferOverflowException e) {
396                         marked = false;
397                     }
398                 }
399             }
400             return willreturn;
401         }
402 
available()403         public synchronized int available () throws IOException {
404             if (closed)
405                 throw new IOException ("Stream is closed");
406 
407             if (reset)
408                 return markBuf.remaining();
409 
410             if (available > 0)
411                 return available;
412 
413             chanbuf.clear ();
414             available = channel.read (chanbuf);
415             if (available > 0)
416                 chanbuf.flip();
417             else if (available == -1)
418                 throw new IOException ("Stream is closed");
419             return available;
420         }
421 
422         /**
423          * block() only called when available==0 and buf is empty
424          */
block()425         private synchronized void block () throws IOException {
426             //assert available == 0;
427             int n = selector.select ();
428             //assert n == 1;
429             selector.selectedKeys().clear();
430             available ();
431         }
432 
close()433         public void close () throws IOException {
434             if (closed)
435                 return;
436             channel.close ();
437             closed = true;
438         }
439 
mark(int readlimit)440         public synchronized void mark (int readlimit) {
441             if (closed)
442                 return;
443             this.readlimit = readlimit;
444             markBuf = ByteBuffer.allocate (readlimit);
445             marked = true;
446             reset = false;
447         }
448 
reset()449         public synchronized void reset () throws IOException {
450             if (closed )
451                 return;
452             if (!marked)
453                 throw new IOException ("Stream not marked");
454             marked = false;
455             reset = true;
456             markBuf.flip ();
457         }
458     }
459 
460     static class NioOutputStream extends OutputStream {
461         SocketChannel channel;
462         ByteBuffer buf;
463         SelectionKey key;
464         Selector selector;
465         boolean closed;
466         byte[] one;
467 
NioOutputStream(SocketChannel channel)468         public NioOutputStream (SocketChannel channel) throws IOException {
469             this.channel = channel;
470             selector = Selector.open ();
471             key = channel.register (selector, SelectionKey.OP_WRITE);
472             closed = false;
473             one = new byte [1];
474         }
475 
write(int b)476         public synchronized void write (int b) throws IOException {
477             one[0] = (byte)b;
478             write (one, 0, 1);
479         }
480 
write(byte[] b)481         public synchronized void write (byte[] b) throws IOException {
482             write (b, 0, b.length);
483         }
484 
write(byte[] b, int off, int len)485         public synchronized void write (byte[] b, int off, int len) throws IOException {
486             if (closed)
487                 throw new IOException ("stream is closed");
488 
489             buf = ByteBuffer.allocate (len);
490             buf.put (b, off, len);
491             buf.flip ();
492             int n;
493             while ((n = channel.write (buf)) < len) {
494                 len -= n;
495                 if (len == 0)
496                     return;
497                 selector.select ();
498                 selector.selectedKeys().clear ();
499             }
500         }
501 
close()502         public void close () throws IOException {
503             if (closed)
504                 return;
505             channel.close ();
506             closed = true;
507         }
508     }
509 
510     /*
511      * Pipeline object :-
512      * 1) Will pump every byte from its input stream to output stream
513      * 2) Is an 'active object'
514      */
515     static class Pipeline implements Runnable {
516         InputStream in;
517         OutputStream out;
518         Thread t;
519 
Pipeline(InputStream is, OutputStream os)520         public Pipeline(InputStream is, OutputStream os) {
521             in = is;
522             out = os;
523         }
524 
start()525         public void start() {
526             t = new Thread(this);
527             t.start();
528         }
529 
join()530         public void join() throws InterruptedException {
531             t.join();
532         }
533 
terminate()534         public void terminate() {
535             t.interrupt();
536         }
537 
run()538         public void run() {
539             byte[] buffer = new byte[10000];
540             try {
541                 while (!Thread.interrupted()) {
542                     int len;
543                     while ((len = in.read(buffer)) != -1) {
544                         out.write(buffer, 0, len);
545                         out.flush();
546                     }
547                 }
548             } catch(IOException e) {
549                 // No-op
550             } finally {
551             }
552         }
553     }
554 
555     /**
556      * Utilities for synchronization. A condition is
557      * identified by a string name, and is initialized
558      * upon first use (ie. setCondition() or waitForCondition()). Threads
559      * are blocked until some thread calls (or has called) setCondition() for the same
560      * condition.
561      * <P>
562      * A rendezvous built on a condition is also provided for synchronizing
563      * N threads.
564      */
565 
566     private static HashMap conditions = new HashMap();
567 
568     /*
569      * Modifiable boolean object
570      */
571     private static class BValue {
572         boolean v;
573     }
574 
575     /*
576      * Modifiable int object
577      */
578     private static class IValue {
579         int v;
IValue(int i)580         IValue (int i) {
581             v =i;
582         }
583     }
584 
585 
getCond(String condition)586     private static BValue getCond (String condition) {
587         synchronized (conditions) {
588             BValue cond = (BValue) conditions.get (condition);
589             if (cond == null) {
590                 cond = new BValue();
591                 conditions.put (condition, cond);
592             }
593             return cond;
594         }
595     }
596 
597     /**
598      * Set the condition to true. Any threads that are currently blocked
599      * waiting on the condition, will be unblocked and allowed to continue.
600      * Threads that subsequently call waitForCondition() will not block.
601      * If the named condition did not exist prior to the call, then it is created
602      * first.
603      */
604 
setCondition(String condition)605     public static void setCondition (String condition) {
606         BValue cond = getCond (condition);
607         synchronized (cond) {
608             if (cond.v) {
609                 return;
610             }
611             cond.v = true;
612             cond.notifyAll();
613         }
614     }
615 
616     /**
617      * If the named condition does not exist, then it is created and initialized
618      * to false. If the condition exists or has just been created and its value
619      * is false, then the thread blocks until another thread sets the condition.
620      * If the condition exists and is already set to true, then this call returns
621      * immediately without blocking.
622      */
623 
waitForCondition(String condition)624     public static void waitForCondition (String condition) {
625         BValue cond = getCond (condition);
626         synchronized (cond) {
627             if (!cond.v) {
628                 try {
629                     cond.wait();
630                 } catch (InterruptedException e) {}
631             }
632         }
633     }
634 
635     /* conditions must be locked when accessing this */
636     static HashMap rv = new HashMap();
637 
638     /**
639      * Force N threads to rendezvous (ie. wait for each other) before proceeding.
640      * The first thread(s) to call are blocked until the last
641      * thread makes the call. Then all threads continue.
642      * <p>
643      * All threads that call with the same condition name, must use the same value
644      * for N (or the results may be not be as expected).
645      * <P>
646      * Obviously, if fewer than N threads make the rendezvous then the result
647      * will be a hang.
648      */
649 
rendezvous(String condition, int N)650     public static void rendezvous (String condition, int N) {
651         BValue cond;
652         IValue iv;
653         String name = "RV_"+condition;
654 
655         /* get the condition */
656 
657         synchronized (conditions) {
658             cond = (BValue)conditions.get (name);
659             if (cond == null) {
660                 /* we are first caller */
661                 if (N < 2) {
662                     throw new RuntimeException ("rendezvous must be called with N >= 2");
663                 }
664                 cond = new BValue ();
665                 conditions.put (name, cond);
666                 iv = new IValue (N-1);
667                 rv.put (name, iv);
668             } else {
669                 /* already initialised, just decrement the counter */
670                 iv = (IValue) rv.get (name);
671                 iv.v --;
672             }
673         }
674 
675         if (iv.v > 0) {
676             waitForCondition (name);
677         } else {
678             setCondition (name);
679             synchronized (conditions) {
680                 clearCondition (name);
681                 rv.remove (name);
682             }
683         }
684     }
685 
686     /**
687      * If the named condition exists and is set then remove it, so it can
688      * be re-initialized and used again. If the condition does not exist, or
689      * exists but is not set, then the call returns without doing anything.
690      * Note, some higher level synchronization
691      * may be needed between clear and the other operations.
692      */
693 
clearCondition(String condition)694     public static void clearCondition(String condition) {
695         BValue cond;
696         synchronized (conditions) {
697             cond = (BValue) conditions.get (condition);
698             if (cond == null) {
699                 return;
700             }
701             synchronized (cond) {
702                 if (cond.v) {
703                     conditions.remove (condition);
704                 }
705             }
706         }
707     }
708 }
709