1// Copyright 2018 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// +build linux
16
17package server
18
19import (
20	"fmt"
21	"os"
22	"runtime"
23	"syscall"
24	"time"
25)
26
27// ptraceRun runs all the closures from fc on a dedicated OS thread. Errors
28// are returned on ec. Both channels must be unbuffered, to ensure that the
29// resultant error is sent back to the same goroutine that sent the closure.
30func ptraceRun(fc chan func() error, ec chan error) {
31	if cap(fc) != 0 || cap(ec) != 0 {
32		panic("ptraceRun was given buffered channels")
33	}
34	runtime.LockOSThread()
35	for f := range fc {
36		ec <- f()
37	}
38}
39
40func (s *Server) startProcess(name string, argv []string, attr *os.ProcAttr) (proc *os.Process, err error) {
41	s.fc <- func() error {
42		var err1 error
43		proc, err1 = os.StartProcess(name, argv, attr)
44		return err1
45	}
46	err = <-s.ec
47	return
48}
49
50func (s *Server) ptraceCont(pid int, signal int) (err error) {
51	s.fc <- func() error {
52		return syscall.PtraceCont(pid, signal)
53	}
54	return <-s.ec
55}
56
57func (s *Server) ptraceGetRegs(pid int, regsout *syscall.PtraceRegs) (err error) {
58	s.fc <- func() error {
59		return syscall.PtraceGetRegs(pid, regsout)
60	}
61	return <-s.ec
62}
63
64func (s *Server) ptracePeek(pid int, addr uintptr, out []byte) (err error) {
65	s.fc <- func() error {
66		n, err := syscall.PtracePeekText(pid, addr, out)
67		if err != nil {
68			return err
69		}
70		if n != len(out) {
71			return fmt.Errorf("ptracePeek: peeked %d bytes, want %d", n, len(out))
72		}
73		return nil
74	}
75	return <-s.ec
76}
77
78func (s *Server) ptracePoke(pid int, addr uintptr, data []byte) (err error) {
79	s.fc <- func() error {
80		n, err := syscall.PtracePokeText(pid, addr, data)
81		if err != nil {
82			return err
83		}
84		if n != len(data) {
85			return fmt.Errorf("ptracePoke: poked %d bytes, want %d", n, len(data))
86		}
87		return nil
88	}
89	return <-s.ec
90}
91
92func (s *Server) ptraceSetOptions(pid int, options int) (err error) {
93	s.fc <- func() error {
94		return syscall.PtraceSetOptions(pid, options)
95	}
96	return <-s.ec
97}
98
99func (s *Server) ptraceSetRegs(pid int, regs *syscall.PtraceRegs) (err error) {
100	s.fc <- func() error {
101		return syscall.PtraceSetRegs(pid, regs)
102	}
103	return <-s.ec
104}
105
106func (s *Server) ptraceSingleStep(pid int) (err error) {
107	s.fc <- func() error {
108		return syscall.PtraceSingleStep(pid)
109	}
110	return <-s.ec
111}
112
113type breakpointsChangedError struct {
114	call call
115}
116
117func (*breakpointsChangedError) Error() string {
118	return "breakpoints changed"
119}
120
121func (s *Server) wait(pid int, allowBreakpointsChange bool) (wpid int, status syscall.WaitStatus, err error) {
122	// We poll syscall.Wait4 with WNOHANG, sleeping in between, as a poor man's
123	// waitpid-with-timeout. This allows adding and removing breakpoints
124	// concurrently with waiting to hit an existing breakpoint.
125	f := func() error {
126		var err1 error
127		wpid, err1 = syscall.Wait4(pid, &status, syscall.WALL|syscall.WNOHANG, nil)
128		return err1
129	}
130
131	const (
132		minSleep = 1 * time.Microsecond
133		maxSleep = 100 * time.Millisecond
134	)
135	for sleep := minSleep; ; {
136		s.fc <- f
137		err = <-s.ec
138
139		// wpid == 0 means that wait found nothing (and returned due to WNOHANG).
140		if wpid != 0 {
141			return
142		}
143
144		if allowBreakpointsChange {
145			select {
146			case c := <-s.breakpointc:
147				return 0, 0, &breakpointsChangedError{c}
148			default:
149			}
150		}
151
152		time.Sleep(sleep)
153		if sleep < maxSleep {
154			sleep *= 10
155		}
156	}
157}
158