1// Copyright 2016 Michal Witkowski. All Rights Reserved. 2// See LICENSE for licensing terms. 3 4package conntrack 5 6import ( 7 "context" 8 "fmt" 9 "net" 10 "sync" 11 12 "golang.org/x/net/trace" 13) 14 15var ( 16 dialerNameKey = "conntrackDialerKey" 17) 18 19type dialerOpts struct { 20 name string 21 monitoring bool 22 tracing bool 23 parentDialContextFunc dialerContextFunc 24} 25 26type dialerOpt func(*dialerOpts) 27 28type dialerContextFunc func(context.Context, string, string) (net.Conn, error) 29 30// DialWithName sets the name of the dialer for tracking and monitoring. 31// This is the name for the dialer (default is `default`), but for `NewDialContextFunc` can be overwritten from the 32// Context using `DialNameToContext`. 33func DialWithName(name string) dialerOpt { 34 return func(opts *dialerOpts) { 35 opts.name = name 36 } 37} 38 39// DialWithoutMonitoring turns *off* Prometheus monitoring for this dialer. 40func DialWithoutMonitoring() dialerOpt { 41 return func(opts *dialerOpts) { 42 opts.monitoring = false 43 } 44} 45 46// DialWithTracing turns *on* the /debug/events tracing of the dial calls. 47func DialWithTracing() dialerOpt { 48 return func(opts *dialerOpts) { 49 opts.tracing = true 50 } 51} 52 53// DialWithDialer allows you to override the `net.Dialer` instance used to actually conduct the dials. 54func DialWithDialer(parentDialer *net.Dialer) dialerOpt { 55 return DialWithDialContextFunc(parentDialer.DialContext) 56} 57 58// DialWithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`. 59func DialWithDialContextFunc(parentDialerFunc dialerContextFunc) dialerOpt { 60 return func(opts *dialerOpts) { 61 opts.parentDialContextFunc = parentDialerFunc 62 } 63} 64 65// DialNameFromContext returns the name of the dialer from the context of the DialContext func, if any. 66func DialNameFromContext(ctx context.Context) string { 67 val, ok := ctx.Value(dialerNameKey).(string) 68 if !ok { 69 return "" 70 } 71 return val 72} 73 74// DialNameToContext returns a context that will contain a dialer name override. 75func DialNameToContext(ctx context.Context, dialerName string) context.Context { 76 return context.WithValue(ctx, dialerNameKey, dialerName) 77} 78 79// NewDialContextFunc returns a `DialContext` function that tracks outbound connections. 80// The signature is compatible with `http.Tranport.DialContext` and is meant to be used there. 81func NewDialContextFunc(optFuncs ...dialerOpt) func(context.Context, string, string) (net.Conn, error) { 82 opts := &dialerOpts{name: defaultName, monitoring: true, parentDialContextFunc: (&net.Dialer{}).DialContext} 83 for _, f := range optFuncs { 84 f(opts) 85 } 86 if opts.monitoring { 87 PreRegisterDialerMetrics(opts.name) 88 } 89 return func(ctx context.Context, network string, addr string) (net.Conn, error) { 90 name := opts.name 91 if ctxName := DialNameFromContext(ctx); ctxName != "" { 92 name = ctxName 93 } 94 return dialClientConnTracker(ctx, network, addr, name, opts) 95 } 96} 97 98// NewDialFunc returns a `Dial` function that tracks outbound connections. 99// The signature is compatible with `http.Tranport.Dial` and is meant to be used there for Go < 1.7. 100func NewDialFunc(optFuncs ...dialerOpt) func(string, string) (net.Conn, error) { 101 dialContextFunc := NewDialContextFunc(optFuncs...) 102 return func(network string, addr string) (net.Conn, error) { 103 return dialContextFunc(context.TODO(), network, addr) 104 } 105} 106 107type clientConnTracker struct { 108 net.Conn 109 opts *dialerOpts 110 dialerName string 111 event trace.EventLog 112 mu sync.Mutex 113} 114 115func dialClientConnTracker(ctx context.Context, network string, addr string, dialerName string, opts *dialerOpts) (net.Conn, error) { 116 var event trace.EventLog 117 if opts.tracing { 118 event = trace.NewEventLog(fmt.Sprintf("net.ClientConn.%s", dialerName), fmt.Sprintf("%v", addr)) 119 } 120 if opts.monitoring { 121 reportDialerConnAttempt(dialerName) 122 } 123 conn, err := opts.parentDialContextFunc(ctx, network, addr) 124 if err != nil { 125 if event != nil { 126 event.Errorf("failed dialing: %v", err) 127 event.Finish() 128 } 129 if opts.monitoring { 130 reportDialerConnFailed(dialerName, err) 131 } 132 return nil, err 133 } 134 if event != nil { 135 event.Printf("established: %s -> %s", conn.LocalAddr(), conn.RemoteAddr()) 136 } 137 if opts.monitoring { 138 reportDialerConnEstablished(dialerName) 139 } 140 tracker := &clientConnTracker{ 141 Conn: conn, 142 opts: opts, 143 dialerName: dialerName, 144 event: event, 145 } 146 return tracker, nil 147} 148 149func (ct *clientConnTracker) Close() error { 150 err := ct.Conn.Close() 151 ct.mu.Lock() 152 if ct.event != nil { 153 if err != nil { 154 ct.event.Errorf("failed closing: %v", err) 155 } else { 156 ct.event.Printf("closing") 157 } 158 ct.event.Finish() 159 ct.event = nil 160 } 161 ct.mu.Unlock() 162 if ct.opts.monitoring { 163 reportDialerConnClosed(ct.dialerName) 164 } 165 return err 166} 167