1package tsacmd 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "net/http" 8 "time" 9 10 "code.cloudfoundry.org/clock" 11 "code.cloudfoundry.org/lager" 12 "code.cloudfoundry.org/lager/lagerctx" 13 bclient "github.com/concourse/baggageclaim/client" 14 "github.com/concourse/concourse/atc" 15 "github.com/concourse/concourse/atc/worker/gclient" 16 "github.com/concourse/concourse/tsa" 17 "golang.org/x/crypto/ssh" 18) 19 20type request interface { 21 Handle(context.Context, ConnState, ssh.Channel) error 22} 23 24type forwardWorkerRequest struct { 25 server *server 26 27 gardenAddr string 28 baggageclaimAddr string 29} 30 31func (req forwardWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error { 32 logger := lagerctx.FromContext(ctx) 33 34 var worker atc.Worker 35 err := json.NewDecoder(channel).Decode(&worker) 36 if err != nil { 37 return err 38 } 39 40 if err := checkTeam(state, worker); err != nil { 41 return err 42 } 43 44 forwards := map[string]ForwardedTCPIP{} 45 for i := 0; i < 2; i++ { 46 select { 47 case forwarded := <-state.ForwardedTCPIPs: 48 logger.Info("forwarded-tcpip", lager.Data{ 49 "bind-addr": forwarded.BindAddr, 50 "bound-port": forwarded.BoundPort, 51 }) 52 53 forwards[forwarded.BindAddr] = forwarded 54 55 case <-time.After(10 * time.Second): 56 logger.Info("never-forwarded-tcpip") 57 } 58 } 59 60 gardenForward, found := forwards[req.gardenAddr] 61 if !found { 62 return fmt.Errorf("garden address (%s) not forwarded", req.gardenAddr) 63 } 64 65 baggageclaimForward, found := forwards[req.baggageclaimAddr] 66 if !found { 67 return fmt.Errorf("baggageclaim address (%s) not forwarded", req.baggageclaimAddr) 68 } 69 70 worker.GardenAddr = fmt.Sprintf("%s:%d", req.server.forwardHost, gardenForward.BoundPort) 71 worker.BaggageclaimURL = fmt.Sprintf("http://%s:%d", req.server.forwardHost, baggageclaimForward.BoundPort) 72 73 heartbeater := tsa.NewHeartbeater( 74 clock.NewClock(), 75 req.server.heartbeatInterval, 76 req.server.cprInterval, 77 gclient.BasicGardenClientWithRequestTimeout( 78 lagerctx.WithSession(ctx, "garden-connection"), 79 req.server.gardenRequestTimeout, 80 gardenURL(worker.GardenAddr), 81 ), 82 bclient.NewWithHTTPClient(worker.BaggageclaimURL, &http.Client{ 83 Transport: &http.Transport{ 84 DisableKeepAlives: true, 85 ResponseHeaderTimeout: 1 * time.Minute, 86 }, 87 }), 88 req.server.atcEndpointPicker, 89 req.server.httpClient, 90 worker, 91 tsa.NewEventWriter(channel), 92 ) 93 94 err = heartbeater.Heartbeat(ctx) 95 if err != nil { 96 logger.Error("failed-to-heartbeat", err) 97 return err 98 } 99 100 for _, forward := range forwards { 101 // prevent new connections from being accepted 102 close(forward.Drain) 103 } 104 105 // only drain if heartbeating was interrupted; otherwise the worker landed or 106 // retired, so it's time to go away 107 if ctx.Err() != nil { 108 logger.Info("draining-forwarded-connections") 109 110 for _, forward := range forwards { 111 // wait for connections to drain 112 forward.Wait() 113 114 logger.Info("forward-process-exited", lager.Data{ 115 "bind-addr": forward.BindAddr, 116 "bound-port": forward.BoundPort, 117 }) 118 } 119 } 120 121 return nil 122} 123 124func (r forwardWorkerRequest) expectedForwards() int { 125 expected := 0 126 127 // Garden should always be forwarded; 128 // if not explicitly given, the only given forward is used 129 expected++ 130 131 if r.baggageclaimAddr != "" { 132 expected++ 133 } 134 135 return expected 136} 137 138type landWorkerRequest struct { 139 server *server 140} 141 142func checkTeam(state ConnState, worker atc.Worker) error { 143 if state.Team == "" { 144 // global keys can be used for all teams 145 return nil 146 } 147 148 if worker.Team == "" && state.Team != "" { 149 return fmt.Errorf("key is authorized for team %s, but worker is global", state.Team) 150 } 151 152 if worker.Team != state.Team { 153 return fmt.Errorf("key is authorized for team %s, but worker belongs to team %s", state.Team, worker.Team) 154 } 155 156 return nil 157} 158 159func (req landWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error { 160 var worker atc.Worker 161 err := json.NewDecoder(channel).Decode(&worker) 162 if err != nil { 163 return err 164 } 165 166 if err := checkTeam(state, worker); err != nil { 167 return err 168 } 169 170 return (&tsa.Lander{ 171 ATCEndpoint: req.server.atcEndpointPicker.Pick(), 172 HTTPClient: req.server.httpClient, 173 }).Land(ctx, worker) 174} 175 176type retireWorkerRequest struct { 177 server *server 178} 179 180func (req retireWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error { 181 var worker atc.Worker 182 err := json.NewDecoder(channel).Decode(&worker) 183 if err != nil { 184 return err 185 } 186 187 if err := checkTeam(state, worker); err != nil { 188 return err 189 } 190 191 return (&tsa.Retirer{ 192 ATCEndpoint: req.server.atcEndpointPicker.Pick(), 193 HTTPClient: req.server.httpClient, 194 }).Retire(ctx, worker) 195} 196 197type deleteWorkerRequest struct { 198 server *server 199} 200 201func (req deleteWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error { 202 var worker atc.Worker 203 err := json.NewDecoder(channel).Decode(&worker) 204 if err != nil { 205 return err 206 } 207 208 if err := checkTeam(state, worker); err != nil { 209 return err 210 } 211 212 return (&tsa.Deleter{ 213 ATCEndpoint: req.server.atcEndpointPicker.Pick(), 214 HTTPClient: req.server.httpClient, 215 }).Delete(ctx, worker) 216} 217 218type sweepContainersRequest struct { 219 server *server 220} 221 222func (req sweepContainersRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error { 223 var worker atc.Worker 224 err := json.NewDecoder(channel).Decode(&worker) 225 if err != nil { 226 return err 227 } 228 229 if err := checkTeam(state, worker); err != nil { 230 return err 231 } 232 233 sweeper := &tsa.Sweeper{ 234 ATCEndpoint: req.server.atcEndpointPicker.Pick(), 235 HTTPClient: req.server.httpClient, 236 } 237 238 handles, err := sweeper.Sweep(ctx, worker, tsa.SweepContainers) 239 if err != nil { 240 return err 241 } 242 243 _, err = channel.Write(handles) 244 if err != nil { 245 return err 246 } 247 248 return nil 249} 250 251type reportContainersRequest struct { 252 server *server 253 containerHandles []string 254} 255 256func (req reportContainersRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error { 257 var worker atc.Worker 258 err := json.NewDecoder(channel).Decode(&worker) 259 if err != nil { 260 return err 261 } 262 263 if err := checkTeam(state, worker); err != nil { 264 return err 265 } 266 267 return (&tsa.WorkerStatus{ 268 ATCEndpoint: req.server.atcEndpointPicker.Pick(), 269 HTTPClient: req.server.httpClient, 270 ContainerHandles: req.containerHandles, 271 }).WorkerStatus(ctx, worker, tsa.ReportContainers) 272} 273 274type sweepVolumesRequest struct { 275 server *server 276} 277 278func (req sweepVolumesRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error { 279 var worker atc.Worker 280 err := json.NewDecoder(channel).Decode(&worker) 281 if err != nil { 282 return err 283 } 284 285 if err := checkTeam(state, worker); err != nil { 286 return err 287 } 288 289 sweeper := &tsa.Sweeper{ 290 ATCEndpoint: req.server.atcEndpointPicker.Pick(), 291 HTTPClient: req.server.httpClient, 292 } 293 294 handles, err := sweeper.Sweep(ctx, worker, tsa.SweepVolumes) 295 if err != nil { 296 return err 297 } 298 299 _, err = channel.Write(handles) 300 if err != nil { 301 return err 302 } 303 304 return nil 305} 306 307type reportVolumesRequest struct { 308 server *server 309 volumeHandles []string 310} 311 312func (req reportVolumesRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error { 313 var worker atc.Worker 314 err := json.NewDecoder(channel).Decode(&worker) 315 if err != nil { 316 return err 317 } 318 319 if err := checkTeam(state, worker); err != nil { 320 return err 321 } 322 323 return (&tsa.WorkerStatus{ 324 ATCEndpoint: req.server.atcEndpointPicker.Pick(), 325 HTTPClient: req.server.httpClient, 326 VolumeHandles: req.volumeHandles, 327 }).WorkerStatus(ctx, worker, tsa.ReportVolumes) 328} 329 330func gardenURL(addr string) string { 331 return fmt.Sprintf("http://%s", addr) 332} 333