1#!/usr/bin/env python
2
3# Licensed to the Apache Software Foundation (ASF) under one
4# or more contributor license agreements.  See the NOTICE file
5# distributed with this work for additional information
6# regarding copyright ownership.  The ASF licenses this file
7# to you under the Apache License, Version 2.0 (the
8# "License"); you may not use this file except in compliance
9# with the License.  You may obtain a copy of the License at
10#
11#   http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing,
14# software distributed under the License is distributed on an
15# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16# KIND, either express or implied.  See the License for the
17# specific language governing permissions and limitations
18# under the License.
19
20"""Diagnose script for checking OS/hardware/python/pip/mxnet/network.
21The output of this script can be a very good hint to issue/problem.
22"""
23import platform, subprocess, sys, os
24import socket, time
25try:
26    from urllib.request import urlopen
27    from urllib.parse import urlparse
28except ImportError:
29    from urlparse import urlparse
30    from urllib2 import urlopen
31import argparse
32
33URLS = {
34    'MXNet': 'https://github.com/apache/incubator-mxnet',
35    'Gluon Tutorial(en)': 'http://gluon.mxnet.io',
36    'Gluon Tutorial(cn)': 'https://zh.gluon.ai',
37    'FashionMNIST': 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz',
38    'PYPI': 'https://pypi.python.org/pypi/pip',
39    'Conda': 'https://repo.continuum.io/pkgs/free/',
40}
41
42REGIONAL_URLS = {
43    'cn': {
44        'PYPI(douban)': 'https://pypi.douban.com/',
45        'Conda(tsinghua)': 'https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/',
46    }
47}
48
49def test_connection(name, url, timeout=10):
50    """Simple connection test"""
51    urlinfo = urlparse(url)
52    start = time.time()
53    try:
54        ip = socket.gethostbyname(urlinfo.netloc)
55    except Exception as e:
56        print('Error resolving DNS for {}: {}, {}'.format(name, url, e))
57        return
58    dns_elapsed = time.time() - start
59    start = time.time()
60    try:
61        _ = urlopen(url, timeout=timeout)
62    except Exception as e:
63        print("Error open {}: {}, {}, DNS finished in {} sec.".format(name, url, e, dns_elapsed))
64        return
65    load_elapsed = time.time() - start
66    print("Timing for {}: {}, DNS: {:.4f} sec, LOAD: {:.4f} sec.".format(name, url, dns_elapsed, load_elapsed))
67
68
69def check_python():
70    print('----------Python Info----------')
71    print('Version      :', platform.python_version())
72    print('Compiler     :', platform.python_compiler())
73    print('Build        :', platform.python_build())
74    print('Arch         :', platform.architecture())
75
76def check_pip():
77    print('------------Pip Info-----------')
78    try:
79        import pip
80        print('Version      :', pip.__version__)
81        print('Directory    :', os.path.dirname(pip.__file__))
82    except ImportError:
83        print('No corresponding pip install for current python.')
84
85
86def get_build_features_str():
87    import mxnet.runtime
88    features = mxnet.runtime.Features()
89    return '\n'.join(map(str, list(features.values())))
90
91
92def check_mxnet():
93    print('----------MXNet Info-----------')
94    try:
95        import mxnet
96        print('Version      :', mxnet.__version__)
97        mx_dir = os.path.dirname(mxnet.__file__)
98        print('Directory    :', mx_dir)
99        commit_hash = os.path.join(mx_dir, 'COMMIT_HASH')
100        if os.path.exists(commit_hash):
101            with open(commit_hash, 'r') as f:
102                ch = f.read().strip()
103                print('Commit Hash   :', ch)
104        else:
105            print('Commit hash file "{}" not found. Not installed from pre-built package or built from source.'.format(commit_hash))
106        print('Library      :', mxnet.libinfo.find_lib_path())
107        try:
108            print('Build features:')
109            print(get_build_features_str())
110        except Exception:
111            print('No runtime build feature info available')
112    except ImportError:
113        print('No MXNet installed.')
114    except Exception as e:
115        import traceback
116        if not isinstance(e, IOError):
117            print("An error occured trying to import mxnet.")
118            print("This is very likely due to missing missing or incompatible library files.")
119        print(traceback.format_exc())
120
121
122def check_os():
123    print('----------System Info----------')
124    print('Platform     :', platform.platform())
125    print('system       :', platform.system())
126    print('node         :', platform.node())
127    print('release      :', platform.release())
128    print('version      :', platform.version())
129
130
131def check_hardware():
132    print('----------Hardware Info----------')
133    print('machine      :', platform.machine())
134    print('processor    :', platform.processor())
135    if sys.platform.startswith('darwin'):
136        pipe = subprocess.Popen(('sysctl', '-a'), stdout=subprocess.PIPE)
137        output = pipe.communicate()[0]
138        for line in output.split(b'\n'):
139            if b'brand_string' in line or b'features' in line:
140                print(line.strip())
141    elif sys.platform.startswith('linux'):
142        subprocess.call(['lscpu'])
143    elif sys.platform.startswith('win32'):
144        subprocess.call(['wmic', 'cpu', 'get', 'name'])
145
146
147def check_network(args):
148    print('----------Network Test----------')
149    if args.timeout > 0:
150        print('Setting timeout: {}'.format(args.timeout))
151        socket.setdefaulttimeout(10)
152    for region in args.region.strip().split(','):
153        r = region.strip().lower()
154        if not r:
155            continue
156        if r in REGIONAL_URLS:
157            URLS.update(REGIONAL_URLS[r])
158        else:
159            import warnings
160            warnings.warn('Region {} do not need specific test, please refer to global sites.'.format(r))
161    for name, url in URLS.items():
162        test_connection(name, url, args.timeout)
163
164
165def check_environment():
166    print('----------Environment----------')
167    for k,v in os.environ.items():
168        if k.startswith('MXNET_') or k.startswith('OMP_') or k.startswith('KMP_') or k == 'CC' or k == 'CXX':
169            print('{}="{}"'.format(k,v))
170
171
172def parse_args():
173    """Parse arguments."""
174    parser = argparse.ArgumentParser(
175        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
176        description='Diagnose script for checking the current system.')
177    choices = ['python', 'pip', 'mxnet', 'os', 'hardware', 'network', 'environment']
178    for choice in choices:
179        parser.add_argument('--' + choice, default=1, type=int,
180                            help='Diagnose {}.'.format(choice))
181    parser.add_argument('--region', default='', type=str,
182                        help="Additional sites in which region(s) to test. \
183                        Specify 'cn' for example to test mirror sites in China.")
184    parser.add_argument('--timeout', default=10, type=int,
185                        help="Connection test timeout threshold, 0 to disable.")
186    args = parser.parse_args()
187    return args
188
189
190if __name__ == '__main__':
191    args = parse_args()
192    if args.python:
193        check_python()
194
195    if args.pip:
196        check_pip()
197
198    if args.mxnet:
199        check_mxnet()
200
201    if args.os:
202        check_os()
203
204    if args.hardware:
205        check_hardware()
206
207    if args.network:
208        check_network(args)
209
210    if args.environment:
211        check_environment()
212