1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"google.golang.org/protobuf/reflect/protoreflect"
9)
10
11// DiscardUnknown recursively discards all unknown fields from this message
12// and all embedded messages.
13//
14// When unmarshaling a message with unrecognized fields, the tags and values
15// of such fields are preserved in the Message. This allows a later call to
16// marshal to be able to produce a message that continues to have those
17// unrecognized fields. To avoid this, DiscardUnknown is used to
18// explicitly clear the unknown fields after unmarshaling.
19func DiscardUnknown(m Message) {
20	if m != nil {
21		discardUnknown(MessageReflect(m))
22	}
23}
24
25func discardUnknown(m protoreflect.Message) {
26	m.Range(func(fd protoreflect.FieldDescriptor, val protoreflect.Value) bool {
27		switch {
28		// Handle singular message.
29		case fd.Cardinality() != protoreflect.Repeated:
30			if fd.Message() != nil {
31				discardUnknown(m.Get(fd).Message())
32			}
33		// Handle list of messages.
34		case fd.IsList():
35			if fd.Message() != nil {
36				ls := m.Get(fd).List()
37				for i := 0; i < ls.Len(); i++ {
38					discardUnknown(ls.Get(i).Message())
39				}
40			}
41		// Handle map of messages.
42		case fd.IsMap():
43			if fd.MapValue().Message() != nil {
44				ms := m.Get(fd).Map()
45				ms.Range(func(_ protoreflect.MapKey, v protoreflect.Value) bool {
46					discardUnknown(v.Message())
47					return true
48				})
49			}
50		}
51		return true
52	})
53
54	// Discard unknown fields.
55	if len(m.GetUnknown()) > 0 {
56		m.SetUnknown(nil)
57	}
58}
59