1package descriptor
2
3import (
4	"fmt"
5	"strings"
6
7	"github.com/golang/glog"
8	"github.com/golang/protobuf/proto"
9	descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
10	"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
11	options "google.golang.org/genproto/googleapis/api/annotations"
12)
13
14// loadServices registers services and their methods from "targetFile" to "r".
15// It must be called after loadFile is called for all files so that loadServices
16// can resolve names of message types and their fields.
17func (r *Registry) loadServices(file *File) error {
18	glog.V(1).Infof("Loading services from %s", file.GetName())
19	var svcs []*Service
20	for _, sd := range file.GetService() {
21		glog.V(2).Infof("Registering %s", sd.GetName())
22		svc := &Service{
23			File: file,
24			ServiceDescriptorProto: sd,
25		}
26		for _, md := range sd.GetMethod() {
27			glog.V(2).Infof("Processing %s.%s", sd.GetName(), md.GetName())
28			opts, err := extractAPIOptions(md)
29			if err != nil {
30				glog.Errorf("Failed to extract ApiMethodOptions from %s.%s: %v", svc.GetName(), md.GetName(), err)
31				return err
32			}
33			if opts == nil {
34				glog.V(1).Infof("Found non-target method: %s.%s", svc.GetName(), md.GetName())
35			}
36			meth, err := r.newMethod(svc, md, opts)
37			if err != nil {
38				return err
39			}
40			svc.Methods = append(svc.Methods, meth)
41		}
42		if len(svc.Methods) == 0 {
43			continue
44		}
45		glog.V(2).Infof("Registered %s with %d method(s)", svc.GetName(), len(svc.Methods))
46		svcs = append(svcs, svc)
47	}
48	file.Services = svcs
49	return nil
50}
51
52func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, opts *options.HttpRule) (*Method, error) {
53	requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType())
54	if err != nil {
55		return nil, err
56	}
57	responseType, err := r.LookupMsg(svc.File.GetPackage(), md.GetOutputType())
58	if err != nil {
59		return nil, err
60	}
61	meth := &Method{
62		Service:               svc,
63		MethodDescriptorProto: md,
64		RequestType:           requestType,
65		ResponseType:          responseType,
66	}
67
68	newBinding := func(opts *options.HttpRule, idx int) (*Binding, error) {
69		var (
70			httpMethod   string
71			pathTemplate string
72		)
73		switch {
74		case opts.GetGet() != "":
75			httpMethod = "GET"
76			pathTemplate = opts.GetGet()
77			if opts.Body != "" {
78				return nil, fmt.Errorf("needs request body even though http method is GET: %s", md.GetName())
79			}
80
81		case opts.GetPut() != "":
82			httpMethod = "PUT"
83			pathTemplate = opts.GetPut()
84
85		case opts.GetPost() != "":
86			httpMethod = "POST"
87			pathTemplate = opts.GetPost()
88
89		case opts.GetDelete() != "":
90			httpMethod = "DELETE"
91			pathTemplate = opts.GetDelete()
92			if opts.Body != "" && !r.allowDeleteBody {
93				return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName())
94			}
95
96		case opts.GetPatch() != "":
97			httpMethod = "PATCH"
98			pathTemplate = opts.GetPatch()
99
100		case opts.GetCustom() != nil:
101			custom := opts.GetCustom()
102			httpMethod = custom.Kind
103			pathTemplate = custom.Path
104
105		default:
106			glog.V(1).Infof("No pattern specified in google.api.HttpRule: %s", md.GetName())
107			return nil, nil
108		}
109
110		parsed, err := httprule.Parse(pathTemplate)
111		if err != nil {
112			return nil, err
113		}
114		tmpl := parsed.Compile()
115
116		if md.GetClientStreaming() && len(tmpl.Fields) > 0 {
117			return nil, fmt.Errorf("cannot use path parameter in client streaming")
118		}
119
120		b := &Binding{
121			Method:     meth,
122			Index:      idx,
123			PathTmpl:   tmpl,
124			HTTPMethod: httpMethod,
125		}
126
127		for _, f := range tmpl.Fields {
128			param, err := r.newParam(meth, f)
129			if err != nil {
130				return nil, err
131			}
132			b.PathParams = append(b.PathParams, param)
133		}
134
135		// TODO(yugui) Handle query params
136
137		b.Body, err = r.newBody(meth, opts.Body)
138		if err != nil {
139			return nil, err
140		}
141
142		return b, nil
143	}
144	b, err := newBinding(opts, 0)
145	if err != nil {
146		return nil, err
147	}
148
149	if b != nil {
150		meth.Bindings = append(meth.Bindings, b)
151	}
152	for i, additional := range opts.GetAdditionalBindings() {
153		if len(additional.AdditionalBindings) > 0 {
154			return nil, fmt.Errorf("additional_binding in additional_binding not allowed: %s.%s", svc.GetName(), meth.GetName())
155		}
156		b, err := newBinding(additional, i+1)
157		if err != nil {
158			return nil, err
159		}
160		meth.Bindings = append(meth.Bindings, b)
161	}
162
163	return meth, nil
164}
165
166func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*options.HttpRule, error) {
167	if meth.Options == nil {
168		return nil, nil
169	}
170	if !proto.HasExtension(meth.Options, options.E_Http) {
171		return nil, nil
172	}
173	ext, err := proto.GetExtension(meth.Options, options.E_Http)
174	if err != nil {
175		return nil, err
176	}
177	opts, ok := ext.(*options.HttpRule)
178	if !ok {
179		return nil, fmt.Errorf("extension is %T; want an HttpRule", ext)
180	}
181	return opts, nil
182}
183
184func (r *Registry) newParam(meth *Method, path string) (Parameter, error) {
185	msg := meth.RequestType
186	fields, err := r.resolveFiledPath(msg, path)
187	if err != nil {
188		return Parameter{}, err
189	}
190	l := len(fields)
191	if l == 0 {
192		return Parameter{}, fmt.Errorf("invalid field access list for %s", path)
193	}
194	target := fields[l-1].Target
195	switch target.GetType() {
196	case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP:
197		return Parameter{}, fmt.Errorf("aggregate type %s in parameter of %s.%s: %s", target.Type, meth.Service.GetName(), meth.GetName(), path)
198	}
199	return Parameter{
200		FieldPath: FieldPath(fields),
201		Method:    meth,
202		Target:    fields[l-1].Target,
203	}, nil
204}
205
206func (r *Registry) newBody(meth *Method, path string) (*Body, error) {
207	msg := meth.RequestType
208	switch path {
209	case "":
210		return nil, nil
211	case "*":
212		return &Body{FieldPath: nil}, nil
213	}
214	fields, err := r.resolveFiledPath(msg, path)
215	if err != nil {
216		return nil, err
217	}
218	return &Body{FieldPath: FieldPath(fields)}, nil
219}
220
221// lookupField looks up a field named "name" within "msg".
222// It returns nil if no such field found.
223func lookupField(msg *Message, name string) *Field {
224	for _, f := range msg.Fields {
225		if f.GetName() == name {
226			return f
227		}
228	}
229	return nil
230}
231
232// resolveFieldPath resolves "path" into a list of fieldDescriptor, starting from "msg".
233func (r *Registry) resolveFiledPath(msg *Message, path string) ([]FieldPathComponent, error) {
234	if path == "" {
235		return nil, nil
236	}
237
238	root := msg
239	var result []FieldPathComponent
240	for i, c := range strings.Split(path, ".") {
241		if i > 0 {
242			f := result[i-1].Target
243			switch f.GetType() {
244			case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP:
245				var err error
246				msg, err = r.LookupMsg(msg.FQMN(), f.GetTypeName())
247				if err != nil {
248					return nil, err
249				}
250			default:
251				return nil, fmt.Errorf("not an aggregate type: %s in %s", f.GetName(), path)
252			}
253		}
254
255		glog.V(2).Infof("Lookup %s in %s", c, msg.FQMN())
256		f := lookupField(msg, c)
257		if f == nil {
258			return nil, fmt.Errorf("no field %q found in %s", path, root.GetName())
259		}
260		if f.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED {
261			return nil, fmt.Errorf("repeated field not allowed in field path: %s in %s", f.GetName(), path)
262		}
263		result = append(result, FieldPathComponent{Name: c, Target: f})
264	}
265	return result, nil
266}
267