1// Copyright 2012 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5//go:build !windows && !solaris && !js 6// +build !windows,!solaris,!js 7 8package test 9 10// direct-tcpip and direct-streamlocal functional tests 11 12import ( 13 "fmt" 14 "io" 15 "io/ioutil" 16 "net" 17 "strings" 18 "testing" 19) 20 21type dialTester interface { 22 TestServerConn(t *testing.T, c net.Conn) 23 TestClientConn(t *testing.T, c net.Conn) 24} 25 26func testDial(t *testing.T, n, listenAddr string, x dialTester) { 27 server := newServer(t) 28 defer server.Shutdown() 29 sshConn := server.Dial(clientConfig()) 30 defer sshConn.Close() 31 32 l, err := net.Listen(n, listenAddr) 33 if err != nil { 34 t.Fatalf("Listen: %v", err) 35 } 36 defer l.Close() 37 38 testData := fmt.Sprintf("hello from %s, %s", n, listenAddr) 39 go func() { 40 for { 41 c, err := l.Accept() 42 if err != nil { 43 break 44 } 45 x.TestServerConn(t, c) 46 47 io.WriteString(c, testData) 48 c.Close() 49 } 50 }() 51 52 conn, err := sshConn.Dial(n, l.Addr().String()) 53 if err != nil { 54 t.Fatalf("Dial: %v", err) 55 } 56 x.TestClientConn(t, conn) 57 defer conn.Close() 58 b, err := ioutil.ReadAll(conn) 59 if err != nil { 60 t.Fatalf("ReadAll: %v", err) 61 } 62 t.Logf("got %q", string(b)) 63 if string(b) != testData { 64 t.Fatalf("expected %q, got %q", testData, string(b)) 65 } 66} 67 68type tcpDialTester struct { 69 listenAddr string 70} 71 72func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) { 73 host := strings.Split(x.listenAddr, ":")[0] 74 prefix := host + ":" 75 if !strings.HasPrefix(c.LocalAddr().String(), prefix) { 76 t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String()) 77 } 78 if !strings.HasPrefix(c.RemoteAddr().String(), prefix) { 79 t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String()) 80 } 81} 82 83func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) { 84 // we use zero addresses. see *Client.Dial. 85 if c.LocalAddr().String() != "0.0.0.0:0" { 86 t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String()) 87 } 88 if c.RemoteAddr().String() != "0.0.0.0:0" { 89 t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String()) 90 } 91} 92 93func TestDialTCP(t *testing.T) { 94 x := &tcpDialTester{ 95 listenAddr: "127.0.0.1:0", 96 } 97 testDial(t, "tcp", x.listenAddr, x) 98} 99 100type unixDialTester struct { 101 listenAddr string 102} 103 104func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) { 105 if c.LocalAddr().String() != x.listenAddr { 106 t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String()) 107 } 108 if c.RemoteAddr().String() != "@" && c.RemoteAddr().String() != "" { 109 t.Fatalf("expected \"@\" or \"\", got %q", c.RemoteAddr().String()) 110 } 111} 112 113func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) { 114 if c.RemoteAddr().String() != x.listenAddr { 115 t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String()) 116 } 117 if c.LocalAddr().String() != "@" { 118 t.Fatalf("expected \"@\", got %q", c.LocalAddr().String()) 119 } 120} 121 122func TestDialUnix(t *testing.T) { 123 addr, cleanup := newTempSocket(t) 124 defer cleanup() 125 x := &unixDialTester{ 126 listenAddr: addr, 127 } 128 testDial(t, "unix", x.listenAddr, x) 129} 130