1package getter 2 3import ( 4 "io" 5 "net/http" 6 "net/http/httptest" 7 "os" 8 "path/filepath" 9 "sync" 10 "testing" 11) 12 13type MockProgressTracking struct { 14 sync.Mutex 15 downloaded map[string]int 16} 17 18func (p *MockProgressTracking) TrackProgress(src string, 19 currentSize, totalSize int64, stream io.ReadCloser) (body io.ReadCloser) { 20 p.Lock() 21 defer p.Unlock() 22 23 if p.downloaded == nil { 24 p.downloaded = map[string]int{} 25 } 26 27 v, _ := p.downloaded[src] 28 p.downloaded[src] = v + 1 29 return stream 30} 31 32func TestGet_progress(t *testing.T) { 33 s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 34 // all good 35 rw.Header().Add("X-Terraform-Get", "something") 36 })) 37 defer s.Close() 38 39 { // dl without tracking 40 dst := tempTestFile(t) 41 defer os.RemoveAll(filepath.Dir(dst)) 42 if err := GetFile(dst, s.URL+"/file?thig=this&that"); err != nil { 43 t.Fatalf("download failed: %v", err) 44 } 45 } 46 47 { // tracking 48 p := &MockProgressTracking{} 49 dst := tempTestFile(t) 50 defer os.RemoveAll(filepath.Dir(dst)) 51 if err := GetFile(dst, s.URL+"/file?thig=this&that", WithProgress(p)); err != nil { 52 t.Fatalf("download failed: %v", err) 53 } 54 if err := GetFile(dst, s.URL+"/otherfile?thig=this&that", WithProgress(p)); err != nil { 55 t.Fatalf("download failed: %v", err) 56 } 57 58 if p.downloaded["file"] != 1 { 59 t.Error("Expected a file download") 60 } 61 if p.downloaded["otherfile"] != 1 { 62 t.Error("Expected a otherfile download") 63 } 64 } 65} 66