1package transport_test
2
3import (
4	"errors"
5	"net/http"
6	"net/url"
7	"time"
8
9	. "github.com/onsi/ginkgo"
10	. "github.com/onsi/gomega"
11
12	"github.com/concourse/concourse/atc/db"
13	"github.com/concourse/concourse/atc/db/dbfakes"
14	"github.com/concourse/concourse/atc/worker/transport"
15	"github.com/concourse/concourse/atc/worker/transport/transportfakes"
16	"github.com/concourse/retryhttp"
17	"github.com/concourse/retryhttp/retryhttpfakes"
18)
19
20var _ = Describe("HijackableClient #Do", func() {
21	var (
22		request              http.Request
23		fakeDB               *transportfakes.FakeTransportDB
24		savedWorker          *dbfakes.FakeWorker
25		savedWorkerAddress   string
26		fakeHijackableClient *retryhttpfakes.FakeHijackableClient
27		hijackableClient     retryhttp.HijackableClient
28		response             *http.Response
29		err                  error
30		fakeHijackCloser     *retryhttpfakes.FakeHijackCloser
31		actualHijackCloser   retryhttp.HijackCloser
32	)
33
34	BeforeEach(func() {
35		fakeDB = new(transportfakes.FakeTransportDB)
36		fakeHijackableClient = new(retryhttpfakes.FakeHijackableClient)
37		fakeHijackCloser = new(retryhttpfakes.FakeHijackCloser)
38		hijackableClient = transport.NewHijackableClient("some-worker", fakeDB, fakeHijackableClient)
39		requestUrl, err := url.Parse("http://1.2.3.4/something")
40		Expect(err).NotTo(HaveOccurred())
41
42		request = http.Request{
43			URL: requestUrl,
44		}
45
46		savedWorkerAddress = "some-garden-addr"
47		savedWorker = new(dbfakes.FakeWorker)
48		savedWorker.GardenAddrReturns(&savedWorkerAddress)
49		savedWorker.ExpiresAtReturns(time.Now().Add(123 * time.Minute))
50		savedWorker.StateReturns(db.WorkerStateRunning)
51
52		fakeDB.GetWorkerReturns(savedWorker, true, nil)
53
54		fakeHijackableClient.DoReturns(&http.Response{StatusCode: http.StatusTeapot}, fakeHijackCloser, nil)
55	})
56
57	JustBeforeEach(func() {
58		response, actualHijackCloser, err = hijackableClient.Do(&request)
59	})
60
61	It("returns the response", func() {
62		Expect(err).NotTo(HaveOccurred())
63		Expect(actualHijackCloser).To(Equal(fakeHijackCloser))
64		Expect(response).To(Equal(&http.Response{StatusCode: http.StatusTeapot}))
65	})
66
67	It("sends the request with worker's garden address", func() {
68		Expect(fakeHijackableClient.DoCallCount()).To(Equal(1))
69		actualRequest := fakeHijackableClient.DoArgsForCall(0)
70		Expect(actualRequest.URL.Host).To(Equal(*savedWorker.GardenAddr()))
71		Expect(actualRequest.URL.Path).To(Equal("/something"))
72	})
73
74	Context("when the lookup of the worker in the db errors", func() {
75		var expectedErr error
76		BeforeEach(func() {
77			expectedErr = errors.New("some-db-error")
78			fakeDB.GetWorkerReturns(nil, true, expectedErr)
79		})
80
81		It("throws an error", func() {
82			Expect(err).To(HaveOccurred())
83			Expect(err.Error()).To(ContainSubstring(expectedErr.Error()))
84		})
85	})
86
87	Context("when the worker is not found in the db", func() {
88		BeforeEach(func() {
89			fakeDB.GetWorkerReturns(nil, false, nil)
90		})
91
92		It("throws an error", func() {
93			Expect(err).To(HaveOccurred())
94			Expect(err).To(Equal(transport.WorkerMissingError{WorkerName: "some-worker"}))
95		})
96	})
97
98	Context("when the worker is stalled in the db", func() {
99		BeforeEach(func() {
100			stalledWorker := new(dbfakes.FakeWorker)
101			stalledWorker.StateReturns(db.WorkerStateStalled)
102			fakeDB.GetWorkerReturns(stalledWorker, true, nil)
103		})
104
105		It("throws a descriptive error", func() {
106			Expect(err).To(HaveOccurred())
107			Expect(err).To(Equal(transport.WorkerUnreachableError{
108				WorkerName:  "some-worker",
109				WorkerState: "stalled",
110			}))
111		})
112	})
113
114	It("reuses the request cached host on subsequent calls", func() {
115		Expect(fakeDB.GetWorkerCallCount()).To(Equal(1))
116		_, _, err := hijackableClient.Do(&request)
117		Expect(err).NotTo(HaveOccurred())
118		Expect(fakeDB.GetWorkerCallCount()).To(Equal(1))
119	})
120
121	Context("when inner Do fails", func() {
122		BeforeEach(func() {
123			fakeHijackableClient.DoReturns(nil, nil, errors.New("some-error"))
124		})
125
126		It("updates cached request host", func() {
127			Expect(fakeDB.GetWorkerCallCount()).To(Equal(1))
128			_, _, err := hijackableClient.Do(&request)
129			Expect(err).To(HaveOccurred())
130			Expect(fakeDB.GetWorkerCallCount()).To(Equal(2))
131		})
132	})
133})
134