1// Copyright 2019 DeepMap, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14package codegen
15
16import (
17	"bytes"
18	"fmt"
19	"os"
20	"strings"
21	"text/template"
22
23	"github.com/labstack/echo/v4"
24)
25
26const (
27	// These allow the case statements to be sorted later:
28	prefixMostSpecific, prefixLessSpecific, prefixLeastSpecific = "3", "6", "9"
29	responseTypeSuffix                                          = "Response"
30)
31
32var (
33	contentTypesJSON = []string{echo.MIMEApplicationJSON, "text/x-json"}
34	contentTypesYAML = []string{"application/yaml", "application/x-yaml", "text/yaml", "text/x-yaml"}
35	contentTypesXML  = []string{echo.MIMEApplicationXML, echo.MIMETextXML}
36)
37
38// This function takes an array of Parameter definition, and generates a valid
39// Go parameter declaration from them, eg:
40// ", foo int, bar string, baz float32". The preceding comma is there to save
41// a lot of work in the template engine.
42func genParamArgs(params []ParameterDefinition) string {
43	if len(params) == 0 {
44		return ""
45	}
46	parts := make([]string, len(params))
47	for i, p := range params {
48		paramName := p.GoVariableName()
49		parts[i] = fmt.Sprintf("%s %s", paramName, p.TypeDef())
50	}
51	return ", " + strings.Join(parts, ", ")
52}
53
54// This function is much like the one above, except it only produces the
55// types of the parameters for a type declaration. It would produce this
56// from the same input as above:
57// ", int, string, float32".
58func genParamTypes(params []ParameterDefinition) string {
59	if len(params) == 0 {
60		return ""
61	}
62	parts := make([]string, len(params))
63	for i, p := range params {
64		parts[i] = p.TypeDef()
65	}
66	return ", " + strings.Join(parts, ", ")
67}
68
69// This is another variation of the function above which generates only the
70// parameter names:
71// ", foo, bar, baz"
72func genParamNames(params []ParameterDefinition) string {
73	if len(params) == 0 {
74		return ""
75	}
76	parts := make([]string, len(params))
77	for i, p := range params {
78		parts[i] = p.GoVariableName()
79	}
80	return ", " + strings.Join(parts, ", ")
81}
82
83// genResponsePayload generates the payload returned at the end of each client request function
84func genResponsePayload(operationID string) string {
85	var buffer = bytes.NewBufferString("")
86
87	// Here is where we build up a response:
88	fmt.Fprintf(buffer, "&%s{\n", genResponseTypeName(operationID))
89	fmt.Fprintf(buffer, "Body: bodyBytes,\n")
90	fmt.Fprintf(buffer, "HTTPResponse: rsp,\n")
91	fmt.Fprintf(buffer, "}")
92
93	return buffer.String()
94}
95
96// genResponseUnmarshal generates unmarshaling steps for structured response payloads
97func genResponseUnmarshal(op *OperationDefinition) string {
98	var handledCaseClauses = make(map[string]string)
99	var unhandledCaseClauses = make(map[string]string)
100
101	// Get the type definitions from the operation:
102	typeDefinitions, err := op.GetResponseTypeDefinitions()
103	if err != nil {
104		panic(err)
105	}
106
107	if len(typeDefinitions) == 0 {
108		// No types.
109		return ""
110	}
111
112	// Add a case for each possible response:
113	buffer := new(bytes.Buffer)
114	responses := op.Spec.Responses
115	for _, typeDefinition := range typeDefinitions {
116
117		responseRef, ok := responses[typeDefinition.ResponseName]
118		if !ok {
119			continue
120		}
121
122		// We can't do much without a value:
123		if responseRef.Value == nil {
124			fmt.Fprintf(os.Stderr, "Response %s.%s has nil value\n", op.OperationId, typeDefinition.ResponseName)
125			continue
126		}
127
128		// If there is no content-type then we have no unmarshaling to do:
129		if len(responseRef.Value.Content) == 0 {
130			caseAction := "break // No content-type"
131			caseClauseKey := "case " + getConditionOfResponseName("rsp.StatusCode", typeDefinition.ResponseName) + ":"
132			unhandledCaseClauses[prefixLeastSpecific+caseClauseKey] = fmt.Sprintf("%s\n%s\n", caseClauseKey, caseAction)
133			continue
134		}
135
136		// If we made it this far then we need to handle unmarshaling for each content-type:
137		sortedContentKeys := SortedContentKeys(responseRef.Value.Content)
138		for _, contentTypeName := range sortedContentKeys {
139
140			// We get "interface{}" when using "anyOf" or "oneOf" (which doesn't work with Go types):
141			if typeDefinition.TypeName == "interface{}" {
142				// Unable to unmarshal this, so we leave it out:
143				continue
144			}
145
146			// Add content-types here (json / yaml / xml etc):
147			switch {
148
149			// JSON:
150			case StringInArray(contentTypeName, contentTypesJSON):
151				if typeDefinition.ContentTypeName == contentTypeName {
152					var caseAction string
153
154					caseAction = fmt.Sprintf("var dest %s\n"+
155						"if err := json.Unmarshal(bodyBytes, &dest); err != nil { \n"+
156						" return nil, err \n"+
157						"}\n"+
158						"response.%s = &dest",
159						typeDefinition.Schema.TypeDecl(),
160						typeDefinition.TypeName)
161
162					caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "json")
163					handledCaseClauses[caseKey] = caseClause
164				}
165
166			// YAML:
167			case StringInArray(contentTypeName, contentTypesYAML):
168				if typeDefinition.ContentTypeName == contentTypeName {
169					var caseAction string
170					caseAction = fmt.Sprintf("var dest %s\n"+
171						"if err := yaml.Unmarshal(bodyBytes, &dest); err != nil { \n"+
172						" return nil, err \n"+
173						"}\n"+
174						"response.%s = &dest",
175						typeDefinition.Schema.TypeDecl(),
176						typeDefinition.TypeName)
177					caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "yaml")
178					handledCaseClauses[caseKey] = caseClause
179				}
180
181			// XML:
182			case StringInArray(contentTypeName, contentTypesXML):
183				if typeDefinition.ContentTypeName == contentTypeName {
184					var caseAction string
185					caseAction = fmt.Sprintf("var dest %s\n"+
186						"if err := xml.Unmarshal(bodyBytes, &dest); err != nil { \n"+
187						" return nil, err \n"+
188						"}\n"+
189						"response.%s = &dest",
190						typeDefinition.Schema.TypeDecl(),
191						typeDefinition.TypeName)
192					caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "xml")
193					handledCaseClauses[caseKey] = caseClause
194				}
195
196			// Everything else:
197			default:
198				caseAction := fmt.Sprintf("// Content-type (%s) unsupported", contentTypeName)
199				caseClauseKey := "case " + getConditionOfResponseName("rsp.StatusCode", typeDefinition.ResponseName) + ":"
200				unhandledCaseClauses[prefixLeastSpecific+caseClauseKey] = fmt.Sprintf("%s\n%s\n", caseClauseKey, caseAction)
201			}
202		}
203	}
204
205	if len(handledCaseClauses)+len(unhandledCaseClauses) == 0 {
206		// switch would be empty.
207		return ""
208	}
209
210	// Now build the switch statement in order of most-to-least specific:
211	// See: https://github.com/deepmap/oapi-codegen/issues/127 for why we handle this in two separate
212	// groups.
213	fmt.Fprintf(buffer, "switch {\n")
214	for _, caseClauseKey := range SortedStringKeys(handledCaseClauses) {
215
216		fmt.Fprintf(buffer, "%s\n", handledCaseClauses[caseClauseKey])
217	}
218	for _, caseClauseKey := range SortedStringKeys(unhandledCaseClauses) {
219
220		fmt.Fprintf(buffer, "%s\n", unhandledCaseClauses[caseClauseKey])
221	}
222	fmt.Fprintf(buffer, "}\n")
223
224	return buffer.String()
225}
226
227// buildUnmarshalCase builds an unmarshalling case clause for different content-types:
228func buildUnmarshalCase(typeDefinition ResponseTypeDefinition, caseAction string, contentType string) (caseKey string, caseClause string) {
229	caseKey = fmt.Sprintf("%s.%s.%s", prefixLeastSpecific, contentType, typeDefinition.ResponseName)
230	caseClauseKey := getConditionOfResponseName("rsp.StatusCode", typeDefinition.ResponseName)
231	caseClause = fmt.Sprintf("case strings.Contains(rsp.Header.Get(\"%s\"), \"%s\") && %s:\n%s\n", echo.HeaderContentType, contentType, caseClauseKey, caseAction)
232	return caseKey, caseClause
233}
234
235// genResponseTypeName creates the name of generated response types (given the operationID):
236func genResponseTypeName(operationID string) string {
237	return fmt.Sprintf("%s%s", UppercaseFirstCharacter(operationID), responseTypeSuffix)
238}
239
240func getResponseTypeDefinitions(op *OperationDefinition) []ResponseTypeDefinition {
241	td, err := op.GetResponseTypeDefinitions()
242	if err != nil {
243		panic(err)
244	}
245	return td
246}
247
248// Return the statusCode comparison clause from the response name.
249func getConditionOfResponseName(statusCodeVar, responseName string) string {
250	switch responseName {
251	case "default":
252		return "true"
253	case "1XX", "2XX", "3XX", "4XX", "5XX":
254		return fmt.Sprintf("%s / 100 == %s", statusCodeVar, responseName[:1])
255	default:
256		return fmt.Sprintf("%s == %s", statusCodeVar, responseName)
257	}
258}
259
260// This outputs a string array
261func toStringArray(sarr []string) string {
262	return `[]string{"` + strings.Join(sarr, `","`) + `"}`
263}
264
265func stripNewLines(s string) string {
266	r := strings.NewReplacer("\n", "")
267	return r.Replace(s)
268}
269
270// This function map is passed to the template engine, and we can call each
271// function here by keyName from the template code.
272var TemplateFunctions = template.FuncMap{
273	"genParamArgs":               genParamArgs,
274	"genParamTypes":              genParamTypes,
275	"genParamNames":              genParamNames,
276	"genParamFmtString":          ReplacePathParamsWithStr,
277	"swaggerUriToEchoUri":        SwaggerUriToEchoUri,
278	"swaggerUriToChiUri":         SwaggerUriToChiUri,
279	"lcFirst":                    LowercaseFirstCharacter,
280	"ucFirst":                    UppercaseFirstCharacter,
281	"camelCase":                  ToCamelCase,
282	"genResponsePayload":         genResponsePayload,
283	"genResponseTypeName":        genResponseTypeName,
284	"genResponseUnmarshal":       genResponseUnmarshal,
285	"getResponseTypeDefinitions": getResponseTypeDefinitions,
286	"toStringArray":              toStringArray,
287	"lower":                      strings.ToLower,
288	"title":                      strings.Title,
289	"stripNewLines":              stripNewLines,
290	"sanitizeGoIdentity":         SanitizeGoIdentity,
291}
292