1// Copyright 2019 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package messageset encodes and decodes the obsolete MessageSet wire format. 6package messageset 7 8import ( 9 "math" 10 11 "google.golang.org/protobuf/encoding/protowire" 12 "google.golang.org/protobuf/internal/errors" 13 pref "google.golang.org/protobuf/reflect/protoreflect" 14 preg "google.golang.org/protobuf/reflect/protoregistry" 15) 16 17// The MessageSet wire format is equivalent to a message defiend as follows, 18// where each Item defines an extension field with a field number of 'type_id' 19// and content of 'message'. MessageSet extensions must be non-repeated message 20// fields. 21// 22// message MessageSet { 23// repeated group Item = 1 { 24// required int32 type_id = 2; 25// required string message = 3; 26// } 27// } 28const ( 29 FieldItem = protowire.Number(1) 30 FieldTypeID = protowire.Number(2) 31 FieldMessage = protowire.Number(3) 32) 33 34// ExtensionName is the field name for extensions of MessageSet. 35// 36// A valid MessageSet extension must be of the form: 37// message MyMessage { 38// extend proto2.bridge.MessageSet { 39// optional MyMessage message_set_extension = 1234; 40// } 41// ... 42// } 43const ExtensionName = "message_set_extension" 44 45// IsMessageSet returns whether the message uses the MessageSet wire format. 46func IsMessageSet(md pref.MessageDescriptor) bool { 47 xmd, ok := md.(interface{ IsMessageSet() bool }) 48 return ok && xmd.IsMessageSet() 49} 50 51// IsMessageSetExtension reports this field extends a MessageSet. 52func IsMessageSetExtension(fd pref.FieldDescriptor) bool { 53 if fd.Name() != ExtensionName { 54 return false 55 } 56 if fd.FullName().Parent() != fd.Message().FullName() { 57 return false 58 } 59 return IsMessageSet(fd.ContainingMessage()) 60} 61 62// FindMessageSetExtension locates a MessageSet extension field by name. 63// In text and JSON formats, the extension name used is the message itself. 64// The extension field name is derived by appending ExtensionName. 65func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) { 66 name := s.Append(ExtensionName) 67 xt, err := r.FindExtensionByName(name) 68 if err != nil { 69 if err == preg.NotFound { 70 return nil, err 71 } 72 return nil, errors.Wrap(err, "%q", name) 73 } 74 if !IsMessageSetExtension(xt.TypeDescriptor()) { 75 return nil, preg.NotFound 76 } 77 return xt, nil 78} 79 80// SizeField returns the size of a MessageSet item field containing an extension 81// with the given field number, not counting the contents of the message subfield. 82func SizeField(num protowire.Number) int { 83 return 2*protowire.SizeTag(FieldItem) + protowire.SizeTag(FieldTypeID) + protowire.SizeVarint(uint64(num)) 84} 85 86// Unmarshal parses a MessageSet. 87// 88// It calls fn with the type ID and value of each item in the MessageSet. 89// Unknown fields are discarded. 90// 91// If wantLen is true, the item values include the varint length prefix. 92// This is ugly, but simplifies the fast-path decoder in internal/impl. 93func Unmarshal(b []byte, wantLen bool, fn func(typeID protowire.Number, value []byte) error) error { 94 for len(b) > 0 { 95 num, wtyp, n := protowire.ConsumeTag(b) 96 if n < 0 { 97 return protowire.ParseError(n) 98 } 99 b = b[n:] 100 if num != FieldItem || wtyp != protowire.StartGroupType { 101 n := protowire.ConsumeFieldValue(num, wtyp, b) 102 if n < 0 { 103 return protowire.ParseError(n) 104 } 105 b = b[n:] 106 continue 107 } 108 typeID, value, n, err := ConsumeFieldValue(b, wantLen) 109 if err != nil { 110 return err 111 } 112 b = b[n:] 113 if typeID == 0 { 114 continue 115 } 116 if err := fn(typeID, value); err != nil { 117 return err 118 } 119 } 120 return nil 121} 122 123// ConsumeFieldValue parses b as a MessageSet item field value until and including 124// the trailing end group marker. It assumes the start group tag has already been parsed. 125// It returns the contents of the type_id and message subfields and the total 126// item length. 127// 128// If wantLen is true, the returned message value includes the length prefix. 129func ConsumeFieldValue(b []byte, wantLen bool) (typeid protowire.Number, message []byte, n int, err error) { 130 ilen := len(b) 131 for { 132 num, wtyp, n := protowire.ConsumeTag(b) 133 if n < 0 { 134 return 0, nil, 0, protowire.ParseError(n) 135 } 136 b = b[n:] 137 switch { 138 case num == FieldItem && wtyp == protowire.EndGroupType: 139 if wantLen && len(message) == 0 { 140 // The message field was missing, which should never happen. 141 // Be prepared for this case anyway. 142 message = protowire.AppendVarint(message, 0) 143 } 144 return typeid, message, ilen - len(b), nil 145 case num == FieldTypeID && wtyp == protowire.VarintType: 146 v, n := protowire.ConsumeVarint(b) 147 if n < 0 { 148 return 0, nil, 0, protowire.ParseError(n) 149 } 150 b = b[n:] 151 if v < 1 || v > math.MaxInt32 { 152 return 0, nil, 0, errors.New("invalid type_id in message set") 153 } 154 typeid = protowire.Number(v) 155 case num == FieldMessage && wtyp == protowire.BytesType: 156 m, n := protowire.ConsumeBytes(b) 157 if n < 0 { 158 return 0, nil, 0, protowire.ParseError(n) 159 } 160 if message == nil { 161 if wantLen { 162 message = b[:n:n] 163 } else { 164 message = m[:len(m):len(m)] 165 } 166 } else { 167 // This case should never happen in practice, but handle it for 168 // correctness: The MessageSet item contains multiple message 169 // fields, which need to be merged. 170 // 171 // In the case where we're returning the length, this becomes 172 // quite inefficient since we need to strip the length off 173 // the existing data and reconstruct it with the combined length. 174 if wantLen { 175 _, nn := protowire.ConsumeVarint(message) 176 m0 := message[nn:] 177 message = nil 178 message = protowire.AppendVarint(message, uint64(len(m0)+len(m))) 179 message = append(message, m0...) 180 message = append(message, m...) 181 } else { 182 message = append(message, m...) 183 } 184 } 185 b = b[n:] 186 default: 187 // We have no place to put it, so we just ignore unknown fields. 188 n := protowire.ConsumeFieldValue(num, wtyp, b) 189 if n < 0 { 190 return 0, nil, 0, protowire.ParseError(n) 191 } 192 b = b[n:] 193 } 194 } 195} 196 197// AppendFieldStart appends the start of a MessageSet item field containing 198// an extension with the given number. The caller must add the message 199// subfield (including the tag). 200func AppendFieldStart(b []byte, num protowire.Number) []byte { 201 b = protowire.AppendTag(b, FieldItem, protowire.StartGroupType) 202 b = protowire.AppendTag(b, FieldTypeID, protowire.VarintType) 203 b = protowire.AppendVarint(b, uint64(num)) 204 return b 205} 206 207// AppendFieldEnd appends the trailing end group marker for a MessageSet item field. 208func AppendFieldEnd(b []byte) []byte { 209 return protowire.AppendTag(b, FieldItem, protowire.EndGroupType) 210} 211 212// SizeUnknown returns the size of an unknown fields section in MessageSet format. 213// 214// See AppendUnknown. 215func SizeUnknown(unknown []byte) (size int) { 216 for len(unknown) > 0 { 217 num, typ, n := protowire.ConsumeTag(unknown) 218 if n < 0 || typ != protowire.BytesType { 219 return 0 220 } 221 unknown = unknown[n:] 222 _, n = protowire.ConsumeBytes(unknown) 223 if n < 0 { 224 return 0 225 } 226 unknown = unknown[n:] 227 size += SizeField(num) + protowire.SizeTag(FieldMessage) + n 228 } 229 return size 230} 231 232// AppendUnknown appends unknown fields to b in MessageSet format. 233// 234// For historic reasons, unresolved items in a MessageSet are stored in a 235// message's unknown fields section in non-MessageSet format. That is, an 236// unknown item with typeID T and value V appears in the unknown fields as 237// a field with number T and value V. 238// 239// This function converts the unknown fields back into MessageSet form. 240func AppendUnknown(b, unknown []byte) ([]byte, error) { 241 for len(unknown) > 0 { 242 num, typ, n := protowire.ConsumeTag(unknown) 243 if n < 0 || typ != protowire.BytesType { 244 return nil, errors.New("invalid data in message set unknown fields") 245 } 246 unknown = unknown[n:] 247 _, n = protowire.ConsumeBytes(unknown) 248 if n < 0 { 249 return nil, errors.New("invalid data in message set unknown fields") 250 } 251 b = AppendFieldStart(b, num) 252 b = protowire.AppendTag(b, FieldMessage, protowire.BytesType) 253 b = append(b, unknown[:n]...) 254 b = AppendFieldEnd(b) 255 unknown = unknown[n:] 256 } 257 return b, nil 258} 259