1package getter
2
3import (
4	"io"
5	"net/http"
6	"net/http/httptest"
7	"os"
8	"path/filepath"
9	"sync"
10	"testing"
11)
12
13type MockProgressTracking struct {
14	sync.Mutex
15	downloaded map[string]int
16}
17
18func (p *MockProgressTracking) TrackProgress(src string,
19	currentSize, totalSize int64, stream io.ReadCloser) (body io.ReadCloser) {
20	p.Lock()
21	defer p.Unlock()
22
23	if p.downloaded == nil {
24		p.downloaded = map[string]int{}
25	}
26
27	v, _ := p.downloaded[src]
28	p.downloaded[src] = v + 1
29	return stream
30}
31
32func TestGet_progress(t *testing.T) {
33	s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
34		// all good
35		rw.Header().Add("X-Terraform-Get", "something")
36	}))
37	defer s.Close()
38
39	{ // dl without tracking
40		dst := tempTestFile(t)
41		defer os.RemoveAll(filepath.Dir(dst))
42		if err := GetFile(dst, s.URL+"/file?thig=this&that"); err != nil {
43			t.Fatalf("download failed: %v", err)
44		}
45	}
46
47	{ // tracking
48		p := &MockProgressTracking{}
49		dst := tempTestFile(t)
50		defer os.RemoveAll(filepath.Dir(dst))
51		if err := GetFile(dst, s.URL+"/file?thig=this&that", WithProgress(p)); err != nil {
52			t.Fatalf("download failed: %v", err)
53		}
54		if err := GetFile(dst, s.URL+"/otherfile?thig=this&that", WithProgress(p)); err != nil {
55			t.Fatalf("download failed: %v", err)
56		}
57
58		if p.downloaded["file"] != 1 {
59			t.Error("Expected a file download")
60		}
61		if p.downloaded["otherfile"] != 1 {
62			t.Error("Expected a otherfile download")
63		}
64	}
65}
66