1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package jsonpb
6
7import (
8	"encoding/json"
9	"errors"
10	"fmt"
11	"io"
12	"math"
13	"reflect"
14	"strconv"
15	"strings"
16	"time"
17
18	"github.com/golang/protobuf/proto"
19	"google.golang.org/protobuf/encoding/protojson"
20	protoV2 "google.golang.org/protobuf/proto"
21	"google.golang.org/protobuf/reflect/protoreflect"
22	"google.golang.org/protobuf/reflect/protoregistry"
23)
24
25const wrapJSONUnmarshalV2 = false
26
27// UnmarshalNext unmarshals the next JSON object from d into m.
28func UnmarshalNext(d *json.Decoder, m proto.Message) error {
29	return new(Unmarshaler).UnmarshalNext(d, m)
30}
31
32// Unmarshal unmarshals a JSON object from r into m.
33func Unmarshal(r io.Reader, m proto.Message) error {
34	return new(Unmarshaler).Unmarshal(r, m)
35}
36
37// UnmarshalString unmarshals a JSON object from s into m.
38func UnmarshalString(s string, m proto.Message) error {
39	return new(Unmarshaler).Unmarshal(strings.NewReader(s), m)
40}
41
42// Unmarshaler is a configurable object for converting from a JSON
43// representation to a protocol buffer object.
44type Unmarshaler struct {
45	// AllowUnknownFields specifies whether to allow messages to contain
46	// unknown JSON fields, as opposed to failing to unmarshal.
47	AllowUnknownFields bool
48
49	// AnyResolver is used to resolve the google.protobuf.Any well-known type.
50	// If unset, the global registry is used by default.
51	AnyResolver AnyResolver
52}
53
54// JSONPBUnmarshaler is implemented by protobuf messages that customize the way
55// they are unmarshaled from JSON. Messages that implement this should also
56// implement JSONPBMarshaler so that the custom format can be produced.
57//
58// The JSON unmarshaling must follow the JSON to proto specification:
59//	https://developers.google.com/protocol-buffers/docs/proto3#json
60//
61// Deprecated: Custom types should implement protobuf reflection instead.
62type JSONPBUnmarshaler interface {
63	UnmarshalJSONPB(*Unmarshaler, []byte) error
64}
65
66// Unmarshal unmarshals a JSON object from r into m.
67func (u *Unmarshaler) Unmarshal(r io.Reader, m proto.Message) error {
68	return u.UnmarshalNext(json.NewDecoder(r), m)
69}
70
71// UnmarshalNext unmarshals the next JSON object from d into m.
72func (u *Unmarshaler) UnmarshalNext(d *json.Decoder, m proto.Message) error {
73	if m == nil {
74		return errors.New("invalid nil message")
75	}
76
77	// Parse the next JSON object from the stream.
78	raw := json.RawMessage{}
79	if err := d.Decode(&raw); err != nil {
80		return err
81	}
82
83	// Check for custom unmarshalers first since they may not properly
84	// implement protobuf reflection that the logic below relies on.
85	if jsu, ok := m.(JSONPBUnmarshaler); ok {
86		return jsu.UnmarshalJSONPB(u, raw)
87	}
88
89	mr := proto.MessageReflect(m)
90
91	// NOTE: For historical reasons, a top-level null is treated as a noop.
92	// This is incorrect, but kept for compatibility.
93	if string(raw) == "null" && mr.Descriptor().FullName() != "google.protobuf.Value" {
94		return nil
95	}
96
97	if wrapJSONUnmarshalV2 {
98		// NOTE: If input message is non-empty, we need to preserve merge semantics
99		// of the old jsonpb implementation. These semantics are not supported by
100		// the protobuf JSON specification.
101		isEmpty := true
102		mr.Range(func(protoreflect.FieldDescriptor, protoreflect.Value) bool {
103			isEmpty = false // at least one iteration implies non-empty
104			return false
105		})
106		if !isEmpty {
107			// Perform unmarshaling into a newly allocated, empty message.
108			mr = mr.New()
109
110			// Use a defer to copy all unmarshaled fields into the original message.
111			dst := proto.MessageReflect(m)
112			defer mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
113				dst.Set(fd, v)
114				return true
115			})
116		}
117
118		// Unmarshal using the v2 JSON unmarshaler.
119		opts := protojson.UnmarshalOptions{
120			DiscardUnknown: u.AllowUnknownFields,
121		}
122		if u.AnyResolver != nil {
123			opts.Resolver = anyResolver{u.AnyResolver}
124		}
125		return opts.Unmarshal(raw, mr.Interface())
126	} else {
127		if err := u.unmarshalMessage(mr, raw); err != nil {
128			return err
129		}
130		return protoV2.CheckInitialized(mr.Interface())
131	}
132}
133
134func (u *Unmarshaler) unmarshalMessage(m protoreflect.Message, in []byte) error {
135	md := m.Descriptor()
136	fds := md.Fields()
137
138	if string(in) == "null" && md.FullName() != "google.protobuf.Value" {
139		return nil
140	}
141
142	if jsu, ok := proto.MessageV1(m.Interface()).(JSONPBUnmarshaler); ok {
143		return jsu.UnmarshalJSONPB(u, in)
144	}
145
146	switch wellKnownType(md.FullName()) {
147	case "Any":
148		var jsonObject map[string]json.RawMessage
149		if err := json.Unmarshal(in, &jsonObject); err != nil {
150			return err
151		}
152
153		rawTypeURL, ok := jsonObject["@type"]
154		if !ok {
155			return errors.New("Any JSON doesn't have '@type'")
156		}
157		typeURL, err := unquoteString(string(rawTypeURL))
158		if err != nil {
159			return fmt.Errorf("can't unmarshal Any's '@type': %q", rawTypeURL)
160		}
161		m.Set(fds.ByNumber(1), protoreflect.ValueOfString(typeURL))
162
163		var m2 protoreflect.Message
164		if u.AnyResolver != nil {
165			mi, err := u.AnyResolver.Resolve(typeURL)
166			if err != nil {
167				return err
168			}
169			m2 = proto.MessageReflect(mi)
170		} else {
171			mt, err := protoregistry.GlobalTypes.FindMessageByURL(typeURL)
172			if err != nil {
173				if err == protoregistry.NotFound {
174					return fmt.Errorf("could not resolve Any message type: %v", typeURL)
175				}
176				return err
177			}
178			m2 = mt.New()
179		}
180
181		if wellKnownType(m2.Descriptor().FullName()) != "" {
182			rawValue, ok := jsonObject["value"]
183			if !ok {
184				return errors.New("Any JSON doesn't have 'value'")
185			}
186			if err := u.unmarshalMessage(m2, rawValue); err != nil {
187				return fmt.Errorf("can't unmarshal Any nested proto %v: %v", typeURL, err)
188			}
189		} else {
190			delete(jsonObject, "@type")
191			rawJSON, err := json.Marshal(jsonObject)
192			if err != nil {
193				return fmt.Errorf("can't generate JSON for Any's nested proto to be unmarshaled: %v", err)
194			}
195			if err = u.unmarshalMessage(m2, rawJSON); err != nil {
196				return fmt.Errorf("can't unmarshal Any nested proto %v: %v", typeURL, err)
197			}
198		}
199
200		rawWire, err := protoV2.Marshal(m2.Interface())
201		if err != nil {
202			return fmt.Errorf("can't marshal proto %v into Any.Value: %v", typeURL, err)
203		}
204		m.Set(fds.ByNumber(2), protoreflect.ValueOfBytes(rawWire))
205		return nil
206	case "BoolValue", "BytesValue", "StringValue",
207		"Int32Value", "UInt32Value", "FloatValue",
208		"Int64Value", "UInt64Value", "DoubleValue":
209		fd := fds.ByNumber(1)
210		v, err := u.unmarshalValue(m.NewField(fd), in, fd)
211		if err != nil {
212			return err
213		}
214		m.Set(fd, v)
215		return nil
216	case "Duration":
217		v, err := unquoteString(string(in))
218		if err != nil {
219			return err
220		}
221		d, err := time.ParseDuration(v)
222		if err != nil {
223			return fmt.Errorf("bad Duration: %v", err)
224		}
225
226		sec := d.Nanoseconds() / 1e9
227		nsec := d.Nanoseconds() % 1e9
228		m.Set(fds.ByNumber(1), protoreflect.ValueOfInt64(int64(sec)))
229		m.Set(fds.ByNumber(2), protoreflect.ValueOfInt32(int32(nsec)))
230		return nil
231	case "Timestamp":
232		v, err := unquoteString(string(in))
233		if err != nil {
234			return err
235		}
236		t, err := time.Parse(time.RFC3339Nano, v)
237		if err != nil {
238			return fmt.Errorf("bad Timestamp: %v", err)
239		}
240
241		sec := t.Unix()
242		nsec := t.Nanosecond()
243		m.Set(fds.ByNumber(1), protoreflect.ValueOfInt64(int64(sec)))
244		m.Set(fds.ByNumber(2), protoreflect.ValueOfInt32(int32(nsec)))
245		return nil
246	case "Value":
247		switch {
248		case string(in) == "null":
249			m.Set(fds.ByNumber(1), protoreflect.ValueOfEnum(0))
250		case string(in) == "true":
251			m.Set(fds.ByNumber(4), protoreflect.ValueOfBool(true))
252		case string(in) == "false":
253			m.Set(fds.ByNumber(4), protoreflect.ValueOfBool(false))
254		case hasPrefixAndSuffix('"', in, '"'):
255			s, err := unquoteString(string(in))
256			if err != nil {
257				return fmt.Errorf("unrecognized type for Value %q", in)
258			}
259			m.Set(fds.ByNumber(3), protoreflect.ValueOfString(s))
260		case hasPrefixAndSuffix('[', in, ']'):
261			v := m.Mutable(fds.ByNumber(6))
262			return u.unmarshalMessage(v.Message(), in)
263		case hasPrefixAndSuffix('{', in, '}'):
264			v := m.Mutable(fds.ByNumber(5))
265			return u.unmarshalMessage(v.Message(), in)
266		default:
267			f, err := strconv.ParseFloat(string(in), 0)
268			if err != nil {
269				return fmt.Errorf("unrecognized type for Value %q", in)
270			}
271			m.Set(fds.ByNumber(2), protoreflect.ValueOfFloat64(f))
272		}
273		return nil
274	case "ListValue":
275		var jsonArray []json.RawMessage
276		if err := json.Unmarshal(in, &jsonArray); err != nil {
277			return fmt.Errorf("bad ListValue: %v", err)
278		}
279
280		lv := m.Mutable(fds.ByNumber(1)).List()
281		for _, raw := range jsonArray {
282			ve := lv.NewElement()
283			if err := u.unmarshalMessage(ve.Message(), raw); err != nil {
284				return err
285			}
286			lv.Append(ve)
287		}
288		return nil
289	case "Struct":
290		var jsonObject map[string]json.RawMessage
291		if err := json.Unmarshal(in, &jsonObject); err != nil {
292			return fmt.Errorf("bad StructValue: %v", err)
293		}
294
295		mv := m.Mutable(fds.ByNumber(1)).Map()
296		for key, raw := range jsonObject {
297			kv := protoreflect.ValueOf(key).MapKey()
298			vv := mv.NewValue()
299			if err := u.unmarshalMessage(vv.Message(), raw); err != nil {
300				return fmt.Errorf("bad value in StructValue for key %q: %v", key, err)
301			}
302			mv.Set(kv, vv)
303		}
304		return nil
305	}
306
307	var jsonObject map[string]json.RawMessage
308	if err := json.Unmarshal(in, &jsonObject); err != nil {
309		return err
310	}
311
312	// Handle known fields.
313	for i := 0; i < fds.Len(); i++ {
314		fd := fds.Get(i)
315		if fd.IsWeak() && fd.Message().IsPlaceholder() {
316			continue //  weak reference is not linked in
317		}
318
319		// Search for any raw JSON value associated with this field.
320		var raw json.RawMessage
321		name := string(fd.Name())
322		if fd.Kind() == protoreflect.GroupKind {
323			name = string(fd.Message().Name())
324		}
325		if v, ok := jsonObject[name]; ok {
326			delete(jsonObject, name)
327			raw = v
328		}
329		name = string(fd.JSONName())
330		if v, ok := jsonObject[name]; ok {
331			delete(jsonObject, name)
332			raw = v
333		}
334
335		// Unmarshal the field value.
336		if raw == nil || (string(raw) == "null" && !isSingularWellKnownValue(fd)) {
337			continue
338		}
339		v, err := u.unmarshalValue(m.NewField(fd), raw, fd)
340		if err != nil {
341			return err
342		}
343		m.Set(fd, v)
344	}
345
346	// Handle extension fields.
347	for name, raw := range jsonObject {
348		if !strings.HasPrefix(name, "[") || !strings.HasSuffix(name, "]") {
349			continue
350		}
351
352		// Resolve the extension field by name.
353		xname := protoreflect.FullName(name[len("[") : len(name)-len("]")])
354		xt, _ := protoregistry.GlobalTypes.FindExtensionByName(xname)
355		if xt == nil && isMessageSet(md) {
356			xt, _ = protoregistry.GlobalTypes.FindExtensionByName(xname.Append("message_set_extension"))
357		}
358		if xt == nil {
359			continue
360		}
361		delete(jsonObject, name)
362		fd := xt.TypeDescriptor()
363		if fd.ContainingMessage().FullName() != m.Descriptor().FullName() {
364			return fmt.Errorf("extension field %q does not extend message %q", xname, m.Descriptor().FullName())
365		}
366
367		// Unmarshal the field value.
368		if raw == nil || (string(raw) == "null" && !isSingularWellKnownValue(fd)) {
369			continue
370		}
371		v, err := u.unmarshalValue(m.NewField(fd), raw, fd)
372		if err != nil {
373			return err
374		}
375		m.Set(fd, v)
376	}
377
378	if !u.AllowUnknownFields && len(jsonObject) > 0 {
379		for name := range jsonObject {
380			return fmt.Errorf("unknown field %q in %v", name, md.FullName())
381		}
382	}
383	return nil
384}
385
386func isSingularWellKnownValue(fd protoreflect.FieldDescriptor) bool {
387	if md := fd.Message(); md != nil {
388		return md.FullName() == "google.protobuf.Value" && fd.Cardinality() != protoreflect.Repeated
389	}
390	return false
391}
392
393func (u *Unmarshaler) unmarshalValue(v protoreflect.Value, in []byte, fd protoreflect.FieldDescriptor) (protoreflect.Value, error) {
394	switch {
395	case fd.IsList():
396		var jsonArray []json.RawMessage
397		if err := json.Unmarshal(in, &jsonArray); err != nil {
398			return v, err
399		}
400		lv := v.List()
401		for _, raw := range jsonArray {
402			ve, err := u.unmarshalSingularValue(lv.NewElement(), raw, fd)
403			if err != nil {
404				return v, err
405			}
406			lv.Append(ve)
407		}
408		return v, nil
409	case fd.IsMap():
410		var jsonObject map[string]json.RawMessage
411		if err := json.Unmarshal(in, &jsonObject); err != nil {
412			return v, err
413		}
414		kfd := fd.MapKey()
415		vfd := fd.MapValue()
416		mv := v.Map()
417		for key, raw := range jsonObject {
418			var kv protoreflect.MapKey
419			if kfd.Kind() == protoreflect.StringKind {
420				kv = protoreflect.ValueOf(key).MapKey()
421			} else {
422				v, err := u.unmarshalSingularValue(kfd.Default(), []byte(key), kfd)
423				if err != nil {
424					return v, err
425				}
426				kv = v.MapKey()
427			}
428
429			vv, err := u.unmarshalSingularValue(mv.NewValue(), raw, vfd)
430			if err != nil {
431				return v, err
432			}
433			mv.Set(kv, vv)
434		}
435		return v, nil
436	default:
437		return u.unmarshalSingularValue(v, in, fd)
438	}
439}
440
441var nonFinite = map[string]float64{
442	`"NaN"`:       math.NaN(),
443	`"Infinity"`:  math.Inf(+1),
444	`"-Infinity"`: math.Inf(-1),
445}
446
447func (u *Unmarshaler) unmarshalSingularValue(v protoreflect.Value, in []byte, fd protoreflect.FieldDescriptor) (protoreflect.Value, error) {
448	switch fd.Kind() {
449	case protoreflect.BoolKind:
450		return unmarshalValue(in, new(bool))
451	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
452		return unmarshalValue(trimQuote(in), new(int32))
453	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
454		return unmarshalValue(trimQuote(in), new(int64))
455	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
456		return unmarshalValue(trimQuote(in), new(uint32))
457	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
458		return unmarshalValue(trimQuote(in), new(uint64))
459	case protoreflect.FloatKind:
460		if f, ok := nonFinite[string(in)]; ok {
461			return protoreflect.ValueOfFloat32(float32(f)), nil
462		}
463		return unmarshalValue(trimQuote(in), new(float32))
464	case protoreflect.DoubleKind:
465		if f, ok := nonFinite[string(in)]; ok {
466			return protoreflect.ValueOfFloat64(float64(f)), nil
467		}
468		return unmarshalValue(trimQuote(in), new(float64))
469	case protoreflect.StringKind:
470		return unmarshalValue(in, new(string))
471	case protoreflect.BytesKind:
472		return unmarshalValue(in, new([]byte))
473	case protoreflect.EnumKind:
474		if hasPrefixAndSuffix('"', in, '"') {
475			vd := fd.Enum().Values().ByName(protoreflect.Name(trimQuote(in)))
476			if vd == nil {
477				return v, fmt.Errorf("unknown value %q for enum %s", in, fd.Enum().FullName())
478			}
479			return protoreflect.ValueOfEnum(vd.Number()), nil
480		}
481		return unmarshalValue(in, new(protoreflect.EnumNumber))
482	case protoreflect.MessageKind, protoreflect.GroupKind:
483		err := u.unmarshalMessage(v.Message(), in)
484		return v, err
485	default:
486		panic(fmt.Sprintf("invalid kind %v", fd.Kind()))
487	}
488}
489
490func unmarshalValue(in []byte, v interface{}) (protoreflect.Value, error) {
491	err := json.Unmarshal(in, v)
492	return protoreflect.ValueOf(reflect.ValueOf(v).Elem().Interface()), err
493}
494
495func unquoteString(in string) (out string, err error) {
496	err = json.Unmarshal([]byte(in), &out)
497	return out, err
498}
499
500func hasPrefixAndSuffix(prefix byte, in []byte, suffix byte) bool {
501	if len(in) >= 2 && in[0] == prefix && in[len(in)-1] == suffix {
502		return true
503	}
504	return false
505}
506
507// trimQuote is like unquoteString but simply strips surrounding quotes.
508// This is incorrect, but is behavior done by the legacy implementation.
509func trimQuote(in []byte) []byte {
510	if len(in) >= 2 && in[0] == '"' && in[len(in)-1] == '"' {
511		in = in[1 : len(in)-1]
512	}
513	return in
514}
515