1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"fmt"
9	"reflect"
10
11	"google.golang.org/protobuf/proto"
12	pref "google.golang.org/protobuf/reflect/protoreflect"
13	piface "google.golang.org/protobuf/runtime/protoiface"
14)
15
16type mergeOptions struct{}
17
18func (o mergeOptions) Merge(dst, src proto.Message) {
19	proto.Merge(dst, src)
20}
21
22// merge is protoreflect.Methods.Merge.
23func (mi *MessageInfo) merge(in piface.MergeInput) piface.MergeOutput {
24	dp, ok := mi.getPointer(in.Destination)
25	if !ok {
26		return piface.MergeOutput{}
27	}
28	sp, ok := mi.getPointer(in.Source)
29	if !ok {
30		return piface.MergeOutput{}
31	}
32	mi.mergePointer(dp, sp, mergeOptions{})
33	return piface.MergeOutput{Flags: piface.MergeComplete}
34}
35
36func (mi *MessageInfo) mergePointer(dst, src pointer, opts mergeOptions) {
37	mi.init()
38	if dst.IsNil() {
39		panic(fmt.Sprintf("invalid value: merging into nil message"))
40	}
41	if src.IsNil() {
42		return
43	}
44	for _, f := range mi.orderedCoderFields {
45		if f.funcs.merge == nil {
46			continue
47		}
48		sfptr := src.Apply(f.offset)
49		if f.isPointer && sfptr.Elem().IsNil() {
50			continue
51		}
52		f.funcs.merge(dst.Apply(f.offset), sfptr, f, opts)
53	}
54	if mi.extensionOffset.IsValid() {
55		sext := src.Apply(mi.extensionOffset).Extensions()
56		dext := dst.Apply(mi.extensionOffset).Extensions()
57		if *dext == nil {
58			*dext = make(map[int32]ExtensionField)
59		}
60		for num, sx := range *sext {
61			xt := sx.Type()
62			xi := getExtensionFieldInfo(xt)
63			if xi.funcs.merge == nil {
64				continue
65			}
66			dx := (*dext)[num]
67			var dv pref.Value
68			if dx.Type() == sx.Type() {
69				dv = dx.Value()
70			}
71			if !dv.IsValid() && xi.unmarshalNeedsValue {
72				dv = xt.New()
73			}
74			dv = xi.funcs.merge(dv, sx.Value(), opts)
75			dx.Set(sx.Type(), dv)
76			(*dext)[num] = dx
77		}
78	}
79	if mi.unknownOffset.IsValid() {
80		du := dst.Apply(mi.unknownOffset).Bytes()
81		su := src.Apply(mi.unknownOffset).Bytes()
82		if len(*su) > 0 {
83			*du = append(*du, *su...)
84		}
85	}
86}
87
88func mergeScalarValue(dst, src pref.Value, opts mergeOptions) pref.Value {
89	return src
90}
91
92func mergeBytesValue(dst, src pref.Value, opts mergeOptions) pref.Value {
93	return pref.ValueOfBytes(append(emptyBuf[:], src.Bytes()...))
94}
95
96func mergeListValue(dst, src pref.Value, opts mergeOptions) pref.Value {
97	dstl := dst.List()
98	srcl := src.List()
99	for i, llen := 0, srcl.Len(); i < llen; i++ {
100		dstl.Append(srcl.Get(i))
101	}
102	return dst
103}
104
105func mergeBytesListValue(dst, src pref.Value, opts mergeOptions) pref.Value {
106	dstl := dst.List()
107	srcl := src.List()
108	for i, llen := 0, srcl.Len(); i < llen; i++ {
109		sb := srcl.Get(i).Bytes()
110		db := append(emptyBuf[:], sb...)
111		dstl.Append(pref.ValueOfBytes(db))
112	}
113	return dst
114}
115
116func mergeMessageListValue(dst, src pref.Value, opts mergeOptions) pref.Value {
117	dstl := dst.List()
118	srcl := src.List()
119	for i, llen := 0, srcl.Len(); i < llen; i++ {
120		sm := srcl.Get(i).Message()
121		dm := proto.Clone(sm.Interface()).ProtoReflect()
122		dstl.Append(pref.ValueOfMessage(dm))
123	}
124	return dst
125}
126
127func mergeMessageValue(dst, src pref.Value, opts mergeOptions) pref.Value {
128	opts.Merge(dst.Message().Interface(), src.Message().Interface())
129	return dst
130}
131
132func mergeMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
133	if f.mi != nil {
134		if dst.Elem().IsNil() {
135			dst.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
136		}
137		f.mi.mergePointer(dst.Elem(), src.Elem(), opts)
138	} else {
139		dm := dst.AsValueOf(f.ft).Elem()
140		sm := src.AsValueOf(f.ft).Elem()
141		if dm.IsNil() {
142			dm.Set(reflect.New(f.ft.Elem()))
143		}
144		opts.Merge(asMessage(dm), asMessage(sm))
145	}
146}
147
148func mergeMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
149	for _, sp := range src.PointerSlice() {
150		dm := reflect.New(f.ft.Elem().Elem())
151		if f.mi != nil {
152			f.mi.mergePointer(pointerOfValue(dm), sp, opts)
153		} else {
154			opts.Merge(asMessage(dm), asMessage(sp.AsValueOf(f.ft.Elem().Elem())))
155		}
156		dst.AppendPointerSlice(pointerOfValue(dm))
157	}
158}
159
160func mergeBytes(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
161	*dst.Bytes() = append(emptyBuf[:], *src.Bytes()...)
162}
163
164func mergeBytesNoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
165	v := *src.Bytes()
166	if len(v) > 0 {
167		*dst.Bytes() = append(emptyBuf[:], v...)
168	}
169}
170
171func mergeBytesSlice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
172	ds := dst.BytesSlice()
173	for _, v := range *src.BytesSlice() {
174		*ds = append(*ds, append(emptyBuf[:], v...))
175	}
176}
177