1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build !windows && !solaris && !js
6// +build !windows,!solaris,!js
7
8package test
9
10// direct-tcpip and direct-streamlocal functional tests
11
12import (
13	"fmt"
14	"io"
15	"io/ioutil"
16	"net"
17	"strings"
18	"testing"
19)
20
21type dialTester interface {
22	TestServerConn(t *testing.T, c net.Conn)
23	TestClientConn(t *testing.T, c net.Conn)
24}
25
26func testDial(t *testing.T, n, listenAddr string, x dialTester) {
27	server := newServer(t)
28	defer server.Shutdown()
29	sshConn := server.Dial(clientConfig())
30	defer sshConn.Close()
31
32	l, err := net.Listen(n, listenAddr)
33	if err != nil {
34		t.Fatalf("Listen: %v", err)
35	}
36	defer l.Close()
37
38	testData := fmt.Sprintf("hello from %s, %s", n, listenAddr)
39	go func() {
40		for {
41			c, err := l.Accept()
42			if err != nil {
43				break
44			}
45			x.TestServerConn(t, c)
46
47			io.WriteString(c, testData)
48			c.Close()
49		}
50	}()
51
52	conn, err := sshConn.Dial(n, l.Addr().String())
53	if err != nil {
54		t.Fatalf("Dial: %v", err)
55	}
56	x.TestClientConn(t, conn)
57	defer conn.Close()
58	b, err := ioutil.ReadAll(conn)
59	if err != nil {
60		t.Fatalf("ReadAll: %v", err)
61	}
62	t.Logf("got %q", string(b))
63	if string(b) != testData {
64		t.Fatalf("expected %q, got %q", testData, string(b))
65	}
66}
67
68type tcpDialTester struct {
69	listenAddr string
70}
71
72func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) {
73	host := strings.Split(x.listenAddr, ":")[0]
74	prefix := host + ":"
75	if !strings.HasPrefix(c.LocalAddr().String(), prefix) {
76		t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String())
77	}
78	if !strings.HasPrefix(c.RemoteAddr().String(), prefix) {
79		t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String())
80	}
81}
82
83func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) {
84	// we use zero addresses. see *Client.Dial.
85	if c.LocalAddr().String() != "0.0.0.0:0" {
86		t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String())
87	}
88	if c.RemoteAddr().String() != "0.0.0.0:0" {
89		t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String())
90	}
91}
92
93func TestDialTCP(t *testing.T) {
94	x := &tcpDialTester{
95		listenAddr: "127.0.0.1:0",
96	}
97	testDial(t, "tcp", x.listenAddr, x)
98}
99
100type unixDialTester struct {
101	listenAddr string
102}
103
104func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) {
105	if c.LocalAddr().String() != x.listenAddr {
106		t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String())
107	}
108	if c.RemoteAddr().String() != "@" && c.RemoteAddr().String() != "" {
109		t.Fatalf("expected \"@\" or \"\", got %q", c.RemoteAddr().String())
110	}
111}
112
113func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) {
114	if c.RemoteAddr().String() != x.listenAddr {
115		t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String())
116	}
117	if c.LocalAddr().String() != "@" {
118		t.Fatalf("expected \"@\", got %q", c.LocalAddr().String())
119	}
120}
121
122func TestDialUnix(t *testing.T) {
123	addr, cleanup := newTempSocket(t)
124	defer cleanup()
125	x := &unixDialTester{
126		listenAddr: addr,
127	}
128	testDial(t, "unix", x.listenAddr, x)
129}
130