1// Go support for Protocol Buffers - Google's data interchange format
2//
3// Copyright 2016 The Go Authors.  All rights reserved.
4// https://github.com/golang/protobuf
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10//     * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//     * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16//     * Neither the name of Google Inc. nor the names of its
17// contributors may be used to endorse or promote products derived from
18// this software without specific prior written permission.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32package proto
33
34import (
35	"fmt"
36	"reflect"
37	"strings"
38	"sync"
39	"sync/atomic"
40)
41
42// Merge merges the src message into dst.
43// This assumes that dst and src of the same type and are non-nil.
44func (a *InternalMessageInfo) Merge(dst, src Message) {
45	mi := atomicLoadMergeInfo(&a.merge)
46	if mi == nil {
47		mi = getMergeInfo(reflect.TypeOf(dst).Elem())
48		atomicStoreMergeInfo(&a.merge, mi)
49	}
50	mi.merge(toPointer(&dst), toPointer(&src))
51}
52
53type mergeInfo struct {
54	typ reflect.Type
55
56	initialized int32 // 0: only typ is valid, 1: everything is valid
57	lock        sync.Mutex
58
59	fields       []mergeFieldInfo
60	unrecognized field // Offset of XXX_unrecognized
61}
62
63type mergeFieldInfo struct {
64	field field // Offset of field, guaranteed to be valid
65
66	// isPointer reports whether the value in the field is a pointer.
67	// This is true for the following situations:
68	//	* Pointer to struct
69	//	* Pointer to basic type (proto2 only)
70	//	* Slice (first value in slice header is a pointer)
71	//	* String (first value in string header is a pointer)
72	isPointer bool
73
74	// basicWidth reports the width of the field assuming that it is directly
75	// embedded in the struct (as is the case for basic types in proto3).
76	// The possible values are:
77	// 	0: invalid
78	//	1: bool
79	//	4: int32, uint32, float32
80	//	8: int64, uint64, float64
81	basicWidth int
82
83	// Where dst and src are pointers to the types being merged.
84	merge func(dst, src pointer)
85}
86
87var (
88	mergeInfoMap  = map[reflect.Type]*mergeInfo{}
89	mergeInfoLock sync.Mutex
90)
91
92func getMergeInfo(t reflect.Type) *mergeInfo {
93	mergeInfoLock.Lock()
94	defer mergeInfoLock.Unlock()
95	mi := mergeInfoMap[t]
96	if mi == nil {
97		mi = &mergeInfo{typ: t}
98		mergeInfoMap[t] = mi
99	}
100	return mi
101}
102
103// merge merges src into dst assuming they are both of type *mi.typ.
104func (mi *mergeInfo) merge(dst, src pointer) {
105	if dst.isNil() {
106		panic("proto: nil destination")
107	}
108	if src.isNil() {
109		return // Nothing to do.
110	}
111
112	if atomic.LoadInt32(&mi.initialized) == 0 {
113		mi.computeMergeInfo()
114	}
115
116	for _, fi := range mi.fields {
117		sfp := src.offset(fi.field)
118
119		// As an optimization, we can avoid the merge function call cost
120		// if we know for sure that the source will have no effect
121		// by checking if it is the zero value.
122		if unsafeAllowed {
123			if fi.isPointer && sfp.getPointer().isNil() { // Could be slice or string
124				continue
125			}
126			if fi.basicWidth > 0 {
127				switch {
128				case fi.basicWidth == 1 && !*sfp.toBool():
129					continue
130				case fi.basicWidth == 4 && *sfp.toUint32() == 0:
131					continue
132				case fi.basicWidth == 8 && *sfp.toUint64() == 0:
133					continue
134				}
135			}
136		}
137
138		dfp := dst.offset(fi.field)
139		fi.merge(dfp, sfp)
140	}
141
142	// TODO: Make this faster?
143	out := dst.asPointerTo(mi.typ).Elem()
144	in := src.asPointerTo(mi.typ).Elem()
145	if emIn, err := extendable(in.Addr().Interface()); err == nil {
146		emOut, _ := extendable(out.Addr().Interface())
147		mIn, muIn := emIn.extensionsRead()
148		if mIn != nil {
149			mOut := emOut.extensionsWrite()
150			muIn.Lock()
151			mergeExtension(mOut, mIn)
152			muIn.Unlock()
153		}
154	}
155
156	if mi.unrecognized.IsValid() {
157		if b := *src.offset(mi.unrecognized).toBytes(); len(b) > 0 {
158			*dst.offset(mi.unrecognized).toBytes() = append([]byte(nil), b...)
159		}
160	}
161}
162
163func (mi *mergeInfo) computeMergeInfo() {
164	mi.lock.Lock()
165	defer mi.lock.Unlock()
166	if mi.initialized != 0 {
167		return
168	}
169	t := mi.typ
170	n := t.NumField()
171
172	props := GetProperties(t)
173	for i := 0; i < n; i++ {
174		f := t.Field(i)
175		if strings.HasPrefix(f.Name, "XXX_") {
176			continue
177		}
178
179		mfi := mergeFieldInfo{field: toField(&f)}
180		tf := f.Type
181
182		// As an optimization, we can avoid the merge function call cost
183		// if we know for sure that the source will have no effect
184		// by checking if it is the zero value.
185		if unsafeAllowed {
186			switch tf.Kind() {
187			case reflect.Ptr, reflect.Slice, reflect.String:
188				// As a special case, we assume slices and strings are pointers
189				// since we know that the first field in the SliceSlice or
190				// StringHeader is a data pointer.
191				mfi.isPointer = true
192			case reflect.Bool:
193				mfi.basicWidth = 1
194			case reflect.Int32, reflect.Uint32, reflect.Float32:
195				mfi.basicWidth = 4
196			case reflect.Int64, reflect.Uint64, reflect.Float64:
197				mfi.basicWidth = 8
198			}
199		}
200
201		// Unwrap tf to get at its most basic type.
202		var isPointer, isSlice bool
203		if tf.Kind() == reflect.Slice && tf.Elem().Kind() != reflect.Uint8 {
204			isSlice = true
205			tf = tf.Elem()
206		}
207		if tf.Kind() == reflect.Ptr {
208			isPointer = true
209			tf = tf.Elem()
210		}
211		if isPointer && isSlice && tf.Kind() != reflect.Struct {
212			panic("both pointer and slice for basic type in " + tf.Name())
213		}
214
215		switch tf.Kind() {
216		case reflect.Int32:
217			switch {
218			case isSlice: // E.g., []int32
219				mfi.merge = func(dst, src pointer) {
220					// NOTE: toInt32Slice is not defined (see pointer_reflect.go).
221					/*
222						sfsp := src.toInt32Slice()
223						if *sfsp != nil {
224							dfsp := dst.toInt32Slice()
225							*dfsp = append(*dfsp, *sfsp...)
226							if *dfsp == nil {
227								*dfsp = []int64{}
228							}
229						}
230					*/
231					sfs := src.getInt32Slice()
232					if sfs != nil {
233						dfs := dst.getInt32Slice()
234						dfs = append(dfs, sfs...)
235						if dfs == nil {
236							dfs = []int32{}
237						}
238						dst.setInt32Slice(dfs)
239					}
240				}
241			case isPointer: // E.g., *int32
242				mfi.merge = func(dst, src pointer) {
243					// NOTE: toInt32Ptr is not defined (see pointer_reflect.go).
244					/*
245						sfpp := src.toInt32Ptr()
246						if *sfpp != nil {
247							dfpp := dst.toInt32Ptr()
248							if *dfpp == nil {
249								*dfpp = Int32(**sfpp)
250							} else {
251								**dfpp = **sfpp
252							}
253						}
254					*/
255					sfp := src.getInt32Ptr()
256					if sfp != nil {
257						dfp := dst.getInt32Ptr()
258						if dfp == nil {
259							dst.setInt32Ptr(*sfp)
260						} else {
261							*dfp = *sfp
262						}
263					}
264				}
265			default: // E.g., int32
266				mfi.merge = func(dst, src pointer) {
267					if v := *src.toInt32(); v != 0 {
268						*dst.toInt32() = v
269					}
270				}
271			}
272		case reflect.Int64:
273			switch {
274			case isSlice: // E.g., []int64
275				mfi.merge = func(dst, src pointer) {
276					sfsp := src.toInt64Slice()
277					if *sfsp != nil {
278						dfsp := dst.toInt64Slice()
279						*dfsp = append(*dfsp, *sfsp...)
280						if *dfsp == nil {
281							*dfsp = []int64{}
282						}
283					}
284				}
285			case isPointer: // E.g., *int64
286				mfi.merge = func(dst, src pointer) {
287					sfpp := src.toInt64Ptr()
288					if *sfpp != nil {
289						dfpp := dst.toInt64Ptr()
290						if *dfpp == nil {
291							*dfpp = Int64(**sfpp)
292						} else {
293							**dfpp = **sfpp
294						}
295					}
296				}
297			default: // E.g., int64
298				mfi.merge = func(dst, src pointer) {
299					if v := *src.toInt64(); v != 0 {
300						*dst.toInt64() = v
301					}
302				}
303			}
304		case reflect.Uint32:
305			switch {
306			case isSlice: // E.g., []uint32
307				mfi.merge = func(dst, src pointer) {
308					sfsp := src.toUint32Slice()
309					if *sfsp != nil {
310						dfsp := dst.toUint32Slice()
311						*dfsp = append(*dfsp, *sfsp...)
312						if *dfsp == nil {
313							*dfsp = []uint32{}
314						}
315					}
316				}
317			case isPointer: // E.g., *uint32
318				mfi.merge = func(dst, src pointer) {
319					sfpp := src.toUint32Ptr()
320					if *sfpp != nil {
321						dfpp := dst.toUint32Ptr()
322						if *dfpp == nil {
323							*dfpp = Uint32(**sfpp)
324						} else {
325							**dfpp = **sfpp
326						}
327					}
328				}
329			default: // E.g., uint32
330				mfi.merge = func(dst, src pointer) {
331					if v := *src.toUint32(); v != 0 {
332						*dst.toUint32() = v
333					}
334				}
335			}
336		case reflect.Uint64:
337			switch {
338			case isSlice: // E.g., []uint64
339				mfi.merge = func(dst, src pointer) {
340					sfsp := src.toUint64Slice()
341					if *sfsp != nil {
342						dfsp := dst.toUint64Slice()
343						*dfsp = append(*dfsp, *sfsp...)
344						if *dfsp == nil {
345							*dfsp = []uint64{}
346						}
347					}
348				}
349			case isPointer: // E.g., *uint64
350				mfi.merge = func(dst, src pointer) {
351					sfpp := src.toUint64Ptr()
352					if *sfpp != nil {
353						dfpp := dst.toUint64Ptr()
354						if *dfpp == nil {
355							*dfpp = Uint64(**sfpp)
356						} else {
357							**dfpp = **sfpp
358						}
359					}
360				}
361			default: // E.g., uint64
362				mfi.merge = func(dst, src pointer) {
363					if v := *src.toUint64(); v != 0 {
364						*dst.toUint64() = v
365					}
366				}
367			}
368		case reflect.Float32:
369			switch {
370			case isSlice: // E.g., []float32
371				mfi.merge = func(dst, src pointer) {
372					sfsp := src.toFloat32Slice()
373					if *sfsp != nil {
374						dfsp := dst.toFloat32Slice()
375						*dfsp = append(*dfsp, *sfsp...)
376						if *dfsp == nil {
377							*dfsp = []float32{}
378						}
379					}
380				}
381			case isPointer: // E.g., *float32
382				mfi.merge = func(dst, src pointer) {
383					sfpp := src.toFloat32Ptr()
384					if *sfpp != nil {
385						dfpp := dst.toFloat32Ptr()
386						if *dfpp == nil {
387							*dfpp = Float32(**sfpp)
388						} else {
389							**dfpp = **sfpp
390						}
391					}
392				}
393			default: // E.g., float32
394				mfi.merge = func(dst, src pointer) {
395					if v := *src.toFloat32(); v != 0 {
396						*dst.toFloat32() = v
397					}
398				}
399			}
400		case reflect.Float64:
401			switch {
402			case isSlice: // E.g., []float64
403				mfi.merge = func(dst, src pointer) {
404					sfsp := src.toFloat64Slice()
405					if *sfsp != nil {
406						dfsp := dst.toFloat64Slice()
407						*dfsp = append(*dfsp, *sfsp...)
408						if *dfsp == nil {
409							*dfsp = []float64{}
410						}
411					}
412				}
413			case isPointer: // E.g., *float64
414				mfi.merge = func(dst, src pointer) {
415					sfpp := src.toFloat64Ptr()
416					if *sfpp != nil {
417						dfpp := dst.toFloat64Ptr()
418						if *dfpp == nil {
419							*dfpp = Float64(**sfpp)
420						} else {
421							**dfpp = **sfpp
422						}
423					}
424				}
425			default: // E.g., float64
426				mfi.merge = func(dst, src pointer) {
427					if v := *src.toFloat64(); v != 0 {
428						*dst.toFloat64() = v
429					}
430				}
431			}
432		case reflect.Bool:
433			switch {
434			case isSlice: // E.g., []bool
435				mfi.merge = func(dst, src pointer) {
436					sfsp := src.toBoolSlice()
437					if *sfsp != nil {
438						dfsp := dst.toBoolSlice()
439						*dfsp = append(*dfsp, *sfsp...)
440						if *dfsp == nil {
441							*dfsp = []bool{}
442						}
443					}
444				}
445			case isPointer: // E.g., *bool
446				mfi.merge = func(dst, src pointer) {
447					sfpp := src.toBoolPtr()
448					if *sfpp != nil {
449						dfpp := dst.toBoolPtr()
450						if *dfpp == nil {
451							*dfpp = Bool(**sfpp)
452						} else {
453							**dfpp = **sfpp
454						}
455					}
456				}
457			default: // E.g., bool
458				mfi.merge = func(dst, src pointer) {
459					if v := *src.toBool(); v {
460						*dst.toBool() = v
461					}
462				}
463			}
464		case reflect.String:
465			switch {
466			case isSlice: // E.g., []string
467				mfi.merge = func(dst, src pointer) {
468					sfsp := src.toStringSlice()
469					if *sfsp != nil {
470						dfsp := dst.toStringSlice()
471						*dfsp = append(*dfsp, *sfsp...)
472						if *dfsp == nil {
473							*dfsp = []string{}
474						}
475					}
476				}
477			case isPointer: // E.g., *string
478				mfi.merge = func(dst, src pointer) {
479					sfpp := src.toStringPtr()
480					if *sfpp != nil {
481						dfpp := dst.toStringPtr()
482						if *dfpp == nil {
483							*dfpp = String(**sfpp)
484						} else {
485							**dfpp = **sfpp
486						}
487					}
488				}
489			default: // E.g., string
490				mfi.merge = func(dst, src pointer) {
491					if v := *src.toString(); v != "" {
492						*dst.toString() = v
493					}
494				}
495			}
496		case reflect.Slice:
497			isProto3 := props.Prop[i].proto3
498			switch {
499			case isPointer:
500				panic("bad pointer in byte slice case in " + tf.Name())
501			case tf.Elem().Kind() != reflect.Uint8:
502				panic("bad element kind in byte slice case in " + tf.Name())
503			case isSlice: // E.g., [][]byte
504				mfi.merge = func(dst, src pointer) {
505					sbsp := src.toBytesSlice()
506					if *sbsp != nil {
507						dbsp := dst.toBytesSlice()
508						for _, sb := range *sbsp {
509							if sb == nil {
510								*dbsp = append(*dbsp, nil)
511							} else {
512								*dbsp = append(*dbsp, append([]byte{}, sb...))
513							}
514						}
515						if *dbsp == nil {
516							*dbsp = [][]byte{}
517						}
518					}
519				}
520			default: // E.g., []byte
521				mfi.merge = func(dst, src pointer) {
522					sbp := src.toBytes()
523					if *sbp != nil {
524						dbp := dst.toBytes()
525						if !isProto3 || len(*sbp) > 0 {
526							*dbp = append([]byte{}, *sbp...)
527						}
528					}
529				}
530			}
531		case reflect.Struct:
532			switch {
533			case !isPointer:
534				panic(fmt.Sprintf("message field %s without pointer", tf))
535			case isSlice: // E.g., []*pb.T
536				mi := getMergeInfo(tf)
537				mfi.merge = func(dst, src pointer) {
538					sps := src.getPointerSlice()
539					if sps != nil {
540						dps := dst.getPointerSlice()
541						for _, sp := range sps {
542							var dp pointer
543							if !sp.isNil() {
544								dp = valToPointer(reflect.New(tf))
545								mi.merge(dp, sp)
546							}
547							dps = append(dps, dp)
548						}
549						if dps == nil {
550							dps = []pointer{}
551						}
552						dst.setPointerSlice(dps)
553					}
554				}
555			default: // E.g., *pb.T
556				mi := getMergeInfo(tf)
557				mfi.merge = func(dst, src pointer) {
558					sp := src.getPointer()
559					if !sp.isNil() {
560						dp := dst.getPointer()
561						if dp.isNil() {
562							dp = valToPointer(reflect.New(tf))
563							dst.setPointer(dp)
564						}
565						mi.merge(dp, sp)
566					}
567				}
568			}
569		case reflect.Map:
570			switch {
571			case isPointer || isSlice:
572				panic("bad pointer or slice in map case in " + tf.Name())
573			default: // E.g., map[K]V
574				mfi.merge = func(dst, src pointer) {
575					sm := src.asPointerTo(tf).Elem()
576					if sm.Len() == 0 {
577						return
578					}
579					dm := dst.asPointerTo(tf).Elem()
580					if dm.IsNil() {
581						dm.Set(reflect.MakeMap(tf))
582					}
583
584					switch tf.Elem().Kind() {
585					case reflect.Ptr: // Proto struct (e.g., *T)
586						for _, key := range sm.MapKeys() {
587							val := sm.MapIndex(key)
588							val = reflect.ValueOf(Clone(val.Interface().(Message)))
589							dm.SetMapIndex(key, val)
590						}
591					case reflect.Slice: // E.g. Bytes type (e.g., []byte)
592						for _, key := range sm.MapKeys() {
593							val := sm.MapIndex(key)
594							val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
595							dm.SetMapIndex(key, val)
596						}
597					default: // Basic type (e.g., string)
598						for _, key := range sm.MapKeys() {
599							val := sm.MapIndex(key)
600							dm.SetMapIndex(key, val)
601						}
602					}
603				}
604			}
605		case reflect.Interface:
606			// Must be oneof field.
607			switch {
608			case isPointer || isSlice:
609				panic("bad pointer or slice in interface case in " + tf.Name())
610			default: // E.g., interface{}
611				// TODO: Make this faster?
612				mfi.merge = func(dst, src pointer) {
613					su := src.asPointerTo(tf).Elem()
614					if !su.IsNil() {
615						du := dst.asPointerTo(tf).Elem()
616						typ := su.Elem().Type()
617						if du.IsNil() || du.Elem().Type() != typ {
618							du.Set(reflect.New(typ.Elem())) // Initialize interface if empty
619						}
620						sv := su.Elem().Elem().Field(0)
621						if sv.Kind() == reflect.Ptr && sv.IsNil() {
622							return
623						}
624						dv := du.Elem().Elem().Field(0)
625						if dv.Kind() == reflect.Ptr && dv.IsNil() {
626							dv.Set(reflect.New(sv.Type().Elem())) // Initialize proto message if empty
627						}
628						switch sv.Type().Kind() {
629						case reflect.Ptr: // Proto struct (e.g., *T)
630							Merge(dv.Interface().(Message), sv.Interface().(Message))
631						case reflect.Slice: // E.g. Bytes type (e.g., []byte)
632							dv.Set(reflect.ValueOf(append([]byte{}, sv.Bytes()...)))
633						default: // Basic type (e.g., string)
634							dv.Set(sv)
635						}
636					}
637				}
638			}
639		default:
640			panic(fmt.Sprintf("merger not found for type:%s", tf))
641		}
642		mi.fields = append(mi.fields, mfi)
643	}
644
645	mi.unrecognized = invalidField
646	if f, ok := t.FieldByName("XXX_unrecognized"); ok {
647		if f.Type != reflect.TypeOf([]byte{}) {
648			panic("expected XXX_unrecognized to be of type []byte")
649		}
650		mi.unrecognized = toField(&f)
651	}
652
653	atomic.StoreInt32(&mi.initialized, 1)
654}
655