1import os
2import sys
3from tempfile import mkstemp
4
5import db_utils as dbutils
6import fixture_utils as fixutils
7import pexpect
8
9from steps.wrappers import run_cli, wait_prompt
10
11test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
12
13
14def before_all(context):
15    """Set env parameters."""
16    os.environ['LINES'] = "100"
17    os.environ['COLUMNS'] = "100"
18    os.environ['EDITOR'] = 'ex'
19    os.environ['LC_ALL'] = 'en_US.UTF-8'
20    os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1'
21    os.environ['MYCLI_HISTFILE'] = os.devnull
22
23    test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
24    login_path_file = os.path.join(test_dir, 'mylogin.cnf')
25    os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
26
27    context.package_root = os.path.abspath(
28        os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
29
30    os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root,
31                                                        '.coveragerc')
32
33    context.exit_sent = False
34
35    vi = '_'.join([str(x) for x in sys.version_info[:3]])
36    db_name = context.config.userdata.get(
37        'my_test_db', None) or "mycli_behave_tests"
38    db_name_full = '{0}_{1}'.format(db_name, vi)
39
40    # Store get params from config/environment variables
41    context.conf = {
42        'host': context.config.userdata.get(
43            'my_test_host',
44            os.getenv('PYTEST_HOST', 'localhost')
45        ),
46        'port': context.config.userdata.get(
47            'my_test_port',
48            int(os.getenv('PYTEST_PORT', '3306'))
49        ),
50        'user': context.config.userdata.get(
51            'my_test_user',
52            os.getenv('PYTEST_USER', 'root')
53        ),
54        'pass': context.config.userdata.get(
55            'my_test_pass',
56            os.getenv('PYTEST_PASSWORD', None)
57        ),
58        'cli_command': context.config.userdata.get(
59            'my_cli_command', None) or
60        sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
61        'dbname': db_name,
62        'dbname_tmp': db_name_full + '_tmp',
63        'vi': vi,
64        'pager_boundary': '---boundary---',
65    }
66
67    _, my_cnf = mkstemp()
68    with open(my_cnf, 'w') as f:
69        f.write(
70            '[client]\n'
71            'pager={0} {1} {2}\n'.format(
72                sys.executable, os.path.join(context.package_root,
73                                             'test/features/wrappager.py'),
74                context.conf['pager_boundary'])
75        )
76    context.conf['defaults-file'] = my_cnf
77    context.conf['myclirc'] = os.path.join(context.package_root, 'test',
78                                           'myclirc')
79
80    context.cn = dbutils.create_db(context.conf['host'], context.conf['port'],
81                                   context.conf['user'],
82                                   context.conf['pass'],
83                                   context.conf['dbname'])
84
85    context.fixture_data = fixutils.read_fixture_files()
86
87
88def after_all(context):
89    """Unset env parameters."""
90    dbutils.close_cn(context.cn)
91    dbutils.drop_db(context.conf['host'], context.conf['port'],
92                    context.conf['user'], context.conf['pass'],
93                    context.conf['dbname'])
94
95    # Restore env vars.
96    #for k, v in context.pgenv.items():
97    #    if k in os.environ and v is None:
98    #        del os.environ[k]
99    #    elif v:
100    #        os.environ[k] = v
101
102
103def before_step(context, _):
104    context.atprompt = False
105
106
107def before_scenario(context, _):
108    with open(test_log_file, 'w') as f:
109        f.write('')
110    run_cli(context)
111    wait_prompt(context)
112
113
114def after_scenario(context, _):
115    """Cleans up after each test complete."""
116    with open(test_log_file) as f:
117        for line in f:
118            if 'error' in line.lower():
119                raise RuntimeError(f'Error in log file: {line}')
120
121    if hasattr(context, 'cli') and not context.exit_sent:
122        # Quit nicely.
123        if not context.atprompt:
124            user = context.conf['user']
125            host = context.conf['host']
126            dbname = context.currentdb
127            context.cli.expect_exact(
128                '{0}@{1}:{2}>'.format(
129                    user, host, dbname
130                ),
131                timeout=5
132            )
133        context.cli.sendcontrol('c')
134        context.cli.sendcontrol('d')
135        context.cli.expect_exact(pexpect.EOF, timeout=5)
136
137# TODO: uncomment to debug a failure
138# def after_step(context, step):
139#     if step.status == "failed":
140#         import ipdb; ipdb.set_trace()
141