1package env
2
3import (
4	"encoding"
5	"errors"
6	"fmt"
7	"net/url"
8	"os"
9	"reflect"
10	"strconv"
11	"strings"
12	"time"
13)
14
15// nolint: gochecknoglobals
16var (
17	// ErrNotAStructPtr is returned if you pass something that is not a pointer to a
18	// Struct to Parse
19	ErrNotAStructPtr = errors.New("env: expected a pointer to a Struct")
20
21	defaultBuiltInParsers = map[reflect.Kind]ParserFunc{
22		reflect.Bool: func(v string) (interface{}, error) {
23			return strconv.ParseBool(v)
24		},
25		reflect.String: func(v string) (interface{}, error) {
26			return v, nil
27		},
28		reflect.Int: func(v string) (interface{}, error) {
29			i, err := strconv.ParseInt(v, 10, 32)
30			return int(i), err
31		},
32		reflect.Int16: func(v string) (interface{}, error) {
33			i, err := strconv.ParseInt(v, 10, 16)
34			return int16(i), err
35		},
36		reflect.Int32: func(v string) (interface{}, error) {
37			i, err := strconv.ParseInt(v, 10, 32)
38			return int32(i), err
39		},
40		reflect.Int64: func(v string) (interface{}, error) {
41			return strconv.ParseInt(v, 10, 64)
42		},
43		reflect.Int8: func(v string) (interface{}, error) {
44			i, err := strconv.ParseInt(v, 10, 8)
45			return int8(i), err
46		},
47		reflect.Uint: func(v string) (interface{}, error) {
48			i, err := strconv.ParseUint(v, 10, 32)
49			return uint(i), err
50		},
51		reflect.Uint16: func(v string) (interface{}, error) {
52			i, err := strconv.ParseUint(v, 10, 16)
53			return uint16(i), err
54		},
55		reflect.Uint32: func(v string) (interface{}, error) {
56			i, err := strconv.ParseUint(v, 10, 32)
57			return uint32(i), err
58		},
59		reflect.Uint64: func(v string) (interface{}, error) {
60			i, err := strconv.ParseUint(v, 10, 64)
61			return i, err
62		},
63		reflect.Uint8: func(v string) (interface{}, error) {
64			i, err := strconv.ParseUint(v, 10, 8)
65			return uint8(i), err
66		},
67		reflect.Float64: func(v string) (interface{}, error) {
68			return strconv.ParseFloat(v, 64)
69		},
70		reflect.Float32: func(v string) (interface{}, error) {
71			f, err := strconv.ParseFloat(v, 32)
72			return float32(f), err
73		},
74	}
75
76	defaultTypeParsers = map[reflect.Type]ParserFunc{
77		reflect.TypeOf(url.URL{}): func(v string) (interface{}, error) {
78			u, err := url.Parse(v)
79			if err != nil {
80				return nil, fmt.Errorf("unable parse URL: %v", err)
81			}
82			return *u, nil
83		},
84		reflect.TypeOf(time.Nanosecond): func(v string) (interface{}, error) {
85			s, err := time.ParseDuration(v)
86			if err != nil {
87				return nil, fmt.Errorf("unable to parser duration: %v", err)
88			}
89			return s, err
90		},
91	}
92)
93
94// ParserFunc defines the signature of a function that can be used within `CustomParsers`
95type ParserFunc func(v string) (interface{}, error)
96
97// Parse parses a struct containing `env` tags and loads its values from
98// environment variables.
99func Parse(v interface{}) error {
100	return ParseWithFuncs(v, map[reflect.Type]ParserFunc{})
101}
102
103// ParseWithFuncs is the same as `Parse` except it also allows the user to pass
104// in custom parsers.
105func ParseWithFuncs(v interface{}, funcMap map[reflect.Type]ParserFunc) error {
106	ptrRef := reflect.ValueOf(v)
107	if ptrRef.Kind() != reflect.Ptr {
108		return ErrNotAStructPtr
109	}
110	ref := ptrRef.Elem()
111	if ref.Kind() != reflect.Struct {
112		return ErrNotAStructPtr
113	}
114	var parsers = defaultTypeParsers
115	for k, v := range funcMap {
116		parsers[k] = v
117	}
118	return doParse(ref, parsers)
119}
120
121func doParse(ref reflect.Value, funcMap map[reflect.Type]ParserFunc) error {
122	var refType = ref.Type()
123
124	for i := 0; i < refType.NumField(); i++ {
125		refField := ref.Field(i)
126		if !refField.CanSet() {
127			continue
128		}
129		if reflect.Ptr == refField.Kind() && !refField.IsNil() {
130			err := ParseWithFuncs(refField.Interface(), funcMap)
131			if err != nil {
132				return err
133			}
134			continue
135		}
136		refTypeField := refType.Field(i)
137		value, err := get(refTypeField)
138		if err != nil {
139			return err
140		}
141		if value == "" {
142			if reflect.Struct == refField.Kind() {
143				if err := doParse(refField, funcMap); err != nil {
144					return err
145				}
146			}
147			continue
148		}
149		if err := set(refField, refTypeField, value, funcMap); err != nil {
150			return err
151		}
152	}
153	return nil
154}
155
156func get(field reflect.StructField) (string, error) {
157	var (
158		val string
159		err error
160	)
161
162	key, opts := parseKeyForOption(field.Tag.Get("env"))
163
164	defaultValue := field.Tag.Get("envDefault")
165	val = getOr(key, defaultValue)
166
167	expandVar := field.Tag.Get("envExpand")
168	if strings.ToLower(expandVar) == "true" {
169		val = os.ExpandEnv(val)
170	}
171
172	if len(opts) > 0 {
173		for _, opt := range opts {
174			// The only option supported is "required".
175			switch opt {
176			case "":
177				break
178			case "required":
179				val, err = getRequired(key)
180			default:
181				err = fmt.Errorf("env: tag option %q not supported", opt)
182			}
183		}
184	}
185
186	return val, err
187}
188
189// split the env tag's key into the expected key and desired option, if any.
190func parseKeyForOption(key string) (string, []string) {
191	opts := strings.Split(key, ",")
192	return opts[0], opts[1:]
193}
194
195func getRequired(key string) (string, error) {
196	if value, ok := os.LookupEnv(key); ok {
197		return value, nil
198	}
199	return "", fmt.Errorf(`env: required environment variable %q is not set`, key)
200}
201
202func getOr(key, defaultValue string) string {
203	value, ok := os.LookupEnv(key)
204	if ok {
205		return value
206	}
207	return defaultValue
208}
209
210func set(field reflect.Value, sf reflect.StructField, value string, funcMap map[reflect.Type]ParserFunc) error {
211	if field.Kind() == reflect.Slice {
212		return handleSlice(field, value, sf, funcMap)
213	}
214
215	var tm = asTextUnmarshaler(field)
216	if tm != nil {
217		var err = tm.UnmarshalText([]byte(value))
218		return newParseError(sf, err)
219	}
220
221	var typee = sf.Type
222	var fieldee = field
223	if typee.Kind() == reflect.Ptr {
224		typee = typee.Elem()
225		fieldee = field.Elem()
226	}
227
228	parserFunc, ok := funcMap[typee]
229	if ok {
230		val, err := parserFunc(value)
231		if err != nil {
232			return newParseError(sf, err)
233		}
234
235		fieldee.Set(reflect.ValueOf(val))
236		return nil
237	}
238
239	parserFunc, ok = defaultBuiltInParsers[typee.Kind()]
240	if ok {
241		val, err := parserFunc(value)
242		if err != nil {
243			return newParseError(sf, err)
244		}
245
246		fieldee.Set(reflect.ValueOf(val).Convert(typee))
247		return nil
248	}
249
250	return newNoParserError(sf)
251}
252
253func handleSlice(field reflect.Value, value string, sf reflect.StructField, funcMap map[reflect.Type]ParserFunc) error {
254	var separator = sf.Tag.Get("envSeparator")
255	if separator == "" {
256		separator = ","
257	}
258	var parts = strings.Split(value, separator)
259
260	var typee = sf.Type.Elem()
261	if typee.Kind() == reflect.Ptr {
262		typee = typee.Elem()
263	}
264
265	if _, ok := reflect.New(typee).Interface().(encoding.TextUnmarshaler); ok {
266		return parseTextUnmarshalers(field, parts, sf)
267	}
268
269	parserFunc, ok := funcMap[typee]
270	if !ok {
271		parserFunc, ok = defaultBuiltInParsers[typee.Kind()]
272		if !ok {
273			return newNoParserError(sf)
274		}
275	}
276
277	var result = reflect.MakeSlice(sf.Type, 0, len(parts))
278	for _, part := range parts {
279		r, err := parserFunc(part)
280		if err != nil {
281			return newParseError(sf, err)
282		}
283		var v = reflect.ValueOf(r).Convert(typee)
284		if sf.Type.Elem().Kind() == reflect.Ptr {
285			v = reflect.New(typee)
286			v.Elem().Set(reflect.ValueOf(r).Convert(typee))
287		}
288		result = reflect.Append(result, v)
289	}
290	field.Set(result)
291	return nil
292}
293
294func asTextUnmarshaler(field reflect.Value) encoding.TextUnmarshaler {
295	if reflect.Ptr == field.Kind() {
296		if field.IsNil() {
297			field.Set(reflect.New(field.Type().Elem()))
298		}
299	} else if field.CanAddr() {
300		field = field.Addr()
301	}
302
303	tm, ok := field.Interface().(encoding.TextUnmarshaler)
304	if !ok {
305		return nil
306	}
307	return tm
308}
309
310func parseTextUnmarshalers(field reflect.Value, data []string, sf reflect.StructField) error {
311	s := len(data)
312	elemType := field.Type().Elem()
313	slice := reflect.MakeSlice(reflect.SliceOf(elemType), s, s)
314	for i, v := range data {
315		sv := slice.Index(i)
316		kind := sv.Kind()
317		if kind == reflect.Ptr {
318			sv = reflect.New(elemType.Elem())
319		} else {
320			sv = sv.Addr()
321		}
322		tm := sv.Interface().(encoding.TextUnmarshaler)
323		if err := tm.UnmarshalText([]byte(v)); err != nil {
324			return newParseError(sf, err)
325		}
326		if kind == reflect.Ptr {
327			slice.Index(i).Set(sv)
328		}
329	}
330
331	field.Set(slice)
332
333	return nil
334}
335
336func newParseError(sf reflect.StructField, err error) error {
337	if err == nil {
338		return nil
339	}
340	return parseError{
341		sf:  sf,
342		err: err,
343	}
344}
345
346type parseError struct {
347	sf  reflect.StructField
348	err error
349}
350
351func (e parseError) Error() string {
352	return fmt.Sprintf(`env: parse error on field "%s" of type "%s": %v`, e.sf.Name, e.sf.Type, e.err)
353}
354
355func newNoParserError(sf reflect.StructField) error {
356	return fmt.Errorf(`env: no parser found for field "%s" of type "%s"`, sf.Name, sf.Type)
357}
358