1package restore
2
3import (
4	"crypto/rand"
5	"fmt"
6	"io"
7	"io/ioutil"
8	"os"
9	"path/filepath"
10	"strings"
11	"testing"
12
13	"github.com/hashicorp/consul/agent"
14	"github.com/hashicorp/consul/api"
15	"github.com/hashicorp/consul/sdk/testutil"
16	"github.com/mitchellh/cli"
17	"github.com/stretchr/testify/require"
18)
19
20func TestSnapshotRestoreCommand_noTabs(t *testing.T) {
21	t.Parallel()
22	if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
23		t.Fatal("help has tabs")
24	}
25}
26
27func TestSnapshotRestoreCommand_Validation(t *testing.T) {
28	t.Parallel()
29	ui := cli.NewMockUi()
30	c := New(ui)
31
32	cases := map[string]struct {
33		args   []string
34		output string
35	}{
36		"no file": {
37			[]string{},
38			"Missing FILE argument",
39		},
40		"extra args": {
41			[]string{"foo", "bar", "baz"},
42			"Too many arguments",
43		},
44	}
45
46	for name, tc := range cases {
47		// Ensure our buffer is always clear
48		if ui.ErrorWriter != nil {
49			ui.ErrorWriter.Reset()
50		}
51		if ui.OutputWriter != nil {
52			ui.OutputWriter.Reset()
53		}
54
55		code := c.Run(tc.args)
56		if code == 0 {
57			t.Errorf("%s: expected non-zero exit", name)
58		}
59
60		output := ui.ErrorWriter.String()
61		if !strings.Contains(output, tc.output) {
62			t.Errorf("%s: expected %q to contain %q", name, output, tc.output)
63		}
64	}
65}
66
67func TestSnapshotRestoreCommand(t *testing.T) {
68	t.Parallel()
69	a := agent.NewTestAgent(t, ``)
70	defer a.Shutdown()
71	client := a.Client()
72
73	ui := cli.NewMockUi()
74	c := New(ui)
75
76	dir := testutil.TempDir(t, "snapshot")
77	defer os.RemoveAll(dir)
78
79	file := filepath.Join(dir, "backup.tgz")
80	args := []string{
81		"-http-addr=" + a.HTTPAddr(),
82		file,
83	}
84
85	f, err := os.Create(file)
86	if err != nil {
87		t.Fatalf("err: %v", err)
88	}
89
90	snap, _, err := client.Snapshot().Save(nil)
91	if err != nil {
92		f.Close()
93		t.Fatalf("err: %v", err)
94	}
95	if _, err := io.Copy(f, snap); err != nil {
96		f.Close()
97		t.Fatalf("err: %v", err)
98	}
99	if err := f.Close(); err != nil {
100		t.Fatalf("err: %v", err)
101	}
102
103	code := c.Run(args)
104	if code != 0 {
105		t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String())
106	}
107}
108
109func TestSnapshotRestoreCommand_TruncatedSnapshot(t *testing.T) {
110	t.Parallel()
111	a := agent.NewTestAgent(t, ``)
112	defer a.Shutdown()
113	client := a.Client()
114
115	// Seed it with 64K of random data just so we have something to work with.
116	{
117		blob := make([]byte, 64*1024)
118		_, err := rand.Read(blob)
119		require.NoError(t, err)
120
121		_, err = client.KV().Put(&api.KVPair{Key: "blob", Value: blob}, nil)
122		require.NoError(t, err)
123	}
124
125	// Do a manual snapshot so we can send back roughly reasonable data.
126	var inputData []byte
127	{
128		rc, _, err := client.Snapshot().Save(nil)
129		require.NoError(t, err)
130		defer rc.Close()
131
132		inputData, err = ioutil.ReadAll(rc)
133		require.NoError(t, err)
134	}
135
136	dir := testutil.TempDir(t, "snapshot")
137	defer os.RemoveAll(dir)
138
139	for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
140		t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
141			// Lop off part of the end.
142			data := inputData[0 : len(inputData)-removeBytes]
143
144			ui := cli.NewMockUi()
145			c := New(ui)
146
147			file := filepath.Join(dir, "backup.tgz")
148			require.NoError(t, ioutil.WriteFile(file, data, 0644))
149			args := []string{
150				"-http-addr=" + a.HTTPAddr(),
151				file,
152			}
153
154			code := c.Run(args)
155			require.Equal(t, 1, code, "expected non-zero exit")
156
157			output := ui.ErrorWriter.String()
158			require.Contains(t, output, "Error restoring snapshot")
159			require.Contains(t, output, "EOF")
160		})
161	}
162}
163