1package hijacker 2 3import ( 4 "crypto/tls" 5 "fmt" 6 "io" 7 "net/http" 8 "os" 9 "time" 10 11 "github.com/concourse/concourse/atc" 12 "github.com/concourse/concourse/fly/pty" 13 "github.com/concourse/concourse/fly/rc" 14 "github.com/concourse/concourse/fly/ui" 15 "github.com/gorilla/websocket" 16 "github.com/mgutz/ansi" 17 "github.com/tedsuo/rata" 18) 19 20type ProcessIO struct { 21 In io.Reader 22 Out io.Writer 23 Err io.Writer 24} 25 26type Hijacker struct { 27 tlsConfig *tls.Config 28 requestGenerator *rata.RequestGenerator 29 token *rc.TargetToken 30 interval time.Duration 31} 32 33func New(tlsConfig *tls.Config, requestGenerator *rata.RequestGenerator, token *rc.TargetToken) *Hijacker { 34 return &Hijacker{ 35 tlsConfig: tlsConfig, 36 requestGenerator: requestGenerator, 37 token: token, 38 interval: 10 * time.Second, 39 } 40} 41 42func (h *Hijacker) SetHeartbeatInterval(interval time.Duration) { 43 h.interval = interval 44} 45 46func (h *Hijacker) Hijack(teamName, handle string, spec atc.HijackProcessSpec, pio ProcessIO) (int, error) { 47 url, header, err := h.hijackRequestParts(teamName, handle) 48 if err != nil { 49 return -1, err 50 } 51 52 dialer := websocket.Dialer{ 53 TLSClientConfig: h.tlsConfig, 54 Proxy: http.ProxyFromEnvironment, 55 } 56 conn, response, err := dialer.Dial(url, header) 57 if err != nil { 58 return -1, fmt.Errorf("%s %w", response.Status, err) 59 } 60 61 defer conn.Close() 62 63 err = conn.WriteJSON(spec) 64 if err != nil { 65 return -1, err 66 } 67 68 inputs := make(chan atc.HijackInput, 1) 69 finished := make(chan struct{}, 1) 70 71 go h.monitorTTYSize(inputs, finished) 72 go func() { 73 io.Copy(&stdinWriter{inputs}, pio.In) 74 inputs <- atc.HijackInput{Closed: true} 75 }() 76 go h.handleInput(conn, inputs, finished) 77 78 exitStatus := h.handleOutput(conn, pio) 79 80 close(finished) 81 82 return exitStatus, nil 83} 84 85func (h *Hijacker) hijackRequestParts(teamName, handle string) (string, http.Header, error) { 86 hijackReq, err := h.requestGenerator.CreateRequest( 87 atc.HijackContainer, 88 rata.Params{"id": handle, "team_name": teamName}, 89 nil, 90 ) 91 92 if err != nil { 93 panic(err) 94 } 95 96 if h.token != nil { 97 hijackReq.Header.Add("Authorization", h.token.Type+" "+h.token.Value) 98 } 99 100 wsUrl := hijackReq.URL 101 102 var found bool 103 wsUrl.Scheme, found = websocketSchemeMap[wsUrl.Scheme] 104 if !found { 105 return "", nil, fmt.Errorf("unknown target scheme: %s", wsUrl.Scheme) 106 } 107 108 return wsUrl.String(), hijackReq.Header, nil 109} 110 111func (h *Hijacker) handleOutput(conn *websocket.Conn, pio ProcessIO) int { 112 var exitStatus int 113 for { 114 var output atc.HijackOutput 115 err := conn.ReadJSON(&output) 116 if err != nil { 117 if !websocket.IsCloseError(err) && !websocket.IsUnexpectedCloseError(err) { 118 fmt.Println(err) 119 } 120 break 121 } 122 123 if output.ExitStatus != nil { 124 exitStatus = *output.ExitStatus 125 } else if len(output.Error) > 0 { 126 fmt.Fprintf(ui.Stderr, "%s\n", ansi.Color(output.Error, "red+b")) 127 exitStatus = 255 128 } else if len(output.Stdout) > 0 { 129 pio.Out.Write(output.Stdout) 130 } else if len(output.Stderr) > 0 { 131 pio.Err.Write(output.Stderr) 132 } 133 } 134 135 return exitStatus 136} 137 138func (h *Hijacker) handleInput(conn *websocket.Conn, inputs <-chan atc.HijackInput, finished chan struct{}) { 139 ticker := time.NewTicker(h.interval) 140 defer ticker.Stop() 141 142 for { 143 select { 144 case input := <-inputs: 145 err := conn.WriteJSON(input) 146 if err != nil { 147 fmt.Fprintf(ui.Stderr, "failed to send input: %s", err.Error()) 148 return 149 } 150 case t := <-ticker.C: 151 err := conn.WriteControl(websocket.PingMessage, []byte(t.String()), time.Now().Add(time.Second)) 152 if err != nil { 153 fmt.Fprintf(ui.Stderr, "failed to send heartbeat: %s", err.Error()) 154 } 155 case <-finished: 156 return 157 } 158 } 159} 160 161func (h *Hijacker) monitorTTYSize(inputs chan<- atc.HijackInput, finished chan struct{}) { 162 resized := pty.ResizeNotifier() 163 164 for { 165 select { 166 case <-resized: 167 rows, cols, err := pty.Getsize(os.Stdin) 168 if err == nil { 169 inputs <- atc.HijackInput{ 170 TTYSpec: &atc.HijackTTYSpec{ 171 WindowSize: atc.HijackWindowSize{ 172 Columns: cols, 173 Rows: rows, 174 }, 175 }, 176 } 177 } 178 case <-finished: 179 return 180 } 181 } 182} 183 184type stdinWriter struct { 185 inputs chan<- atc.HijackInput 186} 187 188func (w *stdinWriter) Write(d []byte) (int, error) { 189 w.inputs <- atc.HijackInput{ 190 Stdin: d, 191 } 192 193 return len(d), nil 194} 195 196var websocketSchemeMap = map[string]string{ 197 "http": "ws", 198 "https": "wss", 199} 200