1// Copyright 2015 go-swagger maintainers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package generator
16
17import (
18	"bytes"
19	"encoding/json"
20	"errors"
21	"fmt"
22	"log"
23	"os"
24	"path/filepath"
25	"regexp"
26	goruntime "runtime"
27	"sort"
28	"strings"
29
30	"github.com/go-openapi/analysis"
31	"github.com/go-openapi/loads"
32	"github.com/go-openapi/runtime"
33	"github.com/go-openapi/spec"
34	"github.com/go-openapi/swag"
35)
36
37// GenerateServer generates a server application
38func GenerateServer(name string, modelNames, operationIDs []string, opts *GenOpts) error {
39	generator, err := newAppGenerator(name, modelNames, operationIDs, opts)
40	if err != nil {
41		return err
42	}
43	return generator.Generate()
44}
45
46// GenerateSupport generates the supporting files for an API
47func GenerateSupport(name string, modelNames, operationIDs []string, opts *GenOpts) error {
48
49	generator, err := newAppGenerator(name, modelNames, operationIDs, opts)
50	if err != nil {
51		return err
52	}
53	return generator.GenerateSupport(nil)
54}
55
56func newAppGenerator(name string, modelNames, operationIDs []string, opts *GenOpts) (*appGenerator, error) {
57	if opts == nil {
58		return nil, errors.New("gen opts are required")
59	}
60	if err := opts.EnsureDefaults(false); err != nil {
61		return nil, err
62	}
63
64	if opts.TemplateDir != "" {
65		if err := templates.LoadDir(opts.TemplateDir); err != nil {
66			return nil, err
67		}
68	}
69
70	// Load the spec
71	var err error
72	var specDoc *loads.Document
73	opts.Spec, specDoc, err = loadSpec(opts.Spec)
74	if err != nil {
75		return nil, err
76	}
77
78	// Validate if needed
79	if opts.ValidateSpec {
80		if err = validateSpec(opts.Spec, specDoc); err != nil {
81			return nil, err
82		}
83	}
84
85	analyzed := analysis.New(specDoc.Spec())
86
87	models, err := gatherModels(specDoc, modelNames)
88	if err != nil {
89		return nil, err
90	}
91
92	operations := gatherOperations(analyzed, operationIDs)
93	if len(operations) == 0 {
94		return nil, errors.New("no operations were selected")
95	}
96
97	defaultScheme := opts.DefaultScheme
98	if defaultScheme == "" {
99		defaultScheme = "http"
100	}
101
102	defaultProduces := opts.DefaultProduces
103	if defaultProduces == "" {
104		defaultProduces = runtime.JSONMime
105	}
106
107	defaultConsumes := opts.DefaultConsumes
108	if defaultConsumes == "" {
109		defaultConsumes = runtime.JSONMime
110	}
111
112	apiPackage := opts.LanguageOpts.MangleName(swag.ToFileName(opts.APIPackage), "api")
113	return &appGenerator{
114		Name:       appNameOrDefault(specDoc, name, "swagger"),
115		Receiver:   "o",
116		SpecDoc:    specDoc,
117		Analyzed:   analyzed,
118		Models:     models,
119		Operations: operations,
120		Target:     opts.Target,
121		// Package:       filepath.Base(opts.Target),
122		DumpData:        opts.DumpData,
123		Package:         apiPackage,
124		APIPackage:      apiPackage,
125		ModelsPackage:   opts.LanguageOpts.MangleName(swag.ToFileName(opts.ModelPackage), "definitions"),
126		ServerPackage:   opts.LanguageOpts.MangleName(swag.ToFileName(opts.ServerPackage), "server"),
127		ClientPackage:   opts.LanguageOpts.MangleName(swag.ToFileName(opts.ClientPackage), "client"),
128		Principal:       opts.Principal,
129		DefaultScheme:   defaultScheme,
130		DefaultProduces: defaultProduces,
131		DefaultConsumes: defaultConsumes,
132		GenOpts:         opts,
133	}, nil
134}
135
136type appGenerator struct {
137	Name            string
138	Receiver        string
139	SpecDoc         *loads.Document
140	Analyzed        *analysis.Spec
141	Package         string
142	APIPackage      string
143	ModelsPackage   string
144	ServerPackage   string
145	ClientPackage   string
146	Principal       string
147	Models          map[string]spec.Schema
148	Operations      map[string]opRef
149	Target          string
150	DumpData        bool
151	DefaultScheme   string
152	DefaultProduces string
153	DefaultConsumes string
154	GenOpts         *GenOpts
155}
156
157func baseImport(tgt string) string {
158	p, err := filepath.Abs(tgt)
159	if err != nil {
160		log.Fatalln(err)
161	}
162
163	var pth string
164	for _, gp := range filepath.SplitList(os.Getenv("GOPATH")) {
165		pp := filepath.Join(filepath.Clean(gp), "src")
166		var np, npp string
167		if goruntime.GOOS == "windows" {
168			np = strings.ToLower(p)
169			npp = strings.ToLower(pp)
170		}
171		if strings.HasPrefix(np, npp) {
172			pth, err = filepath.Rel(pp, p)
173			if err != nil {
174				log.Fatalln(err)
175			}
176			break
177		}
178	}
179
180	if pth == "" {
181		log.Fatalln("target must reside inside a location in the $GOPATH/src")
182	}
183	return pth
184}
185
186func (a *appGenerator) Generate() error {
187
188	app, err := a.makeCodegenApp()
189	if err != nil {
190		return err
191	}
192
193	if a.DumpData {
194		bb, err := json.MarshalIndent(app, "", "  ")
195		if err != nil {
196			return err
197		}
198		fmt.Fprintln(os.Stdout, string(bb))
199		return nil
200	}
201
202	// IPC removed concurrent execution because of the FuncMap that is being shared
203	// templates are now lazy loaded so there is concurrent map access I can't guard
204
205	// errChan := make(chan error, 100)
206	// wg := nsync.NewControlWaitGroup(20)
207
208	if a.GenOpts.IncludeModel {
209		log.Printf("rendering %d models", len(app.Models))
210		for _, mod := range app.Models {
211			// if len(errChan) > 0 {
212			// 	wg.Wait()
213			// 	return <-errChan
214			// }
215			modCopy := mod
216			// wg.Do(func() {
217			modCopy.IncludeValidator = true // a.GenOpts.IncludeValidator
218			modCopy.IncludeModel = true
219			if err := a.GenOpts.renderDefinition(&modCopy); err != nil {
220				return err
221			}
222			// })
223		}
224	}
225	// wg.Wait()
226
227	if a.GenOpts.IncludeHandler {
228		log.Printf("rendering %d operation groups (tags)", app.OperationGroups.Len())
229		for _, opg := range app.OperationGroups {
230			opgCopy := opg
231			log.Printf("rendering %d operations for %s", opg.Operations.Len(), opg.Name)
232			for _, op := range opgCopy.Operations {
233				// if len(errChan) > 0 {
234				// 	wg.Wait()
235				// 	return <-errChan
236				// }
237				opCopy := op
238				// wg.Do(func() {
239
240				if err := a.GenOpts.renderOperation(&opCopy); err != nil {
241					return err
242				}
243				// })
244			}
245		}
246	}
247
248	if a.GenOpts.IncludeSupport {
249		log.Printf("rendering support")
250		// wg.Do(func() {
251		if err := a.GenerateSupport(&app); err != nil {
252			// errChan <- err
253			return err
254		}
255		// })
256	}
257	// wg.Wait()
258	// if len(errChan) > 0 {
259	// 	return <-errChan
260	// }
261	return nil
262}
263
264func (a *appGenerator) GenerateSupport(ap *GenApp) error {
265	var app *GenApp
266	app = ap
267	if ap == nil {
268		ca, err := a.makeCodegenApp()
269		if err != nil {
270			return err
271		}
272		app = &ca
273	}
274
275	importPath := filepath.ToSlash(filepath.Join(baseImport(a.Target), a.ServerPackage, a.APIPackage))
276	app.DefaultImports = append(
277		app.DefaultImports,
278		filepath.ToSlash(filepath.Join(baseImport(a.Target), a.ServerPackage)),
279		importPath,
280	)
281
282	return a.GenOpts.renderApplication(app)
283}
284
285var mediaTypeNames = map[*regexp.Regexp]string{
286	regexp.MustCompile("application/.*json"):                "json",
287	regexp.MustCompile("application/.*yaml"):                "yaml",
288	regexp.MustCompile("application/.*protobuf"):            "protobuf",
289	regexp.MustCompile("application/.*capnproto"):           "capnproto",
290	regexp.MustCompile("application/.*thrift"):              "thrift",
291	regexp.MustCompile("(?:application|text)/.*xml"):        "xml",
292	regexp.MustCompile("text/.*markdown"):                   "markdown",
293	regexp.MustCompile("text/.*html"):                       "html",
294	regexp.MustCompile("text/.*csv"):                        "csv",
295	regexp.MustCompile("text/.*tsv"):                        "tsv",
296	regexp.MustCompile("text/.*javascript"):                 "js",
297	regexp.MustCompile("text/.*css"):                        "css",
298	regexp.MustCompile("text/.*plain"):                      "txt",
299	regexp.MustCompile("application/.*octet-stream"):        "bin",
300	regexp.MustCompile("application/.*tar"):                 "tar",
301	regexp.MustCompile("application/.*gzip"):                "gzip",
302	regexp.MustCompile("application/.*gz"):                  "gzip",
303	regexp.MustCompile("application/.*raw-stream"):          "bin",
304	regexp.MustCompile("application/x-www-form-urlencoded"): "urlform",
305	regexp.MustCompile("multipart/form-data"):               "multipartform",
306}
307
308var knownProducers = map[string]string{
309	"json":          "runtime.JSONProducer()",
310	"yaml":          "yamlpc.YAMLProducer()",
311	"xml":           "runtime.XMLProducer()",
312	"txt":           "runtime.TextProducer()",
313	"bin":           "runtime.ByteStreamProducer()",
314	"urlform":       "runtime.DiscardProducer",
315	"multipartform": "runtime.DiscardProducer",
316}
317
318var knownConsumers = map[string]string{
319	"json":          "runtime.JSONConsumer()",
320	"yaml":          "yamlpc.YAMLConsumer()",
321	"xml":           "runtime.XMLConsumer()",
322	"txt":           "runtime.TextConsumer()",
323	"bin":           "runtime.ByteStreamConsumer()",
324	"urlform":       "runtime.DiscardConsumer",
325	"multipartform": "runtime.DiscardConsumer",
326}
327
328func getSerializer(sers []GenSerGroup, ext string) (*GenSerGroup, bool) {
329	for i := range sers {
330		s := &sers[i]
331		if s.Name == ext {
332			return s, true
333		}
334	}
335	return nil, false
336}
337
338func mediaTypeName(tn string) (string, bool) {
339	for k, v := range mediaTypeNames {
340		if k.MatchString(tn) {
341			return v, true
342		}
343	}
344	return "", false
345}
346
347func (a *appGenerator) makeConsumes() (consumes GenSerGroups, consumesJSON bool) {
348	for _, cons := range a.Analyzed.RequiredConsumes() {
349		cn, ok := mediaTypeName(cons)
350		if !ok {
351			continue
352		}
353		nm := swag.ToJSONName(cn)
354		if nm == "json" {
355			consumesJSON = true
356		}
357
358		if ser, ok := getSerializer(consumes, cn); ok {
359			ser.AllSerializers = append(ser.AllSerializers, GenSerializer{
360				AppName:        ser.AppName,
361				ReceiverName:   ser.ReceiverName,
362				Name:           ser.Name,
363				MediaType:      cons,
364				Implementation: knownConsumers[nm],
365			})
366			sort.Sort(ser.AllSerializers)
367			continue
368		}
369
370		ser := GenSerializer{
371			AppName:        a.Name,
372			ReceiverName:   a.Receiver,
373			Name:           nm,
374			MediaType:      cons,
375			Implementation: knownConsumers[nm],
376		}
377
378		consumes = append(consumes, GenSerGroup{
379			AppName:        ser.AppName,
380			ReceiverName:   ser.ReceiverName,
381			Name:           ser.Name,
382			MediaType:      cons,
383			AllSerializers: []GenSerializer{ser},
384			Implementation: ser.Implementation,
385		})
386	}
387	if len(consumes) == 0 {
388		consumes = append(consumes, GenSerGroup{
389			AppName:      a.Name,
390			ReceiverName: a.Receiver,
391			Name:         "json",
392			MediaType:    runtime.JSONMime,
393			AllSerializers: []GenSerializer{GenSerializer{
394				AppName:        a.Name,
395				ReceiverName:   a.Receiver,
396				Name:           "json",
397				MediaType:      runtime.JSONMime,
398				Implementation: knownConsumers["json"],
399			}},
400			Implementation: knownConsumers["json"],
401		})
402		consumesJSON = true
403	}
404	sort.Sort(consumes)
405	return
406}
407
408func (a *appGenerator) makeProduces() (produces GenSerGroups, producesJSON bool) {
409	for _, prod := range a.Analyzed.RequiredProduces() {
410		pn, ok := mediaTypeName(prod)
411		if !ok {
412			continue
413		}
414		nm := swag.ToJSONName(pn)
415		if nm == "json" {
416			producesJSON = true
417		}
418
419		if ser, ok := getSerializer(produces, pn); ok {
420			ser.AllSerializers = append(ser.AllSerializers, GenSerializer{
421				AppName:        ser.AppName,
422				ReceiverName:   ser.ReceiverName,
423				Name:           ser.Name,
424				MediaType:      prod,
425				Implementation: knownProducers[nm],
426			})
427			sort.Sort(ser.AllSerializers)
428			continue
429		}
430
431		ser := GenSerializer{
432			AppName:        a.Name,
433			ReceiverName:   a.Receiver,
434			Name:           nm,
435			MediaType:      prod,
436			Implementation: knownProducers[nm],
437		}
438		produces = append(produces, GenSerGroup{
439			AppName:        ser.AppName,
440			ReceiverName:   ser.ReceiverName,
441			Name:           ser.Name,
442			MediaType:      prod,
443			Implementation: ser.Implementation,
444			AllSerializers: []GenSerializer{ser},
445		})
446	}
447	if len(produces) == 0 {
448		produces = append(produces, GenSerGroup{
449			AppName:      a.Name,
450			ReceiverName: a.Receiver,
451			Name:         "json",
452			MediaType:    runtime.JSONMime,
453			AllSerializers: []GenSerializer{GenSerializer{
454				AppName:        a.Name,
455				ReceiverName:   a.Receiver,
456				Name:           "json",
457				MediaType:      runtime.JSONMime,
458				Implementation: knownProducers["json"],
459			}},
460			Implementation: knownProducers["json"],
461		})
462		producesJSON = true
463	}
464	sort.Sort(produces)
465	return
466}
467
468func (a *appGenerator) makeSecuritySchemes() (security GenSecuritySchemes) {
469
470	prin := a.Principal
471	if prin == "" {
472		prin = "interface{}"
473	}
474	for _, scheme := range a.Analyzed.RequiredSecuritySchemes() {
475		if req, ok := a.SpecDoc.Spec().SecurityDefinitions[scheme]; ok {
476			isOAuth2 := strings.ToLower(req.Type) == "oauth2"
477			var scopes []string
478			if isOAuth2 {
479				for k := range req.Scopes {
480					scopes = append(scopes, k)
481				}
482			}
483
484			security = append(security, GenSecurityScheme{
485				AppName:      a.Name,
486				ID:           scheme,
487				ReceiverName: a.Receiver,
488				Name:         req.Name,
489				IsBasicAuth:  strings.ToLower(req.Type) == "basic",
490				IsAPIKeyAuth: strings.ToLower(req.Type) == "apikey",
491				IsOAuth2:     isOAuth2,
492				Scopes:       scopes,
493				Principal:    prin,
494				Source:       req.In,
495			})
496		}
497	}
498	sort.Sort(security)
499	return
500}
501
502func (a *appGenerator) makeCodegenApp() (GenApp, error) {
503	log.Println("building a plan for generation")
504	sw := a.SpecDoc.Spec()
505	receiver := a.Receiver
506
507	var defaultImports []string
508
509	jsonb, _ := json.MarshalIndent(sw, "", "  ")
510
511	consumes, _ := a.makeConsumes()
512	produces, _ := a.makeProduces()
513	sort.Sort(consumes)
514	sort.Sort(produces)
515	prin := a.Principal
516	if prin == "" {
517		prin = "interface{}"
518	}
519	security := a.makeSecuritySchemes()
520
521	var genMods []GenDefinition
522	importPath := filepath.ToSlash(filepath.Join(baseImport(a.Target), a.ModelsPackage))
523	defaultImports = append(defaultImports, importPath)
524
525	log.Println("planning definitions")
526	for mn, m := range a.Models {
527		mod, err := makeGenDefinition(
528			mn,
529			a.ModelsPackage,
530			m,
531			a.SpecDoc,
532			a.GenOpts,
533		)
534		if err != nil {
535			return GenApp{}, err
536		}
537		if mod != nil {
538			//mod.ReceiverName = receiver
539			genMods = append(genMods, *mod)
540		}
541	}
542
543	log.Println("planning operations")
544	tns := make(map[string]struct{})
545	var genOps GenOperations
546	for on, opp := range a.Operations {
547		o := opp.Op
548		o.Tags = pruneEmpty(o.Tags)
549		o.ID = on
550		var bldr codeGenOpBuilder
551		bldr.ModelsPackage = a.ModelsPackage
552		bldr.Principal = prin
553		bldr.Target = a.Target
554		bldr.DefaultImports = defaultImports
555		bldr.DefaultScheme = a.DefaultScheme
556		bldr.Doc = a.SpecDoc
557		bldr.Analyzed = a.Analyzed
558		bldr.BasePath = a.SpecDoc.BasePath()
559
560		// TODO: change operation name to something safe
561		bldr.Name = on
562		bldr.Operation = *o
563		bldr.Method = opp.Method
564		bldr.Path = opp.Path
565		bldr.Authed = len(a.Analyzed.SecurityRequirementsFor(o)) > 0
566		bldr.RootAPIPackage = swag.ToFileName(a.APIPackage)
567		bldr.WithContext = a.GenOpts != nil && a.GenOpts.WithContext
568
569		bldr.APIPackage = bldr.RootAPIPackage
570		st := o.Tags
571		if a.GenOpts != nil {
572			st = a.GenOpts.Tags
573		}
574		intersected := intersectTags(o.Tags, st)
575		if len(intersected) == 1 {
576			tag := intersected[0]
577			bldr.APIPackage = a.GenOpts.LanguageOpts.MangleName(swag.ToFileName(tag), a.APIPackage)
578			for _, t := range intersected {
579				tns[t] = struct{}{}
580			}
581		}
582		op, err := bldr.MakeOperation()
583		if err != nil {
584			return GenApp{}, err
585		}
586		op.ReceiverName = receiver
587		op.Tags = intersected
588		genOps = append(genOps, op)
589
590	}
591	for k := range tns {
592		importPath := filepath.ToSlash(filepath.Join(baseImport(a.Target), a.ServerPackage, a.APIPackage, swag.ToFileName(k)))
593		defaultImports = append(defaultImports, importPath)
594	}
595	sort.Sort(genOps)
596
597	log.Println("grouping operations into packages")
598	opsGroupedByTag := make(map[string]GenOperations)
599	for _, operation := range genOps {
600		if operation.Package == "" {
601			operation.Package = a.Package
602		}
603		opsGroupedByTag[operation.Package] = append(opsGroupedByTag[operation.Package], operation)
604	}
605
606	var opGroups GenOperationGroups
607	for k, v := range opsGroupedByTag {
608		sort.Sort(v)
609		opGroup := GenOperationGroup{
610			Name:           k,
611			Operations:     v,
612			DefaultImports: []string{filepath.ToSlash(filepath.Join(baseImport(a.Target), a.ModelsPackage))},
613			RootPackage:    a.APIPackage,
614			WithContext:    a.GenOpts != nil && a.GenOpts.WithContext,
615		}
616		opGroups = append(opGroups, opGroup)
617		var importPath string
618		if k == a.APIPackage {
619			importPath = filepath.ToSlash(filepath.Join(baseImport(a.Target), a.ServerPackage, a.APIPackage))
620		} else {
621			importPath = filepath.ToSlash(filepath.Join(baseImport(a.Target), a.ServerPackage, a.APIPackage, k))
622		}
623		defaultImports = append(defaultImports, importPath)
624	}
625	sort.Sort(opGroups)
626
627	log.Println("planning meta data and facades")
628
629	var collectedSchemes []string
630	var extraSchemes []string
631	for _, op := range genOps {
632		collectedSchemes = concatUnique(collectedSchemes, op.Schemes)
633		extraSchemes = concatUnique(extraSchemes, op.ExtraSchemes)
634	}
635	sort.Strings(collectedSchemes)
636	sort.Strings(extraSchemes)
637
638	host := "localhost"
639	if sw.Host != "" {
640		host = sw.Host
641	}
642
643	basePath := "/"
644	if sw.BasePath != "" {
645		basePath = sw.BasePath
646	}
647
648	return GenApp{
649		APIPackage:          a.ServerPackage,
650		Package:             a.Package,
651		ReceiverName:        receiver,
652		Name:                a.Name,
653		Host:                host,
654		BasePath:            basePath,
655		Schemes:             schemeOrDefault(collectedSchemes, a.DefaultScheme),
656		ExtraSchemes:        extraSchemes,
657		ExternalDocs:        sw.ExternalDocs,
658		Info:                sw.Info,
659		Consumes:            consumes,
660		Produces:            produces,
661		DefaultConsumes:     a.DefaultConsumes,
662		DefaultProduces:     a.DefaultProduces,
663		DefaultImports:      defaultImports,
664		SecurityDefinitions: security,
665		Models:              genMods,
666		Operations:          genOps,
667		OperationGroups:     opGroups,
668		Principal:           prin,
669		SwaggerJSON:         generateReadableSpec(jsonb),
670		ExcludeSpec:         a.GenOpts != nil && a.GenOpts.ExcludeSpec,
671		WithContext:         a.GenOpts != nil && a.GenOpts.WithContext,
672		GenOpts:             a.GenOpts,
673	}, nil
674}
675
676// generateReadableSpec makes swagger json spec as a string instead of bytes
677// the only character that needs to be escaped is '`' symbol, since it cannot be escaped in the GO string
678// that is quoted as `string data`. The function doesn't care about the beginning or the ending of the
679// string it escapes since all data that needs to be escaped is always in the middle of the swagger spec.
680func generateReadableSpec(spec []byte) string {
681	buf := &bytes.Buffer{}
682	for _, b := range string(spec) {
683		if b == '`' {
684			buf.WriteString("`+\"`\"+`")
685		} else {
686			buf.WriteRune(b)
687		}
688	}
689	return buf.String()
690}
691