1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"google.golang.org/protobuf/internal/errors"
9	"google.golang.org/protobuf/reflect/protoreflect"
10	"google.golang.org/protobuf/runtime/protoiface"
11)
12
13// CheckInitialized returns an error if any required fields in m are not set.
14func CheckInitialized(m Message) error {
15	// Treat a nil message interface as an "untyped" empty message,
16	// which we assume to have no required fields.
17	if m == nil {
18		return nil
19	}
20
21	return checkInitialized(m.ProtoReflect())
22}
23
24// CheckInitialized returns an error if any required fields in m are not set.
25func checkInitialized(m protoreflect.Message) error {
26	if methods := protoMethods(m); methods != nil && methods.CheckInitialized != nil {
27		_, err := methods.CheckInitialized(protoiface.CheckInitializedInput{
28			Message: m,
29		})
30		return err
31	}
32	return checkInitializedSlow(m)
33}
34
35func checkInitializedSlow(m protoreflect.Message) error {
36	md := m.Descriptor()
37	fds := md.Fields()
38	for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ {
39		fd := fds.ByNumber(nums.Get(i))
40		if !m.Has(fd) {
41			return errors.RequiredNotSet(string(fd.FullName()))
42		}
43	}
44	var err error
45	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
46		switch {
47		case fd.IsList():
48			if fd.Message() == nil {
49				return true
50			}
51			for i, list := 0, v.List(); i < list.Len() && err == nil; i++ {
52				err = checkInitialized(list.Get(i).Message())
53			}
54		case fd.IsMap():
55			if fd.MapValue().Message() == nil {
56				return true
57			}
58			v.Map().Range(func(key protoreflect.MapKey, v protoreflect.Value) bool {
59				err = checkInitialized(v.Message())
60				return err == nil
61			})
62		default:
63			if fd.Message() == nil {
64				return true
65			}
66			err = checkInitialized(v.Message())
67		}
68		return err == nil
69	})
70	return err
71}
72