1/*
2   Copyright The containerd Authors.
3
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15*/
16
17package shim
18
19import (
20	"context"
21	"flag"
22	"fmt"
23	"io"
24	"os"
25	"runtime"
26	"runtime/debug"
27	"strings"
28	"time"
29
30	"github.com/containerd/containerd/events"
31	"github.com/containerd/containerd/log"
32	"github.com/containerd/containerd/namespaces"
33	shimapi "github.com/containerd/containerd/runtime/v2/task"
34	"github.com/containerd/containerd/version"
35	"github.com/containerd/ttrpc"
36	"github.com/gogo/protobuf/proto"
37	"github.com/pkg/errors"
38	"github.com/sirupsen/logrus"
39)
40
41// Client for a shim server
42type Client struct {
43	service shimapi.TaskService
44	context context.Context
45	signals chan os.Signal
46}
47
48// Publisher for events
49type Publisher interface {
50	events.Publisher
51	io.Closer
52}
53
54// StartOpts describes shim start configuration received from containerd
55type StartOpts struct {
56	ID               string
57	ContainerdBinary string
58	Address          string
59	TTRPCAddress     string
60}
61
62// Init func for the creation of a shim server
63type Init func(context.Context, string, Publisher, func()) (Shim, error)
64
65// Shim server interface
66type Shim interface {
67	shimapi.TaskService
68	Cleanup(ctx context.Context) (*shimapi.DeleteResponse, error)
69	StartShim(ctx context.Context, opts StartOpts) (string, error)
70}
71
72// OptsKey is the context key for the Opts value.
73type OptsKey struct{}
74
75// Opts are context options associated with the shim invocation.
76type Opts struct {
77	BundlePath string
78	Debug      bool
79}
80
81// BinaryOpts allows the configuration of a shims binary setup
82type BinaryOpts func(*Config)
83
84// Config of shim binary options provided by shim implementations
85type Config struct {
86	// NoSubreaper disables setting the shim as a child subreaper
87	NoSubreaper bool
88	// NoReaper disables the shim binary from reaping any child process implicitly
89	NoReaper bool
90	// NoSetupLogger disables automatic configuration of logrus to use the shim FIFO
91	NoSetupLogger bool
92}
93
94var (
95	debugFlag            bool
96	versionFlag          bool
97	idFlag               string
98	namespaceFlag        string
99	socketFlag           string
100	bundlePath           string
101	addressFlag          string
102	containerdBinaryFlag string
103	action               string
104)
105
106const (
107	ttrpcAddressEnv = "TTRPC_ADDRESS"
108)
109
110func parseFlags() {
111	flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs")
112	flag.BoolVar(&versionFlag, "v", false, "show the shim version and exit")
113	flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim")
114	flag.StringVar(&idFlag, "id", "", "id of the task")
115	flag.StringVar(&socketFlag, "socket", "", "socket path to serve")
116	flag.StringVar(&bundlePath, "bundle", "", "path to the bundle if not workdir")
117
118	flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd")
119	flag.StringVar(&containerdBinaryFlag, "publish-binary", "containerd", "path to publish binary (used for publishing events)")
120
121	flag.Parse()
122	action = flag.Arg(0)
123}
124
125func setRuntime() {
126	debug.SetGCPercent(40)
127	go func() {
128		for range time.Tick(30 * time.Second) {
129			debug.FreeOSMemory()
130		}
131	}()
132	if os.Getenv("GOMAXPROCS") == "" {
133		// If GOMAXPROCS hasn't been set, we default to a value of 2 to reduce
134		// the number of Go stacks present in the shim.
135		runtime.GOMAXPROCS(2)
136	}
137}
138
139func setLogger(ctx context.Context, id string) error {
140	logrus.SetFormatter(&logrus.TextFormatter{
141		TimestampFormat: log.RFC3339NanoFixed,
142		FullTimestamp:   true,
143	})
144	if debugFlag {
145		logrus.SetLevel(logrus.DebugLevel)
146	}
147	f, err := openLog(ctx, id)
148	if err != nil {
149		return err
150	}
151	logrus.SetOutput(f)
152	return nil
153}
154
155// Run initializes and runs a shim server
156func Run(id string, initFunc Init, opts ...BinaryOpts) {
157	var config Config
158	for _, o := range opts {
159		o(&config)
160	}
161	if err := run(id, initFunc, config); err != nil {
162		fmt.Fprintf(os.Stderr, "%s: %s\n", id, err)
163		os.Exit(1)
164	}
165}
166
167func run(id string, initFunc Init, config Config) error {
168	parseFlags()
169	if versionFlag {
170		fmt.Printf("%s:\n", os.Args[0])
171		fmt.Println("  Version: ", version.Version)
172		fmt.Println("  Revision:", version.Revision)
173		fmt.Println("  Go version:", version.GoVersion)
174		fmt.Println("")
175		return nil
176	}
177
178	if namespaceFlag == "" {
179		return fmt.Errorf("shim namespace cannot be empty")
180	}
181
182	setRuntime()
183
184	signals, err := setupSignals(config)
185	if err != nil {
186		return err
187	}
188
189	if !config.NoSubreaper {
190		if err := subreaper(); err != nil {
191			return err
192		}
193	}
194
195	ttrpcAddress := os.Getenv(ttrpcAddressEnv)
196	publisher, err := NewPublisher(ttrpcAddress)
197	if err != nil {
198		return err
199	}
200	defer publisher.Close()
201
202	ctx := namespaces.WithNamespace(context.Background(), namespaceFlag)
203	ctx = context.WithValue(ctx, OptsKey{}, Opts{BundlePath: bundlePath, Debug: debugFlag})
204	ctx = log.WithLogger(ctx, log.G(ctx).WithField("runtime", id))
205	ctx, cancel := context.WithCancel(ctx)
206	service, err := initFunc(ctx, idFlag, publisher, cancel)
207	if err != nil {
208		return err
209	}
210
211	switch action {
212	case "delete":
213		logger := logrus.WithFields(logrus.Fields{
214			"pid":       os.Getpid(),
215			"namespace": namespaceFlag,
216		})
217		go handleSignals(ctx, logger, signals)
218		response, err := service.Cleanup(ctx)
219		if err != nil {
220			return err
221		}
222		data, err := proto.Marshal(response)
223		if err != nil {
224			return err
225		}
226		if _, err := os.Stdout.Write(data); err != nil {
227			return err
228		}
229		return nil
230	case "start":
231		opts := StartOpts{
232			ID:               idFlag,
233			ContainerdBinary: containerdBinaryFlag,
234			Address:          addressFlag,
235			TTRPCAddress:     ttrpcAddress,
236		}
237		address, err := service.StartShim(ctx, opts)
238		if err != nil {
239			return err
240		}
241		if _, err := os.Stdout.WriteString(address); err != nil {
242			return err
243		}
244		return nil
245	default:
246		if !config.NoSetupLogger {
247			if err := setLogger(ctx, idFlag); err != nil {
248				return err
249			}
250		}
251		client := NewShimClient(ctx, service, signals)
252		if err := client.Serve(); err != nil {
253			if err != context.Canceled {
254				return err
255			}
256		}
257
258		// NOTE: If the shim server is down(like oom killer), the address
259		// socket might be leaking.
260		if address, err := ReadAddress("address"); err == nil {
261			_ = RemoveSocket(address)
262		}
263
264		select {
265		case <-publisher.Done():
266			return nil
267		case <-time.After(5 * time.Second):
268			return errors.New("publisher not closed")
269		}
270	}
271}
272
273// NewShimClient creates a new shim server client
274func NewShimClient(ctx context.Context, svc shimapi.TaskService, signals chan os.Signal) *Client {
275	s := &Client{
276		service: svc,
277		context: ctx,
278		signals: signals,
279	}
280	return s
281}
282
283// Serve the shim server
284func (s *Client) Serve() error {
285	dump := make(chan os.Signal, 32)
286	setupDumpStacks(dump)
287
288	path, err := os.Getwd()
289	if err != nil {
290		return err
291	}
292	server, err := newServer()
293	if err != nil {
294		return errors.Wrap(err, "failed creating server")
295	}
296
297	logrus.Debug("registering ttrpc server")
298	shimapi.RegisterTaskService(server, s.service)
299
300	if err := serve(s.context, server, socketFlag); err != nil {
301		return err
302	}
303	logger := logrus.WithFields(logrus.Fields{
304		"pid":       os.Getpid(),
305		"path":      path,
306		"namespace": namespaceFlag,
307	})
308	go func() {
309		for range dump {
310			dumpStacks(logger)
311		}
312	}()
313	return handleSignals(s.context, logger, s.signals)
314}
315
316// serve serves the ttrpc API over a unix socket at the provided path
317// this function does not block
318func serve(ctx context.Context, server *ttrpc.Server, path string) error {
319	l, err := serveListener(path)
320	if err != nil {
321		return err
322	}
323	go func() {
324		defer l.Close()
325		if err := server.Serve(ctx, l); err != nil &&
326			!strings.Contains(err.Error(), "use of closed network connection") {
327			logrus.WithError(err).Fatal("containerd-shim: ttrpc server failure")
328		}
329	}()
330	return nil
331}
332
333func dumpStacks(logger *logrus.Entry) {
334	var (
335		buf       []byte
336		stackSize int
337	)
338	bufferLen := 16384
339	for stackSize == len(buf) {
340		buf = make([]byte, bufferLen)
341		stackSize = runtime.Stack(buf, true)
342		bufferLen *= 2
343	}
344	buf = buf[:stackSize]
345	logger.Infof("=== BEGIN goroutine stack dump ===\n%s\n=== END goroutine stack dump ===", buf)
346}
347