1// Copyright 2016 Michal Witkowski. All Rights Reserved.
2// See LICENSE for licensing terms.
3
4package conntrack
5
6import (
7	"context"
8	"fmt"
9	"net"
10	"sync"
11
12	"golang.org/x/net/trace"
13)
14
15var (
16	dialerNameKey = "conntrackDialerKey"
17)
18
19type dialerOpts struct {
20	name                  string
21	monitoring            bool
22	tracing               bool
23	parentDialContextFunc dialerContextFunc
24}
25
26type dialerOpt func(*dialerOpts)
27
28type dialerContextFunc func(context.Context, string, string) (net.Conn, error)
29
30// DialWithName sets the name of the dialer for tracking and monitoring.
31// This is the name for the dialer (default is `default`), but for `NewDialContextFunc` can be overwritten from the
32// Context using `DialNameToContext`.
33func DialWithName(name string) dialerOpt {
34	return func(opts *dialerOpts) {
35		opts.name = name
36	}
37}
38
39// DialWithoutMonitoring turns *off* Prometheus monitoring for this dialer.
40func DialWithoutMonitoring() dialerOpt {
41	return func(opts *dialerOpts) {
42		opts.monitoring = false
43	}
44}
45
46// DialWithTracing turns *on* the /debug/events tracing of the dial calls.
47func DialWithTracing() dialerOpt {
48	return func(opts *dialerOpts) {
49		opts.tracing = true
50	}
51}
52
53// DialWithDialer allows you to override the `net.Dialer` instance used to actually conduct the dials.
54func DialWithDialer(parentDialer *net.Dialer) dialerOpt {
55	return DialWithDialContextFunc(parentDialer.DialContext)
56}
57
58// DialWithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`.
59func DialWithDialContextFunc(parentDialerFunc dialerContextFunc) dialerOpt {
60	return func(opts *dialerOpts) {
61		opts.parentDialContextFunc = parentDialerFunc
62	}
63}
64
65// DialNameFromContext returns the name of the dialer from the context of the DialContext func, if any.
66func DialNameFromContext(ctx context.Context) string {
67	val, ok := ctx.Value(dialerNameKey).(string)
68	if !ok {
69		return ""
70	}
71	return val
72}
73
74// DialNameToContext returns a context that will contain a dialer name override.
75func DialNameToContext(ctx context.Context, dialerName string) context.Context {
76	return context.WithValue(ctx, dialerNameKey, dialerName)
77}
78
79// NewDialContextFunc returns a `DialContext` function that tracks outbound connections.
80// The signature is compatible with `http.Tranport.DialContext` and is meant to be used there.
81func NewDialContextFunc(optFuncs ...dialerOpt) func(context.Context, string, string) (net.Conn, error) {
82	opts := &dialerOpts{name: defaultName, monitoring: true, parentDialContextFunc: (&net.Dialer{}).DialContext}
83	for _, f := range optFuncs {
84		f(opts)
85	}
86	if opts.monitoring {
87		PreRegisterDialerMetrics(opts.name)
88	}
89	return func(ctx context.Context, network string, addr string) (net.Conn, error) {
90		name := opts.name
91		if ctxName := DialNameFromContext(ctx); ctxName != "" {
92			name = ctxName
93		}
94		return dialClientConnTracker(ctx, network, addr, name, opts)
95	}
96}
97
98// NewDialFunc returns a `Dial` function that tracks outbound connections.
99// The signature is compatible with `http.Tranport.Dial` and is meant to be used there for Go < 1.7.
100func NewDialFunc(optFuncs ...dialerOpt) func(string, string) (net.Conn, error) {
101	dialContextFunc := NewDialContextFunc(optFuncs...)
102	return func(network string, addr string) (net.Conn, error) {
103		return dialContextFunc(context.TODO(), network, addr)
104	}
105}
106
107type clientConnTracker struct {
108	net.Conn
109	opts       *dialerOpts
110	dialerName string
111	event      trace.EventLog
112	mu         sync.Mutex
113}
114
115func dialClientConnTracker(ctx context.Context, network string, addr string, dialerName string, opts *dialerOpts) (net.Conn, error) {
116	var event trace.EventLog
117	if opts.tracing {
118		event = trace.NewEventLog(fmt.Sprintf("net.ClientConn.%s", dialerName), fmt.Sprintf("%v", addr))
119	}
120	if opts.monitoring {
121		reportDialerConnAttempt(dialerName)
122	}
123	conn, err := opts.parentDialContextFunc(ctx, network, addr)
124	if err != nil {
125		if event != nil {
126			event.Errorf("failed dialing: %v", err)
127			event.Finish()
128		}
129		if opts.monitoring {
130			reportDialerConnFailed(dialerName, err)
131		}
132		return nil, err
133	}
134	if event != nil {
135		event.Printf("established: %s -> %s", conn.LocalAddr(), conn.RemoteAddr())
136	}
137	if opts.monitoring {
138		reportDialerConnEstablished(dialerName)
139	}
140	tracker := &clientConnTracker{
141		Conn:       conn,
142		opts:       opts,
143		dialerName: dialerName,
144		event:      event,
145	}
146	return tracker, nil
147}
148
149func (ct *clientConnTracker) Close() error {
150	err := ct.Conn.Close()
151	ct.mu.Lock()
152	if ct.event != nil {
153		if err != nil {
154			ct.event.Errorf("failed closing: %v", err)
155		} else {
156			ct.event.Printf("closing")
157		}
158		ct.event.Finish()
159		ct.event = nil
160	}
161	ct.mu.Unlock()
162	if ct.opts.monitoring {
163		reportDialerConnClosed(ct.dialerName)
164	}
165	return err
166}
167