1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"bytes"
9	"encoding"
10	"fmt"
11	"io"
12	"math"
13	"sort"
14	"strings"
15
16	"google.golang.org/protobuf/encoding/prototext"
17	"google.golang.org/protobuf/encoding/protowire"
18	"google.golang.org/protobuf/proto"
19	"google.golang.org/protobuf/reflect/protoreflect"
20	"google.golang.org/protobuf/reflect/protoregistry"
21)
22
23const wrapTextMarshalV2 = false
24
25// TextMarshaler is a configurable text format marshaler.
26type TextMarshaler struct {
27	Compact   bool // use compact text format (one line)
28	ExpandAny bool // expand google.protobuf.Any messages of known types
29}
30
31// Marshal writes the proto text format of m to w.
32func (tm *TextMarshaler) Marshal(w io.Writer, m Message) error {
33	b, err := tm.marshal(m)
34	if len(b) > 0 {
35		if _, err := w.Write(b); err != nil {
36			return err
37		}
38	}
39	return err
40}
41
42// Text returns a proto text formatted string of m.
43func (tm *TextMarshaler) Text(m Message) string {
44	b, _ := tm.marshal(m)
45	return string(b)
46}
47
48func (tm *TextMarshaler) marshal(m Message) ([]byte, error) {
49	mr := MessageReflect(m)
50	if mr == nil || !mr.IsValid() {
51		return []byte("<nil>"), nil
52	}
53
54	if wrapTextMarshalV2 {
55		if m, ok := m.(encoding.TextMarshaler); ok {
56			return m.MarshalText()
57		}
58
59		opts := prototext.MarshalOptions{
60			AllowPartial: true,
61			EmitUnknown:  true,
62		}
63		if !tm.Compact {
64			opts.Indent = "  "
65		}
66		if !tm.ExpandAny {
67			opts.Resolver = (*protoregistry.Types)(nil)
68		}
69		return opts.Marshal(mr.Interface())
70	} else {
71		w := &textWriter{
72			compact:   tm.Compact,
73			expandAny: tm.ExpandAny,
74			complete:  true,
75		}
76
77		if m, ok := m.(encoding.TextMarshaler); ok {
78			b, err := m.MarshalText()
79			if err != nil {
80				return nil, err
81			}
82			w.Write(b)
83			return w.buf, nil
84		}
85
86		err := w.writeMessage(mr)
87		return w.buf, err
88	}
89}
90
91var (
92	defaultTextMarshaler = TextMarshaler{}
93	compactTextMarshaler = TextMarshaler{Compact: true}
94)
95
96// MarshalText writes the proto text format of m to w.
97func MarshalText(w io.Writer, m Message) error { return defaultTextMarshaler.Marshal(w, m) }
98
99// MarshalTextString returns a proto text formatted string of m.
100func MarshalTextString(m Message) string { return defaultTextMarshaler.Text(m) }
101
102// CompactText writes the compact proto text format of m to w.
103func CompactText(w io.Writer, m Message) error { return compactTextMarshaler.Marshal(w, m) }
104
105// CompactTextString returns a compact proto text formatted string of m.
106func CompactTextString(m Message) string { return compactTextMarshaler.Text(m) }
107
108var (
109	newline         = []byte("\n")
110	endBraceNewline = []byte("}\n")
111	posInf          = []byte("inf")
112	negInf          = []byte("-inf")
113	nan             = []byte("nan")
114)
115
116// textWriter is an io.Writer that tracks its indentation level.
117type textWriter struct {
118	compact   bool // same as TextMarshaler.Compact
119	expandAny bool // same as TextMarshaler.ExpandAny
120	complete  bool // whether the current position is a complete line
121	indent    int  // indentation level; never negative
122	buf       []byte
123}
124
125func (w *textWriter) Write(p []byte) (n int, _ error) {
126	newlines := bytes.Count(p, newline)
127	if newlines == 0 {
128		if !w.compact && w.complete {
129			w.writeIndent()
130		}
131		w.buf = append(w.buf, p...)
132		w.complete = false
133		return len(p), nil
134	}
135
136	frags := bytes.SplitN(p, newline, newlines+1)
137	if w.compact {
138		for i, frag := range frags {
139			if i > 0 {
140				w.buf = append(w.buf, ' ')
141				n++
142			}
143			w.buf = append(w.buf, frag...)
144			n += len(frag)
145		}
146		return n, nil
147	}
148
149	for i, frag := range frags {
150		if w.complete {
151			w.writeIndent()
152		}
153		w.buf = append(w.buf, frag...)
154		n += len(frag)
155		if i+1 < len(frags) {
156			w.buf = append(w.buf, '\n')
157			n++
158		}
159	}
160	w.complete = len(frags[len(frags)-1]) == 0
161	return n, nil
162}
163
164func (w *textWriter) WriteByte(c byte) error {
165	if w.compact && c == '\n' {
166		c = ' '
167	}
168	if !w.compact && w.complete {
169		w.writeIndent()
170	}
171	w.buf = append(w.buf, c)
172	w.complete = c == '\n'
173	return nil
174}
175
176func (w *textWriter) writeName(fd protoreflect.FieldDescriptor) {
177	if !w.compact && w.complete {
178		w.writeIndent()
179	}
180	w.complete = false
181
182	if fd.Kind() != protoreflect.GroupKind {
183		w.buf = append(w.buf, fd.Name()...)
184		w.WriteByte(':')
185	} else {
186		// Use message type name for group field name.
187		w.buf = append(w.buf, fd.Message().Name()...)
188	}
189
190	if !w.compact {
191		w.WriteByte(' ')
192	}
193}
194
195func requiresQuotes(u string) bool {
196	// When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted.
197	for _, ch := range u {
198		switch {
199		case ch == '.' || ch == '/' || ch == '_':
200			continue
201		case '0' <= ch && ch <= '9':
202			continue
203		case 'A' <= ch && ch <= 'Z':
204			continue
205		case 'a' <= ch && ch <= 'z':
206			continue
207		default:
208			return true
209		}
210	}
211	return false
212}
213
214// writeProto3Any writes an expanded google.protobuf.Any message.
215//
216// It returns (false, nil) if sv value can't be unmarshaled (e.g. because
217// required messages are not linked in).
218//
219// It returns (true, error) when sv was written in expanded format or an error
220// was encountered.
221func (w *textWriter) writeProto3Any(m protoreflect.Message) (bool, error) {
222	md := m.Descriptor()
223	fdURL := md.Fields().ByName("type_url")
224	fdVal := md.Fields().ByName("value")
225
226	url := m.Get(fdURL).String()
227	mt, err := protoregistry.GlobalTypes.FindMessageByURL(url)
228	if err != nil {
229		return false, nil
230	}
231
232	b := m.Get(fdVal).Bytes()
233	m2 := mt.New()
234	if err := proto.Unmarshal(b, m2.Interface()); err != nil {
235		return false, nil
236	}
237	w.Write([]byte("["))
238	if requiresQuotes(url) {
239		w.writeQuotedString(url)
240	} else {
241		w.Write([]byte(url))
242	}
243	if w.compact {
244		w.Write([]byte("]:<"))
245	} else {
246		w.Write([]byte("]: <\n"))
247		w.indent++
248	}
249	if err := w.writeMessage(m2); err != nil {
250		return true, err
251	}
252	if w.compact {
253		w.Write([]byte("> "))
254	} else {
255		w.indent--
256		w.Write([]byte(">\n"))
257	}
258	return true, nil
259}
260
261func (w *textWriter) writeMessage(m protoreflect.Message) error {
262	md := m.Descriptor()
263	if w.expandAny && md.FullName() == "google.protobuf.Any" {
264		if canExpand, err := w.writeProto3Any(m); canExpand {
265			return err
266		}
267	}
268
269	fds := md.Fields()
270	for i := 0; i < fds.Len(); {
271		fd := fds.Get(i)
272		if od := fd.ContainingOneof(); od != nil {
273			fd = m.WhichOneof(od)
274			i += od.Fields().Len()
275		} else {
276			i++
277		}
278		if fd == nil || !m.Has(fd) {
279			continue
280		}
281
282		switch {
283		case fd.IsList():
284			lv := m.Get(fd).List()
285			for j := 0; j < lv.Len(); j++ {
286				w.writeName(fd)
287				v := lv.Get(j)
288				if err := w.writeSingularValue(v, fd); err != nil {
289					return err
290				}
291				w.WriteByte('\n')
292			}
293		case fd.IsMap():
294			kfd := fd.MapKey()
295			vfd := fd.MapValue()
296			mv := m.Get(fd).Map()
297
298			type entry struct{ key, val protoreflect.Value }
299			var entries []entry
300			mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
301				entries = append(entries, entry{k.Value(), v})
302				return true
303			})
304			sort.Slice(entries, func(i, j int) bool {
305				switch kfd.Kind() {
306				case protoreflect.BoolKind:
307					return !entries[i].key.Bool() && entries[j].key.Bool()
308				case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
309					return entries[i].key.Int() < entries[j].key.Int()
310				case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
311					return entries[i].key.Uint() < entries[j].key.Uint()
312				case protoreflect.StringKind:
313					return entries[i].key.String() < entries[j].key.String()
314				default:
315					panic("invalid kind")
316				}
317			})
318			for _, entry := range entries {
319				w.writeName(fd)
320				w.WriteByte('<')
321				if !w.compact {
322					w.WriteByte('\n')
323				}
324				w.indent++
325				w.writeName(kfd)
326				if err := w.writeSingularValue(entry.key, kfd); err != nil {
327					return err
328				}
329				w.WriteByte('\n')
330				w.writeName(vfd)
331				if err := w.writeSingularValue(entry.val, vfd); err != nil {
332					return err
333				}
334				w.WriteByte('\n')
335				w.indent--
336				w.WriteByte('>')
337				w.WriteByte('\n')
338			}
339		default:
340			w.writeName(fd)
341			if err := w.writeSingularValue(m.Get(fd), fd); err != nil {
342				return err
343			}
344			w.WriteByte('\n')
345		}
346	}
347
348	if b := m.GetUnknown(); len(b) > 0 {
349		w.writeUnknownFields(b)
350	}
351	return w.writeExtensions(m)
352}
353
354func (w *textWriter) writeSingularValue(v protoreflect.Value, fd protoreflect.FieldDescriptor) error {
355	switch fd.Kind() {
356	case protoreflect.FloatKind, protoreflect.DoubleKind:
357		switch vf := v.Float(); {
358		case math.IsInf(vf, +1):
359			w.Write(posInf)
360		case math.IsInf(vf, -1):
361			w.Write(negInf)
362		case math.IsNaN(vf):
363			w.Write(nan)
364		default:
365			fmt.Fprint(w, v.Interface())
366		}
367	case protoreflect.StringKind:
368		// NOTE: This does not validate UTF-8 for historical reasons.
369		w.writeQuotedString(string(v.String()))
370	case protoreflect.BytesKind:
371		w.writeQuotedString(string(v.Bytes()))
372	case protoreflect.MessageKind, protoreflect.GroupKind:
373		var bra, ket byte = '<', '>'
374		if fd.Kind() == protoreflect.GroupKind {
375			bra, ket = '{', '}'
376		}
377		w.WriteByte(bra)
378		if !w.compact {
379			w.WriteByte('\n')
380		}
381		w.indent++
382		m := v.Message()
383		if m2, ok := m.Interface().(encoding.TextMarshaler); ok {
384			b, err := m2.MarshalText()
385			if err != nil {
386				return err
387			}
388			w.Write(b)
389		} else {
390			w.writeMessage(m)
391		}
392		w.indent--
393		w.WriteByte(ket)
394	case protoreflect.EnumKind:
395		if ev := fd.Enum().Values().ByNumber(v.Enum()); ev != nil {
396			fmt.Fprint(w, ev.Name())
397		} else {
398			fmt.Fprint(w, v.Enum())
399		}
400	default:
401		fmt.Fprint(w, v.Interface())
402	}
403	return nil
404}
405
406// writeQuotedString writes a quoted string in the protocol buffer text format.
407func (w *textWriter) writeQuotedString(s string) {
408	w.WriteByte('"')
409	for i := 0; i < len(s); i++ {
410		switch c := s[i]; c {
411		case '\n':
412			w.buf = append(w.buf, `\n`...)
413		case '\r':
414			w.buf = append(w.buf, `\r`...)
415		case '\t':
416			w.buf = append(w.buf, `\t`...)
417		case '"':
418			w.buf = append(w.buf, `\"`...)
419		case '\\':
420			w.buf = append(w.buf, `\\`...)
421		default:
422			if isPrint := c >= 0x20 && c < 0x7f; isPrint {
423				w.buf = append(w.buf, c)
424			} else {
425				w.buf = append(w.buf, fmt.Sprintf(`\%03o`, c)...)
426			}
427		}
428	}
429	w.WriteByte('"')
430}
431
432func (w *textWriter) writeUnknownFields(b []byte) {
433	if !w.compact {
434		fmt.Fprintf(w, "/* %d unknown bytes */\n", len(b))
435	}
436
437	for len(b) > 0 {
438		num, wtyp, n := protowire.ConsumeTag(b)
439		if n < 0 {
440			return
441		}
442		b = b[n:]
443
444		if wtyp == protowire.EndGroupType {
445			w.indent--
446			w.Write(endBraceNewline)
447			continue
448		}
449		fmt.Fprint(w, num)
450		if wtyp != protowire.StartGroupType {
451			w.WriteByte(':')
452		}
453		if !w.compact || wtyp == protowire.StartGroupType {
454			w.WriteByte(' ')
455		}
456		switch wtyp {
457		case protowire.VarintType:
458			v, n := protowire.ConsumeVarint(b)
459			if n < 0 {
460				return
461			}
462			b = b[n:]
463			fmt.Fprint(w, v)
464		case protowire.Fixed32Type:
465			v, n := protowire.ConsumeFixed32(b)
466			if n < 0 {
467				return
468			}
469			b = b[n:]
470			fmt.Fprint(w, v)
471		case protowire.Fixed64Type:
472			v, n := protowire.ConsumeFixed64(b)
473			if n < 0 {
474				return
475			}
476			b = b[n:]
477			fmt.Fprint(w, v)
478		case protowire.BytesType:
479			v, n := protowire.ConsumeBytes(b)
480			if n < 0 {
481				return
482			}
483			b = b[n:]
484			fmt.Fprintf(w, "%q", v)
485		case protowire.StartGroupType:
486			w.WriteByte('{')
487			w.indent++
488		default:
489			fmt.Fprintf(w, "/* unknown wire type %d */", wtyp)
490		}
491		w.WriteByte('\n')
492	}
493}
494
495// writeExtensions writes all the extensions in m.
496func (w *textWriter) writeExtensions(m protoreflect.Message) error {
497	md := m.Descriptor()
498	if md.ExtensionRanges().Len() == 0 {
499		return nil
500	}
501
502	type ext struct {
503		desc protoreflect.FieldDescriptor
504		val  protoreflect.Value
505	}
506	var exts []ext
507	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
508		if fd.IsExtension() {
509			exts = append(exts, ext{fd, v})
510		}
511		return true
512	})
513	sort.Slice(exts, func(i, j int) bool {
514		return exts[i].desc.Number() < exts[j].desc.Number()
515	})
516
517	for _, ext := range exts {
518		// For message set, use the name of the message as the extension name.
519		name := string(ext.desc.FullName())
520		if isMessageSet(ext.desc.ContainingMessage()) {
521			name = strings.TrimSuffix(name, ".message_set_extension")
522		}
523
524		if !ext.desc.IsList() {
525			if err := w.writeSingularExtension(name, ext.val, ext.desc); err != nil {
526				return err
527			}
528		} else {
529			lv := ext.val.List()
530			for i := 0; i < lv.Len(); i++ {
531				if err := w.writeSingularExtension(name, lv.Get(i), ext.desc); err != nil {
532					return err
533				}
534			}
535		}
536	}
537	return nil
538}
539
540func (w *textWriter) writeSingularExtension(name string, v protoreflect.Value, fd protoreflect.FieldDescriptor) error {
541	fmt.Fprintf(w, "[%s]:", name)
542	if !w.compact {
543		w.WriteByte(' ')
544	}
545	if err := w.writeSingularValue(v, fd); err != nil {
546		return err
547	}
548	w.WriteByte('\n')
549	return nil
550}
551
552func (w *textWriter) writeIndent() {
553	if !w.complete {
554		return
555	}
556	for i := 0; i < w.indent*2; i++ {
557		w.buf = append(w.buf, ' ')
558	}
559	w.complete = false
560}
561