1# SPDX-License-Identifier: BSD-2-Clause 2import pytest 3from unittest import mock 4 5import sarge 6 7import zsm_lib.zfs 8 9 10def get_mocked_readline(lines): 11 if lines is None: 12 lines = [] 13 14 # Always add an empty line to the end, 15 # because that is how readline signals that there are no more lines. 16 lines.append("") 17 18 def generate_lines(): 19 for line in lines: 20 yield line.encode("utf-8") 21 22 generated_lines = generate_lines() 23 24 def mocked_readline(): 25 return next(generated_lines) 26 27 return mocked_readline 28 29 30def test_run_with_success(): 31 mocked_pipeline = mock.MagicMock() 32 mocked_pipeline.returncode = 0 33 sarge.run = mock.MagicMock(return_value=mocked_pipeline) 34 35 zsm_lib.zfs.run("asdf") 36 37 sarge.run.assert_called() 38 39 40def test_run_with_failure(): 41 expected_error = "a problem?" 42 43 mocked_pipeline = mock.MagicMock() 44 mocked_pipeline.returncode = 1 45 mocked_pipeline.stderr.readline = get_mocked_readline([expected_error]) 46 sarge.run = mock.MagicMock(return_value=mocked_pipeline) 47 48 with pytest.raises(zsm_lib.zfs.ZfsOperationFailed) as e: 49 zsm_lib.zfs.run("asdf") 50 51 assert str(e.value) == expected_error 52 53 sarge.run.assert_called() 54 55 56def patch_run(stdout_lines=None, stderr_lines=None): 57 def mocked_run(cmd): 58 mocked_stdout_capture = mock.MagicMock() 59 mocked_stdout_capture.readline = get_mocked_readline(stdout_lines) 60 61 mocked_stderr_capture = mock.MagicMock() 62 mocked_stderr_capture.readline = get_mocked_readline(stderr_lines) 63 64 mocked_pipeline = mock.MagicMock() 65 mocked_pipeline.stdout = mocked_stdout_capture 66 mocked_pipeline.stderr = mocked_stderr_capture 67 68 return mocked_pipeline 69 70 return mock.patch("zsm_lib.zfs.run", side_effect=mocked_run) 71 72 73def test_get_datasets(): 74 stdout_lines = ["HEADER", "tank/a", "tank/b"] 75 76 with patch_run(stdout_lines=stdout_lines): 77 datasets = zsm_lib.zfs.get_datasets() 78 79 # Start counting stdout_lines at 1 since the header is supposed to be skipped. 80 assert datasets[0].name == stdout_lines[1] 81 assert datasets[1].name == stdout_lines[2] 82 83 84def test_get_snapshots(): 85 dataset_name = "tank/a" 86 stdout_lines = ["HEADER", f"{dataset_name}@a", f"{dataset_name}@b"] 87 88 with patch_run(stdout_lines=stdout_lines): 89 90 snapshots = zsm_lib.zfs.get_snapshots( 91 dataset=zsm_lib.zfs.Dataset(name="tank/a") 92 ) 93 94 # Start counting stdout_lines at 1 since the header is supposed to be skipped. 95 assert f"{dataset_name}@{snapshots[0].name}" == stdout_lines[1] 96 assert f"{dataset_name}@{snapshots[1].name}" == stdout_lines[2] 97 98 99def test_create_snapshot(): 100 with patch_run() as mock_run: 101 zsm_lib.zfs.create_snapshot( 102 dataset=zsm_lib.zfs.Dataset(name="tank/a"), name="asdf" 103 ) 104 mock_run.assert_called() 105 106 107def test_destroy_snapshot(): 108 with patch_run() as mock_run: 109 zsm_lib.zfs.destroy_snapshot( 110 snapshot=zsm_lib.zfs.Snapshot( 111 dataset=zsm_lib.zfs.Dataset(name="tank/a"), name="asdf" 112 ) 113 ) 114 mock_run.assert_called() 115