1// +build linux
2
3/*
4 *
5 * Copyright 2020 gRPC authors.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *     https://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 */
20
21package test
22
23import (
24	"context"
25	"fmt"
26	"net"
27	"os"
28	"strings"
29	"sync"
30	"testing"
31	"time"
32
33	"google.golang.org/grpc"
34	"google.golang.org/grpc/codes"
35	"google.golang.org/grpc/internal/stubserver"
36	"google.golang.org/grpc/metadata"
37	"google.golang.org/grpc/status"
38	testpb "google.golang.org/grpc/test/grpc_testing"
39)
40
41func authorityChecker(ctx context.Context, expectedAuthority string) (*testpb.Empty, error) {
42	md, ok := metadata.FromIncomingContext(ctx)
43	if !ok {
44		return nil, status.Error(codes.InvalidArgument, "failed to parse metadata")
45	}
46	auths, ok := md[":authority"]
47	if !ok {
48		return nil, status.Error(codes.InvalidArgument, "no authority header")
49	}
50	if len(auths) != 1 {
51		return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("no authority header, auths = %v", auths))
52	}
53	if auths[0] != expectedAuthority {
54		return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, expected %v", auths[0], expectedAuthority))
55	}
56	return &testpb.Empty{}, nil
57}
58
59func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer func(context.Context, string) (net.Conn, error)) {
60	if !strings.HasPrefix(target, "unix-abstract:") {
61		if err := os.RemoveAll(address); err != nil {
62			t.Fatalf("Error removing socket file %v: %v\n", address, err)
63		}
64	}
65	ss := &stubserver.StubServer{
66		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
67			return authorityChecker(ctx, expectedAuthority)
68		},
69		Network: "unix",
70		Address: address,
71		Target:  target,
72	}
73	opts := []grpc.DialOption{}
74	if dialer != nil {
75		opts = append(opts, grpc.WithContextDialer(dialer))
76	}
77	if err := ss.Start(nil, opts...); err != nil {
78		t.Fatalf("Error starting endpoint server: %v", err)
79	}
80	defer ss.Stop()
81	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
82	defer cancel()
83	_, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
84	if err != nil {
85		t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
86	}
87}
88
89type authorityTest struct {
90	name           string
91	address        string
92	target         string
93	authority      string
94	dialTargetWant string
95}
96
97var authorityTests = []authorityTest{
98	{
99		name:      "UnixRelative",
100		address:   "sock.sock",
101		target:    "unix:sock.sock",
102		authority: "localhost",
103	},
104	{
105		name:      "UnixAbsolute",
106		address:   "/tmp/sock.sock",
107		target:    "unix:/tmp/sock.sock",
108		authority: "localhost",
109	},
110	{
111		name:      "UnixAbsoluteAlternate",
112		address:   "/tmp/sock.sock",
113		target:    "unix:///tmp/sock.sock",
114		authority: "localhost",
115	},
116	{
117		name:           "UnixPassthrough",
118		address:        "/tmp/sock.sock",
119		target:         "passthrough:///unix:///tmp/sock.sock",
120		authority:      "unix:///tmp/sock.sock",
121		dialTargetWant: "unix:///tmp/sock.sock",
122	},
123	{
124		name:           "UnixAbstract",
125		address:        "\x00abc efg",
126		target:         "unix-abstract:abc efg",
127		authority:      "localhost",
128		dialTargetWant: "\x00abc efg",
129	},
130}
131
132// TestUnix does end to end tests with the various supported unix target
133// formats, ensuring that the authority is set as expected.
134func (s) TestUnix(t *testing.T) {
135	for _, test := range authorityTests {
136		t.Run(test.name, func(t *testing.T) {
137			runUnixTest(t, test.address, test.target, test.authority, nil)
138		})
139	}
140}
141
142// TestUnixCustomDialer does end to end tests with various supported unix target
143// formats, ensuring that the target sent to the dialer does NOT have the
144// "unix:" prefix stripped.
145func (s) TestUnixCustomDialer(t *testing.T) {
146	for _, test := range authorityTests {
147		t.Run(test.name+"WithDialer", func(t *testing.T) {
148			if test.dialTargetWant == "" {
149				test.dialTargetWant = test.target
150			}
151			dialer := func(ctx context.Context, address string) (net.Conn, error) {
152				if address != test.dialTargetWant {
153					return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.dialTargetWant, address)
154				}
155				if !strings.HasPrefix(test.target, "unix-abstract:") {
156					address = address[len("unix:"):]
157				}
158				return (&net.Dialer{}).DialContext(ctx, "unix", address)
159			}
160			runUnixTest(t, test.address, test.target, test.authority, dialer)
161		})
162	}
163}
164
165// TestColonPortAuthority does an end to end test with the target for grpc.Dial
166// being ":[port]". Ensures authority is "localhost:[port]".
167func (s) TestColonPortAuthority(t *testing.T) {
168	expectedAuthority := ""
169	var authorityMu sync.Mutex
170	ss := &stubserver.StubServer{
171		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
172			authorityMu.Lock()
173			defer authorityMu.Unlock()
174			return authorityChecker(ctx, expectedAuthority)
175		},
176		Network: "tcp",
177	}
178	if err := ss.Start(nil); err != nil {
179		t.Fatalf("Error starting endpoint server: %v", err)
180	}
181	defer ss.Stop()
182	_, port, err := net.SplitHostPort(ss.Address)
183	if err != nil {
184		t.Fatalf("Failed splitting host from post: %v", err)
185	}
186	authorityMu.Lock()
187	expectedAuthority = "localhost:" + port
188	authorityMu.Unlock()
189	// ss.Start dials, but not the ":[port]" target that is being tested here.
190	// Dial again, with ":[port]" as the target.
191	//
192	// Append "localhost" before calling net.Dial, in case net.Dial on certain
193	// platforms doesn't work well for address without the IP.
194	cc, err := grpc.Dial(":"+port, grpc.WithInsecure(), grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
195		return (&net.Dialer{}).DialContext(ctx, "tcp", "localhost"+addr)
196	}))
197	if err != nil {
198		t.Fatalf("grpc.Dial(%q) = %v", ss.Target, err)
199	}
200	defer cc.Close()
201	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
202	defer cancel()
203	_, err = testpb.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{})
204	if err != nil {
205		t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
206	}
207}
208