1package env
2
3import (
4	"encoding"
5	"errors"
6	"fmt"
7	"os"
8	"reflect"
9	"strconv"
10	"strings"
11	"time"
12)
13
14var (
15	// ErrNotAStructPtr is returned if you pass something that is not a pointer to a
16	// Struct to Parse
17	ErrNotAStructPtr = errors.New("Expected a pointer to a Struct")
18	// ErrUnsupportedType if the struct field type is not supported by env
19	ErrUnsupportedType = errors.New("Type is not supported")
20	// ErrUnsupportedSliceType if the slice element type is not supported by env
21	ErrUnsupportedSliceType = errors.New("Unsupported slice type")
22	// OnEnvVarSet is an optional convenience callback, such as for logging purposes.
23	// If not nil, it's called after successfully setting the given field from the given value.
24	OnEnvVarSet func(reflect.StructField, string)
25	// Friendly names for reflect types
26	sliceOfInts      = reflect.TypeOf([]int(nil))
27	sliceOfInt64s    = reflect.TypeOf([]int64(nil))
28	sliceOfUint64s   = reflect.TypeOf([]uint64(nil))
29	sliceOfStrings   = reflect.TypeOf([]string(nil))
30	sliceOfBools     = reflect.TypeOf([]bool(nil))
31	sliceOfFloat32s  = reflect.TypeOf([]float32(nil))
32	sliceOfFloat64s  = reflect.TypeOf([]float64(nil))
33	sliceOfDurations = reflect.TypeOf([]time.Duration(nil))
34)
35
36// CustomParsers is a friendly name for the type that `ParseWithFuncs()` accepts
37type CustomParsers map[reflect.Type]ParserFunc
38
39// ParserFunc defines the signature of a function that can be used within `CustomParsers`
40type ParserFunc func(v string) (interface{}, error)
41
42// Parse parses a struct containing `env` tags and loads its values from
43// environment variables.
44func Parse(v interface{}) error {
45	ptrRef := reflect.ValueOf(v)
46	if ptrRef.Kind() != reflect.Ptr {
47		return ErrNotAStructPtr
48	}
49	ref := ptrRef.Elem()
50	if ref.Kind() != reflect.Struct {
51		return ErrNotAStructPtr
52	}
53	return doParse(ref, make(map[reflect.Type]ParserFunc, 0))
54}
55
56// ParseWithFuncs is the same as `Parse` except it also allows the user to pass
57// in custom parsers.
58func ParseWithFuncs(v interface{}, funcMap CustomParsers) error {
59	ptrRef := reflect.ValueOf(v)
60	if ptrRef.Kind() != reflect.Ptr {
61		return ErrNotAStructPtr
62	}
63	ref := ptrRef.Elem()
64	if ref.Kind() != reflect.Struct {
65		return ErrNotAStructPtr
66	}
67	return doParse(ref, funcMap)
68}
69
70func doParse(ref reflect.Value, funcMap CustomParsers) error {
71	refType := ref.Type()
72	var errorList []string
73
74	for i := 0; i < refType.NumField(); i++ {
75		refField := ref.Field(i)
76		if reflect.Ptr == refField.Kind() && !refField.IsNil() && refField.CanSet() {
77			err := Parse(refField.Interface())
78			if nil != err {
79				return err
80			}
81			continue
82		}
83		refTypeField := refType.Field(i)
84		value, err := get(refTypeField)
85		if err != nil {
86			errorList = append(errorList, err.Error())
87			continue
88		}
89		if value == "" {
90			continue
91		}
92		if err := set(refField, refTypeField, value, funcMap); err != nil {
93			errorList = append(errorList, err.Error())
94			continue
95		}
96		if OnEnvVarSet != nil {
97			OnEnvVarSet(refTypeField, value)
98		}
99	}
100	if len(errorList) == 0 {
101		return nil
102	}
103	return errors.New(strings.Join(errorList, ". "))
104}
105
106func get(field reflect.StructField) (string, error) {
107	var (
108		val string
109		err error
110	)
111
112	key, opts := parseKeyForOption(field.Tag.Get("env"))
113
114	defaultValue := field.Tag.Get("envDefault")
115	val = getOr(key, defaultValue)
116
117	expandVar := field.Tag.Get("envExpand")
118	if strings.ToLower(expandVar) == "true" {
119		val = os.ExpandEnv(val)
120	}
121
122	if len(opts) > 0 {
123		for _, opt := range opts {
124			// The only option supported is "required".
125			switch opt {
126			case "":
127				break
128			case "required":
129				val, err = getRequired(key)
130			default:
131				err = fmt.Errorf("env tag option %q not supported", opt)
132			}
133		}
134	}
135
136	return val, err
137}
138
139// split the env tag's key into the expected key and desired option, if any.
140func parseKeyForOption(key string) (string, []string) {
141	opts := strings.Split(key, ",")
142	return opts[0], opts[1:]
143}
144
145func getRequired(key string) (string, error) {
146	if value, ok := os.LookupEnv(key); ok {
147		return value, nil
148	}
149	return "", fmt.Errorf("required environment variable %q is not set", key)
150}
151
152func getOr(key, defaultValue string) string {
153	value, ok := os.LookupEnv(key)
154	if ok {
155		return value
156	}
157	return defaultValue
158}
159
160func set(field reflect.Value, refType reflect.StructField, value string, funcMap CustomParsers) error {
161	// use custom parser if configured for this type
162	parserFunc, ok := funcMap[refType.Type]
163	if ok {
164		val, err := parserFunc(value)
165		if err != nil {
166			return fmt.Errorf("Custom parser error: %v", err)
167		}
168		field.Set(reflect.ValueOf(val))
169		return nil
170	}
171
172	// fall back to built-in parsers
173	switch field.Kind() {
174	case reflect.Slice:
175		separator := refType.Tag.Get("envSeparator")
176		return handleSlice(field, value, separator)
177	case reflect.String:
178		field.SetString(value)
179	case reflect.Bool:
180		bvalue, err := strconv.ParseBool(value)
181		if err != nil {
182			return err
183		}
184		field.SetBool(bvalue)
185	case reflect.Int:
186		intValue, err := strconv.ParseInt(value, 10, 32)
187		if err != nil {
188			return err
189		}
190		field.SetInt(intValue)
191	case reflect.Uint:
192		uintValue, err := strconv.ParseUint(value, 10, 32)
193		if err != nil {
194			return err
195		}
196		field.SetUint(uintValue)
197	case reflect.Float32:
198		v, err := strconv.ParseFloat(value, 32)
199		if err != nil {
200			return err
201		}
202		field.SetFloat(v)
203	case reflect.Float64:
204		v, err := strconv.ParseFloat(value, 64)
205		if err != nil {
206			return err
207		}
208		field.Set(reflect.ValueOf(v))
209	case reflect.Int64:
210		if refType.Type.String() == "time.Duration" {
211			dValue, err := time.ParseDuration(value)
212			if err != nil {
213				return err
214			}
215			field.Set(reflect.ValueOf(dValue))
216		} else {
217			intValue, err := strconv.ParseInt(value, 10, 64)
218			if err != nil {
219				return err
220			}
221			field.SetInt(intValue)
222		}
223	case reflect.Uint64:
224		uintValue, err := strconv.ParseUint(value, 10, 64)
225		if err != nil {
226			return err
227		}
228		field.SetUint(uintValue)
229	default:
230		return handleTextUnmarshaler(field, value)
231	}
232	return nil
233}
234
235func handleSlice(field reflect.Value, value, separator string) error {
236	if separator == "" {
237		separator = ","
238	}
239
240	splitData := strings.Split(value, separator)
241
242	switch field.Type() {
243	case sliceOfStrings:
244		field.Set(reflect.ValueOf(splitData))
245	case sliceOfInts:
246		intData, err := parseInts(splitData)
247		if err != nil {
248			return err
249		}
250		field.Set(reflect.ValueOf(intData))
251	case sliceOfInt64s:
252		int64Data, err := parseInt64s(splitData)
253		if err != nil {
254			return err
255		}
256		field.Set(reflect.ValueOf(int64Data))
257	case sliceOfUint64s:
258		uint64Data, err := parseUint64s(splitData)
259		if err != nil {
260			return err
261		}
262		field.Set(reflect.ValueOf(uint64Data))
263	case sliceOfFloat32s:
264		data, err := parseFloat32s(splitData)
265		if err != nil {
266			return err
267		}
268		field.Set(reflect.ValueOf(data))
269	case sliceOfFloat64s:
270		data, err := parseFloat64s(splitData)
271		if err != nil {
272			return err
273		}
274		field.Set(reflect.ValueOf(data))
275	case sliceOfBools:
276		boolData, err := parseBools(splitData)
277		if err != nil {
278			return err
279		}
280		field.Set(reflect.ValueOf(boolData))
281	case sliceOfDurations:
282		durationData, err := parseDurations(splitData)
283		if err != nil {
284			return err
285		}
286		field.Set(reflect.ValueOf(durationData))
287	default:
288		elemType := field.Type().Elem()
289		// Ensure we test *type as we can always address elements in a slice.
290		if elemType.Kind() == reflect.Ptr {
291			elemType = elemType.Elem()
292		}
293		if _, ok := reflect.New(elemType).Interface().(encoding.TextUnmarshaler); !ok {
294			return ErrUnsupportedSliceType
295		}
296		return parseTextUnmarshalers(field, splitData)
297
298	}
299	return nil
300}
301
302func handleTextUnmarshaler(field reflect.Value, value string) error {
303	if reflect.Ptr == field.Kind() {
304		if field.IsNil() {
305			field.Set(reflect.New(field.Type().Elem()))
306		}
307	} else if field.CanAddr() {
308		field = field.Addr()
309	}
310
311	tm, ok := field.Interface().(encoding.TextUnmarshaler)
312	if !ok {
313		return ErrUnsupportedType
314	}
315
316	return tm.UnmarshalText([]byte(value))
317}
318
319func parseInts(data []string) ([]int, error) {
320	intSlice := make([]int, 0, len(data))
321
322	for _, v := range data {
323		intValue, err := strconv.ParseInt(v, 10, 32)
324		if err != nil {
325			return nil, err
326		}
327		intSlice = append(intSlice, int(intValue))
328	}
329	return intSlice, nil
330}
331
332func parseInt64s(data []string) ([]int64, error) {
333	intSlice := make([]int64, 0, len(data))
334
335	for _, v := range data {
336		intValue, err := strconv.ParseInt(v, 10, 64)
337		if err != nil {
338			return nil, err
339		}
340		intSlice = append(intSlice, int64(intValue))
341	}
342	return intSlice, nil
343}
344
345func parseUint64s(data []string) ([]uint64, error) {
346	var uintSlice []uint64
347
348	for _, v := range data {
349		uintValue, err := strconv.ParseUint(v, 10, 64)
350		if err != nil {
351			return nil, err
352		}
353		uintSlice = append(uintSlice, uint64(uintValue))
354	}
355	return uintSlice, nil
356}
357
358func parseFloat32s(data []string) ([]float32, error) {
359	float32Slice := make([]float32, 0, len(data))
360
361	for _, v := range data {
362		data, err := strconv.ParseFloat(v, 32)
363		if err != nil {
364			return nil, err
365		}
366		float32Slice = append(float32Slice, float32(data))
367	}
368	return float32Slice, nil
369}
370
371func parseFloat64s(data []string) ([]float64, error) {
372	float64Slice := make([]float64, 0, len(data))
373
374	for _, v := range data {
375		data, err := strconv.ParseFloat(v, 64)
376		if err != nil {
377			return nil, err
378		}
379		float64Slice = append(float64Slice, float64(data))
380	}
381	return float64Slice, nil
382}
383
384func parseBools(data []string) ([]bool, error) {
385	boolSlice := make([]bool, 0, len(data))
386
387	for _, v := range data {
388		bvalue, err := strconv.ParseBool(v)
389		if err != nil {
390			return nil, err
391		}
392
393		boolSlice = append(boolSlice, bvalue)
394	}
395	return boolSlice, nil
396}
397
398func parseDurations(data []string) ([]time.Duration, error) {
399	durationSlice := make([]time.Duration, 0, len(data))
400
401	for _, v := range data {
402		dvalue, err := time.ParseDuration(v)
403		if err != nil {
404			return nil, err
405		}
406
407		durationSlice = append(durationSlice, dvalue)
408	}
409	return durationSlice, nil
410}
411
412func parseTextUnmarshalers(field reflect.Value, data []string) error {
413	s := len(data)
414	elemType := field.Type().Elem()
415	slice := reflect.MakeSlice(reflect.SliceOf(elemType), s, s)
416	for i, v := range data {
417		sv := slice.Index(i)
418		kind := sv.Kind()
419		if kind == reflect.Ptr {
420			sv = reflect.New(elemType.Elem())
421		} else {
422			sv = sv.Addr()
423		}
424		tm := sv.Interface().(encoding.TextUnmarshaler)
425		if err := tm.UnmarshalText([]byte(v)); err != nil {
426			return err
427		}
428		if kind == reflect.Ptr {
429			slice.Index(i).Set(sv)
430		}
431	}
432
433	field.Set(slice)
434
435	return nil
436}
437