1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package jsonpb
6
7import (
8	"encoding/json"
9	"errors"
10	"fmt"
11	"io"
12	"math"
13	"reflect"
14	"sort"
15	"strconv"
16	"strings"
17	"time"
18
19	"github.com/golang/protobuf/proto"
20	"google.golang.org/protobuf/encoding/protojson"
21	protoV2 "google.golang.org/protobuf/proto"
22	"google.golang.org/protobuf/reflect/protoreflect"
23	"google.golang.org/protobuf/reflect/protoregistry"
24)
25
26const wrapJSONMarshalV2 = false
27
28// Marshaler is a configurable object for marshaling protocol buffer messages
29// to the specified JSON representation.
30type Marshaler struct {
31	// OrigName specifies whether to use the original protobuf name for fields.
32	OrigName bool
33
34	// EnumsAsInts specifies whether to render enum values as integers,
35	// as opposed to string values.
36	EnumsAsInts bool
37
38	// EmitDefaults specifies whether to render fields with zero values.
39	EmitDefaults bool
40
41	// Indent controls whether the output is compact or not.
42	// If empty, the output is compact JSON. Otherwise, every JSON object
43	// entry and JSON array value will be on its own line.
44	// Each line will be preceded by repeated copies of Indent, where the
45	// number of copies is the current indentation depth.
46	Indent string
47
48	// AnyResolver is used to resolve the google.protobuf.Any well-known type.
49	// If unset, the global registry is used by default.
50	AnyResolver AnyResolver
51}
52
53// JSONPBMarshaler is implemented by protobuf messages that customize the
54// way they are marshaled to JSON. Messages that implement this should also
55// implement JSONPBUnmarshaler so that the custom format can be parsed.
56//
57// The JSON marshaling must follow the proto to JSON specification:
58//	https://developers.google.com/protocol-buffers/docs/proto3#json
59//
60// Deprecated: Custom types should implement protobuf reflection instead.
61type JSONPBMarshaler interface {
62	MarshalJSONPB(*Marshaler) ([]byte, error)
63}
64
65// Marshal serializes a protobuf message as JSON into w.
66func (jm *Marshaler) Marshal(w io.Writer, m proto.Message) error {
67	b, err := jm.marshal(m)
68	if len(b) > 0 {
69		if _, err := w.Write(b); err != nil {
70			return err
71		}
72	}
73	return err
74}
75
76// MarshalToString serializes a protobuf message as JSON in string form.
77func (jm *Marshaler) MarshalToString(m proto.Message) (string, error) {
78	b, err := jm.marshal(m)
79	if err != nil {
80		return "", err
81	}
82	return string(b), nil
83}
84
85func (jm *Marshaler) marshal(m proto.Message) ([]byte, error) {
86	v := reflect.ValueOf(m)
87	if m == nil || (v.Kind() == reflect.Ptr && v.IsNil()) {
88		return nil, errors.New("Marshal called with nil")
89	}
90
91	// Check for custom marshalers first since they may not properly
92	// implement protobuf reflection that the logic below relies on.
93	if jsm, ok := m.(JSONPBMarshaler); ok {
94		return jsm.MarshalJSONPB(jm)
95	}
96
97	if wrapJSONMarshalV2 {
98		opts := protojson.MarshalOptions{
99			UseProtoNames:   jm.OrigName,
100			UseEnumNumbers:  jm.EnumsAsInts,
101			EmitUnpopulated: jm.EmitDefaults,
102			Indent:          jm.Indent,
103		}
104		if jm.AnyResolver != nil {
105			opts.Resolver = anyResolver{jm.AnyResolver}
106		}
107		return opts.Marshal(proto.MessageReflect(m).Interface())
108	} else {
109		// Check for unpopulated required fields first.
110		m2 := proto.MessageReflect(m)
111		if err := protoV2.CheckInitialized(m2.Interface()); err != nil {
112			return nil, err
113		}
114
115		w := jsonWriter{Marshaler: jm}
116		err := w.marshalMessage(m2, "", "")
117		return w.buf, err
118	}
119}
120
121type jsonWriter struct {
122	*Marshaler
123	buf []byte
124}
125
126func (w *jsonWriter) write(s string) {
127	w.buf = append(w.buf, s...)
128}
129
130func (w *jsonWriter) marshalMessage(m protoreflect.Message, indent, typeURL string) error {
131	if jsm, ok := proto.MessageV1(m.Interface()).(JSONPBMarshaler); ok {
132		b, err := jsm.MarshalJSONPB(w.Marshaler)
133		if err != nil {
134			return err
135		}
136		if typeURL != "" {
137			// we are marshaling this object to an Any type
138			var js map[string]*json.RawMessage
139			if err = json.Unmarshal(b, &js); err != nil {
140				return fmt.Errorf("type %T produced invalid JSON: %v", m.Interface(), err)
141			}
142			turl, err := json.Marshal(typeURL)
143			if err != nil {
144				return fmt.Errorf("failed to marshal type URL %q to JSON: %v", typeURL, err)
145			}
146			js["@type"] = (*json.RawMessage)(&turl)
147			if b, err = json.Marshal(js); err != nil {
148				return err
149			}
150		}
151		w.write(string(b))
152		return nil
153	}
154
155	md := m.Descriptor()
156	fds := md.Fields()
157
158	// Handle well-known types.
159	const secondInNanos = int64(time.Second / time.Nanosecond)
160	switch wellKnownType(md.FullName()) {
161	case "Any":
162		return w.marshalAny(m, indent)
163	case "BoolValue", "BytesValue", "StringValue",
164		"Int32Value", "UInt32Value", "FloatValue",
165		"Int64Value", "UInt64Value", "DoubleValue":
166		fd := fds.ByNumber(1)
167		return w.marshalValue(fd, m.Get(fd), indent)
168	case "Duration":
169		const maxSecondsInDuration = 315576000000
170		// "Generated output always contains 0, 3, 6, or 9 fractional digits,
171		//  depending on required precision."
172		s := m.Get(fds.ByNumber(1)).Int()
173		ns := m.Get(fds.ByNumber(2)).Int()
174		if s < -maxSecondsInDuration || s > maxSecondsInDuration {
175			return fmt.Errorf("seconds out of range %v", s)
176		}
177		if ns <= -secondInNanos || ns >= secondInNanos {
178			return fmt.Errorf("ns out of range (%v, %v)", -secondInNanos, secondInNanos)
179		}
180		if (s > 0 && ns < 0) || (s < 0 && ns > 0) {
181			return errors.New("signs of seconds and nanos do not match")
182		}
183		var sign string
184		if s < 0 || ns < 0 {
185			sign, s, ns = "-", -1*s, -1*ns
186		}
187		x := fmt.Sprintf("%s%d.%09d", sign, s, ns)
188		x = strings.TrimSuffix(x, "000")
189		x = strings.TrimSuffix(x, "000")
190		x = strings.TrimSuffix(x, ".000")
191		w.write(fmt.Sprintf(`"%vs"`, x))
192		return nil
193	case "Timestamp":
194		// "RFC 3339, where generated output will always be Z-normalized
195		//  and uses 0, 3, 6 or 9 fractional digits."
196		s := m.Get(fds.ByNumber(1)).Int()
197		ns := m.Get(fds.ByNumber(2)).Int()
198		if ns < 0 || ns >= secondInNanos {
199			return fmt.Errorf("ns out of range [0, %v)", secondInNanos)
200		}
201		t := time.Unix(s, ns).UTC()
202		// time.RFC3339Nano isn't exactly right (we need to get 3/6/9 fractional digits).
203		x := t.Format("2006-01-02T15:04:05.000000000")
204		x = strings.TrimSuffix(x, "000")
205		x = strings.TrimSuffix(x, "000")
206		x = strings.TrimSuffix(x, ".000")
207		w.write(fmt.Sprintf(`"%vZ"`, x))
208		return nil
209	case "Value":
210		// JSON value; which is a null, number, string, bool, object, or array.
211		od := md.Oneofs().Get(0)
212		fd := m.WhichOneof(od)
213		if fd == nil {
214			return errors.New("nil Value")
215		}
216		return w.marshalValue(fd, m.Get(fd), indent)
217	case "Struct", "ListValue":
218		// JSON object or array.
219		fd := fds.ByNumber(1)
220		return w.marshalValue(fd, m.Get(fd), indent)
221	}
222
223	w.write("{")
224	if w.Indent != "" {
225		w.write("\n")
226	}
227
228	firstField := true
229	if typeURL != "" {
230		if err := w.marshalTypeURL(indent, typeURL); err != nil {
231			return err
232		}
233		firstField = false
234	}
235
236	for i := 0; i < fds.Len(); {
237		fd := fds.Get(i)
238		if od := fd.ContainingOneof(); od != nil {
239			fd = m.WhichOneof(od)
240			i += od.Fields().Len()
241			if fd == nil {
242				continue
243			}
244		} else {
245			i++
246		}
247
248		v := m.Get(fd)
249
250		if !m.Has(fd) {
251			if !w.EmitDefaults || fd.ContainingOneof() != nil {
252				continue
253			}
254			if fd.Cardinality() != protoreflect.Repeated && (fd.Message() != nil || fd.Syntax() == protoreflect.Proto2) {
255				v = protoreflect.Value{} // use "null" for singular messages or proto2 scalars
256			}
257		}
258
259		if !firstField {
260			w.writeComma()
261		}
262		if err := w.marshalField(fd, v, indent); err != nil {
263			return err
264		}
265		firstField = false
266	}
267
268	// Handle proto2 extensions.
269	if md.ExtensionRanges().Len() > 0 {
270		// Collect a sorted list of all extension descriptor and values.
271		type ext struct {
272			desc protoreflect.FieldDescriptor
273			val  protoreflect.Value
274		}
275		var exts []ext
276		m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
277			if fd.IsExtension() {
278				exts = append(exts, ext{fd, v})
279			}
280			return true
281		})
282		sort.Slice(exts, func(i, j int) bool {
283			return exts[i].desc.Number() < exts[j].desc.Number()
284		})
285
286		for _, ext := range exts {
287			if !firstField {
288				w.writeComma()
289			}
290			if err := w.marshalField(ext.desc, ext.val, indent); err != nil {
291				return err
292			}
293			firstField = false
294		}
295	}
296
297	if w.Indent != "" {
298		w.write("\n")
299		w.write(indent)
300	}
301	w.write("}")
302	return nil
303}
304
305func (w *jsonWriter) writeComma() {
306	if w.Indent != "" {
307		w.write(",\n")
308	} else {
309		w.write(",")
310	}
311}
312
313func (w *jsonWriter) marshalAny(m protoreflect.Message, indent string) error {
314	// "If the Any contains a value that has a special JSON mapping,
315	//  it will be converted as follows: {"@type": xxx, "value": yyy}.
316	//  Otherwise, the value will be converted into a JSON object,
317	//  and the "@type" field will be inserted to indicate the actual data type."
318	md := m.Descriptor()
319	typeURL := m.Get(md.Fields().ByNumber(1)).String()
320	rawVal := m.Get(md.Fields().ByNumber(2)).Bytes()
321
322	var m2 protoreflect.Message
323	if w.AnyResolver != nil {
324		mi, err := w.AnyResolver.Resolve(typeURL)
325		if err != nil {
326			return err
327		}
328		m2 = proto.MessageReflect(mi)
329	} else {
330		mt, err := protoregistry.GlobalTypes.FindMessageByURL(typeURL)
331		if err != nil {
332			return err
333		}
334		m2 = mt.New()
335	}
336
337	if err := protoV2.Unmarshal(rawVal, m2.Interface()); err != nil {
338		return err
339	}
340
341	if wellKnownType(m2.Descriptor().FullName()) == "" {
342		return w.marshalMessage(m2, indent, typeURL)
343	}
344
345	w.write("{")
346	if w.Indent != "" {
347		w.write("\n")
348	}
349	if err := w.marshalTypeURL(indent, typeURL); err != nil {
350		return err
351	}
352	w.writeComma()
353	if w.Indent != "" {
354		w.write(indent)
355		w.write(w.Indent)
356		w.write(`"value": `)
357	} else {
358		w.write(`"value":`)
359	}
360	if err := w.marshalMessage(m2, indent+w.Indent, ""); err != nil {
361		return err
362	}
363	if w.Indent != "" {
364		w.write("\n")
365		w.write(indent)
366	}
367	w.write("}")
368	return nil
369}
370
371func (w *jsonWriter) marshalTypeURL(indent, typeURL string) error {
372	if w.Indent != "" {
373		w.write(indent)
374		w.write(w.Indent)
375	}
376	w.write(`"@type":`)
377	if w.Indent != "" {
378		w.write(" ")
379	}
380	b, err := json.Marshal(typeURL)
381	if err != nil {
382		return err
383	}
384	w.write(string(b))
385	return nil
386}
387
388// marshalField writes field description and value to the Writer.
389func (w *jsonWriter) marshalField(fd protoreflect.FieldDescriptor, v protoreflect.Value, indent string) error {
390	if w.Indent != "" {
391		w.write(indent)
392		w.write(w.Indent)
393	}
394	w.write(`"`)
395	switch {
396	case fd.IsExtension():
397		// For message set, use the fname of the message as the extension name.
398		name := string(fd.FullName())
399		if isMessageSet(fd.ContainingMessage()) {
400			name = strings.TrimSuffix(name, ".message_set_extension")
401		}
402
403		w.write("[" + name + "]")
404	case w.OrigName:
405		name := string(fd.Name())
406		if fd.Kind() == protoreflect.GroupKind {
407			name = string(fd.Message().Name())
408		}
409		w.write(name)
410	default:
411		w.write(string(fd.JSONName()))
412	}
413	w.write(`":`)
414	if w.Indent != "" {
415		w.write(" ")
416	}
417	return w.marshalValue(fd, v, indent)
418}
419
420func (w *jsonWriter) marshalValue(fd protoreflect.FieldDescriptor, v protoreflect.Value, indent string) error {
421	switch {
422	case fd.IsList():
423		w.write("[")
424		comma := ""
425		lv := v.List()
426		for i := 0; i < lv.Len(); i++ {
427			w.write(comma)
428			if w.Indent != "" {
429				w.write("\n")
430				w.write(indent)
431				w.write(w.Indent)
432				w.write(w.Indent)
433			}
434			if err := w.marshalSingularValue(fd, lv.Get(i), indent+w.Indent); err != nil {
435				return err
436			}
437			comma = ","
438		}
439		if w.Indent != "" {
440			w.write("\n")
441			w.write(indent)
442			w.write(w.Indent)
443		}
444		w.write("]")
445		return nil
446	case fd.IsMap():
447		kfd := fd.MapKey()
448		vfd := fd.MapValue()
449		mv := v.Map()
450
451		// Collect a sorted list of all map keys and values.
452		type entry struct{ key, val protoreflect.Value }
453		var entries []entry
454		mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
455			entries = append(entries, entry{k.Value(), v})
456			return true
457		})
458		sort.Slice(entries, func(i, j int) bool {
459			switch kfd.Kind() {
460			case protoreflect.BoolKind:
461				return !entries[i].key.Bool() && entries[j].key.Bool()
462			case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
463				return entries[i].key.Int() < entries[j].key.Int()
464			case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
465				return entries[i].key.Uint() < entries[j].key.Uint()
466			case protoreflect.StringKind:
467				return entries[i].key.String() < entries[j].key.String()
468			default:
469				panic("invalid kind")
470			}
471		})
472
473		w.write(`{`)
474		comma := ""
475		for _, entry := range entries {
476			w.write(comma)
477			if w.Indent != "" {
478				w.write("\n")
479				w.write(indent)
480				w.write(w.Indent)
481				w.write(w.Indent)
482			}
483
484			s := fmt.Sprint(entry.key.Interface())
485			b, err := json.Marshal(s)
486			if err != nil {
487				return err
488			}
489			w.write(string(b))
490
491			w.write(`:`)
492			if w.Indent != "" {
493				w.write(` `)
494			}
495
496			if err := w.marshalSingularValue(vfd, entry.val, indent+w.Indent); err != nil {
497				return err
498			}
499			comma = ","
500		}
501		if w.Indent != "" {
502			w.write("\n")
503			w.write(indent)
504			w.write(w.Indent)
505		}
506		w.write(`}`)
507		return nil
508	default:
509		return w.marshalSingularValue(fd, v, indent)
510	}
511}
512
513func (w *jsonWriter) marshalSingularValue(fd protoreflect.FieldDescriptor, v protoreflect.Value, indent string) error {
514	switch {
515	case !v.IsValid():
516		w.write("null")
517		return nil
518	case fd.Message() != nil:
519		return w.marshalMessage(v.Message(), indent+w.Indent, "")
520	case fd.Enum() != nil:
521		if fd.Enum().FullName() == "google.protobuf.NullValue" {
522			w.write("null")
523			return nil
524		}
525
526		vd := fd.Enum().Values().ByNumber(v.Enum())
527		if vd == nil || w.EnumsAsInts {
528			w.write(strconv.Itoa(int(v.Enum())))
529		} else {
530			w.write(`"` + string(vd.Name()) + `"`)
531		}
532		return nil
533	default:
534		switch v.Interface().(type) {
535		case float32, float64:
536			switch {
537			case math.IsInf(v.Float(), +1):
538				w.write(`"Infinity"`)
539				return nil
540			case math.IsInf(v.Float(), -1):
541				w.write(`"-Infinity"`)
542				return nil
543			case math.IsNaN(v.Float()):
544				w.write(`"NaN"`)
545				return nil
546			}
547		case int64, uint64:
548			w.write(fmt.Sprintf(`"%d"`, v.Interface()))
549			return nil
550		}
551
552		b, err := json.Marshal(v.Interface())
553		if err != nil {
554			return err
555		}
556		w.write(string(b))
557		return nil
558	}
559}
560