1package proxy
2
3import (
4	"context"
5	"crypto/tls"
6	"errors"
7	"net"
8	"sync"
9	"sync/atomic"
10	"time"
11
12	metrics "github.com/armon/go-metrics"
13	"github.com/hashicorp/consul/api"
14	"github.com/hashicorp/consul/connect"
15	"github.com/hashicorp/consul/ipaddr"
16	"github.com/hashicorp/go-hclog"
17)
18
19const (
20	publicListenerPrefix   = "inbound"
21	upstreamListenerPrefix = "upstream"
22)
23
24// Listener is the implementation of a specific proxy listener. It has pluggable
25// Listen and Dial methods to suit public mTLS vs upstream semantics. It handles
26// the lifecycle of the listener and all connections opened through it
27type Listener struct {
28	// Service is the connect service instance to use.
29	Service *connect.Service
30
31	// listenFunc, dialFunc, and bindAddr are set by type-specific constructors.
32	listenFunc func() (net.Listener, error)
33	dialFunc   func() (net.Conn, error)
34	bindAddr   string
35
36	stopFlag int32
37	stopChan chan struct{}
38
39	// listeningChan is closed when listener is opened successfully. It's really
40	// only for use in tests where we need to coordinate wait for the Serve
41	// goroutine to be running before we proceed trying to connect. On my laptop
42	// this always works out anyway but on constrained VMs and especially docker
43	// containers (e.g. in CI) we often see the Dial routine win the race and get
44	// `connection refused`. Retry loops and sleeps are unpleasant workarounds and
45	// this is cheap and correct.
46	listeningChan chan struct{}
47
48	logger hclog.Logger
49
50	// Gauge to track current open connections
51	activeConns  int32
52	connWG       sync.WaitGroup
53	metricPrefix string
54	metricLabels []metrics.Label
55}
56
57// NewPublicListener returns a Listener setup to listen for public mTLS
58// connections and proxy them to the configured local application over TCP.
59func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig,
60	logger hclog.Logger) *Listener {
61	bindAddr := ipaddr.FormatAddressPort(cfg.BindAddress, cfg.BindPort)
62	return &Listener{
63		Service: svc,
64		listenFunc: func() (net.Listener, error) {
65			return tls.Listen("tcp", bindAddr, svc.ServerTLSConfig())
66		},
67		dialFunc: func() (net.Conn, error) {
68			return net.DialTimeout("tcp", cfg.LocalServiceAddress,
69				time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond)
70		},
71		bindAddr:      bindAddr,
72		stopChan:      make(chan struct{}),
73		listeningChan: make(chan struct{}),
74		logger:        logger.Named(publicListenerPrefix),
75		metricPrefix:  publicListenerPrefix,
76		// For now we only label ourselves as source - we could fetch the src
77		// service from cert on each connection and label metrics differently but it
78		// significaly complicates the active connection tracking here and it's not
79		// clear that it's very valuable - on aggregate looking at all _outbound_
80		// connections across all proxies gets you a full picture of src->dst
81		// traffic. We might expand this later for better debugging of which clients
82		// are abusing a particular service instance but we'll see how valuable that
83		// seems for the extra complication of tracking many gauges here.
84		metricLabels: []metrics.Label{{Name: "dst", Value: svc.Name()}},
85	}
86}
87
88// NewUpstreamListener returns a Listener setup to listen locally for TCP
89// connections that are proxied to a discovered Connect service instance.
90func NewUpstreamListener(svc *connect.Service, client *api.Client,
91	cfg UpstreamConfig, logger hclog.Logger) *Listener {
92	return newUpstreamListenerWithResolver(svc, cfg,
93		UpstreamResolverFuncFromClient(client), logger)
94}
95
96func newUpstreamListenerWithResolver(svc *connect.Service, cfg UpstreamConfig,
97	resolverFunc func(UpstreamConfig) (connect.Resolver, error),
98	logger hclog.Logger) *Listener {
99	bindAddr := ipaddr.FormatAddressPort(cfg.LocalBindAddress, cfg.LocalBindPort)
100	return &Listener{
101		Service: svc,
102		listenFunc: func() (net.Listener, error) {
103			return net.Listen("tcp", bindAddr)
104		},
105		dialFunc: func() (net.Conn, error) {
106			rf, err := resolverFunc(cfg)
107			if err != nil {
108				return nil, err
109			}
110			ctx, cancel := context.WithTimeout(context.Background(),
111				cfg.ConnectTimeout())
112			defer cancel()
113			return svc.Dial(ctx, rf)
114		},
115		bindAddr:      bindAddr,
116		stopChan:      make(chan struct{}),
117		listeningChan: make(chan struct{}),
118		logger:        logger.Named(upstreamListenerPrefix),
119		metricPrefix:  upstreamListenerPrefix,
120		metricLabels: []metrics.Label{
121			{Name: "src", Value: svc.Name()},
122			// TODO(banks): namespace support
123			{Name: "dst_type", Value: string(cfg.DestinationType)},
124			{Name: "dst", Value: cfg.DestinationName},
125		},
126	}
127}
128
129// Serve runs the listener until it is stopped. It is an error to call Serve
130// more than once for any given Listener instance.
131func (l *Listener) Serve() error {
132	// Ensure we mark state closed if we fail before Close is called externally.
133	defer l.Close()
134
135	if atomic.LoadInt32(&l.stopFlag) != 0 {
136		return errors.New("serve called on a closed listener")
137	}
138
139	listen, err := l.listenFunc()
140	if err != nil {
141		return err
142	}
143	close(l.listeningChan)
144
145	for {
146		conn, err := listen.Accept()
147		if err != nil {
148			if atomic.LoadInt32(&l.stopFlag) == 1 {
149				return nil
150			}
151			return err
152		}
153
154		go l.handleConn(conn)
155	}
156}
157
158// handleConn is the internal connection handler goroutine.
159func (l *Listener) handleConn(src net.Conn) {
160	defer src.Close()
161
162	dst, err := l.dialFunc()
163	if err != nil {
164		l.logger.Error("failed to dial", "error", err)
165		return
166	}
167
168	// Track active conn now (first function call) and defer un-counting it when
169	// it closes.
170	defer l.trackConn()()
171
172	// Make sure Close() waits for this conn to be cleaned up. Note defer is
173	// before conn.Close() so runs after defer conn.Close().
174	l.connWG.Add(1)
175	defer l.connWG.Done()
176
177	// Note no need to defer dst.Close() since conn handles that for us.
178	conn := NewConn(src, dst)
179	defer conn.Close()
180
181	connStop := make(chan struct{})
182
183	// Run another goroutine to copy the bytes.
184	go func() {
185		err = conn.CopyBytes()
186		if err != nil {
187			l.logger.Error("connection failed", "error", err)
188		}
189		close(connStop)
190	}()
191
192	// Periodically copy stats from conn to metrics (to keep metrics calls out of
193	// the path of every single packet copy). 5 seconds is probably good enough
194	// resolution - statsd and most others tend to summarize with lower resolution
195	// anyway and this amortizes the cost more.
196	var tx, rx uint64
197	statsT := time.NewTicker(5 * time.Second)
198	defer statsT.Stop()
199
200	reportStats := func() {
201		newTx, newRx := conn.Stats()
202		if delta := newTx - tx; delta > 0 {
203			metrics.IncrCounterWithLabels([]string{l.metricPrefix, "tx_bytes"},
204				float32(newTx-tx), l.metricLabels)
205		}
206		if delta := newRx - rx; delta > 0 {
207			metrics.IncrCounterWithLabels([]string{l.metricPrefix, "rx_bytes"},
208				float32(newRx-rx), l.metricLabels)
209		}
210		tx, rx = newTx, newRx
211	}
212	// Always report final stats for the conn.
213	defer reportStats()
214
215	// Wait for conn to close
216	for {
217		select {
218		case <-connStop:
219			return
220		case <-l.stopChan:
221			return
222		case <-statsT.C:
223			reportStats()
224		}
225	}
226}
227
228// trackConn increments the count of active conns and returns a func() that can
229// be deferred on to decrement the counter again on connection close.
230func (l *Listener) trackConn() func() {
231	c := atomic.AddInt32(&l.activeConns, 1)
232	metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c),
233		l.metricLabels)
234
235	return func() {
236		c := atomic.AddInt32(&l.activeConns, -1)
237		metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c),
238			l.metricLabels)
239	}
240}
241
242// Close terminates the listener and all active connections.
243func (l *Listener) Close() error {
244	oldFlag := atomic.SwapInt32(&l.stopFlag, 1)
245	if oldFlag == 0 {
246		close(l.stopChan)
247		// Wait for all conns to close
248		l.connWG.Wait()
249	}
250	return nil
251}
252
253// Wait for the listener to be ready to accept connections.
254func (l *Listener) Wait() {
255	<-l.listeningChan
256}
257
258// BindAddr returns the address the listen is bound to.
259func (l *Listener) BindAddr() string {
260	return l.bindAddr
261}
262