1package main
2
3import (
4	"go/ast"
5	"go/token"
6	"strings"
7)
8
9type method struct {
10	name            *ast.Ident
11	params          []arg
12	results         []arg
13	structsResolved bool
14}
15
16func (m method) definition(ifc iface) ast.Decl {
17	notImpl := fetchFuncDecl("ExampleEndpoint")
18
19	notImpl.Name = m.name
20	notImpl.Recv = fieldList(ifc.receiver())
21	scope := scopeWith(notImpl.Recv.List[0].Names[0].Name)
22	notImpl.Type.Params = m.funcParams(scope)
23	notImpl.Type.Results = m.funcResults()
24
25	return notImpl
26}
27
28func (m method) endpointMaker(ifc iface) ast.Decl {
29	endpointFn := fetchFuncDecl("makeExampleEndpoint")
30	scope := scopeWith("ctx", "req", ifc.receiverName().Name)
31
32	anonFunc := endpointFn.Body.List[0].(*ast.ReturnStmt).Results[0].(*ast.FuncLit)
33	if !m.hasContext() {
34		// strip context param from endpoint function
35		anonFunc.Type.Params.List = anonFunc.Type.Params.List[1:]
36	}
37
38	anonFunc = replaceIdent(anonFunc, "ExampleRequest", m.requestStructName()).(*ast.FuncLit)
39	callMethod := m.called(ifc, scope, "ctx", "req")
40	anonFunc.Body.List[1] = callMethod
41	anonFunc.Body.List[2].(*ast.ReturnStmt).Results[0] = m.wrapResult(callMethod.Lhs)
42
43	endpointFn.Body.List[0].(*ast.ReturnStmt).Results[0] = anonFunc
44	endpointFn.Name = m.endpointMakerName()
45	endpointFn.Type.Params = fieldList(ifc.receiver())
46	endpointFn.Type.Results = fieldList(typeField(sel(id("endpoint"), id("Endpoint"))))
47	return endpointFn
48}
49
50func (m method) pathName() string {
51	return "/" + strings.ToLower(m.name.Name)
52}
53
54func (m method) encodeFuncName() *ast.Ident {
55	return id("Encode" + m.name.Name + "Response")
56}
57
58func (m method) decodeFuncName() *ast.Ident {
59	return id("Decode" + m.name.Name + "Request")
60}
61
62func (m method) resultNames(scope *ast.Scope) []*ast.Ident {
63	ids := []*ast.Ident{}
64	for _, rz := range m.results {
65		ids = append(ids, rz.chooseName(scope))
66	}
67	return ids
68}
69
70func (m method) called(ifc iface, scope *ast.Scope, ctxName, spreadStruct string) *ast.AssignStmt {
71	m.resolveStructNames()
72
73	resNamesExpr := []ast.Expr{}
74	for _, r := range m.resultNames(scope) {
75		resNamesExpr = append(resNamesExpr, ast.Expr(r))
76	}
77
78	arglist := []ast.Expr{}
79	if m.hasContext() {
80		arglist = append(arglist, id(ctxName))
81	}
82	ssid := id(spreadStruct)
83	for _, f := range m.requestStructFields().List {
84		arglist = append(arglist, sel(ssid, f.Names[0]))
85	}
86
87	return &ast.AssignStmt{
88		Lhs: resNamesExpr,
89		Tok: token.DEFINE,
90		Rhs: []ast.Expr{
91			&ast.CallExpr{
92				Fun:  sel(ifc.receiverName(), m.name),
93				Args: arglist,
94			},
95		},
96	}
97}
98
99func (m method) wrapResult(results []ast.Expr) ast.Expr {
100	kvs := []ast.Expr{}
101	m.resolveStructNames()
102
103	for i, a := range m.results {
104		kvs = append(kvs, &ast.KeyValueExpr{
105			Key:   ast.NewIdent(export(a.asField.Name)),
106			Value: results[i],
107		})
108	}
109	return &ast.CompositeLit{
110		Type: m.responseStructName(),
111		Elts: kvs,
112	}
113}
114
115func (m method) resolveStructNames() {
116	if m.structsResolved {
117		return
118	}
119	m.structsResolved = true
120	scope := ast.NewScope(nil)
121	for i, p := range m.params {
122		p.asField = p.chooseName(scope)
123		m.params[i] = p
124	}
125	scope = ast.NewScope(nil)
126	for i, r := range m.results {
127		r.asField = r.chooseName(scope)
128		m.results[i] = r
129	}
130}
131
132func (m method) decoderFunc() ast.Decl {
133	fn := fetchFuncDecl("DecodeExampleRequest")
134	fn.Name = m.decodeFuncName()
135	fn = replaceIdent(fn, "ExampleRequest", m.requestStructName()).(*ast.FuncDecl)
136	return fn
137}
138
139func (m method) encoderFunc() ast.Decl {
140	fn := fetchFuncDecl("EncodeExampleResponse")
141	fn.Name = m.encodeFuncName()
142	return fn
143}
144
145func (m method) endpointMakerName() *ast.Ident {
146	return id("Make" + m.name.Name + "Endpoint")
147}
148
149func (m method) requestStruct() ast.Decl {
150	m.resolveStructNames()
151	return structDecl(m.requestStructName(), m.requestStructFields())
152}
153
154func (m method) responseStruct() ast.Decl {
155	m.resolveStructNames()
156	return structDecl(m.responseStructName(), m.responseStructFields())
157}
158
159func (m method) hasContext() bool {
160	if len(m.params) < 1 {
161		return false
162	}
163	carg := m.params[0].typ
164	// ugh. this is maybe okay for the one-off, but a general case for matching
165	// types would be helpful
166	if sel, is := carg.(*ast.SelectorExpr); is && sel.Sel.Name == "Context" {
167		if id, is := sel.X.(*ast.Ident); is && id.Name == "context" {
168			return true
169		}
170	}
171	return false
172}
173
174func (m method) nonContextParams() []arg {
175	if m.hasContext() {
176		return m.params[1:]
177	}
178	return m.params
179}
180
181func (m method) funcParams(scope *ast.Scope) *ast.FieldList {
182	parms := &ast.FieldList{}
183	if m.hasContext() {
184		parms.List = []*ast.Field{{
185			Names: []*ast.Ident{ast.NewIdent("ctx")},
186			Type:  sel(id("context"), id("Context")),
187		}}
188		scope.Insert(ast.NewObj(ast.Var, "ctx"))
189	}
190	parms.List = append(parms.List, mappedFieldList(func(a arg) *ast.Field {
191		return a.field(scope)
192	}, m.nonContextParams()...).List...)
193	return parms
194}
195
196func (m method) funcResults() *ast.FieldList {
197	return mappedFieldList(func(a arg) *ast.Field {
198		return a.result()
199	}, m.results...)
200}
201
202func (m method) requestStructName() *ast.Ident {
203	return id(export(m.name.Name) + "Request")
204}
205
206func (m method) requestStructFields() *ast.FieldList {
207	return mappedFieldList(func(a arg) *ast.Field {
208		return a.exported()
209	}, m.nonContextParams()...)
210}
211
212func (m method) responseStructName() *ast.Ident {
213	return id(export(m.name.Name) + "Response")
214}
215
216func (m method) responseStructFields() *ast.FieldList {
217	return mappedFieldList(func(a arg) *ast.Field {
218		return a.exported()
219	}, m.results...)
220}
221