1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build linux
6// +build linux
7
8package unix_test
9
10import (
11	"bytes"
12	"net"
13	"os"
14	"testing"
15
16	"golang.org/x/sys/unix"
17)
18
19// TestSCMCredentials tests the sending and receiving of credentials
20// (PID, UID, GID) in an ancillary message between two UNIX
21// sockets. The SO_PASSCRED socket option is enabled on the sending
22// socket for this to work.
23func TestSCMCredentials(t *testing.T) {
24	socketTypeTests := []struct {
25		socketType int
26		dataLen    int
27	}{
28		{
29			unix.SOCK_STREAM,
30			1,
31		}, {
32			unix.SOCK_DGRAM,
33			0,
34		},
35	}
36
37	for _, tt := range socketTypeTests {
38		fds, err := unix.Socketpair(unix.AF_LOCAL, tt.socketType, 0)
39		if err != nil {
40			t.Fatalf("Socketpair: %v", err)
41		}
42
43		err = unix.SetsockoptInt(fds[0], unix.SOL_SOCKET, unix.SO_PASSCRED, 1)
44		if err != nil {
45			unix.Close(fds[0])
46			unix.Close(fds[1])
47			t.Fatalf("SetsockoptInt: %v", err)
48		}
49
50		srvFile := os.NewFile(uintptr(fds[0]), "server")
51		cliFile := os.NewFile(uintptr(fds[1]), "client")
52		defer srvFile.Close()
53		defer cliFile.Close()
54
55		srv, err := net.FileConn(srvFile)
56		if err != nil {
57			t.Errorf("FileConn: %v", err)
58			return
59		}
60		defer srv.Close()
61
62		cli, err := net.FileConn(cliFile)
63		if err != nil {
64			t.Errorf("FileConn: %v", err)
65			return
66		}
67		defer cli.Close()
68
69		var ucred unix.Ucred
70		ucred.Pid = int32(os.Getpid())
71		ucred.Uid = uint32(os.Getuid())
72		ucred.Gid = uint32(os.Getgid())
73		oob := unix.UnixCredentials(&ucred)
74
75		// On SOCK_STREAM, this is internally going to send a dummy byte
76		n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
77		if err != nil {
78			t.Fatalf("WriteMsgUnix: %v", err)
79		}
80		if n != 0 {
81			t.Fatalf("WriteMsgUnix n = %d, want 0", n)
82		}
83		if oobn != len(oob) {
84			t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob))
85		}
86
87		oob2 := make([]byte, 10*len(oob))
88		n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2)
89		if err != nil {
90			t.Fatalf("ReadMsgUnix: %v", err)
91		}
92		if flags != 0 && flags != unix.MSG_CMSG_CLOEXEC {
93			t.Fatalf("ReadMsgUnix flags = %#x, want 0 or %#x (MSG_CMSG_CLOEXEC)", flags, unix.MSG_CMSG_CLOEXEC)
94		}
95		if n != tt.dataLen {
96			t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)
97		}
98		if oobn2 != oobn {
99			// without SO_PASSCRED set on the socket, ReadMsgUnix will
100			// return zero oob bytes
101			t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn)
102		}
103		oob2 = oob2[:oobn2]
104		if !bytes.Equal(oob, oob2) {
105			t.Fatal("ReadMsgUnix oob bytes don't match")
106		}
107
108		scm, err := unix.ParseSocketControlMessage(oob2)
109		if err != nil {
110			t.Fatalf("ParseSocketControlMessage: %v", err)
111		}
112		newUcred, err := unix.ParseUnixCredentials(&scm[0])
113		if err != nil {
114			t.Fatalf("ParseUnixCredentials: %v", err)
115		}
116		if *newUcred != ucred {
117			t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred)
118		}
119	}
120}
121