1package descriptor 2 3import ( 4 "fmt" 5 "strings" 6 7 "github.com/golang/glog" 8 "github.com/golang/protobuf/proto" 9 descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" 10 "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule" 11 options "google.golang.org/genproto/googleapis/api/annotations" 12) 13 14// loadServices registers services and their methods from "targetFile" to "r". 15// It must be called after loadFile is called for all files so that loadServices 16// can resolve names of message types and their fields. 17func (r *Registry) loadServices(file *File) error { 18 glog.V(1).Infof("Loading services from %s", file.GetName()) 19 var svcs []*Service 20 for _, sd := range file.GetService() { 21 glog.V(2).Infof("Registering %s", sd.GetName()) 22 svc := &Service{ 23 File: file, 24 ServiceDescriptorProto: sd, 25 } 26 for _, md := range sd.GetMethod() { 27 glog.V(2).Infof("Processing %s.%s", sd.GetName(), md.GetName()) 28 opts, err := extractAPIOptions(md) 29 if err != nil { 30 glog.Errorf("Failed to extract ApiMethodOptions from %s.%s: %v", svc.GetName(), md.GetName(), err) 31 return err 32 } 33 if opts == nil { 34 glog.V(1).Infof("Found non-target method: %s.%s", svc.GetName(), md.GetName()) 35 } 36 meth, err := r.newMethod(svc, md, opts) 37 if err != nil { 38 return err 39 } 40 svc.Methods = append(svc.Methods, meth) 41 } 42 if len(svc.Methods) == 0 { 43 continue 44 } 45 glog.V(2).Infof("Registered %s with %d method(s)", svc.GetName(), len(svc.Methods)) 46 svcs = append(svcs, svc) 47 } 48 file.Services = svcs 49 return nil 50} 51 52func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, opts *options.HttpRule) (*Method, error) { 53 requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType()) 54 if err != nil { 55 return nil, err 56 } 57 responseType, err := r.LookupMsg(svc.File.GetPackage(), md.GetOutputType()) 58 if err != nil { 59 return nil, err 60 } 61 meth := &Method{ 62 Service: svc, 63 MethodDescriptorProto: md, 64 RequestType: requestType, 65 ResponseType: responseType, 66 } 67 68 newBinding := func(opts *options.HttpRule, idx int) (*Binding, error) { 69 var ( 70 httpMethod string 71 pathTemplate string 72 ) 73 switch { 74 case opts.GetGet() != "": 75 httpMethod = "GET" 76 pathTemplate = opts.GetGet() 77 if opts.Body != "" { 78 return nil, fmt.Errorf("needs request body even though http method is GET: %s", md.GetName()) 79 } 80 81 case opts.GetPut() != "": 82 httpMethod = "PUT" 83 pathTemplate = opts.GetPut() 84 85 case opts.GetPost() != "": 86 httpMethod = "POST" 87 pathTemplate = opts.GetPost() 88 89 case opts.GetDelete() != "": 90 httpMethod = "DELETE" 91 pathTemplate = opts.GetDelete() 92 if opts.Body != "" && !r.allowDeleteBody { 93 return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName()) 94 } 95 96 case opts.GetPatch() != "": 97 httpMethod = "PATCH" 98 pathTemplate = opts.GetPatch() 99 100 case opts.GetCustom() != nil: 101 custom := opts.GetCustom() 102 httpMethod = custom.Kind 103 pathTemplate = custom.Path 104 105 default: 106 glog.V(1).Infof("No pattern specified in google.api.HttpRule: %s", md.GetName()) 107 return nil, nil 108 } 109 110 parsed, err := httprule.Parse(pathTemplate) 111 if err != nil { 112 return nil, err 113 } 114 tmpl := parsed.Compile() 115 116 if md.GetClientStreaming() && len(tmpl.Fields) > 0 { 117 return nil, fmt.Errorf("cannot use path parameter in client streaming") 118 } 119 120 b := &Binding{ 121 Method: meth, 122 Index: idx, 123 PathTmpl: tmpl, 124 HTTPMethod: httpMethod, 125 } 126 127 for _, f := range tmpl.Fields { 128 param, err := r.newParam(meth, f) 129 if err != nil { 130 return nil, err 131 } 132 b.PathParams = append(b.PathParams, param) 133 } 134 135 // TODO(yugui) Handle query params 136 137 b.Body, err = r.newBody(meth, opts.Body) 138 if err != nil { 139 return nil, err 140 } 141 142 return b, nil 143 } 144 b, err := newBinding(opts, 0) 145 if err != nil { 146 return nil, err 147 } 148 149 if b != nil { 150 meth.Bindings = append(meth.Bindings, b) 151 } 152 for i, additional := range opts.GetAdditionalBindings() { 153 if len(additional.AdditionalBindings) > 0 { 154 return nil, fmt.Errorf("additional_binding in additional_binding not allowed: %s.%s", svc.GetName(), meth.GetName()) 155 } 156 b, err := newBinding(additional, i+1) 157 if err != nil { 158 return nil, err 159 } 160 meth.Bindings = append(meth.Bindings, b) 161 } 162 163 return meth, nil 164} 165 166func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*options.HttpRule, error) { 167 if meth.Options == nil { 168 return nil, nil 169 } 170 if !proto.HasExtension(meth.Options, options.E_Http) { 171 return nil, nil 172 } 173 ext, err := proto.GetExtension(meth.Options, options.E_Http) 174 if err != nil { 175 return nil, err 176 } 177 opts, ok := ext.(*options.HttpRule) 178 if !ok { 179 return nil, fmt.Errorf("extension is %T; want an HttpRule", ext) 180 } 181 return opts, nil 182} 183 184func (r *Registry) newParam(meth *Method, path string) (Parameter, error) { 185 msg := meth.RequestType 186 fields, err := r.resolveFiledPath(msg, path) 187 if err != nil { 188 return Parameter{}, err 189 } 190 l := len(fields) 191 if l == 0 { 192 return Parameter{}, fmt.Errorf("invalid field access list for %s", path) 193 } 194 target := fields[l-1].Target 195 switch target.GetType() { 196 case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP: 197 return Parameter{}, fmt.Errorf("aggregate type %s in parameter of %s.%s: %s", target.Type, meth.Service.GetName(), meth.GetName(), path) 198 } 199 return Parameter{ 200 FieldPath: FieldPath(fields), 201 Method: meth, 202 Target: fields[l-1].Target, 203 }, nil 204} 205 206func (r *Registry) newBody(meth *Method, path string) (*Body, error) { 207 msg := meth.RequestType 208 switch path { 209 case "": 210 return nil, nil 211 case "*": 212 return &Body{FieldPath: nil}, nil 213 } 214 fields, err := r.resolveFiledPath(msg, path) 215 if err != nil { 216 return nil, err 217 } 218 return &Body{FieldPath: FieldPath(fields)}, nil 219} 220 221// lookupField looks up a field named "name" within "msg". 222// It returns nil if no such field found. 223func lookupField(msg *Message, name string) *Field { 224 for _, f := range msg.Fields { 225 if f.GetName() == name { 226 return f 227 } 228 } 229 return nil 230} 231 232// resolveFieldPath resolves "path" into a list of fieldDescriptor, starting from "msg". 233func (r *Registry) resolveFiledPath(msg *Message, path string) ([]FieldPathComponent, error) { 234 if path == "" { 235 return nil, nil 236 } 237 238 root := msg 239 var result []FieldPathComponent 240 for i, c := range strings.Split(path, ".") { 241 if i > 0 { 242 f := result[i-1].Target 243 switch f.GetType() { 244 case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP: 245 var err error 246 msg, err = r.LookupMsg(msg.FQMN(), f.GetTypeName()) 247 if err != nil { 248 return nil, err 249 } 250 default: 251 return nil, fmt.Errorf("not an aggregate type: %s in %s", f.GetName(), path) 252 } 253 } 254 255 glog.V(2).Infof("Lookup %s in %s", c, msg.FQMN()) 256 f := lookupField(msg, c) 257 if f == nil { 258 return nil, fmt.Errorf("no field %q found in %s", path, root.GetName()) 259 } 260 if f.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED { 261 return nil, fmt.Errorf("repeated field not allowed in field path: %s in %s", f.GetName(), path) 262 } 263 result = append(result, FieldPathComponent{Name: c, Target: f}) 264 } 265 return result, nil 266} 267