1package goshared 2 3import ( 4 "fmt" 5 "reflect" 6 "strings" 7 "text/template" 8 9 "github.com/golang/protobuf/ptypes" 10 "github.com/golang/protobuf/ptypes/duration" 11 "github.com/golang/protobuf/ptypes/timestamp" 12 "github.com/lyft/protoc-gen-star" 13 "github.com/lyft/protoc-gen-star/lang/go" 14 "github.com/envoyproxy/protoc-gen-validate/templates/shared" 15) 16 17func Register(tpl *template.Template, params pgs.Parameters) { 18 fns := goSharedFuncs{pgsgo.InitContext(params)} 19 20 tpl.Funcs(map[string]interface{}{ 21 "accessor": fns.accessor, 22 "byteStr": fns.byteStr, 23 "cmt": pgs.C80, 24 "durGt": fns.durGt, 25 "durLit": fns.durLit, 26 "durStr": fns.durStr, 27 "err": fns.err, 28 "errCause": fns.errCause, 29 "errIdx": fns.errIdx, 30 "errIdxCause": fns.errIdxCause, 31 "errname": fns.errName, 32 "inKey": fns.inKey, 33 "inType": fns.inType, 34 "isBytes": fns.isBytes, 35 "lit": fns.lit, 36 "lookup": fns.lookup, 37 "msgTyp": fns.msgTyp, 38 "name": fns.Name, 39 "oneof": fns.oneofTypeName, 40 "pkg": fns.PackageName, 41 "tsGt": fns.tsGt, 42 "tsLit": fns.tsLit, 43 "tsStr": fns.tsStr, 44 "typ": fns.Type, 45 "unwrap": fns.unwrap, 46 "externalEnums": fns.externalEnums, 47 "enumPackages": fns.enumPackages, 48 }) 49 50 template.Must(tpl.New("msg").Parse(msgTpl)) 51 template.Must(tpl.New("const").Parse(constTpl)) 52 template.Must(tpl.New("ltgt").Parse(ltgtTpl)) 53 template.Must(tpl.New("in").Parse(inTpl)) 54 55 template.Must(tpl.New("none").Parse(noneTpl)) 56 template.Must(tpl.New("float").Parse(numTpl)) 57 template.Must(tpl.New("double").Parse(numTpl)) 58 template.Must(tpl.New("int32").Parse(numTpl)) 59 template.Must(tpl.New("int64").Parse(numTpl)) 60 template.Must(tpl.New("uint32").Parse(numTpl)) 61 template.Must(tpl.New("uint64").Parse(numTpl)) 62 template.Must(tpl.New("sint32").Parse(numTpl)) 63 template.Must(tpl.New("sint64").Parse(numTpl)) 64 template.Must(tpl.New("fixed32").Parse(numTpl)) 65 template.Must(tpl.New("fixed64").Parse(numTpl)) 66 template.Must(tpl.New("sfixed32").Parse(numTpl)) 67 template.Must(tpl.New("sfixed64").Parse(numTpl)) 68 69 template.Must(tpl.New("bool").Parse(constTpl)) 70 template.Must(tpl.New("string").Parse(strTpl)) 71 template.Must(tpl.New("bytes").Parse(bytesTpl)) 72 73 template.Must(tpl.New("email").Parse(emailTpl)) 74 template.Must(tpl.New("hostname").Parse(hostTpl)) 75 template.Must(tpl.New("address").Parse(hostTpl)) 76 77 template.Must(tpl.New("enum").Parse(enumTpl)) 78 template.Must(tpl.New("repeated").Parse(repTpl)) 79 template.Must(tpl.New("map").Parse(mapTpl)) 80 81 template.Must(tpl.New("any").Parse(anyTpl)) 82 template.Must(tpl.New("timestampcmp").Parse(timestampcmpTpl)) 83 template.Must(tpl.New("durationcmp").Parse(durationcmpTpl)) 84 85 template.Must(tpl.New("wrapper").Parse(wrapperTpl)) 86} 87 88type goSharedFuncs struct{ pgsgo.Context } 89 90func (fns goSharedFuncs) accessor(ctx shared.RuleContext) string { 91 if ctx.AccessorOverride != "" { 92 return ctx.AccessorOverride 93 } 94 95 return fmt.Sprintf("m.Get%s()", fns.Name(ctx.Field)) 96} 97 98func (fns goSharedFuncs) errName(m pgs.Message) pgs.Name { 99 return fns.Name(m) + "ValidationError" 100} 101 102func (fns goSharedFuncs) errIdxCause(ctx shared.RuleContext, idx, cause string, reason ...interface{}) string { 103 f := ctx.Field 104 n := fns.Name(f) 105 106 var fld string 107 if idx != "" { 108 fld = fmt.Sprintf(`fmt.Sprintf("%s[%%v]", %s)`, n, idx) 109 } else if ctx.Index != "" { 110 fld = fmt.Sprintf(`fmt.Sprintf("%s[%%v]", %s)`, n, ctx.Index) 111 } else { 112 fld = fmt.Sprintf("%q", n) 113 } 114 115 causeFld := "" 116 if cause != "nil" && cause != "" { 117 causeFld = fmt.Sprintf("cause: %s,", cause) 118 } 119 120 keyFld := "" 121 if ctx.OnKey { 122 keyFld = "key: true," 123 } 124 125 return fmt.Sprintf(`%s{ 126 field: %s, 127 reason: %q, 128 %s%s 129 }`, 130 fns.errName(f.Message()), 131 fld, 132 fmt.Sprint(reason...), 133 causeFld, 134 keyFld) 135} 136 137func (fns goSharedFuncs) err(ctx shared.RuleContext, reason ...interface{}) string { 138 return fns.errIdxCause(ctx, "", "nil", reason...) 139} 140 141func (fns goSharedFuncs) errCause(ctx shared.RuleContext, cause string, reason ...interface{}) string { 142 return fns.errIdxCause(ctx, "", cause, reason...) 143} 144 145func (fns goSharedFuncs) errIdx(ctx shared.RuleContext, idx string, reason ...interface{}) string { 146 return fns.errIdxCause(ctx, idx, "nil", reason...) 147} 148 149func (fns goSharedFuncs) lookup(f pgs.Field, name string) string { 150 return fmt.Sprintf( 151 "_%s_%s_%s", 152 fns.Name(f.Message()), 153 fns.Name(f), 154 name, 155 ) 156} 157 158func (fns goSharedFuncs) lit(x interface{}) string { 159 val := reflect.ValueOf(x) 160 161 if val.Kind() == reflect.Interface { 162 val = val.Elem() 163 } 164 165 if val.Kind() == reflect.Ptr { 166 val = val.Elem() 167 } 168 169 switch val.Kind() { 170 case reflect.String: 171 return fmt.Sprintf("%q", x) 172 case reflect.Uint8: 173 return fmt.Sprintf("0x%X", x) 174 case reflect.Slice: 175 els := make([]string, val.Len()) 176 for i, l := 0, val.Len(); i < l; i++ { 177 els[i] = fns.lit(val.Index(i).Interface()) 178 } 179 return fmt.Sprintf("%T{%s}", val.Interface(), strings.Join(els, ", ")) 180 default: 181 return fmt.Sprint(x) 182 } 183} 184 185func (fns goSharedFuncs) isBytes(f interface { 186 ProtoType() pgs.ProtoType 187}) bool { 188 return f.ProtoType() == pgs.BytesT 189} 190 191func (fns goSharedFuncs) byteStr(x []byte) string { 192 elms := make([]string, len(x)) 193 for i, b := range x { 194 elms[i] = fmt.Sprintf(`\x%X`, b) 195 } 196 197 return fmt.Sprintf(`"%s"`, strings.Join(elms, "")) 198} 199 200func (fns goSharedFuncs) oneofTypeName(f pgs.Field) pgsgo.TypeName { 201 return pgsgo.TypeName(fns.OneofOption(f)).Pointer() 202} 203 204func (fns goSharedFuncs) inType(f pgs.Field, x interface{}) string { 205 switch f.Type().ProtoType() { 206 case pgs.BytesT: 207 return "string" 208 case pgs.MessageT: 209 switch x.(type) { 210 case []*duration.Duration: 211 return "time.Duration" 212 default: 213 return pgsgo.TypeName(fmt.Sprintf("%T", x)).Element().String() 214 } 215 default: 216 return fns.Type(f).String() 217 } 218} 219 220func (fns goSharedFuncs) inKey(f pgs.Field, x interface{}) string { 221 switch f.Type().ProtoType() { 222 case pgs.BytesT: 223 return fns.byteStr(x.([]byte)) 224 case pgs.MessageT: 225 switch x := x.(type) { 226 case *duration.Duration: 227 dur, _ := ptypes.Duration(x) 228 return fns.lit(int64(dur)) 229 default: 230 return fns.lit(x) 231 } 232 default: 233 return fns.lit(x) 234 } 235} 236 237func (fns goSharedFuncs) durLit(dur *duration.Duration) string { 238 return fmt.Sprintf( 239 "time.Duration(%d * time.Second + %d * time.Nanosecond)", 240 dur.GetSeconds(), dur.GetNanos()) 241} 242 243func (fns goSharedFuncs) durStr(dur *duration.Duration) string { 244 d, _ := ptypes.Duration(dur) 245 return d.String() 246} 247 248func (fns goSharedFuncs) durGt(a, b *duration.Duration) bool { 249 ad, _ := ptypes.Duration(a) 250 bd, _ := ptypes.Duration(b) 251 252 return ad > bd 253} 254 255func (fns goSharedFuncs) tsLit(ts *timestamp.Timestamp) string { 256 return fmt.Sprintf( 257 "time.Unix(%d, %d)", 258 ts.GetSeconds(), ts.GetNanos(), 259 ) 260} 261 262func (fns goSharedFuncs) tsGt(a, b *timestamp.Timestamp) bool { 263 at, _ := ptypes.Timestamp(a) 264 bt, _ := ptypes.Timestamp(b) 265 266 return bt.Before(at) 267} 268 269func (fns goSharedFuncs) tsStr(ts *timestamp.Timestamp) string { 270 t, _ := ptypes.Timestamp(ts) 271 return t.String() 272} 273 274func (fns goSharedFuncs) unwrap(ctx shared.RuleContext, name string) (shared.RuleContext, error) { 275 ctx, err := ctx.Unwrap("wrapper") 276 if err != nil { 277 return ctx, err 278 } 279 280 ctx.AccessorOverride = fmt.Sprintf("%s.Get%s()", name, 281 pgsgo.PGGUpperCamelCase(ctx.Field.Type().Embed().Fields()[0].Name())) 282 283 return ctx, nil 284} 285 286func (fns goSharedFuncs) msgTyp(message pgs.Message) pgsgo.TypeName { 287 return pgsgo.TypeName(fns.Name(message)) 288} 289 290func (fns goSharedFuncs) externalEnums(file pgs.File) []pgs.Enum { 291 var out []pgs.Enum 292 293 for _, msg := range file.AllMessages() { 294 for _, fld := range msg.Fields() { 295 if en := fld.Type().Enum(); fld.Type().IsEnum() && en.Package().ProtoName() != fld.Package().ProtoName() && fns.PackageName(en) != fns.PackageName(fld) { 296 out = append(out, en) 297 } 298 } 299 } 300 301 return out 302} 303 304func (fns goSharedFuncs) enumPackages(enums []pgs.Enum) map[pgs.FilePath]pgs.Name { 305 out := make(map[pgs.FilePath]pgs.Name, len(enums)) 306 307 for _, en := range enums { 308 out[fns.ImportPath(en)] = fns.PackageName(en) 309 } 310 311 return out 312} 313