1// Copyright 2018 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package main 6 7import ( 8 "strings" 9 "text/template" 10) 11 12type WireType string 13 14const ( 15 WireVarint WireType = "Varint" 16 WireFixed32 WireType = "Fixed32" 17 WireFixed64 WireType = "Fixed64" 18 WireBytes WireType = "Bytes" 19 WireGroup WireType = "Group" 20) 21 22func (w WireType) Expr() Expr { 23 if w == WireGroup { 24 return "protowire.StartGroupType" 25 } 26 return "protowire." + Expr(w) + "Type" 27} 28 29func (w WireType) Packable() bool { 30 return w == WireVarint || w == WireFixed32 || w == WireFixed64 31} 32 33func (w WireType) ConstSize() bool { 34 return w == WireFixed32 || w == WireFixed64 35} 36 37type GoType string 38 39var GoTypes = []GoType{ 40 GoBool, 41 GoInt32, 42 GoUint32, 43 GoInt64, 44 GoUint64, 45 GoFloat32, 46 GoFloat64, 47 GoString, 48 GoBytes, 49} 50 51const ( 52 GoBool = "bool" 53 GoInt32 = "int32" 54 GoUint32 = "uint32" 55 GoInt64 = "int64" 56 GoUint64 = "uint64" 57 GoFloat32 = "float32" 58 GoFloat64 = "float64" 59 GoString = "string" 60 GoBytes = "[]byte" 61) 62 63func (g GoType) Zero() Expr { 64 switch g { 65 case GoBool: 66 return "false" 67 case GoString: 68 return `""` 69 case GoBytes: 70 return "nil" 71 } 72 return "0" 73} 74 75// Kind is the reflect.Kind of the type. 76func (g GoType) Kind() Expr { 77 if g == "" || g == GoBytes { 78 return "" 79 } 80 return "reflect." + Expr(strings.ToUpper(string(g[:1]))+string(g[1:])) 81} 82 83// PointerMethod is the "internal/impl".pointer method used to access a pointer to this type. 84func (g GoType) PointerMethod() Expr { 85 if g == GoBytes { 86 return "Bytes" 87 } 88 return Expr(strings.ToUpper(string(g[:1])) + string(g[1:])) 89} 90 91type ProtoKind struct { 92 Name string 93 WireType WireType 94 95 // Conversions to/from protoreflect.Value. 96 ToValue Expr 97 FromValue Expr 98 99 // Conversions to/from generated structures. 100 GoType GoType 101 ToGoType Expr 102 ToGoTypeNoZero Expr 103 FromGoType Expr 104 NoPointer bool 105 NoValueCodec bool 106} 107 108func (k ProtoKind) Expr() Expr { 109 return "protoreflect." + Expr(k.Name) + "Kind" 110} 111 112var ProtoKinds = []ProtoKind{ 113 { 114 Name: "Bool", 115 WireType: WireVarint, 116 ToValue: "protoreflect.ValueOfBool(protowire.DecodeBool(v))", 117 FromValue: "protowire.EncodeBool(v.Bool())", 118 GoType: GoBool, 119 ToGoType: "protowire.DecodeBool(v)", 120 FromGoType: "protowire.EncodeBool(v)", 121 }, 122 { 123 Name: "Enum", 124 WireType: WireVarint, 125 ToValue: "protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))", 126 FromValue: "uint64(v.Enum())", 127 }, 128 { 129 Name: "Int32", 130 WireType: WireVarint, 131 ToValue: "protoreflect.ValueOfInt32(int32(v))", 132 FromValue: "uint64(int32(v.Int()))", 133 GoType: GoInt32, 134 ToGoType: "int32(v)", 135 FromGoType: "uint64(v)", 136 }, 137 { 138 Name: "Sint32", 139 WireType: WireVarint, 140 ToValue: "protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32)))", 141 FromValue: "protowire.EncodeZigZag(int64(int32(v.Int())))", 142 GoType: GoInt32, 143 ToGoType: "int32(protowire.DecodeZigZag(v & math.MaxUint32))", 144 FromGoType: "protowire.EncodeZigZag(int64(v))", 145 }, 146 { 147 Name: "Uint32", 148 WireType: WireVarint, 149 ToValue: "protoreflect.ValueOfUint32(uint32(v))", 150 FromValue: "uint64(uint32(v.Uint()))", 151 GoType: GoUint32, 152 ToGoType: "uint32(v)", 153 FromGoType: "uint64(v)", 154 }, 155 { 156 Name: "Int64", 157 WireType: WireVarint, 158 ToValue: "protoreflect.ValueOfInt64(int64(v))", 159 FromValue: "uint64(v.Int())", 160 GoType: GoInt64, 161 ToGoType: "int64(v)", 162 FromGoType: "uint64(v)", 163 }, 164 { 165 Name: "Sint64", 166 WireType: WireVarint, 167 ToValue: "protoreflect.ValueOfInt64(protowire.DecodeZigZag(v))", 168 FromValue: "protowire.EncodeZigZag(v.Int())", 169 GoType: GoInt64, 170 ToGoType: "protowire.DecodeZigZag(v)", 171 FromGoType: "protowire.EncodeZigZag(v)", 172 }, 173 { 174 Name: "Uint64", 175 WireType: WireVarint, 176 ToValue: "protoreflect.ValueOfUint64(v)", 177 FromValue: "v.Uint()", 178 GoType: GoUint64, 179 ToGoType: "v", 180 FromGoType: "v", 181 }, 182 { 183 Name: "Sfixed32", 184 WireType: WireFixed32, 185 ToValue: "protoreflect.ValueOfInt32(int32(v))", 186 FromValue: "uint32(v.Int())", 187 GoType: GoInt32, 188 ToGoType: "int32(v)", 189 FromGoType: "uint32(v)", 190 }, 191 { 192 Name: "Fixed32", 193 WireType: WireFixed32, 194 ToValue: "protoreflect.ValueOfUint32(uint32(v))", 195 FromValue: "uint32(v.Uint())", 196 GoType: GoUint32, 197 ToGoType: "v", 198 FromGoType: "v", 199 }, 200 { 201 Name: "Float", 202 WireType: WireFixed32, 203 ToValue: "protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))", 204 FromValue: "math.Float32bits(float32(v.Float()))", 205 GoType: GoFloat32, 206 ToGoType: "math.Float32frombits(v)", 207 FromGoType: "math.Float32bits(v)", 208 }, 209 { 210 Name: "Sfixed64", 211 WireType: WireFixed64, 212 ToValue: "protoreflect.ValueOfInt64(int64(v))", 213 FromValue: "uint64(v.Int())", 214 GoType: GoInt64, 215 ToGoType: "int64(v)", 216 FromGoType: "uint64(v)", 217 }, 218 { 219 Name: "Fixed64", 220 WireType: WireFixed64, 221 ToValue: "protoreflect.ValueOfUint64(v)", 222 FromValue: "v.Uint()", 223 GoType: GoUint64, 224 ToGoType: "v", 225 FromGoType: "v", 226 }, 227 { 228 Name: "Double", 229 WireType: WireFixed64, 230 ToValue: "protoreflect.ValueOfFloat64(math.Float64frombits(v))", 231 FromValue: "math.Float64bits(v.Float())", 232 GoType: GoFloat64, 233 ToGoType: "math.Float64frombits(v)", 234 FromGoType: "math.Float64bits(v)", 235 }, 236 { 237 Name: "String", 238 WireType: WireBytes, 239 ToValue: "protoreflect.ValueOfString(string(v))", 240 FromValue: "v.String()", 241 GoType: GoString, 242 ToGoType: "string(v)", 243 FromGoType: "v", 244 }, 245 { 246 Name: "Bytes", 247 WireType: WireBytes, 248 ToValue: "protoreflect.ValueOfBytes(append(emptyBuf[:], v...))", 249 FromValue: "v.Bytes()", 250 GoType: GoBytes, 251 ToGoType: "append(emptyBuf[:], v...)", 252 ToGoTypeNoZero: "append(([]byte)(nil), v...)", 253 FromGoType: "v", 254 NoPointer: true, 255 }, 256 { 257 Name: "Message", 258 WireType: WireBytes, 259 ToValue: "protoreflect.ValueOfBytes(v)", 260 FromValue: "v", 261 NoValueCodec: true, 262 }, 263 { 264 Name: "Group", 265 WireType: WireGroup, 266 ToValue: "protoreflect.ValueOfBytes(v)", 267 FromValue: "v", 268 NoValueCodec: true, 269 }, 270} 271 272func generateProtoDecode() string { 273 return mustExecute(protoDecodeTemplate, ProtoKinds) 274} 275 276var protoDecodeTemplate = template.Must(template.New("").Parse(` 277// unmarshalScalar decodes a value of the given kind. 278// 279// Message values are decoded into a []byte which aliases the input data. 280func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) { 281 switch fd.Kind() { 282 {{- range .}} 283 case {{.Expr}}: 284 if wtyp != {{.WireType.Expr}} { 285 return val, 0, errUnknown 286 } 287 {{if (eq .WireType "Group") -}} 288 v, n := protowire.ConsumeGroup(fd.Number(), b) 289 {{- else -}} 290 v, n := protowire.Consume{{.WireType}}(b) 291 {{- end}} 292 if n < 0 { 293 return val, 0, errDecode 294 } 295 {{if (eq .Name "String") -}} 296 if strs.EnforceUTF8(fd) && !utf8.Valid(v) { 297 return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName())) 298 } 299 {{end -}} 300 return {{.ToValue}}, n, nil 301 {{- end}} 302 default: 303 return val, 0, errUnknown 304 } 305} 306 307func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list protoreflect.List, fd protoreflect.FieldDescriptor) (n int, err error) { 308 switch fd.Kind() { 309 {{- range .}} 310 case {{.Expr}}: 311 {{- if .WireType.Packable}} 312 if wtyp == protowire.BytesType { 313 buf, n := protowire.ConsumeBytes(b) 314 if n < 0 { 315 return 0, errDecode 316 } 317 for len(buf) > 0 { 318 v, n := protowire.Consume{{.WireType}}(buf) 319 if n < 0 { 320 return 0, errDecode 321 } 322 buf = buf[n:] 323 list.Append({{.ToValue}}) 324 } 325 return n, nil 326 } 327 {{- end}} 328 if wtyp != {{.WireType.Expr}} { 329 return 0, errUnknown 330 } 331 {{if (eq .WireType "Group") -}} 332 v, n := protowire.ConsumeGroup(fd.Number(), b) 333 {{- else -}} 334 v, n := protowire.Consume{{.WireType}}(b) 335 {{- end}} 336 if n < 0 { 337 return 0, errDecode 338 } 339 {{if (eq .Name "String") -}} 340 if strs.EnforceUTF8(fd) && !utf8.Valid(v) { 341 return 0, errors.InvalidUTF8(string(fd.FullName())) 342 } 343 {{end -}} 344 {{if or (eq .Name "Message") (eq .Name "Group") -}} 345 m := list.NewElement() 346 if err := o.unmarshalMessage(v, m.Message()); err != nil { 347 return 0, err 348 } 349 list.Append(m) 350 {{- else -}} 351 list.Append({{.ToValue}}) 352 {{- end}} 353 return n, nil 354 {{- end}} 355 default: 356 return 0, errUnknown 357 } 358} 359 360// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices. 361var emptyBuf [0]byte 362`)) 363 364func generateProtoEncode() string { 365 return mustExecute(protoEncodeTemplate, ProtoKinds) 366} 367 368var protoEncodeTemplate = template.Must(template.New("").Parse(` 369var wireTypes = map[protoreflect.Kind]protowire.Type{ 370{{- range .}} 371 {{.Expr}}: {{.WireType.Expr}}, 372{{- end}} 373} 374 375func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) { 376 switch fd.Kind() { 377 {{- range .}} 378 case {{.Expr}}: 379 {{- if (eq .Name "String") }} 380 if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) { 381 return b, errors.InvalidUTF8(string(fd.FullName())) 382 } 383 b = protowire.AppendString(b, {{.FromValue}}) 384 {{- else if (eq .Name "Message") -}} 385 var pos int 386 var err error 387 b, pos = appendSpeculativeLength(b) 388 b, err = o.marshalMessage(b, v.Message()) 389 if err != nil { 390 return b, err 391 } 392 b = finishSpeculativeLength(b, pos) 393 {{- else if (eq .Name "Group") -}} 394 var err error 395 b, err = o.marshalMessage(b, v.Message()) 396 if err != nil { 397 return b, err 398 } 399 b = protowire.AppendVarint(b, protowire.EncodeTag(fd.Number(), protowire.EndGroupType)) 400 {{- else -}} 401 b = protowire.Append{{.WireType}}(b, {{.FromValue}}) 402 {{- end}} 403 {{- end}} 404 default: 405 return b, errors.New("invalid kind %v", fd.Kind()) 406 } 407 return b, nil 408} 409`)) 410 411func generateProtoSize() string { 412 return mustExecute(protoSizeTemplate, ProtoKinds) 413} 414 415var protoSizeTemplate = template.Must(template.New("").Parse(` 416func (o MarshalOptions) sizeSingular(num protowire.Number, kind protoreflect.Kind, v protoreflect.Value) int { 417 switch kind { 418 {{- range .}} 419 case {{.Expr}}: 420 {{if (eq .Name "Message") -}} 421 return protowire.SizeBytes(o.size(v.Message())) 422 {{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}} 423 return protowire.Size{{.WireType}}() 424 {{- else if (eq .WireType "Bytes") -}} 425 return protowire.Size{{.WireType}}(len({{.FromValue}})) 426 {{- else if (eq .WireType "Group") -}} 427 return protowire.Size{{.WireType}}(num, o.size(v.Message())) 428 {{- else -}} 429 return protowire.Size{{.WireType}}({{.FromValue}}) 430 {{- end}} 431 {{- end}} 432 default: 433 return 0 434 } 435} 436`)) 437