1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"fmt"
9	"reflect"
10	"sort"
11
12	"google.golang.org/protobuf/encoding/protowire"
13	"google.golang.org/protobuf/internal/encoding/messageset"
14	"google.golang.org/protobuf/internal/fieldsort"
15	pref "google.golang.org/protobuf/reflect/protoreflect"
16	piface "google.golang.org/protobuf/runtime/protoiface"
17)
18
19// coderMessageInfo contains per-message information used by the fast-path functions.
20// This is a different type from MessageInfo to keep MessageInfo as general-purpose as
21// possible.
22type coderMessageInfo struct {
23	methods piface.Methods
24
25	orderedCoderFields []*coderFieldInfo
26	denseCoderFields   []*coderFieldInfo
27	coderFields        map[protowire.Number]*coderFieldInfo
28	sizecacheOffset    offset
29	unknownOffset      offset
30	extensionOffset    offset
31	needsInitCheck     bool
32	isMessageSet       bool
33	numRequiredFields  uint8
34}
35
36type coderFieldInfo struct {
37	funcs      pointerCoderFuncs // fast-path per-field functions
38	mi         *MessageInfo      // field's message
39	ft         reflect.Type
40	validation validationInfo   // information used by message validation
41	num        pref.FieldNumber // field number
42	offset     offset           // struct field offset
43	wiretag    uint64           // field tag (number + wire type)
44	tagsize    int              // size of the varint-encoded tag
45	isPointer  bool             // true if IsNil may be called on the struct field
46	isRequired bool             // true if field is required
47}
48
49func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
50	mi.sizecacheOffset = si.sizecacheOffset
51	mi.unknownOffset = si.unknownOffset
52	mi.extensionOffset = si.extensionOffset
53
54	mi.coderFields = make(map[protowire.Number]*coderFieldInfo)
55	fields := mi.Desc.Fields()
56	preallocFields := make([]coderFieldInfo, fields.Len())
57	for i := 0; i < fields.Len(); i++ {
58		fd := fields.Get(i)
59
60		fs := si.fieldsByNumber[fd.Number()]
61		isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
62		if isOneof {
63			fs = si.oneofsByName[fd.ContainingOneof().Name()]
64		}
65		ft := fs.Type
66		var wiretag uint64
67		if !fd.IsPacked() {
68			wiretag = protowire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
69		} else {
70			wiretag = protowire.EncodeTag(fd.Number(), protowire.BytesType)
71		}
72		var fieldOffset offset
73		var funcs pointerCoderFuncs
74		var childMessage *MessageInfo
75		switch {
76		case isOneof:
77			fieldOffset = offsetOf(fs, mi.Exporter)
78		case fd.IsWeak():
79			fieldOffset = si.weakOffset
80			funcs = makeWeakMessageFieldCoder(fd)
81		default:
82			fieldOffset = offsetOf(fs, mi.Exporter)
83			childMessage, funcs = fieldCoder(fd, ft)
84		}
85		cf := &preallocFields[i]
86		*cf = coderFieldInfo{
87			num:        fd.Number(),
88			offset:     fieldOffset,
89			wiretag:    wiretag,
90			ft:         ft,
91			tagsize:    protowire.SizeVarint(wiretag),
92			funcs:      funcs,
93			mi:         childMessage,
94			validation: newFieldValidationInfo(mi, si, fd, ft),
95			isPointer:  fd.Cardinality() == pref.Repeated || fd.HasPresence(),
96			isRequired: fd.Cardinality() == pref.Required,
97		}
98		mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
99		mi.coderFields[cf.num] = cf
100	}
101	for i, oneofs := 0, mi.Desc.Oneofs(); i < oneofs.Len(); i++ {
102		if od := oneofs.Get(i); !od.IsSynthetic() {
103			mi.initOneofFieldCoders(od, si)
104		}
105	}
106	if messageset.IsMessageSet(mi.Desc) {
107		if !mi.extensionOffset.IsValid() {
108			panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName()))
109		}
110		if !mi.unknownOffset.IsValid() {
111			panic(fmt.Sprintf("%v: MessageSet with no unknown field", mi.Desc.FullName()))
112		}
113		mi.isMessageSet = true
114	}
115	sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
116		return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num
117	})
118
119	var maxDense pref.FieldNumber
120	for _, cf := range mi.orderedCoderFields {
121		if cf.num >= 16 && cf.num >= 2*maxDense {
122			break
123		}
124		maxDense = cf.num
125	}
126	mi.denseCoderFields = make([]*coderFieldInfo, maxDense+1)
127	for _, cf := range mi.orderedCoderFields {
128		if int(cf.num) >= len(mi.denseCoderFields) {
129			break
130		}
131		mi.denseCoderFields[cf.num] = cf
132	}
133
134	// To preserve compatibility with historic wire output, marshal oneofs last.
135	if mi.Desc.Oneofs().Len() > 0 {
136		sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
137			fi := fields.ByNumber(mi.orderedCoderFields[i].num)
138			fj := fields.ByNumber(mi.orderedCoderFields[j].num)
139			return fieldsort.Less(fi, fj)
140		})
141	}
142
143	mi.needsInitCheck = needsInitCheck(mi.Desc)
144	if mi.methods.Marshal == nil && mi.methods.Size == nil {
145		mi.methods.Flags |= piface.SupportMarshalDeterministic
146		mi.methods.Marshal = mi.marshal
147		mi.methods.Size = mi.size
148	}
149	if mi.methods.Unmarshal == nil {
150		mi.methods.Flags |= piface.SupportUnmarshalDiscardUnknown
151		mi.methods.Unmarshal = mi.unmarshal
152	}
153	if mi.methods.CheckInitialized == nil {
154		mi.methods.CheckInitialized = mi.checkInitialized
155	}
156	if mi.methods.Merge == nil {
157		mi.methods.Merge = mi.merge
158	}
159}
160