1package restore
2
3import (
4	"crypto/rand"
5	"fmt"
6	"io"
7	"io/ioutil"
8	"os"
9	"path/filepath"
10	"strings"
11	"testing"
12
13	"github.com/hashicorp/consul/agent"
14	"github.com/hashicorp/consul/api"
15	"github.com/hashicorp/consul/sdk/testutil"
16	"github.com/mitchellh/cli"
17	"github.com/stretchr/testify/require"
18)
19
20func TestSnapshotRestoreCommand_noTabs(t *testing.T) {
21	t.Parallel()
22	if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
23		t.Fatal("help has tabs")
24	}
25}
26
27func TestSnapshotRestoreCommand_Validation(t *testing.T) {
28	t.Parallel()
29	ui := cli.NewMockUi()
30	c := New(ui)
31
32	cases := map[string]struct {
33		args   []string
34		output string
35	}{
36		"no file": {
37			[]string{},
38			"Missing FILE argument",
39		},
40		"extra args": {
41			[]string{"foo", "bar", "baz"},
42			"Too many arguments",
43		},
44	}
45
46	for name, tc := range cases {
47		// Ensure our buffer is always clear
48		if ui.ErrorWriter != nil {
49			ui.ErrorWriter.Reset()
50		}
51		if ui.OutputWriter != nil {
52			ui.OutputWriter.Reset()
53		}
54
55		code := c.Run(tc.args)
56		if code == 0 {
57			t.Errorf("%s: expected non-zero exit", name)
58		}
59
60		output := ui.ErrorWriter.String()
61		if !strings.Contains(output, tc.output) {
62			t.Errorf("%s: expected %q to contain %q", name, output, tc.output)
63		}
64	}
65}
66
67func TestSnapshotRestoreCommand(t *testing.T) {
68	if testing.Short() {
69		t.Skip("too slow for testing.Short")
70	}
71
72	t.Parallel()
73	a := agent.NewTestAgent(t, ``)
74	defer a.Shutdown()
75	client := a.Client()
76
77	ui := cli.NewMockUi()
78	c := New(ui)
79
80	dir := testutil.TempDir(t, "snapshot")
81	file := filepath.Join(dir, "backup.tgz")
82	args := []string{
83		"-http-addr=" + a.HTTPAddr(),
84		file,
85	}
86
87	f, err := os.Create(file)
88	if err != nil {
89		t.Fatalf("err: %v", err)
90	}
91
92	snap, _, err := client.Snapshot().Save(nil)
93	if err != nil {
94		f.Close()
95		t.Fatalf("err: %v", err)
96	}
97	if _, err := io.Copy(f, snap); err != nil {
98		f.Close()
99		t.Fatalf("err: %v", err)
100	}
101	if err := f.Close(); err != nil {
102		t.Fatalf("err: %v", err)
103	}
104
105	code := c.Run(args)
106	if code != 0 {
107		t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String())
108	}
109}
110
111func TestSnapshotRestoreCommand_TruncatedSnapshot(t *testing.T) {
112	if testing.Short() {
113		t.Skip("too slow for testing.Short")
114	}
115
116	t.Parallel()
117	a := agent.NewTestAgent(t, ``)
118	defer a.Shutdown()
119	client := a.Client()
120
121	// Seed it with 64K of random data just so we have something to work with.
122	{
123		blob := make([]byte, 64*1024)
124		_, err := rand.Read(blob)
125		require.NoError(t, err)
126
127		_, err = client.KV().Put(&api.KVPair{Key: "blob", Value: blob}, nil)
128		require.NoError(t, err)
129	}
130
131	// Do a manual snapshot so we can send back roughly reasonable data.
132	var inputData []byte
133	{
134		rc, _, err := client.Snapshot().Save(nil)
135		require.NoError(t, err)
136		defer rc.Close()
137
138		inputData, err = ioutil.ReadAll(rc)
139		require.NoError(t, err)
140	}
141
142	dir := testutil.TempDir(t, "snapshot")
143
144	for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
145		t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
146			// Lop off part of the end.
147			data := inputData[0 : len(inputData)-removeBytes]
148
149			ui := cli.NewMockUi()
150			c := New(ui)
151
152			file := filepath.Join(dir, "backup.tgz")
153			require.NoError(t, ioutil.WriteFile(file, data, 0644))
154			args := []string{
155				"-http-addr=" + a.HTTPAddr(),
156				file,
157			}
158
159			code := c.Run(args)
160			require.Equal(t, 1, code, "expected non-zero exit")
161
162			output := ui.ErrorWriter.String()
163			require.Contains(t, output, "Error restoring snapshot")
164			require.Contains(t, output, "EOF")
165		})
166	}
167}
168