1// Copyright 2019 DeepMap, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package codegen
16
17import (
18	"bufio"
19	"bytes"
20	"fmt"
21	"sort"
22	"strings"
23	"text/template"
24
25	"github.com/getkin/kin-openapi/openapi3"
26	"github.com/pkg/errors"
27	"golang.org/x/tools/imports"
28
29	"github.com/deepmap/oapi-codegen/pkg/codegen/templates"
30)
31
32// Options defines the optional code to generate.
33type Options struct {
34	GenerateChiServer  bool              // GenerateChiServer specifies whether to generate chi server boilerplate
35	GenerateEchoServer bool              // GenerateEchoServer specifies whether to generate echo server boilerplate
36	GenerateClient     bool              // GenerateClient specifies whether to generate client boilerplate
37	GenerateTypes      bool              // GenerateTypes specifies whether to generate type definitions
38	EmbedSpec          bool              // Whether to embed the swagger spec in the generated code
39	SkipFmt            bool              // Whether to skip go imports on the generated code
40	SkipPrune          bool              // Whether to skip pruning unused components on the generated code
41	AliasTypes         bool              // Whether to alias types if possible
42	IncludeTags        []string          // Only include operations that have one of these tags. Ignored when empty.
43	ExcludeTags        []string          // Exclude operations that have one of these tags. Ignored when empty.
44	UserTemplates      map[string]string // Override built-in templates from user-provided files
45	ImportMapping      map[string]string // ImportMapping specifies the golang package path for each external reference
46	ExcludeSchemas     []string          // Exclude from generation schemas with given names. Ignored when empty.
47}
48
49// goImport represents a go package to be imported in the generated code
50type goImport struct {
51	Name string // package name
52	Path string // package path
53}
54
55// String returns a go import statement
56func (gi goImport) String() string {
57	if gi.Name != "" {
58		return fmt.Sprintf("%s %q", gi.Name, gi.Path)
59	}
60	return fmt.Sprintf("%q", gi.Path)
61}
62
63// importMap maps external OpenAPI specifications files/urls to external go packages
64type importMap map[string]goImport
65
66// GoImports returns a slice of go import statements
67func (im importMap) GoImports() []string {
68	goImports := make([]string, 0, len(im))
69	for _, v := range im {
70		goImports = append(goImports, v.String())
71	}
72	return goImports
73}
74
75var importMapping importMap
76
77func constructImportMapping(input map[string]string) importMap {
78	var (
79		pathToName = map[string]string{}
80		result     = importMap{}
81	)
82
83	{
84		var packagePaths []string
85		for _, packageName := range input {
86			packagePaths = append(packagePaths, packageName)
87		}
88		sort.Strings(packagePaths)
89
90		for _, packagePath := range packagePaths {
91			if _, ok := pathToName[packagePath]; !ok {
92				pathToName[packagePath] = fmt.Sprintf("externalRef%d", len(pathToName))
93			}
94		}
95	}
96	for specPath, packagePath := range input {
97		result[specPath] = goImport{Name: pathToName[packagePath], Path: packagePath}
98	}
99	return result
100}
101
102// Uses the Go templating engine to generate all of our server wrappers from
103// the descriptions we've built up above from the schema objects.
104// opts defines
105func Generate(swagger *openapi3.Swagger, packageName string, opts Options) (string, error) {
106	importMapping = constructImportMapping(opts.ImportMapping)
107
108	filterOperationsByTag(swagger, opts)
109	if !opts.SkipPrune {
110		pruneUnusedComponents(swagger)
111	}
112
113	// This creates the golang templates text package
114	TemplateFunctions["opts"] = func() Options { return opts }
115	t := template.New("oapi-codegen").Funcs(TemplateFunctions)
116	// This parses all of our own template files into the template object
117	// above
118	t, err := templates.Parse(t)
119	if err != nil {
120		return "", errors.Wrap(err, "error parsing oapi-codegen templates")
121	}
122
123	// Override built-in templates with user-provided versions
124	for _, tpl := range t.Templates() {
125		if _, ok := opts.UserTemplates[tpl.Name()]; ok {
126			utpl := t.New(tpl.Name())
127			if _, err := utpl.Parse(opts.UserTemplates[tpl.Name()]); err != nil {
128				return "", errors.Wrapf(err, "error parsing user-provided template %q", tpl.Name())
129			}
130		}
131	}
132
133	ops, err := OperationDefinitions(swagger)
134	if err != nil {
135		return "", errors.Wrap(err, "error creating operation definitions")
136	}
137
138	var typeDefinitions, constantDefinitions string
139	if opts.GenerateTypes {
140		typeDefinitions, err = GenerateTypeDefinitions(t, swagger, ops, opts.ExcludeSchemas)
141		if err != nil {
142			return "", errors.Wrap(err, "error generating type definitions")
143		}
144
145		constantDefinitions, err = GenerateConstants(t, ops)
146		if err != nil {
147			return "", errors.Wrap(err, "error generating constants")
148		}
149
150	}
151
152	var echoServerOut string
153	if opts.GenerateEchoServer {
154		echoServerOut, err = GenerateEchoServer(t, ops)
155		if err != nil {
156			return "", errors.Wrap(err, "error generating Go handlers for Paths")
157		}
158	}
159
160	var chiServerOut string
161	if opts.GenerateChiServer {
162		chiServerOut, err = GenerateChiServer(t, ops)
163		if err != nil {
164			return "", errors.Wrap(err, "error generating Go handlers for Paths")
165		}
166	}
167
168	var clientOut string
169	if opts.GenerateClient {
170		clientOut, err = GenerateClient(t, ops)
171		if err != nil {
172			return "", errors.Wrap(err, "error generating client")
173		}
174	}
175
176	var clientWithResponsesOut string
177	if opts.GenerateClient {
178		clientWithResponsesOut, err = GenerateClientWithResponses(t, ops)
179		if err != nil {
180			return "", errors.Wrap(err, "error generating client with responses")
181		}
182	}
183
184	var inlinedSpec string
185	if opts.EmbedSpec {
186		inlinedSpec, err = GenerateInlinedSpec(t, importMapping, swagger)
187		if err != nil {
188			return "", errors.Wrap(err, "error generating Go handlers for Paths")
189		}
190	}
191
192	var buf bytes.Buffer
193	w := bufio.NewWriter(&buf)
194
195	externalImports := importMapping.GoImports()
196	importsOut, err := GenerateImports(t, externalImports, packageName)
197	if err != nil {
198		return "", errors.Wrap(err, "error generating imports")
199	}
200
201	_, err = w.WriteString(importsOut)
202	if err != nil {
203		return "", errors.Wrap(err, "error writing imports")
204	}
205
206	_, err = w.WriteString(constantDefinitions)
207	if err != nil {
208		return "", errors.Wrap(err, "error writing constants")
209	}
210
211	_, err = w.WriteString(typeDefinitions)
212	if err != nil {
213		return "", errors.Wrap(err, "error writing type definitions")
214
215	}
216
217	if opts.GenerateClient {
218		_, err = w.WriteString(clientOut)
219		if err != nil {
220			return "", errors.Wrap(err, "error writing client")
221		}
222		_, err = w.WriteString(clientWithResponsesOut)
223		if err != nil {
224			return "", errors.Wrap(err, "error writing client")
225		}
226	}
227
228	if opts.GenerateEchoServer {
229		_, err = w.WriteString(echoServerOut)
230		if err != nil {
231			return "", errors.Wrap(err, "error writing server path handlers")
232		}
233	}
234
235	if opts.GenerateChiServer {
236		_, err = w.WriteString(chiServerOut)
237		if err != nil {
238			return "", errors.Wrap(err, "error writing server path handlers")
239		}
240	}
241
242	if opts.EmbedSpec {
243		_, err = w.WriteString(inlinedSpec)
244		if err != nil {
245			return "", errors.Wrap(err, "error writing inlined spec")
246		}
247	}
248
249	err = w.Flush()
250	if err != nil {
251		return "", errors.Wrap(err, "error flushing output buffer")
252	}
253
254	// remove any byte-order-marks which break Go-Code
255	goCode := SanitizeCode(buf.String())
256
257	// The generation code produces unindented horrors. Use the Go Imports
258	// to make it all pretty.
259	if opts.SkipFmt {
260		return goCode, nil
261	}
262
263	outBytes, err := imports.Process(packageName+".go", []byte(goCode), nil)
264	if err != nil {
265		fmt.Println(goCode)
266		return "", errors.Wrap(err, "error formatting Go code")
267	}
268	return string(outBytes), nil
269}
270
271func GenerateTypeDefinitions(t *template.Template, swagger *openapi3.Swagger, ops []OperationDefinition, excludeSchemas []string) (string, error) {
272	schemaTypes, err := GenerateTypesForSchemas(t, swagger.Components.Schemas, excludeSchemas)
273	if err != nil {
274		return "", errors.Wrap(err, "error generating Go types for component schemas")
275	}
276
277	paramTypes, err := GenerateTypesForParameters(t, swagger.Components.Parameters)
278	if err != nil {
279		return "", errors.Wrap(err, "error generating Go types for component parameters")
280	}
281	allTypes := append(schemaTypes, paramTypes...)
282
283	responseTypes, err := GenerateTypesForResponses(t, swagger.Components.Responses)
284	if err != nil {
285		return "", errors.Wrap(err, "error generating Go types for component responses")
286	}
287	allTypes = append(allTypes, responseTypes...)
288
289	bodyTypes, err := GenerateTypesForRequestBodies(t, swagger.Components.RequestBodies)
290	if err != nil {
291		return "", errors.Wrap(err, "error generating Go types for component request bodies")
292	}
293	allTypes = append(allTypes, bodyTypes...)
294
295	paramTypesOut, err := GenerateTypesForOperations(t, ops)
296	if err != nil {
297		return "", errors.Wrap(err, "error generating Go types for operation parameters")
298	}
299
300	enumsOut, err := GenerateEnums(t, allTypes)
301	if err != nil {
302		return "", errors.Wrap(err, "error generating code for type enums")
303	}
304
305	typesOut, err := GenerateTypes(t, allTypes)
306	if err != nil {
307		return "", errors.Wrap(err, "error generating code for type definitions")
308	}
309
310	allOfBoilerplate, err := GenerateAdditionalPropertyBoilerplate(t, allTypes)
311	if err != nil {
312		return "", errors.Wrap(err, "error generating allOf boilerplate")
313	}
314
315	typeDefinitions := strings.Join([]string{enumsOut, typesOut, paramTypesOut, allOfBoilerplate}, "")
316	return typeDefinitions, nil
317}
318
319// Generates operation ids, context keys, paths, etc. to be exported as constants
320func GenerateConstants(t *template.Template, ops []OperationDefinition) (string, error) {
321	var buf bytes.Buffer
322	w := bufio.NewWriter(&buf)
323
324	constants := Constants{
325		SecuritySchemeProviderNames: []string{},
326	}
327
328	providerNameMap := map[string]struct{}{}
329	for _, op := range ops {
330		for _, def := range op.SecurityDefinitions {
331			providerName := SanitizeGoIdentity(def.ProviderName)
332			providerNameMap[providerName] = struct{}{}
333		}
334	}
335
336	var providerNames []string
337	for providerName := range providerNameMap {
338		providerNames = append(providerNames, providerName)
339	}
340
341	sort.Strings(providerNames)
342
343	for _, providerName := range providerNames {
344		constants.SecuritySchemeProviderNames = append(constants.SecuritySchemeProviderNames, providerName)
345	}
346
347	err := t.ExecuteTemplate(w, "constants.tmpl", constants)
348
349	if err != nil {
350		return "", fmt.Errorf("error generating server interface: %s", err)
351	}
352	err = w.Flush()
353	if err != nil {
354		return "", fmt.Errorf("error flushing output buffer for server interface: %s", err)
355	}
356	return buf.String(), nil
357}
358
359// Generates type definitions for any custom types defined in the
360// components/schemas section of the Swagger spec.
361func GenerateTypesForSchemas(t *template.Template, schemas map[string]*openapi3.SchemaRef, excludeSchemas []string) ([]TypeDefinition, error) {
362	var excludeSchemasMap = make(map[string]bool)
363	for _, schema := range excludeSchemas {
364		excludeSchemasMap[schema] = true
365	}
366	types := make([]TypeDefinition, 0)
367	// We're going to define Go types for every object under components/schemas
368	for _, schemaName := range SortedSchemaKeys(schemas) {
369		if _, ok := excludeSchemasMap[schemaName]; ok {
370			continue
371		}
372		schemaRef := schemas[schemaName]
373
374		goSchema, err := GenerateGoSchema(schemaRef, []string{schemaName})
375		if err != nil {
376			return nil, errors.Wrap(err, fmt.Sprintf("error converting Schema %s to Go type", schemaName))
377		}
378
379		types = append(types, TypeDefinition{
380			JsonName: schemaName,
381			TypeName: SchemaNameToTypeName(schemaName),
382			Schema:   goSchema,
383		})
384
385		types = append(types, goSchema.GetAdditionalTypeDefs()...)
386	}
387	return types, nil
388}
389
390// Generates type definitions for any custom types defined in the
391// components/parameters section of the Swagger spec.
392func GenerateTypesForParameters(t *template.Template, params map[string]*openapi3.ParameterRef) ([]TypeDefinition, error) {
393	var types []TypeDefinition
394	for _, paramName := range SortedParameterKeys(params) {
395		paramOrRef := params[paramName]
396
397		goType, err := paramToGoType(paramOrRef.Value, nil)
398		if err != nil {
399			return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for schema in parameter %s", paramName))
400		}
401
402		typeDef := TypeDefinition{
403			JsonName: paramName,
404			Schema:   goType,
405			TypeName: SchemaNameToTypeName(paramName),
406		}
407
408		if paramOrRef.Ref != "" {
409			// Generate a reference type for referenced parameters
410			refType, err := RefPathToGoType(paramOrRef.Ref)
411			if err != nil {
412				return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for (%s) in parameter %s", paramOrRef.Ref, paramName))
413			}
414			typeDef.TypeName = SchemaNameToTypeName(refType)
415		}
416
417		types = append(types, typeDef)
418	}
419	return types, nil
420}
421
422// Generates type definitions for any custom types defined in the
423// components/responses section of the Swagger spec.
424func GenerateTypesForResponses(t *template.Template, responses openapi3.Responses) ([]TypeDefinition, error) {
425	var types []TypeDefinition
426
427	for _, responseName := range SortedResponsesKeys(responses) {
428		responseOrRef := responses[responseName]
429
430		// We have to generate the response object. We're only going to
431		// handle application/json media types here. Other responses should
432		// simply be specified as strings or byte arrays.
433		response := responseOrRef.Value
434		jsonResponse, found := response.Content["application/json"]
435		if found {
436			goType, err := GenerateGoSchema(jsonResponse.Schema, []string{responseName})
437			if err != nil {
438				return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for schema in response %s", responseName))
439			}
440
441			typeDef := TypeDefinition{
442				JsonName: responseName,
443				Schema:   goType,
444				TypeName: SchemaNameToTypeName(responseName),
445			}
446
447			if responseOrRef.Ref != "" {
448				// Generate a reference type for referenced parameters
449				refType, err := RefPathToGoType(responseOrRef.Ref)
450				if err != nil {
451					return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for (%s) in parameter %s", responseOrRef.Ref, responseName))
452				}
453				typeDef.TypeName = SchemaNameToTypeName(refType)
454			}
455			types = append(types, typeDef)
456		}
457	}
458	return types, nil
459}
460
461// Generates type definitions for any custom types defined in the
462// components/requestBodies section of the Swagger spec.
463func GenerateTypesForRequestBodies(t *template.Template, bodies map[string]*openapi3.RequestBodyRef) ([]TypeDefinition, error) {
464	var types []TypeDefinition
465
466	for _, bodyName := range SortedRequestBodyKeys(bodies) {
467		bodyOrRef := bodies[bodyName]
468
469		// As for responses, we will only generate Go code for JSON bodies,
470		// the other body formats are up to the user.
471		response := bodyOrRef.Value
472		jsonBody, found := response.Content["application/json"]
473		if found {
474			goType, err := GenerateGoSchema(jsonBody.Schema, []string{bodyName})
475			if err != nil {
476				return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for schema in body %s", bodyName))
477			}
478
479			typeDef := TypeDefinition{
480				JsonName: bodyName,
481				Schema:   goType,
482				TypeName: SchemaNameToTypeName(bodyName),
483			}
484
485			if bodyOrRef.Ref != "" {
486				// Generate a reference type for referenced bodies
487				refType, err := RefPathToGoType(bodyOrRef.Ref)
488				if err != nil {
489					return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for (%s) in body %s", bodyOrRef.Ref, bodyName))
490				}
491				typeDef.TypeName = SchemaNameToTypeName(refType)
492			}
493			types = append(types, typeDef)
494		}
495	}
496	return types, nil
497}
498
499// Helper function to pass a bunch of types to the template engine, and buffer
500// its output into a string.
501func GenerateTypes(t *template.Template, types []TypeDefinition) (string, error) {
502	var buf bytes.Buffer
503	w := bufio.NewWriter(&buf)
504
505	context := struct {
506		Types []TypeDefinition
507	}{
508		Types: types,
509	}
510
511	err := t.ExecuteTemplate(w, "typedef.tmpl", context)
512	if err != nil {
513		return "", errors.Wrap(err, "error generating types")
514	}
515	err = w.Flush()
516	if err != nil {
517		return "", errors.Wrap(err, "error flushing output buffer for types")
518	}
519	return buf.String(), nil
520}
521
522func GenerateEnums(t *template.Template, types []TypeDefinition) (string, error) {
523	var buf bytes.Buffer
524	w := bufio.NewWriter(&buf)
525	c := Constants{
526		EnumDefinitions: []EnumDefinition{},
527	}
528	for _, tp := range types {
529		if len(tp.Schema.EnumValues) > 0 {
530			wrapper := ""
531			if tp.Schema.GoType == "string" {
532				wrapper = `"`
533			}
534			c.EnumDefinitions = append(c.EnumDefinitions, EnumDefinition{
535				Schema:       tp.Schema,
536				TypeName:     tp.TypeName,
537				ValueWrapper: wrapper,
538			})
539		}
540	}
541	err := t.ExecuteTemplate(w, "constants.tmpl", c)
542	if err != nil {
543		return "", errors.Wrap(err, "error generating enums")
544	}
545	err = w.Flush()
546	if err != nil {
547		return "", errors.Wrap(err, "error flushing output buffer for enums")
548	}
549	return buf.String(), nil
550}
551
552// Generate our import statements and package definition.
553func GenerateImports(t *template.Template, externalImports []string, packageName string) (string, error) {
554	var buf bytes.Buffer
555	w := bufio.NewWriter(&buf)
556	context := struct {
557		ExternalImports []string
558		PackageName     string
559	}{
560		ExternalImports: externalImports,
561		PackageName:     packageName,
562	}
563	err := t.ExecuteTemplate(w, "imports.tmpl", context)
564	if err != nil {
565		return "", errors.Wrap(err, "error generating imports")
566	}
567	err = w.Flush()
568	if err != nil {
569		return "", errors.Wrap(err, "error flushing output buffer for imports")
570	}
571	return buf.String(), nil
572}
573
574// Generate all the glue code which provides the API for interacting with
575// additional properties and JSON-ification
576func GenerateAdditionalPropertyBoilerplate(t *template.Template, typeDefs []TypeDefinition) (string, error) {
577	var buf bytes.Buffer
578	w := bufio.NewWriter(&buf)
579
580	var filteredTypes []TypeDefinition
581	for _, t := range typeDefs {
582		if t.Schema.HasAdditionalProperties {
583			filteredTypes = append(filteredTypes, t)
584		}
585	}
586
587	context := struct {
588		Types []TypeDefinition
589	}{
590		Types: filteredTypes,
591	}
592
593	err := t.ExecuteTemplate(w, "additional-properties.tmpl", context)
594	if err != nil {
595		return "", errors.Wrap(err, "error generating additional properties code")
596	}
597	err = w.Flush()
598	if err != nil {
599		return "", errors.Wrap(err, "error flushing output buffer for additional properties")
600	}
601	return buf.String(), nil
602}
603
604// SanitizeCode runs sanitizers across the generated Go code to ensure the
605// generated code will be able to compile.
606func SanitizeCode(goCode string) string {
607	// remove any byte-order-marks which break Go-Code
608	// See: https://groups.google.com/forum/#!topic/golang-nuts/OToNIPdfkks
609	return strings.Replace(goCode, "\uFEFF", "", -1)
610}
611