1// Copyright 2016 Michal Witkowski. All Rights Reserved.
2// See LICENSE for licensing terms.
3
4package conntrack
5
6import (
7	"fmt"
8	"net"
9	"sync"
10	"time"
11
12	"github.com/jpillora/backoff"
13	"golang.org/x/net/trace"
14)
15
16const (
17	defaultName = "default"
18)
19
20type listenerOpts struct {
21	name         string
22	monitoring   bool
23	tracing      bool
24	tcpKeepAlive time.Duration
25	retryBackoff *backoff.Backoff
26}
27
28type listenerOpt func(*listenerOpts)
29
30// TrackWithName sets the name of the Listener for use in tracking and monitoring.
31func TrackWithName(name string) listenerOpt {
32	return func(opts *listenerOpts) {
33		opts.name = name
34	}
35}
36
37// TrackWithoutMonitoring turns *off* Prometheus monitoring for this listener.
38func TrackWithoutMonitoring() listenerOpt {
39	return func(opts *listenerOpts) {
40		opts.monitoring = false
41	}
42}
43
44// TrackWithTracing turns *on* the /debug/events tracing of the live listener connections.
45func TrackWithTracing() listenerOpt {
46	return func(opts *listenerOpts) {
47		opts.tracing = true
48	}
49}
50
51// TrackWithRetries enables retrying of temporary Accept() errors, with the given backoff between attempts.
52// Concurrent accept calls that receive temporary errors have independent backoff scaling.
53func TrackWithRetries(b backoff.Backoff) listenerOpt {
54	return func(opts *listenerOpts) {
55		opts.retryBackoff = &b
56	}
57}
58
59// TrackWithTcpKeepAlive makes sure that any `net.TCPConn` that get accepted have a keep-alive.
60// This is useful for HTTP servers in order for, for example laptops, to not use up resources on the
61// server while they don't utilise their connection.
62// A value of 0 disables it.
63func TrackWithTcpKeepAlive(keepalive time.Duration) listenerOpt {
64	return func(opts *listenerOpts) {
65		opts.tcpKeepAlive = keepalive
66	}
67}
68
69type connTrackListener struct {
70	net.Listener
71	opts *listenerOpts
72}
73
74// NewListener returns the given listener wrapped in connection tracking listener.
75func NewListener(inner net.Listener, optFuncs ...listenerOpt) net.Listener {
76	opts := &listenerOpts{
77		name:       defaultName,
78		monitoring: true,
79		tracing:    false,
80	}
81	for _, f := range optFuncs {
82		f(opts)
83	}
84	if opts.monitoring {
85		preRegisterListenerMetrics(opts.name)
86	}
87	return &connTrackListener{
88		Listener: inner,
89		opts:     opts,
90	}
91}
92
93func (ct *connTrackListener) Accept() (net.Conn, error) {
94	// TODO(mwitkow): Add monitoring of failed accept.
95	var (
96		conn net.Conn
97		err  error
98	)
99	for attempt := 0; ; attempt++ {
100		conn, err = ct.Listener.Accept()
101		if err == nil || ct.opts.retryBackoff == nil {
102			break
103		}
104		if t, ok := err.(interface{ Temporary() bool }); !ok || !t.Temporary() {
105			break
106		}
107		time.Sleep(ct.opts.retryBackoff.ForAttempt(float64(attempt)))
108	}
109	if err != nil {
110		return nil, err
111	}
112	if tcpConn, ok := conn.(*net.TCPConn); ok && ct.opts.tcpKeepAlive > 0 {
113		tcpConn.SetKeepAlive(true)
114		tcpConn.SetKeepAlivePeriod(ct.opts.tcpKeepAlive)
115	}
116	return newServerConnTracker(conn, ct.opts), nil
117}
118
119type serverConnTracker struct {
120	net.Conn
121	opts  *listenerOpts
122	event trace.EventLog
123	mu    sync.Mutex
124}
125
126func newServerConnTracker(inner net.Conn, opts *listenerOpts) net.Conn {
127	tracker := &serverConnTracker{
128		Conn: inner,
129		opts: opts,
130	}
131	if opts.tracing {
132		tracker.event = trace.NewEventLog(fmt.Sprintf("net.ServerConn.%s", opts.name), fmt.Sprintf("%v", inner.RemoteAddr()))
133		tracker.event.Printf("accepted: %v -> %v", inner.RemoteAddr(), inner.LocalAddr())
134	}
135	if opts.monitoring {
136		reportListenerConnAccepted(opts.name)
137	}
138	return tracker
139}
140
141func (ct *serverConnTracker) Close() error {
142	err := ct.Conn.Close()
143	ct.mu.Lock()
144	if ct.event != nil {
145		if err != nil {
146			ct.event.Errorf("failed closing: %v", err)
147		} else {
148			ct.event.Printf("closing")
149		}
150		ct.event.Finish()
151		ct.event = nil
152	}
153	ct.mu.Unlock()
154	if ct.opts.monitoring {
155		reportListenerConnClosed(ct.opts.name)
156	}
157	return err
158}
159