1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto_test
6
7import (
8	"bytes"
9	"fmt"
10	"reflect"
11	"testing"
12
13	"google.golang.org/protobuf/encoding/prototext"
14	"google.golang.org/protobuf/proto"
15	"google.golang.org/protobuf/reflect/protoreflect"
16	"google.golang.org/protobuf/testing/protopack"
17
18	"google.golang.org/protobuf/internal/errors"
19	testpb "google.golang.org/protobuf/internal/testprotos/test"
20	test3pb "google.golang.org/protobuf/internal/testprotos/test3"
21)
22
23func TestDecode(t *testing.T) {
24	for _, test := range testValidMessages {
25		if len(test.decodeTo) == 0 {
26			t.Errorf("%v: no test message types", test.desc)
27		}
28		for _, want := range test.decodeTo {
29			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
30				opts := test.unmarshalOptions
31				opts.AllowPartial = test.partial
32				wire := append(([]byte)(nil), test.wire...)
33				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
34				if err := opts.Unmarshal(wire, got); err != nil {
35					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, prototext.Format(want))
36					return
37				}
38
39				// Aliasing check: Unmarshal shouldn't modify the original wire
40				// bytes, and modifying the original wire bytes shouldn't affect
41				// the unmarshaled message.
42				if !bytes.Equal(test.wire, wire) {
43					t.Errorf("Unmarshal unexpectedly modified its input")
44				}
45				for i := range wire {
46					wire[i] = 0
47				}
48				if !proto.Equal(got, want) && got.ProtoReflect().IsValid() && want.ProtoReflect().IsValid() {
49					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", prototext.Format(got), prototext.Format(want))
50				}
51			})
52		}
53	}
54}
55
56func TestDecodeRequiredFieldChecks(t *testing.T) {
57	for _, test := range testValidMessages {
58		if !test.partial {
59			continue
60		}
61		for _, m := range test.decodeTo {
62			t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
63				opts := test.unmarshalOptions
64				opts.AllowPartial = false
65				got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
66				if err := proto.Unmarshal(test.wire, got); err == nil {
67					t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", prototext.Format(got))
68				}
69			})
70		}
71	}
72}
73
74func TestDecodeInvalidMessages(t *testing.T) {
75	for _, test := range testInvalidMessages {
76		if len(test.decodeTo) == 0 {
77			t.Errorf("%v: no test message types", test.desc)
78		}
79		for _, want := range test.decodeTo {
80			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
81				opts := test.unmarshalOptions
82				opts.AllowPartial = test.partial
83				got := want.ProtoReflect().New().Interface()
84				if err := opts.Unmarshal(test.wire, got); err == nil {
85					t.Errorf("Unmarshal unexpectedly succeeded\ninput bytes: [%x]\nMessage:\n%v", test.wire, prototext.Format(got))
86				} else if !errors.Is(err, proto.Error) {
87					t.Errorf("Unmarshal error is not a proto.Error: %v", err)
88				}
89			})
90		}
91	}
92}
93
94func TestDecodeZeroLengthBytes(t *testing.T) {
95	// Verify that proto3 bytes fields don't give the mistaken
96	// impression that they preserve presence.
97	wire := protopack.Message{
98		protopack.Tag{94, protopack.BytesType}, protopack.Bytes(nil),
99	}.Marshal()
100	m := &test3pb.TestAllTypes{}
101	if err := proto.Unmarshal(wire, m); err != nil {
102		t.Fatal(err)
103	}
104	if m.OptionalBytes != nil {
105		t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes)
106	}
107}
108
109func TestDecodeOneofNilWrapper(t *testing.T) {
110	wire := protopack.Message{
111		protopack.Tag{111, protopack.VarintType}, protopack.Varint(1111),
112	}.Marshal()
113	m := &testpb.TestAllTypes{OneofField: (*testpb.TestAllTypes_OneofUint32)(nil)}
114	if err := proto.Unmarshal(wire, m); err != nil {
115		t.Fatal(err)
116	}
117	if got := m.GetOneofUint32(); got != 1111 {
118		t.Errorf("GetOneofUint32() = %v, want %v", got, 1111)
119	}
120}
121
122func TestDecodeEmptyBytes(t *testing.T) {
123	// There's really nothing wrong with a nil entry in a [][]byte,
124	// but we take care to produce non-nil []bytes for zero-length
125	// byte strings, so test for it.
126	m := &testpb.TestAllTypes{}
127	b := protopack.Message{
128		protopack.Tag{45, protopack.BytesType}, protopack.Bytes(nil),
129	}.Marshal()
130	if err := proto.Unmarshal(b, m); err != nil {
131		t.Fatal(err)
132	}
133	if m.RepeatedBytes[0] == nil {
134		t.Errorf("unmarshaling repeated bytes field containing zero-length value: Got nil bytes, want non-nil")
135	}
136}
137
138func build(m proto.Message, opts ...buildOpt) proto.Message {
139	for _, opt := range opts {
140		opt(m)
141	}
142	return m
143}
144
145type buildOpt func(proto.Message)
146
147func unknown(raw protoreflect.RawFields) buildOpt {
148	return func(m proto.Message) {
149		m.ProtoReflect().SetUnknown(raw)
150	}
151}
152
153func extend(desc protoreflect.ExtensionType, value interface{}) buildOpt {
154	return func(m proto.Message) {
155		proto.SetExtension(m, desc, value)
156	}
157}
158