1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"sync"
9
10	"google.golang.org/protobuf/internal/errors"
11	pref "google.golang.org/protobuf/reflect/protoreflect"
12	piface "google.golang.org/protobuf/runtime/protoiface"
13)
14
15func (mi *MessageInfo) checkInitialized(in piface.CheckInitializedInput) (piface.CheckInitializedOutput, error) {
16	var p pointer
17	if ms, ok := in.Message.(*messageState); ok {
18		p = ms.pointer()
19	} else {
20		p = in.Message.(*messageReflectWrapper).pointer()
21	}
22	return piface.CheckInitializedOutput{}, mi.checkInitializedPointer(p)
23}
24
25func (mi *MessageInfo) checkInitializedPointer(p pointer) error {
26	mi.init()
27	if !mi.needsInitCheck {
28		return nil
29	}
30	if p.IsNil() {
31		for _, f := range mi.orderedCoderFields {
32			if f.isRequired {
33				return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
34			}
35		}
36		return nil
37	}
38	if mi.extensionOffset.IsValid() {
39		e := p.Apply(mi.extensionOffset).Extensions()
40		if err := mi.isInitExtensions(e); err != nil {
41			return err
42		}
43	}
44	for _, f := range mi.orderedCoderFields {
45		if !f.isRequired && f.funcs.isInit == nil {
46			continue
47		}
48		fptr := p.Apply(f.offset)
49		if f.isPointer && fptr.Elem().IsNil() {
50			if f.isRequired {
51				return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
52			}
53			continue
54		}
55		if f.funcs.isInit == nil {
56			continue
57		}
58		if err := f.funcs.isInit(fptr, f); err != nil {
59			return err
60		}
61	}
62	return nil
63}
64
65func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
66	if ext == nil {
67		return nil
68	}
69	for _, x := range *ext {
70		ei := getExtensionFieldInfo(x.Type())
71		if ei.funcs.isInit == nil {
72			continue
73		}
74		v := x.Value()
75		if !v.IsValid() {
76			continue
77		}
78		if err := ei.funcs.isInit(v); err != nil {
79			return err
80		}
81	}
82	return nil
83}
84
85var (
86	needsInitCheckMu  sync.Mutex
87	needsInitCheckMap sync.Map
88)
89
90// needsInitCheck reports whether a message needs to be checked for partial initialization.
91//
92// It returns true if the message transitively includes any required or extension fields.
93func needsInitCheck(md pref.MessageDescriptor) bool {
94	if v, ok := needsInitCheckMap.Load(md); ok {
95		if has, ok := v.(bool); ok {
96			return has
97		}
98	}
99	needsInitCheckMu.Lock()
100	defer needsInitCheckMu.Unlock()
101	return needsInitCheckLocked(md)
102}
103
104func needsInitCheckLocked(md pref.MessageDescriptor) (has bool) {
105	if v, ok := needsInitCheckMap.Load(md); ok {
106		// If has is true, we've previously determined that this message
107		// needs init checks.
108		//
109		// If has is false, we've previously determined that it can never
110		// be uninitialized.
111		//
112		// If has is not a bool, we've just encountered a cycle in the
113		// message graph. In this case, it is safe to return false: If
114		// the message does have required fields, we'll detect them later
115		// in the graph traversal.
116		has, ok := v.(bool)
117		return ok && has
118	}
119	needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
120	defer func() {
121		needsInitCheckMap.Store(md, has)
122	}()
123	if md.RequiredNumbers().Len() > 0 {
124		return true
125	}
126	if md.ExtensionRanges().Len() > 0 {
127		return true
128	}
129	for i := 0; i < md.Fields().Len(); i++ {
130		fd := md.Fields().Get(i)
131		// Map keys are never messages, so just consider the map value.
132		if fd.IsMap() {
133			fd = fd.MapValue()
134		}
135		fmd := fd.Message()
136		if fmd != nil && needsInitCheckLocked(fmd) {
137			return true
138		}
139	}
140	return false
141}
142