1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proxy
6
7import (
8	"bytes"
9	"context"
10	"errors"
11	"fmt"
12	"net"
13	"net/url"
14	"os"
15	"strings"
16	"testing"
17
18	"golang.org/x/net/internal/socks"
19	"golang.org/x/net/internal/sockstest"
20)
21
22type proxyFromEnvTest struct {
23	allProxyEnv string
24	noProxyEnv  string
25	wantTypeOf  Dialer
26}
27
28func (t proxyFromEnvTest) String() string {
29	var buf bytes.Buffer
30	space := func() {
31		if buf.Len() > 0 {
32			buf.WriteByte(' ')
33		}
34	}
35	if t.allProxyEnv != "" {
36		fmt.Fprintf(&buf, "all_proxy=%q", t.allProxyEnv)
37	}
38	if t.noProxyEnv != "" {
39		space()
40		fmt.Fprintf(&buf, "no_proxy=%q", t.noProxyEnv)
41	}
42	return strings.TrimSpace(buf.String())
43}
44
45func TestFromEnvironment(t *testing.T) {
46	ResetProxyEnv()
47
48	type dummyDialer struct {
49		direct
50	}
51
52	RegisterDialerType("irc", func(_ *url.URL, _ Dialer) (Dialer, error) {
53		return dummyDialer{}, nil
54	})
55
56	proxyFromEnvTests := []proxyFromEnvTest{
57		{allProxyEnv: "127.0.0.1:8080", noProxyEnv: "localhost, 127.0.0.1", wantTypeOf: direct{}},
58		{allProxyEnv: "ftp://example.com:8000", noProxyEnv: "localhost, 127.0.0.1", wantTypeOf: direct{}},
59		{allProxyEnv: "socks5://example.com:8080", noProxyEnv: "localhost, 127.0.0.1", wantTypeOf: &PerHost{}},
60		{allProxyEnv: "socks5h://example.com", wantTypeOf: &socks.Dialer{}},
61		{allProxyEnv: "irc://example.com:8000", wantTypeOf: dummyDialer{}},
62		{noProxyEnv: "localhost, 127.0.0.1", wantTypeOf: direct{}},
63		{wantTypeOf: direct{}},
64	}
65
66	for _, tt := range proxyFromEnvTests {
67		os.Setenv("ALL_PROXY", tt.allProxyEnv)
68		os.Setenv("NO_PROXY", tt.noProxyEnv)
69		ResetCachedEnvironment()
70
71		d := FromEnvironment()
72		if got, want := fmt.Sprintf("%T", d), fmt.Sprintf("%T", tt.wantTypeOf); got != want {
73			t.Errorf("%v: got type = %T, want %T", tt, d, tt.wantTypeOf)
74		}
75	}
76}
77
78func TestFromURL(t *testing.T) {
79	ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
80	if err != nil {
81		t.Fatal(err)
82	}
83	defer ss.Close()
84	url, err := url.Parse("socks5://user:password@" + ss.Addr().String())
85	if err != nil {
86		t.Fatal(err)
87	}
88	proxy, err := FromURL(url, nil)
89	if err != nil {
90		t.Fatal(err)
91	}
92	c, err := proxy.Dial("tcp", "fqdn.doesnotexist:5963")
93	if err != nil {
94		t.Fatal(err)
95	}
96	c.Close()
97}
98
99func TestSOCKS5(t *testing.T) {
100	ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
101	if err != nil {
102		t.Fatal(err)
103	}
104	defer ss.Close()
105	proxy, err := SOCKS5("tcp", ss.Addr().String(), nil, nil)
106	if err != nil {
107		t.Fatal(err)
108	}
109	c, err := proxy.Dial("tcp", ss.TargetAddr().String())
110	if err != nil {
111		t.Fatal(err)
112	}
113	c.Close()
114}
115
116type funcFailDialer func(context.Context) error
117
118func (f funcFailDialer) Dial(net, addr string) (net.Conn, error) {
119	panic("shouldn't see a call to Dial")
120}
121
122func (f funcFailDialer) DialContext(ctx context.Context, net, addr string) (net.Conn, error) {
123	return nil, f(ctx)
124}
125
126// Check that FromEnvironmentUsing uses our dialer.
127func TestFromEnvironmentUsing(t *testing.T) {
128	ResetProxyEnv()
129	errFoo := errors.New("some error to check our dialer was used)")
130	type key string
131	ctx := context.WithValue(context.Background(), key("foo"), "bar")
132	dialer := FromEnvironmentUsing(funcFailDialer(func(ctx context.Context) error {
133		if got := ctx.Value(key("foo")); got != "bar" {
134			t.Errorf("Resolver context = %T %v, want %q", got, got, "bar")
135		}
136		return errFoo
137	}))
138	_, err := dialer.(ContextDialer).DialContext(ctx, "tcp", "foo.tld:123")
139	if err == nil {
140		t.Fatalf("unexpected success")
141	}
142	if !strings.Contains(err.Error(), errFoo.Error()) {
143		t.Errorf("got unexpected error %q; want substr %q", err, errFoo)
144	}
145}
146
147func ResetProxyEnv() {
148	for _, env := range []*envOnce{allProxyEnv, noProxyEnv} {
149		for _, v := range env.names {
150			os.Setenv(v, "")
151		}
152	}
153	ResetCachedEnvironment()
154}
155
156func ResetCachedEnvironment() {
157	allProxyEnv.reset()
158	noProxyEnv.reset()
159}
160