1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"sync"
9	"sync/atomic"
10
11	"google.golang.org/protobuf/encoding/protowire"
12	"google.golang.org/protobuf/internal/errors"
13	pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16type extensionFieldInfo struct {
17	wiretag             uint64
18	tagsize             int
19	unmarshalNeedsValue bool
20	funcs               valueCoderFuncs
21	validation          validationInfo
22}
23
24var legacyExtensionFieldInfoCache sync.Map // map[protoreflect.ExtensionType]*extensionFieldInfo
25
26func getExtensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
27	if xi, ok := xt.(*ExtensionInfo); ok {
28		xi.lazyInit()
29		return xi.info
30	}
31	return legacyLoadExtensionFieldInfo(xt)
32}
33
34// legacyLoadExtensionFieldInfo dynamically loads a *ExtensionInfo for xt.
35func legacyLoadExtensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
36	if xi, ok := legacyExtensionFieldInfoCache.Load(xt); ok {
37		return xi.(*extensionFieldInfo)
38	}
39	e := makeExtensionFieldInfo(xt.TypeDescriptor())
40	if e, ok := legacyMessageTypeCache.LoadOrStore(xt, e); ok {
41		return e.(*extensionFieldInfo)
42	}
43	return e
44}
45
46func makeExtensionFieldInfo(xd pref.ExtensionDescriptor) *extensionFieldInfo {
47	var wiretag uint64
48	if !xd.IsPacked() {
49		wiretag = protowire.EncodeTag(xd.Number(), wireTypes[xd.Kind()])
50	} else {
51		wiretag = protowire.EncodeTag(xd.Number(), protowire.BytesType)
52	}
53	e := &extensionFieldInfo{
54		wiretag: wiretag,
55		tagsize: protowire.SizeVarint(wiretag),
56		funcs:   encoderFuncsForValue(xd),
57	}
58	// Does the unmarshal function need a value passed to it?
59	// This is true for composite types, where we pass in a message, list, or map to fill in,
60	// and for enums, where we pass in a prototype value to specify the concrete enum type.
61	switch xd.Kind() {
62	case pref.MessageKind, pref.GroupKind, pref.EnumKind:
63		e.unmarshalNeedsValue = true
64	default:
65		if xd.Cardinality() == pref.Repeated {
66			e.unmarshalNeedsValue = true
67		}
68	}
69	return e
70}
71
72type lazyExtensionValue struct {
73	atomicOnce uint32 // atomically set if value is valid
74	mu         sync.Mutex
75	xi         *extensionFieldInfo
76	value      pref.Value
77	b          []byte
78	fn         func() pref.Value
79}
80
81type ExtensionField struct {
82	typ pref.ExtensionType
83
84	// value is either the value of GetValue,
85	// or a *lazyExtensionValue that then returns the value of GetValue.
86	value pref.Value
87	lazy  *lazyExtensionValue
88}
89
90func (f *ExtensionField) appendLazyBytes(xt pref.ExtensionType, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, b []byte) {
91	if f.lazy == nil {
92		f.lazy = &lazyExtensionValue{xi: xi}
93	}
94	f.typ = xt
95	f.lazy.xi = xi
96	f.lazy.b = protowire.AppendTag(f.lazy.b, num, wtyp)
97	f.lazy.b = append(f.lazy.b, b...)
98}
99
100func (f *ExtensionField) canLazy(xt pref.ExtensionType) bool {
101	if f.typ == nil {
102		return true
103	}
104	if f.typ == xt && f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0 {
105		return true
106	}
107	return false
108}
109
110func (f *ExtensionField) lazyInit() {
111	f.lazy.mu.Lock()
112	defer f.lazy.mu.Unlock()
113	if atomic.LoadUint32(&f.lazy.atomicOnce) == 1 {
114		return
115	}
116	if f.lazy.xi != nil {
117		b := f.lazy.b
118		val := f.typ.New()
119		for len(b) > 0 {
120			var tag uint64
121			if b[0] < 0x80 {
122				tag = uint64(b[0])
123				b = b[1:]
124			} else if len(b) >= 2 && b[1] < 128 {
125				tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
126				b = b[2:]
127			} else {
128				var n int
129				tag, n = protowire.ConsumeVarint(b)
130				if n < 0 {
131					panic(errors.New("bad tag in lazy extension decoding"))
132				}
133				b = b[n:]
134			}
135			num := protowire.Number(tag >> 3)
136			wtyp := protowire.Type(tag & 7)
137			var out unmarshalOutput
138			var err error
139			val, out, err = f.lazy.xi.funcs.unmarshal(b, val, num, wtyp, lazyUnmarshalOptions)
140			if err != nil {
141				panic(errors.New("decode failure in lazy extension decoding: %v", err))
142			}
143			b = b[out.n:]
144		}
145		f.lazy.value = val
146	} else {
147		f.lazy.value = f.lazy.fn()
148	}
149	f.lazy.xi = nil
150	f.lazy.fn = nil
151	f.lazy.b = nil
152	atomic.StoreUint32(&f.lazy.atomicOnce, 1)
153}
154
155// Set sets the type and value of the extension field.
156// This must not be called concurrently.
157func (f *ExtensionField) Set(t pref.ExtensionType, v pref.Value) {
158	f.typ = t
159	f.value = v
160	f.lazy = nil
161}
162
163// SetLazy sets the type and a value that is to be lazily evaluated upon first use.
164// This must not be called concurrently.
165func (f *ExtensionField) SetLazy(t pref.ExtensionType, fn func() pref.Value) {
166	f.typ = t
167	f.lazy = &lazyExtensionValue{fn: fn}
168}
169
170// Value returns the value of the extension field.
171// This may be called concurrently.
172func (f *ExtensionField) Value() pref.Value {
173	if f.lazy != nil {
174		if atomic.LoadUint32(&f.lazy.atomicOnce) == 0 {
175			f.lazyInit()
176		}
177		return f.lazy.value
178	}
179	return f.value
180}
181
182// Type returns the type of the extension field.
183// This may be called concurrently.
184func (f ExtensionField) Type() pref.ExtensionType {
185	return f.typ
186}
187
188// IsSet returns whether the extension field is set.
189// This may be called concurrently.
190func (f ExtensionField) IsSet() bool {
191	return f.typ != nil
192}
193
194// IsLazy reports whether a field is lazily encoded.
195// It is exported for testing.
196func IsLazy(m pref.Message, fd pref.FieldDescriptor) bool {
197	var mi *MessageInfo
198	var p pointer
199	switch m := m.(type) {
200	case *messageState:
201		mi = m.messageInfo()
202		p = m.pointer()
203	case *messageReflectWrapper:
204		mi = m.messageInfo()
205		p = m.pointer()
206	default:
207		return false
208	}
209	xd, ok := fd.(pref.ExtensionTypeDescriptor)
210	if !ok {
211		return false
212	}
213	xt := xd.Type()
214	ext := mi.extensionMap(p)
215	if ext == nil {
216		return false
217	}
218	f, ok := (*ext)[int32(fd.Number())]
219	if !ok {
220		return false
221	}
222	return f.typ == xt && f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0
223}
224