1package awsspec
2
3import (
4	"errors"
5	"fmt"
6	"math/rand"
7	"reflect"
8	"strings"
9	"time"
10
11	"github.com/aws/aws-sdk-go/aws/awserr"
12	"github.com/wallix/awless/template/env"
13	"github.com/wallix/awless/template/params"
14
15	"github.com/fatih/color"
16	"github.com/wallix/awless/cloud"
17	"github.com/wallix/awless/logger"
18)
19
20const (
21	dryRunOperation = "DryRunOperation"
22	notFound        = "NotFound"
23)
24
25type BeforeRunner interface {
26	BeforeRun(env.Running) error
27}
28
29type AfterRunner interface {
30	AfterRun(env.Running, interface{}) error
31}
32
33type ResultExtractor interface {
34	ExtractResult(interface{}) string
35}
36
37type command interface {
38	ParamsSpec() params.Spec
39	inject(map[string]interface{}) error
40	Run(env.Running, map[string]interface{}) (interface{}, error)
41}
42
43func implementsBeforeRun(i interface{}) (BeforeRunner, bool) {
44	v, ok := i.(BeforeRunner)
45	return v, ok
46}
47
48func implementsAfterRun(i interface{}) (AfterRunner, bool) {
49	v, ok := i.(AfterRunner)
50	return v, ok
51}
52
53func implementsResultExtractor(i interface{}) (ResultExtractor, bool) {
54	v, ok := i.(ResultExtractor)
55	return v, ok
56}
57
58func fakeDryRunId(entity string) string {
59	suffix := rand.Intn(1e6)
60	switch entity {
61	case cloud.Instance:
62		return fmt.Sprintf("i-%d", suffix)
63	case cloud.Subnet:
64		return fmt.Sprintf("subnet-%d", suffix)
65	case cloud.Vpc:
66		return fmt.Sprintf("vpc-%d", suffix)
67	case cloud.Volume:
68		return fmt.Sprintf("vol-%d", suffix)
69	case cloud.SecurityGroup:
70		return fmt.Sprintf("sg-%d", suffix)
71	case cloud.InternetGateway:
72		return fmt.Sprintf("igw-%d", suffix)
73	case cloud.NatGateway:
74		return fmt.Sprintf("nat-%d", suffix)
75	case cloud.RouteTable:
76		return fmt.Sprintf("rtb-%d", suffix)
77	default:
78		return fmt.Sprintf("dryrunid-%d", suffix)
79	}
80}
81
82type awsCall struct {
83	fnName  string
84	fn      interface{}
85	logger  *logger.Logger
86	setters []setter
87}
88
89func (dc *awsCall) execute(input interface{}) (output interface{}, err error) {
90	defer func() {
91		if e := recover(); e != nil {
92			output = nil
93			err = fmt.Errorf("%s", e)
94		}
95	}()
96
97	for _, s := range dc.setters {
98		if err = s.set(input); err != nil {
99			return nil, err
100		}
101	}
102
103	fnVal := reflect.ValueOf(dc.fn)
104	values := []reflect.Value{reflect.ValueOf(input)}
105
106	start := time.Now()
107	results := fnVal.Call(values)
108
109	if err, ok := results[1].Interface().(error); ok && err != nil {
110		return nil, fmt.Errorf("%s", err)
111	}
112
113	dc.logger.ExtraVerbosef("%s call took %s", dc.fnName, time.Since(start))
114
115	output = results[0].Interface()
116
117	return
118}
119
120type checker struct {
121	description string
122	timeout     time.Duration
123	frequency   time.Duration
124	fetchFunc   func() (string, error)
125	expect      string
126	logger      *logger.Logger
127	checkName   string
128}
129
130func (c *checker) check() error {
131	now := time.Now().UTC()
132	timer := time.NewTimer(c.timeout)
133	if c.checkName == "" {
134		c.checkName = "status"
135	}
136	defer timer.Stop()
137	defer c.logger.Println()
138	for {
139		select {
140		case <-timer.C:
141			return fmt.Errorf("timeout of %s expired", c.timeout)
142		default:
143		}
144		got, err := c.fetchFunc()
145		if err != nil {
146			return fmt.Errorf("check %s: %s", c.description, err)
147		}
148		if strings.ToLower(got) == strings.ToLower(c.expect) {
149			c.logger.InteractiveInfof("check %s %s '%s' done", c.description, c.checkName, c.expect)
150			return nil
151		}
152		elapsed := time.Since(now)
153		c.logger.InteractiveInfof("%s %s '%s', expect '%s', timeout in %s (retry in %s)", c.description, c.checkName, got, c.expect, color.New(color.FgGreen).Sprint(c.timeout-elapsed.Round(time.Second)), c.frequency)
154		time.Sleep(c.frequency)
155	}
156}
157
158type enumValidator struct {
159	expected []string
160}
161
162func NewEnumValidator(expected ...string) *enumValidator {
163	return &enumValidator{expected: expected}
164}
165
166func (v *enumValidator) Validate(in *string) error {
167	val := strings.ToLower(StringValue(in))
168	for _, e := range v.expected {
169		if val == strings.ToLower(e) {
170			return nil
171		}
172	}
173	var expString string
174	switch len(v.expected) {
175	case 0:
176		return errors.New("empty enumeration")
177	case 1:
178		expString = fmt.Sprintf("'%s'", v.expected[0])
179	case 2:
180		expString = fmt.Sprintf("'%s' or '%s'", v.expected[0], v.expected[1])
181	default:
182		expString = fmt.Sprintf("'%s' or '%s'", strings.Join(v.expected[0:len(v.expected)-1], "', '"), v.expected[len(v.expected)-1])
183	}
184	return fmt.Errorf("invalid value '%s' expect %s", StringValue(in), expString)
185}
186
187func String(v string) *string {
188	return &v
189}
190
191func StringValue(v *string) string {
192	if v != nil {
193		return *v
194	}
195	return ""
196}
197
198func Int64(v int64) *int64 {
199	return &v
200}
201
202func Int64AsIntValue(v *int64) int {
203	if v != nil {
204		return int(*v)
205	}
206	return 0
207}
208
209func Bool(v bool) *bool {
210	return &v
211}
212
213func BoolValue(v *bool) bool {
214	if v != nil {
215		return *v
216	}
217	return false
218}
219
220func decorateAWSError(err error) error {
221	if aerr, ok := err.(awserr.Error); ok {
222		return fmt.Errorf("%s: %s", aerr.Code(), aerr.Message())
223	}
224	return err
225}
226