1package rtmp
2
3import (
4	"fmt"
5	"net"
6	"net/url"
7	"reflect"
8	"strings"
9	"time"
10
11	"github.com/gwuhaolin/livego/utils/uid"
12
13	"github.com/gwuhaolin/livego/av"
14	"github.com/gwuhaolin/livego/configure"
15	"github.com/gwuhaolin/livego/container/flv"
16	"github.com/gwuhaolin/livego/protocol/rtmp/core"
17
18	log "github.com/sirupsen/logrus"
19)
20
21const (
22	maxQueueNum           = 1024
23	SAVE_STATICS_INTERVAL = 5000
24)
25
26var (
27	readTimeout  = configure.Config.GetInt("read_timeout")
28	writeTimeout = configure.Config.GetInt("write_timeout")
29)
30
31type Client struct {
32	handler av.Handler
33	getter  av.GetWriter
34}
35
36func NewRtmpClient(h av.Handler, getter av.GetWriter) *Client {
37	return &Client{
38		handler: h,
39		getter:  getter,
40	}
41}
42
43func (c *Client) Dial(url string, method string) error {
44	connClient := core.NewConnClient()
45	if err := connClient.Start(url, method); err != nil {
46		return err
47	}
48	if method == av.PUBLISH {
49		writer := NewVirWriter(connClient)
50		log.Debugf("client Dial call NewVirWriter url=%s, method=%s", url, method)
51		c.handler.HandleWriter(writer)
52	} else if method == av.PLAY {
53		reader := NewVirReader(connClient)
54		log.Debugf("client Dial call NewVirReader url=%s, method=%s", url, method)
55		c.handler.HandleReader(reader)
56		if c.getter != nil {
57			writer := c.getter.GetWriter(reader.Info())
58			c.handler.HandleWriter(writer)
59		}
60	}
61	return nil
62}
63
64func (c *Client) GetHandle() av.Handler {
65	return c.handler
66}
67
68type Server struct {
69	handler av.Handler
70	getter  av.GetWriter
71}
72
73func NewRtmpServer(h av.Handler, getter av.GetWriter) *Server {
74	return &Server{
75		handler: h,
76		getter:  getter,
77	}
78}
79
80func (s *Server) Serve(listener net.Listener) (err error) {
81	defer func() {
82		if r := recover(); r != nil {
83			log.Error("rtmp serve panic: ", r)
84		}
85	}()
86
87	for {
88		var netconn net.Conn
89		netconn, err = listener.Accept()
90		if err != nil {
91			return
92		}
93		conn := core.NewConn(netconn, 4*1024)
94		log.Debug("new client, connect remote: ", conn.RemoteAddr().String(),
95			"local:", conn.LocalAddr().String())
96		go s.handleConn(conn)
97	}
98}
99
100func (s *Server) handleConn(conn *core.Conn) error {
101	if err := conn.HandshakeServer(); err != nil {
102		conn.Close()
103		log.Error("handleConn HandshakeServer err: ", err)
104		return err
105	}
106	connServer := core.NewConnServer(conn)
107
108	if err := connServer.ReadMsg(); err != nil {
109		conn.Close()
110		log.Error("handleConn read msg err: ", err)
111		return err
112	}
113
114	appname, name, _ := connServer.GetInfo()
115
116	if ret := configure.CheckAppName(appname); !ret {
117		err := fmt.Errorf("application name=%s is not configured", appname)
118		conn.Close()
119		log.Error("CheckAppName err: ", err)
120		return err
121	}
122
123	log.Debugf("handleConn: IsPublisher=%v", connServer.IsPublisher())
124	if connServer.IsPublisher() {
125		if configure.Config.GetBool("rtmp_noauth") {
126			key, err := configure.RoomKeys.GetKey(name)
127			if err != nil {
128				err := fmt.Errorf("Cannot create key err=%s", err.Error())
129				conn.Close()
130				log.Error("GetKey err: ", err)
131				return err
132			}
133			name = key
134		}
135		channel, err := configure.RoomKeys.GetChannel(name)
136		if err != nil {
137			err := fmt.Errorf("invalid key err=%s", err.Error())
138			conn.Close()
139			log.Error("CheckKey err: ", err)
140			return err
141		}
142		connServer.PublishInfo.Name = channel
143		if pushlist, ret := configure.GetStaticPushUrlList(appname); ret && (pushlist != nil) {
144			log.Debugf("GetStaticPushUrlList: %v", pushlist)
145		}
146		reader := NewVirReader(connServer)
147		s.handler.HandleReader(reader)
148		log.Debugf("new publisher: %+v", reader.Info())
149
150		if s.getter != nil {
151			writeType := reflect.TypeOf(s.getter)
152			log.Debugf("handleConn:writeType=%v", writeType)
153			writer := s.getter.GetWriter(reader.Info())
154			s.handler.HandleWriter(writer)
155		}
156		if configure.Config.GetBool("flv_archive") {
157			flvWriter := new(flv.FlvDvr)
158			s.handler.HandleWriter(flvWriter.GetWriter(reader.Info()))
159		}
160	} else {
161		writer := NewVirWriter(connServer)
162		log.Debugf("new player: %+v", writer.Info())
163		s.handler.HandleWriter(writer)
164	}
165
166	return nil
167}
168
169type GetInFo interface {
170	GetInfo() (string, string, string)
171}
172
173type StreamReadWriteCloser interface {
174	GetInFo
175	Close(error)
176	Write(core.ChunkStream) error
177	Read(c *core.ChunkStream) error
178}
179
180type StaticsBW struct {
181	StreamId               uint32
182	VideoDatainBytes       uint64
183	LastVideoDatainBytes   uint64
184	VideoSpeedInBytesperMS uint64
185
186	AudioDatainBytes       uint64
187	LastAudioDatainBytes   uint64
188	AudioSpeedInBytesperMS uint64
189
190	LastTimestamp int64
191}
192
193type VirWriter struct {
194	Uid    string
195	closed bool
196	av.RWBaser
197	conn        StreamReadWriteCloser
198	packetQueue chan *av.Packet
199	WriteBWInfo StaticsBW
200}
201
202func NewVirWriter(conn StreamReadWriteCloser) *VirWriter {
203	ret := &VirWriter{
204		Uid:         uid.NewId(),
205		conn:        conn,
206		RWBaser:     av.NewRWBaser(time.Second * time.Duration(writeTimeout)),
207		packetQueue: make(chan *av.Packet, maxQueueNum),
208		WriteBWInfo: StaticsBW{0, 0, 0, 0, 0, 0, 0, 0},
209	}
210
211	go ret.Check()
212	go func() {
213		err := ret.SendPacket()
214		if err != nil {
215			log.Warning(err)
216		}
217	}()
218	return ret
219}
220
221func (v *VirWriter) SaveStatics(streamid uint32, length uint64, isVideoFlag bool) {
222	nowInMS := int64(time.Now().UnixNano() / 1e6)
223
224	v.WriteBWInfo.StreamId = streamid
225	if isVideoFlag {
226		v.WriteBWInfo.VideoDatainBytes = v.WriteBWInfo.VideoDatainBytes + length
227	} else {
228		v.WriteBWInfo.AudioDatainBytes = v.WriteBWInfo.AudioDatainBytes + length
229	}
230
231	if v.WriteBWInfo.LastTimestamp == 0 {
232		v.WriteBWInfo.LastTimestamp = nowInMS
233	} else if (nowInMS - v.WriteBWInfo.LastTimestamp) >= SAVE_STATICS_INTERVAL {
234		diffTimestamp := (nowInMS - v.WriteBWInfo.LastTimestamp) / 1000
235
236		v.WriteBWInfo.VideoSpeedInBytesperMS = (v.WriteBWInfo.VideoDatainBytes - v.WriteBWInfo.LastVideoDatainBytes) * 8 / uint64(diffTimestamp) / 1000
237		v.WriteBWInfo.AudioSpeedInBytesperMS = (v.WriteBWInfo.AudioDatainBytes - v.WriteBWInfo.LastAudioDatainBytes) * 8 / uint64(diffTimestamp) / 1000
238
239		v.WriteBWInfo.LastVideoDatainBytes = v.WriteBWInfo.VideoDatainBytes
240		v.WriteBWInfo.LastAudioDatainBytes = v.WriteBWInfo.AudioDatainBytes
241		v.WriteBWInfo.LastTimestamp = nowInMS
242	}
243}
244
245func (v *VirWriter) Check() {
246	var c core.ChunkStream
247	for {
248		if err := v.conn.Read(&c); err != nil {
249			v.Close(err)
250			return
251		}
252	}
253}
254
255func (v *VirWriter) DropPacket(pktQue chan *av.Packet, info av.Info) {
256	log.Warningf("[%v] packet queue max!!!", info)
257	for i := 0; i < maxQueueNum-84; i++ {
258		tmpPkt, ok := <-pktQue
259		// try to don't drop audio
260		if ok && tmpPkt.IsAudio {
261			if len(pktQue) > maxQueueNum-2 {
262				log.Debug("drop audio pkt")
263				<-pktQue
264			} else {
265				pktQue <- tmpPkt
266			}
267
268		}
269
270		if ok && tmpPkt.IsVideo {
271			videoPkt, ok := tmpPkt.Header.(av.VideoPacketHeader)
272			// dont't drop sps config and dont't drop key frame
273			if ok && (videoPkt.IsSeq() || videoPkt.IsKeyFrame()) {
274				pktQue <- tmpPkt
275			}
276			if len(pktQue) > maxQueueNum-10 {
277				log.Debug("drop video pkt")
278				<-pktQue
279			}
280		}
281
282	}
283	log.Debug("packet queue len: ", len(pktQue))
284}
285
286//
287func (v *VirWriter) Write(p *av.Packet) (err error) {
288	err = nil
289
290	if v.closed {
291		err = fmt.Errorf("VirWriter closed")
292		return
293	}
294	defer func() {
295		if e := recover(); e != nil {
296			err = fmt.Errorf("VirWriter has already been closed:%v", e)
297		}
298	}()
299	if len(v.packetQueue) >= maxQueueNum-24 {
300		v.DropPacket(v.packetQueue, v.Info())
301	} else {
302		v.packetQueue <- p
303	}
304
305	return
306}
307
308func (v *VirWriter) SendPacket() error {
309	Flush := reflect.ValueOf(v.conn).MethodByName("Flush")
310	var cs core.ChunkStream
311	for {
312		p, ok := <-v.packetQueue
313		if ok {
314			cs.Data = p.Data
315			cs.Length = uint32(len(p.Data))
316			cs.StreamID = p.StreamID
317			cs.Timestamp = p.TimeStamp
318			cs.Timestamp += v.BaseTimeStamp()
319
320			if p.IsVideo {
321				cs.TypeID = av.TAG_VIDEO
322			} else {
323				if p.IsMetadata {
324					cs.TypeID = av.TAG_SCRIPTDATAAMF0
325				} else {
326					cs.TypeID = av.TAG_AUDIO
327				}
328			}
329
330			v.SaveStatics(p.StreamID, uint64(cs.Length), p.IsVideo)
331			v.SetPreTime()
332			v.RecTimeStamp(cs.Timestamp, cs.TypeID)
333			err := v.conn.Write(cs)
334			if err != nil {
335				v.closed = true
336				return err
337			}
338			Flush.Call(nil)
339		} else {
340			return fmt.Errorf("closed")
341		}
342
343	}
344}
345
346func (v *VirWriter) Info() (ret av.Info) {
347	ret.UID = v.Uid
348	_, _, URL := v.conn.GetInfo()
349	ret.URL = URL
350	_url, err := url.Parse(URL)
351	if err != nil {
352		log.Warning(err)
353	}
354	ret.Key = strings.TrimLeft(_url.Path, "/")
355	ret.Inter = true
356	return
357}
358
359func (v *VirWriter) Close(err error) {
360	log.Warning("player ", v.Info(), "closed: "+err.Error())
361	if !v.closed {
362		close(v.packetQueue)
363	}
364	v.closed = true
365	v.conn.Close(err)
366}
367
368type VirReader struct {
369	Uid string
370	av.RWBaser
371	demuxer    *flv.Demuxer
372	conn       StreamReadWriteCloser
373	ReadBWInfo StaticsBW
374}
375
376func NewVirReader(conn StreamReadWriteCloser) *VirReader {
377	return &VirReader{
378		Uid:        uid.NewId(),
379		conn:       conn,
380		RWBaser:    av.NewRWBaser(time.Second * time.Duration(writeTimeout)),
381		demuxer:    flv.NewDemuxer(),
382		ReadBWInfo: StaticsBW{0, 0, 0, 0, 0, 0, 0, 0},
383	}
384}
385
386func (v *VirReader) SaveStatics(streamid uint32, length uint64, isVideoFlag bool) {
387	nowInMS := int64(time.Now().UnixNano() / 1e6)
388
389	v.ReadBWInfo.StreamId = streamid
390	if isVideoFlag {
391		v.ReadBWInfo.VideoDatainBytes = v.ReadBWInfo.VideoDatainBytes + length
392	} else {
393		v.ReadBWInfo.AudioDatainBytes = v.ReadBWInfo.AudioDatainBytes + length
394	}
395
396	if v.ReadBWInfo.LastTimestamp == 0 {
397		v.ReadBWInfo.LastTimestamp = nowInMS
398	} else if (nowInMS - v.ReadBWInfo.LastTimestamp) >= SAVE_STATICS_INTERVAL {
399		diffTimestamp := (nowInMS - v.ReadBWInfo.LastTimestamp) / 1000
400
401		//log.Printf("now=%d, last=%d, diff=%d", nowInMS, v.ReadBWInfo.LastTimestamp, diffTimestamp)
402		v.ReadBWInfo.VideoSpeedInBytesperMS = (v.ReadBWInfo.VideoDatainBytes - v.ReadBWInfo.LastVideoDatainBytes) * 8 / uint64(diffTimestamp) / 1000
403		v.ReadBWInfo.AudioSpeedInBytesperMS = (v.ReadBWInfo.AudioDatainBytes - v.ReadBWInfo.LastAudioDatainBytes) * 8 / uint64(diffTimestamp) / 1000
404
405		v.ReadBWInfo.LastVideoDatainBytes = v.ReadBWInfo.VideoDatainBytes
406		v.ReadBWInfo.LastAudioDatainBytes = v.ReadBWInfo.AudioDatainBytes
407		v.ReadBWInfo.LastTimestamp = nowInMS
408	}
409}
410
411func (v *VirReader) Read(p *av.Packet) (err error) {
412	defer func() {
413		if r := recover(); r != nil {
414			log.Warning("rtmp read packet panic: ", r)
415		}
416	}()
417
418	v.SetPreTime()
419	var cs core.ChunkStream
420	for {
421		err = v.conn.Read(&cs)
422		if err != nil {
423			return err
424		}
425		if cs.TypeID == av.TAG_AUDIO ||
426			cs.TypeID == av.TAG_VIDEO ||
427			cs.TypeID == av.TAG_SCRIPTDATAAMF0 ||
428			cs.TypeID == av.TAG_SCRIPTDATAAMF3 {
429			break
430		}
431	}
432
433	p.IsAudio = cs.TypeID == av.TAG_AUDIO
434	p.IsVideo = cs.TypeID == av.TAG_VIDEO
435	p.IsMetadata = cs.TypeID == av.TAG_SCRIPTDATAAMF0 || cs.TypeID == av.TAG_SCRIPTDATAAMF3
436	p.StreamID = cs.StreamID
437	p.Data = cs.Data
438	p.TimeStamp = cs.Timestamp
439
440	v.SaveStatics(p.StreamID, uint64(len(p.Data)), p.IsVideo)
441	v.demuxer.DemuxH(p)
442	return err
443}
444
445func (v *VirReader) Info() (ret av.Info) {
446	ret.UID = v.Uid
447	_, _, URL := v.conn.GetInfo()
448	ret.URL = URL
449	_url, err := url.Parse(URL)
450	if err != nil {
451		log.Warning(err)
452	}
453	ret.Key = strings.TrimLeft(_url.Path, "/")
454	return
455}
456
457func (v *VirReader) Close(err error) {
458	log.Debug("publisher ", v.Info(), "closed: "+err.Error())
459	v.conn.Close(err)
460}
461