1package h264
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7)
8
9const (
10	i_frame byte = 0
11	p_frame byte = 1
12	b_frame byte = 2
13)
14
15const (
16	nalu_type_not_define byte = 0
17	nalu_type_slice      byte = 1  //slice_layer_without_partioning_rbsp() sliceheader
18	nalu_type_dpa        byte = 2  // slice_data_partition_a_layer_rbsp( ), slice_header
19	nalu_type_dpb        byte = 3  // slice_data_partition_b_layer_rbsp( )
20	nalu_type_dpc        byte = 4  // slice_data_partition_c_layer_rbsp( )
21	nalu_type_idr        byte = 5  // slice_layer_without_partitioning_rbsp( ),sliceheader
22	nalu_type_sei        byte = 6  //sei_rbsp( )
23	nalu_type_sps        byte = 7  //seq_parameter_set_rbsp( )
24	nalu_type_pps        byte = 8  //pic_parameter_set_rbsp( )
25	nalu_type_aud        byte = 9  // access_unit_delimiter_rbsp( )
26	nalu_type_eoesq      byte = 10 //end_of_seq_rbsp( )
27	nalu_type_eostream   byte = 11 //end_of_stream_rbsp( )
28	nalu_type_filler     byte = 12 //filler_data_rbsp( )
29)
30
31const (
32	naluBytesLen int = 4
33	maxSpsPpsLen int = 2 * 1024
34)
35
36var (
37	decDataNil        = fmt.Errorf("dec buf is nil")
38	spsDataError      = fmt.Errorf("sps data error")
39	ppsHeaderError    = fmt.Errorf("pps header error")
40	ppsDataError      = fmt.Errorf("pps data error")
41	naluHeaderInvalid = fmt.Errorf("nalu header invalid")
42	videoDataInvalid  = fmt.Errorf("video data not match")
43	dataSizeNotMatch  = fmt.Errorf("data size not match")
44	naluBodyLenError  = fmt.Errorf("nalu body len error")
45)
46
47var startCode = []byte{0x00, 0x00, 0x00, 0x01}
48var naluAud = []byte{0x00, 0x00, 0x00, 0x01, 0x09, 0xf0}
49
50type Parser struct {
51	frameType    byte
52	specificInfo []byte
53	pps          *bytes.Buffer
54}
55
56type sequenceHeader struct {
57	configVersion        byte //8bits
58	avcProfileIndication byte //8bits
59	profileCompatility   byte //8bits
60	avcLevelIndication   byte //8bits
61	reserved1            byte //6bits
62	naluLen              byte //2bits
63	reserved2            byte //3bits
64	spsNum               byte //5bits
65	ppsNum               byte //8bits
66	spsLen               int
67	ppsLen               int
68}
69
70func NewParser() *Parser {
71	return &Parser{
72		pps: bytes.NewBuffer(make([]byte, maxSpsPpsLen)),
73	}
74}
75
76//return value 1:sps, value2 :pps
77func (parser *Parser) parseSpecificInfo(src []byte) error {
78	if len(src) < 9 {
79		return decDataNil
80	}
81	sps := []byte{}
82	pps := []byte{}
83
84	var seq sequenceHeader
85	seq.configVersion = src[0]
86	seq.avcProfileIndication = src[1]
87	seq.profileCompatility = src[2]
88	seq.avcLevelIndication = src[3]
89	seq.reserved1 = src[4] & 0xfc
90	seq.naluLen = src[4]&0x03 + 1
91	seq.reserved2 = src[5] >> 5
92
93	//get sps
94	seq.spsNum = src[5] & 0x1f
95	seq.spsLen = int(src[6])<<8 | int(src[7])
96
97	if len(src[8:]) < seq.spsLen || seq.spsLen <= 0 {
98		return spsDataError
99	}
100	sps = append(sps, startCode...)
101	sps = append(sps, src[8:(8+seq.spsLen)]...)
102
103	//get pps
104	tmpBuf := src[(8 + seq.spsLen):]
105	if len(tmpBuf) < 4 {
106		return ppsHeaderError
107	}
108	seq.ppsNum = tmpBuf[0]
109	seq.ppsLen = int(0)<<16 | int(tmpBuf[1])<<8 | int(tmpBuf[2])
110	if len(tmpBuf[3:]) < seq.ppsLen || seq.ppsLen <= 0 {
111		return ppsDataError
112	}
113
114	pps = append(pps, startCode...)
115	pps = append(pps, tmpBuf[3:]...)
116
117	parser.specificInfo = append(parser.specificInfo, sps...)
118	parser.specificInfo = append(parser.specificInfo, pps...)
119
120	return nil
121}
122
123func (parser *Parser) isNaluHeader(src []byte) bool {
124	if len(src) < naluBytesLen {
125		return false
126	}
127	return src[0] == 0x00 &&
128		src[1] == 0x00 &&
129		src[2] == 0x00 &&
130		src[3] == 0x01
131}
132
133func (parser *Parser) naluSize(src []byte) (int, error) {
134	if len(src) < naluBytesLen {
135		return 0, fmt.Errorf("nalusizedata invalid")
136	}
137	buf := src[:naluBytesLen]
138	size := int(0)
139	for i := 0; i < len(buf); i++ {
140		size = size<<8 + int(buf[i])
141	}
142	return size, nil
143}
144
145func (parser *Parser) getAnnexbH264(src []byte, w io.Writer) error {
146	dataSize := len(src)
147	if dataSize < naluBytesLen {
148		return videoDataInvalid
149	}
150	parser.pps.Reset()
151	_, err := w.Write(naluAud)
152	if err != nil {
153		return err
154	}
155
156	index := 0
157	nalLen := 0
158	hasSpsPps := false
159	hasWriteSpsPps := false
160
161	for dataSize > 0 {
162		nalLen, err = parser.naluSize(src[index:])
163		if err != nil {
164			return dataSizeNotMatch
165		}
166		index += naluBytesLen
167		dataSize -= naluBytesLen
168		if dataSize >= nalLen && len(src[index:]) >= nalLen && nalLen > 0 {
169			nalType := src[index] & 0x1f
170			switch nalType {
171			case nalu_type_aud:
172			case nalu_type_idr:
173				if !hasWriteSpsPps {
174					hasWriteSpsPps = true
175					if !hasSpsPps {
176						if _, err := w.Write(parser.specificInfo); err != nil {
177							return err
178						}
179					} else {
180						if _, err := w.Write(parser.pps.Bytes()); err != nil {
181							return err
182						}
183					}
184				}
185				fallthrough
186			case nalu_type_slice:
187				fallthrough
188			case nalu_type_sei:
189				_, err := w.Write(startCode)
190				if err != nil {
191					return err
192				}
193				_, err = w.Write(src[index : index+nalLen])
194				if err != nil {
195					return err
196				}
197			case nalu_type_sps:
198				fallthrough
199			case nalu_type_pps:
200				hasSpsPps = true
201				_, err := parser.pps.Write(startCode)
202				if err != nil {
203					return err
204				}
205				_, err = parser.pps.Write(src[index : index+nalLen])
206				if err != nil {
207					return err
208				}
209			}
210			index += nalLen
211			dataSize -= nalLen
212		} else {
213			return naluBodyLenError
214		}
215	}
216	return nil
217}
218
219func (parser *Parser) Parse(b []byte, isSeq bool, w io.Writer) (err error) {
220	switch isSeq {
221	case true:
222		err = parser.parseSpecificInfo(b)
223	case false:
224		// is annexb
225		if parser.isNaluHeader(b) {
226			_, err = w.Write(b)
227		} else {
228			err = parser.getAnnexbH264(b, w)
229		}
230	}
231	return
232}
233