1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4#
5
6"""
7This module contains the implementation of the TCP forwarding, which allows
8clients and servers to forward arbitrary TCP data across the connection.
9
10Maintainer: Paul Swartz
11"""
12
13import struct
14
15from twisted.internet import protocol, reactor
16from twisted.python import log
17
18import common, channel
19
20class SSHListenForwardingFactory(protocol.Factory):
21    def __init__(self, connection, hostport, klass):
22        self.conn = connection
23        self.hostport = hostport # tuple
24        self.klass = klass
25
26    def buildProtocol(self, addr):
27        channel = self.klass(conn = self.conn)
28        client = SSHForwardingClient(channel)
29        channel.client = client
30        addrTuple = (addr.host, addr.port)
31        channelOpenData = packOpen_direct_tcpip(self.hostport, addrTuple)
32        self.conn.openChannel(channel, channelOpenData)
33        return client
34
35class SSHListenForwardingChannel(channel.SSHChannel):
36
37    def channelOpen(self, specificData):
38        log.msg('opened forwarding channel %s' % self.id)
39        if len(self.client.buf)>1:
40            b = self.client.buf[1:]
41            self.write(b)
42        self.client.buf = ''
43
44    def openFailed(self, reason):
45        self.closed()
46
47    def dataReceived(self, data):
48        self.client.transport.write(data)
49
50    def eofReceived(self):
51        self.client.transport.loseConnection()
52
53    def closed(self):
54        if hasattr(self, 'client'):
55            log.msg('closing local forwarding channel %s' % self.id)
56            self.client.transport.loseConnection()
57            del self.client
58
59class SSHListenClientForwardingChannel(SSHListenForwardingChannel):
60
61    name = 'direct-tcpip'
62
63class SSHListenServerForwardingChannel(SSHListenForwardingChannel):
64
65    name = 'forwarded-tcpip'
66
67class SSHConnectForwardingChannel(channel.SSHChannel):
68
69    def __init__(self, hostport, *args, **kw):
70        channel.SSHChannel.__init__(self, *args, **kw)
71        self.hostport = hostport
72        self.client = None
73        self.clientBuf = ''
74
75    def channelOpen(self, specificData):
76        cc = protocol.ClientCreator(reactor, SSHForwardingClient, self)
77        log.msg("connecting to %s:%i" % self.hostport)
78        cc.connectTCP(*self.hostport).addCallbacks(self._setClient, self._close)
79
80    def _setClient(self, client):
81        self.client = client
82        log.msg("connected to %s:%i" % self.hostport)
83        if self.clientBuf:
84            self.client.transport.write(self.clientBuf)
85            self.clientBuf = None
86        if self.client.buf[1:]:
87            self.write(self.client.buf[1:])
88        self.client.buf = ''
89
90    def _close(self, reason):
91        log.msg("failed to connect: %s" % reason)
92        self.loseConnection()
93
94    def dataReceived(self, data):
95        if self.client:
96            self.client.transport.write(data)
97        else:
98            self.clientBuf += data
99
100    def closed(self):
101        if self.client:
102            log.msg('closed remote forwarding channel %s' % self.id)
103            if self.client.channel:
104                self.loseConnection()
105            self.client.transport.loseConnection()
106            del self.client
107
108def openConnectForwardingClient(remoteWindow, remoteMaxPacket, data, avatar):
109    remoteHP, origHP = unpackOpen_direct_tcpip(data)
110    return SSHConnectForwardingChannel(remoteHP,
111                                       remoteWindow=remoteWindow,
112                                       remoteMaxPacket=remoteMaxPacket,
113                                       avatar=avatar)
114
115class SSHForwardingClient(protocol.Protocol):
116
117    def __init__(self, channel):
118        self.channel = channel
119        self.buf = '\000'
120
121    def dataReceived(self, data):
122        if self.buf:
123            self.buf += data
124        else:
125            self.channel.write(data)
126
127    def connectionLost(self, reason):
128        if self.channel:
129            self.channel.loseConnection()
130            self.channel = None
131
132
133def packOpen_direct_tcpip((connHost, connPort), (origHost, origPort)):
134    """Pack the data suitable for sending in a CHANNEL_OPEN packet.
135    """
136    conn = common.NS(connHost) + struct.pack('>L', connPort)
137    orig = common.NS(origHost) + struct.pack('>L', origPort)
138    return conn + orig
139
140packOpen_forwarded_tcpip = packOpen_direct_tcpip
141
142def unpackOpen_direct_tcpip(data):
143    """Unpack the data to a usable format.
144    """
145    connHost, rest = common.getNS(data)
146    connPort = int(struct.unpack('>L', rest[:4])[0])
147    origHost, rest = common.getNS(rest[4:])
148    origPort = int(struct.unpack('>L', rest[:4])[0])
149    return (connHost, connPort), (origHost, origPort)
150
151unpackOpen_forwarded_tcpip = unpackOpen_direct_tcpip
152
153def packGlobal_tcpip_forward((host, port)):
154    return common.NS(host) + struct.pack('>L', port)
155
156def unpackGlobal_tcpip_forward(data):
157    host, rest = common.getNS(data)
158    port = int(struct.unpack('>L', rest[:4])[0])
159    return host, port
160
161"""This is how the data -> eof -> close stuff /should/ work.
162
163debug3: channel 1: waiting for connection
164debug1: channel 1: connected
165debug1: channel 1: read<=0 rfd 7 len 0
166debug1: channel 1: read failed
167debug1: channel 1: close_read
168debug1: channel 1: input open -> drain
169debug1: channel 1: ibuf empty
170debug1: channel 1: send eof
171debug1: channel 1: input drain -> closed
172debug1: channel 1: rcvd eof
173debug1: channel 1: output open -> drain
174debug1: channel 1: obuf empty
175debug1: channel 1: close_write
176debug1: channel 1: output drain -> closed
177debug1: channel 1: rcvd close
178debug3: channel 1: will not send data after close
179debug1: channel 1: send close
180debug1: channel 1: is dead
181"""
182