1package getter
2
3import (
4	"crypto/md5"
5	"encoding/hex"
6	"io"
7	"io/ioutil"
8	"os"
9	"path/filepath"
10	"reflect"
11	"runtime"
12	"sort"
13	"strings"
14	"time"
15
16	"github.com/mitchellh/go-testing-interface"
17)
18
19// TestDecompressCase is a single test case for testing decompressors
20type TestDecompressCase struct {
21	Input   string     // Input is the complete path to the input file
22	Dir     bool       // Dir is whether or not we're testing directory mode
23	Err     bool       // Err is whether we expect an error or not
24	DirList []string   // DirList is the list of files for Dir mode
25	FileMD5 string     // FileMD5 is the expected MD5 for a single file
26	Mtime   *time.Time // Mtime is the optionally expected mtime for a single file (or all files if in Dir mode)
27}
28
29// TestDecompressor is a helper function for testing generic decompressors.
30func TestDecompressor(t testing.T, d Decompressor, cases []TestDecompressCase) {
31	t.Helper()
32
33	for _, tc := range cases {
34		t.Logf("Testing: %s", tc.Input)
35
36		// Temporary dir to store stuff
37		td, err := ioutil.TempDir("", "getter")
38		if err != nil {
39			t.Fatalf("err: %s", err)
40		}
41
42		// Destination is always joining result so that we have a new path
43		dst := filepath.Join(td, "subdir", "result")
44
45		// We use a function so defers work
46		func() {
47			defer os.RemoveAll(td)
48
49			// Decompress
50			err := d.Decompress(dst, tc.Input, tc.Dir, 0022)
51			if (err != nil) != tc.Err {
52				t.Fatalf("err %s: %s", tc.Input, err)
53			}
54			if tc.Err {
55				return
56			}
57
58			// If it isn't a directory, then check for a single file
59			if !tc.Dir {
60				fi, err := os.Stat(dst)
61				if err != nil {
62					t.Fatalf("err %s: %s", tc.Input, err)
63				}
64				if fi.IsDir() {
65					t.Fatalf("err %s: expected file, got directory", tc.Input)
66				}
67				if tc.FileMD5 != "" {
68					actual := testMD5(t, dst)
69					expected := tc.FileMD5
70					if actual != expected {
71						t.Fatalf("err %s: expected MD5 %s, got %s", tc.Input, expected, actual)
72					}
73				}
74
75				if tc.Mtime != nil {
76					actual := fi.ModTime()
77					if tc.Mtime.Unix() > 0 {
78						expected := *tc.Mtime
79						if actual != expected {
80							t.Fatalf("err %s: expected mtime '%s' for %s, got '%s'", tc.Input, expected.String(), dst, actual.String())
81						}
82					} else if actual.Unix() <= 0 {
83						t.Fatalf("err %s: expected mtime to be > 0, got '%s'", actual.String())
84					}
85				}
86
87				return
88			}
89
90			// Convert expected for windows
91			expected := tc.DirList
92			if runtime.GOOS == "windows" {
93				for i, v := range expected {
94					expected[i] = strings.Replace(v, "/", "\\", -1)
95				}
96			}
97
98			// Directory, check for the correct contents
99			actual := testListDir(t, dst)
100			if !reflect.DeepEqual(actual, expected) {
101				t.Fatalf("bad %s\n\n%#v\n\n%#v", tc.Input, actual, expected)
102			}
103			// Check for correct atime/mtime
104			for _, dir := range actual {
105				path := filepath.Join(dst, dir)
106				if tc.Mtime != nil {
107					fi, err := os.Stat(path)
108					if err != nil {
109						t.Fatalf("err: %s", err)
110					}
111					actual := fi.ModTime()
112					if tc.Mtime.Unix() > 0 {
113						expected := *tc.Mtime
114						if actual != expected {
115							t.Fatalf("err %s: expected mtime '%s' for %s, got '%s'", tc.Input, expected.String(), path, actual.String())
116						}
117					} else if actual.Unix() < 0 {
118						t.Fatalf("err %s: expected mtime to be > 0, got '%s'", actual.String())
119					}
120
121				}
122			}
123		}()
124	}
125}
126
127func testListDir(t testing.T, path string) []string {
128	var result []string
129	err := filepath.Walk(path, func(sub string, info os.FileInfo, err error) error {
130		if err != nil {
131			return err
132		}
133
134		sub = strings.TrimPrefix(sub, path)
135		if sub == "" {
136			return nil
137		}
138		sub = sub[1:] // Trim the leading path sep.
139
140		// If it is a dir, add trailing sep
141		if info.IsDir() {
142			sub += string(os.PathSeparator)
143		}
144
145		result = append(result, sub)
146		return nil
147	})
148	if err != nil {
149		t.Fatalf("err: %s", err)
150	}
151
152	sort.Strings(result)
153	return result
154}
155
156func testMD5(t testing.T, path string) string {
157	f, err := os.Open(path)
158	if err != nil {
159		t.Fatalf("err: %s", err)
160	}
161	defer f.Close()
162
163	h := md5.New()
164	_, err = io.Copy(h, f)
165	if err != nil {
166		t.Fatalf("err: %s", err)
167	}
168
169	result := h.Sum(nil)
170	return hex.EncodeToString(result)
171}
172