1// Copyright 2012 The Go Authors.  All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
5// +build !appengine
7// Package socket implements an WebSocket-based playground backend.
8// Clients connect to a websocket handler and send run/kill commands, and
9// the server sends the output and exit status of the running processes.
10// Multiple clients running multiple processes may be served concurrently.
11// The wire format is JSON and is described by the Message type.
13// This will not run on App Engine as WebSockets are not supported there.
14package socket // import "golang.org/x/tools/playground/socket"
16import (
17	"bytes"
18	"encoding/json"
19	"errors"
20	"go/parser"
21	"go/token"
22	"io"
23	"io/ioutil"
24	"log"
25	"net"
26	"net/http"
27	"net/url"
28	"os"
29	"os/exec"
30	"path/filepath"
31	"runtime"
32	"strings"
33	"time"
34	"unicode/utf8"
36	"golang.org/x/net/websocket"
37	"golang.org/x/tools/txtar"
40// RunScripts specifies whether the socket handler should execute shell scripts
41// (snippets that start with a shebang).
42var RunScripts = true
44// Environ provides an environment when a binary, such as the go tool, is
45// invoked.
46var Environ func() []string = os.Environ
48const (
49	// The maximum number of messages to send per session (avoid flooding).
50	msgLimit = 1000
52	// Batch messages sent in this interval and send as a single message.
53	msgDelay = 10 * time.Millisecond
56// Message is the wire format for the websocket connection to the browser.
57// It is used for both sending output messages and receiving commands, as
58// distinguished by the Kind field.
59type Message struct {
60	Id      string // client-provided unique id for the process
61	Kind    string // in: "run", "kill" out: "stdout", "stderr", "end"
62	Body    string
63	Options *Options `json:",omitempty"`
66// Options specify additional message options.
67type Options struct {
68	Race bool // use -race flag when building code (for "run" only)
71// NewHandler returns a websocket server which checks the origin of requests.
72func NewHandler(origin *url.URL) websocket.Server {
73	return websocket.Server{
74		Config:    websocket.Config{Origin: origin},
75		Handshake: handshake,
76		Handler:   websocket.Handler(socketHandler),
77	}
80// handshake checks the origin of a request during the websocket handshake.
81func handshake(c *websocket.Config, req *http.Request) error {
82	o, err := websocket.Origin(c, req)
83	if err != nil {
84		log.Println("bad websocket origin:", err)
85		return websocket.ErrBadWebSocketOrigin
86	}
87	_, port, err := net.SplitHostPort(c.Origin.Host)
88	if err != nil {
89		log.Println("bad websocket origin:", err)
90		return websocket.ErrBadWebSocketOrigin
91	}
92	ok := c.Origin.Scheme == o.Scheme && (c.Origin.Host == o.Host || c.Origin.Host == net.JoinHostPort(o.Host, port))
93	if !ok {
94		log.Println("bad websocket origin:", o)
95		return websocket.ErrBadWebSocketOrigin
96	}
97	log.Println("accepting connection from:", req.RemoteAddr)
98	return nil
101// socketHandler handles the websocket connection for a given present session.
102// It handles transcoding Messages to and from JSON format, and starting
103// and killing processes.
104func socketHandler(c *websocket.Conn) {
105	in, out := make(chan *Message), make(chan *Message)
106	errc := make(chan error, 1)
108	// Decode messages from client and send to the in channel.
109	go func() {
110		dec := json.NewDecoder(c)
111		for {
112			var m Message
113			if err := dec.Decode(&m); err != nil {
114				errc <- err
115				return
116			}
117			in <- &m
118		}
119	}()
121	// Receive messages from the out channel and encode to the client.
122	go func() {
123		enc := json.NewEncoder(c)
124		for m := range out {
125			if err := enc.Encode(m); err != nil {
126				errc <- err
127				return
128			}
129		}
130	}()
131	defer close(out)
133	// Start and kill processes and handle errors.
134	proc := make(map[string]*process)
135	for {
136		select {
137		case m := <-in:
138			switch m.Kind {
139			case "run":
140				log.Println("running snippet from:", c.Request().RemoteAddr)
141				proc[m.Id].Kill()
142				proc[m.Id] = startProcess(m.Id, m.Body, out, m.Options)
143			case "kill":
144				proc[m.Id].Kill()
145			}
146		case err := <-errc:
147			if err != io.EOF {
148				// A encode or decode has failed; bail.
149				log.Println(err)
150			}
151			// Shut down any running processes.
152			for _, p := range proc {
153				p.Kill()
154			}
155			return
156		}
157	}
160// process represents a running process.
161type process struct {
162	out  chan<- *Message
163	done chan struct{} // closed when wait completes
164	run  *exec.Cmd
165	path string
168// startProcess builds and runs the given program, sending its output
169// and end event as Messages on the provided channel.
170func startProcess(id, body string, dest chan<- *Message, opt *Options) *process {
171	var (
172		done = make(chan struct{})
173		out  = make(chan *Message)
174		p    = &process{out: out, done: done}
175	)
176	go func() {
177		defer close(done)
178		for m := range buffer(limiter(out, p), time.After) {
179			m.Id = id
180			dest <- m
181		}
182	}()
183	var err error
184	if path, args := shebang(body); path != "" {
185		if RunScripts {
186			err = p.startProcess(path, args, body)
187		} else {
188			err = errors.New("script execution is not allowed")
189		}
190	} else {
191		err = p.start(body, opt)
192	}
193	if err != nil {
194		p.end(err)
195		return nil
196	}
197	go func() {
198		p.end(p.run.Wait())
199	}()
200	return p
203// end sends an "end" message to the client, containing the process id and the
204// given error value. It also removes the binary, if present.
205func (p *process) end(err error) {
206	if p.path != "" {
207		defer os.RemoveAll(p.path)
208	}
209	m := &Message{Kind: "end"}
210	if err != nil {
211		m.Body = err.Error()
212	}
213	p.out <- m
214	close(p.out)
217// A killer provides a mechanism to terminate a process.
218// The Kill method returns only once the process has exited.
219type killer interface {
220	Kill()
223// limiter returns a channel that wraps the given channel.
224// It receives Messages from the given channel and sends them to the returned
225// channel until it passes msgLimit messages, at which point it will kill the
226// process and pass only the "end" message.
227// When the given channel is closed, or when the "end" message is received,
228// it closes the returned channel.
229func limiter(in <-chan *Message, p killer) <-chan *Message {
230	out := make(chan *Message)
231	go func() {
232		defer close(out)
233		n := 0
234		for m := range in {
235			switch {
236			case n < msgLimit || m.Kind == "end":
237				out <- m
238				if m.Kind == "end" {
239					return
240				}
241			case n == msgLimit:
242				// Kill in a goroutine as Kill will not return
243				// until the process' output has been
244				// processed, and we're doing that in this loop.
245				go p.Kill()
246			default:
247				continue // don't increment
248			}
249			n++
250		}
251	}()
252	return out
255// buffer returns a channel that wraps the given channel. It receives messages
256// from the given channel and sends them to the returned channel.
257// Message bodies are gathered over the period msgDelay and coalesced into a
258// single Message before they are passed on. Messages of the same kind are
259// coalesced; when a message of a different kind is received, any buffered
260// messages are flushed. When the given channel is closed, buffer flushes the
261// remaining buffered messages and closes the returned channel.
262// The timeAfter func should be time.After. It exists for testing.
263func buffer(in <-chan *Message, timeAfter func(time.Duration) <-chan time.Time) <-chan *Message {
264	out := make(chan *Message)
265	go func() {
266		defer close(out)
267		var (
268			tc    <-chan time.Time
269			buf   []byte
270			kind  string
271			flush = func() {
272				if len(buf) == 0 {
273					return
274				}
275				out <- &Message{Kind: kind, Body: safeString(buf)}
276				buf = buf[:0] // recycle buffer
277				kind = ""
278			}
279		)
280		for {
281			select {
282			case m, ok := <-in:
283				if !ok {
284					flush()
285					return
286				}
287				if m.Kind == "end" {
288					flush()
289					out <- m
290					return
291				}
292				if kind != m.Kind {
293					flush()
294					kind = m.Kind
295					if tc == nil {
296						tc = timeAfter(msgDelay)
297					}
298				}
299				buf = append(buf, m.Body...)
300			case <-tc:
301				flush()
302				tc = nil
303			}
304		}
305	}()
306	return out
309// Kill stops the process if it is running and waits for it to exit.
310func (p *process) Kill() {
311	if p == nil || p.run == nil {
312		return
313	}
314	p.run.Process.Kill()
315	<-p.done // block until process exits
318// shebang looks for a shebang ('#!') at the beginning of the passed string.
319// If found, it returns the path and args after the shebang.
320// args includes the command as args[0].
321func shebang(body string) (path string, args []string) {
322	body = strings.TrimSpace(body)
323	if !strings.HasPrefix(body, "#!") {
324		return "", nil
325	}
326	if i := strings.Index(body, "\n"); i >= 0 {
327		body = body[:i]
328	}
329	fs := strings.Fields(body[2:])
330	return fs[0], fs
333// startProcess starts a given program given its path and passing the given body
334// to the command standard input.
335func (p *process) startProcess(path string, args []string, body string) error {
336	cmd := &exec.Cmd{
337		Path:   path,
338		Args:   args,
339		Stdin:  strings.NewReader(body),
340		Stdout: &messageWriter{kind: "stdout", out: p.out},
341		Stderr: &messageWriter{kind: "stderr", out: p.out},
342	}
343	if err := cmd.Start(); err != nil {
344		return err
345	}
346	p.run = cmd
347	return nil
350// start builds and starts the given program, sending its output to p.out,
351// and stores the running *exec.Cmd in the run field.
352func (p *process) start(body string, opt *Options) error {
353	// We "go build" and then exec the binary so that the
354	// resultant *exec.Cmd is a handle to the user's program
355	// (rather than the go tool process).
356	// This makes Kill work.
358	path, err := ioutil.TempDir("", "present-")
359	if err != nil {
360		return err
361	}
362	defer os.RemoveAll(path)
364	out := "prog"
365	if runtime.GOOS == "windows" {
366		out = "prog.exe"
367	}
368	bin := filepath.Join(path, out)
370	// write body to x.go files
371	a := txtar.Parse([]byte(body))
372	if len(a.Comment) != 0 {
373		a.Files = append(a.Files, txtar.File{Name: "prog.go", Data: a.Comment})
374		a.Comment = nil
375	}
376	hasModfile := false
377	for _, f := range a.Files {
378		err = ioutil.WriteFile(filepath.Join(path, f.Name), f.Data, 0666)
379		if err != nil {
380			return err
381		}
382		if f.Name == "go.mod" {
383			hasModfile = true
384		}
385	}
387	// build x.go, creating x
388	p.path = path // to be removed by p.end
389	args := []string{"go", "build", "-tags", "OMIT"}
390	if opt != nil && opt.Race {
391		p.out <- &Message{
392			Kind: "stderr",
393			Body: "Running with race detector.\n",
394		}
395		args = append(args, "-race")
396	}
397	args = append(args, "-o", bin)
398	cmd := p.cmd(path, args...)
399	if !hasModfile {
400		cmd.Env = append(cmd.Env, "GO111MODULE=off")
401	}
402	cmd.Stdout = cmd.Stderr // send compiler output to stderr
403	if err := cmd.Run(); err != nil {
404		return err
405	}
407	// run x
408	if isNacl() {
409		cmd, err = p.naclCmd(bin)
410		if err != nil {
411			return err
412		}
413	} else {
414		cmd = p.cmd("", bin)
415	}
416	if opt != nil && opt.Race {
417		cmd.Env = append(cmd.Env, "GOMAXPROCS=2")
418	}
419	if err := cmd.Start(); err != nil {
420		// If we failed to exec, that might be because they built
421		// a non-main package instead of an executable.
422		// Check and report that.
423		if name, err := packageName(body); err == nil && name != "main" {
424			return errors.New(`executable programs must use "package main"`)
425		}
426		return err
427	}
428	p.run = cmd
429	return nil
432// cmd builds an *exec.Cmd that writes its standard output and error to the
433// process' output channel.
434func (p *process) cmd(dir string, args ...string) *exec.Cmd {
435	cmd := exec.Command(args[0], args[1:]...)
436	cmd.Dir = dir
437	cmd.Env = Environ()
438	cmd.Stdout = &messageWriter{kind: "stdout", out: p.out}
439	cmd.Stderr = &messageWriter{kind: "stderr", out: p.out}
440	return cmd
443func isNacl() bool {
444	for _, v := range append(Environ(), os.Environ()...) {
445		if v == "GOOS=nacl" {
446			return true
447		}
448	}
449	return false
452// naclCmd returns an *exec.Cmd that executes bin under native client.
453func (p *process) naclCmd(bin string) (*exec.Cmd, error) {
454	pwd, err := os.Getwd()
455	if err != nil {
456		return nil, err
457	}
458	var args []string
459	env := []string{
460		"NACLENV_GOOS=" + runtime.GOOS,
461		"NACLENV_GOROOT=/go",
462		"NACLENV_NACLPWD=" + strings.Replace(pwd, runtime.GOROOT(), "/go", 1),
463	}
464	switch runtime.GOARCH {
465	case "amd64":
466		env = append(env, "NACLENV_GOARCH=amd64p32")
467		args = []string{"sel_ldr_x86_64"}
468	case "386":
469		env = append(env, "NACLENV_GOARCH=386")
470		args = []string{"sel_ldr_x86_32"}
471	case "arm":
472		env = append(env, "NACLENV_GOARCH=arm")
473		selLdr, err := exec.LookPath("sel_ldr_arm")
474		if err != nil {
475			return nil, err
476		}
477		args = []string{"nacl_helper_bootstrap_arm", selLdr, "--reserved_at_zero=0xXXXXXXXXXXXXXXXX"}
478	default:
479		return nil, errors.New("native client does not support GOARCH=" + runtime.GOARCH)
480	}
482	cmd := p.cmd("", append(args, "-l", "/dev/null", "-S", "-e", bin)...)
483	cmd.Env = append(cmd.Env, env...)
485	return cmd, nil
488func packageName(body string) (string, error) {
489	f, err := parser.ParseFile(token.NewFileSet(), "prog.go",
490		strings.NewReader(body), parser.PackageClauseOnly)
491	if err != nil {
492		return "", err
493	}
494	return f.Name.String(), nil
497// messageWriter is an io.Writer that converts all writes to Message sends on
498// the out channel with the specified id and kind.
499type messageWriter struct {
500	kind string
501	out  chan<- *Message
504func (w *messageWriter) Write(b []byte) (n int, err error) {
505	w.out <- &Message{Kind: w.kind, Body: safeString(b)}
506	return len(b), nil
509// safeString returns b as a valid UTF-8 string.
510func safeString(b []byte) string {
511	if utf8.Valid(b) {
512		return string(b)
513	}
514	var buf bytes.Buffer
515	for len(b) > 0 {
516		r, size := utf8.DecodeRune(b)
517		b = b[size:]
518		buf.WriteRune(r)
519	}
520	return buf.String()