1// Copyright 2016 CoreOS, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// +build cgo
16
17package util
18
19// #include <stdlib.h>
20// #include <sys/types.h>
21// #include <unistd.h>
22//
23// int
24// my_sd_pid_get_owner_uid(void *f, pid_t pid, uid_t *uid)
25// {
26//   int (*sd_pid_get_owner_uid)(pid_t, uid_t *);
27//
28//   sd_pid_get_owner_uid = (int (*)(pid_t, uid_t *))f;
29//   return sd_pid_get_owner_uid(pid, uid);
30// }
31//
32// int
33// my_sd_pid_get_unit(void *f, pid_t pid, char **unit)
34// {
35//   int (*sd_pid_get_unit)(pid_t, char **);
36//
37//   sd_pid_get_unit = (int (*)(pid_t, char **))f;
38//   return sd_pid_get_unit(pid, unit);
39// }
40//
41// int
42// my_sd_pid_get_slice(void *f, pid_t pid, char **slice)
43// {
44//   int (*sd_pid_get_slice)(pid_t, char **);
45//
46//   sd_pid_get_slice = (int (*)(pid_t, char **))f;
47//   return sd_pid_get_slice(pid, slice);
48// }
49//
50// int
51// am_session_leader()
52// {
53//   return (getsid(0) == getpid());
54// }
55import "C"
56import (
57	"fmt"
58	"syscall"
59	"unsafe"
60
61	"github.com/coreos/pkg/dlopen"
62)
63
64var libsystemdNames = []string{
65	// systemd < 209
66	"libsystemd-login.so.0",
67	"libsystemd-login.so",
68
69	// systemd >= 209 merged libsystemd-login into libsystemd proper
70	"libsystemd.so.0",
71	"libsystemd.so",
72}
73
74func getRunningSlice() (slice string, err error) {
75	var h *dlopen.LibHandle
76	h, err = dlopen.GetHandle(libsystemdNames)
77	if err != nil {
78		return
79	}
80	defer func() {
81		if err1 := h.Close(); err1 != nil {
82			err = err1
83		}
84	}()
85
86	sd_pid_get_slice, err := h.GetSymbolPointer("sd_pid_get_slice")
87	if err != nil {
88		return
89	}
90
91	var s string
92	sl := C.CString(s)
93	defer C.free(unsafe.Pointer(sl))
94
95	ret := C.my_sd_pid_get_slice(sd_pid_get_slice, 0, &sl)
96	if ret < 0 {
97		err = fmt.Errorf("error calling sd_pid_get_slice: %v", syscall.Errno(-ret))
98		return
99	}
100
101	return C.GoString(sl), nil
102}
103
104func runningFromSystemService() (ret bool, err error) {
105	var h *dlopen.LibHandle
106	h, err = dlopen.GetHandle(libsystemdNames)
107	if err != nil {
108		return
109	}
110	defer func() {
111		if err1 := h.Close(); err1 != nil {
112			err = err1
113		}
114	}()
115
116	sd_pid_get_owner_uid, err := h.GetSymbolPointer("sd_pid_get_owner_uid")
117	if err != nil {
118		return
119	}
120
121	var uid C.uid_t
122	errno := C.my_sd_pid_get_owner_uid(sd_pid_get_owner_uid, 0, &uid)
123	serrno := syscall.Errno(-errno)
124	// when we're running from a unit file, sd_pid_get_owner_uid returns
125	// ENOENT (systemd <220), ENXIO (systemd 220-223), or ENODATA
126	// (systemd >=234)
127	switch {
128	case errno >= 0:
129		ret = false
130	case serrno == syscall.ENOENT, serrno == syscall.ENXIO, serrno == syscall.ENODATA:
131		// Since the implementation of sessions in systemd relies on
132		// the `pam_systemd` module, using the sd_pid_get_owner_uid
133		// heuristic alone can result in false positives if that module
134		// (or PAM itself) is not present or properly configured on the
135		// system. As such, we also check if we're the session leader,
136		// which should be the case if we're invoked from a unit file,
137		// but not if e.g. we're invoked from the command line from a
138		// user's login session
139		ret = C.am_session_leader() == 1
140	default:
141		err = fmt.Errorf("error calling sd_pid_get_owner_uid: %v", syscall.Errno(-errno))
142	}
143	return
144}
145
146func currentUnitName() (unit string, err error) {
147	var h *dlopen.LibHandle
148	h, err = dlopen.GetHandle(libsystemdNames)
149	if err != nil {
150		return
151	}
152	defer func() {
153		if err1 := h.Close(); err1 != nil {
154			err = err1
155		}
156	}()
157
158	sd_pid_get_unit, err := h.GetSymbolPointer("sd_pid_get_unit")
159	if err != nil {
160		return
161	}
162
163	var s string
164	u := C.CString(s)
165	defer C.free(unsafe.Pointer(u))
166
167	ret := C.my_sd_pid_get_unit(sd_pid_get_unit, 0, &u)
168	if ret < 0 {
169		err = fmt.Errorf("error calling sd_pid_get_unit: %v", syscall.Errno(-ret))
170		return
171	}
172
173	unit = C.GoString(u)
174	return
175}
176