1package hijacker
2
3import (
4	"crypto/tls"
5	"fmt"
6	"io"
7	"net/http"
8	"os"
9	"time"
10
11	"github.com/concourse/concourse/atc"
12	"github.com/concourse/concourse/fly/pty"
13	"github.com/concourse/concourse/fly/rc"
14	"github.com/concourse/concourse/fly/ui"
15	"github.com/gorilla/websocket"
16	"github.com/mgutz/ansi"
17	"github.com/tedsuo/rata"
18)
19
20type ProcessIO struct {
21	In  io.Reader
22	Out io.Writer
23	Err io.Writer
24}
25
26type Hijacker struct {
27	tlsConfig        *tls.Config
28	requestGenerator *rata.RequestGenerator
29	token            *rc.TargetToken
30	interval         time.Duration
31}
32
33func New(tlsConfig *tls.Config, requestGenerator *rata.RequestGenerator, token *rc.TargetToken) *Hijacker {
34	return &Hijacker{
35		tlsConfig:        tlsConfig,
36		requestGenerator: requestGenerator,
37		token:            token,
38		interval:         10 * time.Second,
39	}
40}
41
42func (h *Hijacker) SetHeartbeatInterval(interval time.Duration) {
43	h.interval = interval
44}
45
46func (h *Hijacker) Hijack(teamName, handle string, spec atc.HijackProcessSpec, pio ProcessIO) (int, error) {
47	url, header, err := h.hijackRequestParts(teamName, handle)
48	if err != nil {
49		return -1, err
50	}
51
52	dialer := websocket.Dialer{
53		TLSClientConfig: h.tlsConfig,
54		Proxy:           http.ProxyFromEnvironment,
55	}
56	conn, response, err := dialer.Dial(url, header)
57	if err != nil {
58		return -1, fmt.Errorf("%s %w", response.Status, err)
59	}
60
61	defer conn.Close()
62
63	err = conn.WriteJSON(spec)
64	if err != nil {
65		return -1, err
66	}
67
68	inputs := make(chan atc.HijackInput, 1)
69	finished := make(chan struct{}, 1)
70
71	go h.monitorTTYSize(inputs, finished)
72	go func() {
73		io.Copy(&stdinWriter{inputs}, pio.In)
74		inputs <- atc.HijackInput{Closed: true}
75	}()
76	go h.handleInput(conn, inputs, finished)
77
78	exitStatus := h.handleOutput(conn, pio)
79
80	close(finished)
81
82	return exitStatus, nil
83}
84
85func (h *Hijacker) hijackRequestParts(teamName, handle string) (string, http.Header, error) {
86	hijackReq, err := h.requestGenerator.CreateRequest(
87		atc.HijackContainer,
88		rata.Params{"id": handle, "team_name": teamName},
89		nil,
90	)
91
92	if err != nil {
93		panic(err)
94	}
95
96	if h.token != nil {
97		hijackReq.Header.Add("Authorization", h.token.Type+" "+h.token.Value)
98	}
99
100	wsUrl := hijackReq.URL
101
102	var found bool
103	wsUrl.Scheme, found = websocketSchemeMap[wsUrl.Scheme]
104	if !found {
105		return "", nil, fmt.Errorf("unknown target scheme: %s", wsUrl.Scheme)
106	}
107
108	return wsUrl.String(), hijackReq.Header, nil
109}
110
111func (h *Hijacker) handleOutput(conn *websocket.Conn, pio ProcessIO) int {
112	var exitStatus int
113	for {
114		var output atc.HijackOutput
115		err := conn.ReadJSON(&output)
116		if err != nil {
117			if !websocket.IsCloseError(err) && !websocket.IsUnexpectedCloseError(err) {
118				fmt.Println(err)
119			}
120			break
121		}
122
123		if output.ExitStatus != nil {
124			exitStatus = *output.ExitStatus
125		} else if len(output.Error) > 0 {
126			fmt.Fprintf(ui.Stderr, "%s\n", ansi.Color(output.Error, "red+b"))
127			exitStatus = 255
128		} else if len(output.Stdout) > 0 {
129			pio.Out.Write(output.Stdout)
130		} else if len(output.Stderr) > 0 {
131			pio.Err.Write(output.Stderr)
132		}
133	}
134
135	return exitStatus
136}
137
138func (h *Hijacker) handleInput(conn *websocket.Conn, inputs <-chan atc.HijackInput, finished chan struct{}) {
139	ticker := time.NewTicker(h.interval)
140	defer ticker.Stop()
141
142	for {
143		select {
144		case input := <-inputs:
145			err := conn.WriteJSON(input)
146			if err != nil {
147				fmt.Fprintf(ui.Stderr, "failed to send input: %s", err.Error())
148				return
149			}
150		case t := <-ticker.C:
151			err := conn.WriteControl(websocket.PingMessage, []byte(t.String()), time.Now().Add(time.Second))
152			if err != nil {
153				fmt.Fprintf(ui.Stderr, "failed to send heartbeat: %s", err.Error())
154			}
155		case <-finished:
156			return
157		}
158	}
159}
160
161func (h *Hijacker) monitorTTYSize(inputs chan<- atc.HijackInput, finished chan struct{}) {
162	resized := pty.ResizeNotifier()
163
164	for {
165		select {
166		case <-resized:
167			rows, cols, err := pty.Getsize(os.Stdin)
168			if err == nil {
169				inputs <- atc.HijackInput{
170					TTYSpec: &atc.HijackTTYSpec{
171						WindowSize: atc.HijackWindowSize{
172							Columns: cols,
173							Rows:    rows,
174						},
175					},
176				}
177			}
178		case <-finished:
179			return
180		}
181	}
182}
183
184type stdinWriter struct {
185	inputs chan<- atc.HijackInput
186}
187
188func (w *stdinWriter) Write(d []byte) (int, error) {
189	w.inputs <- atc.HijackInput{
190		Stdin: d,
191	}
192
193	return len(d), nil
194}
195
196var websocketSchemeMap = map[string]string{
197	"http":  "ws",
198	"https": "wss",
199}
200