1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"errors"
9	"reflect"
10	"sort"
11
12	"google.golang.org/protobuf/encoding/protowire"
13	pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16type mapInfo struct {
17	goType     reflect.Type
18	keyWiretag uint64
19	valWiretag uint64
20	keyFuncs   valueCoderFuncs
21	valFuncs   valueCoderFuncs
22	keyZero    pref.Value
23	keyKind    pref.Kind
24	conv       *mapConverter
25}
26
27func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
28	// TODO: Consider generating specialized map coders.
29	keyField := fd.MapKey()
30	valField := fd.MapValue()
31	keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
32	valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
33	keyFuncs := encoderFuncsForValue(keyField)
34	valFuncs := encoderFuncsForValue(valField)
35	conv := newMapConverter(ft, fd)
36
37	mapi := &mapInfo{
38		goType:     ft,
39		keyWiretag: keyWiretag,
40		valWiretag: valWiretag,
41		keyFuncs:   keyFuncs,
42		valFuncs:   valFuncs,
43		keyZero:    keyField.Default(),
44		keyKind:    keyField.Kind(),
45		conv:       conv,
46	}
47	if valField.Kind() == pref.MessageKind {
48		valueMessage = getMessageInfo(ft.Elem())
49	}
50
51	funcs = pointerCoderFuncs{
52		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
53			return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
54		},
55		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
56			return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
57		},
58		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
59			mp := p.AsValueOf(ft)
60			if mp.Elem().IsNil() {
61				mp.Elem().Set(reflect.MakeMap(mapi.goType))
62			}
63			if f.mi == nil {
64				return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
65			} else {
66				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
67			}
68		},
69	}
70	switch valField.Kind() {
71	case pref.MessageKind:
72		funcs.merge = mergeMapOfMessage
73	case pref.BytesKind:
74		funcs.merge = mergeMapOfBytes
75	default:
76		funcs.merge = mergeMap
77	}
78	if valFuncs.isInit != nil {
79		funcs.isInit = func(p pointer, f *coderFieldInfo) error {
80			return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
81		}
82	}
83	return valueMessage, funcs
84}
85
86const (
87	mapKeyTagSize = 1 // field 1, tag size 1.
88	mapValTagSize = 1 // field 2, tag size 2.
89)
90
91func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
92	if mapv.Len() == 0 {
93		return 0
94	}
95	n := 0
96	iter := mapRange(mapv)
97	for iter.Next() {
98		key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
99		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
100		var valSize int
101		value := mapi.conv.valConv.PBValueOf(iter.Value())
102		if f.mi == nil {
103			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
104		} else {
105			p := pointerOfValue(iter.Value())
106			valSize += mapValTagSize
107			valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
108		}
109		n += f.tagsize + protowire.SizeBytes(keySize+valSize)
110	}
111	return n
112}
113
114func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
115	if wtyp != protowire.BytesType {
116		return out, errUnknown
117	}
118	b, n := protowire.ConsumeBytes(b)
119	if n < 0 {
120		return out, protowire.ParseError(n)
121	}
122	var (
123		key = mapi.keyZero
124		val = mapi.conv.valConv.New()
125	)
126	for len(b) > 0 {
127		num, wtyp, n := protowire.ConsumeTag(b)
128		if n < 0 {
129			return out, protowire.ParseError(n)
130		}
131		if num > protowire.MaxValidNumber {
132			return out, errors.New("invalid field number")
133		}
134		b = b[n:]
135		err := errUnknown
136		switch num {
137		case 1:
138			var v pref.Value
139			var o unmarshalOutput
140			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
141			if err != nil {
142				break
143			}
144			key = v
145			n = o.n
146		case 2:
147			var v pref.Value
148			var o unmarshalOutput
149			v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
150			if err != nil {
151				break
152			}
153			val = v
154			n = o.n
155		}
156		if err == errUnknown {
157			n = protowire.ConsumeFieldValue(num, wtyp, b)
158			if n < 0 {
159				return out, protowire.ParseError(n)
160			}
161		} else if err != nil {
162			return out, err
163		}
164		b = b[n:]
165	}
166	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
167	out.n = n
168	return out, nil
169}
170
171func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
172	if wtyp != protowire.BytesType {
173		return out, errUnknown
174	}
175	b, n := protowire.ConsumeBytes(b)
176	if n < 0 {
177		return out, protowire.ParseError(n)
178	}
179	var (
180		key = mapi.keyZero
181		val = reflect.New(f.mi.GoReflectType.Elem())
182	)
183	for len(b) > 0 {
184		num, wtyp, n := protowire.ConsumeTag(b)
185		if n < 0 {
186			return out, protowire.ParseError(n)
187		}
188		if num > protowire.MaxValidNumber {
189			return out, errors.New("invalid field number")
190		}
191		b = b[n:]
192		err := errUnknown
193		switch num {
194		case 1:
195			var v pref.Value
196			var o unmarshalOutput
197			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
198			if err != nil {
199				break
200			}
201			key = v
202			n = o.n
203		case 2:
204			if wtyp != protowire.BytesType {
205				break
206			}
207			var v []byte
208			v, n = protowire.ConsumeBytes(b)
209			if n < 0 {
210				return out, protowire.ParseError(n)
211			}
212			var o unmarshalOutput
213			o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
214			if o.initialized {
215				// Consider this map item initialized so long as we see
216				// an initialized value.
217				out.initialized = true
218			}
219		}
220		if err == errUnknown {
221			n = protowire.ConsumeFieldValue(num, wtyp, b)
222			if n < 0 {
223				return out, protowire.ParseError(n)
224			}
225		} else if err != nil {
226			return out, err
227		}
228		b = b[n:]
229	}
230	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
231	out.n = n
232	return out, nil
233}
234
235func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
236	if f.mi == nil {
237		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
238		val := mapi.conv.valConv.PBValueOf(valrv)
239		size := 0
240		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
241		size += mapi.valFuncs.size(val, mapValTagSize, opts)
242		b = protowire.AppendVarint(b, uint64(size))
243		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
244		if err != nil {
245			return nil, err
246		}
247		return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
248	} else {
249		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
250		val := pointerOfValue(valrv)
251		valSize := f.mi.sizePointer(val, opts)
252		size := 0
253		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
254		size += mapValTagSize + protowire.SizeBytes(valSize)
255		b = protowire.AppendVarint(b, uint64(size))
256		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
257		if err != nil {
258			return nil, err
259		}
260		b = protowire.AppendVarint(b, mapi.valWiretag)
261		b = protowire.AppendVarint(b, uint64(valSize))
262		return f.mi.marshalAppendPointer(b, val, opts)
263	}
264}
265
266func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
267	if mapv.Len() == 0 {
268		return b, nil
269	}
270	if opts.Deterministic() {
271		return appendMapDeterministic(b, mapv, mapi, f, opts)
272	}
273	iter := mapRange(mapv)
274	for iter.Next() {
275		var err error
276		b = protowire.AppendVarint(b, f.wiretag)
277		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
278		if err != nil {
279			return b, err
280		}
281	}
282	return b, nil
283}
284
285func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
286	keys := mapv.MapKeys()
287	sort.Slice(keys, func(i, j int) bool {
288		switch keys[i].Kind() {
289		case reflect.Bool:
290			return !keys[i].Bool() && keys[j].Bool()
291		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
292			return keys[i].Int() < keys[j].Int()
293		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
294			return keys[i].Uint() < keys[j].Uint()
295		case reflect.Float32, reflect.Float64:
296			return keys[i].Float() < keys[j].Float()
297		case reflect.String:
298			return keys[i].String() < keys[j].String()
299		default:
300			panic("invalid kind: " + keys[i].Kind().String())
301		}
302	})
303	for _, key := range keys {
304		var err error
305		b = protowire.AppendVarint(b, f.wiretag)
306		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
307		if err != nil {
308			return b, err
309		}
310	}
311	return b, nil
312}
313
314func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
315	if mi := f.mi; mi != nil {
316		mi.init()
317		if !mi.needsInitCheck {
318			return nil
319		}
320		iter := mapRange(mapv)
321		for iter.Next() {
322			val := pointerOfValue(iter.Value())
323			if err := mi.checkInitializedPointer(val); err != nil {
324				return err
325			}
326		}
327	} else {
328		iter := mapRange(mapv)
329		for iter.Next() {
330			val := mapi.conv.valConv.PBValueOf(iter.Value())
331			if err := mapi.valFuncs.isInit(val); err != nil {
332				return err
333			}
334		}
335	}
336	return nil
337}
338
339func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
340	dstm := dst.AsValueOf(f.ft).Elem()
341	srcm := src.AsValueOf(f.ft).Elem()
342	if srcm.Len() == 0 {
343		return
344	}
345	if dstm.IsNil() {
346		dstm.Set(reflect.MakeMap(f.ft))
347	}
348	iter := mapRange(srcm)
349	for iter.Next() {
350		dstm.SetMapIndex(iter.Key(), iter.Value())
351	}
352}
353
354func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
355	dstm := dst.AsValueOf(f.ft).Elem()
356	srcm := src.AsValueOf(f.ft).Elem()
357	if srcm.Len() == 0 {
358		return
359	}
360	if dstm.IsNil() {
361		dstm.Set(reflect.MakeMap(f.ft))
362	}
363	iter := mapRange(srcm)
364	for iter.Next() {
365		dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
366	}
367}
368
369func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
370	dstm := dst.AsValueOf(f.ft).Elem()
371	srcm := src.AsValueOf(f.ft).Elem()
372	if srcm.Len() == 0 {
373		return
374	}
375	if dstm.IsNil() {
376		dstm.Set(reflect.MakeMap(f.ft))
377	}
378	iter := mapRange(srcm)
379	for iter.Next() {
380		val := reflect.New(f.ft.Elem().Elem())
381		if f.mi != nil {
382			f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
383		} else {
384			opts.Merge(asMessage(val), asMessage(iter.Value()))
385		}
386		dstm.SetMapIndex(iter.Key(), val)
387	}
388}
389