1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto_test
6
7import (
8	"google.golang.org/protobuf/encoding/protowire"
9	"google.golang.org/protobuf/internal/flags"
10	"google.golang.org/protobuf/proto"
11	"google.golang.org/protobuf/testing/protopack"
12
13	messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
14	msetextpb "google.golang.org/protobuf/internal/testprotos/messageset/msetextpb"
15)
16
17func init() {
18	if flags.ProtoLegacy {
19		testValidMessages = append(testValidMessages, messageSetTestProtos...)
20		testInvalidMessages = append(testInvalidMessages, messageSetInvalidTestProtos...)
21	}
22}
23
24var messageSetTestProtos = []testProto{
25	{
26		desc: "MessageSet type_id before message content",
27		decodeTo: []proto.Message{func() proto.Message {
28			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
29			proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
30				Ext1Field1: proto.Int32(10),
31			})
32			return m
33		}()},
34		wire: protopack.Message{
35			protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
36				protopack.Tag{1, protopack.StartGroupType},
37				protopack.Tag{2, protopack.VarintType}, protopack.Varint(1000),
38				protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
39					protopack.Tag{1, protopack.VarintType}, protopack.Varint(10),
40				}),
41				protopack.Tag{1, protopack.EndGroupType},
42			}),
43		}.Marshal(),
44	},
45	{
46		desc: "MessageSet type_id after message content",
47		decodeTo: []proto.Message{func() proto.Message {
48			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
49			proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
50				Ext1Field1: proto.Int32(10),
51			})
52			return m
53		}()},
54		wire: protopack.Message{
55			protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
56				protopack.Tag{1, protopack.StartGroupType},
57				protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
58					protopack.Tag{1, protopack.VarintType}, protopack.Varint(10),
59				}),
60				protopack.Tag{2, protopack.VarintType}, protopack.Varint(1000),
61				protopack.Tag{1, protopack.EndGroupType},
62			}),
63		}.Marshal(),
64	},
65	{
66		desc: "MessageSet does not preserve unknown field",
67		decodeTo: []proto.Message{build(
68			&messagesetpb.MessageSet{},
69			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
70				Ext1Field1: proto.Int32(10),
71			}),
72		)},
73		wire: protopack.Message{
74			protopack.Tag{1, protopack.StartGroupType},
75			protopack.Tag{2, protopack.VarintType}, protopack.Varint(1000),
76			protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
77				protopack.Tag{1, protopack.VarintType}, protopack.Varint(10),
78			}),
79			protopack.Tag{1, protopack.EndGroupType},
80			// Unknown field
81			protopack.Tag{4, protopack.VarintType}, protopack.Varint(30),
82		}.Marshal(),
83	},
84	{
85		desc: "MessageSet with unknown type_id",
86		decodeTo: []proto.Message{build(
87			&messagesetpb.MessageSet{},
88			unknown(protopack.Message{
89				protopack.Tag{999, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
90					protopack.Tag{1, protopack.VarintType}, protopack.Varint(10),
91				}),
92			}.Marshal()),
93		)},
94		wire: protopack.Message{
95			protopack.Tag{1, protopack.StartGroupType},
96			protopack.Tag{2, protopack.VarintType}, protopack.Varint(999),
97			protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
98				protopack.Tag{1, protopack.VarintType}, protopack.Varint(10),
99			}),
100			protopack.Tag{1, protopack.EndGroupType},
101		}.Marshal(),
102	},
103	{
104		desc: "MessageSet merges repeated message fields in item",
105		decodeTo: []proto.Message{build(
106			&messagesetpb.MessageSet{},
107			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
108				Ext1Field1: proto.Int32(10),
109				Ext1Field2: proto.Int32(20),
110			}),
111		)},
112		wire: protopack.Message{
113			protopack.Tag{1, protopack.StartGroupType},
114			protopack.Tag{2, protopack.VarintType}, protopack.Varint(1000),
115			protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
116				protopack.Tag{1, protopack.VarintType}, protopack.Varint(10),
117			}),
118			protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
119				protopack.Tag{2, protopack.VarintType}, protopack.Varint(20),
120			}),
121			protopack.Tag{1, protopack.EndGroupType},
122		}.Marshal(),
123	},
124	{
125		desc: "MessageSet merges message fields in repeated items",
126		decodeTo: []proto.Message{build(
127			&messagesetpb.MessageSet{},
128			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
129				Ext1Field1: proto.Int32(10),
130				Ext1Field2: proto.Int32(20),
131			}),
132			extend(msetextpb.E_Ext2_MessageSetExtension, &msetextpb.Ext2{
133				Ext2Field1: proto.Int32(30),
134			}),
135		)},
136		wire: protopack.Message{
137			// Ext1, field1
138			protopack.Tag{1, protopack.StartGroupType},
139			protopack.Tag{2, protopack.VarintType}, protopack.Varint(1000),
140			protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
141				protopack.Tag{1, protopack.VarintType}, protopack.Varint(10),
142			}),
143			protopack.Tag{1, protopack.EndGroupType},
144			// Ext2, field1
145			protopack.Tag{1, protopack.StartGroupType},
146			protopack.Tag{2, protopack.VarintType}, protopack.Varint(1001),
147			protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
148				protopack.Tag{1, protopack.VarintType}, protopack.Varint(30),
149			}),
150			protopack.Tag{1, protopack.EndGroupType},
151			// Ext2, field2
152			protopack.Tag{1, protopack.StartGroupType},
153			protopack.Tag{2, protopack.VarintType}, protopack.Varint(1000),
154			protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
155				protopack.Tag{2, protopack.VarintType}, protopack.Varint(20),
156			}),
157			protopack.Tag{1, protopack.EndGroupType},
158		}.Marshal(),
159	},
160	{
161		desc: "MessageSet with missing type_id",
162		decodeTo: []proto.Message{build(
163			&messagesetpb.MessageSet{},
164		)},
165		wire: protopack.Message{
166			protopack.Tag{1, protopack.StartGroupType},
167			protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
168				protopack.Tag{1, protopack.VarintType}, protopack.Varint(10),
169			}),
170			protopack.Tag{1, protopack.EndGroupType},
171		}.Marshal(),
172	},
173	{
174		desc: "MessageSet with missing message",
175		decodeTo: []proto.Message{build(
176			&messagesetpb.MessageSet{},
177			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{}),
178		)},
179		wire: protopack.Message{
180			protopack.Tag{1, protopack.StartGroupType},
181			protopack.Tag{2, protopack.VarintType}, protopack.Varint(1000),
182			protopack.Tag{1, protopack.EndGroupType},
183		}.Marshal(),
184	},
185	{
186		desc: "MessageSet with type id out of valid field number range",
187		decodeTo: []proto.Message{func() proto.Message {
188			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
189			proto.SetExtension(m.MessageSet, msetextpb.E_ExtLargeNumber_MessageSetExtension, &msetextpb.ExtLargeNumber{})
190			return m
191		}()},
192		wire: protopack.Message{
193			protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
194				protopack.Tag{1, protopack.StartGroupType},
195				protopack.Tag{2, protopack.VarintType}, protopack.Varint(protowire.MaxValidNumber + 1),
196				protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{}),
197				protopack.Tag{1, protopack.EndGroupType},
198			}),
199		}.Marshal(),
200	},
201	{
202		desc: "MessageSet with unknown type id out of valid field number range",
203		decodeTo: []proto.Message{func() proto.Message {
204			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
205			m.MessageSet.ProtoReflect().SetUnknown(
206				protopack.Message{
207					protopack.Tag{protowire.MaxValidNumber + 2, protopack.BytesType}, protopack.LengthPrefix{},
208				}.Marshal(),
209			)
210			return m
211		}()},
212		wire: protopack.Message{
213			protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
214				protopack.Tag{1, protopack.StartGroupType},
215				protopack.Tag{2, protopack.VarintType}, protopack.Varint(protowire.MaxValidNumber + 2),
216				protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{}),
217				protopack.Tag{1, protopack.EndGroupType},
218			}),
219		}.Marshal(),
220	},
221	{
222		desc: "MessageSet with unknown field",
223		decodeTo: []proto.Message{func() proto.Message {
224			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
225			proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
226				Ext1Field1: proto.Int32(10),
227			})
228			return m
229		}()},
230		wire: protopack.Message{
231			protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
232				protopack.Tag{1, protopack.StartGroupType},
233				protopack.Tag{2, protopack.VarintType}, protopack.Varint(1000),
234				protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
235					protopack.Tag{1, protopack.VarintType}, protopack.Varint(10),
236				}),
237				protopack.Tag{4, protopack.VarintType}, protopack.Varint(0),
238				protopack.Tag{1, protopack.EndGroupType},
239			}),
240		}.Marshal(),
241	},
242	{
243		desc:          "MessageSet with required field set",
244		checkFastInit: true,
245		decodeTo: []proto.Message{func() proto.Message {
246			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
247			proto.SetExtension(m.MessageSet, msetextpb.E_ExtRequired_MessageSetExtension, &msetextpb.ExtRequired{
248				RequiredField1: proto.Int32(1),
249			})
250			return m
251		}()},
252		wire: protopack.Message{
253			protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
254				protopack.Tag{1, protopack.StartGroupType},
255				protopack.Tag{2, protopack.VarintType}, protopack.Varint(1002),
256				protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
257					protopack.Tag{1, protopack.VarintType}, protopack.Varint(1),
258				}),
259				protopack.Tag{1, protopack.EndGroupType},
260			}),
261		}.Marshal(),
262	},
263	{
264		desc:          "MessageSet with required field unset",
265		checkFastInit: true,
266		partial:       true,
267		decodeTo: []proto.Message{func() proto.Message {
268			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
269			proto.SetExtension(m.MessageSet, msetextpb.E_ExtRequired_MessageSetExtension, &msetextpb.ExtRequired{})
270			return m
271		}()},
272		wire: protopack.Message{
273			protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
274				protopack.Tag{1, protopack.StartGroupType},
275				protopack.Tag{2, protopack.VarintType}, protopack.Varint(1002),
276				protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{}),
277				protopack.Tag{1, protopack.EndGroupType},
278			}),
279		}.Marshal(),
280	},
281}
282
283var messageSetInvalidTestProtos = []testProto{
284	{
285		desc: "MessageSet with type id 0",
286		decodeTo: []proto.Message{
287			(*messagesetpb.MessageSetContainer)(nil),
288		},
289		wire: protopack.Message{
290			protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
291				protopack.Tag{1, protopack.StartGroupType},
292				protopack.Tag{2, protopack.VarintType}, protopack.Uvarint(0),
293				protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{}),
294				protopack.Tag{1, protopack.EndGroupType},
295			}),
296		}.Marshal(),
297	},
298	{
299		desc: "MessageSet with type id overflowing int32",
300		decodeTo: []proto.Message{
301			(*messagesetpb.MessageSetContainer)(nil),
302		},
303		wire: protopack.Message{
304			protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
305				protopack.Tag{1, protopack.StartGroupType},
306				protopack.Tag{2, protopack.VarintType}, protopack.Uvarint(0x80000000),
307				protopack.Tag{3, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{}),
308				protopack.Tag{1, protopack.EndGroupType},
309			}),
310		}.Marshal(),
311	},
312}
313