1package marionette 2 3import ( 4 "expvar" 5 "fmt" 6 "io" 7 "math/rand" 8 "os" 9 "path/filepath" 10 "strconv" 11 "sync" 12 "time" 13 14 "github.com/ooni/psiphon/oopsi/go.uber.org/zap" 15) 16 17const ( 18 // StreamCloseTimeout is the amount of time before an idle read-closed or 19 // write-closed stream is reaped by a monitoring goroutine. 20 StreamCloseTimeout = 5 * time.Second 21) 22 23// evStreams is a global expvar variable for tracking open streams. 24var evStreams = expvar.NewInt("streams") 25 26// StreamSet represents a multiplexer for a set of streams for a connection. 27type StreamSet struct { 28 mu sync.RWMutex 29 streams map[int]*Stream // streams by id 30 streamIDs []int // cached list of all stream ids 31 wnotify chan struct{} // notification of write changes 32 33 // Close management 34 closing chan struct{} 35 once sync.Once 36 wg sync.WaitGroup 37 38 // Callback executed when a new stream is created. 39 OnNewStream func(*Stream) 40 41 // Directory for storing stream traces. 42 TracePath string 43} 44 45// NewStreamSet returns a new instance of StreamSet. 46func NewStreamSet() *StreamSet { 47 ss := &StreamSet{ 48 streams: make(map[int]*Stream), 49 closing: make(chan struct{}), 50 wnotify: make(chan struct{}), 51 } 52 return ss 53} 54 55// Close closes all streams in the set. 56func (ss *StreamSet) Close() (err error) { 57 ss.mu.Lock() 58 for _, stream := range ss.streams { 59 if e := stream.CloseWrite(); e != nil && err == nil { 60 err = e 61 } else if e := stream.CloseRead(); e != nil && err == nil { 62 err = e 63 } 64 } 65 ss.mu.Unlock() 66 67 ss.once.Do(func() { close(ss.closing) }) 68 ss.wg.Wait() 69 return err 70} 71 72// monitorStream checks a stream until its read & write channels are closed 73// and then removes the stream from the set. 74func (ss *StreamSet) monitorStream(stream *Stream) { 75 readCloseNotify := stream.ReadCloseNotify() 76 writeCloseNotifiedNotify := stream.WriteCloseNotifiedNotify() 77 var timeout <-chan time.Time 78 79LOOP: 80 for { 81 // Wait until stream closed state is changed or the set is closed. 82 select { 83 case <-ss.closing: 84 break LOOP 85 case <-timeout: 86 break LOOP 87 case <-readCloseNotify: 88 readCloseNotify = nil 89 timeout = time.After(StreamCloseTimeout) 90 case <-writeCloseNotifiedNotify: 91 writeCloseNotifiedNotify = nil 92 timeout = time.After(StreamCloseTimeout) 93 } 94 95 // If stream is completely closed then remove from the set. 96 if stream.ReadWriteCloseNotified() { 97 break 98 } 99 } 100 101 // Ensure both sides are closed. 102 stream.CloseRead() 103 stream.CloseWrite() 104 105 ss.mu.Lock() 106 ss.remove(stream) 107 ss.mu.Unlock() 108} 109 110// Stream returns a stream by id. 111func (ss *StreamSet) Stream(id int) *Stream { 112 ss.mu.Lock() 113 defer ss.mu.Unlock() 114 return ss.streams[id] 115} 116 117// Streams returns a list of streams. 118func (ss *StreamSet) Streams() []*Stream { 119 ss.mu.Lock() 120 defer ss.mu.Unlock() 121 122 streams := make([]*Stream, 0, len(ss.streams)) 123 for _, stream := range ss.streams { 124 streams = append(streams, stream) 125 } 126 return streams 127} 128 129// Create returns a new stream with a random stream id. 130func (ss *StreamSet) Create() *Stream { 131 ss.mu.Lock() 132 defer ss.mu.Unlock() 133 return ss.create(0) 134} 135 136func (ss *StreamSet) create(id int) *Stream { 137 if id == 0 { 138 id = int(rand.Int31() + 1) 139 } 140 141 stream := NewStream(id) 142 143 // Create a per-stream log if trace path is specified. 144 if ss.TracePath != "" { 145 path := filepath.Join(ss.TracePath, strconv.Itoa(id)) 146 if err := os.MkdirAll(ss.TracePath, 0777); err != nil { 147 Logger.Warn("cannot create trace directory", zap.Error(err)) 148 } else if w, err := os.Create(path); err != nil { 149 Logger.Warn("cannot create trace file", zap.Error(err)) 150 } else { 151 fmt.Fprintf(w, "# STREAM %d\n\n", id) 152 stream.TraceWriter = ×tampWriter{Writer: w} 153 } 154 stream.TraceWriter.Write([]byte("[create]")) 155 } 156 157 // Add stream to set. 158 ss.streams[stream.id] = stream 159 ss.streamIDs = append(ss.streamIDs, stream.id) 160 161 // Add to global counter. 162 evStreams.Add(1) 163 164 // Monitor each stream closing in a separate goroutine. 165 ss.wg.Add(1) 166 go func() { defer ss.wg.Done(); ss.monitorStream(stream) }() 167 168 // Monitor read/write changes in a separate goroutine per stream. 169 ss.wg.Add(1) 170 go func() { defer ss.wg.Done(); ss.handleStream(stream) }() 171 172 // Execute callback, if exists. 173 if ss.OnNewStream != nil { 174 ss.OnNewStream(stream) 175 } 176 177 return stream 178} 179 180// remove removes stream from the set and decrements open stream count. 181// This must be called under lock. 182func (ss *StreamSet) remove(stream *Stream) { 183 streamID := stream.ID() 184 185 evStreams.Add(-1) 186 187 if stream.TraceWriter != nil { 188 stream.TraceWriter.Write([]byte("[remove]")) 189 if traceWriter, ok := stream.TraceWriter.(io.Closer); ok { 190 traceWriter.Close() 191 } 192 } 193 delete(ss.streams, streamID) 194 195 for i, id := range ss.streamIDs { 196 if id == streamID { 197 ss.streamIDs = append(ss.streamIDs[:i], ss.streamIDs[i+1:]...) 198 } 199 } 200} 201 202// Enqueue pushes a cell onto a stream's read queue. 203// If the stream doesn't exist then it is created. 204func (ss *StreamSet) Enqueue(cell *Cell) error { 205 ss.mu.Lock() 206 defer ss.mu.Unlock() 207 208 // Ignore empty cells. 209 if cell.StreamID == 0 { 210 return nil 211 } 212 213 // Create or find stream and enqueue cell. 214 stream := ss.streams[cell.StreamID] 215 if stream == nil { 216 stream = ss.create(cell.StreamID) 217 } 218 return stream.Enqueue(cell) 219} 220 221// Dequeue returns a cell containing data for a random stream's write buffer. 222func (ss *StreamSet) Dequeue(n int) *Cell { 223 ss.mu.Lock() 224 defer ss.mu.Unlock() 225 226 // Choose a random stream with data. 227 var stream *Stream 228 for _, i := range rand.Perm(len(ss.streamIDs)) { 229 s := ss.streams[ss.streamIDs[i]] 230 if s.WriteBufferLen() > 0 || s.WriteClosed() { 231 stream = s 232 break 233 } 234 } 235 236 // If there is no stream with data then send an empty 237 if stream == nil { 238 return nil 239 } 240 241 // Generate cell from stream. 242 return stream.Dequeue(n) 243} 244 245// WriteNotify returns a channel that receives a notification when a new write is available. 246func (ss *StreamSet) WriteNotify() <-chan struct{} { 247 ss.mu.RLock() 248 defer ss.mu.RUnlock() 249 return ss.wnotify 250} 251 252// notifyWrite closes previous write notification channel and creates a new one. 253// This provides a broadcast to all interested parties. 254func (ss *StreamSet) notifyWrite() { 255 ss.mu.Lock() 256 close(ss.wnotify) 257 ss.wnotify = make(chan struct{}) 258 ss.mu.Unlock() 259} 260 261// handleStream continually monitors write changes for stream. 262func (ss *StreamSet) handleStream(stream *Stream) { 263 notify := stream.WriteNotify() 264 ss.notifyWrite() 265 266 for { 267 select { 268 case <-notify: 269 notify = stream.WriteNotify() 270 ss.notifyWrite() 271 case <-stream.WriteCloseNotify(): 272 ss.notifyWrite() 273 return 274 } 275 } 276} 277 278// timestampWriter wraps a writer and prepends a timestamp & appends a newline to every write. 279type timestampWriter struct { 280 Writer io.Writer 281} 282 283func (w *timestampWriter) Write(p []byte) (n int, err error) { 284 return fmt.Fprintf(w.Writer, "%s %s\n", time.Now().UTC().Format("2006-01-02T15:04:05.000Z"), p) 285} 286