1package tsacmd
2
3import (
4	"context"
5	"encoding/json"
6	"fmt"
7	"net/http"
8	"time"
9
10	"code.cloudfoundry.org/clock"
11	"code.cloudfoundry.org/lager"
12	"code.cloudfoundry.org/lager/lagerctx"
13	bclient "github.com/concourse/baggageclaim/client"
14	"github.com/concourse/concourse/atc"
15	"github.com/concourse/concourse/atc/worker/gclient"
16	"github.com/concourse/concourse/tsa"
17	"golang.org/x/crypto/ssh"
18)
19
20type request interface {
21	Handle(context.Context, ConnState, ssh.Channel) error
22}
23
24type forwardWorkerRequest struct {
25	server *server
26
27	gardenAddr       string
28	baggageclaimAddr string
29}
30
31func (req forwardWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
32	logger := lagerctx.FromContext(ctx)
33
34	var worker atc.Worker
35	err := json.NewDecoder(channel).Decode(&worker)
36	if err != nil {
37		return err
38	}
39
40	if err := checkTeam(state, worker); err != nil {
41		return err
42	}
43
44	forwards := map[string]ForwardedTCPIP{}
45	for i := 0; i < 2; i++ {
46		select {
47		case forwarded := <-state.ForwardedTCPIPs:
48			logger.Info("forwarded-tcpip", lager.Data{
49				"bind-addr":  forwarded.BindAddr,
50				"bound-port": forwarded.BoundPort,
51			})
52
53			forwards[forwarded.BindAddr] = forwarded
54
55		case <-time.After(10 * time.Second):
56			logger.Info("never-forwarded-tcpip")
57		}
58	}
59
60	gardenForward, found := forwards[req.gardenAddr]
61	if !found {
62		return fmt.Errorf("garden address (%s) not forwarded", req.gardenAddr)
63	}
64
65	baggageclaimForward, found := forwards[req.baggageclaimAddr]
66	if !found {
67		return fmt.Errorf("baggageclaim address (%s) not forwarded", req.baggageclaimAddr)
68	}
69
70	worker.GardenAddr = fmt.Sprintf("%s:%d", req.server.forwardHost, gardenForward.BoundPort)
71	worker.BaggageclaimURL = fmt.Sprintf("http://%s:%d", req.server.forwardHost, baggageclaimForward.BoundPort)
72
73	heartbeater := tsa.NewHeartbeater(
74		clock.NewClock(),
75		req.server.heartbeatInterval,
76		req.server.cprInterval,
77		gclient.BasicGardenClientWithRequestTimeout(
78			lagerctx.WithSession(ctx, "garden-connection"),
79			req.server.gardenRequestTimeout,
80			gardenURL(worker.GardenAddr),
81		),
82		bclient.NewWithHTTPClient(worker.BaggageclaimURL, &http.Client{
83			Transport: &http.Transport{
84				DisableKeepAlives:     true,
85				ResponseHeaderTimeout: 1 * time.Minute,
86			},
87		}),
88		req.server.atcEndpointPicker,
89		req.server.httpClient,
90		worker,
91		tsa.NewEventWriter(channel),
92	)
93
94	err = heartbeater.Heartbeat(ctx)
95	if err != nil {
96		logger.Error("failed-to-heartbeat", err)
97		return err
98	}
99
100	for _, forward := range forwards {
101		// prevent new connections from being accepted
102		close(forward.Drain)
103	}
104
105	// only drain if heartbeating was interrupted; otherwise the worker landed or
106	// retired, so it's time to go away
107	if ctx.Err() != nil {
108		logger.Info("draining-forwarded-connections")
109
110		for _, forward := range forwards {
111			// wait for connections to drain
112			forward.Wait()
113
114			logger.Info("forward-process-exited", lager.Data{
115				"bind-addr":  forward.BindAddr,
116				"bound-port": forward.BoundPort,
117			})
118		}
119	}
120
121	return nil
122}
123
124func (r forwardWorkerRequest) expectedForwards() int {
125	expected := 0
126
127	// Garden should always be forwarded;
128	// if not explicitly given, the only given forward is used
129	expected++
130
131	if r.baggageclaimAddr != "" {
132		expected++
133	}
134
135	return expected
136}
137
138type landWorkerRequest struct {
139	server *server
140}
141
142func checkTeam(state ConnState, worker atc.Worker) error {
143	if state.Team == "" {
144		// global keys can be used for all teams
145		return nil
146	}
147
148	if worker.Team == "" && state.Team != "" {
149		return fmt.Errorf("key is authorized for team %s, but worker is global", state.Team)
150	}
151
152	if worker.Team != state.Team {
153		return fmt.Errorf("key is authorized for team %s, but worker belongs to team %s", state.Team, worker.Team)
154	}
155
156	return nil
157}
158
159func (req landWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
160	var worker atc.Worker
161	err := json.NewDecoder(channel).Decode(&worker)
162	if err != nil {
163		return err
164	}
165
166	if err := checkTeam(state, worker); err != nil {
167		return err
168	}
169
170	return (&tsa.Lander{
171		ATCEndpoint: req.server.atcEndpointPicker.Pick(),
172		HTTPClient:  req.server.httpClient,
173	}).Land(ctx, worker)
174}
175
176type retireWorkerRequest struct {
177	server *server
178}
179
180func (req retireWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
181	var worker atc.Worker
182	err := json.NewDecoder(channel).Decode(&worker)
183	if err != nil {
184		return err
185	}
186
187	if err := checkTeam(state, worker); err != nil {
188		return err
189	}
190
191	return (&tsa.Retirer{
192		ATCEndpoint: req.server.atcEndpointPicker.Pick(),
193		HTTPClient:  req.server.httpClient,
194	}).Retire(ctx, worker)
195}
196
197type deleteWorkerRequest struct {
198	server *server
199}
200
201func (req deleteWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
202	var worker atc.Worker
203	err := json.NewDecoder(channel).Decode(&worker)
204	if err != nil {
205		return err
206	}
207
208	if err := checkTeam(state, worker); err != nil {
209		return err
210	}
211
212	return (&tsa.Deleter{
213		ATCEndpoint: req.server.atcEndpointPicker.Pick(),
214		HTTPClient:  req.server.httpClient,
215	}).Delete(ctx, worker)
216}
217
218type sweepContainersRequest struct {
219	server *server
220}
221
222func (req sweepContainersRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
223	var worker atc.Worker
224	err := json.NewDecoder(channel).Decode(&worker)
225	if err != nil {
226		return err
227	}
228
229	if err := checkTeam(state, worker); err != nil {
230		return err
231	}
232
233	sweeper := &tsa.Sweeper{
234		ATCEndpoint: req.server.atcEndpointPicker.Pick(),
235		HTTPClient:  req.server.httpClient,
236	}
237
238	handles, err := sweeper.Sweep(ctx, worker, tsa.SweepContainers)
239	if err != nil {
240		return err
241	}
242
243	_, err = channel.Write(handles)
244	if err != nil {
245		return err
246	}
247
248	return nil
249}
250
251type reportContainersRequest struct {
252	server           *server
253	containerHandles []string
254}
255
256func (req reportContainersRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
257	var worker atc.Worker
258	err := json.NewDecoder(channel).Decode(&worker)
259	if err != nil {
260		return err
261	}
262
263	if err := checkTeam(state, worker); err != nil {
264		return err
265	}
266
267	return (&tsa.WorkerStatus{
268		ATCEndpoint:      req.server.atcEndpointPicker.Pick(),
269		HTTPClient:       req.server.httpClient,
270		ContainerHandles: req.containerHandles,
271	}).WorkerStatus(ctx, worker, tsa.ReportContainers)
272}
273
274type sweepVolumesRequest struct {
275	server *server
276}
277
278func (req sweepVolumesRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
279	var worker atc.Worker
280	err := json.NewDecoder(channel).Decode(&worker)
281	if err != nil {
282		return err
283	}
284
285	if err := checkTeam(state, worker); err != nil {
286		return err
287	}
288
289	sweeper := &tsa.Sweeper{
290		ATCEndpoint: req.server.atcEndpointPicker.Pick(),
291		HTTPClient:  req.server.httpClient,
292	}
293
294	handles, err := sweeper.Sweep(ctx, worker, tsa.SweepVolumes)
295	if err != nil {
296		return err
297	}
298
299	_, err = channel.Write(handles)
300	if err != nil {
301		return err
302	}
303
304	return nil
305}
306
307type reportVolumesRequest struct {
308	server        *server
309	volumeHandles []string
310}
311
312func (req reportVolumesRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
313	var worker atc.Worker
314	err := json.NewDecoder(channel).Decode(&worker)
315	if err != nil {
316		return err
317	}
318
319	if err := checkTeam(state, worker); err != nil {
320		return err
321	}
322
323	return (&tsa.WorkerStatus{
324		ATCEndpoint:   req.server.atcEndpointPicker.Pick(),
325		HTTPClient:    req.server.httpClient,
326		VolumeHandles: req.volumeHandles,
327	}).WorkerStatus(ctx, worker, tsa.ReportVolumes)
328}
329
330func gardenURL(addr string) string {
331	return fmt.Sprintf("http://%s", addr)
332}
333