1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3# Copyright 2012 Matt Martz
4# All Rights Reserved.
5#
6#    Licensed under the Apache License, Version 2.0 (the "License"); you may
7#    not use this file except in compliance with the License. You may obtain
8#    a copy of the License at
9#
10#         http://www.apache.org/licenses/LICENSE-2.0
11#
12#    Unless required by applicable law or agreed to in writing, software
13#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15#    License for the specific language governing permissions and limitations
16#    under the License.
17
18import os
19import re
20import csv
21import sys
22import math
23import errno
24import signal
25import socket
26import timeit
27import datetime
28import platform
29import threading
30import xml.parsers.expat
31
32try:
33    import gzip
34    GZIP_BASE = gzip.GzipFile
35except ImportError:
36    gzip = None
37    GZIP_BASE = object
38
39__version__ = '2.1.3'
40
41
42class FakeShutdownEvent(object):
43    """Class to fake a threading.Event.isSet so that users of this module
44    are not required to register their own threading.Event()
45    """
46
47    @staticmethod
48    def isSet():
49        "Dummy method to always return false"""
50        return False
51
52
53# Some global variables we use
54DEBUG = False
55_GLOBAL_DEFAULT_TIMEOUT = object()
56PY25PLUS = sys.version_info[:2] >= (2, 5)
57PY26PLUS = sys.version_info[:2] >= (2, 6)
58PY32PLUS = sys.version_info[:2] >= (3, 2)
59
60# Begin import game to handle Python 2 and Python 3
61try:
62    import json
63except ImportError:
64    try:
65        import simplejson as json
66    except ImportError:
67        json = None
68
69try:
70    import xml.etree.ElementTree as ET
71    try:
72        from xml.etree.ElementTree import _Element as ET_Element
73    except ImportError:
74        pass
75except ImportError:
76    from xml.dom import minidom as DOM
77    from xml.parsers.expat import ExpatError
78    ET = None
79
80try:
81    from urllib2 import (urlopen, Request, HTTPError, URLError,
82                         AbstractHTTPHandler, ProxyHandler,
83                         HTTPDefaultErrorHandler, HTTPRedirectHandler,
84                         HTTPErrorProcessor, OpenerDirector)
85except ImportError:
86    from urllib.request import (urlopen, Request, HTTPError, URLError,
87                                AbstractHTTPHandler, ProxyHandler,
88                                HTTPDefaultErrorHandler, HTTPRedirectHandler,
89                                HTTPErrorProcessor, OpenerDirector)
90
91try:
92    from httplib import HTTPConnection, BadStatusLine
93except ImportError:
94    from http.client import HTTPConnection, BadStatusLine
95
96try:
97    from httplib import HTTPSConnection
98except ImportError:
99    try:
100        from http.client import HTTPSConnection
101    except ImportError:
102        HTTPSConnection = None
103
104try:
105    from httplib import FakeSocket
106except ImportError:
107    FakeSocket = None
108
109try:
110    from Queue import Queue
111except ImportError:
112    from queue import Queue
113
114try:
115    from urlparse import urlparse
116except ImportError:
117    from urllib.parse import urlparse
118
119try:
120    from urlparse import parse_qs
121except ImportError:
122    try:
123        from urllib.parse import parse_qs
124    except ImportError:
125        from cgi import parse_qs
126
127try:
128    from hashlib import md5
129except ImportError:
130    from md5 import md5
131
132try:
133    from argparse import ArgumentParser as ArgParser
134    from argparse import SUPPRESS as ARG_SUPPRESS
135    PARSER_TYPE_INT = int
136    PARSER_TYPE_STR = str
137    PARSER_TYPE_FLOAT = float
138except ImportError:
139    from optparse import OptionParser as ArgParser
140    from optparse import SUPPRESS_HELP as ARG_SUPPRESS
141    PARSER_TYPE_INT = 'int'
142    PARSER_TYPE_STR = 'string'
143    PARSER_TYPE_FLOAT = 'float'
144
145try:
146    from cStringIO import StringIO
147    BytesIO = None
148except ImportError:
149    try:
150        from StringIO import StringIO
151        BytesIO = None
152    except ImportError:
153        from io import StringIO, BytesIO
154
155try:
156    import __builtin__
157except ImportError:
158    import builtins
159    from io import TextIOWrapper, FileIO
160
161    class _Py3Utf8Output(TextIOWrapper):
162        """UTF-8 encoded wrapper around stdout for py3, to override
163        ASCII stdout
164        """
165        def __init__(self, f, **kwargs):
166            buf = FileIO(f.fileno(), 'w')
167            super(_Py3Utf8Output, self).__init__(
168                buf,
169                encoding='utf8',
170                errors='strict'
171            )
172
173        def write(self, s):
174            super(_Py3Utf8Output, self).write(s)
175            self.flush()
176
177    _py3_print = getattr(builtins, 'print')
178    try:
179        _py3_utf8_stdout = _Py3Utf8Output(sys.stdout)
180        _py3_utf8_stderr = _Py3Utf8Output(sys.stderr)
181    except OSError:
182        # sys.stdout/sys.stderr is not a compatible stdout/stderr object
183        # just use it and hope things go ok
184        _py3_utf8_stdout = sys.stdout
185        _py3_utf8_stderr = sys.stderr
186
187    def to_utf8(v):
188        """No-op encode to utf-8 for py3"""
189        return v
190
191    def print_(*args, **kwargs):
192        """Wrapper function for py3 to print, with a utf-8 encoded stdout"""
193        if kwargs.get('file') == sys.stderr:
194            kwargs['file'] = _py3_utf8_stderr
195        else:
196            kwargs['file'] = kwargs.get('file', _py3_utf8_stdout)
197        _py3_print(*args, **kwargs)
198else:
199    del __builtin__
200
201    def to_utf8(v):
202        """Encode value to utf-8 if possible for py2"""
203        try:
204            return v.encode('utf8', 'strict')
205        except AttributeError:
206            return v
207
208    def print_(*args, **kwargs):
209        """The new-style print function for Python 2.4 and 2.5.
210
211        Taken from https://pypi.python.org/pypi/six/
212
213        Modified to set encoding to UTF-8 always, and to flush after write
214        """
215        fp = kwargs.pop("file", sys.stdout)
216        if fp is None:
217            return
218
219        def write(data):
220            if not isinstance(data, basestring):
221                data = str(data)
222            # If the file has an encoding, encode unicode with it.
223            encoding = 'utf8'  # Always trust UTF-8 for output
224            if (isinstance(fp, file) and
225                    isinstance(data, unicode) and
226                    encoding is not None):
227                errors = getattr(fp, "errors", None)
228                if errors is None:
229                    errors = "strict"
230                data = data.encode(encoding, errors)
231            fp.write(data)
232            fp.flush()
233        want_unicode = False
234        sep = kwargs.pop("sep", None)
235        if sep is not None:
236            if isinstance(sep, unicode):
237                want_unicode = True
238            elif not isinstance(sep, str):
239                raise TypeError("sep must be None or a string")
240        end = kwargs.pop("end", None)
241        if end is not None:
242            if isinstance(end, unicode):
243                want_unicode = True
244            elif not isinstance(end, str):
245                raise TypeError("end must be None or a string")
246        if kwargs:
247            raise TypeError("invalid keyword arguments to print()")
248        if not want_unicode:
249            for arg in args:
250                if isinstance(arg, unicode):
251                    want_unicode = True
252                    break
253        if want_unicode:
254            newline = unicode("\n")
255            space = unicode(" ")
256        else:
257            newline = "\n"
258            space = " "
259        if sep is None:
260            sep = space
261        if end is None:
262            end = newline
263        for i, arg in enumerate(args):
264            if i:
265                write(sep)
266            write(arg)
267        write(end)
268
269if PY32PLUS:
270    etree_iter = ET.Element.iter
271elif PY25PLUS:
272    etree_iter = ET_Element.getiterator
273
274if PY26PLUS:
275    thread_is_alive = threading.Thread.is_alive
276else:
277    thread_is_alive = threading.Thread.isAlive
278
279
280# Exception "constants" to support Python 2 through Python 3
281try:
282    import ssl
283    try:
284        CERT_ERROR = (ssl.CertificateError,)
285    except AttributeError:
286        CERT_ERROR = tuple()
287
288    HTTP_ERRORS = (
289        (HTTPError, URLError, socket.error, ssl.SSLError, BadStatusLine) +
290        CERT_ERROR
291    )
292except ImportError:
293    ssl = None
294    HTTP_ERRORS = (HTTPError, URLError, socket.error, BadStatusLine)
295
296
297class SpeedtestException(Exception):
298    """Base exception for this module"""
299
300
301class SpeedtestCLIError(SpeedtestException):
302    """Generic exception for raising errors during CLI operation"""
303
304
305class SpeedtestHTTPError(SpeedtestException):
306    """Base HTTP exception for this module"""
307
308
309class SpeedtestConfigError(SpeedtestException):
310    """Configuration XML is invalid"""
311
312
313class SpeedtestServersError(SpeedtestException):
314    """Servers XML is invalid"""
315
316
317class ConfigRetrievalError(SpeedtestHTTPError):
318    """Could not retrieve config.php"""
319
320
321class ServersRetrievalError(SpeedtestHTTPError):
322    """Could not retrieve speedtest-servers.php"""
323
324
325class InvalidServerIDType(SpeedtestException):
326    """Server ID used for filtering was not an integer"""
327
328
329class NoMatchedServers(SpeedtestException):
330    """No servers matched when filtering"""
331
332
333class SpeedtestMiniConnectFailure(SpeedtestException):
334    """Could not connect to the provided speedtest mini server"""
335
336
337class InvalidSpeedtestMiniServer(SpeedtestException):
338    """Server provided as a speedtest mini server does not actually appear
339    to be a speedtest mini server
340    """
341
342
343class ShareResultsConnectFailure(SpeedtestException):
344    """Could not connect to speedtest.net API to POST results"""
345
346
347class ShareResultsSubmitFailure(SpeedtestException):
348    """Unable to successfully POST results to speedtest.net API after
349    connection
350    """
351
352
353class SpeedtestUploadTimeout(SpeedtestException):
354    """testlength configuration reached during upload
355    Used to ensure the upload halts when no additional data should be sent
356    """
357
358
359class SpeedtestBestServerFailure(SpeedtestException):
360    """Unable to determine best server"""
361
362
363class SpeedtestMissingBestServer(SpeedtestException):
364    """get_best_server not called or not able to determine best server"""
365
366
367def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
368                      source_address=None):
369    """Connect to *address* and return the socket object.
370
371    Convenience function.  Connect to *address* (a 2-tuple ``(host,
372    port)``) and return the socket object.  Passing the optional
373    *timeout* parameter will set the timeout on the socket instance
374    before attempting to connect.  If no *timeout* is supplied, the
375    global default timeout setting returned by :func:`getdefaulttimeout`
376    is used.  If *source_address* is set it must be a tuple of (host, port)
377    for the socket to bind as a source address before making the connection.
378    An host of '' or port 0 tells the OS to use the default.
379
380    Largely vendored from Python 2.7, modified to work with Python 2.4
381    """
382
383    host, port = address
384    err = None
385    for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
386        af, socktype, proto, canonname, sa = res
387        sock = None
388        try:
389            sock = socket.socket(af, socktype, proto)
390            if timeout is not _GLOBAL_DEFAULT_TIMEOUT:
391                sock.settimeout(float(timeout))
392            if source_address:
393                sock.bind(source_address)
394            sock.connect(sa)
395            return sock
396
397        except socket.error:
398            err = get_exception()
399            if sock is not None:
400                sock.close()
401
402    if err is not None:
403        raise err
404    else:
405        raise socket.error("getaddrinfo returns an empty list")
406
407
408class SpeedtestHTTPConnection(HTTPConnection):
409    """Custom HTTPConnection to support source_address across
410    Python 2.4 - Python 3
411    """
412    def __init__(self, *args, **kwargs):
413        source_address = kwargs.pop('source_address', None)
414        timeout = kwargs.pop('timeout', 10)
415
416        self._tunnel_host = None
417
418        HTTPConnection.__init__(self, *args, **kwargs)
419
420        self.source_address = source_address
421        self.timeout = timeout
422
423    def connect(self):
424        """Connect to the host and port specified in __init__."""
425        try:
426            self.sock = socket.create_connection(
427                (self.host, self.port),
428                self.timeout,
429                self.source_address
430            )
431        except (AttributeError, TypeError):
432            self.sock = create_connection(
433                (self.host, self.port),
434                self.timeout,
435                self.source_address
436            )
437
438        if self._tunnel_host:
439            self._tunnel()
440
441
442if HTTPSConnection:
443    class SpeedtestHTTPSConnection(HTTPSConnection):
444        """Custom HTTPSConnection to support source_address across
445        Python 2.4 - Python 3
446        """
447        default_port = 443
448
449        def __init__(self, *args, **kwargs):
450            source_address = kwargs.pop('source_address', None)
451            timeout = kwargs.pop('timeout', 10)
452
453            self._tunnel_host = None
454
455            HTTPSConnection.__init__(self, *args, **kwargs)
456
457            self.timeout = timeout
458            self.source_address = source_address
459
460        def connect(self):
461            "Connect to a host on a given (SSL) port."
462            try:
463                self.sock = socket.create_connection(
464                    (self.host, self.port),
465                    self.timeout,
466                    self.source_address
467                )
468            except (AttributeError, TypeError):
469                self.sock = create_connection(
470                    (self.host, self.port),
471                    self.timeout,
472                    self.source_address
473                )
474
475            if self._tunnel_host:
476                self._tunnel()
477
478            if ssl:
479                try:
480                    kwargs = {}
481                    if hasattr(ssl, 'SSLContext'):
482                        if self._tunnel_host:
483                            kwargs['server_hostname'] = self._tunnel_host
484                        else:
485                            kwargs['server_hostname'] = self.host
486                    self.sock = self._context.wrap_socket(self.sock, **kwargs)
487                except AttributeError:
488                    self.sock = ssl.wrap_socket(self.sock)
489                    try:
490                        self.sock.server_hostname = self.host
491                    except AttributeError:
492                        pass
493            elif FakeSocket:
494                # Python 2.4/2.5 support
495                try:
496                    self.sock = FakeSocket(self.sock, socket.ssl(self.sock))
497                except AttributeError:
498                    raise SpeedtestException(
499                        'This version of Python does not support HTTPS/SSL '
500                        'functionality'
501                    )
502            else:
503                raise SpeedtestException(
504                    'This version of Python does not support HTTPS/SSL '
505                    'functionality'
506                )
507
508
509def _build_connection(connection, source_address, timeout, context=None):
510    """Cross Python 2.4 - Python 3 callable to build an ``HTTPConnection`` or
511    ``HTTPSConnection`` with the args we need
512
513    Called from ``http(s)_open`` methods of ``SpeedtestHTTPHandler`` or
514    ``SpeedtestHTTPSHandler``
515    """
516    def inner(host, **kwargs):
517        kwargs.update({
518            'source_address': source_address,
519            'timeout': timeout
520        })
521        if context:
522            kwargs['context'] = context
523        return connection(host, **kwargs)
524    return inner
525
526
527class SpeedtestHTTPHandler(AbstractHTTPHandler):
528    """Custom ``HTTPHandler`` that can build a ``HTTPConnection`` with the
529    args we need for ``source_address`` and ``timeout``
530    """
531    def __init__(self, debuglevel=0, source_address=None, timeout=10):
532        AbstractHTTPHandler.__init__(self, debuglevel)
533        self.source_address = source_address
534        self.timeout = timeout
535
536    def http_open(self, req):
537        return self.do_open(
538            _build_connection(
539                SpeedtestHTTPConnection,
540                self.source_address,
541                self.timeout
542            ),
543            req
544        )
545
546    http_request = AbstractHTTPHandler.do_request_
547
548
549class SpeedtestHTTPSHandler(AbstractHTTPHandler):
550    """Custom ``HTTPSHandler`` that can build a ``HTTPSConnection`` with the
551    args we need for ``source_address`` and ``timeout``
552    """
553    def __init__(self, debuglevel=0, context=None, source_address=None,
554                 timeout=10):
555        AbstractHTTPHandler.__init__(self, debuglevel)
556        self._context = context
557        self.source_address = source_address
558        self.timeout = timeout
559
560    def https_open(self, req):
561        return self.do_open(
562            _build_connection(
563                SpeedtestHTTPSConnection,
564                self.source_address,
565                self.timeout,
566                context=self._context,
567            ),
568            req
569        )
570
571    https_request = AbstractHTTPHandler.do_request_
572
573
574def build_opener(source_address=None, timeout=10):
575    """Function similar to ``urllib2.build_opener`` that will build
576    an ``OpenerDirector`` with the explicit handlers we want,
577    ``source_address`` for binding, ``timeout`` and our custom
578    `User-Agent`
579    """
580
581    printer('Timeout set to %d' % timeout, debug=True)
582
583    if source_address:
584        source_address_tuple = (source_address, 0)
585        printer('Binding to source address: %r' % (source_address_tuple,),
586                debug=True)
587    else:
588        source_address_tuple = None
589
590    handlers = [
591        ProxyHandler(),
592        SpeedtestHTTPHandler(source_address=source_address_tuple,
593                             timeout=timeout),
594        SpeedtestHTTPSHandler(source_address=source_address_tuple,
595                              timeout=timeout),
596        HTTPDefaultErrorHandler(),
597        HTTPRedirectHandler(),
598        HTTPErrorProcessor()
599    ]
600
601    opener = OpenerDirector()
602    opener.addheaders = [('User-agent', build_user_agent())]
603
604    for handler in handlers:
605        opener.add_handler(handler)
606
607    return opener
608
609
610class GzipDecodedResponse(GZIP_BASE):
611    """A file-like object to decode a response encoded with the gzip
612    method, as described in RFC 1952.
613
614    Largely copied from ``xmlrpclib``/``xmlrpc.client`` and modified
615    to work for py2.4-py3
616    """
617    def __init__(self, response):
618        # response doesn't support tell() and read(), required by
619        # GzipFile
620        if not gzip:
621            raise SpeedtestHTTPError('HTTP response body is gzip encoded, '
622                                     'but gzip support is not available')
623        IO = BytesIO or StringIO
624        self.io = IO()
625        while 1:
626            chunk = response.read(1024)
627            if len(chunk) == 0:
628                break
629            self.io.write(chunk)
630        self.io.seek(0)
631        gzip.GzipFile.__init__(self, mode='rb', fileobj=self.io)
632
633    def close(self):
634        try:
635            gzip.GzipFile.close(self)
636        finally:
637            self.io.close()
638
639
640def get_exception():
641    """Helper function to work with py2.4-py3 for getting the current
642    exception in a try/except block
643    """
644    return sys.exc_info()[1]
645
646
647def distance(origin, destination):
648    """Determine distance between 2 sets of [lat,lon] in km"""
649
650    lat1, lon1 = origin
651    lat2, lon2 = destination
652    radius = 6371  # km
653
654    dlat = math.radians(lat2 - lat1)
655    dlon = math.radians(lon2 - lon1)
656    a = (math.sin(dlat / 2) * math.sin(dlat / 2) +
657         math.cos(math.radians(lat1)) *
658         math.cos(math.radians(lat2)) * math.sin(dlon / 2) *
659         math.sin(dlon / 2))
660    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
661    d = radius * c
662
663    return d
664
665
666def build_user_agent():
667    """Build a Mozilla/5.0 compatible User-Agent string"""
668
669    ua_tuple = (
670        'Mozilla/5.0',
671        '(%s; U; %s; en-us)' % (platform.platform(),
672                                platform.architecture()[0]),
673        'Python/%s' % platform.python_version(),
674        '(KHTML, like Gecko)',
675        'speedtest-cli/%s' % __version__
676    )
677    user_agent = ' '.join(ua_tuple)
678    printer('User-Agent: %s' % user_agent, debug=True)
679    return user_agent
680
681
682def build_request(url, data=None, headers=None, bump='0', secure=False):
683    """Build a urllib2 request object
684
685    This function automatically adds a User-Agent header to all requests
686
687    """
688
689    if not headers:
690        headers = {}
691
692    if url[0] == ':':
693        scheme = ('http', 'https')[bool(secure)]
694        schemed_url = '%s%s' % (scheme, url)
695    else:
696        schemed_url = url
697
698    if '?' in url:
699        delim = '&'
700    else:
701        delim = '?'
702
703    # WHO YOU GONNA CALL? CACHE BUSTERS!
704    final_url = '%s%sx=%s.%s' % (schemed_url, delim,
705                                 int(timeit.time.time() * 1000),
706                                 bump)
707
708    headers.update({
709        'Cache-Control': 'no-cache',
710    })
711
712    printer('%s %s' % (('GET', 'POST')[bool(data)], final_url),
713            debug=True)
714
715    return Request(final_url, data=data, headers=headers)
716
717
718def catch_request(request, opener=None):
719    """Helper function to catch common exceptions encountered when
720    establishing a connection with a HTTP/HTTPS request
721
722    """
723
724    if opener:
725        _open = opener.open
726    else:
727        _open = urlopen
728
729    try:
730        uh = _open(request)
731        if request.get_full_url() != uh.geturl():
732            printer('Redirected to %s' % uh.geturl(), debug=True)
733        return uh, False
734    except HTTP_ERRORS:
735        e = get_exception()
736        return None, e
737
738
739def get_response_stream(response):
740    """Helper function to return either a Gzip reader if
741    ``Content-Encoding`` is ``gzip`` otherwise the response itself
742
743    """
744
745    try:
746        getheader = response.headers.getheader
747    except AttributeError:
748        getheader = response.getheader
749
750    if getheader('content-encoding') == 'gzip':
751        return GzipDecodedResponse(response)
752
753    return response
754
755
756def get_attributes_by_tag_name(dom, tag_name):
757    """Retrieve an attribute from an XML document and return it in a
758    consistent format
759
760    Only used with xml.dom.minidom, which is likely only to be used
761    with python versions older than 2.5
762    """
763    elem = dom.getElementsByTagName(tag_name)[0]
764    return dict(list(elem.attributes.items()))
765
766
767def print_dots(shutdown_event):
768    """Built in callback function used by Thread classes for printing
769    status
770    """
771    def inner(current, total, start=False, end=False):
772        if shutdown_event.isSet():
773            return
774
775        sys.stdout.write('.')
776        if current + 1 == total and end is True:
777            sys.stdout.write('\n')
778        sys.stdout.flush()
779    return inner
780
781
782def do_nothing(*args, **kwargs):
783    pass
784
785
786class HTTPDownloader(threading.Thread):
787    """Thread class for retrieving a URL"""
788
789    def __init__(self, i, request, start, timeout, opener=None,
790                 shutdown_event=None):
791        threading.Thread.__init__(self)
792        self.request = request
793        self.result = [0]
794        self.starttime = start
795        self.timeout = timeout
796        self.i = i
797        if opener:
798            self._opener = opener.open
799        else:
800            self._opener = urlopen
801
802        if shutdown_event:
803            self._shutdown_event = shutdown_event
804        else:
805            self._shutdown_event = FakeShutdownEvent()
806
807    def run(self):
808        try:
809            if (timeit.default_timer() - self.starttime) <= self.timeout:
810                f = self._opener(self.request)
811                while (not self._shutdown_event.isSet() and
812                        (timeit.default_timer() - self.starttime) <=
813                        self.timeout):
814                    self.result.append(len(f.read(10240)))
815                    if self.result[-1] == 0:
816                        break
817                f.close()
818        except IOError:
819            pass
820        except HTTP_ERRORS:
821            pass
822
823
824class HTTPUploaderData(object):
825    """File like object to improve cutting off the upload once the timeout
826    has been reached
827    """
828
829    def __init__(self, length, start, timeout, shutdown_event=None):
830        self.length = length
831        self.start = start
832        self.timeout = timeout
833
834        if shutdown_event:
835            self._shutdown_event = shutdown_event
836        else:
837            self._shutdown_event = FakeShutdownEvent()
838
839        self._data = None
840
841        self.total = [0]
842
843    def pre_allocate(self):
844        chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
845        multiplier = int(round(int(self.length) / 36.0))
846        IO = BytesIO or StringIO
847        try:
848            self._data = IO(
849                ('content1=%s' %
850                 (chars * multiplier)[0:int(self.length) - 9]
851                 ).encode()
852            )
853        except MemoryError:
854            raise SpeedtestCLIError(
855                'Insufficient memory to pre-allocate upload data. Please '
856                'use --no-pre-allocate'
857            )
858
859    @property
860    def data(self):
861        if not self._data:
862            self.pre_allocate()
863        return self._data
864
865    def read(self, n=10240):
866        if ((timeit.default_timer() - self.start) <= self.timeout and
867                not self._shutdown_event.isSet()):
868            chunk = self.data.read(n)
869            self.total.append(len(chunk))
870            return chunk
871        else:
872            raise SpeedtestUploadTimeout()
873
874    def __len__(self):
875        return self.length
876
877
878class HTTPUploader(threading.Thread):
879    """Thread class for putting a URL"""
880
881    def __init__(self, i, request, start, size, timeout, opener=None,
882                 shutdown_event=None):
883        threading.Thread.__init__(self)
884        self.request = request
885        self.request.data.start = self.starttime = start
886        self.size = size
887        self.result = 0
888        self.timeout = timeout
889        self.i = i
890
891        if opener:
892            self._opener = opener.open
893        else:
894            self._opener = urlopen
895
896        if shutdown_event:
897            self._shutdown_event = shutdown_event
898        else:
899            self._shutdown_event = FakeShutdownEvent()
900
901    def run(self):
902        request = self.request
903        try:
904            if ((timeit.default_timer() - self.starttime) <= self.timeout and
905                    not self._shutdown_event.isSet()):
906                try:
907                    f = self._opener(request)
908                except TypeError:
909                    # PY24 expects a string or buffer
910                    # This also causes issues with Ctrl-C, but we will concede
911                    # for the moment that Ctrl-C on PY24 isn't immediate
912                    request = build_request(self.request.get_full_url(),
913                                            data=request.data.read(self.size))
914                    f = self._opener(request)
915                f.read(11)
916                f.close()
917                self.result = sum(self.request.data.total)
918            else:
919                self.result = 0
920        except (IOError, SpeedtestUploadTimeout):
921            self.result = sum(self.request.data.total)
922        except HTTP_ERRORS:
923            self.result = 0
924
925
926class SpeedtestResults(object):
927    """Class for holding the results of a speedtest, including:
928
929    Download speed
930    Upload speed
931    Ping/Latency to test server
932    Data about server that the test was run against
933
934    Additionally this class can return a result data as a dictionary or CSV,
935    as well as submit a POST of the result data to the speedtest.net API
936    to get a share results image link.
937    """
938
939    def __init__(self, download=0, upload=0, ping=0, server=None, client=None,
940                 opener=None, secure=False):
941        self.download = download
942        self.upload = upload
943        self.ping = ping
944        if server is None:
945            self.server = {}
946        else:
947            self.server = server
948        self.client = client or {}
949
950        self._share = None
951        self.timestamp = '%sZ' % datetime.datetime.utcnow().isoformat()
952        self.bytes_received = 0
953        self.bytes_sent = 0
954
955        if opener:
956            self._opener = opener
957        else:
958            self._opener = build_opener()
959
960        self._secure = secure
961
962    def __repr__(self):
963        return repr(self.dict())
964
965    def share(self):
966        """POST data to the speedtest.net API to obtain a share results
967        link
968        """
969
970        if self._share:
971            return self._share
972
973        download = int(round(self.download / 1000.0, 0))
974        ping = int(round(self.ping, 0))
975        upload = int(round(self.upload / 1000.0, 0))
976
977        # Build the request to send results back to speedtest.net
978        # We use a list instead of a dict because the API expects parameters
979        # in a certain order
980        api_data = [
981            'recommendedserverid=%s' % self.server['id'],
982            'ping=%s' % ping,
983            'screenresolution=',
984            'promo=',
985            'download=%s' % download,
986            'screendpi=',
987            'upload=%s' % upload,
988            'testmethod=http',
989            'hash=%s' % md5(('%s-%s-%s-%s' %
990                             (ping, upload, download, '297aae72'))
991                            .encode()).hexdigest(),
992            'touchscreen=none',
993            'startmode=pingselect',
994            'accuracy=1',
995            'bytesreceived=%s' % self.bytes_received,
996            'bytessent=%s' % self.bytes_sent,
997            'serverid=%s' % self.server['id'],
998        ]
999
1000        headers = {'Referer': 'http://c.speedtest.net/flash/speedtest.swf'}
1001        request = build_request('://www.speedtest.net/api/api.php',
1002                                data='&'.join(api_data).encode(),
1003                                headers=headers, secure=self._secure)
1004        f, e = catch_request(request, opener=self._opener)
1005        if e:
1006            raise ShareResultsConnectFailure(e)
1007
1008        response = f.read()
1009        code = f.code
1010        f.close()
1011
1012        if int(code) != 200:
1013            raise ShareResultsSubmitFailure('Could not submit results to '
1014                                            'speedtest.net')
1015
1016        qsargs = parse_qs(response.decode())
1017        resultid = qsargs.get('resultid')
1018        if not resultid or len(resultid) != 1:
1019            raise ShareResultsSubmitFailure('Could not submit results to '
1020                                            'speedtest.net')
1021
1022        self._share = 'http://www.speedtest.net/result/%s.png' % resultid[0]
1023
1024        return self._share
1025
1026    def dict(self):
1027        """Return dictionary of result data"""
1028
1029        return {
1030            'download': self.download,
1031            'upload': self.upload,
1032            'ping': self.ping,
1033            'server': self.server,
1034            'timestamp': self.timestamp,
1035            'bytes_sent': self.bytes_sent,
1036            'bytes_received': self.bytes_received,
1037            'share': self._share,
1038            'client': self.client,
1039        }
1040
1041    @staticmethod
1042    def csv_header(delimiter=','):
1043        """Return CSV Headers"""
1044
1045        row = ['Server ID', 'Sponsor', 'Server Name', 'Timestamp', 'Distance',
1046               'Ping', 'Download', 'Upload', 'Share', 'IP Address']
1047        out = StringIO()
1048        writer = csv.writer(out, delimiter=delimiter, lineterminator='')
1049        writer.writerow([to_utf8(v) for v in row])
1050        return out.getvalue()
1051
1052    def csv(self, delimiter=','):
1053        """Return data in CSV format"""
1054
1055        data = self.dict()
1056        out = StringIO()
1057        writer = csv.writer(out, delimiter=delimiter, lineterminator='')
1058        row = [data['server']['id'], data['server']['sponsor'],
1059               data['server']['name'], data['timestamp'],
1060               data['server']['d'], data['ping'], data['download'],
1061               data['upload'], self._share or '', self.client['ip']]
1062        writer.writerow([to_utf8(v) for v in row])
1063        return out.getvalue()
1064
1065    def json(self, pretty=False):
1066        """Return data in JSON format"""
1067
1068        kwargs = {}
1069        if pretty:
1070            kwargs.update({
1071                'indent': 4,
1072                'sort_keys': True
1073            })
1074        return json.dumps(self.dict(), **kwargs)
1075
1076
1077class Speedtest(object):
1078    """Class for performing standard speedtest.net testing operations"""
1079
1080    def __init__(self, config=None, source_address=None, timeout=10,
1081                 secure=False, shutdown_event=None):
1082        self.config = {}
1083
1084        self._source_address = source_address
1085        self._timeout = timeout
1086        self._opener = build_opener(source_address, timeout)
1087
1088        self._secure = secure
1089
1090        if shutdown_event:
1091            self._shutdown_event = shutdown_event
1092        else:
1093            self._shutdown_event = FakeShutdownEvent()
1094
1095        self.get_config()
1096        if config is not None:
1097            self.config.update(config)
1098
1099        self.servers = {}
1100        self.closest = []
1101        self._best = {}
1102
1103        self.results = SpeedtestResults(
1104            client=self.config['client'],
1105            opener=self._opener,
1106            secure=secure,
1107        )
1108
1109    @property
1110    def best(self):
1111        if not self._best:
1112            self.get_best_server()
1113        return self._best
1114
1115    def get_config(self):
1116        """Download the speedtest.net configuration and return only the data
1117        we are interested in
1118        """
1119
1120        headers = {}
1121        if gzip:
1122            headers['Accept-Encoding'] = 'gzip'
1123        request = build_request('://www.speedtest.net/speedtest-config.php',
1124                                headers=headers, secure=self._secure)
1125        uh, e = catch_request(request, opener=self._opener)
1126        if e:
1127            raise ConfigRetrievalError(e)
1128        configxml_list = []
1129
1130        stream = get_response_stream(uh)
1131
1132        while 1:
1133            try:
1134                configxml_list.append(stream.read(1024))
1135            except (OSError, EOFError):
1136                raise ConfigRetrievalError(get_exception())
1137            if len(configxml_list[-1]) == 0:
1138                break
1139        stream.close()
1140        uh.close()
1141
1142        if int(uh.code) != 200:
1143            return None
1144
1145        configxml = ''.encode().join(configxml_list)
1146
1147        printer('Config XML:\n%s' % configxml, debug=True)
1148
1149        try:
1150            try:
1151                root = ET.fromstring(configxml)
1152            except ET.ParseError:
1153                e = get_exception()
1154                raise SpeedtestConfigError(
1155                    'Malformed speedtest.net configuration: %s' % e
1156                )
1157            server_config = root.find('server-config').attrib
1158            download = root.find('download').attrib
1159            upload = root.find('upload').attrib
1160            # times = root.find('times').attrib
1161            client = root.find('client').attrib
1162
1163        except AttributeError:
1164            try:
1165                root = DOM.parseString(configxml)
1166            except ExpatError:
1167                e = get_exception()
1168                raise SpeedtestConfigError(
1169                    'Malformed speedtest.net configuration: %s' % e
1170                )
1171            server_config = get_attributes_by_tag_name(root, 'server-config')
1172            download = get_attributes_by_tag_name(root, 'download')
1173            upload = get_attributes_by_tag_name(root, 'upload')
1174            # times = get_attributes_by_tag_name(root, 'times')
1175            client = get_attributes_by_tag_name(root, 'client')
1176
1177        ignore_servers = [
1178            int(i) for i in server_config['ignoreids'].split(',') if i
1179        ]
1180
1181        ratio = int(upload['ratio'])
1182        upload_max = int(upload['maxchunkcount'])
1183        up_sizes = [32768, 65536, 131072, 262144, 524288, 1048576, 7340032]
1184        sizes = {
1185            'upload': up_sizes[ratio - 1:],
1186            'download': [350, 500, 750, 1000, 1500, 2000, 2500,
1187                         3000, 3500, 4000]
1188        }
1189
1190        size_count = len(sizes['upload'])
1191
1192        upload_count = int(math.ceil(upload_max / size_count))
1193
1194        counts = {
1195            'upload': upload_count,
1196            'download': int(download['threadsperurl'])
1197        }
1198
1199        threads = {
1200            'upload': int(upload['threads']),
1201            'download': int(server_config['threadcount']) * 2
1202        }
1203
1204        length = {
1205            'upload': int(upload['testlength']),
1206            'download': int(download['testlength'])
1207        }
1208
1209        self.config.update({
1210            'client': client,
1211            'ignore_servers': ignore_servers,
1212            'sizes': sizes,
1213            'counts': counts,
1214            'threads': threads,
1215            'length': length,
1216            'upload_max': upload_count * size_count
1217        })
1218
1219        try:
1220            self.lat_lon = (float(client['lat']), float(client['lon']))
1221        except ValueError:
1222            raise SpeedtestConfigError(
1223                'Unknown location: lat=%r lon=%r' %
1224                (client.get('lat'), client.get('lon'))
1225            )
1226
1227        printer('Config:\n%r' % self.config, debug=True)
1228
1229        return self.config
1230
1231    def get_servers(self, servers=None, exclude=None):
1232        """Retrieve a the list of speedtest.net servers, optionally filtered
1233        to servers matching those specified in the ``servers`` argument
1234        """
1235        if servers is None:
1236            servers = []
1237
1238        if exclude is None:
1239            exclude = []
1240
1241        self.servers.clear()
1242
1243        for server_list in (servers, exclude):
1244            for i, s in enumerate(server_list):
1245                try:
1246                    server_list[i] = int(s)
1247                except ValueError:
1248                    raise InvalidServerIDType(
1249                        '%s is an invalid server type, must be int' % s
1250                    )
1251
1252        urls = [
1253            '://www.speedtest.net/speedtest-servers-static.php',
1254            'http://c.speedtest.net/speedtest-servers-static.php',
1255            '://www.speedtest.net/speedtest-servers.php',
1256            'http://c.speedtest.net/speedtest-servers.php',
1257        ]
1258
1259        headers = {}
1260        if gzip:
1261            headers['Accept-Encoding'] = 'gzip'
1262
1263        errors = []
1264        for url in urls:
1265            try:
1266                request = build_request(
1267                    '%s?threads=%s' % (url,
1268                                       self.config['threads']['download']),
1269                    headers=headers,
1270                    secure=self._secure
1271                )
1272                uh, e = catch_request(request, opener=self._opener)
1273                if e:
1274                    errors.append('%s' % e)
1275                    raise ServersRetrievalError()
1276
1277                stream = get_response_stream(uh)
1278
1279                serversxml_list = []
1280                while 1:
1281                    try:
1282                        serversxml_list.append(stream.read(1024))
1283                    except (OSError, EOFError):
1284                        raise ServersRetrievalError(get_exception())
1285                    if len(serversxml_list[-1]) == 0:
1286                        break
1287
1288                stream.close()
1289                uh.close()
1290
1291                if int(uh.code) != 200:
1292                    raise ServersRetrievalError()
1293
1294                serversxml = ''.encode().join(serversxml_list)
1295
1296                printer('Servers XML:\n%s' % serversxml, debug=True)
1297
1298                try:
1299                    try:
1300                        try:
1301                            root = ET.fromstring(serversxml)
1302                        except ET.ParseError:
1303                            e = get_exception()
1304                            raise SpeedtestServersError(
1305                                'Malformed speedtest.net server list: %s' % e
1306                            )
1307                        elements = etree_iter(root, 'server')
1308                    except AttributeError:
1309                        try:
1310                            root = DOM.parseString(serversxml)
1311                        except ExpatError:
1312                            e = get_exception()
1313                            raise SpeedtestServersError(
1314                                'Malformed speedtest.net server list: %s' % e
1315                            )
1316                        elements = root.getElementsByTagName('server')
1317                except (SyntaxError, xml.parsers.expat.ExpatError):
1318                    raise ServersRetrievalError()
1319
1320                for server in elements:
1321                    try:
1322                        attrib = server.attrib
1323                    except AttributeError:
1324                        attrib = dict(list(server.attributes.items()))
1325
1326                    if servers and int(attrib.get('id')) not in servers:
1327                        continue
1328
1329                    if (int(attrib.get('id')) in self.config['ignore_servers']
1330                            or int(attrib.get('id')) in exclude):
1331                        continue
1332
1333                    try:
1334                        d = distance(self.lat_lon,
1335                                     (float(attrib.get('lat')),
1336                                      float(attrib.get('lon'))))
1337                    except Exception:
1338                        continue
1339
1340                    attrib['d'] = d
1341
1342                    try:
1343                        self.servers[d].append(attrib)
1344                    except KeyError:
1345                        self.servers[d] = [attrib]
1346
1347                break
1348
1349            except ServersRetrievalError:
1350                continue
1351
1352        if (servers or exclude) and not self.servers:
1353            raise NoMatchedServers()
1354
1355        return self.servers
1356
1357    def set_mini_server(self, server):
1358        """Instead of querying for a list of servers, set a link to a
1359        speedtest mini server
1360        """
1361
1362        urlparts = urlparse(server)
1363
1364        name, ext = os.path.splitext(urlparts[2])
1365        if ext:
1366            url = os.path.dirname(server)
1367        else:
1368            url = server
1369
1370        request = build_request(url)
1371        uh, e = catch_request(request, opener=self._opener)
1372        if e:
1373            raise SpeedtestMiniConnectFailure('Failed to connect to %s' %
1374                                              server)
1375        else:
1376            text = uh.read()
1377            uh.close()
1378
1379        extension = re.findall('upload_?[Ee]xtension: "([^"]+)"',
1380                               text.decode())
1381        if not extension:
1382            for ext in ['php', 'asp', 'aspx', 'jsp']:
1383                try:
1384                    f = self._opener.open(
1385                        '%s/speedtest/upload.%s' % (url, ext)
1386                    )
1387                except Exception:
1388                    pass
1389                else:
1390                    data = f.read().strip().decode()
1391                    if (f.code == 200 and
1392                            len(data.splitlines()) == 1 and
1393                            re.match('size=[0-9]', data)):
1394                        extension = [ext]
1395                        break
1396        if not urlparts or not extension:
1397            raise InvalidSpeedtestMiniServer('Invalid Speedtest Mini Server: '
1398                                             '%s' % server)
1399
1400        self.servers = [{
1401            'sponsor': 'Speedtest Mini',
1402            'name': urlparts[1],
1403            'd': 0,
1404            'url': '%s/speedtest/upload.%s' % (url.rstrip('/'), extension[0]),
1405            'latency': 0,
1406            'id': 0
1407        }]
1408
1409        return self.servers
1410
1411    def get_closest_servers(self, limit=5):
1412        """Limit servers to the closest speedtest.net servers based on
1413        geographic distance
1414        """
1415
1416        if not self.servers:
1417            self.get_servers()
1418
1419        for d in sorted(self.servers.keys()):
1420            for s in self.servers[d]:
1421                self.closest.append(s)
1422                if len(self.closest) == limit:
1423                    break
1424            else:
1425                continue
1426            break
1427
1428        printer('Closest Servers:\n%r' % self.closest, debug=True)
1429        return self.closest
1430
1431    def get_best_server(self, servers=None):
1432        """Perform a speedtest.net "ping" to determine which speedtest.net
1433        server has the lowest latency
1434        """
1435
1436        if not servers:
1437            if not self.closest:
1438                servers = self.get_closest_servers()
1439            servers = self.closest
1440
1441        if self._source_address:
1442            source_address_tuple = (self._source_address, 0)
1443        else:
1444            source_address_tuple = None
1445
1446        user_agent = build_user_agent()
1447
1448        results = {}
1449        for server in servers:
1450            cum = []
1451            url = os.path.dirname(server['url'])
1452            stamp = int(timeit.time.time() * 1000)
1453            latency_url = '%s/latency.txt?x=%s' % (url, stamp)
1454            for i in range(0, 3):
1455                this_latency_url = '%s.%s' % (latency_url, i)
1456                printer('%s %s' % ('GET', this_latency_url),
1457                        debug=True)
1458                urlparts = urlparse(latency_url)
1459                try:
1460                    if urlparts[0] == 'https':
1461                        h = SpeedtestHTTPSConnection(
1462                            urlparts[1],
1463                            source_address=source_address_tuple
1464                        )
1465                    else:
1466                        h = SpeedtestHTTPConnection(
1467                            urlparts[1],
1468                            source_address=source_address_tuple
1469                        )
1470                    headers = {'User-Agent': user_agent}
1471                    path = '%s?%s' % (urlparts[2], urlparts[4])
1472                    start = timeit.default_timer()
1473                    h.request("GET", path, headers=headers)
1474                    r = h.getresponse()
1475                    total = (timeit.default_timer() - start)
1476                except HTTP_ERRORS:
1477                    e = get_exception()
1478                    printer('ERROR: %r' % e, debug=True)
1479                    cum.append(3600)
1480                    continue
1481
1482                text = r.read(9)
1483                if int(r.status) == 200 and text == 'test=test'.encode():
1484                    cum.append(total)
1485                else:
1486                    cum.append(3600)
1487                h.close()
1488
1489            avg = round((sum(cum) / 6) * 1000.0, 3)
1490            results[avg] = server
1491
1492        try:
1493            fastest = sorted(results.keys())[0]
1494        except IndexError:
1495            raise SpeedtestBestServerFailure('Unable to connect to servers to '
1496                                             'test latency.')
1497        best = results[fastest]
1498        best['latency'] = fastest
1499
1500        self.results.ping = fastest
1501        self.results.server = best
1502
1503        self._best.update(best)
1504        printer('Best Server:\n%r' % best, debug=True)
1505        return best
1506
1507    def download(self, callback=do_nothing, threads=None):
1508        """Test download speed against speedtest.net
1509
1510        A ``threads`` value of ``None`` will fall back to those dictated
1511        by the speedtest.net configuration
1512        """
1513
1514        urls = []
1515        for size in self.config['sizes']['download']:
1516            for _ in range(0, self.config['counts']['download']):
1517                urls.append('%s/random%sx%s.jpg' %
1518                            (os.path.dirname(self.best['url']), size, size))
1519
1520        request_count = len(urls)
1521        requests = []
1522        for i, url in enumerate(urls):
1523            requests.append(
1524                build_request(url, bump=i, secure=self._secure)
1525            )
1526
1527        max_threads = threads or self.config['threads']['download']
1528        in_flight = {'threads': 0}
1529
1530        def producer(q, requests, request_count):
1531            for i, request in enumerate(requests):
1532                thread = HTTPDownloader(
1533                    i,
1534                    request,
1535                    start,
1536                    self.config['length']['download'],
1537                    opener=self._opener,
1538                    shutdown_event=self._shutdown_event
1539                )
1540                while in_flight['threads'] >= max_threads:
1541                    timeit.time.sleep(0.001)
1542                thread.start()
1543                q.put(thread, True)
1544                in_flight['threads'] += 1
1545                callback(i, request_count, start=True)
1546
1547        finished = []
1548
1549        def consumer(q, request_count):
1550            _is_alive = thread_is_alive
1551            while len(finished) < request_count:
1552                thread = q.get(True)
1553                while _is_alive(thread):
1554                    thread.join(timeout=0.001)
1555                in_flight['threads'] -= 1
1556                finished.append(sum(thread.result))
1557                callback(thread.i, request_count, end=True)
1558
1559        q = Queue(max_threads)
1560        prod_thread = threading.Thread(target=producer,
1561                                       args=(q, requests, request_count))
1562        cons_thread = threading.Thread(target=consumer,
1563                                       args=(q, request_count))
1564        start = timeit.default_timer()
1565        prod_thread.start()
1566        cons_thread.start()
1567        _is_alive = thread_is_alive
1568        while _is_alive(prod_thread):
1569            prod_thread.join(timeout=0.001)
1570        while _is_alive(cons_thread):
1571            cons_thread.join(timeout=0.001)
1572
1573        stop = timeit.default_timer()
1574        self.results.bytes_received = sum(finished)
1575        self.results.download = (
1576            (self.results.bytes_received / (stop - start)) * 8.0
1577        )
1578        if self.results.download > 100000:
1579            self.config['threads']['upload'] = 8
1580        return self.results.download
1581
1582    def upload(self, callback=do_nothing, pre_allocate=True, threads=None):
1583        """Test upload speed against speedtest.net
1584
1585        A ``threads`` value of ``None`` will fall back to those dictated
1586        by the speedtest.net configuration
1587        """
1588
1589        sizes = []
1590
1591        for size in self.config['sizes']['upload']:
1592            for _ in range(0, self.config['counts']['upload']):
1593                sizes.append(size)
1594
1595        # request_count = len(sizes)
1596        request_count = self.config['upload_max']
1597
1598        requests = []
1599        for i, size in enumerate(sizes):
1600            # We set ``0`` for ``start`` and handle setting the actual
1601            # ``start`` in ``HTTPUploader`` to get better measurements
1602            data = HTTPUploaderData(
1603                size,
1604                0,
1605                self.config['length']['upload'],
1606                shutdown_event=self._shutdown_event
1607            )
1608            if pre_allocate:
1609                data.pre_allocate()
1610
1611            headers = {'Content-length': size}
1612            requests.append(
1613                (
1614                    build_request(self.best['url'], data, secure=self._secure,
1615                                  headers=headers),
1616                    size
1617                )
1618            )
1619
1620        max_threads = threads or self.config['threads']['upload']
1621        in_flight = {'threads': 0}
1622
1623        def producer(q, requests, request_count):
1624            for i, request in enumerate(requests[:request_count]):
1625                thread = HTTPUploader(
1626                    i,
1627                    request[0],
1628                    start,
1629                    request[1],
1630                    self.config['length']['upload'],
1631                    opener=self._opener,
1632                    shutdown_event=self._shutdown_event
1633                )
1634                while in_flight['threads'] >= max_threads:
1635                    timeit.time.sleep(0.001)
1636                thread.start()
1637                q.put(thread, True)
1638                in_flight['threads'] += 1
1639                callback(i, request_count, start=True)
1640
1641        finished = []
1642
1643        def consumer(q, request_count):
1644            _is_alive = thread_is_alive
1645            while len(finished) < request_count:
1646                thread = q.get(True)
1647                while _is_alive(thread):
1648                    thread.join(timeout=0.001)
1649                in_flight['threads'] -= 1
1650                finished.append(thread.result)
1651                callback(thread.i, request_count, end=True)
1652
1653        q = Queue(threads or self.config['threads']['upload'])
1654        prod_thread = threading.Thread(target=producer,
1655                                       args=(q, requests, request_count))
1656        cons_thread = threading.Thread(target=consumer,
1657                                       args=(q, request_count))
1658        start = timeit.default_timer()
1659        prod_thread.start()
1660        cons_thread.start()
1661        _is_alive = thread_is_alive
1662        while _is_alive(prod_thread):
1663            prod_thread.join(timeout=0.1)
1664        while _is_alive(cons_thread):
1665            cons_thread.join(timeout=0.1)
1666
1667        stop = timeit.default_timer()
1668        self.results.bytes_sent = sum(finished)
1669        self.results.upload = (
1670            (self.results.bytes_sent / (stop - start)) * 8.0
1671        )
1672        return self.results.upload
1673
1674
1675def ctrl_c(shutdown_event):
1676    """Catch Ctrl-C key sequence and set a SHUTDOWN_EVENT for our threaded
1677    operations
1678    """
1679    def inner(signum, frame):
1680        shutdown_event.set()
1681        printer('\nCancelling...', error=True)
1682        sys.exit(0)
1683    return inner
1684
1685
1686def version():
1687    """Print the version"""
1688
1689    printer('speedtest-cli %s' % __version__)
1690    printer('Python %s' % sys.version.replace('\n', ''))
1691    sys.exit(0)
1692
1693
1694def csv_header(delimiter=','):
1695    """Print the CSV Headers"""
1696
1697    printer(SpeedtestResults.csv_header(delimiter=delimiter))
1698    sys.exit(0)
1699
1700
1701def parse_args():
1702    """Function to handle building and parsing of command line arguments"""
1703    description = (
1704        'Command line interface for testing internet bandwidth using '
1705        'speedtest.net.\n'
1706        '------------------------------------------------------------'
1707        '--------------\n'
1708        'https://github.com/sivel/speedtest-cli')
1709
1710    parser = ArgParser(description=description)
1711    # Give optparse.OptionParser an `add_argument` method for
1712    # compatibility with argparse.ArgumentParser
1713    try:
1714        parser.add_argument = parser.add_option
1715    except AttributeError:
1716        pass
1717    parser.add_argument('--no-download', dest='download', default=True,
1718                        action='store_const', const=False,
1719                        help='Do not perform download test')
1720    parser.add_argument('--no-upload', dest='upload', default=True,
1721                        action='store_const', const=False,
1722                        help='Do not perform upload test')
1723    parser.add_argument('--single', default=False, action='store_true',
1724                        help='Only use a single connection instead of '
1725                             'multiple. This simulates a typical file '
1726                             'transfer.')
1727    parser.add_argument('--bytes', dest='units', action='store_const',
1728                        const=('byte', 8), default=('bit', 1),
1729                        help='Display values in bytes instead of bits. Does '
1730                             'not affect the image generated by --share, nor '
1731                             'output from --json or --csv')
1732    parser.add_argument('--share', action='store_true',
1733                        help='Generate and provide a URL to the speedtest.net '
1734                             'share results image, not displayed with --csv')
1735    parser.add_argument('--simple', action='store_true', default=False,
1736                        help='Suppress verbose output, only show basic '
1737                             'information')
1738    parser.add_argument('--csv', action='store_true', default=False,
1739                        help='Suppress verbose output, only show basic '
1740                             'information in CSV format. Speeds listed in '
1741                             'bit/s and not affected by --bytes')
1742    parser.add_argument('--csv-delimiter', default=',', type=PARSER_TYPE_STR,
1743                        help='Single character delimiter to use in CSV '
1744                             'output. Default ","')
1745    parser.add_argument('--csv-header', action='store_true', default=False,
1746                        help='Print CSV headers')
1747    parser.add_argument('--json', action='store_true', default=False,
1748                        help='Suppress verbose output, only show basic '
1749                             'information in JSON format. Speeds listed in '
1750                             'bit/s and not affected by --bytes')
1751    parser.add_argument('--list', action='store_true',
1752                        help='Display a list of speedtest.net servers '
1753                             'sorted by distance')
1754    parser.add_argument('--server', type=PARSER_TYPE_INT, action='append',
1755                        help='Specify a server ID to test against. Can be '
1756                             'supplied multiple times')
1757    parser.add_argument('--exclude', type=PARSER_TYPE_INT, action='append',
1758                        help='Exclude a server from selection. Can be '
1759                             'supplied multiple times')
1760    parser.add_argument('--mini', help='URL of the Speedtest Mini server')
1761    parser.add_argument('--source', help='Source IP address to bind to')
1762    parser.add_argument('--timeout', default=10, type=PARSER_TYPE_FLOAT,
1763                        help='HTTP timeout in seconds. Default 10')
1764    parser.add_argument('--secure', action='store_true',
1765                        help='Use HTTPS instead of HTTP when communicating '
1766                             'with speedtest.net operated servers')
1767    parser.add_argument('--no-pre-allocate', dest='pre_allocate',
1768                        action='store_const', default=True, const=False,
1769                        help='Do not pre allocate upload data. Pre allocation '
1770                             'is enabled by default to improve upload '
1771                             'performance. To support systems with '
1772                             'insufficient memory, use this option to avoid a '
1773                             'MemoryError')
1774    parser.add_argument('--version', action='store_true',
1775                        help='Show the version number and exit')
1776    parser.add_argument('--debug', action='store_true',
1777                        help=ARG_SUPPRESS, default=ARG_SUPPRESS)
1778
1779    options = parser.parse_args()
1780    if isinstance(options, tuple):
1781        args = options[0]
1782    else:
1783        args = options
1784    return args
1785
1786
1787def validate_optional_args(args):
1788    """Check if an argument was provided that depends on a module that may
1789    not be part of the Python standard library.
1790
1791    If such an argument is supplied, and the module does not exist, exit
1792    with an error stating which module is missing.
1793    """
1794    optional_args = {
1795        'json': ('json/simplejson python module', json),
1796        'secure': ('SSL support', HTTPSConnection),
1797    }
1798
1799    for arg, info in optional_args.items():
1800        if getattr(args, arg, False) and info[1] is None:
1801            raise SystemExit('%s is not installed. --%s is '
1802                             'unavailable' % (info[0], arg))
1803
1804
1805def printer(string, quiet=False, debug=False, error=False, **kwargs):
1806    """Helper function print a string with various features"""
1807
1808    if debug and not DEBUG:
1809        return
1810
1811    if debug:
1812        if sys.stdout.isatty():
1813            out = '\033[1;30mDEBUG: %s\033[0m' % string
1814        else:
1815            out = 'DEBUG: %s' % string
1816    else:
1817        out = string
1818
1819    if error:
1820        kwargs['file'] = sys.stderr
1821
1822    if not quiet:
1823        print_(out, **kwargs)
1824
1825
1826def shell():
1827    """Run the full speedtest.net test"""
1828
1829    global DEBUG
1830    shutdown_event = threading.Event()
1831
1832    signal.signal(signal.SIGINT, ctrl_c(shutdown_event))
1833
1834    args = parse_args()
1835
1836    # Print the version and exit
1837    if args.version:
1838        version()
1839
1840    if not args.download and not args.upload:
1841        raise SpeedtestCLIError('Cannot supply both --no-download and '
1842                                '--no-upload')
1843
1844    if len(args.csv_delimiter) != 1:
1845        raise SpeedtestCLIError('--csv-delimiter must be a single character')
1846
1847    if args.csv_header:
1848        csv_header(args.csv_delimiter)
1849
1850    validate_optional_args(args)
1851
1852    debug = getattr(args, 'debug', False)
1853    if debug == 'SUPPRESSHELP':
1854        debug = False
1855    if debug:
1856        DEBUG = True
1857
1858    if args.simple or args.csv or args.json:
1859        quiet = True
1860    else:
1861        quiet = False
1862
1863    if args.csv or args.json:
1864        machine_format = True
1865    else:
1866        machine_format = False
1867
1868    # Don't set a callback if we are running quietly
1869    if quiet or debug:
1870        callback = do_nothing
1871    else:
1872        callback = print_dots(shutdown_event)
1873
1874    printer('Retrieving speedtest.net configuration...', quiet)
1875    try:
1876        speedtest = Speedtest(
1877            source_address=args.source,
1878            timeout=args.timeout,
1879            secure=args.secure
1880        )
1881    except (ConfigRetrievalError,) + HTTP_ERRORS:
1882        printer('Cannot retrieve speedtest configuration', error=True)
1883        raise SpeedtestCLIError(get_exception())
1884
1885    if args.list:
1886        try:
1887            speedtest.get_servers()
1888        except (ServersRetrievalError,) + HTTP_ERRORS:
1889            printer('Cannot retrieve speedtest server list', error=True)
1890            raise SpeedtestCLIError(get_exception())
1891
1892        for _, servers in sorted(speedtest.servers.items()):
1893            for server in servers:
1894                line = ('%(id)5s) %(sponsor)s (%(name)s, %(country)s) '
1895                        '[%(d)0.2f km]' % server)
1896                try:
1897                    printer(line)
1898                except IOError:
1899                    e = get_exception()
1900                    if e.errno != errno.EPIPE:
1901                        raise
1902        sys.exit(0)
1903
1904    printer('Testing from %(isp)s (%(ip)s)...' % speedtest.config['client'],
1905            quiet)
1906
1907    if not args.mini:
1908        printer('Retrieving speedtest.net server list...', quiet)
1909        try:
1910            speedtest.get_servers(servers=args.server, exclude=args.exclude)
1911        except NoMatchedServers:
1912            raise SpeedtestCLIError(
1913                'No matched servers: %s' %
1914                ', '.join('%s' % s for s in args.server)
1915            )
1916        except (ServersRetrievalError,) + HTTP_ERRORS:
1917            printer('Cannot retrieve speedtest server list', error=True)
1918            raise SpeedtestCLIError(get_exception())
1919        except InvalidServerIDType:
1920            raise SpeedtestCLIError(
1921                '%s is an invalid server type, must '
1922                'be an int' % ', '.join('%s' % s for s in args.server)
1923            )
1924
1925        if args.server and len(args.server) == 1:
1926            printer('Retrieving information for the selected server...', quiet)
1927        else:
1928            printer('Selecting best server based on ping...', quiet)
1929        speedtest.get_best_server()
1930    elif args.mini:
1931        speedtest.get_best_server(speedtest.set_mini_server(args.mini))
1932
1933    results = speedtest.results
1934
1935    printer('Hosted by %(sponsor)s (%(name)s) [%(d)0.2f km]: '
1936            '%(latency)s ms' % results.server, quiet)
1937
1938    if args.download:
1939        printer('Testing download speed', quiet,
1940                end=('', '\n')[bool(debug)])
1941        speedtest.download(
1942            callback=callback,
1943            threads=(None, 1)[args.single]
1944        )
1945        printer('Download: %0.2f M%s/s' %
1946                ((results.download / 1000.0 / 1000.0) / args.units[1],
1947                 args.units[0]),
1948                quiet)
1949    else:
1950        printer('Skipping download test', quiet)
1951
1952    if args.upload:
1953        printer('Testing upload speed', quiet,
1954                end=('', '\n')[bool(debug)])
1955        speedtest.upload(
1956            callback=callback,
1957            pre_allocate=args.pre_allocate,
1958            threads=(None, 1)[args.single]
1959        )
1960        printer('Upload: %0.2f M%s/s' %
1961                ((results.upload / 1000.0 / 1000.0) / args.units[1],
1962                 args.units[0]),
1963                quiet)
1964    else:
1965        printer('Skipping upload test', quiet)
1966
1967    printer('Results:\n%r' % results.dict(), debug=True)
1968
1969    if not args.simple and args.share:
1970        results.share()
1971
1972    if args.simple:
1973        printer('Ping: %s ms\nDownload: %0.2f M%s/s\nUpload: %0.2f M%s/s' %
1974                (results.ping,
1975                 (results.download / 1000.0 / 1000.0) / args.units[1],
1976                 args.units[0],
1977                 (results.upload / 1000.0 / 1000.0) / args.units[1],
1978                 args.units[0]))
1979    elif args.csv:
1980        printer(results.csv(delimiter=args.csv_delimiter))
1981    elif args.json:
1982        printer(results.json())
1983
1984    if args.share and not machine_format:
1985        printer('Share results: %s' % results.share())
1986
1987
1988def main():
1989    try:
1990        shell()
1991    except KeyboardInterrupt:
1992        printer('\nCancelling...', error=True)
1993    except (SpeedtestException, SystemExit):
1994        e = get_exception()
1995        # Ignore a successful exit, or argparse exit
1996        if getattr(e, 'code', 1) not in (0, 2):
1997            msg = '%s' % e
1998            if not msg:
1999                msg = '%r' % e
2000            raise SystemExit('ERROR: %s' % msg)
2001
2002
2003if __name__ == '__main__':
2004    main()
2005