1package nomad
2
3import (
4	"context"
5	"fmt"
6	"net"
7	"sync"
8	"time"
9
10	"github.com/hashicorp/nomad/helper/pool"
11	"github.com/hashicorp/nomad/helper/tlsutil"
12	"github.com/hashicorp/raft"
13)
14
15// RaftLayer implements the raft.StreamLayer interface,
16// so that we can use a single RPC layer for Raft and Nomad
17type RaftLayer struct {
18	// Addr is the listener address to return
19	addr net.Addr
20
21	// connCh is used to accept connections
22	connCh chan net.Conn
23
24	// TLS wrapper
25	tlsWrap     tlsutil.Wrapper
26	tlsWrapLock sync.RWMutex
27
28	// Tracks if we are closed
29	closed    bool
30	closeCh   chan struct{}
31	closeLock sync.Mutex
32}
33
34// NewRaftLayer is used to initialize a new RaftLayer which can
35// be used as a StreamLayer for Raft. If a tlsConfig is provided,
36// then the connection will use TLS.
37func NewRaftLayer(addr net.Addr, tlsWrap tlsutil.Wrapper) *RaftLayer {
38	layer := &RaftLayer{
39		addr:    addr,
40		connCh:  make(chan net.Conn),
41		tlsWrap: tlsWrap,
42		closeCh: make(chan struct{}),
43	}
44	return layer
45}
46
47// Handoff is used to hand off a connection to the
48// RaftLayer. This allows it to be Accept()'ed
49func (l *RaftLayer) Handoff(ctx context.Context, c net.Conn) error {
50	select {
51	case l.connCh <- c:
52		return nil
53	case <-l.closeCh:
54		return fmt.Errorf("Raft RPC layer closed")
55	case <-ctx.Done():
56		return nil
57	}
58}
59
60// Accept is used to return connection which are
61// dialed to be used with the Raft layer
62func (l *RaftLayer) Accept() (net.Conn, error) {
63	select {
64	case conn := <-l.connCh:
65		return conn, nil
66	case <-l.closeCh:
67		return nil, fmt.Errorf("Raft RPC layer closed")
68	}
69}
70
71// Close is used to stop listening for Raft connections
72func (l *RaftLayer) Close() error {
73	l.closeLock.Lock()
74	defer l.closeLock.Unlock()
75
76	if !l.closed {
77		l.closed = true
78		close(l.closeCh)
79	}
80	return nil
81}
82
83// getTLSWrapper is used to retrieve the current TLS wrapper
84func (l *RaftLayer) getTLSWrapper() tlsutil.Wrapper {
85	l.tlsWrapLock.RLock()
86	defer l.tlsWrapLock.RUnlock()
87	return l.tlsWrap
88}
89
90// ReloadTLS swaps the TLS wrapper. This is useful when upgrading or
91// downgrading TLS connections.
92func (l *RaftLayer) ReloadTLS(tlsWrap tlsutil.Wrapper) {
93	l.tlsWrapLock.Lock()
94	defer l.tlsWrapLock.Unlock()
95	l.tlsWrap = tlsWrap
96}
97
98// Addr is used to return the address of the listener
99func (l *RaftLayer) Addr() net.Addr {
100	return l.addr
101}
102
103// Dial is used to create a new outgoing connection
104func (l *RaftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
105	conn, err := net.DialTimeout("tcp", string(address), timeout)
106	if err != nil {
107		return nil, err
108	}
109
110	tlsWrapper := l.getTLSWrapper()
111
112	// Check for tls mode
113	if tlsWrapper != nil {
114		// Switch the connection into TLS mode
115		if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil {
116			conn.Close()
117			return nil, err
118		}
119
120		// Wrap the connection in a TLS client
121		conn, err = tlsWrapper(conn)
122		if err != nil {
123			return nil, err
124		}
125	}
126
127	// Write the Raft byte to set the mode
128	_, err = conn.Write([]byte{byte(pool.RpcRaft)})
129	if err != nil {
130		conn.Close()
131		return nil, err
132	}
133	return conn, err
134}
135