1package marionette
2
3import (
4	"expvar"
5	"fmt"
6	"io"
7	"math/rand"
8	"os"
9	"path/filepath"
10	"strconv"
11	"sync"
12	"time"
13
14	"github.com/ooni/psiphon/oopsi/go.uber.org/zap"
15)
16
17const (
18	// StreamCloseTimeout is the amount of time before an idle read-closed or
19	// write-closed stream is reaped by a monitoring goroutine.
20	StreamCloseTimeout = 5 * time.Second
21)
22
23// evStreams is a global expvar variable for tracking open streams.
24var evStreams = expvar.NewInt("streams")
25
26// StreamSet represents a multiplexer for a set of streams for a connection.
27type StreamSet struct {
28	mu        sync.RWMutex
29	streams   map[int]*Stream // streams by id
30	streamIDs []int           // cached list of all stream ids
31	wnotify   chan struct{}   // notification of write changes
32
33	// Close management
34	closing chan struct{}
35	once    sync.Once
36	wg      sync.WaitGroup
37
38	// Callback executed when a new stream is created.
39	OnNewStream func(*Stream)
40
41	// Directory for storing stream traces.
42	TracePath string
43}
44
45// NewStreamSet returns a new instance of StreamSet.
46func NewStreamSet() *StreamSet {
47	ss := &StreamSet{
48		streams: make(map[int]*Stream),
49		closing: make(chan struct{}),
50		wnotify: make(chan struct{}),
51	}
52	return ss
53}
54
55// Close closes all streams in the set.
56func (ss *StreamSet) Close() (err error) {
57	ss.mu.Lock()
58	for _, stream := range ss.streams {
59		if e := stream.CloseWrite(); e != nil && err == nil {
60			err = e
61		} else if e := stream.CloseRead(); e != nil && err == nil {
62			err = e
63		}
64	}
65	ss.mu.Unlock()
66
67	ss.once.Do(func() { close(ss.closing) })
68	ss.wg.Wait()
69	return err
70}
71
72// monitorStream checks a stream until its read & write channels are closed
73// and then removes the stream from the set.
74func (ss *StreamSet) monitorStream(stream *Stream) {
75	readCloseNotify := stream.ReadCloseNotify()
76	writeCloseNotifiedNotify := stream.WriteCloseNotifiedNotify()
77	var timeout <-chan time.Time
78
79LOOP:
80	for {
81		// Wait until stream closed state is changed or the set is closed.
82		select {
83		case <-ss.closing:
84			break LOOP
85		case <-timeout:
86			break LOOP
87		case <-readCloseNotify:
88			readCloseNotify = nil
89			timeout = time.After(StreamCloseTimeout)
90		case <-writeCloseNotifiedNotify:
91			writeCloseNotifiedNotify = nil
92			timeout = time.After(StreamCloseTimeout)
93		}
94
95		// If stream is completely closed then remove from the set.
96		if stream.ReadWriteCloseNotified() {
97			break
98		}
99	}
100
101	// Ensure both sides are closed.
102	stream.CloseRead()
103	stream.CloseWrite()
104
105	ss.mu.Lock()
106	ss.remove(stream)
107	ss.mu.Unlock()
108}
109
110// Stream returns a stream by id.
111func (ss *StreamSet) Stream(id int) *Stream {
112	ss.mu.Lock()
113	defer ss.mu.Unlock()
114	return ss.streams[id]
115}
116
117// Streams returns a list of streams.
118func (ss *StreamSet) Streams() []*Stream {
119	ss.mu.Lock()
120	defer ss.mu.Unlock()
121
122	streams := make([]*Stream, 0, len(ss.streams))
123	for _, stream := range ss.streams {
124		streams = append(streams, stream)
125	}
126	return streams
127}
128
129// Create returns a new stream with a random stream id.
130func (ss *StreamSet) Create() *Stream {
131	ss.mu.Lock()
132	defer ss.mu.Unlock()
133	return ss.create(0)
134}
135
136func (ss *StreamSet) create(id int) *Stream {
137	if id == 0 {
138		id = int(rand.Int31() + 1)
139	}
140
141	stream := NewStream(id)
142
143	// Create a per-stream log if trace path is specified.
144	if ss.TracePath != "" {
145		path := filepath.Join(ss.TracePath, strconv.Itoa(id))
146		if err := os.MkdirAll(ss.TracePath, 0777); err != nil {
147			Logger.Warn("cannot create trace directory", zap.Error(err))
148		} else if w, err := os.Create(path); err != nil {
149			Logger.Warn("cannot create trace file", zap.Error(err))
150		} else {
151			fmt.Fprintf(w, "# STREAM %d\n\n", id)
152			stream.TraceWriter = &timestampWriter{Writer: w}
153		}
154		stream.TraceWriter.Write([]byte("[create]"))
155	}
156
157	// Add stream to set.
158	ss.streams[stream.id] = stream
159	ss.streamIDs = append(ss.streamIDs, stream.id)
160
161	// Add to global counter.
162	evStreams.Add(1)
163
164	// Monitor each stream closing in a separate goroutine.
165	ss.wg.Add(1)
166	go func() { defer ss.wg.Done(); ss.monitorStream(stream) }()
167
168	// Monitor read/write changes in a separate goroutine per stream.
169	ss.wg.Add(1)
170	go func() { defer ss.wg.Done(); ss.handleStream(stream) }()
171
172	// Execute callback, if exists.
173	if ss.OnNewStream != nil {
174		ss.OnNewStream(stream)
175	}
176
177	return stream
178}
179
180// remove removes stream from the set and decrements open stream count.
181// This must be called under lock.
182func (ss *StreamSet) remove(stream *Stream) {
183	streamID := stream.ID()
184
185	evStreams.Add(-1)
186
187	if stream.TraceWriter != nil {
188		stream.TraceWriter.Write([]byte("[remove]"))
189		if traceWriter, ok := stream.TraceWriter.(io.Closer); ok {
190			traceWriter.Close()
191		}
192	}
193	delete(ss.streams, streamID)
194
195	for i, id := range ss.streamIDs {
196		if id == streamID {
197			ss.streamIDs = append(ss.streamIDs[:i], ss.streamIDs[i+1:]...)
198		}
199	}
200}
201
202// Enqueue pushes a cell onto a stream's read queue.
203// If the stream doesn't exist then it is created.
204func (ss *StreamSet) Enqueue(cell *Cell) error {
205	ss.mu.Lock()
206	defer ss.mu.Unlock()
207
208	// Ignore empty cells.
209	if cell.StreamID == 0 {
210		return nil
211	}
212
213	// Create or find stream and enqueue cell.
214	stream := ss.streams[cell.StreamID]
215	if stream == nil {
216		stream = ss.create(cell.StreamID)
217	}
218	return stream.Enqueue(cell)
219}
220
221// Dequeue returns a cell containing data for a random stream's write buffer.
222func (ss *StreamSet) Dequeue(n int) *Cell {
223	ss.mu.Lock()
224	defer ss.mu.Unlock()
225
226	// Choose a random stream with data.
227	var stream *Stream
228	for _, i := range rand.Perm(len(ss.streamIDs)) {
229		s := ss.streams[ss.streamIDs[i]]
230		if s.WriteBufferLen() > 0 || s.WriteClosed() {
231			stream = s
232			break
233		}
234	}
235
236	// If there is no stream with data then send an empty
237	if stream == nil {
238		return nil
239	}
240
241	// Generate cell from stream.
242	return stream.Dequeue(n)
243}
244
245// WriteNotify returns a channel that receives a notification when a new write is available.
246func (ss *StreamSet) WriteNotify() <-chan struct{} {
247	ss.mu.RLock()
248	defer ss.mu.RUnlock()
249	return ss.wnotify
250}
251
252// notifyWrite closes previous write notification channel and creates a new one.
253// This provides a broadcast to all interested parties.
254func (ss *StreamSet) notifyWrite() {
255	ss.mu.Lock()
256	close(ss.wnotify)
257	ss.wnotify = make(chan struct{})
258	ss.mu.Unlock()
259}
260
261// handleStream continually monitors write changes for stream.
262func (ss *StreamSet) handleStream(stream *Stream) {
263	notify := stream.WriteNotify()
264	ss.notifyWrite()
265
266	for {
267		select {
268		case <-notify:
269			notify = stream.WriteNotify()
270			ss.notifyWrite()
271		case <-stream.WriteCloseNotify():
272			ss.notifyWrite()
273			return
274		}
275	}
276}
277
278// timestampWriter wraps a writer and prepends a timestamp & appends a newline to every write.
279type timestampWriter struct {
280	Writer io.Writer
281}
282
283func (w *timestampWriter) Write(p []byte) (n int, err error) {
284	return fmt.Fprintf(w.Writer, "%s %s\n", time.Now().UTC().Format("2006-01-02T15:04:05.000Z"), p)
285}
286