1# SPDX-License-Identifier: BSD-2-Clause
2import pytest
3from unittest import mock
4
5import sarge
6
7import zsm_lib.zfs
8
9
10def get_mocked_readline(lines):
11    if lines is None:
12        lines = []
13
14    # Always add an empty line to the end,
15    # because that is how readline signals that there are no more lines.
16    lines.append("")
17
18    def generate_lines():
19        for line in lines:
20            yield line.encode("utf-8")
21
22    generated_lines = generate_lines()
23
24    def mocked_readline():
25        return next(generated_lines)
26
27    return mocked_readline
28
29
30def test_run_with_success():
31    mocked_pipeline = mock.MagicMock()
32    mocked_pipeline.returncode = 0
33    sarge.run = mock.MagicMock(return_value=mocked_pipeline)
34
35    zsm_lib.zfs.run("asdf")
36
37    sarge.run.assert_called()
38
39
40def test_run_with_failure():
41    expected_error = "a problem?"
42
43    mocked_pipeline = mock.MagicMock()
44    mocked_pipeline.returncode = 1
45    mocked_pipeline.stderr.readline = get_mocked_readline([expected_error])
46    sarge.run = mock.MagicMock(return_value=mocked_pipeline)
47
48    with pytest.raises(zsm_lib.zfs.ZfsOperationFailed) as e:
49        zsm_lib.zfs.run("asdf")
50
51    assert str(e.value) == expected_error
52
53    sarge.run.assert_called()
54
55
56def patch_run(stdout_lines=None, stderr_lines=None):
57    def mocked_run(cmd):
58        mocked_stdout_capture = mock.MagicMock()
59        mocked_stdout_capture.readline = get_mocked_readline(stdout_lines)
60
61        mocked_stderr_capture = mock.MagicMock()
62        mocked_stderr_capture.readline = get_mocked_readline(stderr_lines)
63
64        mocked_pipeline = mock.MagicMock()
65        mocked_pipeline.stdout = mocked_stdout_capture
66        mocked_pipeline.stderr = mocked_stderr_capture
67
68        return mocked_pipeline
69
70    return mock.patch("zsm_lib.zfs.run", side_effect=mocked_run)
71
72
73def test_get_datasets():
74    stdout_lines = ["HEADER", "tank/a", "tank/b"]
75
76    with patch_run(stdout_lines=stdout_lines):
77        datasets = zsm_lib.zfs.get_datasets()
78
79        # Start counting stdout_lines at 1 since the header is supposed to be skipped.
80        assert datasets[0].name == stdout_lines[1]
81        assert datasets[1].name == stdout_lines[2]
82
83
84def test_get_snapshots():
85    dataset_name = "tank/a"
86    stdout_lines = ["HEADER", f"{dataset_name}@a", f"{dataset_name}@b"]
87
88    with patch_run(stdout_lines=stdout_lines):
89
90        snapshots = zsm_lib.zfs.get_snapshots(
91            dataset=zsm_lib.zfs.Dataset(name="tank/a")
92        )
93
94        # Start counting stdout_lines at 1 since the header is supposed to be skipped.
95        assert f"{dataset_name}@{snapshots[0].name}" == stdout_lines[1]
96        assert f"{dataset_name}@{snapshots[1].name}" == stdout_lines[2]
97
98
99def test_create_snapshot():
100    with patch_run() as mock_run:
101        zsm_lib.zfs.create_snapshot(
102            dataset=zsm_lib.zfs.Dataset(name="tank/a"), name="asdf"
103        )
104        mock_run.assert_called()
105
106
107def test_destroy_snapshot():
108    with patch_run() as mock_run:
109        zsm_lib.zfs.destroy_snapshot(
110            snapshot=zsm_lib.zfs.Snapshot(
111                dataset=zsm_lib.zfs.Dataset(name="tank/a"), name="asdf"
112            )
113        )
114        mock_run.assert_called()
115