1#!/usr/bin/env python
2"""
3Copyright 2016 Google Inc. All Rights Reserved.
4
5Licensed under the Apache License, Version 2.0 (the "License");
6you may not use this file except in compliance with the License.
7You may obtain a copy of the License at
8
9    http://www.apache.org/licenses/LICENSE-2.0
10
11Unless required by applicable law or agreed to in writing, software
12distributed under the License is distributed on an "AS IS" BASIS,
13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14See the License for the specific language governing permissions and
15limitations under the License.
16"""
17import asyncore
18import gc
19import logging
20import platform
21import Queue
22import re
23import signal
24import socket
25import sys
26import threading
27import time
28
29server = None
30in_pipe = None
31out_pipe = None
32must_exit = False
33options = None
34dest_addresses = None
35connections = {}
36dns_cache = {}
37port_mappings = None
38map_localhost = False
39needs_flush = False
40flush_pipes = False
41last_activity = None
42REMOVE_TCP_OVERHEAD = 1460.0 / 1500.0
43lock = threading.Lock()
44background_activity_count = 0
45current_time = time.clock if sys.platform == "win32" else time.time
46try:
47  import monotonic
48  current_time = monotonic.monotonic
49except Exception:
50  pass
51
52
53def PrintMessage(msg):
54  # Print the message to stdout & flush to make sure that the message is not
55  # buffered when tsproxy is run as a subprocess.
56  print >> sys.stdout, msg
57  sys.stdout.flush()
58
59########################################################################################################################
60#   Traffic-shaping pipe (just passthrough for now)
61########################################################################################################################
62class TSPipe():
63  PIPE_IN = 0
64  PIPE_OUT = 1
65
66  def __init__(self, direction, latency, kbps):
67    self.direction = direction
68    self.latency = latency
69    self.kbps = kbps
70    self.queue = Queue.Queue()
71    self.last_tick = current_time()
72    self.next_message = None
73    self.available_bytes = .0
74    self.peer = 'server'
75    if self.direction == self.PIPE_IN:
76      self.peer = 'client'
77
78  def SendMessage(self, message, main_thread = True):
79    global connections, in_pipe, out_pipe
80    message_sent = False
81    now = current_time()
82    if message['message'] == 'closed':
83      message['time'] = now
84    else:
85      message['time'] = current_time() + self.latency
86    message['size'] = .0
87    if 'data' in message:
88      message['size'] = float(len(message['data']))
89    try:
90      connection_id = message['connection']
91      # Send messages directly, bypassing the queues is throttling is disabled and we are on the main thread
92      if main_thread and connection_id in connections and self.peer in connections[connection_id]and self.latency == 0 and self.kbps == .0:
93        message_sent = self.SendPeerMessage(message)
94    except:
95      pass
96    if not message_sent:
97      try:
98        self.queue.put(message)
99      except:
100        pass
101
102  def SendPeerMessage(self, message):
103    global last_activity
104    last_activity = current_time()
105    message_sent = False
106    connection_id = message['connection']
107    if connection_id in connections:
108      if self.peer in connections[connection_id]:
109        try:
110          connections[connection_id][self.peer].handle_message(message)
111          message_sent = True
112        except:
113          # Clean up any disconnected connections
114          try:
115            connections[connection_id]['server'].close()
116          except:
117            pass
118          try:
119            connections[connection_id]['client'].close()
120          except:
121            pass
122          del connections[connection_id]
123    return message_sent
124
125  def tick(self):
126    global connections
127    global flush_pipes
128    processed_messages = False
129    now = current_time()
130    try:
131      if self.next_message is None:
132        self.next_message = self.queue.get_nowait()
133
134      # Accumulate bandwidth if an available packet/message was waiting since our last tick
135      if self.next_message is not None and self.kbps > .0 and self.next_message['time'] <= now:
136        elapsed = now - self.last_tick
137        accumulated_bytes = elapsed * self.kbps * 1000.0 / 8.0
138        self.available_bytes += accumulated_bytes
139
140      # process messages as long as the next message is sendable (latency or available bytes)
141      while (self.next_message is not None) and\
142          (flush_pipes or ((self.next_message['time'] <= now) and
143                          (self.kbps <= .0 or self.next_message['size'] <= self.available_bytes))):
144        self.queue.task_done()
145        processed_messages = True
146        if self.kbps > .0:
147          self.available_bytes -= self.next_message['size']
148        self.SendPeerMessage(self.next_message)
149        self.next_message = None
150        self.next_message = self.queue.get_nowait()
151    except:
152      pass
153
154    # Only accumulate bytes while we have messages that are ready to send
155    if self.next_message is None or self.next_message['time'] > now:
156      self.available_bytes = .0
157    self.last_tick = now
158
159    return processed_messages
160
161
162########################################################################################################################
163#   Threaded DNS resolver
164########################################################################################################################
165class AsyncDNS(threading.Thread):
166  def __init__(self, client_id, hostname, port, is_localhost, result_pipe):
167    threading.Thread.__init__(self)
168    self.hostname = hostname
169    self.port = port
170    self.client_id = client_id
171    self.is_localhost = is_localhost
172    self.result_pipe = result_pipe
173
174  def run(self):
175    global lock, background_activity_count
176    try:
177      logging.debug('[{0:d}] AsyncDNS - calling getaddrinfo for {1}:{2:d}'.format(self.client_id, self.hostname, self.port))
178      addresses = socket.getaddrinfo(self.hostname, self.port)
179      logging.info('[{0:d}] Resolving {1}:{2:d} Completed'.format(self.client_id, self.hostname, self.port))
180    except:
181      addresses = ()
182      logging.info('[{0:d}] Resolving {1}:{2:d} Failed'.format(self.client_id, self.hostname, self.port))
183    message = {'message': 'resolved', 'connection': self.client_id, 'addresses': addresses, 'localhost': self.is_localhost}
184    self.result_pipe.SendMessage(message, False)
185    lock.acquire()
186    if background_activity_count > 0:
187      background_activity_count -= 1
188    lock.release()
189    # open and close a local socket which will interrupt the long polling loop to process the message
190    s = socket.socket()
191    s.connect((server.ipaddr, server.port))
192    s.close()
193
194
195########################################################################################################################
196#   TCP Client
197########################################################################################################################
198class TCPConnection(asyncore.dispatcher):
199  STATE_ERROR = -1
200  STATE_IDLE = 0
201  STATE_RESOLVING = 1
202  STATE_CONNECTING = 2
203  STATE_CONNECTED = 3
204
205  def __init__(self, client_id):
206    global options
207    asyncore.dispatcher.__init__(self)
208    self.client_id = client_id
209    self.state = self.STATE_IDLE
210    self.buffer = ''
211    self.addr = None
212    self.dns_thread = None
213    self.hostname = None
214    self.port = None
215    self.needs_config = True
216    self.needs_close = False
217    self.did_resolve = False
218
219  def SendMessage(self, type, message):
220    message['message'] = type
221    message['connection'] = self.client_id
222    in_pipe.SendMessage(message)
223
224  def handle_message(self, message):
225    if message['message'] == 'data' and 'data' in message and len(message['data']):
226      self.buffer += message['data']
227      if self.state == self.STATE_CONNECTED:
228        self.handle_write()
229    elif message['message'] == 'resolve':
230      self.HandleResolve(message)
231    elif message['message'] == 'connect':
232      self.HandleConnect(message)
233    elif message['message'] == 'closed':
234      if len(self.buffer) == 0:
235        self.handle_close()
236      else:
237        self.needs_close = True
238
239  def handle_error(self):
240    logging.warning('[{0:d}] Error'.format(self.client_id))
241    if self.state == self.STATE_CONNECTING:
242      self.SendMessage('connected', {'success': False, 'address': self.addr})
243
244  def handle_close(self):
245    logging.info('[{0:d}] Server Connection Closed'.format(self.client_id))
246    self.state = self.STATE_ERROR
247    self.close()
248    try:
249      if self.client_id in connections:
250        if 'server' in connections[self.client_id]:
251          del connections[self.client_id]['server']
252        if 'client' in connections[self.client_id]:
253          self.SendMessage('closed', {})
254        else:
255          del connections[self.client_id]
256    except:
257      pass
258
259  def handle_connect(self):
260    if self.state == self.STATE_CONNECTING:
261      self.state = self.STATE_CONNECTED
262      self.SendMessage('connected', {'success': True, 'address': self.addr})
263      logging.info('[{0:d}] Connected'.format(self.client_id))
264    self.handle_write()
265
266  def writable(self):
267    if self.state == self.STATE_CONNECTING:
268      return True
269    return len(self.buffer) > 0
270
271  def handle_write(self):
272    if self.needs_config:
273      self.needs_config = False
274      self.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
275      self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 128 * 1024)
276      self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 128 * 1024)
277    if len(self.buffer) > 0:
278      sent = self.send(self.buffer)
279      logging.debug('[{0:d}] TCP => {1:d} byte(s)'.format(self.client_id, sent))
280      self.buffer = self.buffer[sent:]
281      if self.needs_close and len(self.buffer) == 0:
282        self.needs_close = False
283        self.handle_close()
284
285  def handle_read(self):
286    try:
287      while True:
288        data = self.recv(1460)
289        if data:
290          if self.state == self.STATE_CONNECTED:
291            logging.debug('[{0:d}] TCP <= {1:d} byte(s)'.format(self.client_id, len(data)))
292            self.SendMessage('data', {'data': data})
293        else:
294          return
295    except:
296      pass
297
298  def HandleResolve(self, message):
299    global in_pipe,  map_localhost, lock, background_activity_count
300    self.did_resolve = True
301    is_localhost = False
302    if 'hostname' in message:
303      self.hostname = message['hostname']
304    self.port = 0
305    if 'port' in message:
306      self.port = message['port']
307    logging.info('[{0:d}] Resolving {1}:{2:d}'.format(self.client_id, self.hostname, self.port))
308    if self.hostname == 'localhost':
309      self.hostname = '127.0.0.1'
310    if self.hostname == '127.0.0.1':
311      logging.info('[{0:d}] Connection to localhost detected'.format(self.client_id))
312      is_localhost = True
313    if (dest_addresses is not None) and (not is_localhost or map_localhost):
314      logging.info('[{0:d}] Resolving {1}:{2:d} to mapped address {3}'.format(self.client_id, self.hostname, self.port, dest_addresses))
315      self.SendMessage('resolved', {'addresses': dest_addresses, 'localhost': False})
316    else:
317      lock.acquire()
318      background_activity_count += 1
319      lock.release()
320      self.state = self.STATE_RESOLVING
321      self.dns_thread = AsyncDNS(self.client_id, self.hostname, self.port, is_localhost, in_pipe)
322      self.dns_thread.start()
323
324  def HandleConnect(self, message):
325    global map_localhost
326    if 'addresses' in message and len(message['addresses']):
327      self.state = self.STATE_CONNECTING
328      is_localhost = False
329      if 'localhost' in message:
330        is_localhost = message['localhost']
331      elif not self.did_resolve and message['addresses'][0] == '127.0.0.1':
332        logging.info('[{0:d}] Connection to localhost detected'.format(self.client_id))
333        is_localhost = True
334      if (dest_addresses is not None) and (not is_localhost or map_localhost):
335        self.addr = dest_addresses[0]
336      else:
337        self.addr = message['addresses'][0]
338      self.create_socket(self.addr[0], socket.SOCK_STREAM)
339      addr = self.addr[4][0]
340      if not is_localhost or map_localhost:
341        port = GetDestPort(message['port'])
342      else:
343        port = message['port']
344      logging.info('[{0:d}] Connecting to {1}:{2:d}'.format(self.client_id, addr, port))
345      self.connect((addr, port))
346
347
348########################################################################################################################
349#   Socks5 Server
350########################################################################################################################
351class Socks5Server(asyncore.dispatcher):
352
353  def __init__(self, host, port):
354    asyncore.dispatcher.__init__(self)
355    self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
356    try:
357      self.set_reuse_addr()
358      self.bind((host, port))
359      self.listen(socket.SOMAXCONN)
360      self.ipaddr, self.port = self.getsockname()
361      self.current_client_id = 0
362    except:
363      PrintMessage("Unable to listen on {0}:{1}. Is the port already in use?".format(host, port))
364      exit(1)
365
366  def handle_accept(self):
367    global connections
368    pair = self.accept()
369    if pair is not None:
370      sock, addr = pair
371      self.current_client_id += 1
372      logging.info('[{0:d}] Incoming connection from {1}'.format(self.current_client_id, repr(addr)))
373      connections[self.current_client_id] = {
374        'client' : Socks5Connection(sock, self.current_client_id),
375        'server' : None
376      }
377
378
379# Socks5 reference: https://en.wikipedia.org/wiki/SOCKS#SOCKS5
380class Socks5Connection(asyncore.dispatcher):
381  STATE_ERROR = -1
382  STATE_WAITING_FOR_HANDSHAKE = 0
383  STATE_WAITING_FOR_CONNECT_REQUEST = 1
384  STATE_RESOLVING = 2
385  STATE_CONNECTING = 3
386  STATE_CONNECTED = 4
387
388  def __init__(self, connected_socket, client_id):
389    global options
390    asyncore.dispatcher.__init__(self, connected_socket)
391    self.client_id = client_id
392    self.state = self.STATE_WAITING_FOR_HANDSHAKE
393    self.ip = None
394    self.addresses = None
395    self.hostname = None
396    self.port = None
397    self.requested_address = None
398    self.buffer = ''
399    self.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
400    self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 128 * 1024)
401    self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 128 * 1024)
402    self.needs_close = False
403
404  def SendMessage(self, type, message):
405    message['message'] = type
406    message['connection'] = self.client_id
407    out_pipe.SendMessage(message)
408
409  def handle_message(self, message):
410    if message['message'] == 'data' and 'data' in message and len(message['data']) > 0:
411      self.buffer += message['data']
412      if self.state == self.STATE_CONNECTED:
413        self.handle_write()
414    elif message['message'] == 'resolved':
415      self.HandleResolved(message)
416    elif message['message'] == 'connected':
417      self.HandleConnected(message)
418      self.handle_write()
419    elif message['message'] == 'closed':
420      if len(self.buffer) == 0:
421        logging.info('[{0:d}] Server connection close being processed, closing Browser connection'.format(self.client_id))
422        self.handle_close()
423      else:
424        logging.info('[{0:d}] Server connection close being processed, queuing browser connection close'.format(self.client_id))
425        self.needs_close = True
426
427  def writable(self):
428    return len(self.buffer) > 0
429
430  def handle_write(self):
431    if len(self.buffer) > 0:
432      sent = self.send(self.buffer)
433      logging.debug('[{0:d}] SOCKS <= {1:d} byte(s)'.format(self.client_id, sent))
434      self.buffer = self.buffer[sent:]
435      if self.needs_close and len(self.buffer) == 0:
436        logging.info('[{0:d}] queued browser connection close being processed, closing Browser connection'.format(self.client_id))
437        self.needs_close = False
438        self.handle_close()
439
440  def handle_read(self):
441    global connections
442    global dns_cache
443    try:
444      while True:
445        # Consume in up-to packet-sized chunks (TCP packet payload as 1460 bytes from 1500 byte ethernet frames)
446        data = self.recv(1460)
447        if data:
448          data_len = len(data)
449          if self.state == self.STATE_CONNECTED:
450            logging.debug('[{0:d}] SOCKS => {1:d} byte(s)'.format(self.client_id, data_len))
451            self.SendMessage('data', {'data': data})
452          elif self.state == self.STATE_WAITING_FOR_HANDSHAKE:
453            self.state = self.STATE_ERROR #default to an error state, set correctly if things work out
454            if data_len >= 2 and ord(data[0]) == 0x05:
455              supports_no_auth = False
456              auth_count = ord(data[1])
457              if data_len == auth_count + 2:
458                for i in range(auth_count):
459                  offset = i + 2
460                  if ord(data[offset]) == 0:
461                    supports_no_auth = True
462              if supports_no_auth:
463                # Respond with a message that "No Authentication" was agreed to
464                logging.info('[{0:d}] New Socks5 client'.format(self.client_id))
465                response = chr(0x05) + chr(0x00)
466                self.state = self.STATE_WAITING_FOR_CONNECT_REQUEST
467                self.buffer += response
468                self.handle_write()
469          elif self.state == self.STATE_WAITING_FOR_CONNECT_REQUEST:
470            self.state = self.STATE_ERROR #default to an error state, set correctly if things work out
471            if data_len >= 10 and ord(data[0]) == 0x05 and ord(data[2]) == 0x00:
472              if ord(data[1]) == 0x01: #TCP connection (only supported method for now)
473                connections[self.client_id]['server'] = TCPConnection(self.client_id)
474              self.requested_address = data[3:]
475              port_offset = 0
476              if ord(data[3]) == 0x01:
477                port_offset = 8
478                self.ip = '{0:d}.{1:d}.{2:d}.{3:d}'.format(ord(data[4]), ord(data[5]), ord(data[6]), ord(data[7]))
479              elif ord(data[3]) == 0x03:
480                name_len = ord(data[4])
481                if data_len >= 6 + name_len:
482                  port_offset = 5 + name_len
483                  self.hostname = data[5:5 + name_len]
484              elif ord(data[3]) == 0x04 and data_len >= 22:
485                port_offset = 20
486                self.ip = ''
487                for i in range(16):
488                  self.ip += '{0:02x}'.format(ord(data[4 + i]))
489                  if i % 2 and i < 15:
490                    self.ip += ':'
491              if port_offset and connections[self.client_id]['server'] is not None:
492                self.port = 256 * ord(data[port_offset]) + ord(data[port_offset + 1])
493                if self.port:
494                  if self.ip is None and self.hostname is not None:
495                    if self.hostname in dns_cache:
496                      self.state = self.STATE_CONNECTING
497                      cache_entry = dns_cache[self.hostname]
498                      self.addresses = cache_entry['addresses']
499                      self.SendMessage('connect', {'addresses': self.addresses, 'port': self.port, 'localhost': cache_entry['localhost']})
500                    else:
501                      self.state = self.STATE_RESOLVING
502                      self.SendMessage('resolve', {'hostname': self.hostname, 'port': self.port})
503                  elif self.ip is not None:
504                    self.state = self.STATE_CONNECTING
505                    logging.debug('[{0:d}] Socks Connect - calling getaddrinfo for {1}:{2:d}'.format(self.client_id, self.ip, self.port))
506                    self.addresses = socket.getaddrinfo(self.ip, self.port)
507                    self.SendMessage('connect', {'addresses': self.addresses, 'port': self.port})
508        else:
509          return
510    except:
511      pass
512
513  def handle_close(self):
514    logging.info('[{0:d}] Browser Connection Closed by browser'.format(self.client_id))
515    self.state = self.STATE_ERROR
516    self.close()
517    try:
518      if self.client_id in connections:
519        if 'client' in connections[self.client_id]:
520          del connections[self.client_id]['client']
521        if 'server' in connections[self.client_id]:
522          self.SendMessage('closed', {})
523        else:
524          del connections[self.client_id]
525    except:
526      pass
527
528  def HandleResolved(self, message):
529    global dns_cache
530    if self.state == self.STATE_RESOLVING:
531      if 'addresses' in message and len(message['addresses']):
532        self.state = self.STATE_CONNECTING
533        self.addresses = message['addresses']
534        dns_cache[self.hostname] = {'addresses': self.addresses, 'localhost': message['localhost']}
535        logging.debug('[{0:d}] Resolved {1}, Connecting'.format(self.client_id, self.hostname))
536        self.SendMessage('connect', {'addresses': self.addresses, 'port': self.port, 'localhost': message['localhost']})
537      else:
538        # Send host unreachable error
539        self.state = self.STATE_ERROR
540        self.buffer += chr(0x05) + chr(0x04) + self.requested_address
541        self.handle_write()
542
543  def HandleConnected(self, message):
544    if 'success' in message and self.state == self.STATE_CONNECTING:
545      response = chr(0x05)
546      if message['success']:
547        response += chr(0x00)
548        logging.debug('[{0:d}] Connected to {1}'.format(self.client_id, self.hostname))
549        self.state = self.STATE_CONNECTED
550      else:
551        response += chr(0x04)
552        self.state = self.STATE_ERROR
553      response += chr(0x00)
554      response += self.requested_address
555      self.buffer += response
556      self.handle_write()
557
558
559########################################################################################################################
560#   stdin command processor
561########################################################################################################################
562class CommandProcessor():
563  def __init__(self):
564    thread = threading.Thread(target = self.run, args=())
565    thread.daemon = True
566    thread.start()
567
568  def run(self):
569    global must_exit
570    while not must_exit:
571      for line in iter(sys.stdin.readline, ''):
572        self.ProcessCommand(line.strip())
573
574  def ProcessCommand(self, input):
575    global in_pipe
576    global out_pipe
577    global needs_flush
578    global REMOVE_TCP_OVERHEAD
579    global port_mappings
580    global server
581    if len(input):
582      ok = False
583      try:
584        command = input.split()
585        if len(command) and len(command[0]):
586          if command[0].lower() == 'flush':
587            ok = True
588          elif command[0].lower() == 'set' and len(command) >= 3:
589            if command[1].lower() == 'rtt' and len(command[2]):
590              rtt = float(command[2])
591              latency = rtt / 2000.0
592              in_pipe.latency = latency
593              out_pipe.latency = latency
594              ok = True
595            elif command[1].lower() == 'inkbps' and len(command[2]):
596              in_pipe.kbps = float(command[2]) * REMOVE_TCP_OVERHEAD
597              ok = True
598            elif command[1].lower() == 'outkbps' and len(command[2]):
599              out_pipe.kbps = float(command[2]) * REMOVE_TCP_OVERHEAD
600              ok = True
601            elif command[1].lower() == 'mapports' and len(command[2]):
602              SetPortMappings(command[2])
603              ok = True
604          elif command[0].lower() == 'reset' and len(command) >= 2:
605            if command[1].lower() == 'rtt' or command[1].lower() == 'all':
606              in_pipe.latency = 0
607              out_pipe.latency = 0
608              ok = True
609            if command[1].lower() == 'inkbps' or command[1].lower() == 'all':
610              in_pipe.kbps = 0
611              ok = True
612            if command[1].lower() == 'outkbps' or command[1].lower() == 'all':
613              out_pipe.kbps = 0
614              ok = True
615            if command[1].lower() == 'mapports' or command[1].lower() == 'all':
616              port_mappings = {}
617              ok = True
618
619          if ok:
620            needs_flush = True
621      except:
622        pass
623      if not ok:
624        PrintMessage('ERROR')
625      # open and close a local socket which will interrupt the long polling loop to process the flush
626      if needs_flush:
627        s = socket.socket()
628        s.connect((server.ipaddr, server.port))
629        s.close()
630
631
632########################################################################################################################
633#   Main Entry Point
634########################################################################################################################
635def main():
636  global server
637  global options
638  global in_pipe
639  global out_pipe
640  global dest_addresses
641  global port_mappings
642  global map_localhost
643  import argparse
644  global REMOVE_TCP_OVERHEAD
645  parser = argparse.ArgumentParser(description='Traffic-shaping socks5 proxy.',
646                                   prog='tsproxy')
647  parser.add_argument('-v', '--verbose', action='count', help="Increase verbosity (specify multiple times for more). -vvvv for full debug output.")
648  parser.add_argument('--logfile', help="Write log messages to given file instead of stdout.")
649  parser.add_argument('-b', '--bind', default='localhost', help="Server interface address (defaults to localhost).")
650  parser.add_argument('-p', '--port', type=int, default=1080, help="Server port (defaults to 1080, use 0 for randomly assigned).")
651  parser.add_argument('-r', '--rtt', type=float, default=.0, help="Round Trip Time Latency (in ms).")
652  parser.add_argument('-i', '--inkbps', type=float, default=.0, help="Download Bandwidth (in 1000 bits/s - Kbps).")
653  parser.add_argument('-o', '--outkbps', type=float, default=.0, help="Upload Bandwidth (in 1000 bits/s - Kbps).")
654  parser.add_argument('-w', '--window', type=int, default=10, help="Emulated TCP initial congestion window (defaults to 10).")
655  parser.add_argument('-d', '--desthost', help="Redirect all outbound connections to the specified host.")
656  parser.add_argument('-m', '--mapports', help="Remap outbound ports. Comma-separated list of original:new with * as a wildcard. --mapports '443:8443,*:8080'")
657  parser.add_argument('-l', '--localhost', action='store_true', default=False,
658                      help="Include connections already destined for localhost/127.0.0.1 in the host and port remapping.")
659  options = parser.parse_args()
660
661  # Set up logging
662  log_level = logging.CRITICAL
663  if options.verbose == 1:
664    log_level = logging.ERROR
665  elif options.verbose == 2:
666    log_level = logging.WARNING
667  elif options.verbose == 3:
668    log_level = logging.INFO
669  elif options.verbose >= 4:
670    log_level = logging.DEBUG
671  if options.logfile is not None:
672    logging.basicConfig(filename=options.logfile, level=log_level,
673                        format="%(asctime)s.%(msecs)03d - %(message)s", datefmt="%H:%M:%S")
674  else:
675    logging.basicConfig(level=log_level, format="%(asctime)s.%(msecs)03d - %(message)s", datefmt="%H:%M:%S")
676
677  # Parse any port mappings
678  if options.mapports:
679    SetPortMappings(options.mapports)
680
681  map_localhost = options.localhost
682
683  # Resolve the address for a rewrite destination host if one was specified
684  if options.desthost:
685    logging.debug('Startup - calling getaddrinfo for {0}:{1:d}'.format(options.desthost, GetDestPort(80)))
686    dest_addresses = socket.getaddrinfo(options.desthost, GetDestPort(80))
687
688  # Set up the pipes.  1/2 of the latency gets applied in each direction (and /1000 to convert to seconds)
689  in_pipe = TSPipe(TSPipe.PIPE_IN, options.rtt / 2000.0, options.inkbps * REMOVE_TCP_OVERHEAD)
690  out_pipe = TSPipe(TSPipe.PIPE_OUT, options.rtt / 2000.0, options.outkbps * REMOVE_TCP_OVERHEAD)
691
692  signal.signal(signal.SIGINT, signal_handler)
693  server = Socks5Server(options.bind, options.port)
694  command_processor = CommandProcessor()
695  PrintMessage('Started Socks5 proxy server on {0}:{1:d}\nHit Ctrl-C to exit.'.format(server.ipaddr, server.port))
696  run_loop()
697
698def signal_handler(signal, frame):
699  global server
700  global must_exit
701  logging.error('Exiting...')
702  must_exit = True
703  del server
704
705
706# Wrapper around the asyncore loop that lets us poll the in/out pipes every 1ms
707def run_loop():
708  global must_exit
709  global in_pipe
710  global out_pipe
711  global needs_flush
712  global flush_pipes
713  global last_activity
714  winmm = None
715
716  # increase the windows timer resolution to 1ms
717  if platform.system() == "Windows":
718    try:
719      import ctypes
720      winmm = ctypes.WinDLL('winmm')
721      winmm.timeBeginPeriod(1)
722    except:
723      pass
724
725  last_activity = current_time()
726  last_check = current_time()
727  # disable gc to avoid pauses during traffic shaping/proxying
728  gc.disable()
729  while not must_exit:
730    # Tick every 1ms if traffic-shaping is enabled and we have data or are doing background dns lookups, every 1 second otherwise
731    lock.acquire()
732    tick_interval = 0.001
733    if background_activity_count == 0:
734      if in_pipe.next_message is None and in_pipe.queue.empty() and out_pipe.next_message is None and out_pipe.queue.empty():
735        tick_interval = 1.0
736      elif in_pipe.kbps == .0 and in_pipe.latency == 0 and out_pipe.kbps == .0 and out_pipe.latency == 0:
737        tick_interval = 1.0
738    lock.release()
739    asyncore.poll(tick_interval, asyncore.socket_map)
740    if needs_flush:
741      flush_pipes = True
742      needs_flush = False
743    out_pipe.tick()
744    in_pipe.tick()
745    if flush_pipes:
746      PrintMessage('OK')
747      flush_pipes = False
748    # Every 500 ms check to see if it is a good time to do a gc
749    now = current_time()
750    if now - last_check > 0.5:
751      last_check = now
752      # manually gc after 5 seconds of idle
753      if now - last_activity >= 5:
754        last_activity = now
755        logging.debug("Triggering manual GC")
756        gc.collect()
757
758  if winmm is not None:
759    winmm.timeEndPeriod(1)
760
761def GetDestPort(port):
762  global port_mappings
763  if port_mappings is not None:
764    src_port = str(port)
765    if src_port in port_mappings:
766      return port_mappings[src_port]
767    elif 'default' in port_mappings:
768      return port_mappings['default']
769  return port
770
771
772def SetPortMappings(map_string):
773  global port_mappings
774  port_mappings = {}
775  map_string = map_string.strip('\'" \t\r\n')
776  for pair in map_string.split(','):
777    (src, dest) = pair.split(':')
778    if src == '*':
779      port_mappings['default'] = int(dest)
780      logging.debug("Default port mapped to port {0}".format(dest))
781    else:
782      logging.debug("Port {0} mapped to port {1}".format(src, dest))
783      port_mappings[src] = int(dest)
784
785
786if '__main__' == __name__:
787  main()
788