1// Copyright 2016 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package integration
16
17import (
18	"fmt"
19	"io"
20	"io/ioutil"
21	"net"
22	"sync"
23
24	"github.com/coreos/etcd/pkg/transport"
25)
26
27// bridge creates a unix socket bridge to another unix socket, making it possible
28// to disconnect grpc network connections without closing the logical grpc connection.
29type bridge struct {
30	inaddr  string
31	outaddr string
32	l       net.Listener
33	conns   map[*bridgeConn]struct{}
34
35	stopc      chan struct{}
36	pausec     chan struct{}
37	blackholec chan struct{}
38	wg         sync.WaitGroup
39
40	mu sync.Mutex
41}
42
43func newBridge(addr string) (*bridge, error) {
44	b := &bridge{
45		// bridge "port" is ("%05d%05d0", port, pid) since go1.8 expects the port to be a number
46		inaddr:     addr + "0",
47		outaddr:    addr,
48		conns:      make(map[*bridgeConn]struct{}),
49		stopc:      make(chan struct{}),
50		pausec:     make(chan struct{}),
51		blackholec: make(chan struct{}),
52	}
53	close(b.pausec)
54
55	l, err := transport.NewUnixListener(b.inaddr)
56	if err != nil {
57		return nil, fmt.Errorf("listen failed on socket %s (%v)", addr, err)
58	}
59	b.l = l
60	b.wg.Add(1)
61	go b.serveListen()
62	return b, nil
63}
64
65func (b *bridge) URL() string { return "unix://" + b.inaddr }
66
67func (b *bridge) Close() {
68	b.l.Close()
69	b.mu.Lock()
70	select {
71	case <-b.stopc:
72	default:
73		close(b.stopc)
74	}
75	b.mu.Unlock()
76	b.wg.Wait()
77}
78
79func (b *bridge) Reset() {
80	b.mu.Lock()
81	defer b.mu.Unlock()
82	for bc := range b.conns {
83		bc.Close()
84	}
85	b.conns = make(map[*bridgeConn]struct{})
86}
87
88func (b *bridge) Pause() {
89	b.mu.Lock()
90	b.pausec = make(chan struct{})
91	b.mu.Unlock()
92}
93
94func (b *bridge) Unpause() {
95	b.mu.Lock()
96	select {
97	case <-b.pausec:
98	default:
99		close(b.pausec)
100	}
101	b.mu.Unlock()
102}
103
104func (b *bridge) serveListen() {
105	defer func() {
106		b.l.Close()
107		b.mu.Lock()
108		for bc := range b.conns {
109			bc.Close()
110		}
111		b.mu.Unlock()
112		b.wg.Done()
113	}()
114
115	for {
116		inc, ierr := b.l.Accept()
117		if ierr != nil {
118			return
119		}
120		b.mu.Lock()
121		pausec := b.pausec
122		b.mu.Unlock()
123		select {
124		case <-b.stopc:
125			inc.Close()
126			return
127		case <-pausec:
128		}
129
130		outc, oerr := net.Dial("unix", b.outaddr)
131		if oerr != nil {
132			inc.Close()
133			return
134		}
135
136		bc := &bridgeConn{inc, outc, make(chan struct{})}
137		b.wg.Add(1)
138		b.mu.Lock()
139		b.conns[bc] = struct{}{}
140		go b.serveConn(bc)
141		b.mu.Unlock()
142	}
143}
144
145func (b *bridge) serveConn(bc *bridgeConn) {
146	defer func() {
147		close(bc.donec)
148		bc.Close()
149		b.mu.Lock()
150		delete(b.conns, bc)
151		b.mu.Unlock()
152		b.wg.Done()
153	}()
154
155	var wg sync.WaitGroup
156	wg.Add(2)
157	go func() {
158		b.ioCopy(bc, bc.out, bc.in)
159		bc.close()
160		wg.Done()
161	}()
162	go func() {
163		b.ioCopy(bc, bc.in, bc.out)
164		bc.close()
165		wg.Done()
166	}()
167	wg.Wait()
168}
169
170type bridgeConn struct {
171	in    net.Conn
172	out   net.Conn
173	donec chan struct{}
174}
175
176func (bc *bridgeConn) Close() {
177	bc.close()
178	<-bc.donec
179}
180
181func (bc *bridgeConn) close() {
182	bc.in.Close()
183	bc.out.Close()
184}
185
186func (b *bridge) Blackhole() {
187	b.mu.Lock()
188	close(b.blackholec)
189	b.mu.Unlock()
190}
191
192func (b *bridge) Unblackhole() {
193	b.mu.Lock()
194	for bc := range b.conns {
195		bc.Close()
196	}
197	b.conns = make(map[*bridgeConn]struct{})
198	b.blackholec = make(chan struct{})
199	b.mu.Unlock()
200}
201
202// ref. https://github.com/golang/go/blob/master/src/io/io.go copyBuffer
203func (b *bridge) ioCopy(bc *bridgeConn, dst io.Writer, src io.Reader) (err error) {
204	buf := make([]byte, 32*1024)
205	for {
206		select {
207		case <-b.blackholec:
208			io.Copy(ioutil.Discard, src)
209			return nil
210		default:
211		}
212		nr, er := src.Read(buf)
213		if nr > 0 {
214			nw, ew := dst.Write(buf[0:nr])
215			if ew != nil {
216				return ew
217			}
218			if nr != nw {
219				return io.ErrShortWrite
220			}
221		}
222		if er != nil {
223			err = er
224			break
225		}
226	}
227	return err
228}
229