1import os
2
3import click
4from click.testing import CliRunner
5
6from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT
7from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
8from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
9
10from textwrap import dedent
11from collections import namedtuple
12
13from tempfile import NamedTemporaryFile
14from textwrap import dedent
15
16
17test_dir = os.path.abspath(os.path.dirname(__file__))
18project_dir = os.path.dirname(test_dir)
19default_config_file = os.path.join(project_dir, 'test', 'myclirc')
20login_path_file = os.path.join(test_dir, 'mylogin.cnf')
21
22os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
23CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT,
24            '--password', PASSWORD, '--myclirc', default_config_file,
25            '--defaults-file', default_config_file,
26            '_test_db']
27
28
29@dbtest
30def test_execute_arg(executor):
31    run(executor, 'create table test (a text)')
32    run(executor, 'insert into test values("abc")')
33
34    sql = 'select * from test;'
35    runner = CliRunner()
36    result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql])
37
38    assert result.exit_code == 0
39    assert 'abc' in result.output
40
41    result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql])
42
43    assert result.exit_code == 0
44    assert 'abc' in result.output
45
46    expected = 'a\nabc\n'
47
48    assert expected in result.output
49
50
51@dbtest
52def test_execute_arg_with_table(executor):
53    run(executor, 'create table test (a text)')
54    run(executor, 'insert into test values("abc")')
55
56    sql = 'select * from test;'
57    runner = CliRunner()
58    result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table'])
59    expected = '+-----+\n| a   |\n+-----+\n| abc |\n+-----+\n'
60
61    assert result.exit_code == 0
62    assert expected in result.output
63
64
65@dbtest
66def test_execute_arg_with_csv(executor):
67    run(executor, 'create table test (a text)')
68    run(executor, 'insert into test values("abc")')
69
70    sql = 'select * from test;'
71    runner = CliRunner()
72    result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv'])
73    expected = '"a"\n"abc"\n'
74
75    assert result.exit_code == 0
76    assert expected in "".join(result.output)
77
78
79@dbtest
80def test_batch_mode(executor):
81    run(executor, '''create table test(a text)''')
82    run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
83
84    sql = (
85        'select count(*) from test;\n'
86        'select * from test limit 1;'
87    )
88
89    runner = CliRunner()
90    result = runner.invoke(cli, args=CLI_ARGS, input=sql)
91
92    assert result.exit_code == 0
93    assert 'count(*)\n3\na\nabc\n' in "".join(result.output)
94
95
96@dbtest
97def test_batch_mode_table(executor):
98    run(executor, '''create table test(a text)''')
99    run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
100
101    sql = (
102        'select count(*) from test;\n'
103        'select * from test limit 1;'
104    )
105
106    runner = CliRunner()
107    result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql)
108
109    expected = (dedent("""\
110        +----------+
111        | count(*) |
112        +----------+
113        | 3        |
114        +----------+
115        +-----+
116        | a   |
117        +-----+
118        | abc |
119        +-----+"""))
120
121    assert result.exit_code == 0
122    assert expected in result.output
123
124
125@dbtest
126def test_batch_mode_csv(executor):
127    run(executor, '''create table test(a text, b text)''')
128    run(executor,
129        '''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''')
130
131    sql = 'select * from test;'
132
133    runner = CliRunner()
134    result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql)
135
136    expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n'
137
138    assert result.exit_code == 0
139    assert expected in "".join(result.output)
140
141
142def test_thanks_picker_utf8():
143    author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
144    sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
145
146    name = thanks_picker((author_file, sponsor_file))
147    assert name and isinstance(name, str)
148
149
150def test_help_strings_end_with_periods():
151    """Make sure click options have help text that end with a period."""
152    for param in cli.params:
153        if isinstance(param, click.core.Option):
154            assert hasattr(param, 'help')
155            assert param.help.endswith('.')
156
157
158def test_command_descriptions_end_with_periods():
159    """Make sure that mycli commands' descriptions end with a period."""
160    MyCli()
161    for _, command in SPECIAL_COMMANDS.items():
162        assert command[3].endswith('.')
163
164
165def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
166    global clickoutput
167    clickoutput = ""
168    m = MyCli(myclirc=default_config_file)
169
170    class TestOutput():
171        def get_size(self):
172            size = namedtuple('Size', 'rows columns')
173            size.columns, size.rows = terminal_size
174            return size
175
176    class TestExecute():
177        host = 'test'
178        user = 'test'
179        dbname = 'test'
180        port = 0
181
182        def server_type(self):
183            return ['test']
184
185    class PromptBuffer():
186        output = TestOutput()
187
188    m.prompt_app = PromptBuffer()
189    m.sqlexecute = TestExecute()
190    m.explicit_pager = explicit_pager
191
192    def echo_via_pager(s):
193        assert expect_pager
194        global clickoutput
195        clickoutput += "".join(s)
196
197    def secho(s):
198        assert not expect_pager
199        global clickoutput
200        clickoutput += s + "\n"
201
202    monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager)
203    monkeypatch.setattr(click, 'secho', secho)
204    m.output(testdata)
205    if clickoutput.endswith("\n"):
206        clickoutput = clickoutput[:-1]
207    assert clickoutput == "\n".join(testdata)
208
209
210def test_conditional_pager(monkeypatch):
211    testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(
212        " ")
213    # User didn't set pager, output doesn't fit screen -> pager
214    output(
215        monkeypatch,
216        terminal_size=(5, 10),
217        testdata=testdata,
218        explicit_pager=False,
219        expect_pager=True
220    )
221    # User didn't set pager, output fits screen -> no pager
222    output(
223        monkeypatch,
224        terminal_size=(20, 20),
225        testdata=testdata,
226        explicit_pager=False,
227        expect_pager=False
228    )
229    # User manually configured pager, output doesn't fit screen -> pager
230    output(
231        monkeypatch,
232        terminal_size=(5, 10),
233        testdata=testdata,
234        explicit_pager=True,
235        expect_pager=True
236    )
237    # User manually configured pager, output fit screen -> pager
238    output(
239        monkeypatch,
240        terminal_size=(20, 20),
241        testdata=testdata,
242        explicit_pager=True,
243        expect_pager=True
244    )
245
246    SPECIAL_COMMANDS['nopager'].handler()
247    output(
248        monkeypatch,
249        terminal_size=(5, 10),
250        testdata=testdata,
251        explicit_pager=False,
252        expect_pager=False
253    )
254    SPECIAL_COMMANDS['pager'].handler('')
255
256
257def test_reserved_space_is_integer():
258    """Make sure that reserved space is returned as an integer."""
259    def stub_terminal_size():
260        return (5, 5)
261
262    old_func = click.get_terminal_size
263
264    click.get_terminal_size = stub_terminal_size
265    mycli = MyCli()
266    assert isinstance(mycli.get_reserved_space(), int)
267
268    click.get_terminal_size = old_func
269
270
271def test_list_dsn():
272    runner = CliRunner()
273    with NamedTemporaryFile(mode="w") as myclirc:
274        myclirc.write(dedent("""\
275            [alias_dsn]
276            test = mysql://test/test
277            """))
278        myclirc.flush()
279        args = ['--list-dsn', '--myclirc', myclirc.name]
280        result = runner.invoke(cli, args=args)
281        assert result.output == "test\n"
282        result = runner.invoke(cli, args=args + ['--verbose'])
283        assert result.output == "test : mysql://test/test\n"
284
285
286def test_list_ssh_config():
287    runner = CliRunner()
288    with NamedTemporaryFile(mode="w") as ssh_config:
289        ssh_config.write(dedent("""\
290            Host test
291                Hostname test.example.com
292                User joe
293                Port 22222
294                IdentityFile ~/.ssh/gateway
295        """))
296        ssh_config.flush()
297        args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name]
298        result = runner.invoke(cli, args=args)
299        assert "test\n" in result.output
300        result = runner.invoke(cli, args=args + ['--verbose'])
301        assert "test : test.example.com\n" in result.output
302
303
304def test_dsn(monkeypatch):
305    # Setup classes to mock mycli.main.MyCli
306    class Formatter:
307        format_name = None
308    class Logger:
309        def debug(self, *args, **args_dict):
310            pass
311        def warning(self, *args, **args_dict):
312            pass
313    class MockMyCli:
314        config = {'alias_dsn': {}}
315        def __init__(self, **args):
316            self.logger = Logger()
317            self.destructive_warning = False
318            self.formatter = Formatter()
319        def connect(self, **args):
320            MockMyCli.connect_args = args
321        def run_query(self, query, new_line=True):
322            pass
323
324    import mycli.main
325    monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
326    runner = CliRunner()
327
328    # When a user supplies a DSN as database argument to mycli,
329    # use these values.
330    result = runner.invoke(mycli.main.cli, args=[
331        "mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]
332    )
333    assert result.exit_code == 0, result.output + " " + str(result.exception)
334    assert \
335        MockMyCli.connect_args["user"] == "dsn_user" and \
336        MockMyCli.connect_args["passwd"] == "dsn_passwd" and \
337        MockMyCli.connect_args["host"] == "dsn_host" and \
338        MockMyCli.connect_args["port"] == 1 and \
339        MockMyCli.connect_args["database"] == "dsn_database"
340
341    MockMyCli.connect_args = None
342
343    # When a use supplies a DSN as database argument to mycli,
344    # and used command line arguments, use the command line
345    # arguments.
346    result = runner.invoke(mycli.main.cli, args=[
347        "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
348        "--user", "arg_user",
349        "--password", "arg_password",
350        "--host", "arg_host",
351        "--port", "3",
352        "--database", "arg_database",
353    ])
354    assert result.exit_code == 0, result.output + " " + str(result.exception)
355    assert \
356        MockMyCli.connect_args["user"] == "arg_user" and \
357        MockMyCli.connect_args["passwd"] == "arg_password" and \
358        MockMyCli.connect_args["host"] == "arg_host" and \
359        MockMyCli.connect_args["port"] == 3 and \
360        MockMyCli.connect_args["database"] == "arg_database"
361
362    MockMyCli.config = {
363        'alias_dsn': {
364            'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
365        }
366    }
367    MockMyCli.connect_args = None
368
369    # When a user uses a DSN from the configuration file (alias_dsn),
370    # use these values.
371    result = runner.invoke(cli, args=['--dsn', 'test'])
372    assert result.exit_code == 0, result.output + " " + str(result.exception)
373    assert \
374        MockMyCli.connect_args["user"] == "alias_dsn_user" and \
375        MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \
376        MockMyCli.connect_args["host"] == "alias_dsn_host" and \
377        MockMyCli.connect_args["port"] == 4 and \
378        MockMyCli.connect_args["database"] == "alias_dsn_database"
379
380    MockMyCli.config = {
381        'alias_dsn': {
382            'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
383        }
384    }
385    MockMyCli.connect_args = None
386
387    # When a user uses a DSN from the configuration file (alias_dsn)
388    # and used command line arguments, use the command line arguments.
389    result = runner.invoke(cli, args=[
390        '--dsn', 'test', '',
391        "--user", "arg_user",
392        "--password", "arg_password",
393        "--host", "arg_host",
394        "--port", "5",
395        "--database", "arg_database",
396    ])
397    assert result.exit_code == 0, result.output + " " + str(result.exception)
398    assert \
399        MockMyCli.connect_args["user"] == "arg_user" and \
400        MockMyCli.connect_args["passwd"] == "arg_password" and \
401        MockMyCli.connect_args["host"] == "arg_host" and \
402        MockMyCli.connect_args["port"] == 5 and \
403        MockMyCli.connect_args["database"] == "arg_database"
404
405    # Use a DSN without password
406    result = runner.invoke(mycli.main.cli, args=[
407        "mysql://dsn_user@dsn_host:6/dsn_database"]
408    )
409    assert result.exit_code == 0, result.output + " " + str(result.exception)
410    assert \
411        MockMyCli.connect_args["user"] == "dsn_user" and \
412        MockMyCli.connect_args["passwd"] is None and \
413        MockMyCli.connect_args["host"] == "dsn_host" and \
414        MockMyCli.connect_args["port"] == 6 and \
415        MockMyCli.connect_args["database"] == "dsn_database"
416
417
418def test_ssh_config(monkeypatch):
419    # Setup classes to mock mycli.main.MyCli
420    class Formatter:
421        format_name = None
422
423    class Logger:
424        def debug(self, *args, **args_dict):
425            pass
426
427        def warning(self, *args, **args_dict):
428            pass
429
430    class MockMyCli:
431        config = {'alias_dsn': {}}
432
433        def __init__(self, **args):
434            self.logger = Logger()
435            self.destructive_warning = False
436            self.formatter = Formatter()
437
438        def connect(self, **args):
439            MockMyCli.connect_args = args
440
441        def run_query(self, query, new_line=True):
442            pass
443
444    import mycli.main
445    monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
446    runner = CliRunner()
447
448    # Setup temporary configuration
449    with NamedTemporaryFile(mode="w") as ssh_config:
450        ssh_config.write(dedent("""\
451            Host test
452                Hostname test.example.com
453                User joe
454                Port 22222
455                IdentityFile ~/.ssh/gateway
456        """))
457        ssh_config.flush()
458
459        # When a user supplies a ssh config.
460        result = runner.invoke(mycli.main.cli, args=[
461            "--ssh-config-path",
462            ssh_config.name,
463            "--ssh-config-host",
464            "test"
465        ])
466        assert result.exit_code == 0, result.output + \
467            " " + str(result.exception)
468        assert \
469            MockMyCli.connect_args["ssh_user"] == "joe" and \
470            MockMyCli.connect_args["ssh_host"] == "test.example.com" and \
471            MockMyCli.connect_args["ssh_port"] == 22222 and \
472            MockMyCli.connect_args["ssh_key_filename"] == os.getenv(
473                "HOME") + "/.ssh/gateway"
474
475        # When a user supplies a ssh config host as argument to mycli,
476        # and used command line arguments, use the command line
477        # arguments.
478        result = runner.invoke(mycli.main.cli, args=[
479            "--ssh-config-path",
480            ssh_config.name,
481            "--ssh-config-host",
482            "test",
483            "--ssh-user", "arg_user",
484            "--ssh-host", "arg_host",
485            "--ssh-port", "3",
486            "--ssh-key-filename", "/path/to/key"
487        ])
488        assert result.exit_code == 0, result.output + \
489            " " + str(result.exception)
490        assert \
491            MockMyCli.connect_args["ssh_user"] == "arg_user" and \
492            MockMyCli.connect_args["ssh_host"] == "arg_host" and \
493            MockMyCli.connect_args["ssh_port"] == 3 and \
494            MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
495
496
497@dbtest
498def test_init_command_arg(executor):
499    init_command = "set sql_select_limit=1000"
500    sql = 'show variables like "sql_select_limit";'
501    runner = CliRunner()
502    result = runner.invoke(
503        cli, args=CLI_ARGS + ["--init-command", init_command], input=sql
504    )
505
506    expected = "sql_select_limit\t1000\n"
507    assert result.exit_code == 0
508    assert expected in result.output
509
510
511@dbtest
512def test_init_command_multiple_arg(executor):
513    init_command = 'set sql_select_limit=2000; set max_join_size=20000'
514    sql = (
515        'show variables like "sql_select_limit";\n'
516        'show variables like "max_join_size"'
517    )
518    runner = CliRunner()
519    result = runner.invoke(
520        cli, args=CLI_ARGS + ['--init-command', init_command], input=sql
521    )
522
523    expected_sql_select_limit = 'sql_select_limit\t2000\n'
524    expected_max_join_size = 'max_join_size\t20000\n'
525
526    assert result.exit_code == 0
527    assert expected_sql_select_limit in result.output
528    assert expected_max_join_size in result.output
529