1// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package websocket
6
7import (
8	"bufio"
9	"encoding/base64"
10	"errors"
11	"net"
12	"net/http"
13	"net/url"
14	"strings"
15)
16
17type netDialerFunc func(network, addr string) (net.Conn, error)
18
19func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
20	return fn(network, addr)
21}
22
23func init() {
24	proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
25		return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
26	})
27}
28
29type httpProxyDialer struct {
30	proxyURL    *url.URL
31	forwardDial func(network, addr string) (net.Conn, error)
32}
33
34func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
35	hostPort, _ := hostPortNoPort(hpd.proxyURL)
36	conn, err := hpd.forwardDial(network, hostPort)
37	if err != nil {
38		return nil, err
39	}
40
41	connectHeader := make(http.Header)
42	if user := hpd.proxyURL.User; user != nil {
43		proxyUser := user.Username()
44		if proxyPassword, passwordSet := user.Password(); passwordSet {
45			credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
46			connectHeader.Set("Proxy-Authorization", "Basic "+credential)
47		}
48	}
49
50	connectReq := &http.Request{
51		Method: "CONNECT",
52		URL:    &url.URL{Opaque: addr},
53		Host:   addr,
54		Header: connectHeader,
55	}
56
57	if err := connectReq.Write(conn); err != nil {
58		conn.Close()
59		return nil, err
60	}
61
62	// Read response. It's OK to use and discard buffered reader here becaue
63	// the remote server does not speak until spoken to.
64	br := bufio.NewReader(conn)
65	resp, err := http.ReadResponse(br, connectReq)
66	if err != nil {
67		conn.Close()
68		return nil, err
69	}
70
71	if resp.StatusCode != 200 {
72		conn.Close()
73		f := strings.SplitN(resp.Status, " ", 2)
74		return nil, errors.New(f[1])
75	}
76	return conn, nil
77}
78