1// Copyright 2015 go-swagger maintainers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package untyped
16
17import (
18	"fmt"
19	"net/http"
20	"sort"
21	"strings"
22
23	"github.com/go-openapi/analysis"
24	"github.com/go-openapi/errors"
25	"github.com/go-openapi/loads"
26	"github.com/go-openapi/runtime"
27	"github.com/go-openapi/spec"
28	"github.com/go-openapi/strfmt"
29)
30
31// NewAPI creates the default untyped API
32func NewAPI(spec *loads.Document) *API {
33	var an *analysis.Spec
34	if spec != nil && spec.Spec() != nil {
35		an = analysis.New(spec.Spec())
36	}
37	api := &API{
38		spec:           spec,
39		analyzer:       an,
40		consumers:      make(map[string]runtime.Consumer, 10),
41		producers:      make(map[string]runtime.Producer, 10),
42		authenticators: make(map[string]runtime.Authenticator),
43		operations:     make(map[string]map[string]runtime.OperationHandler),
44		ServeError:     errors.ServeError,
45		Models:         make(map[string]func() interface{}),
46		formats:        strfmt.NewFormats(),
47	}
48	return api.WithJSONDefaults()
49}
50
51// API represents an untyped mux for a swagger spec
52type API struct {
53	spec            *loads.Document
54	analyzer        *analysis.Spec
55	DefaultProduces string
56	DefaultConsumes string
57	consumers       map[string]runtime.Consumer
58	producers       map[string]runtime.Producer
59	authenticators  map[string]runtime.Authenticator
60	operations      map[string]map[string]runtime.OperationHandler
61	ServeError      func(http.ResponseWriter, *http.Request, error)
62	Models          map[string]func() interface{}
63	formats         strfmt.Registry
64}
65
66// WithJSONDefaults loads the json defaults for this api
67func (d *API) WithJSONDefaults() *API {
68	d.DefaultConsumes = runtime.JSONMime
69	d.DefaultProduces = runtime.JSONMime
70	d.consumers[runtime.JSONMime] = runtime.JSONConsumer()
71	d.producers[runtime.JSONMime] = runtime.JSONProducer()
72	return d
73}
74
75// WithoutJSONDefaults clears the json defaults for this api
76func (d *API) WithoutJSONDefaults() *API {
77	d.DefaultConsumes = ""
78	d.DefaultProduces = ""
79	delete(d.consumers, runtime.JSONMime)
80	delete(d.producers, runtime.JSONMime)
81	return d
82}
83
84// Formats returns the registered string formats
85func (d *API) Formats() strfmt.Registry {
86	if d.formats == nil {
87		d.formats = strfmt.NewFormats()
88	}
89	return d.formats
90}
91
92// RegisterFormat registers a custom format validator
93func (d *API) RegisterFormat(name string, format strfmt.Format, validator strfmt.Validator) {
94	if d.formats == nil {
95		d.formats = strfmt.NewFormats()
96	}
97	d.formats.Add(name, format, validator)
98}
99
100// RegisterAuth registers an auth handler in this api
101func (d *API) RegisterAuth(scheme string, handler runtime.Authenticator) {
102	if d.authenticators == nil {
103		d.authenticators = make(map[string]runtime.Authenticator)
104	}
105	d.authenticators[scheme] = handler
106}
107
108// RegisterConsumer registers a consumer for a media type.
109func (d *API) RegisterConsumer(mediaType string, handler runtime.Consumer) {
110	if d.consumers == nil {
111		d.consumers = make(map[string]runtime.Consumer, 10)
112	}
113	d.consumers[strings.ToLower(mediaType)] = handler
114}
115
116// RegisterProducer registers a producer for a media type
117func (d *API) RegisterProducer(mediaType string, handler runtime.Producer) {
118	if d.producers == nil {
119		d.producers = make(map[string]runtime.Producer, 10)
120	}
121	d.producers[strings.ToLower(mediaType)] = handler
122}
123
124// RegisterOperation registers an operation handler for an operation name
125func (d *API) RegisterOperation(method, path string, handler runtime.OperationHandler) {
126	if d.operations == nil {
127		d.operations = make(map[string]map[string]runtime.OperationHandler, 30)
128	}
129	um := strings.ToUpper(method)
130	if b, ok := d.operations[um]; !ok || b == nil {
131		d.operations[um] = make(map[string]runtime.OperationHandler)
132	}
133	d.operations[um][path] = handler
134}
135
136// OperationHandlerFor returns the operation handler for the specified id if it can be found
137func (d *API) OperationHandlerFor(method, path string) (runtime.OperationHandler, bool) {
138	if d.operations == nil {
139		return nil, false
140	}
141	if pi, ok := d.operations[strings.ToUpper(method)]; ok {
142		h, ok := pi[path]
143		return h, ok
144	}
145	return nil, false
146}
147
148// ConsumersFor gets the consumers for the specified media types
149func (d *API) ConsumersFor(mediaTypes []string) map[string]runtime.Consumer {
150	result := make(map[string]runtime.Consumer)
151	for _, mt := range mediaTypes {
152		if consumer, ok := d.consumers[mt]; ok {
153			result[mt] = consumer
154		}
155	}
156	return result
157}
158
159// ProducersFor gets the producers for the specified media types
160func (d *API) ProducersFor(mediaTypes []string) map[string]runtime.Producer {
161	result := make(map[string]runtime.Producer)
162	for _, mt := range mediaTypes {
163		if producer, ok := d.producers[mt]; ok {
164			result[mt] = producer
165		}
166	}
167	return result
168}
169
170// AuthenticatorsFor gets the authenticators for the specified security schemes
171func (d *API) AuthenticatorsFor(schemes map[string]spec.SecurityScheme) map[string]runtime.Authenticator {
172	result := make(map[string]runtime.Authenticator)
173	for k := range schemes {
174		if a, ok := d.authenticators[k]; ok {
175			result[k] = a
176		}
177	}
178	return result
179}
180
181// Validate validates this API for any missing items
182func (d *API) Validate() error {
183	return d.validate()
184}
185
186// validateWith validates the registrations in this API against the provided spec analyzer
187func (d *API) validate() error {
188	var consumes []string
189	for k := range d.consumers {
190		consumes = append(consumes, k)
191	}
192
193	var produces []string
194	for k := range d.producers {
195		produces = append(produces, k)
196	}
197
198	var authenticators []string
199	for k := range d.authenticators {
200		authenticators = append(authenticators, k)
201	}
202
203	var operations []string
204	for m, v := range d.operations {
205		for p := range v {
206			operations = append(operations, fmt.Sprintf("%s %s", strings.ToUpper(m), p))
207		}
208	}
209
210	var definedAuths []string
211	for k := range d.spec.Spec().SecurityDefinitions {
212		definedAuths = append(definedAuths, k)
213	}
214
215	if err := d.verify("consumes", consumes, d.analyzer.RequiredConsumes()); err != nil {
216		return err
217	}
218	if err := d.verify("produces", produces, d.analyzer.RequiredProduces()); err != nil {
219		return err
220	}
221	if err := d.verify("operation", operations, d.analyzer.OperationMethodPaths()); err != nil {
222		return err
223	}
224
225	requiredAuths := d.analyzer.RequiredSecuritySchemes()
226	if err := d.verify("auth scheme", authenticators, requiredAuths); err != nil {
227		return err
228	}
229	if err := d.verify("security definitions", definedAuths, requiredAuths); err != nil {
230		return err
231	}
232	return nil
233}
234
235func (d *API) verify(name string, registrations []string, expectations []string) error {
236
237	sort.Sort(sort.StringSlice(registrations))
238	sort.Sort(sort.StringSlice(expectations))
239
240	expected := map[string]struct{}{}
241	seen := map[string]struct{}{}
242
243	for _, v := range expectations {
244		expected[v] = struct{}{}
245	}
246
247	var unspecified []string
248	for _, v := range registrations {
249		seen[v] = struct{}{}
250		if _, ok := expected[v]; !ok {
251			unspecified = append(unspecified, v)
252		}
253	}
254
255	for k := range seen {
256		delete(expected, k)
257	}
258
259	var unregistered []string
260	for k := range expected {
261		unregistered = append(unregistered, k)
262	}
263	sort.Sort(sort.StringSlice(unspecified))
264	sort.Sort(sort.StringSlice(unregistered))
265
266	if len(unregistered) > 0 || len(unspecified) > 0 {
267		return &errors.APIVerificationFailed{
268			Section:              name,
269			MissingSpecification: unspecified,
270			MissingRegistration:  unregistered,
271		}
272	}
273
274	return nil
275}
276