1package cgroups
2
3import (
4	"bytes"
5	"os"
6	"strings"
7	"sync"
8
9	"github.com/pkg/errors"
10	"github.com/sirupsen/logrus"
11	"golang.org/x/sys/unix"
12)
13
14// OpenFile opens a cgroup file in a given dir with given flags.
15// It is supposed to be used for cgroup files only.
16func OpenFile(dir, file string, flags int) (*os.File, error) {
17	if dir == "" {
18		return nil, errors.Errorf("no directory specified for %s", file)
19	}
20	return openFile(dir, file, flags)
21}
22
23// ReadFile reads data from a cgroup file in dir.
24// It is supposed to be used for cgroup files only.
25func ReadFile(dir, file string) (string, error) {
26	fd, err := OpenFile(dir, file, unix.O_RDONLY)
27	if err != nil {
28		return "", err
29	}
30	defer fd.Close()
31	var buf bytes.Buffer
32
33	_, err = buf.ReadFrom(fd)
34	return buf.String(), err
35}
36
37// WriteFile writes data to a cgroup file in dir.
38// It is supposed to be used for cgroup files only.
39func WriteFile(dir, file, data string) error {
40	fd, err := OpenFile(dir, file, unix.O_WRONLY)
41	if err != nil {
42		return err
43	}
44	defer fd.Close()
45	if err := retryingWriteFile(fd, data); err != nil {
46		return errors.Wrapf(err, "failed to write %q", data)
47	}
48	return nil
49}
50
51func retryingWriteFile(fd *os.File, data string) error {
52	for {
53		_, err := fd.Write([]byte(data))
54		if errors.Is(err, unix.EINTR) {
55			logrus.Infof("interrupted while writing %s to %s", data, fd.Name())
56			continue
57		}
58		return err
59	}
60}
61
62const (
63	cgroupfsDir    = "/sys/fs/cgroup"
64	cgroupfsPrefix = cgroupfsDir + "/"
65)
66
67var (
68	// TestMode is set to true by unit tests that need "fake" cgroupfs.
69	TestMode bool
70
71	cgroupFd     int = -1
72	prepOnce     sync.Once
73	prepErr      error
74	resolveFlags uint64
75)
76
77func prepareOpenat2() error {
78	prepOnce.Do(func() {
79		fd, err := unix.Openat2(-1, cgroupfsDir, &unix.OpenHow{
80			Flags: unix.O_DIRECTORY | unix.O_PATH,
81		})
82		if err != nil {
83			prepErr = &os.PathError{Op: "openat2", Path: cgroupfsDir, Err: err}
84			if err != unix.ENOSYS {
85				logrus.Warnf("falling back to securejoin: %s", prepErr)
86			} else {
87				logrus.Debug("openat2 not available, falling back to securejoin")
88			}
89			return
90		}
91		var st unix.Statfs_t
92		if err = unix.Fstatfs(fd, &st); err != nil {
93			prepErr = &os.PathError{Op: "statfs", Path: cgroupfsDir, Err: err}
94			logrus.Warnf("falling back to securejoin: %s", prepErr)
95			return
96		}
97
98		cgroupFd = fd
99
100		resolveFlags = unix.RESOLVE_BENEATH | unix.RESOLVE_NO_MAGICLINKS
101		if st.Type == unix.CGROUP2_SUPER_MAGIC {
102			// cgroupv2 has a single mountpoint and no "cpu,cpuacct" symlinks
103			resolveFlags |= unix.RESOLVE_NO_XDEV | unix.RESOLVE_NO_SYMLINKS
104		}
105	})
106
107	return prepErr
108}
109
110// OpenFile opens a cgroup file in a given dir with given flags.
111// It is supposed to be used for cgroup files only.
112func openFile(dir, file string, flags int) (*os.File, error) {
113	mode := os.FileMode(0)
114	if TestMode && flags&os.O_WRONLY != 0 {
115		// "emulate" cgroup fs for unit tests
116		flags |= os.O_TRUNC | os.O_CREATE
117		mode = 0o600
118	}
119	if prepareOpenat2() != nil {
120		return openFallback(dir, file, flags, mode)
121	}
122	reldir := strings.TrimPrefix(dir, cgroupfsPrefix)
123	if len(reldir) == len(dir) { // non-standard path, old system?
124		return openFallback(dir, file, flags, mode)
125	}
126
127	relname := reldir + "/" + file
128	fd, err := unix.Openat2(cgroupFd, relname,
129		&unix.OpenHow{
130			Resolve: resolveFlags,
131			Flags:   uint64(flags) | unix.O_CLOEXEC,
132			Mode:    uint64(mode),
133		})
134	if err != nil {
135		return nil, &os.PathError{Op: "openat2", Path: dir + "/" + file, Err: err}
136	}
137
138	return os.NewFile(uintptr(fd), cgroupfsPrefix+relname), nil
139}
140
141var errNotCgroupfs = errors.New("not a cgroup file")
142
143// openFallback is used when openat2(2) is not available. It checks the opened
144// file is on cgroupfs, returning an error otherwise.
145func openFallback(dir, file string, flags int, mode os.FileMode) (*os.File, error) {
146	path := dir + "/" + file
147	fd, err := os.OpenFile(path, flags, mode)
148	if err != nil {
149		return nil, err
150	}
151	if TestMode {
152		return fd, nil
153	}
154	// Check this is a cgroupfs file.
155	var st unix.Statfs_t
156	if err := unix.Fstatfs(int(fd.Fd()), &st); err != nil {
157		_ = fd.Close()
158		return nil, &os.PathError{Op: "statfs", Path: path, Err: err}
159	}
160	if st.Type != unix.CGROUP_SUPER_MAGIC && st.Type != unix.CGROUP2_SUPER_MAGIC {
161		_ = fd.Close()
162		return nil, &os.PathError{Op: "open", Path: path, Err: errNotCgroupfs}
163	}
164
165	return fd, nil
166}
167