1package drivergen
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7	"sort"
8	"strings"
9	"text/template"
10	"unicode"
11
12	"github.com/pelletier/go-toml"
13)
14
15// Operation is the top-level configuration type. It's the direct representation of an operation
16// TOML file.
17type Operation struct {
18	Name           string
19	Documentation  string
20	Version        int
21	DriverInternal bool
22	Properties     Properties
23	Command        Command
24	Request        map[string]RequestField
25	Response       Response
26
27	pkg string
28}
29
30// PackageName returns the package name to use when generating the operation.
31func (op Operation) PackageName() string { return op.pkg }
32
33// Generate creates the operation type and associated response types and writes them to w.
34func (op Operation) Generate(w io.Writer) error {
35	t, err := template.New(op.Name + " operation").Parse(typeTemplate)
36	if err != nil {
37		return err
38	}
39	return t.Execute(w, op)
40}
41
42// ShortName returns the receiver used for this operation.
43func (op Operation) ShortName() string {
44	name := op.Name
45	if len(name) == 0 {
46		return ""
47	}
48	short := strings.ToLower(string(name[0]))
49	idx := 1
50	for {
51		i := strings.IndexFunc(name[idx:], unicode.IsUpper)
52		if i == -1 {
53			break
54		}
55		idx += i
56		short += strings.ToLower(string(name[idx]))
57		idx++
58	}
59	return short
60}
61
62// ResultType returns the type to use as the result of running this operation.
63func (op Operation) ResultType() string {
64	return op.Response.Name
65}
66
67// Title wraps strings.Title for use in templates.
68func (op Operation) Title(name string) string { return strings.Title(name) }
69
70// EscapeDocumentation will add the required // in front of each line of documentation.
71func (Operation) EscapeDocumentation(doc string) string {
72	var slc []string
73	for _, line := range strings.Split(doc, "\n") {
74		slc = append(slc, "// "+line)
75	}
76	return strings.Join(slc, "\n")
77}
78
79// ConstructorParameters builds the parameter names and types for the operation constructor.
80func (op Operation) ConstructorParameters() string {
81	var parameters []string
82	for name, field := range op.Request {
83		if !field.Constructor {
84			continue
85		}
86		parameters = append(parameters, name+" "+field.ParameterType())
87	}
88	return strings.Join(parameters, ", ")
89}
90
91// ConstructorFields returns a slice of name name pairs that set fields in a newly instantiated
92// operation.
93func (op Operation) ConstructorFields() []string {
94	var fields []string
95	for name, field := range op.Request {
96		if !field.Constructor {
97			continue
98		}
99		// either "name: name," or "name: &name,"
100		fieldName := name
101		if field.PointerType() {
102			fieldName = "&" + name
103		}
104		fields = append(fields, name+": "+fieldName+",")
105	}
106	return fields
107}
108
109// CommandMethod returns the code required to transform the operation into a command. This code only
110// returns the contents of the command method, without the function definition and return.
111func (op Operation) CommandMethod() (string, error) {
112	var buf bytes.Buffer
113	switch op.Command.Parameter {
114	case "collection":
115		tmpl := commandCollectionTmpl
116		if op.Command.Database {
117			tmpl = commandCollectionDatabaseTmpl
118		}
119		err := tmpl.Execute(&buf, op)
120		if err != nil {
121			return "", err
122		}
123	case "database":
124		tmpl := commandDatabaseTmpl
125		err := tmpl.Execute(&buf, op)
126		if err != nil {
127			return "", err
128		}
129	default:
130		var tmpl *template.Template
131		field, ok := op.Request[op.Command.Parameter]
132		if !ok {
133			return "", fmt.Errorf(
134				"no request field named '%s' but '%s' is specified as the command parameter",
135				op.Command.Parameter, op.Command.Parameter,
136			)
137		}
138		switch field.Type {
139		case "double":
140			tmpl = commandParamDoubleTmpl
141		case "string":
142			tmpl = commandParamStringTmpl
143		case "document":
144			tmpl = commandParamDocumentTmpl
145		case "array":
146			tmpl = commandParamArrayTmpl
147		case "boolean":
148			tmpl = commandParamBooleanTmpl
149		case "int32":
150			tmpl = commandParamInt32Tmpl
151		case "int64":
152			tmpl = commandParamInt64Tmpl
153		default:
154			return "", fmt.Errorf("unknown request field type %s", field.Type)
155		}
156		var rf struct {
157			ShortName              string
158			Name                   string
159			ParameterName          string
160			MinWireVersion         int
161			MinWireVersionRequired int
162		}
163		rf.ShortName = op.ShortName()
164		rf.Name = op.Command.Parameter
165		rf.ParameterName = op.Command.Name
166		rf.MinWireVersion = field.MinWireVersion
167		rf.MinWireVersionRequired = field.MinWireVersionRequired
168		err := tmpl.Execute(&buf, rf)
169		if err != nil {
170			return "", err
171		}
172	}
173	names := make([]string, 0, len(op.Request))
174	for name := range op.Request {
175		names = append(names, name)
176	}
177	sort.Strings(names)
178	for _, name := range names {
179		field := op.Request[name]
180		if name == op.Properties.Batches || field.Skip {
181			continue
182		}
183		var tmpl *template.Template
184		switch field.Type {
185		case "double":
186			tmpl = commandParamDoubleTmpl
187		case "string":
188			tmpl = commandParamStringTmpl
189		case "document":
190			tmpl = commandParamDocumentTmpl
191		case "array":
192			tmpl = commandParamArrayTmpl
193		case "boolean":
194			tmpl = commandParamBooleanTmpl
195		case "int32":
196			tmpl = commandParamInt32Tmpl
197		case "int64":
198			tmpl = commandParamInt64Tmpl
199		case "value":
200			tmpl = commandParamValueTmpl
201		default:
202			return "", fmt.Errorf("unknown request field type %s", field.Type)
203		}
204		var rf struct {
205			ShortName              string
206			Name                   string
207			ParameterName          string
208			MinWireVersion         int
209			MinWireVersionRequired int
210		}
211		rf.ShortName = op.ShortName()
212		rf.Name = name
213		rf.ParameterName = name
214		if field.KeyName != "" {
215			rf.ParameterName = field.KeyName
216		}
217		rf.MinWireVersion = field.MinWireVersion
218		rf.MinWireVersionRequired = field.MinWireVersionRequired
219		err := tmpl.Execute(&buf, rf)
220		if err != nil {
221			return "", err
222		}
223
224	}
225	return buf.String(), nil
226}
227
228// Properties represent general properties of the operation.
229type Properties struct {
230	Disabled                       []Builtin
231	Enabled                        []Builtin
232	Retryable                      Retryable
233	Batches                        string
234	Legacy                         LegacyOperation
235	MinimumWriteConcernWireVersion int
236	MinimumReadConcernWireVersion  int
237}
238
239// Builtins returns a slice of built-ins that is the combination of the non-disabled default
240// built-ins plus any enabled non-default built-ins.
241func (p Properties) Builtins() []Builtin {
242	defaults := map[Builtin]struct{}{
243		Deployment:     {},
244		Database:       {},
245		Selector:       {},
246		CommandMonitor: {},
247		ClientSession:  {},
248		ClusterClock:   {},
249		Collection:     {},
250		Crypt:          {},
251	}
252	for _, builtin := range p.Disabled {
253		delete(defaults, builtin)
254	}
255	builtins := make([]Builtin, 0, len(defaults)+len(p.Enabled))
256	// We don't do this in a loop because we want them to be in a stable order.
257	if _, ok := defaults[Deployment]; ok {
258		builtins = append(builtins, Deployment)
259	}
260	if _, ok := defaults[Database]; ok {
261		builtins = append(builtins, Database)
262	}
263	if _, ok := defaults[Selector]; ok {
264		builtins = append(builtins, Selector)
265	}
266	if _, ok := defaults[CommandMonitor]; ok {
267		builtins = append(builtins, CommandMonitor)
268	}
269	if _, ok := defaults[ClientSession]; ok {
270		builtins = append(builtins, ClientSession)
271	}
272	if _, ok := defaults[ClusterClock]; ok {
273		builtins = append(builtins, ClusterClock)
274	}
275	if _, ok := defaults[Collection]; ok {
276		builtins = append(builtins, Collection)
277	}
278	if _, ok := defaults[Crypt]; ok {
279		builtins = append(builtins, Crypt)
280	}
281	for _, builtin := range p.Enabled {
282		switch builtin {
283		case Deployment, Database, Selector, CommandMonitor, ClientSession, ClusterClock, Collection, Crypt:
284			continue // If someone added a default to enable, just ignore it.
285		}
286		builtins = append(builtins, builtin)
287	}
288	sort.Slice(builtins, func(i, j int) bool { return builtins[i] < builtins[j] })
289	return builtins
290}
291
292// ExecuteBuiltins returns the builtins that need to be set on the driver.Operation for the
293// properties set.
294func (p Properties) ExecuteBuiltins() []Builtin {
295	builtins := p.Builtins()
296	fields := make([]Builtin, 0, len(builtins))
297	for _, builtin := range builtins {
298		if builtin == Collection {
299			continue // We don't include this in execute.
300		}
301		fields = append(fields, builtin)
302	}
303	return fields
304}
305
306// IsEnabled returns a Builtin if the string that matches that built-in is enabled. If it's not, an
307// empty string is returned.
308func (p Properties) IsEnabled(builtin string) Builtin {
309	m := p.BuiltinsMap()
310	if b := m[Builtin(builtin)]; b {
311		return Builtin(builtin)
312	}
313	return ""
314}
315
316// BuiltinsMap returns a map with the builtins that enabled.
317func (p Properties) BuiltinsMap() map[Builtin]bool {
318	builtins := make(map[Builtin]bool)
319	for _, builtin := range p.Builtins() {
320		builtins[builtin] = true
321	}
322	return builtins
323}
324
325// LegacyOperationKind returns the corresponding LegacyOperationKind value for an operation.
326func (p Properties) LegacyOperationKind() string {
327	switch p.Legacy {
328	case LegacyFind:
329		return "driver.LegacyFind"
330	case LegacyGetMore:
331		return "driver.LegacyGetMore"
332	case LegacyKillCursors:
333		return "driver.LegacyKillCursors"
334	case LegacyListCollections:
335		return "driver.LegacyListCollections"
336	case LegacyListIndexes:
337		return "driver.LegacyListIndexes"
338	default:
339		return "driver.LegacyNone"
340	}
341}
342
343// Retryable represents retryable information for an operation.
344type Retryable struct {
345	Mode RetryableMode
346	Type RetryableType
347}
348
349// RetryableMode are the configuration representations of the retryability modes.
350type RetryableMode string
351
352// These constants are the various retryability modes.
353const (
354	RetryableOnce           RetryableMode = "once"
355	RetryableOncePerCommand RetryableMode = "once per command"
356	RetryableContext        RetryableMode = "context"
357)
358
359// RetryableType instances are the configuration representation of a kind of retryability.
360type RetryableType string
361
362// These constants are the various retryable types.
363const (
364	RetryableWrites RetryableType = "writes"
365	RetryableReads  RetryableType = "reads"
366)
367
368// LegacyOperation enables legacy versions of find, getMore, or killCursors operations.
369type LegacyOperation string
370
371// These constants are the various legacy operations that can be generated.
372const (
373	LegacyFind            LegacyOperation = "find"
374	LegacyGetMore         LegacyOperation = "getMore"
375	LegacyKillCursors     LegacyOperation = "killCursors"
376	LegacyListCollections LegacyOperation = "listCollections"
377	LegacyListIndexes     LegacyOperation = "listIndexes"
378)
379
380// Builtin represent types that are built into the IDL.
381type Builtin string
382
383// These constants are the built in types.
384const (
385	Collection     Builtin = "collection"
386	ReadPreference Builtin = "read preference"
387	ReadConcern    Builtin = "read concern"
388	WriteConcern   Builtin = "write concern"
389	CommandMonitor Builtin = "command monitor"
390	ClientSession  Builtin = "client session"
391	ClusterClock   Builtin = "cluster clock"
392	Selector       Builtin = "selector"
393	Database       Builtin = "database"
394	Deployment     Builtin = "deployment"
395	Crypt          Builtin = "crypt"
396)
397
398// ExecuteName provides the name used when setting this built-in on a driver.Operation.
399func (b Builtin) ExecuteName() string {
400	var execname string
401	switch b {
402	case ReadPreference:
403		execname = "ReadPreference"
404	case ReadConcern:
405		execname = "ReadConcern"
406	case WriteConcern:
407		execname = "WriteConcern"
408	case CommandMonitor:
409		execname = "CommandMonitor"
410	case ClientSession:
411		execname = "Client"
412	case ClusterClock:
413		execname = "Clock"
414	case Selector:
415		execname = "Selector"
416	case Database:
417		execname = "Database"
418	case Deployment:
419		execname = "Deployment"
420	case Crypt:
421		execname = "Crypt"
422	}
423	return execname
424}
425
426// ReferenceName returns the short name used to refer to this built in. It is used as the field name
427// in the struct and as the variable name for the setter.
428func (b Builtin) ReferenceName() string {
429	var refname string
430	switch b {
431	case Collection:
432		refname = "collection"
433	case ReadPreference:
434		refname = "readPreference"
435	case ReadConcern:
436		refname = "readConcern"
437	case WriteConcern:
438		refname = "writeConcern"
439	case CommandMonitor:
440		refname = "monitor"
441	case ClientSession:
442		refname = "session"
443	case ClusterClock:
444		refname = "clock"
445	case Selector:
446		refname = "selector"
447	case Database:
448		refname = "database"
449	case Deployment:
450		refname = "deployment"
451	case Crypt:
452		refname = "crypt"
453	}
454	return refname
455}
456
457// SetterName returns the name to be used when creating a setter for this built-in.
458func (b Builtin) SetterName() string {
459	var setter string
460	switch b {
461	case Collection:
462		setter = "Collection"
463	case ReadPreference:
464		setter = "ReadPreference"
465	case ReadConcern:
466		setter = "ReadConcern"
467	case WriteConcern:
468		setter = "WriteConcern"
469	case CommandMonitor:
470		setter = "CommandMonitor"
471	case ClientSession:
472		setter = "Session"
473	case ClusterClock:
474		setter = "ClusterClock"
475	case Selector:
476		setter = "ServerSelector"
477	case Database:
478		setter = "Database"
479	case Deployment:
480		setter = "Deployment"
481	case Crypt:
482		setter = "Crypt"
483	}
484	return setter
485}
486
487// Type returns the Go type for this built-in.
488func (b Builtin) Type() string {
489	var t string
490	switch b {
491	case Collection:
492		t = "string"
493	case ReadPreference:
494		t = "*readpref.ReadPref"
495	case ReadConcern:
496		t = "*readconcern.ReadConcern"
497	case WriteConcern:
498		t = "*writeconcern.WriteConcern"
499	case CommandMonitor:
500		t = "*event.CommandMonitor"
501	case ClientSession:
502		t = "*session.Client"
503	case ClusterClock:
504		t = "*session.ClusterClock"
505	case Selector:
506		t = "description.ServerSelector"
507	case Database:
508		t = "string"
509	case Deployment:
510		t = "driver.Deployment"
511	case Crypt:
512		t = "*driver.Crypt"
513	}
514	return t
515}
516
517// Documentation returns the GoDoc documentation for this built-in.
518func (b Builtin) Documentation() string {
519	var doc string
520	switch b {
521	case Collection:
522		doc = "Collection sets the collection that this command will run against."
523	case ReadPreference:
524		doc = "ReadPreference set the read prefernce used with this operation."
525	case ReadConcern:
526		doc = "ReadConcern specifies the read concern for this operation."
527	case WriteConcern:
528		doc = "WriteConcern sets the write concern for this operation."
529	case CommandMonitor:
530		doc = "CommandMonitor sets the monitor to use for APM events."
531	case ClientSession:
532		doc = "Session sets the session for this operation."
533	case ClusterClock:
534		doc = "ClusterClock sets the cluster clock for this operation."
535	case Selector:
536		doc = "ServerSelector sets the selector used to retrieve a server."
537	case Database:
538		doc = "Database sets the database to run this operation against."
539	case Deployment:
540		doc = "Deployment sets the deployment to use for this operation."
541	case Crypt:
542		doc = "Crypt sets the Crypt object to use for automatic encryption and decryption."
543	}
544	return doc
545}
546
547// Command holds the command serialization specific information for an operation.
548type Command struct {
549	Name      string
550	Parameter string
551	Database  bool
552}
553
554// RequestField represents an individual operation field.
555type RequestField struct {
556	Type                   string
557	Slice                  bool
558	Constructor            bool
559	Variadic               bool
560	Skip                   bool
561	Documentation          string
562	MinWireVersion         int
563	MinWireVersionRequired int
564	KeyName                string
565}
566
567// Command returns a string function that sets the key to name and value to the RequestField type.
568// It uses accessor to access the parameter. The accessor parameter should be the shortname of the
569// operation and the name of the field of the property used for the command name. For example, if
570// the shortname is "eo" and the field is "collection" then accessor should be "eo.collection".
571func (rf RequestField) Command(name, accessor string) string {
572	return ""
573}
574
575// ParameterType returns this field's type for use as a parameter argument.
576func (rf RequestField) ParameterType() string {
577	var param string
578	if rf.Slice && !rf.Variadic {
579		param = "[]"
580	}
581	if rf.Variadic {
582		param = "..."
583	}
584	switch rf.Type {
585	case "double":
586		param += "float64"
587	case "string":
588		param += "string"
589	case "document", "array":
590		param += "bsoncore.Document"
591	case "binary":
592	case "boolean":
593		param += "bool"
594	case "int32":
595		param += "int32"
596	case "int64":
597		param += "int64"
598	case "value":
599		param += "bsoncore.Value"
600	}
601	return param
602}
603
604// DeclarationType returns this field's type for use in a struct type declaration.
605func (rf RequestField) DeclarationType() string {
606	var decl string
607	switch rf.Type {
608	case "double", "string", "boolean", "int32", "int64":
609		decl = "*"
610	}
611	if rf.Slice {
612		decl = "[]"
613	}
614	switch rf.Type {
615	case "double":
616		decl += "float64"
617	case "string":
618		decl += "string"
619	case "document", "array":
620		decl += "bsoncore.Document"
621	case "binary":
622	case "boolean":
623		decl += "bool"
624	case "int32":
625		decl += "int32"
626	case "int64":
627		decl += "int64"
628	case "value":
629		decl += "bsoncore.Value"
630	}
631	return decl
632}
633
634// PointerType returns true if the request field is a pointer type and the setter should take the
635// address when setting via a setter method.
636func (rf RequestField) PointerType() bool {
637	switch rf.Type {
638	case "double", "string", "boolean", "int32", "int64":
639		return true
640	default:
641		return false
642	}
643}
644
645// Response represents a response type to generate.
646type Response struct {
647	Name  string
648	Type  string
649	Field map[string]ResponseField
650}
651
652// ResponseField is an individual field of a response.
653type ResponseField struct {
654	Type          string
655	Documentation string
656}
657
658// DeclarationType returns the field's type for use in a struct type declaration.
659func (rf ResponseField) DeclarationType() string {
660	switch rf.Type {
661	case "boolean":
662		return "bool"
663	case "value":
664		return "bsoncore.Value"
665	default:
666		return rf.Type
667	}
668}
669
670// BuiltinResponseType is the type used to define built in response types.
671type BuiltinResponseType string
672
673// These constants represents the different built in response types.
674const (
675	BatchCursor BuiltinResponseType = "batch cursor"
676)
677
678// BuildMethod handles creating the body of a method to create a response from a BSON response
679// document.
680//
681// TODO(GODRIVER-1094): This method is hacky because we're not using nested templates like we should
682// be. Each template should be registered and we should be calling the template to create it.
683func (r Response) BuildMethod() (string, error) {
684	var buf bytes.Buffer
685	names := make([]string, 0, len(r.Field))
686	for name := range r.Field {
687		names = append(names, name)
688	}
689	sort.Strings(names)
690	for _, name := range names {
691		field := r.Field[name]
692		var tmpl *template.Template
693		switch field.Type {
694		case "boolean":
695			tmpl = responseFieldBooleanTmpl
696		case "int32":
697			tmpl = responseFieldInt32Tmpl
698		case "int64":
699			tmpl = responseFieldInt64Tmpl
700		case "string":
701			tmpl = responseFieldStringTmpl
702		case "value":
703			tmpl = responseFieldValueTmpl
704		case "document":
705			tmpl = responseFieldDocumentTmpl
706		default:
707			return "", fmt.Errorf("unknown response field type %s", field.Type)
708		}
709		var rf struct {
710			ResponseName      string // Key of the BSON response.
711			ResponseShortName string // Receiver for the type being built.
712			Field             string // Name of the Go response type field.
713		}
714		rf.ResponseShortName = r.ShortName()
715		rf.ResponseName = name
716		rf.Field = strings.Title(name)
717		err := tmpl.Execute(&buf, rf)
718		if err != nil {
719			return "", err
720		}
721
722	}
723	return buf.String(), nil
724}
725
726// ShortName returns the short name used when constructing a response.
727func (r Response) ShortName() string {
728	name := r.Name
729	if len(name) == 0 {
730		return ""
731	}
732	short := strings.ToLower(string(name[0]))
733	idx := 1
734	for {
735		i := strings.IndexFunc(name[idx:], unicode.IsUpper)
736		if i == -1 {
737			break
738		}
739		idx += i
740		short += strings.ToLower(string(name[idx]))
741		idx++
742	}
743	return short
744}
745
746// ParseFile will construct an Operation using the TOML in filename. The Operation will have the
747// package name set to packagename.
748func ParseFile(filename, packagename string) (Operation, error) {
749	tree, err := toml.LoadFile(filename)
750	if err != nil {
751		return Operation{}, err
752	}
753	var op Operation
754	err = tree.Unmarshal(&op)
755	op.pkg = packagename
756	return op, err
757}
758