1// Copyright 2019 DeepMap, Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14package codegen 15 16import ( 17 "bytes" 18 "fmt" 19 "os" 20 "strings" 21 "text/template" 22 23 "github.com/labstack/echo/v4" 24) 25 26const ( 27 // These allow the case statements to be sorted later: 28 prefixMostSpecific, prefixLessSpecific, prefixLeastSpecific = "3", "6", "9" 29 responseTypeSuffix = "Response" 30) 31 32var ( 33 contentTypesJSON = []string{echo.MIMEApplicationJSON, "text/x-json"} 34 contentTypesYAML = []string{"application/yaml", "application/x-yaml", "text/yaml", "text/x-yaml"} 35 contentTypesXML = []string{echo.MIMEApplicationXML, echo.MIMETextXML} 36) 37 38// This function takes an array of Parameter definition, and generates a valid 39// Go parameter declaration from them, eg: 40// ", foo int, bar string, baz float32". The preceding comma is there to save 41// a lot of work in the template engine. 42func genParamArgs(params []ParameterDefinition) string { 43 if len(params) == 0 { 44 return "" 45 } 46 parts := make([]string, len(params)) 47 for i, p := range params { 48 paramName := p.GoVariableName() 49 parts[i] = fmt.Sprintf("%s %s", paramName, p.TypeDef()) 50 } 51 return ", " + strings.Join(parts, ", ") 52} 53 54// This function is much like the one above, except it only produces the 55// types of the parameters for a type declaration. It would produce this 56// from the same input as above: 57// ", int, string, float32". 58func genParamTypes(params []ParameterDefinition) string { 59 if len(params) == 0 { 60 return "" 61 } 62 parts := make([]string, len(params)) 63 for i, p := range params { 64 parts[i] = p.TypeDef() 65 } 66 return ", " + strings.Join(parts, ", ") 67} 68 69// This is another variation of the function above which generates only the 70// parameter names: 71// ", foo, bar, baz" 72func genParamNames(params []ParameterDefinition) string { 73 if len(params) == 0 { 74 return "" 75 } 76 parts := make([]string, len(params)) 77 for i, p := range params { 78 parts[i] = p.GoVariableName() 79 } 80 return ", " + strings.Join(parts, ", ") 81} 82 83// genResponsePayload generates the payload returned at the end of each client request function 84func genResponsePayload(operationID string) string { 85 var buffer = bytes.NewBufferString("") 86 87 // Here is where we build up a response: 88 fmt.Fprintf(buffer, "&%s{\n", genResponseTypeName(operationID)) 89 fmt.Fprintf(buffer, "Body: bodyBytes,\n") 90 fmt.Fprintf(buffer, "HTTPResponse: rsp,\n") 91 fmt.Fprintf(buffer, "}") 92 93 return buffer.String() 94} 95 96// genResponseUnmarshal generates unmarshaling steps for structured response payloads 97func genResponseUnmarshal(op *OperationDefinition) string { 98 var handledCaseClauses = make(map[string]string) 99 var unhandledCaseClauses = make(map[string]string) 100 101 // Get the type definitions from the operation: 102 typeDefinitions, err := op.GetResponseTypeDefinitions() 103 if err != nil { 104 panic(err) 105 } 106 107 if len(typeDefinitions) == 0 { 108 // No types. 109 return "" 110 } 111 112 // Add a case for each possible response: 113 buffer := new(bytes.Buffer) 114 responses := op.Spec.Responses 115 for _, typeDefinition := range typeDefinitions { 116 117 responseRef, ok := responses[typeDefinition.ResponseName] 118 if !ok { 119 continue 120 } 121 122 // We can't do much without a value: 123 if responseRef.Value == nil { 124 fmt.Fprintf(os.Stderr, "Response %s.%s has nil value\n", op.OperationId, typeDefinition.ResponseName) 125 continue 126 } 127 128 // If there is no content-type then we have no unmarshaling to do: 129 if len(responseRef.Value.Content) == 0 { 130 caseAction := "break // No content-type" 131 caseClauseKey := "case " + getConditionOfResponseName("rsp.StatusCode", typeDefinition.ResponseName) + ":" 132 unhandledCaseClauses[prefixLeastSpecific+caseClauseKey] = fmt.Sprintf("%s\n%s\n", caseClauseKey, caseAction) 133 continue 134 } 135 136 // If we made it this far then we need to handle unmarshaling for each content-type: 137 sortedContentKeys := SortedContentKeys(responseRef.Value.Content) 138 for _, contentTypeName := range sortedContentKeys { 139 140 // We get "interface{}" when using "anyOf" or "oneOf" (which doesn't work with Go types): 141 if typeDefinition.TypeName == "interface{}" { 142 // Unable to unmarshal this, so we leave it out: 143 continue 144 } 145 146 // Add content-types here (json / yaml / xml etc): 147 switch { 148 149 // JSON: 150 case StringInArray(contentTypeName, contentTypesJSON): 151 if typeDefinition.ContentTypeName == contentTypeName { 152 var caseAction string 153 154 caseAction = fmt.Sprintf("var dest %s\n"+ 155 "if err := json.Unmarshal(bodyBytes, &dest); err != nil { \n"+ 156 " return nil, err \n"+ 157 "}\n"+ 158 "response.%s = &dest", 159 typeDefinition.Schema.TypeDecl(), 160 typeDefinition.TypeName) 161 162 caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "json") 163 handledCaseClauses[caseKey] = caseClause 164 } 165 166 // YAML: 167 case StringInArray(contentTypeName, contentTypesYAML): 168 if typeDefinition.ContentTypeName == contentTypeName { 169 var caseAction string 170 caseAction = fmt.Sprintf("var dest %s\n"+ 171 "if err := yaml.Unmarshal(bodyBytes, &dest); err != nil { \n"+ 172 " return nil, err \n"+ 173 "}\n"+ 174 "response.%s = &dest", 175 typeDefinition.Schema.TypeDecl(), 176 typeDefinition.TypeName) 177 caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "yaml") 178 handledCaseClauses[caseKey] = caseClause 179 } 180 181 // XML: 182 case StringInArray(contentTypeName, contentTypesXML): 183 if typeDefinition.ContentTypeName == contentTypeName { 184 var caseAction string 185 caseAction = fmt.Sprintf("var dest %s\n"+ 186 "if err := xml.Unmarshal(bodyBytes, &dest); err != nil { \n"+ 187 " return nil, err \n"+ 188 "}\n"+ 189 "response.%s = &dest", 190 typeDefinition.Schema.TypeDecl(), 191 typeDefinition.TypeName) 192 caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "xml") 193 handledCaseClauses[caseKey] = caseClause 194 } 195 196 // Everything else: 197 default: 198 caseAction := fmt.Sprintf("// Content-type (%s) unsupported", contentTypeName) 199 caseClauseKey := "case " + getConditionOfResponseName("rsp.StatusCode", typeDefinition.ResponseName) + ":" 200 unhandledCaseClauses[prefixLeastSpecific+caseClauseKey] = fmt.Sprintf("%s\n%s\n", caseClauseKey, caseAction) 201 } 202 } 203 } 204 205 if len(handledCaseClauses)+len(unhandledCaseClauses) == 0 { 206 // switch would be empty. 207 return "" 208 } 209 210 // Now build the switch statement in order of most-to-least specific: 211 // See: https://github.com/deepmap/oapi-codegen/issues/127 for why we handle this in two separate 212 // groups. 213 fmt.Fprintf(buffer, "switch {\n") 214 for _, caseClauseKey := range SortedStringKeys(handledCaseClauses) { 215 216 fmt.Fprintf(buffer, "%s\n", handledCaseClauses[caseClauseKey]) 217 } 218 for _, caseClauseKey := range SortedStringKeys(unhandledCaseClauses) { 219 220 fmt.Fprintf(buffer, "%s\n", unhandledCaseClauses[caseClauseKey]) 221 } 222 fmt.Fprintf(buffer, "}\n") 223 224 return buffer.String() 225} 226 227// buildUnmarshalCase builds an unmarshalling case clause for different content-types: 228func buildUnmarshalCase(typeDefinition ResponseTypeDefinition, caseAction string, contentType string) (caseKey string, caseClause string) { 229 caseKey = fmt.Sprintf("%s.%s.%s", prefixLeastSpecific, contentType, typeDefinition.ResponseName) 230 caseClauseKey := getConditionOfResponseName("rsp.StatusCode", typeDefinition.ResponseName) 231 caseClause = fmt.Sprintf("case strings.Contains(rsp.Header.Get(\"%s\"), \"%s\") && %s:\n%s\n", echo.HeaderContentType, contentType, caseClauseKey, caseAction) 232 return caseKey, caseClause 233} 234 235// genResponseTypeName creates the name of generated response types (given the operationID): 236func genResponseTypeName(operationID string) string { 237 return fmt.Sprintf("%s%s", UppercaseFirstCharacter(operationID), responseTypeSuffix) 238} 239 240func getResponseTypeDefinitions(op *OperationDefinition) []ResponseTypeDefinition { 241 td, err := op.GetResponseTypeDefinitions() 242 if err != nil { 243 panic(err) 244 } 245 return td 246} 247 248// Return the statusCode comparison clause from the response name. 249func getConditionOfResponseName(statusCodeVar, responseName string) string { 250 switch responseName { 251 case "default": 252 return "true" 253 case "1XX", "2XX", "3XX", "4XX", "5XX": 254 return fmt.Sprintf("%s / 100 == %s", statusCodeVar, responseName[:1]) 255 default: 256 return fmt.Sprintf("%s == %s", statusCodeVar, responseName) 257 } 258} 259 260// This outputs a string array 261func toStringArray(sarr []string) string { 262 return `[]string{"` + strings.Join(sarr, `","`) + `"}` 263} 264 265func stripNewLines(s string) string { 266 r := strings.NewReplacer("\n", "") 267 return r.Replace(s) 268} 269 270// This function map is passed to the template engine, and we can call each 271// function here by keyName from the template code. 272var TemplateFunctions = template.FuncMap{ 273 "genParamArgs": genParamArgs, 274 "genParamTypes": genParamTypes, 275 "genParamNames": genParamNames, 276 "genParamFmtString": ReplacePathParamsWithStr, 277 "swaggerUriToEchoUri": SwaggerUriToEchoUri, 278 "swaggerUriToChiUri": SwaggerUriToChiUri, 279 "lcFirst": LowercaseFirstCharacter, 280 "ucFirst": UppercaseFirstCharacter, 281 "camelCase": ToCamelCase, 282 "genResponsePayload": genResponsePayload, 283 "genResponseTypeName": genResponseTypeName, 284 "genResponseUnmarshal": genResponseUnmarshal, 285 "getResponseTypeDefinitions": getResponseTypeDefinitions, 286 "toStringArray": toStringArray, 287 "lower": strings.ToLower, 288 "title": strings.Title, 289 "stripNewLines": stripNewLines, 290 "sanitizeGoIdentity": SanitizeGoIdentity, 291} 292