1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protobuild constructs messages.
6//
7// This package is used to construct multiple types of message with a similar shape
8// from a common template.
9package protobuild
10
11import (
12	"fmt"
13	"math"
14	"reflect"
15
16	pref "google.golang.org/protobuf/reflect/protoreflect"
17	"google.golang.org/protobuf/reflect/protoregistry"
18)
19
20// A Value is a value assignable to a field.
21// A Value may be a value accepted by protoreflect.ValueOf. In addition:
22//
23// • An int may be assigned to any numeric field.
24//
25// • A float64 may be assigned to a double field.
26//
27// • Either a string or []byte may be assigned to a string or bytes field.
28//
29// • A string containing the value name may be assigned to an enum field.
30//
31// • A slice may be assigned to a list, and a map may be assigned to a map.
32type Value interface{}
33
34// A Message is a template to apply to a message. Keys are field names, including
35// extension names.
36type Message map[pref.Name]Value
37
38// Unknown is a key associated with the unknown fields of a message.
39// The value should be a []byte.
40const Unknown = "@unknown"
41
42// Build applies the template to a message.
43func (template Message) Build(m pref.Message) {
44	md := m.Descriptor()
45	fields := md.Fields()
46	exts := make(map[pref.Name]pref.FieldDescriptor)
47	protoregistry.GlobalTypes.RangeExtensionsByMessage(md.FullName(), func(xt pref.ExtensionType) bool {
48		xd := xt.TypeDescriptor()
49		exts[xd.Name()] = xd
50		return true
51	})
52	for k, v := range template {
53		if k == Unknown {
54			m.SetUnknown(pref.RawFields(v.([]byte)))
55			continue
56		}
57		fd := fields.ByName(k)
58		if fd == nil {
59			fd = exts[k]
60		}
61		if fd == nil {
62			panic(fmt.Sprintf("%v.%v: not found", md.FullName(), k))
63		}
64		switch {
65		case fd.IsList():
66			list := m.Mutable(fd).List()
67			s := reflect.ValueOf(v)
68			for i := 0; i < s.Len(); i++ {
69				if fd.Message() == nil {
70					list.Append(fieldValue(fd, s.Index(i).Interface()))
71				} else {
72					e := list.NewElement()
73					s.Index(i).Interface().(Message).Build(e.Message())
74					list.Append(e)
75				}
76			}
77		case fd.IsMap():
78			mapv := m.Mutable(fd).Map()
79			rm := reflect.ValueOf(v)
80			for _, k := range rm.MapKeys() {
81				mk := fieldValue(fd.MapKey(), k.Interface()).MapKey()
82				if fd.MapValue().Message() == nil {
83					mv := fieldValue(fd.MapValue(), rm.MapIndex(k).Interface())
84					mapv.Set(mk, mv)
85				} else if mapv.Has(mk) {
86					mv := mapv.Get(mk).Message()
87					rm.MapIndex(k).Interface().(Message).Build(mv)
88				} else {
89					mv := mapv.NewValue()
90					rm.MapIndex(k).Interface().(Message).Build(mv.Message())
91					mapv.Set(mk, mv)
92				}
93			}
94		default:
95			if fd.Message() == nil {
96				m.Set(fd, fieldValue(fd, v))
97			} else {
98				v.(Message).Build(m.Mutable(fd).Message())
99			}
100		}
101	}
102}
103
104func fieldValue(fd pref.FieldDescriptor, v interface{}) pref.Value {
105	switch o := v.(type) {
106	case int:
107		switch fd.Kind() {
108		case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
109			if o < math.MinInt32 || math.MaxInt32 < o {
110				panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, int32(math.MinInt32), int32(math.MaxInt32)))
111			}
112			v = int32(o)
113		case pref.Uint32Kind, pref.Fixed32Kind:
114			if o < 0 || math.MaxUint32 < 0 {
115				panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, uint32(0), uint32(math.MaxUint32)))
116			}
117			v = uint32(o)
118		case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
119			v = int64(o)
120		case pref.Uint64Kind, pref.Fixed64Kind:
121			if o < 0 {
122				panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, uint64(0), uint64(math.MaxUint64)))
123			}
124			v = uint64(o)
125		case pref.FloatKind:
126			v = float32(o)
127		case pref.DoubleKind:
128			v = float64(o)
129		case pref.EnumKind:
130			v = pref.EnumNumber(o)
131		default:
132			panic(fmt.Sprintf("%v: invalid value type int", fd.FullName()))
133		}
134	case float64:
135		switch fd.Kind() {
136		case pref.FloatKind:
137			v = float32(o)
138		}
139	case string:
140		switch fd.Kind() {
141		case pref.BytesKind:
142			v = []byte(o)
143		case pref.EnumKind:
144			v = fd.Enum().Values().ByName(pref.Name(o)).Number()
145		}
146	case []byte:
147		return pref.ValueOf(append([]byte{}, o...))
148	}
149	return pref.ValueOf(v)
150}
151