1// +build linux 2 3/* 4 * 5 * Copyright 2020 gRPC authors. 6 * 7 * Licensed under the Apache License, Version 2.0 (the "License"); 8 * you may not use this file except in compliance with the License. 9 * You may obtain a copy of the License at 10 * 11 * https://www.apache.org/licenses/LICENSE-2.0 12 * 13 * Unless required by applicable law or agreed to in writing, software 14 * distributed under the License is distributed on an "AS IS" BASIS, 15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 * See the License for the specific language governing permissions and 17 * limitations under the License. 18 * 19 */ 20 21package test 22 23import ( 24 "context" 25 "fmt" 26 "net" 27 "os" 28 "strings" 29 "sync" 30 "testing" 31 "time" 32 33 "google.golang.org/grpc" 34 "google.golang.org/grpc/codes" 35 "google.golang.org/grpc/internal/stubserver" 36 "google.golang.org/grpc/metadata" 37 "google.golang.org/grpc/status" 38 testpb "google.golang.org/grpc/test/grpc_testing" 39) 40 41func authorityChecker(ctx context.Context, expectedAuthority string) (*testpb.Empty, error) { 42 md, ok := metadata.FromIncomingContext(ctx) 43 if !ok { 44 return nil, status.Error(codes.InvalidArgument, "failed to parse metadata") 45 } 46 auths, ok := md[":authority"] 47 if !ok { 48 return nil, status.Error(codes.InvalidArgument, "no authority header") 49 } 50 if len(auths) != 1 { 51 return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("no authority header, auths = %v", auths)) 52 } 53 if auths[0] != expectedAuthority { 54 return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, expected %v", auths[0], expectedAuthority)) 55 } 56 return &testpb.Empty{}, nil 57} 58 59func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer func(context.Context, string) (net.Conn, error)) { 60 if !strings.HasPrefix(target, "unix-abstract:") { 61 if err := os.RemoveAll(address); err != nil { 62 t.Fatalf("Error removing socket file %v: %v\n", address, err) 63 } 64 } 65 ss := &stubserver.StubServer{ 66 EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { 67 return authorityChecker(ctx, expectedAuthority) 68 }, 69 Network: "unix", 70 Address: address, 71 Target: target, 72 } 73 opts := []grpc.DialOption{} 74 if dialer != nil { 75 opts = append(opts, grpc.WithContextDialer(dialer)) 76 } 77 if err := ss.Start(nil, opts...); err != nil { 78 t.Fatalf("Error starting endpoint server: %v", err) 79 } 80 defer ss.Stop() 81 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 82 defer cancel() 83 _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) 84 if err != nil { 85 t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err) 86 } 87} 88 89type authorityTest struct { 90 name string 91 address string 92 target string 93 authority string 94 dialTargetWant string 95} 96 97var authorityTests = []authorityTest{ 98 { 99 name: "UnixRelative", 100 address: "sock.sock", 101 target: "unix:sock.sock", 102 authority: "localhost", 103 }, 104 { 105 name: "UnixAbsolute", 106 address: "/tmp/sock.sock", 107 target: "unix:/tmp/sock.sock", 108 authority: "localhost", 109 }, 110 { 111 name: "UnixAbsoluteAlternate", 112 address: "/tmp/sock.sock", 113 target: "unix:///tmp/sock.sock", 114 authority: "localhost", 115 }, 116 { 117 name: "UnixPassthrough", 118 address: "/tmp/sock.sock", 119 target: "passthrough:///unix:///tmp/sock.sock", 120 authority: "unix:///tmp/sock.sock", 121 dialTargetWant: "unix:///tmp/sock.sock", 122 }, 123 { 124 name: "UnixAbstract", 125 address: "\x00abc efg", 126 target: "unix-abstract:abc efg", 127 authority: "localhost", 128 dialTargetWant: "\x00abc efg", 129 }, 130} 131 132// TestUnix does end to end tests with the various supported unix target 133// formats, ensuring that the authority is set as expected. 134func (s) TestUnix(t *testing.T) { 135 for _, test := range authorityTests { 136 t.Run(test.name, func(t *testing.T) { 137 runUnixTest(t, test.address, test.target, test.authority, nil) 138 }) 139 } 140} 141 142// TestUnixCustomDialer does end to end tests with various supported unix target 143// formats, ensuring that the target sent to the dialer does NOT have the 144// "unix:" prefix stripped. 145func (s) TestUnixCustomDialer(t *testing.T) { 146 for _, test := range authorityTests { 147 t.Run(test.name+"WithDialer", func(t *testing.T) { 148 if test.dialTargetWant == "" { 149 test.dialTargetWant = test.target 150 } 151 dialer := func(ctx context.Context, address string) (net.Conn, error) { 152 if address != test.dialTargetWant { 153 return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.dialTargetWant, address) 154 } 155 if !strings.HasPrefix(test.target, "unix-abstract:") { 156 address = address[len("unix:"):] 157 } 158 return (&net.Dialer{}).DialContext(ctx, "unix", address) 159 } 160 runUnixTest(t, test.address, test.target, test.authority, dialer) 161 }) 162 } 163} 164 165// TestColonPortAuthority does an end to end test with the target for grpc.Dial 166// being ":[port]". Ensures authority is "localhost:[port]". 167func (s) TestColonPortAuthority(t *testing.T) { 168 expectedAuthority := "" 169 var authorityMu sync.Mutex 170 ss := &stubserver.StubServer{ 171 EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { 172 authorityMu.Lock() 173 defer authorityMu.Unlock() 174 return authorityChecker(ctx, expectedAuthority) 175 }, 176 Network: "tcp", 177 } 178 if err := ss.Start(nil); err != nil { 179 t.Fatalf("Error starting endpoint server: %v", err) 180 } 181 defer ss.Stop() 182 _, port, err := net.SplitHostPort(ss.Address) 183 if err != nil { 184 t.Fatalf("Failed splitting host from post: %v", err) 185 } 186 authorityMu.Lock() 187 expectedAuthority = "localhost:" + port 188 authorityMu.Unlock() 189 // ss.Start dials, but not the ":[port]" target that is being tested here. 190 // Dial again, with ":[port]" as the target. 191 // 192 // Append "localhost" before calling net.Dial, in case net.Dial on certain 193 // platforms doesn't work well for address without the IP. 194 cc, err := grpc.Dial(":"+port, grpc.WithInsecure(), grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { 195 return (&net.Dialer{}).DialContext(ctx, "tcp", "localhost"+addr) 196 })) 197 if err != nil { 198 t.Fatalf("grpc.Dial(%q) = %v", ss.Target, err) 199 } 200 defer cc.Close() 201 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 202 defer cancel() 203 _, err = testpb.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}) 204 if err != nil { 205 t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err) 206 } 207} 208