1package unix_transport
2
3import (
4	"crypto/tls"
5	"errors"
6	"fmt"
7	"net"
8	"net/http"
9	"net/http/httputil"
10	"net/url"
11	"strings"
12)
13
14func NewWithTLS(socketPath string, tlsConfig *tls.Config) *http.Transport {
15	unixTransport := &http.Transport{TLSClientConfig: tlsConfig}
16
17	unixTransport.RegisterProtocol("unix", NewUnixRoundTripperTls(socketPath, tlsConfig))
18	return unixTransport
19}
20
21func New(socketPath string) *http.Transport {
22	unixTransport := &http.Transport{}
23	unixTransport.RegisterProtocol("unix", NewUnixRoundTripper(socketPath))
24	return unixTransport
25}
26
27type UnixRoundTripper struct {
28	path      string
29	conn      httputil.ClientConn
30	useTls    bool
31	tlsConfig *tls.Config
32}
33
34func NewUnixRoundTripper(path string) *UnixRoundTripper {
35	return &UnixRoundTripper{path: path}
36}
37
38func NewUnixRoundTripperTls(path string, tlsConfig *tls.Config) *UnixRoundTripper {
39	return &UnixRoundTripper{
40		path:      path,
41		useTls:    true,
42		tlsConfig: tlsConfig,
43	}
44}
45
46// The RoundTripper (http://golang.org/pkg/net/http/#RoundTripper) for the socket transport dials the socket
47// each time a request is made.
48func (roundTripper UnixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
49	var conn net.Conn
50	var err error
51	if roundTripper.useTls {
52
53		conn, err = tls.Dial("unix", roundTripper.path, roundTripper.tlsConfig)
54		if err != nil {
55			return nil, err
56		}
57		if conn == nil {
58			return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)")
59		}
60		if tc, ok := conn.(*tls.Conn); ok {
61			// Handshake here, in case DialTLS didn't. TLSNextProto below
62			// depends on it for knowing the connection state.
63			if err := tc.Handshake(); err != nil {
64				go conn.Close()
65				return nil, err
66			}
67		}
68	} else {
69		conn, err = net.Dial("unix", roundTripper.path)
70		if err != nil {
71			return nil, err
72		}
73	}
74
75	socketClientConn := httputil.NewClientConn(conn, nil)
76	defer socketClientConn.Close()
77
78	newReq, err := roundTripper.rewriteRequest(req)
79	if err != nil {
80		return nil, err
81	}
82
83	return socketClientConn.Do(newReq)
84}
85
86func (roundTripper *UnixRoundTripper) rewriteRequest(req *http.Request) (*http.Request, error) {
87	requestPath := req.URL.Path
88	if !strings.HasPrefix(requestPath, roundTripper.path) {
89		return nil, fmt.Errorf("Wrong unix socket [unix://%s]. Expected unix socket is [%s]", requestPath, roundTripper.path)
90	}
91
92	reqPath := strings.TrimPrefix(requestPath, roundTripper.path)
93	newReqUrl := fmt.Sprintf("unix://%s", reqPath)
94
95	var err error
96	newURL, err := url.Parse(newReqUrl)
97	if err != nil {
98		return nil, err
99	}
100
101	req.URL.Path = newURL.Path
102	req.URL.Host = roundTripper.path
103	return req, nil
104
105}
106