1package mocknet
2
3import (
4	"bytes"
5	"errors"
6	"io"
7	"net"
8	"strconv"
9	"sync"
10	"sync/atomic"
11	"time"
12
13	"github.com/libp2p/go-libp2p-core/mux"
14	"github.com/libp2p/go-libp2p-core/network"
15	protocol "github.com/libp2p/go-libp2p-core/protocol"
16)
17
18var streamCounter int64
19
20// stream implements network.Stream
21type stream struct {
22	notifLk sync.Mutex
23
24	rstream *stream
25	conn    *conn
26	id      int64
27
28	write     *io.PipeWriter
29	read      *io.PipeReader
30	toDeliver chan *transportObject
31
32	reset  chan struct{}
33	close  chan struct{}
34	closed chan struct{}
35
36	writeErr error
37
38	protocol atomic.Value
39	stat     network.Stat
40}
41
42var ErrClosed error = errors.New("stream closed")
43
44type transportObject struct {
45	msg         []byte
46	arrivalTime time.Time
47}
48
49func newStreamPair() (*stream, *stream) {
50	ra, wb := io.Pipe()
51	rb, wa := io.Pipe()
52
53	sa := newStream(wa, ra, network.DirOutbound)
54	sb := newStream(wb, rb, network.DirInbound)
55	sa.rstream = sb
56	sb.rstream = sa
57	return sa, sb
58}
59
60func newStream(w *io.PipeWriter, r *io.PipeReader, dir network.Direction) *stream {
61	s := &stream{
62		read:      r,
63		write:     w,
64		id:        atomic.AddInt64(&streamCounter, 1),
65		reset:     make(chan struct{}, 1),
66		close:     make(chan struct{}, 1),
67		closed:    make(chan struct{}),
68		toDeliver: make(chan *transportObject),
69		stat:      network.Stat{Direction: dir},
70	}
71
72	go s.transport()
73	return s
74}
75
76//  How to handle errors with writes?
77func (s *stream) Write(p []byte) (n int, err error) {
78	l := s.conn.link
79	delay := l.GetLatency() + l.RateLimit(len(p))
80	t := time.Now().Add(delay)
81
82	// Copy it.
83	cpy := make([]byte, len(p))
84	copy(cpy, p)
85
86	select {
87	case <-s.closed: // bail out if we're closing.
88		return 0, s.writeErr
89	case s.toDeliver <- &transportObject{msg: cpy, arrivalTime: t}:
90	}
91	return len(p), nil
92}
93
94func (s *stream) ID() string {
95	return strconv.FormatInt(s.id, 10)
96}
97
98func (s *stream) Protocol() protocol.ID {
99	// Ignore type error. It means that the protocol is unset.
100	p, _ := s.protocol.Load().(protocol.ID)
101	return p
102}
103
104func (s *stream) Stat() network.Stat {
105	return s.stat
106}
107
108func (s *stream) SetProtocol(proto protocol.ID) {
109	s.protocol.Store(proto)
110}
111
112func (s *stream) CloseWrite() error {
113	select {
114	case s.close <- struct{}{}:
115	default:
116	}
117	<-s.closed
118	if s.writeErr != ErrClosed {
119		return s.writeErr
120	}
121	return nil
122}
123
124func (s *stream) CloseRead() error {
125	return s.read.CloseWithError(ErrClosed)
126}
127
128func (s *stream) Close() error {
129	_ = s.CloseRead()
130	return s.CloseWrite()
131}
132
133func (s *stream) Reset() error {
134	// Cancel any pending reads/writes with an error.
135	s.write.CloseWithError(mux.ErrReset)
136	s.read.CloseWithError(mux.ErrReset)
137
138	select {
139	case s.reset <- struct{}{}:
140	default:
141	}
142	<-s.closed
143
144	// No meaningful error case here.
145	return nil
146}
147
148func (s *stream) teardown() {
149	// at this point, no streams are writing.
150	s.conn.removeStream(s)
151
152	// Mark as closed.
153	close(s.closed)
154}
155
156func (s *stream) Conn() network.Conn {
157	return s.conn
158}
159
160func (s *stream) SetDeadline(t time.Time) error {
161	return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
162}
163
164func (s *stream) SetReadDeadline(t time.Time) error {
165	return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
166}
167
168func (s *stream) SetWriteDeadline(t time.Time) error {
169	return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
170}
171
172func (s *stream) Read(b []byte) (int, error) {
173	return s.read.Read(b)
174}
175
176// transport will grab message arrival times, wait until that time, and
177// then write the message out when it is scheduled to arrive
178func (s *stream) transport() {
179	defer s.teardown()
180
181	bufsize := 256
182	buf := new(bytes.Buffer)
183	timer := time.NewTimer(0)
184	if !timer.Stop() {
185		select {
186		case <-timer.C:
187		default:
188		}
189	}
190
191	// cleanup
192	defer timer.Stop()
193
194	// writeBuf writes the contents of buf through to the s.Writer.
195	// done only when arrival time makes sense.
196	drainBuf := func() error {
197		if buf.Len() > 0 {
198			_, err := s.write.Write(buf.Bytes())
199			if err != nil {
200				return err
201			}
202			buf.Reset()
203		}
204		return nil
205	}
206
207	// deliverOrWait is a helper func that processes
208	// an incoming packet. it waits until the arrival time,
209	// and then writes things out.
210	deliverOrWait := func(o *transportObject) error {
211		buffered := len(o.msg) + buf.Len()
212
213		// Yes, we can end up extending a timer multiple times if we
214		// keep on making small writes but that shouldn't be too much of an
215		// issue. Fixing that would be painful.
216		if !timer.Stop() {
217			// FIXME: So, we *shouldn't* need to do this but we hang
218			// here if we don't... Go bug?
219			select {
220			case <-timer.C:
221			default:
222			}
223		}
224		delay := time.Until(o.arrivalTime)
225		if delay >= 0 {
226			timer.Reset(delay)
227		} else {
228			timer.Reset(0)
229		}
230
231		if buffered >= bufsize {
232			select {
233			case <-timer.C:
234			case <-s.reset:
235				select {
236				case s.reset <- struct{}{}:
237				default:
238				}
239				return mux.ErrReset
240			}
241			if err := drainBuf(); err != nil {
242				return err
243			}
244			// write this message.
245			_, err := s.write.Write(o.msg)
246			if err != nil {
247				return err
248			}
249		} else {
250			buf.Write(o.msg)
251		}
252		return nil
253	}
254
255	for {
256		// Reset takes precedent.
257		select {
258		case <-s.reset:
259			s.writeErr = mux.ErrReset
260			return
261		default:
262		}
263
264		select {
265		case <-s.reset:
266			s.writeErr = mux.ErrReset
267			return
268		case <-s.close:
269			if err := drainBuf(); err != nil {
270				s.cancelWrite(err)
271				return
272			}
273			s.writeErr = s.write.Close()
274			if s.writeErr == nil {
275				s.writeErr = ErrClosed
276			}
277			return
278		case o := <-s.toDeliver:
279			if err := deliverOrWait(o); err != nil {
280				s.cancelWrite(err)
281				return
282			}
283		case <-timer.C: // ok, due to write it out.
284			if err := drainBuf(); err != nil {
285				s.cancelWrite(err)
286				return
287			}
288		}
289	}
290}
291
292func (s *stream) cancelWrite(err error) {
293	s.write.CloseWithError(err)
294	s.writeErr = err
295}
296