1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package jsonpb
6
7import (
8	"encoding/json"
9	"errors"
10	"fmt"
11	"io"
12	"math"
13	"reflect"
14	"sort"
15	"strconv"
16	"strings"
17	"time"
18
19	"github.com/golang/protobuf/proto"
20	"google.golang.org/protobuf/encoding/protojson"
21	protoV2 "google.golang.org/protobuf/proto"
22	"google.golang.org/protobuf/reflect/protoreflect"
23	"google.golang.org/protobuf/reflect/protoregistry"
24)
25
26const wrapJSONMarshalV2 = false
27
28// Marshaler is a configurable object for marshaling protocol buffer messages
29// to the specified JSON representation.
30type Marshaler struct {
31	// OrigName specifies whether to use the original protobuf name for fields.
32	OrigName bool
33
34	// EnumsAsInts specifies whether to render enum values as integers,
35	// as opposed to string values.
36	EnumsAsInts bool
37
38	// EmitDefaults specifies Whether to render fields with zero values.
39	EmitDefaults bool
40
41	// Indent controls whether the output is compact or not.
42	// If empty, the output is compact JSON. If non-empty, every JSON object
43	// entry and JSON array value will be on its own line.
44	// Each line will be preceded by repeated copies of Indent, where the
45	// number of copies is the current indentation depth.
46	Indent string
47
48	// AnyResolver is used to resolve the google.protobuf.Any well-known type.
49	// If unset, the global registry is used by default.
50	AnyResolver AnyResolver
51}
52
53// JSONPBMarshaler is implemented by protobuf messages that customize the
54// way they are marshaled to JSON. Messages that implement this should also
55// implement JSONPBUnmarshaler so that the custom format can be parsed.
56//
57// The JSON marshaling must follow the proto to JSON specification:
58//	https://developers.google.com/protocol-buffers/docs/proto3#json
59//
60// Deprecated: Custom types should implement protobuf reflection instead.
61type JSONPBMarshaler interface {
62	MarshalJSONPB(*Marshaler) ([]byte, error)
63}
64
65// Marshal marshals a protocol buffer into JSON.
66func (jm *Marshaler) Marshal(w io.Writer, m proto.Message) error {
67	b, err := jm.marshal(m)
68	if len(b) > 0 {
69		if _, err := w.Write(b); err != nil {
70			return err
71		}
72	}
73	return err
74}
75
76// MarshalToString converts a protocol buffer object to JSON string.
77func (jm *Marshaler) MarshalToString(m proto.Message) (string, error) {
78	b, err := jm.marshal(m)
79	if err != nil {
80		return "", err
81	}
82	return string(b), nil
83}
84
85func (jm *Marshaler) marshal(m proto.Message) ([]byte, error) {
86	v := reflect.ValueOf(m)
87	if m == nil || (v.Kind() == reflect.Ptr && v.IsNil()) {
88		return nil, errors.New("Marshal called with nil")
89	}
90
91	// Check for custom marshalers first since they may not properly
92	// implement protobuf reflection that the logic below relies on.
93	if jsm, ok := m.(JSONPBMarshaler); ok {
94		return jsm.MarshalJSONPB(jm)
95	}
96
97	if wrapJSONMarshalV2 {
98		opts := protojson.MarshalOptions{
99			UseProtoNames:   jm.OrigName,
100			UseEnumNumbers:  jm.EnumsAsInts,
101			EmitUnpopulated: jm.EmitDefaults,
102			Indent:          jm.Indent,
103		}
104		if jm.AnyResolver != nil {
105			opts.Resolver = anyResolver{jm.AnyResolver}
106		}
107		return opts.Marshal(proto.MessageReflect(m).Interface())
108	} else {
109		// Check for unpopulated required fields first.
110		m2 := proto.MessageReflect(m)
111		if err := protoV2.CheckInitialized(m2.Interface()); err != nil {
112			return nil, err
113		}
114
115		w := jsonWriter{Marshaler: jm}
116		err := w.marshalMessage(m2, "", "")
117		return w.buf, err
118	}
119}
120
121type jsonWriter struct {
122	*Marshaler
123	buf []byte
124}
125
126func (w *jsonWriter) write(s string) {
127	w.buf = append(w.buf, s...)
128}
129
130func (w *jsonWriter) marshalMessage(m protoreflect.Message, indent, typeURL string) error {
131	if jsm, ok := proto.MessageV1(m.Interface()).(JSONPBMarshaler); ok {
132		b, err := jsm.MarshalJSONPB(w.Marshaler)
133		if err != nil {
134			return err
135		}
136		if typeURL != "" {
137			// we are marshaling this object to an Any type
138			var js map[string]*json.RawMessage
139			if err = json.Unmarshal(b, &js); err != nil {
140				return fmt.Errorf("type %T produced invalid JSON: %v", m.Interface(), err)
141			}
142			turl, err := json.Marshal(typeURL)
143			if err != nil {
144				return fmt.Errorf("failed to marshal type URL %q to JSON: %v", typeURL, err)
145			}
146			js["@type"] = (*json.RawMessage)(&turl)
147			if b, err = json.Marshal(js); err != nil {
148				return err
149			}
150		}
151		w.write(string(b))
152		return nil
153	}
154
155	md := m.Descriptor()
156	fds := md.Fields()
157
158	// Handle well-known types.
159	const secondInNanos = int64(time.Second / time.Nanosecond)
160	switch wellKnownType(md.FullName()) {
161	case "Any":
162		return w.marshalAny(m, indent)
163	case "BoolValue", "BytesValue", "StringValue",
164		"Int32Value", "UInt32Value", "FloatValue",
165		"Int64Value", "UInt64Value", "DoubleValue":
166		fd := fds.ByNumber(1)
167		return w.marshalValue(fd, m.Get(fd), indent)
168	case "Duration":
169		// "Generated output always contains 0, 3, 6, or 9 fractional digits,
170		//  depending on required precision."
171		s := m.Get(fds.ByNumber(1)).Int()
172		ns := m.Get(fds.ByNumber(2)).Int()
173		if ns <= -secondInNanos || ns >= secondInNanos {
174			return fmt.Errorf("ns out of range (%v, %v)", -secondInNanos, secondInNanos)
175		}
176		if (s > 0 && ns < 0) || (s < 0 && ns > 0) {
177			return errors.New("signs of seconds and nanos do not match")
178		}
179		if s < 0 {
180			ns = -ns
181		}
182		x := fmt.Sprintf("%d.%09d", s, ns)
183		x = strings.TrimSuffix(x, "000")
184		x = strings.TrimSuffix(x, "000")
185		x = strings.TrimSuffix(x, ".000")
186		w.write(fmt.Sprintf(`"%vs"`, x))
187		return nil
188	case "Timestamp":
189		// "RFC 3339, where generated output will always be Z-normalized
190		//  and uses 0, 3, 6 or 9 fractional digits."
191		s := m.Get(fds.ByNumber(1)).Int()
192		ns := m.Get(fds.ByNumber(2)).Int()
193		if ns < 0 || ns >= secondInNanos {
194			return fmt.Errorf("ns out of range [0, %v)", secondInNanos)
195		}
196		t := time.Unix(s, ns).UTC()
197		// time.RFC3339Nano isn't exactly right (we need to get 3/6/9 fractional digits).
198		x := t.Format("2006-01-02T15:04:05.000000000")
199		x = strings.TrimSuffix(x, "000")
200		x = strings.TrimSuffix(x, "000")
201		x = strings.TrimSuffix(x, ".000")
202		w.write(fmt.Sprintf(`"%vZ"`, x))
203		return nil
204	case "Value":
205		// JSON value; which is a null, number, string, bool, object, or array.
206		od := md.Oneofs().Get(0)
207		fd := m.WhichOneof(od)
208		if fd == nil {
209			return errors.New("nil Value")
210		}
211		return w.marshalValue(fd, m.Get(fd), indent)
212	case "Struct", "ListValue":
213		// JSON object or array.
214		fd := fds.ByNumber(1)
215		return w.marshalValue(fd, m.Get(fd), indent)
216	}
217
218	w.write("{")
219	if w.Indent != "" {
220		w.write("\n")
221	}
222
223	firstField := true
224	if typeURL != "" {
225		if err := w.marshalTypeURL(indent, typeURL); err != nil {
226			return err
227		}
228		firstField = false
229	}
230
231	for i := 0; i < fds.Len(); {
232		fd := fds.Get(i)
233		if od := fd.ContainingOneof(); od != nil {
234			fd = m.WhichOneof(od)
235			i += od.Fields().Len()
236			if fd == nil {
237				continue
238			}
239		} else {
240			i++
241		}
242
243		v := m.Get(fd)
244
245		if !m.Has(fd) {
246			if !w.EmitDefaults || fd.ContainingOneof() != nil {
247				continue
248			}
249			if fd.Cardinality() != protoreflect.Repeated && (fd.Message() != nil || fd.Syntax() == protoreflect.Proto2) {
250				v = protoreflect.Value{} // use "null" for singular messages or proto2 scalars
251			}
252		}
253
254		if !firstField {
255			w.writeComma()
256		}
257		if err := w.marshalField(fd, v, indent); err != nil {
258			return err
259		}
260		firstField = false
261	}
262
263	// Handle proto2 extensions.
264	if md.ExtensionRanges().Len() > 0 {
265		// Collect a sorted list of all extension descriptor and values.
266		type ext struct {
267			desc protoreflect.FieldDescriptor
268			val  protoreflect.Value
269		}
270		var exts []ext
271		m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
272			if fd.IsExtension() {
273				exts = append(exts, ext{fd, v})
274			}
275			return true
276		})
277		sort.Slice(exts, func(i, j int) bool {
278			return exts[i].desc.Number() < exts[j].desc.Number()
279		})
280
281		for _, ext := range exts {
282			if !firstField {
283				w.writeComma()
284			}
285			if err := w.marshalField(ext.desc, ext.val, indent); err != nil {
286				return err
287			}
288			firstField = false
289		}
290	}
291
292	if w.Indent != "" {
293		w.write("\n")
294		w.write(indent)
295	}
296	w.write("}")
297	return nil
298}
299
300func (w *jsonWriter) writeComma() {
301	if w.Indent != "" {
302		w.write(",\n")
303	} else {
304		w.write(",")
305	}
306}
307
308func (w *jsonWriter) marshalAny(m protoreflect.Message, indent string) error {
309	// "If the Any contains a value that has a special JSON mapping,
310	//  it will be converted as follows: {"@type": xxx, "value": yyy}.
311	//  Otherwise, the value will be converted into a JSON object,
312	//  and the "@type" field will be inserted to indicate the actual data type."
313	md := m.Descriptor()
314	typeURL := m.Get(md.Fields().ByNumber(1)).String()
315	rawVal := m.Get(md.Fields().ByNumber(2)).Bytes()
316
317	var m2 protoreflect.Message
318	if w.AnyResolver != nil {
319		mi, err := w.AnyResolver.Resolve(typeURL)
320		if err != nil {
321			return err
322		}
323		m2 = proto.MessageReflect(mi)
324	} else {
325		mt, err := protoregistry.GlobalTypes.FindMessageByURL(typeURL)
326		if err != nil {
327			return err
328		}
329		m2 = mt.New()
330	}
331
332	if err := protoV2.Unmarshal(rawVal, m2.Interface()); err != nil {
333		return err
334	}
335
336	if wellKnownType(m2.Descriptor().FullName()) == "" {
337		return w.marshalMessage(m2, indent, typeURL)
338	}
339
340	w.write("{")
341	if w.Indent != "" {
342		w.write("\n")
343	}
344	if err := w.marshalTypeURL(indent, typeURL); err != nil {
345		return err
346	}
347	w.writeComma()
348	if w.Indent != "" {
349		w.write(indent)
350		w.write(w.Indent)
351		w.write(`"value": `)
352	} else {
353		w.write(`"value":`)
354	}
355	if err := w.marshalMessage(m2, indent+w.Indent, ""); err != nil {
356		return err
357	}
358	if w.Indent != "" {
359		w.write("\n")
360		w.write(indent)
361	}
362	w.write("}")
363	return nil
364}
365
366func (w *jsonWriter) marshalTypeURL(indent, typeURL string) error {
367	if w.Indent != "" {
368		w.write(indent)
369		w.write(w.Indent)
370	}
371	w.write(`"@type":`)
372	if w.Indent != "" {
373		w.write(" ")
374	}
375	b, err := json.Marshal(typeURL)
376	if err != nil {
377		return err
378	}
379	w.write(string(b))
380	return nil
381}
382
383// marshalField writes field description and value to the Writer.
384func (w *jsonWriter) marshalField(fd protoreflect.FieldDescriptor, v protoreflect.Value, indent string) error {
385	if w.Indent != "" {
386		w.write(indent)
387		w.write(w.Indent)
388	}
389	w.write(`"`)
390	switch {
391	case fd.IsExtension():
392		// For message set, use the fname of the message as the extension name.
393		name := string(fd.FullName())
394		if isMessageSet(fd.ContainingMessage()) {
395			name = strings.TrimSuffix(name, ".message_set_extension")
396		}
397
398		w.write("[" + name + "]")
399	case w.OrigName:
400		name := string(fd.Name())
401		if fd.Kind() == protoreflect.GroupKind {
402			name = string(fd.Message().Name())
403		}
404		w.write(name)
405	default:
406		w.write(string(fd.JSONName()))
407	}
408	w.write(`":`)
409	if w.Indent != "" {
410		w.write(" ")
411	}
412	return w.marshalValue(fd, v, indent)
413}
414
415func (w *jsonWriter) marshalValue(fd protoreflect.FieldDescriptor, v protoreflect.Value, indent string) error {
416	switch {
417	case fd.IsList():
418		w.write("[")
419		comma := ""
420		lv := v.List()
421		for i := 0; i < lv.Len(); i++ {
422			w.write(comma)
423			if w.Indent != "" {
424				w.write("\n")
425				w.write(indent)
426				w.write(w.Indent)
427				w.write(w.Indent)
428			}
429			if err := w.marshalSingularValue(fd, lv.Get(i), indent+w.Indent); err != nil {
430				return err
431			}
432			comma = ","
433		}
434		if w.Indent != "" {
435			w.write("\n")
436			w.write(indent)
437			w.write(w.Indent)
438		}
439		w.write("]")
440		return nil
441	case fd.IsMap():
442		kfd := fd.MapKey()
443		vfd := fd.MapValue()
444		mv := v.Map()
445
446		// Collect a sorted list of all map keys and values.
447		type entry struct{ key, val protoreflect.Value }
448		var entries []entry
449		mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
450			entries = append(entries, entry{k.Value(), v})
451			return true
452		})
453		sort.Slice(entries, func(i, j int) bool {
454			switch kfd.Kind() {
455			case protoreflect.BoolKind:
456				return !entries[i].key.Bool() && entries[j].key.Bool()
457			case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
458				return entries[i].key.Int() < entries[j].key.Int()
459			case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
460				return entries[i].key.Uint() < entries[j].key.Uint()
461			case protoreflect.StringKind:
462				return entries[i].key.String() < entries[j].key.String()
463			default:
464				panic("invalid kind")
465			}
466		})
467
468		w.write(`{`)
469		comma := ""
470		for _, entry := range entries {
471			w.write(comma)
472			if w.Indent != "" {
473				w.write("\n")
474				w.write(indent)
475				w.write(w.Indent)
476				w.write(w.Indent)
477			}
478
479			s := fmt.Sprint(entry.key.Interface())
480			b, err := json.Marshal(s)
481			if err != nil {
482				return err
483			}
484			w.write(string(b))
485
486			w.write(`:`)
487			if w.Indent != "" {
488				w.write(` `)
489			}
490
491			if err := w.marshalSingularValue(vfd, entry.val, indent+w.Indent); err != nil {
492				return err
493			}
494			comma = ","
495		}
496		if w.Indent != "" {
497			w.write("\n")
498			w.write(indent)
499			w.write(w.Indent)
500		}
501		w.write(`}`)
502		return nil
503	default:
504		return w.marshalSingularValue(fd, v, indent)
505	}
506}
507
508func (w *jsonWriter) marshalSingularValue(fd protoreflect.FieldDescriptor, v protoreflect.Value, indent string) error {
509	switch {
510	case !v.IsValid():
511		w.write("null")
512		return nil
513	case fd.Message() != nil:
514		return w.marshalMessage(v.Message(), indent+w.Indent, "")
515	case fd.Enum() != nil:
516		if fd.Enum().FullName() == "google.protobuf.NullValue" {
517			w.write("null")
518			return nil
519		}
520
521		vd := fd.Enum().Values().ByNumber(v.Enum())
522		if vd == nil || w.EnumsAsInts {
523			w.write(strconv.Itoa(int(v.Enum())))
524		} else {
525			w.write(`"` + string(vd.Name()) + `"`)
526		}
527		return nil
528	default:
529		switch v.Interface().(type) {
530		case float32, float64:
531			switch {
532			case math.IsInf(v.Float(), +1):
533				w.write(`"Infinity"`)
534				return nil
535			case math.IsInf(v.Float(), -1):
536				w.write(`"-Infinity"`)
537				return nil
538			case math.IsNaN(v.Float()):
539				w.write(`"NaN"`)
540				return nil
541			}
542		case int64, uint64:
543			w.write(fmt.Sprintf(`"%d"`, v.Interface()))
544			return nil
545		}
546
547		b, err := json.Marshal(v.Interface())
548		if err != nil {
549			return err
550		}
551		w.write(string(b))
552		return nil
553	}
554}
555