1/*
2   Copyright The containerd Authors.
3
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15*/
16
17package shim
18
19import (
20	"context"
21	"flag"
22	"fmt"
23	"io"
24	"os"
25	"runtime"
26	"runtime/debug"
27	"strings"
28	"time"
29
30	"github.com/containerd/containerd/events"
31	"github.com/containerd/containerd/log"
32	"github.com/containerd/containerd/namespaces"
33	shimapi "github.com/containerd/containerd/runtime/v2/task"
34	"github.com/containerd/containerd/version"
35	"github.com/containerd/ttrpc"
36	"github.com/gogo/protobuf/proto"
37	"github.com/pkg/errors"
38	"github.com/sirupsen/logrus"
39)
40
41// Client for a shim server
42type Client struct {
43	service shimapi.TaskService
44	context context.Context
45	signals chan os.Signal
46}
47
48// Publisher for events
49type Publisher interface {
50	events.Publisher
51	io.Closer
52}
53
54// Init func for the creation of a shim server
55type Init func(context.Context, string, Publisher, func()) (Shim, error)
56
57// Shim server interface
58type Shim interface {
59	shimapi.TaskService
60	Cleanup(ctx context.Context) (*shimapi.DeleteResponse, error)
61	StartShim(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (string, error)
62}
63
64// OptsKey is the context key for the Opts value.
65type OptsKey struct{}
66
67// Opts are context options associated with the shim invocation.
68type Opts struct {
69	BundlePath string
70	Debug      bool
71}
72
73// BinaryOpts allows the configuration of a shims binary setup
74type BinaryOpts func(*Config)
75
76// Config of shim binary options provided by shim implementations
77type Config struct {
78	// NoSubreaper disables setting the shim as a child subreaper
79	NoSubreaper bool
80	// NoReaper disables the shim binary from reaping any child process implicitly
81	NoReaper bool
82	// NoSetupLogger disables automatic configuration of logrus to use the shim FIFO
83	NoSetupLogger bool
84}
85
86var (
87	debugFlag            bool
88	versionFlag          bool
89	idFlag               string
90	namespaceFlag        string
91	socketFlag           string
92	bundlePath           string
93	addressFlag          string
94	containerdBinaryFlag string
95	action               string
96)
97
98const (
99	ttrpcAddressEnv = "TTRPC_ADDRESS"
100)
101
102func parseFlags() {
103	flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs")
104	flag.BoolVar(&versionFlag, "v", false, "show the shim version and exit")
105	flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim")
106	flag.StringVar(&idFlag, "id", "", "id of the task")
107	flag.StringVar(&socketFlag, "socket", "", "socket path to serve")
108	flag.StringVar(&bundlePath, "bundle", "", "path to the bundle if not workdir")
109
110	flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd")
111	flag.StringVar(&containerdBinaryFlag, "publish-binary", "containerd", "path to publish binary (used for publishing events)")
112
113	flag.Parse()
114	action = flag.Arg(0)
115}
116
117func setRuntime() {
118	debug.SetGCPercent(40)
119	go func() {
120		for range time.Tick(30 * time.Second) {
121			debug.FreeOSMemory()
122		}
123	}()
124	if os.Getenv("GOMAXPROCS") == "" {
125		// If GOMAXPROCS hasn't been set, we default to a value of 2 to reduce
126		// the number of Go stacks present in the shim.
127		runtime.GOMAXPROCS(2)
128	}
129}
130
131func setLogger(ctx context.Context, id string) error {
132	logrus.SetFormatter(&logrus.TextFormatter{
133		TimestampFormat: log.RFC3339NanoFixed,
134		FullTimestamp:   true,
135	})
136	if debugFlag {
137		logrus.SetLevel(logrus.DebugLevel)
138	}
139	f, err := openLog(ctx, id)
140	if err != nil {
141		return err
142	}
143	logrus.SetOutput(f)
144	return nil
145}
146
147// Run initializes and runs a shim server
148func Run(id string, initFunc Init, opts ...BinaryOpts) {
149	var config Config
150	for _, o := range opts {
151		o(&config)
152	}
153	if err := run(id, initFunc, config); err != nil {
154		fmt.Fprintf(os.Stderr, "%s: %s\n", id, err)
155		os.Exit(1)
156	}
157}
158
159func run(id string, initFunc Init, config Config) error {
160	parseFlags()
161	if versionFlag {
162		fmt.Printf("%s:\n", os.Args[0])
163		fmt.Println("  Version: ", version.Version)
164		fmt.Println("  Revision:", version.Revision)
165		fmt.Println("  Go version:", version.GoVersion)
166		fmt.Println("")
167		return nil
168	}
169
170	setRuntime()
171
172	signals, err := setupSignals(config)
173	if err != nil {
174		return err
175	}
176	if !config.NoSubreaper {
177		if err := subreaper(); err != nil {
178			return err
179		}
180	}
181
182	ttrpcAddress := os.Getenv(ttrpcAddressEnv)
183
184	publisher, err := NewPublisher(ttrpcAddress)
185	if err != nil {
186		return err
187	}
188
189	defer publisher.Close()
190
191	if namespaceFlag == "" {
192		return fmt.Errorf("shim namespace cannot be empty")
193	}
194	ctx := namespaces.WithNamespace(context.Background(), namespaceFlag)
195	ctx = context.WithValue(ctx, OptsKey{}, Opts{BundlePath: bundlePath, Debug: debugFlag})
196	ctx = log.WithLogger(ctx, log.G(ctx).WithField("runtime", id))
197	ctx, cancel := context.WithCancel(ctx)
198	service, err := initFunc(ctx, idFlag, publisher, cancel)
199	if err != nil {
200		return err
201	}
202	switch action {
203	case "delete":
204		logger := logrus.WithFields(logrus.Fields{
205			"pid":       os.Getpid(),
206			"namespace": namespaceFlag,
207		})
208		go handleSignals(ctx, logger, signals)
209		response, err := service.Cleanup(ctx)
210		if err != nil {
211			return err
212		}
213		data, err := proto.Marshal(response)
214		if err != nil {
215			return err
216		}
217		if _, err := os.Stdout.Write(data); err != nil {
218			return err
219		}
220		return nil
221	case "start":
222		address, err := service.StartShim(ctx, idFlag, containerdBinaryFlag, addressFlag, ttrpcAddress)
223		if err != nil {
224			return err
225		}
226		if _, err := os.Stdout.WriteString(address); err != nil {
227			return err
228		}
229		return nil
230	default:
231		if !config.NoSetupLogger {
232			if err := setLogger(ctx, idFlag); err != nil {
233				return err
234			}
235		}
236		client := NewShimClient(ctx, service, signals)
237		if err := client.Serve(); err != nil {
238			if err != context.Canceled {
239				return err
240			}
241		}
242		select {
243		case <-publisher.Done():
244			return nil
245		case <-time.After(5 * time.Second):
246			return errors.New("publisher not closed")
247		}
248	}
249}
250
251// NewShimClient creates a new shim server client
252func NewShimClient(ctx context.Context, svc shimapi.TaskService, signals chan os.Signal) *Client {
253	s := &Client{
254		service: svc,
255		context: ctx,
256		signals: signals,
257	}
258	return s
259}
260
261// Serve the shim server
262func (s *Client) Serve() error {
263	dump := make(chan os.Signal, 32)
264	setupDumpStacks(dump)
265
266	path, err := os.Getwd()
267	if err != nil {
268		return err
269	}
270	server, err := newServer()
271	if err != nil {
272		return errors.Wrap(err, "failed creating server")
273	}
274
275	logrus.Debug("registering ttrpc server")
276	shimapi.RegisterTaskService(server, s.service)
277
278	if err := serve(s.context, server, socketFlag); err != nil {
279		return err
280	}
281	logger := logrus.WithFields(logrus.Fields{
282		"pid":       os.Getpid(),
283		"path":      path,
284		"namespace": namespaceFlag,
285	})
286	go func() {
287		for range dump {
288			dumpStacks(logger)
289		}
290	}()
291	return handleSignals(s.context, logger, s.signals)
292}
293
294// serve serves the ttrpc API over a unix socket at the provided path
295// this function does not block
296func serve(ctx context.Context, server *ttrpc.Server, path string) error {
297	l, err := serveListener(path)
298	if err != nil {
299		return err
300	}
301	go func() {
302		if err := server.Serve(ctx, l); err != nil &&
303			!strings.Contains(err.Error(), "use of closed network connection") {
304			logrus.WithError(err).Fatal("containerd-shim: ttrpc server failure")
305		}
306		l.Close()
307		if address, err := ReadAddress("address"); err == nil {
308			_ = RemoveSocket(address)
309		}
310
311	}()
312	return nil
313}
314
315func dumpStacks(logger *logrus.Entry) {
316	var (
317		buf       []byte
318		stackSize int
319	)
320	bufferLen := 16384
321	for stackSize == len(buf) {
322		buf = make([]byte, bufferLen)
323		stackSize = runtime.Stack(buf, true)
324		bufferLen *= 2
325	}
326	buf = buf[:stackSize]
327	logger.Infof("=== BEGIN goroutine stack dump ===\n%s\n=== END goroutine stack dump ===", buf)
328}
329