1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package dynamicpb creates protocol buffer messages using runtime type information.
6package dynamicpb
7
8import (
9	"math"
10
11	"google.golang.org/protobuf/internal/errors"
12	pref "google.golang.org/protobuf/reflect/protoreflect"
13	"google.golang.org/protobuf/runtime/protoiface"
14	"google.golang.org/protobuf/runtime/protoimpl"
15)
16
17// A Message is a dynamically constructed protocol buffer message.
18//
19// Message implements the proto.Message interface, and may be used with all
20// standard proto package functions such as Marshal, Unmarshal, and so forth.
21//
22// Message also implements the protoreflect.Message interface. See the protoreflect
23// package documentation for that interface for how to get and set fields and
24// otherwise interact with the contents of a Message.
25//
26// Reflection API functions which construct messages, such as NewField,
27// return new dynamic messages of the appropriate type. Functions which take
28// messages, such as Set for a message-value field, will accept any message
29// with a compatible type.
30//
31// Operations which modify a Message are not safe for concurrent use.
32type Message struct {
33	typ     messageType
34	known   map[pref.FieldNumber]pref.Value
35	ext     map[pref.FieldNumber]pref.FieldDescriptor
36	unknown pref.RawFields
37}
38
39var (
40	_ pref.Message         = (*Message)(nil)
41	_ pref.ProtoMessage    = (*Message)(nil)
42	_ protoiface.MessageV1 = (*Message)(nil)
43)
44
45// NewMessage creates a new message with the provided descriptor.
46func NewMessage(desc pref.MessageDescriptor) *Message {
47	return &Message{
48		typ:   messageType{desc},
49		known: make(map[pref.FieldNumber]pref.Value),
50		ext:   make(map[pref.FieldNumber]pref.FieldDescriptor),
51	}
52}
53
54// ProtoMessage implements the legacy message interface.
55func (m *Message) ProtoMessage() {}
56
57// ProtoReflect implements the protoreflect.ProtoMessage interface.
58func (m *Message) ProtoReflect() pref.Message {
59	return m
60}
61
62// String returns a string representation of a message.
63func (m *Message) String() string {
64	return protoimpl.X.MessageStringOf(m)
65}
66
67// Reset clears the message to be empty, but preserves the dynamic message type.
68func (m *Message) Reset() {
69	m.known = make(map[pref.FieldNumber]pref.Value)
70	m.ext = make(map[pref.FieldNumber]pref.FieldDescriptor)
71	m.unknown = nil
72}
73
74// Descriptor returns the message descriptor.
75func (m *Message) Descriptor() pref.MessageDescriptor {
76	return m.typ.desc
77}
78
79// Type returns the message type.
80func (m *Message) Type() pref.MessageType {
81	return m.typ
82}
83
84// New returns a newly allocated empty message with the same descriptor.
85// See protoreflect.Message for details.
86func (m *Message) New() pref.Message {
87	return m.Type().New()
88}
89
90// Interface returns the message.
91// See protoreflect.Message for details.
92func (m *Message) Interface() pref.ProtoMessage {
93	return m
94}
95
96// ProtoMethods is an internal detail of the protoreflect.Message interface.
97// Users should never call this directly.
98func (m *Message) ProtoMethods() *protoiface.Methods {
99	return nil
100}
101
102// Range visits every populated field in undefined order.
103// See protoreflect.Message for details.
104func (m *Message) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
105	for num, v := range m.known {
106		fd := m.ext[num]
107		if fd == nil {
108			fd = m.Descriptor().Fields().ByNumber(num)
109		}
110		if !isSet(fd, v) {
111			continue
112		}
113		if !f(fd, v) {
114			return
115		}
116	}
117}
118
119// Has reports whether a field is populated.
120// See protoreflect.Message for details.
121func (m *Message) Has(fd pref.FieldDescriptor) bool {
122	m.checkField(fd)
123	if fd.IsExtension() && m.ext[fd.Number()] != fd {
124		return false
125	}
126	v, ok := m.known[fd.Number()]
127	if !ok {
128		return false
129	}
130	return isSet(fd, v)
131}
132
133// Clear clears a field.
134// See protoreflect.Message for details.
135func (m *Message) Clear(fd pref.FieldDescriptor) {
136	m.checkField(fd)
137	num := fd.Number()
138	delete(m.known, num)
139	delete(m.ext, num)
140}
141
142// Get returns the value of a field.
143// See protoreflect.Message for details.
144func (m *Message) Get(fd pref.FieldDescriptor) pref.Value {
145	m.checkField(fd)
146	num := fd.Number()
147	if fd.IsExtension() {
148		if fd != m.ext[num] {
149			return fd.(pref.ExtensionTypeDescriptor).Type().Zero()
150		}
151		return m.known[num]
152	}
153	if v, ok := m.known[num]; ok {
154		switch {
155		case fd.IsMap():
156			if v.Map().Len() > 0 {
157				return v
158			}
159		case fd.IsList():
160			if v.List().Len() > 0 {
161				return v
162			}
163		default:
164			return v
165		}
166	}
167	switch {
168	case fd.IsMap():
169		return pref.ValueOfMap(&dynamicMap{desc: fd})
170	case fd.IsList():
171		return pref.ValueOfList(emptyList{desc: fd})
172	case fd.Message() != nil:
173		return pref.ValueOfMessage(&Message{typ: messageType{fd.Message()}})
174	case fd.Kind() == pref.BytesKind:
175		return pref.ValueOfBytes(append([]byte(nil), fd.Default().Bytes()...))
176	default:
177		return fd.Default()
178	}
179}
180
181// Mutable returns a mutable reference to a repeated, map, or message field.
182// See protoreflect.Message for details.
183func (m *Message) Mutable(fd pref.FieldDescriptor) pref.Value {
184	m.checkField(fd)
185	if !fd.IsMap() && !fd.IsList() && fd.Message() == nil {
186		panic(errors.New("%v: getting mutable reference to non-composite type", fd.FullName()))
187	}
188	if m.known == nil {
189		panic(errors.New("%v: modification of read-only message", fd.FullName()))
190	}
191	num := fd.Number()
192	if fd.IsExtension() {
193		if fd != m.ext[num] {
194			m.ext[num] = fd
195			m.known[num] = fd.(pref.ExtensionTypeDescriptor).Type().New()
196		}
197		return m.known[num]
198	}
199	if v, ok := m.known[num]; ok {
200		return v
201	}
202	m.clearOtherOneofFields(fd)
203	m.known[num] = m.NewField(fd)
204	if fd.IsExtension() {
205		m.ext[num] = fd
206	}
207	return m.known[num]
208}
209
210// Set stores a value in a field.
211// See protoreflect.Message for details.
212func (m *Message) Set(fd pref.FieldDescriptor, v pref.Value) {
213	m.checkField(fd)
214	if m.known == nil {
215		panic(errors.New("%v: modification of read-only message", fd.FullName()))
216	}
217	if fd.IsExtension() {
218		isValid := true
219		switch {
220		case !fd.(pref.ExtensionTypeDescriptor).Type().IsValidValue(v):
221			isValid = false
222		case fd.IsList():
223			isValid = v.List().IsValid()
224		case fd.IsMap():
225			isValid = v.Map().IsValid()
226		case fd.Message() != nil:
227			isValid = v.Message().IsValid()
228		}
229		if !isValid {
230			panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
231		}
232		m.ext[fd.Number()] = fd
233	} else {
234		typecheck(fd, v)
235	}
236	m.clearOtherOneofFields(fd)
237	m.known[fd.Number()] = v
238}
239
240func (m *Message) clearOtherOneofFields(fd pref.FieldDescriptor) {
241	od := fd.ContainingOneof()
242	if od == nil {
243		return
244	}
245	num := fd.Number()
246	for i := 0; i < od.Fields().Len(); i++ {
247		if n := od.Fields().Get(i).Number(); n != num {
248			delete(m.known, n)
249		}
250	}
251}
252
253// NewField returns a new value for assignable to the field of a given descriptor.
254// See protoreflect.Message for details.
255func (m *Message) NewField(fd pref.FieldDescriptor) pref.Value {
256	m.checkField(fd)
257	switch {
258	case fd.IsExtension():
259		return fd.(pref.ExtensionTypeDescriptor).Type().New()
260	case fd.IsMap():
261		return pref.ValueOfMap(&dynamicMap{
262			desc: fd,
263			mapv: make(map[interface{}]pref.Value),
264		})
265	case fd.IsList():
266		return pref.ValueOfList(&dynamicList{desc: fd})
267	case fd.Message() != nil:
268		return pref.ValueOfMessage(NewMessage(fd.Message()).ProtoReflect())
269	default:
270		return fd.Default()
271	}
272}
273
274// WhichOneof reports which field in a oneof is populated, returning nil if none are populated.
275// See protoreflect.Message for details.
276func (m *Message) WhichOneof(od pref.OneofDescriptor) pref.FieldDescriptor {
277	for i := 0; i < od.Fields().Len(); i++ {
278		fd := od.Fields().Get(i)
279		if m.Has(fd) {
280			return fd
281		}
282	}
283	return nil
284}
285
286// GetUnknown returns the raw unknown fields.
287// See protoreflect.Message for details.
288func (m *Message) GetUnknown() pref.RawFields {
289	return m.unknown
290}
291
292// SetUnknown sets the raw unknown fields.
293// See protoreflect.Message for details.
294func (m *Message) SetUnknown(r pref.RawFields) {
295	if m.known == nil {
296		panic(errors.New("%v: modification of read-only message", m.typ.desc.FullName()))
297	}
298	m.unknown = r
299}
300
301// IsValid reports whether the message is valid.
302// See protoreflect.Message for details.
303func (m *Message) IsValid() bool {
304	return m.known != nil
305}
306
307func (m *Message) checkField(fd pref.FieldDescriptor) {
308	if fd.IsExtension() && fd.ContainingMessage().FullName() == m.Descriptor().FullName() {
309		if _, ok := fd.(pref.ExtensionTypeDescriptor); !ok {
310			panic(errors.New("%v: extension field descriptor does not implement ExtensionTypeDescriptor", fd.FullName()))
311		}
312		return
313	}
314	if fd.Parent() == m.Descriptor() {
315		return
316	}
317	fields := m.Descriptor().Fields()
318	index := fd.Index()
319	if index >= fields.Len() || fields.Get(index) != fd {
320		panic(errors.New("%v: field descriptor does not belong to this message", fd.FullName()))
321	}
322}
323
324type messageType struct {
325	desc pref.MessageDescriptor
326}
327
328// NewMessageType creates a new MessageType with the provided descriptor.
329//
330// MessageTypes created by this package are equal if their descriptors are equal.
331// That is, if md1 == md2, then NewMessageType(md1) == NewMessageType(md2).
332func NewMessageType(desc pref.MessageDescriptor) pref.MessageType {
333	return messageType{desc}
334}
335
336func (mt messageType) New() pref.Message                  { return NewMessage(mt.desc) }
337func (mt messageType) Zero() pref.Message                 { return &Message{typ: messageType{mt.desc}} }
338func (mt messageType) Descriptor() pref.MessageDescriptor { return mt.desc }
339
340type emptyList struct {
341	desc pref.FieldDescriptor
342}
343
344func (x emptyList) Len() int                  { return 0 }
345func (x emptyList) Get(n int) pref.Value      { panic(errors.New("out of range")) }
346func (x emptyList) Set(n int, v pref.Value)   { panic(errors.New("modification of immutable list")) }
347func (x emptyList) Append(v pref.Value)       { panic(errors.New("modification of immutable list")) }
348func (x emptyList) AppendMutable() pref.Value { panic(errors.New("modification of immutable list")) }
349func (x emptyList) Truncate(n int)            { panic(errors.New("modification of immutable list")) }
350func (x emptyList) NewElement() pref.Value    { return newListEntry(x.desc) }
351func (x emptyList) IsValid() bool             { return false }
352
353type dynamicList struct {
354	desc pref.FieldDescriptor
355	list []pref.Value
356}
357
358func (x *dynamicList) Len() int {
359	return len(x.list)
360}
361
362func (x *dynamicList) Get(n int) pref.Value {
363	return x.list[n]
364}
365
366func (x *dynamicList) Set(n int, v pref.Value) {
367	typecheckSingular(x.desc, v)
368	x.list[n] = v
369}
370
371func (x *dynamicList) Append(v pref.Value) {
372	typecheckSingular(x.desc, v)
373	x.list = append(x.list, v)
374}
375
376func (x *dynamicList) AppendMutable() pref.Value {
377	if x.desc.Message() == nil {
378		panic(errors.New("%v: invalid AppendMutable on list with non-message type", x.desc.FullName()))
379	}
380	v := x.NewElement()
381	x.Append(v)
382	return v
383}
384
385func (x *dynamicList) Truncate(n int) {
386	// Zero truncated elements to avoid keeping data live.
387	for i := n; i < len(x.list); i++ {
388		x.list[i] = pref.Value{}
389	}
390	x.list = x.list[:n]
391}
392
393func (x *dynamicList) NewElement() pref.Value {
394	return newListEntry(x.desc)
395}
396
397func (x *dynamicList) IsValid() bool {
398	return true
399}
400
401type dynamicMap struct {
402	desc pref.FieldDescriptor
403	mapv map[interface{}]pref.Value
404}
405
406func (x *dynamicMap) Get(k pref.MapKey) pref.Value { return x.mapv[k.Interface()] }
407func (x *dynamicMap) Set(k pref.MapKey, v pref.Value) {
408	typecheckSingular(x.desc.MapKey(), k.Value())
409	typecheckSingular(x.desc.MapValue(), v)
410	x.mapv[k.Interface()] = v
411}
412func (x *dynamicMap) Has(k pref.MapKey) bool { return x.Get(k).IsValid() }
413func (x *dynamicMap) Clear(k pref.MapKey)    { delete(x.mapv, k.Interface()) }
414func (x *dynamicMap) Mutable(k pref.MapKey) pref.Value {
415	if x.desc.MapValue().Message() == nil {
416		panic(errors.New("%v: invalid Mutable on map with non-message value type", x.desc.FullName()))
417	}
418	v := x.Get(k)
419	if !v.IsValid() {
420		v = x.NewValue()
421		x.Set(k, v)
422	}
423	return v
424}
425func (x *dynamicMap) Len() int { return len(x.mapv) }
426func (x *dynamicMap) NewValue() pref.Value {
427	if md := x.desc.MapValue().Message(); md != nil {
428		return pref.ValueOfMessage(NewMessage(md).ProtoReflect())
429	}
430	return x.desc.MapValue().Default()
431}
432func (x *dynamicMap) IsValid() bool {
433	return x.mapv != nil
434}
435
436func (x *dynamicMap) Range(f func(pref.MapKey, pref.Value) bool) {
437	for k, v := range x.mapv {
438		if !f(pref.ValueOf(k).MapKey(), v) {
439			return
440		}
441	}
442}
443
444func isSet(fd pref.FieldDescriptor, v pref.Value) bool {
445	switch {
446	case fd.IsMap():
447		return v.Map().Len() > 0
448	case fd.IsList():
449		return v.List().Len() > 0
450	case fd.ContainingOneof() != nil:
451		return true
452	case fd.Syntax() == pref.Proto3 && !fd.IsExtension():
453		switch fd.Kind() {
454		case pref.BoolKind:
455			return v.Bool()
456		case pref.EnumKind:
457			return v.Enum() != 0
458		case pref.Int32Kind, pref.Sint32Kind, pref.Int64Kind, pref.Sint64Kind, pref.Sfixed32Kind, pref.Sfixed64Kind:
459			return v.Int() != 0
460		case pref.Uint32Kind, pref.Uint64Kind, pref.Fixed32Kind, pref.Fixed64Kind:
461			return v.Uint() != 0
462		case pref.FloatKind, pref.DoubleKind:
463			return v.Float() != 0 || math.Signbit(v.Float())
464		case pref.StringKind:
465			return v.String() != ""
466		case pref.BytesKind:
467			return len(v.Bytes()) > 0
468		}
469	}
470	return true
471}
472
473func typecheck(fd pref.FieldDescriptor, v pref.Value) {
474	if err := typeIsValid(fd, v); err != nil {
475		panic(err)
476	}
477}
478
479func typeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
480	switch {
481	case !v.IsValid():
482		return errors.New("%v: assigning invalid value", fd.FullName())
483	case fd.IsMap():
484		if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd || !mapv.IsValid() {
485			return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
486		}
487		return nil
488	case fd.IsList():
489		switch list := v.Interface().(type) {
490		case *dynamicList:
491			if list.desc == fd && list.IsValid() {
492				return nil
493			}
494		case emptyList:
495			if list.desc == fd && list.IsValid() {
496				return nil
497			}
498		}
499		return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
500	default:
501		return singularTypeIsValid(fd, v)
502	}
503}
504
505func typecheckSingular(fd pref.FieldDescriptor, v pref.Value) {
506	if err := singularTypeIsValid(fd, v); err != nil {
507		panic(err)
508	}
509}
510
511func singularTypeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
512	vi := v.Interface()
513	var ok bool
514	switch fd.Kind() {
515	case pref.BoolKind:
516		_, ok = vi.(bool)
517	case pref.EnumKind:
518		// We could check against the valid set of enum values, but do not.
519		_, ok = vi.(pref.EnumNumber)
520	case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
521		_, ok = vi.(int32)
522	case pref.Uint32Kind, pref.Fixed32Kind:
523		_, ok = vi.(uint32)
524	case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
525		_, ok = vi.(int64)
526	case pref.Uint64Kind, pref.Fixed64Kind:
527		_, ok = vi.(uint64)
528	case pref.FloatKind:
529		_, ok = vi.(float32)
530	case pref.DoubleKind:
531		_, ok = vi.(float64)
532	case pref.StringKind:
533		_, ok = vi.(string)
534	case pref.BytesKind:
535		_, ok = vi.([]byte)
536	case pref.MessageKind, pref.GroupKind:
537		var m pref.Message
538		m, ok = vi.(pref.Message)
539		if ok && m.Descriptor().FullName() != fd.Message().FullName() {
540			return errors.New("%v: assigning invalid message type %v", fd.FullName(), m.Descriptor().FullName())
541		}
542		if dm, ok := vi.(*Message); ok && dm.known == nil {
543			return errors.New("%v: assigning invalid zero-value message", fd.FullName())
544		}
545	}
546	if !ok {
547		return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
548	}
549	return nil
550}
551
552func newListEntry(fd pref.FieldDescriptor) pref.Value {
553	switch fd.Kind() {
554	case pref.BoolKind:
555		return pref.ValueOfBool(false)
556	case pref.EnumKind:
557		return pref.ValueOfEnum(fd.Enum().Values().Get(0).Number())
558	case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
559		return pref.ValueOfInt32(0)
560	case pref.Uint32Kind, pref.Fixed32Kind:
561		return pref.ValueOfUint32(0)
562	case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
563		return pref.ValueOfInt64(0)
564	case pref.Uint64Kind, pref.Fixed64Kind:
565		return pref.ValueOfUint64(0)
566	case pref.FloatKind:
567		return pref.ValueOfFloat32(0)
568	case pref.DoubleKind:
569		return pref.ValueOfFloat64(0)
570	case pref.StringKind:
571		return pref.ValueOfString("")
572	case pref.BytesKind:
573		return pref.ValueOfBytes(nil)
574	case pref.MessageKind, pref.GroupKind:
575		return pref.ValueOfMessage(NewMessage(fd.Message()).ProtoReflect())
576	}
577	panic(errors.New("%v: unknown kind %v", fd.FullName(), fd.Kind()))
578}
579
580// extensionType is a dynamic protoreflect.ExtensionType.
581type extensionType struct {
582	desc extensionTypeDescriptor
583}
584
585// NewExtensionType creates a new ExtensionType with the provided descriptor.
586//
587// Dynamic ExtensionTypes with the same descriptor compare as equal. That is,
588// if xd1 == xd2, then NewExtensionType(xd1) == NewExtensionType(xd2).
589//
590// The InterfaceOf and ValueOf methods of the extension type are defined as:
591//
592//	func (xt extensionType) ValueOf(iv interface{}) protoreflect.Value {
593//		return protoreflect.ValueOf(iv)
594//	}
595//
596//	func (xt extensionType) InterfaceOf(v protoreflect.Value) interface{} {
597//		return v.Interface()
598//	}
599//
600// The Go type used by the proto.GetExtension and proto.SetExtension functions
601// is determined by these methods, and is therefore equivalent to the Go type
602// used to represent a protoreflect.Value. See the protoreflect.Value
603// documentation for more details.
604func NewExtensionType(desc pref.ExtensionDescriptor) pref.ExtensionType {
605	if xt, ok := desc.(pref.ExtensionTypeDescriptor); ok {
606		desc = xt.Descriptor()
607	}
608	return extensionType{extensionTypeDescriptor{desc}}
609}
610
611func (xt extensionType) New() pref.Value {
612	switch {
613	case xt.desc.IsMap():
614		return pref.ValueOfMap(&dynamicMap{
615			desc: xt.desc,
616			mapv: make(map[interface{}]pref.Value),
617		})
618	case xt.desc.IsList():
619		return pref.ValueOfList(&dynamicList{desc: xt.desc})
620	case xt.desc.Message() != nil:
621		return pref.ValueOfMessage(NewMessage(xt.desc.Message()))
622	default:
623		return xt.desc.Default()
624	}
625}
626
627func (xt extensionType) Zero() pref.Value {
628	switch {
629	case xt.desc.IsMap():
630		return pref.ValueOfMap(&dynamicMap{desc: xt.desc})
631	case xt.desc.Cardinality() == pref.Repeated:
632		return pref.ValueOfList(emptyList{desc: xt.desc})
633	case xt.desc.Message() != nil:
634		return pref.ValueOfMessage(&Message{typ: messageType{xt.desc.Message()}})
635	default:
636		return xt.desc.Default()
637	}
638}
639
640func (xt extensionType) TypeDescriptor() pref.ExtensionTypeDescriptor {
641	return xt.desc
642}
643
644func (xt extensionType) ValueOf(iv interface{}) pref.Value {
645	v := pref.ValueOf(iv)
646	typecheck(xt.desc, v)
647	return v
648}
649
650func (xt extensionType) InterfaceOf(v pref.Value) interface{} {
651	typecheck(xt.desc, v)
652	return v.Interface()
653}
654
655func (xt extensionType) IsValidInterface(iv interface{}) bool {
656	return typeIsValid(xt.desc, pref.ValueOf(iv)) == nil
657}
658
659func (xt extensionType) IsValidValue(v pref.Value) bool {
660	return typeIsValid(xt.desc, v) == nil
661}
662
663type extensionTypeDescriptor struct {
664	pref.ExtensionDescriptor
665}
666
667func (xt extensionTypeDescriptor) Type() pref.ExtensionType {
668	return extensionType{xt}
669}
670
671func (xt extensionTypeDescriptor) Descriptor() pref.ExtensionDescriptor {
672	return xt.ExtensionDescriptor
673}
674