1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package fakenet
6
7import (
8	"io"
9	"net"
10	"sync"
11	"time"
12)
13
14// NewConn returns a net.Conn built on top of the supplied reader and writer.
15// It decouples the read and write on the conn from the underlying stream
16// to enable Close to abort ones that are in progress.
17// It's primary use is to fake a network connection from stdin and stdout.
18func NewConn(name string, in io.ReadCloser, out io.WriteCloser) net.Conn {
19	c := &fakeConn{
20		name:   name,
21		reader: newFeeder(in.Read),
22		writer: newFeeder(out.Write),
23		in:     in,
24		out:    out,
25	}
26	go c.reader.run()
27	go c.writer.run()
28	return c
29}
30
31type fakeConn struct {
32	name   string
33	reader *connFeeder
34	writer *connFeeder
35	in     io.ReadCloser
36	out    io.WriteCloser
37}
38
39type fakeAddr string
40
41// connFeeder serializes calls to the source function (io.Reader.Read or
42// io.Writer.Write) by delegating them to a channel. This also allows calls to
43// be intercepted when the connection is closed, and cancelled early if the
44// connection is closed while the calls are still outstanding.
45type connFeeder struct {
46	source func([]byte) (int, error)
47	input  chan []byte
48	result chan feedResult
49	mu     sync.Mutex
50	closed bool
51	done   chan struct{}
52}
53
54type feedResult struct {
55	n   int
56	err error
57}
58
59func (c *fakeConn) Close() error {
60	c.reader.close()
61	c.writer.close()
62	c.in.Close()
63	c.out.Close()
64	return nil
65}
66
67func (c *fakeConn) Read(b []byte) (n int, err error)   { return c.reader.do(b) }
68func (c *fakeConn) Write(b []byte) (n int, err error)  { return c.writer.do(b) }
69func (c *fakeConn) LocalAddr() net.Addr                { return fakeAddr(c.name) }
70func (c *fakeConn) RemoteAddr() net.Addr               { return fakeAddr(c.name) }
71func (c *fakeConn) SetDeadline(t time.Time) error      { return nil }
72func (c *fakeConn) SetReadDeadline(t time.Time) error  { return nil }
73func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
74func (a fakeAddr) Network() string                     { return "fake" }
75func (a fakeAddr) String() string                      { return string(a) }
76
77func newFeeder(source func([]byte) (int, error)) *connFeeder {
78	return &connFeeder{
79		source: source,
80		input:  make(chan []byte),
81		result: make(chan feedResult),
82		done:   make(chan struct{}),
83	}
84}
85
86func (f *connFeeder) close() {
87	f.mu.Lock()
88	if !f.closed {
89		f.closed = true
90		close(f.done)
91	}
92	f.mu.Unlock()
93}
94
95func (f *connFeeder) do(b []byte) (n int, err error) {
96	// send the request to the worker
97	select {
98	case f.input <- b:
99	case <-f.done:
100		return 0, io.EOF
101	}
102	// get the result from the worker
103	select {
104	case r := <-f.result:
105		return r.n, r.err
106	case <-f.done:
107		return 0, io.EOF
108	}
109}
110
111func (f *connFeeder) run() {
112	var b []byte
113	for {
114		// wait for an input request
115		select {
116		case b = <-f.input:
117		case <-f.done:
118			return
119		}
120		// invoke the underlying method
121		n, err := f.source(b)
122		// send the result back to the requester
123		select {
124		case f.result <- feedResult{n: n, err: err}:
125		case <-f.done:
126			return
127		}
128	}
129}
130