1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"google.golang.org/protobuf/reflect/protoreflect"
9)
10
11// SetDefaults sets unpopulated scalar fields to their default values.
12// Fields within a oneof are not set even if they have a default value.
13// SetDefaults is recursively called upon any populated message fields.
14func SetDefaults(m Message) {
15	if m != nil {
16		setDefaults(MessageReflect(m))
17	}
18}
19
20func setDefaults(m protoreflect.Message) {
21	fds := m.Descriptor().Fields()
22	for i := 0; i < fds.Len(); i++ {
23		fd := fds.Get(i)
24		if !m.Has(fd) {
25			if fd.HasDefault() && fd.ContainingOneof() == nil {
26				v := fd.Default()
27				if fd.Kind() == protoreflect.BytesKind {
28					v = protoreflect.ValueOf(append([]byte(nil), v.Bytes()...)) // copy the default bytes
29				}
30				m.Set(fd, v)
31			}
32			continue
33		}
34	}
35
36	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
37		switch {
38		// Handle singular message.
39		case fd.Cardinality() != protoreflect.Repeated:
40			if fd.Message() != nil {
41				setDefaults(m.Get(fd).Message())
42			}
43		// Handle list of messages.
44		case fd.IsList():
45			if fd.Message() != nil {
46				ls := m.Get(fd).List()
47				for i := 0; i < ls.Len(); i++ {
48					setDefaults(ls.Get(i).Message())
49				}
50			}
51		// Handle map of messages.
52		case fd.IsMap():
53			if fd.MapValue().Message() != nil {
54				ms := m.Get(fd).Map()
55				ms.Range(func(_ protoreflect.MapKey, v protoreflect.Value) bool {
56					setDefaults(v.Message())
57					return true
58				})
59			}
60		}
61		return true
62	})
63}
64