1package unix_transport 2 3import ( 4 "crypto/tls" 5 "errors" 6 "fmt" 7 "net" 8 "net/http" 9 "net/http/httputil" 10 "net/url" 11 "strings" 12) 13 14func NewWithTLS(socketPath string, tlsConfig *tls.Config) *http.Transport { 15 unixTransport := &http.Transport{TLSClientConfig: tlsConfig} 16 17 unixTransport.RegisterProtocol("unix", NewUnixRoundTripperTls(socketPath, tlsConfig)) 18 return unixTransport 19} 20 21func New(socketPath string) *http.Transport { 22 unixTransport := &http.Transport{} 23 unixTransport.RegisterProtocol("unix", NewUnixRoundTripper(socketPath)) 24 return unixTransport 25} 26 27type UnixRoundTripper struct { 28 path string 29 conn httputil.ClientConn 30 useTls bool 31 tlsConfig *tls.Config 32} 33 34func NewUnixRoundTripper(path string) *UnixRoundTripper { 35 return &UnixRoundTripper{path: path} 36} 37 38func NewUnixRoundTripperTls(path string, tlsConfig *tls.Config) *UnixRoundTripper { 39 return &UnixRoundTripper{ 40 path: path, 41 useTls: true, 42 tlsConfig: tlsConfig, 43 } 44} 45 46// The RoundTripper (http://golang.org/pkg/net/http/#RoundTripper) for the socket transport dials the socket 47// each time a request is made. 48func (roundTripper UnixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 49 var conn net.Conn 50 var err error 51 if roundTripper.useTls { 52 53 conn, err = tls.Dial("unix", roundTripper.path, roundTripper.tlsConfig) 54 if err != nil { 55 return nil, err 56 } 57 if conn == nil { 58 return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)") 59 } 60 if tc, ok := conn.(*tls.Conn); ok { 61 // Handshake here, in case DialTLS didn't. TLSNextProto below 62 // depends on it for knowing the connection state. 63 if err := tc.Handshake(); err != nil { 64 go conn.Close() 65 return nil, err 66 } 67 } 68 } else { 69 conn, err = net.Dial("unix", roundTripper.path) 70 if err != nil { 71 return nil, err 72 } 73 } 74 75 socketClientConn := httputil.NewClientConn(conn, nil) 76 defer socketClientConn.Close() 77 78 newReq, err := roundTripper.rewriteRequest(req) 79 if err != nil { 80 return nil, err 81 } 82 83 return socketClientConn.Do(newReq) 84} 85 86func (roundTripper *UnixRoundTripper) rewriteRequest(req *http.Request) (*http.Request, error) { 87 requestPath := req.URL.Path 88 if !strings.HasPrefix(requestPath, roundTripper.path) { 89 return nil, fmt.Errorf("Wrong unix socket [unix://%s]. Expected unix socket is [%s]", requestPath, roundTripper.path) 90 } 91 92 reqPath := strings.TrimPrefix(requestPath, roundTripper.path) 93 newReqUrl := fmt.Sprintf("unix://%s", reqPath) 94 95 var err error 96 newURL, err := url.Parse(newReqUrl) 97 if err != nil { 98 return nil, err 99 } 100 101 req.URL.Path = newURL.Path 102 req.URL.Host = roundTripper.path 103 return req, nil 104 105} 106