1// Copyright 2020 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package protorange 6 7import ( 8 "testing" 9 "time" 10 11 "github.com/google/go-cmp/cmp" 12 "github.com/google/go-cmp/cmp/cmpopts" 13 "google.golang.org/protobuf/proto" 14 "google.golang.org/protobuf/reflect/protopath" 15 "google.golang.org/protobuf/reflect/protoreflect" 16 "google.golang.org/protobuf/reflect/protoregistry" 17 "google.golang.org/protobuf/testing/protocmp" 18 19 newspb "google.golang.org/protobuf/internal/testprotos/news" 20 anypb "google.golang.org/protobuf/types/known/anypb" 21 timestamppb "google.golang.org/protobuf/types/known/timestamppb" 22) 23 24func mustMarshal(m proto.Message) []byte { 25 b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m) 26 if err != nil { 27 panic(err) 28 } 29 return b 30} 31 32var transformReflectValue = cmp.Transformer("", func(v protoreflect.Value) interface{} { 33 switch v := v.Interface().(type) { 34 case protoreflect.Message: 35 return v.Interface() 36 case protoreflect.Map: 37 ms := map[interface{}]protoreflect.Value{} 38 v.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { 39 ms[k.Interface()] = v 40 return true 41 }) 42 return ms 43 case protoreflect.List: 44 ls := []protoreflect.Value{} 45 for i := 0; i < v.Len(); i++ { 46 ls = append(ls, v.Get(i)) 47 } 48 return ls 49 default: 50 return v 51 } 52}) 53 54func TestRange(t *testing.T) { 55 m2 := (&newspb.KeyValueAttachment{ 56 Name: "checksums.txt", 57 Data: map[string]string{ 58 "go1.10.src.tar.gz": "07cbb9d0091b846c6aea40bf5bc0cea7", 59 "go1.10.darwin-amd64.pkg": "cbb38bb6ff6ea86279e01745984445bf", 60 "go1.10.linux-amd64.tar.gz": "6b3d0e4a5c77352cf4275573817f7566", 61 "go1.10.windows-amd64.msi": "57bda02030f58f5d2bf71943e1390123", 62 }, 63 }).ProtoReflect() 64 m := (&newspb.Article{ 65 Author: "Brad Fitzpatrick", 66 Date: timestamppb.New(time.Date(2018, time.February, 16, 0, 0, 0, 0, time.UTC)), 67 Title: "Go 1.10 is released", 68 Content: "Happy Friday, happy weekend! Today the Go team is happy to announce the release of Go 1.10...", 69 Status: newspb.Article_PUBLISHED, 70 Tags: []string{"go1.10", "release"}, 71 Attachments: []*anypb.Any{{ 72 TypeUrl: "google.golang.org.KeyValueAttachment", 73 Value: mustMarshal(m2.Interface()), 74 }}, 75 }).ProtoReflect() 76 77 // Nil push and pop functions should not panic. 78 noop := func(protopath.Values) error { return nil } 79 Options{}.Range(m, nil, nil) 80 Options{}.Range(m, noop, nil) 81 Options{}.Range(m, nil, noop) 82 83 getByName := func(m protoreflect.Message, s protoreflect.Name) protoreflect.Value { 84 fds := m.Descriptor().Fields() 85 return m.Get(fds.ByName(s)) 86 } 87 88 wantPaths := []string{ 89 ``, 90 `.author`, 91 `.date`, 92 `.date.seconds`, 93 `.title`, 94 `.content`, 95 `.attachments`, 96 `.attachments[0]`, 97 `.attachments[0].(google.golang.org.KeyValueAttachment)`, 98 `.attachments[0].(google.golang.org.KeyValueAttachment).name`, 99 `.attachments[0].(google.golang.org.KeyValueAttachment).data`, 100 `.attachments[0].(google.golang.org.KeyValueAttachment).data["go1.10.darwin-amd64.pkg"]`, 101 `.attachments[0].(google.golang.org.KeyValueAttachment).data["go1.10.linux-amd64.tar.gz"]`, 102 `.attachments[0].(google.golang.org.KeyValueAttachment).data["go1.10.src.tar.gz"]`, 103 `.attachments[0].(google.golang.org.KeyValueAttachment).data["go1.10.windows-amd64.msi"]`, 104 `.tags`, 105 `.tags[0]`, 106 `.tags[1]`, 107 `.status`, 108 } 109 wantValues := []protoreflect.Value{ 110 protoreflect.ValueOfMessage(m), 111 getByName(m, "author"), 112 getByName(m, "date"), 113 getByName(getByName(m, "date").Message(), "seconds"), 114 getByName(m, `title`), 115 getByName(m, `content`), 116 getByName(m, `attachments`), 117 getByName(m, `attachments`).List().Get(0), 118 protoreflect.ValueOfMessage(m2), 119 getByName(m2, `name`), 120 getByName(m2, `data`), 121 getByName(m2, `data`).Map().Get(protoreflect.ValueOfString("go1.10.darwin-amd64.pkg").MapKey()), 122 getByName(m2, `data`).Map().Get(protoreflect.ValueOfString("go1.10.linux-amd64.tar.gz").MapKey()), 123 getByName(m2, `data`).Map().Get(protoreflect.ValueOfString("go1.10.src.tar.gz").MapKey()), 124 getByName(m2, `data`).Map().Get(protoreflect.ValueOfString("go1.10.windows-amd64.msi").MapKey()), 125 getByName(m, `tags`), 126 getByName(m, `tags`).List().Get(0), 127 getByName(m, `tags`).List().Get(1), 128 getByName(m, `status`), 129 } 130 131 tests := []struct { 132 resolver interface { 133 protoregistry.ExtensionTypeResolver 134 protoregistry.MessageTypeResolver 135 } 136 137 errorAt int 138 breakAt int 139 terminateAt int 140 141 wantPaths []string 142 wantValues []protoreflect.Value 143 wantError error 144 }{{ 145 wantPaths: wantPaths, 146 wantValues: wantValues, 147 }, { 148 resolver: (*protoregistry.Types)(nil), 149 wantPaths: append(append(wantPaths[:8:8], 150 `.attachments[0].type_url`, 151 `.attachments[0].value`, 152 ), wantPaths[15:]...), 153 wantValues: append(append(wantValues[:8:8], 154 protoreflect.ValueOfString("google.golang.org.KeyValueAttachment"), 155 protoreflect.ValueOfBytes(mustMarshal(m2.Interface())), 156 ), wantValues[15:]...), 157 }, { 158 errorAt: 5, // return error within newspb.Article 159 wantPaths: wantPaths[:5], 160 wantValues: wantValues[:5], 161 wantError: cmpopts.AnyError, 162 }, { 163 terminateAt: 11, // terminate within newspb.KeyValueAttachment 164 wantPaths: wantPaths[:11], 165 wantValues: wantValues[:11], 166 }, { 167 breakAt: 11, // break within newspb.KeyValueAttachment 168 wantPaths: append(wantPaths[:11:11], wantPaths[15:]...), 169 wantValues: append(wantValues[:11:11], wantValues[15:]...), 170 }, { 171 errorAt: 17, // return error within newspb.Article.Tags 172 wantPaths: wantPaths[:17], 173 wantValues: wantValues[:17], 174 wantError: cmpopts.AnyError, 175 }, { 176 breakAt: 17, // break within newspb.Article.Tags 177 wantPaths: append(wantPaths[:17:17], wantPaths[18:]...), 178 wantValues: append(wantValues[:17:17], wantValues[18:]...), 179 }, { 180 terminateAt: 17, // terminate within newspb.Article.Tags 181 wantPaths: wantPaths[:17], 182 wantValues: wantValues[:17], 183 }, { 184 errorAt: 13, // return error within newspb.KeyValueAttachment.Data 185 wantPaths: wantPaths[:13], 186 wantValues: wantValues[:13], 187 wantError: cmpopts.AnyError, 188 }, { 189 breakAt: 13, // break within newspb.KeyValueAttachment.Data 190 wantPaths: append(wantPaths[:13:13], wantPaths[15:]...), 191 wantValues: append(wantValues[:13:13], wantValues[15:]...), 192 }, { 193 terminateAt: 13, // terminate within newspb.KeyValueAttachment.Data 194 wantPaths: wantPaths[:13], 195 wantValues: wantValues[:13], 196 }} 197 for _, tt := range tests { 198 t.Run("", func(t *testing.T) { 199 var gotPaths []string 200 var gotValues []protoreflect.Value 201 var stackPaths []string 202 var stackValues []protoreflect.Value 203 gotError := Options{ 204 Stable: true, 205 Resolver: tt.resolver, 206 }.Range(m, 207 func(p protopath.Values) error { 208 gotPaths = append(gotPaths, p.Path[1:].String()) 209 stackPaths = append(stackPaths, p.Path[1:].String()) 210 gotValues = append(gotValues, p.Index(-1).Value) 211 stackValues = append(stackValues, p.Index(-1).Value) 212 switch { 213 case tt.errorAt > 0 && tt.errorAt == len(gotPaths): 214 return cmpopts.AnyError 215 case tt.breakAt > 0 && tt.breakAt == len(gotPaths): 216 return Break 217 case tt.terminateAt > 0 && tt.terminateAt == len(gotPaths): 218 return Terminate 219 default: 220 return nil 221 } 222 }, 223 func(p protopath.Values) error { 224 gotPath := p.Path[1:].String() 225 wantPath := stackPaths[len(stackPaths)-1] 226 if wantPath != gotPath { 227 t.Errorf("%d: pop path mismatch: got %v, want %v", len(gotPaths), gotPath, wantPath) 228 } 229 gotValue := p.Index(-1).Value 230 wantValue := stackValues[len(stackValues)-1] 231 if diff := cmp.Diff(wantValue, gotValue, transformReflectValue, protocmp.Transform()); diff != "" { 232 t.Errorf("%d: pop value mismatch (-want +got):\n%v", len(gotValues), diff) 233 } 234 stackPaths = stackPaths[:len(stackPaths)-1] 235 stackValues = stackValues[:len(stackValues)-1] 236 return nil 237 }, 238 ) 239 if n := len(stackPaths) + len(stackValues); n > 0 { 240 t.Errorf("stack mismatch: got %d unpopped items", n) 241 } 242 if diff := cmp.Diff(tt.wantPaths, gotPaths); diff != "" { 243 t.Errorf("paths mismatch (-want +got):\n%s", diff) 244 } 245 if diff := cmp.Diff(tt.wantValues, gotValues, transformReflectValue, protocmp.Transform()); diff != "" { 246 t.Errorf("values mismatch (-want +got):\n%s", diff) 247 } 248 if !cmp.Equal(gotError, tt.wantError, cmpopts.EquateErrors()) { 249 t.Errorf("error mismatch: got %v, want %v", gotError, tt.wantError) 250 } 251 }) 252 } 253} 254