1package sarama
2
3import "fmt"
4
5const (
6	unknownRecords = iota
7	legacyRecords
8	defaultRecords
9
10	magicOffset = 16
11)
12
13// Records implements a union type containing either a RecordBatch or a legacy MessageSet.
14type Records struct {
15	recordsType int
16	MsgSet      *MessageSet
17	RecordBatch *RecordBatch
18}
19
20func newLegacyRecords(msgSet *MessageSet) Records {
21	return Records{recordsType: legacyRecords, MsgSet: msgSet}
22}
23
24func newDefaultRecords(batch *RecordBatch) Records {
25	return Records{recordsType: defaultRecords, RecordBatch: batch}
26}
27
28// setTypeFromFields sets type of Records depending on which of MsgSet or RecordBatch is not nil.
29// The first return value indicates whether both fields are nil (and the type is not set).
30// If both fields are not nil, it returns an error.
31func (r *Records) setTypeFromFields() (bool, error) {
32	if r.MsgSet == nil && r.RecordBatch == nil {
33		return true, nil
34	}
35	if r.MsgSet != nil && r.RecordBatch != nil {
36		return false, fmt.Errorf("both MsgSet and RecordBatch are set, but record type is unknown")
37	}
38	r.recordsType = defaultRecords
39	if r.MsgSet != nil {
40		r.recordsType = legacyRecords
41	}
42	return false, nil
43}
44
45func (r *Records) encode(pe packetEncoder) error {
46	if r.recordsType == unknownRecords {
47		if empty, err := r.setTypeFromFields(); err != nil || empty {
48			return err
49		}
50	}
51
52	switch r.recordsType {
53	case legacyRecords:
54		if r.MsgSet == nil {
55			return nil
56		}
57		return r.MsgSet.encode(pe)
58	case defaultRecords:
59		if r.RecordBatch == nil {
60			return nil
61		}
62		return r.RecordBatch.encode(pe)
63	}
64
65	return fmt.Errorf("unknown records type: %v", r.recordsType)
66}
67
68func (r *Records) setTypeFromMagic(pd packetDecoder) error {
69	magic, err := magicValue(pd)
70	if err != nil {
71		return err
72	}
73
74	r.recordsType = defaultRecords
75	if magic < 2 {
76		r.recordsType = legacyRecords
77	}
78
79	return nil
80}
81
82func (r *Records) decode(pd packetDecoder) error {
83	if r.recordsType == unknownRecords {
84		if err := r.setTypeFromMagic(pd); err != nil {
85			return err
86		}
87	}
88
89	switch r.recordsType {
90	case legacyRecords:
91		r.MsgSet = &MessageSet{}
92		return r.MsgSet.decode(pd)
93	case defaultRecords:
94		r.RecordBatch = &RecordBatch{}
95		return r.RecordBatch.decode(pd)
96	}
97	return fmt.Errorf("unknown records type: %v", r.recordsType)
98}
99
100func (r *Records) numRecords() (int, error) {
101	if r.recordsType == unknownRecords {
102		if empty, err := r.setTypeFromFields(); err != nil || empty {
103			return 0, err
104		}
105	}
106
107	switch r.recordsType {
108	case legacyRecords:
109		if r.MsgSet == nil {
110			return 0, nil
111		}
112		return len(r.MsgSet.Messages), nil
113	case defaultRecords:
114		if r.RecordBatch == nil {
115			return 0, nil
116		}
117		return len(r.RecordBatch.Records), nil
118	}
119	return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
120}
121
122func (r *Records) isPartial() (bool, error) {
123	if r.recordsType == unknownRecords {
124		if empty, err := r.setTypeFromFields(); err != nil || empty {
125			return false, err
126		}
127	}
128
129	switch r.recordsType {
130	case unknownRecords:
131		return false, nil
132	case legacyRecords:
133		if r.MsgSet == nil {
134			return false, nil
135		}
136		return r.MsgSet.PartialTrailingMessage, nil
137	case defaultRecords:
138		if r.RecordBatch == nil {
139			return false, nil
140		}
141		return r.RecordBatch.PartialTrailingRecord, nil
142	}
143	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
144}
145
146func (r *Records) isControl() (bool, error) {
147	if r.recordsType == unknownRecords {
148		if empty, err := r.setTypeFromFields(); err != nil || empty {
149			return false, err
150		}
151	}
152
153	switch r.recordsType {
154	case legacyRecords:
155		return false, nil
156	case defaultRecords:
157		if r.RecordBatch == nil {
158			return false, nil
159		}
160		return r.RecordBatch.Control, nil
161	}
162	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
163}
164
165func (r *Records) isOverflow() (bool, error) {
166	if r.recordsType == unknownRecords {
167		if empty, err := r.setTypeFromFields(); err != nil || empty {
168			return false, err
169		}
170	}
171
172	switch r.recordsType {
173	case unknownRecords:
174		return false, nil
175	case legacyRecords:
176		if r.MsgSet == nil {
177			return false, nil
178		}
179		return r.MsgSet.OverflowMessage, nil
180	case defaultRecords:
181		return false, nil
182	}
183	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
184}
185
186func magicValue(pd packetDecoder) (int8, error) {
187	return pd.peekInt8(magicOffset)
188}
189
190func (r *Records) getControlRecord() (ControlRecord, error) {
191	if r.RecordBatch == nil || len(r.RecordBatch.Records) <= 0 {
192		return ControlRecord{}, fmt.Errorf("cannot get control record, record batch is empty")
193	}
194
195	firstRecord := r.RecordBatch.Records[0]
196	controlRecord := ControlRecord{}
197	err := controlRecord.decode(&realDecoder{raw: firstRecord.Key}, &realDecoder{raw: firstRecord.Value})
198	if err != nil {
199		return ControlRecord{}, err
200	}
201
202	return controlRecord, nil
203}
204