1/*
2 *
3 * Copyright 2017 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package bufconn provides a net.Conn implemented by a buffer and related
20// dialing and listening functionality.
21package bufconn
22
23import (
24	"fmt"
25	"io"
26	"net"
27	"sync"
28	"time"
29)
30
31// Listener implements a net.Listener that creates local, buffered net.Conns
32// via its Accept and Dial method.
33type Listener struct {
34	mu   sync.Mutex
35	sz   int
36	ch   chan net.Conn
37	done chan struct{}
38}
39
40// Implementation of net.Error providing timeout
41type netErrorTimeout struct {
42	error
43}
44
45func (e netErrorTimeout) Timeout() bool   { return true }
46func (e netErrorTimeout) Temporary() bool { return false }
47
48var errClosed = fmt.Errorf("closed")
49var errTimeout net.Error = netErrorTimeout{error: fmt.Errorf("i/o timeout")}
50
51// Listen returns a Listener that can only be contacted by its own Dialers and
52// creates buffered connections between the two.
53func Listen(sz int) *Listener {
54	return &Listener{sz: sz, ch: make(chan net.Conn), done: make(chan struct{})}
55}
56
57// Accept blocks until Dial is called, then returns a net.Conn for the server
58// half of the connection.
59func (l *Listener) Accept() (net.Conn, error) {
60	select {
61	case <-l.done:
62		return nil, errClosed
63	case c := <-l.ch:
64		return c, nil
65	}
66}
67
68// Close stops the listener.
69func (l *Listener) Close() error {
70	l.mu.Lock()
71	defer l.mu.Unlock()
72	select {
73	case <-l.done:
74		// Already closed.
75		break
76	default:
77		close(l.done)
78	}
79	return nil
80}
81
82// Addr reports the address of the listener.
83func (l *Listener) Addr() net.Addr { return addr{} }
84
85// Dial creates an in-memory full-duplex network connection, unblocks Accept by
86// providing it the server half of the connection, and returns the client half
87// of the connection.
88func (l *Listener) Dial() (net.Conn, error) {
89	p1, p2 := newPipe(l.sz), newPipe(l.sz)
90	select {
91	case <-l.done:
92		return nil, errClosed
93	case l.ch <- &conn{p1, p2}:
94		return &conn{p2, p1}, nil
95	}
96}
97
98type pipe struct {
99	mu sync.Mutex
100
101	// buf contains the data in the pipe.  It is a ring buffer of fixed capacity,
102	// with r and w pointing to the offset to read and write, respsectively.
103	//
104	// Data is read between [r, w) and written to [w, r), wrapping around the end
105	// of the slice if necessary.
106	//
107	// The buffer is empty if r == len(buf), otherwise if r == w, it is full.
108	//
109	// w and r are always in the range [0, cap(buf)) and [0, len(buf)].
110	buf  []byte
111	w, r int
112
113	wwait sync.Cond
114	rwait sync.Cond
115
116	// Indicate that a write/read timeout has occurred
117	wtimedout bool
118	rtimedout bool
119
120	wtimer *time.Timer
121	rtimer *time.Timer
122
123	closed      bool
124	writeClosed bool
125}
126
127func newPipe(sz int) *pipe {
128	p := &pipe{buf: make([]byte, 0, sz)}
129	p.wwait.L = &p.mu
130	p.rwait.L = &p.mu
131
132	p.wtimer = time.AfterFunc(0, func() {})
133	p.rtimer = time.AfterFunc(0, func() {})
134	return p
135}
136
137func (p *pipe) empty() bool {
138	return p.r == len(p.buf)
139}
140
141func (p *pipe) full() bool {
142	return p.r < len(p.buf) && p.r == p.w
143}
144
145func (p *pipe) Read(b []byte) (n int, err error) {
146	p.mu.Lock()
147	defer p.mu.Unlock()
148	// Block until p has data.
149	for {
150		if p.closed {
151			return 0, io.ErrClosedPipe
152		}
153		if !p.empty() {
154			break
155		}
156		if p.writeClosed {
157			return 0, io.EOF
158		}
159		if p.rtimedout {
160			return 0, errTimeout
161		}
162
163		p.rwait.Wait()
164	}
165	wasFull := p.full()
166
167	n = copy(b, p.buf[p.r:len(p.buf)])
168	p.r += n
169	if p.r == cap(p.buf) {
170		p.r = 0
171		p.buf = p.buf[:p.w]
172	}
173
174	// Signal a blocked writer, if any
175	if wasFull {
176		p.wwait.Signal()
177	}
178
179	return n, nil
180}
181
182func (p *pipe) Write(b []byte) (n int, err error) {
183	p.mu.Lock()
184	defer p.mu.Unlock()
185	if p.closed {
186		return 0, io.ErrClosedPipe
187	}
188	for len(b) > 0 {
189		// Block until p is not full.
190		for {
191			if p.closed || p.writeClosed {
192				return 0, io.ErrClosedPipe
193			}
194			if !p.full() {
195				break
196			}
197			if p.wtimedout {
198				return 0, errTimeout
199			}
200
201			p.wwait.Wait()
202		}
203		wasEmpty := p.empty()
204
205		end := cap(p.buf)
206		if p.w < p.r {
207			end = p.r
208		}
209		x := copy(p.buf[p.w:end], b)
210		b = b[x:]
211		n += x
212		p.w += x
213		if p.w > len(p.buf) {
214			p.buf = p.buf[:p.w]
215		}
216		if p.w == cap(p.buf) {
217			p.w = 0
218		}
219
220		// Signal a blocked reader, if any.
221		if wasEmpty {
222			p.rwait.Signal()
223		}
224	}
225	return n, nil
226}
227
228func (p *pipe) Close() error {
229	p.mu.Lock()
230	defer p.mu.Unlock()
231	p.closed = true
232	// Signal all blocked readers and writers to return an error.
233	p.rwait.Broadcast()
234	p.wwait.Broadcast()
235	return nil
236}
237
238func (p *pipe) closeWrite() error {
239	p.mu.Lock()
240	defer p.mu.Unlock()
241	p.writeClosed = true
242	// Signal all blocked readers and writers to return an error.
243	p.rwait.Broadcast()
244	p.wwait.Broadcast()
245	return nil
246}
247
248type conn struct {
249	io.Reader
250	io.Writer
251}
252
253func (c *conn) Close() error {
254	err1 := c.Reader.(*pipe).Close()
255	err2 := c.Writer.(*pipe).closeWrite()
256	if err1 != nil {
257		return err1
258	}
259	return err2
260}
261
262func (c *conn) SetDeadline(t time.Time) error {
263	c.SetReadDeadline(t)
264	c.SetWriteDeadline(t)
265	return nil
266}
267
268func (c *conn) SetReadDeadline(t time.Time) error {
269	p := c.Reader.(*pipe)
270	p.mu.Lock()
271	defer p.mu.Unlock()
272	p.rtimer.Stop()
273	p.rtimedout = false
274	if !t.IsZero() {
275		p.rtimer = time.AfterFunc(time.Until(t), func() {
276			p.mu.Lock()
277			defer p.mu.Unlock()
278			p.rtimedout = true
279			p.rwait.Broadcast()
280		})
281	}
282	return nil
283}
284
285func (c *conn) SetWriteDeadline(t time.Time) error {
286	p := c.Writer.(*pipe)
287	p.mu.Lock()
288	defer p.mu.Unlock()
289	p.wtimer.Stop()
290	p.wtimedout = false
291	if !t.IsZero() {
292		p.wtimer = time.AfterFunc(time.Until(t), func() {
293			p.mu.Lock()
294			defer p.mu.Unlock()
295			p.wtimedout = true
296			p.wwait.Broadcast()
297		})
298	}
299	return nil
300}
301
302func (*conn) LocalAddr() net.Addr  { return addr{} }
303func (*conn) RemoteAddr() net.Addr { return addr{} }
304
305type addr struct{}
306
307func (addr) Network() string { return "bufconn" }
308func (addr) String() string  { return "bufconn" }
309