1package main 2 3import ( 4 "go/ast" 5 "go/token" 6 "strings" 7) 8 9type method struct { 10 name *ast.Ident 11 params []arg 12 results []arg 13 structsResolved bool 14} 15 16func (m method) definition(ifc iface) ast.Decl { 17 notImpl := fetchFuncDecl("ExampleEndpoint") 18 19 notImpl.Name = m.name 20 notImpl.Recv = fieldList(ifc.receiver()) 21 scope := scopeWith(notImpl.Recv.List[0].Names[0].Name) 22 notImpl.Type.Params = m.funcParams(scope) 23 notImpl.Type.Results = m.funcResults() 24 25 return notImpl 26} 27 28func (m method) endpointMaker(ifc iface) ast.Decl { 29 endpointFn := fetchFuncDecl("makeExampleEndpoint") 30 scope := scopeWith("ctx", "req", ifc.receiverName().Name) 31 32 anonFunc := endpointFn.Body.List[0].(*ast.ReturnStmt).Results[0].(*ast.FuncLit) 33 if !m.hasContext() { 34 // strip context param from endpoint function 35 anonFunc.Type.Params.List = anonFunc.Type.Params.List[1:] 36 } 37 38 anonFunc = replaceIdent(anonFunc, "ExampleRequest", m.requestStructName()).(*ast.FuncLit) 39 callMethod := m.called(ifc, scope, "ctx", "req") 40 anonFunc.Body.List[1] = callMethod 41 anonFunc.Body.List[2].(*ast.ReturnStmt).Results[0] = m.wrapResult(callMethod.Lhs) 42 43 endpointFn.Body.List[0].(*ast.ReturnStmt).Results[0] = anonFunc 44 endpointFn.Name = m.endpointMakerName() 45 endpointFn.Type.Params = fieldList(ifc.receiver()) 46 endpointFn.Type.Results = fieldList(typeField(sel(id("endpoint"), id("Endpoint")))) 47 return endpointFn 48} 49 50func (m method) pathName() string { 51 return "/" + strings.ToLower(m.name.Name) 52} 53 54func (m method) encodeFuncName() *ast.Ident { 55 return id("Encode" + m.name.Name + "Response") 56} 57 58func (m method) decodeFuncName() *ast.Ident { 59 return id("Decode" + m.name.Name + "Request") 60} 61 62func (m method) resultNames(scope *ast.Scope) []*ast.Ident { 63 ids := []*ast.Ident{} 64 for _, rz := range m.results { 65 ids = append(ids, rz.chooseName(scope)) 66 } 67 return ids 68} 69 70func (m method) called(ifc iface, scope *ast.Scope, ctxName, spreadStruct string) *ast.AssignStmt { 71 m.resolveStructNames() 72 73 resNamesExpr := []ast.Expr{} 74 for _, r := range m.resultNames(scope) { 75 resNamesExpr = append(resNamesExpr, ast.Expr(r)) 76 } 77 78 arglist := []ast.Expr{} 79 if m.hasContext() { 80 arglist = append(arglist, id(ctxName)) 81 } 82 ssid := id(spreadStruct) 83 for _, f := range m.requestStructFields().List { 84 arglist = append(arglist, sel(ssid, f.Names[0])) 85 } 86 87 return &ast.AssignStmt{ 88 Lhs: resNamesExpr, 89 Tok: token.DEFINE, 90 Rhs: []ast.Expr{ 91 &ast.CallExpr{ 92 Fun: sel(ifc.receiverName(), m.name), 93 Args: arglist, 94 }, 95 }, 96 } 97} 98 99func (m method) wrapResult(results []ast.Expr) ast.Expr { 100 kvs := []ast.Expr{} 101 m.resolveStructNames() 102 103 for i, a := range m.results { 104 kvs = append(kvs, &ast.KeyValueExpr{ 105 Key: ast.NewIdent(export(a.asField.Name)), 106 Value: results[i], 107 }) 108 } 109 return &ast.CompositeLit{ 110 Type: m.responseStructName(), 111 Elts: kvs, 112 } 113} 114 115func (m method) resolveStructNames() { 116 if m.structsResolved { 117 return 118 } 119 m.structsResolved = true 120 scope := ast.NewScope(nil) 121 for i, p := range m.params { 122 p.asField = p.chooseName(scope) 123 m.params[i] = p 124 } 125 scope = ast.NewScope(nil) 126 for i, r := range m.results { 127 r.asField = r.chooseName(scope) 128 m.results[i] = r 129 } 130} 131 132func (m method) decoderFunc() ast.Decl { 133 fn := fetchFuncDecl("DecodeExampleRequest") 134 fn.Name = m.decodeFuncName() 135 fn = replaceIdent(fn, "ExampleRequest", m.requestStructName()).(*ast.FuncDecl) 136 return fn 137} 138 139func (m method) encoderFunc() ast.Decl { 140 fn := fetchFuncDecl("EncodeExampleResponse") 141 fn.Name = m.encodeFuncName() 142 return fn 143} 144 145func (m method) endpointMakerName() *ast.Ident { 146 return id("Make" + m.name.Name + "Endpoint") 147} 148 149func (m method) requestStruct() ast.Decl { 150 m.resolveStructNames() 151 return structDecl(m.requestStructName(), m.requestStructFields()) 152} 153 154func (m method) responseStruct() ast.Decl { 155 m.resolveStructNames() 156 return structDecl(m.responseStructName(), m.responseStructFields()) 157} 158 159func (m method) hasContext() bool { 160 if len(m.params) < 1 { 161 return false 162 } 163 carg := m.params[0].typ 164 // ugh. this is maybe okay for the one-off, but a general case for matching 165 // types would be helpful 166 if sel, is := carg.(*ast.SelectorExpr); is && sel.Sel.Name == "Context" { 167 if id, is := sel.X.(*ast.Ident); is && id.Name == "context" { 168 return true 169 } 170 } 171 return false 172} 173 174func (m method) nonContextParams() []arg { 175 if m.hasContext() { 176 return m.params[1:] 177 } 178 return m.params 179} 180 181func (m method) funcParams(scope *ast.Scope) *ast.FieldList { 182 parms := &ast.FieldList{} 183 if m.hasContext() { 184 parms.List = []*ast.Field{{ 185 Names: []*ast.Ident{ast.NewIdent("ctx")}, 186 Type: sel(id("context"), id("Context")), 187 }} 188 scope.Insert(ast.NewObj(ast.Var, "ctx")) 189 } 190 parms.List = append(parms.List, mappedFieldList(func(a arg) *ast.Field { 191 return a.field(scope) 192 }, m.nonContextParams()...).List...) 193 return parms 194} 195 196func (m method) funcResults() *ast.FieldList { 197 return mappedFieldList(func(a arg) *ast.Field { 198 return a.result() 199 }, m.results...) 200} 201 202func (m method) requestStructName() *ast.Ident { 203 return id(export(m.name.Name) + "Request") 204} 205 206func (m method) requestStructFields() *ast.FieldList { 207 return mappedFieldList(func(a arg) *ast.Field { 208 return a.exported() 209 }, m.nonContextParams()...) 210} 211 212func (m method) responseStructName() *ast.Ident { 213 return id(export(m.name.Name) + "Response") 214} 215 216func (m method) responseStructFields() *ast.FieldList { 217 return mappedFieldList(func(a arg) *ast.Field { 218 return a.exported() 219 }, m.results...) 220} 221