1/*
2 *
3 * Copyright 2017 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package bufconn provides a net.Conn implemented by a buffer and related
20// dialing and listening functionality.
21package bufconn
22
23import (
24	"context"
25	"fmt"
26	"io"
27	"net"
28	"sync"
29	"time"
30)
31
32// Listener implements a net.Listener that creates local, buffered net.Conns
33// via its Accept and Dial method.
34type Listener struct {
35	mu   sync.Mutex
36	sz   int
37	ch   chan net.Conn
38	done chan struct{}
39}
40
41// Implementation of net.Error providing timeout
42type netErrorTimeout struct {
43	error
44}
45
46func (e netErrorTimeout) Timeout() bool   { return true }
47func (e netErrorTimeout) Temporary() bool { return false }
48
49var errClosed = fmt.Errorf("closed")
50var errTimeout net.Error = netErrorTimeout{error: fmt.Errorf("i/o timeout")}
51
52// Listen returns a Listener that can only be contacted by its own Dialers and
53// creates buffered connections between the two.
54func Listen(sz int) *Listener {
55	return &Listener{sz: sz, ch: make(chan net.Conn), done: make(chan struct{})}
56}
57
58// Accept blocks until Dial is called, then returns a net.Conn for the server
59// half of the connection.
60func (l *Listener) Accept() (net.Conn, error) {
61	select {
62	case <-l.done:
63		return nil, errClosed
64	case c := <-l.ch:
65		return c, nil
66	}
67}
68
69// Close stops the listener.
70func (l *Listener) Close() error {
71	l.mu.Lock()
72	defer l.mu.Unlock()
73	select {
74	case <-l.done:
75		// Already closed.
76		break
77	default:
78		close(l.done)
79	}
80	return nil
81}
82
83// Addr reports the address of the listener.
84func (l *Listener) Addr() net.Addr { return addr{} }
85
86// Dial creates an in-memory full-duplex network connection, unblocks Accept by
87// providing it the server half of the connection, and returns the client half
88// of the connection.
89func (l *Listener) Dial() (net.Conn, error) {
90	return l.DialContext(context.Background())
91}
92
93// DialContext creates an in-memory full-duplex network connection, unblocks Accept by
94// providing it the server half of the connection, and returns the client half
95// of the connection.  If ctx is Done, returns ctx.Err()
96func (l *Listener) DialContext(ctx context.Context) (net.Conn, error) {
97	p1, p2 := newPipe(l.sz), newPipe(l.sz)
98	select {
99	case <-ctx.Done():
100		return nil, ctx.Err()
101	case <-l.done:
102		return nil, errClosed
103	case l.ch <- &conn{p1, p2}:
104		return &conn{p2, p1}, nil
105	}
106}
107
108type pipe struct {
109	mu sync.Mutex
110
111	// buf contains the data in the pipe.  It is a ring buffer of fixed capacity,
112	// with r and w pointing to the offset to read and write, respsectively.
113	//
114	// Data is read between [r, w) and written to [w, r), wrapping around the end
115	// of the slice if necessary.
116	//
117	// The buffer is empty if r == len(buf), otherwise if r == w, it is full.
118	//
119	// w and r are always in the range [0, cap(buf)) and [0, len(buf)].
120	buf  []byte
121	w, r int
122
123	wwait sync.Cond
124	rwait sync.Cond
125
126	// Indicate that a write/read timeout has occurred
127	wtimedout bool
128	rtimedout bool
129
130	wtimer *time.Timer
131	rtimer *time.Timer
132
133	closed      bool
134	writeClosed bool
135}
136
137func newPipe(sz int) *pipe {
138	p := &pipe{buf: make([]byte, 0, sz)}
139	p.wwait.L = &p.mu
140	p.rwait.L = &p.mu
141
142	p.wtimer = time.AfterFunc(0, func() {})
143	p.rtimer = time.AfterFunc(0, func() {})
144	return p
145}
146
147func (p *pipe) empty() bool {
148	return p.r == len(p.buf)
149}
150
151func (p *pipe) full() bool {
152	return p.r < len(p.buf) && p.r == p.w
153}
154
155func (p *pipe) Read(b []byte) (n int, err error) {
156	p.mu.Lock()
157	defer p.mu.Unlock()
158	// Block until p has data.
159	for {
160		if p.closed {
161			return 0, io.ErrClosedPipe
162		}
163		if !p.empty() {
164			break
165		}
166		if p.writeClosed {
167			return 0, io.EOF
168		}
169		if p.rtimedout {
170			return 0, errTimeout
171		}
172
173		p.rwait.Wait()
174	}
175	wasFull := p.full()
176
177	n = copy(b, p.buf[p.r:len(p.buf)])
178	p.r += n
179	if p.r == cap(p.buf) {
180		p.r = 0
181		p.buf = p.buf[:p.w]
182	}
183
184	// Signal a blocked writer, if any
185	if wasFull {
186		p.wwait.Signal()
187	}
188
189	return n, nil
190}
191
192func (p *pipe) Write(b []byte) (n int, err error) {
193	p.mu.Lock()
194	defer p.mu.Unlock()
195	if p.closed {
196		return 0, io.ErrClosedPipe
197	}
198	for len(b) > 0 {
199		// Block until p is not full.
200		for {
201			if p.closed || p.writeClosed {
202				return 0, io.ErrClosedPipe
203			}
204			if !p.full() {
205				break
206			}
207			if p.wtimedout {
208				return 0, errTimeout
209			}
210
211			p.wwait.Wait()
212		}
213		wasEmpty := p.empty()
214
215		end := cap(p.buf)
216		if p.w < p.r {
217			end = p.r
218		}
219		x := copy(p.buf[p.w:end], b)
220		b = b[x:]
221		n += x
222		p.w += x
223		if p.w > len(p.buf) {
224			p.buf = p.buf[:p.w]
225		}
226		if p.w == cap(p.buf) {
227			p.w = 0
228		}
229
230		// Signal a blocked reader, if any.
231		if wasEmpty {
232			p.rwait.Signal()
233		}
234	}
235	return n, nil
236}
237
238func (p *pipe) Close() error {
239	p.mu.Lock()
240	defer p.mu.Unlock()
241	p.closed = true
242	// Signal all blocked readers and writers to return an error.
243	p.rwait.Broadcast()
244	p.wwait.Broadcast()
245	return nil
246}
247
248func (p *pipe) closeWrite() error {
249	p.mu.Lock()
250	defer p.mu.Unlock()
251	p.writeClosed = true
252	// Signal all blocked readers and writers to return an error.
253	p.rwait.Broadcast()
254	p.wwait.Broadcast()
255	return nil
256}
257
258type conn struct {
259	io.Reader
260	io.Writer
261}
262
263func (c *conn) Close() error {
264	err1 := c.Reader.(*pipe).Close()
265	err2 := c.Writer.(*pipe).closeWrite()
266	if err1 != nil {
267		return err1
268	}
269	return err2
270}
271
272func (c *conn) SetDeadline(t time.Time) error {
273	c.SetReadDeadline(t)
274	c.SetWriteDeadline(t)
275	return nil
276}
277
278func (c *conn) SetReadDeadline(t time.Time) error {
279	p := c.Reader.(*pipe)
280	p.mu.Lock()
281	defer p.mu.Unlock()
282	p.rtimer.Stop()
283	p.rtimedout = false
284	if !t.IsZero() {
285		p.rtimer = time.AfterFunc(time.Until(t), func() {
286			p.mu.Lock()
287			defer p.mu.Unlock()
288			p.rtimedout = true
289			p.rwait.Broadcast()
290		})
291	}
292	return nil
293}
294
295func (c *conn) SetWriteDeadline(t time.Time) error {
296	p := c.Writer.(*pipe)
297	p.mu.Lock()
298	defer p.mu.Unlock()
299	p.wtimer.Stop()
300	p.wtimedout = false
301	if !t.IsZero() {
302		p.wtimer = time.AfterFunc(time.Until(t), func() {
303			p.mu.Lock()
304			defer p.mu.Unlock()
305			p.wtimedout = true
306			p.wwait.Broadcast()
307		})
308	}
309	return nil
310}
311
312func (*conn) LocalAddr() net.Addr  { return addr{} }
313func (*conn) RemoteAddr() net.Addr { return addr{} }
314
315type addr struct{}
316
317func (addr) Network() string { return "bufconn" }
318func (addr) String() string  { return "bufconn" }
319