1package sarama
2
3import "fmt"
4
5const (
6	unknownRecords = iota
7	legacyRecords
8	defaultRecords
9
10	magicOffset = 16
11	magicLength = 1
12)
13
14// Records implements a union type containing either a RecordBatch or a legacy MessageSet.
15type Records struct {
16	recordsType int
17	MsgSet      *MessageSet
18	RecordBatch *RecordBatch
19}
20
21func newLegacyRecords(msgSet *MessageSet) Records {
22	return Records{recordsType: legacyRecords, MsgSet: msgSet}
23}
24
25func newDefaultRecords(batch *RecordBatch) Records {
26	return Records{recordsType: defaultRecords, RecordBatch: batch}
27}
28
29// setTypeFromFields sets type of Records depending on which of MsgSet or RecordBatch is not nil.
30// The first return value indicates whether both fields are nil (and the type is not set).
31// If both fields are not nil, it returns an error.
32func (r *Records) setTypeFromFields() (bool, error) {
33	if r.MsgSet == nil && r.RecordBatch == nil {
34		return true, nil
35	}
36	if r.MsgSet != nil && r.RecordBatch != nil {
37		return false, fmt.Errorf("both MsgSet and RecordBatch are set, but record type is unknown")
38	}
39	r.recordsType = defaultRecords
40	if r.MsgSet != nil {
41		r.recordsType = legacyRecords
42	}
43	return false, nil
44}
45
46func (r *Records) encode(pe packetEncoder) error {
47	if r.recordsType == unknownRecords {
48		if empty, err := r.setTypeFromFields(); err != nil || empty {
49			return err
50		}
51	}
52
53	switch r.recordsType {
54	case legacyRecords:
55		if r.MsgSet == nil {
56			return nil
57		}
58		return r.MsgSet.encode(pe)
59	case defaultRecords:
60		if r.RecordBatch == nil {
61			return nil
62		}
63		return r.RecordBatch.encode(pe)
64	}
65
66	return fmt.Errorf("unknown records type: %v", r.recordsType)
67}
68
69func (r *Records) setTypeFromMagic(pd packetDecoder) error {
70	magic, err := magicValue(pd)
71	if err != nil {
72		return err
73	}
74
75	r.recordsType = defaultRecords
76	if magic < 2 {
77		r.recordsType = legacyRecords
78	}
79
80	return nil
81}
82
83func (r *Records) decode(pd packetDecoder) error {
84	if r.recordsType == unknownRecords {
85		if err := r.setTypeFromMagic(pd); err != nil {
86			return err
87		}
88	}
89
90	switch r.recordsType {
91	case legacyRecords:
92		r.MsgSet = &MessageSet{}
93		return r.MsgSet.decode(pd)
94	case defaultRecords:
95		r.RecordBatch = &RecordBatch{}
96		return r.RecordBatch.decode(pd)
97	}
98	return fmt.Errorf("unknown records type: %v", r.recordsType)
99}
100
101func (r *Records) numRecords() (int, error) {
102	if r.recordsType == unknownRecords {
103		if empty, err := r.setTypeFromFields(); err != nil || empty {
104			return 0, err
105		}
106	}
107
108	switch r.recordsType {
109	case legacyRecords:
110		if r.MsgSet == nil {
111			return 0, nil
112		}
113		return len(r.MsgSet.Messages), nil
114	case defaultRecords:
115		if r.RecordBatch == nil {
116			return 0, nil
117		}
118		return len(r.RecordBatch.Records), nil
119	}
120	return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
121}
122
123func (r *Records) isPartial() (bool, error) {
124	if r.recordsType == unknownRecords {
125		if empty, err := r.setTypeFromFields(); err != nil || empty {
126			return false, err
127		}
128	}
129
130	switch r.recordsType {
131	case unknownRecords:
132		return false, nil
133	case legacyRecords:
134		if r.MsgSet == nil {
135			return false, nil
136		}
137		return r.MsgSet.PartialTrailingMessage, nil
138	case defaultRecords:
139		if r.RecordBatch == nil {
140			return false, nil
141		}
142		return r.RecordBatch.PartialTrailingRecord, nil
143	}
144	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
145}
146
147func (r *Records) isControl() (bool, error) {
148	if r.recordsType == unknownRecords {
149		if empty, err := r.setTypeFromFields(); err != nil || empty {
150			return false, err
151		}
152	}
153
154	switch r.recordsType {
155	case legacyRecords:
156		return false, nil
157	case defaultRecords:
158		if r.RecordBatch == nil {
159			return false, nil
160		}
161		return r.RecordBatch.Control, nil
162	}
163	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
164}
165
166func (r *Records) isOverflow() (bool, error) {
167	if r.recordsType == unknownRecords {
168		if empty, err := r.setTypeFromFields(); err != nil || empty {
169			return false, err
170		}
171	}
172
173	switch r.recordsType {
174	case unknownRecords:
175		return false, nil
176	case legacyRecords:
177		if r.MsgSet == nil {
178			return false, nil
179		}
180		return r.MsgSet.OverflowMessage, nil
181	case defaultRecords:
182		return false, nil
183	}
184	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
185}
186
187func magicValue(pd packetDecoder) (int8, error) {
188	dec, err := pd.peek(magicOffset, magicLength)
189	if err != nil {
190		return 0, err
191	}
192
193	return dec.getInt8()
194}
195