1package env 2 3import ( 4 "encoding" 5 "errors" 6 "fmt" 7 "net/url" 8 "os" 9 "reflect" 10 "strconv" 11 "strings" 12 "time" 13) 14 15// nolint: gochecknoglobals 16var ( 17 // ErrNotAStructPtr is returned if you pass something that is not a pointer to a 18 // Struct to Parse 19 ErrNotAStructPtr = errors.New("env: expected a pointer to a Struct") 20 21 defaultBuiltInParsers = map[reflect.Kind]ParserFunc{ 22 reflect.Bool: func(v string) (interface{}, error) { 23 return strconv.ParseBool(v) 24 }, 25 reflect.String: func(v string) (interface{}, error) { 26 return v, nil 27 }, 28 reflect.Int: func(v string) (interface{}, error) { 29 i, err := strconv.ParseInt(v, 10, 32) 30 return int(i), err 31 }, 32 reflect.Int16: func(v string) (interface{}, error) { 33 i, err := strconv.ParseInt(v, 10, 16) 34 return int16(i), err 35 }, 36 reflect.Int32: func(v string) (interface{}, error) { 37 i, err := strconv.ParseInt(v, 10, 32) 38 return int32(i), err 39 }, 40 reflect.Int64: func(v string) (interface{}, error) { 41 return strconv.ParseInt(v, 10, 64) 42 }, 43 reflect.Int8: func(v string) (interface{}, error) { 44 i, err := strconv.ParseInt(v, 10, 8) 45 return int8(i), err 46 }, 47 reflect.Uint: func(v string) (interface{}, error) { 48 i, err := strconv.ParseUint(v, 10, 32) 49 return uint(i), err 50 }, 51 reflect.Uint16: func(v string) (interface{}, error) { 52 i, err := strconv.ParseUint(v, 10, 16) 53 return uint16(i), err 54 }, 55 reflect.Uint32: func(v string) (interface{}, error) { 56 i, err := strconv.ParseUint(v, 10, 32) 57 return uint32(i), err 58 }, 59 reflect.Uint64: func(v string) (interface{}, error) { 60 i, err := strconv.ParseUint(v, 10, 64) 61 return i, err 62 }, 63 reflect.Uint8: func(v string) (interface{}, error) { 64 i, err := strconv.ParseUint(v, 10, 8) 65 return uint8(i), err 66 }, 67 reflect.Float64: func(v string) (interface{}, error) { 68 return strconv.ParseFloat(v, 64) 69 }, 70 reflect.Float32: func(v string) (interface{}, error) { 71 f, err := strconv.ParseFloat(v, 32) 72 return float32(f), err 73 }, 74 } 75 76 defaultTypeParsers = map[reflect.Type]ParserFunc{ 77 reflect.TypeOf(url.URL{}): func(v string) (interface{}, error) { 78 u, err := url.Parse(v) 79 if err != nil { 80 return nil, fmt.Errorf("unable parse URL: %v", err) 81 } 82 return *u, nil 83 }, 84 reflect.TypeOf(time.Nanosecond): func(v string) (interface{}, error) { 85 s, err := time.ParseDuration(v) 86 if err != nil { 87 return nil, fmt.Errorf("unable to parser duration: %v", err) 88 } 89 return s, err 90 }, 91 } 92) 93 94// ParserFunc defines the signature of a function that can be used within `CustomParsers` 95type ParserFunc func(v string) (interface{}, error) 96 97// Parse parses a struct containing `env` tags and loads its values from 98// environment variables. 99func Parse(v interface{}) error { 100 return ParseWithFuncs(v, map[reflect.Type]ParserFunc{}) 101} 102 103// ParseWithFuncs is the same as `Parse` except it also allows the user to pass 104// in custom parsers. 105func ParseWithFuncs(v interface{}, funcMap map[reflect.Type]ParserFunc) error { 106 ptrRef := reflect.ValueOf(v) 107 if ptrRef.Kind() != reflect.Ptr { 108 return ErrNotAStructPtr 109 } 110 ref := ptrRef.Elem() 111 if ref.Kind() != reflect.Struct { 112 return ErrNotAStructPtr 113 } 114 var parsers = defaultTypeParsers 115 for k, v := range funcMap { 116 parsers[k] = v 117 } 118 return doParse(ref, parsers) 119} 120 121func doParse(ref reflect.Value, funcMap map[reflect.Type]ParserFunc) error { 122 var refType = ref.Type() 123 124 for i := 0; i < refType.NumField(); i++ { 125 refField := ref.Field(i) 126 if !refField.CanSet() { 127 continue 128 } 129 if reflect.Ptr == refField.Kind() && !refField.IsNil() { 130 err := ParseWithFuncs(refField.Interface(), funcMap) 131 if err != nil { 132 return err 133 } 134 continue 135 } 136 refTypeField := refType.Field(i) 137 value, err := get(refTypeField) 138 if err != nil { 139 return err 140 } 141 if value == "" { 142 if reflect.Struct == refField.Kind() { 143 if err := doParse(refField, funcMap); err != nil { 144 return err 145 } 146 } 147 continue 148 } 149 if err := set(refField, refTypeField, value, funcMap); err != nil { 150 return err 151 } 152 } 153 return nil 154} 155 156func get(field reflect.StructField) (string, error) { 157 var ( 158 val string 159 err error 160 ) 161 162 key, opts := parseKeyForOption(field.Tag.Get("env")) 163 164 defaultValue := field.Tag.Get("envDefault") 165 val = getOr(key, defaultValue) 166 167 expandVar := field.Tag.Get("envExpand") 168 if strings.ToLower(expandVar) == "true" { 169 val = os.ExpandEnv(val) 170 } 171 172 if len(opts) > 0 { 173 for _, opt := range opts { 174 // The only option supported is "required". 175 switch opt { 176 case "": 177 break 178 case "required": 179 val, err = getRequired(key) 180 default: 181 err = fmt.Errorf("env: tag option %q not supported", opt) 182 } 183 } 184 } 185 186 return val, err 187} 188 189// split the env tag's key into the expected key and desired option, if any. 190func parseKeyForOption(key string) (string, []string) { 191 opts := strings.Split(key, ",") 192 return opts[0], opts[1:] 193} 194 195func getRequired(key string) (string, error) { 196 if value, ok := os.LookupEnv(key); ok { 197 return value, nil 198 } 199 return "", fmt.Errorf(`env: required environment variable %q is not set`, key) 200} 201 202func getOr(key, defaultValue string) string { 203 value, ok := os.LookupEnv(key) 204 if ok { 205 return value 206 } 207 return defaultValue 208} 209 210func set(field reflect.Value, sf reflect.StructField, value string, funcMap map[reflect.Type]ParserFunc) error { 211 if field.Kind() == reflect.Slice { 212 return handleSlice(field, value, sf, funcMap) 213 } 214 215 var tm = asTextUnmarshaler(field) 216 if tm != nil { 217 var err = tm.UnmarshalText([]byte(value)) 218 return newParseError(sf, err) 219 } 220 221 var typee = sf.Type 222 var fieldee = field 223 if typee.Kind() == reflect.Ptr { 224 typee = typee.Elem() 225 fieldee = field.Elem() 226 } 227 228 parserFunc, ok := funcMap[typee] 229 if ok { 230 val, err := parserFunc(value) 231 if err != nil { 232 return newParseError(sf, err) 233 } 234 235 fieldee.Set(reflect.ValueOf(val)) 236 return nil 237 } 238 239 parserFunc, ok = defaultBuiltInParsers[typee.Kind()] 240 if ok { 241 val, err := parserFunc(value) 242 if err != nil { 243 return newParseError(sf, err) 244 } 245 246 fieldee.Set(reflect.ValueOf(val).Convert(typee)) 247 return nil 248 } 249 250 return newNoParserError(sf) 251} 252 253func handleSlice(field reflect.Value, value string, sf reflect.StructField, funcMap map[reflect.Type]ParserFunc) error { 254 var separator = sf.Tag.Get("envSeparator") 255 if separator == "" { 256 separator = "," 257 } 258 var parts = strings.Split(value, separator) 259 260 var typee = sf.Type.Elem() 261 if typee.Kind() == reflect.Ptr { 262 typee = typee.Elem() 263 } 264 265 if _, ok := reflect.New(typee).Interface().(encoding.TextUnmarshaler); ok { 266 return parseTextUnmarshalers(field, parts, sf) 267 } 268 269 parserFunc, ok := funcMap[typee] 270 if !ok { 271 parserFunc, ok = defaultBuiltInParsers[typee.Kind()] 272 if !ok { 273 return newNoParserError(sf) 274 } 275 } 276 277 var result = reflect.MakeSlice(sf.Type, 0, len(parts)) 278 for _, part := range parts { 279 r, err := parserFunc(part) 280 if err != nil { 281 return newParseError(sf, err) 282 } 283 var v = reflect.ValueOf(r).Convert(typee) 284 if sf.Type.Elem().Kind() == reflect.Ptr { 285 v = reflect.New(typee) 286 v.Elem().Set(reflect.ValueOf(r).Convert(typee)) 287 } 288 result = reflect.Append(result, v) 289 } 290 field.Set(result) 291 return nil 292} 293 294func asTextUnmarshaler(field reflect.Value) encoding.TextUnmarshaler { 295 if reflect.Ptr == field.Kind() { 296 if field.IsNil() { 297 field.Set(reflect.New(field.Type().Elem())) 298 } 299 } else if field.CanAddr() { 300 field = field.Addr() 301 } 302 303 tm, ok := field.Interface().(encoding.TextUnmarshaler) 304 if !ok { 305 return nil 306 } 307 return tm 308} 309 310func parseTextUnmarshalers(field reflect.Value, data []string, sf reflect.StructField) error { 311 s := len(data) 312 elemType := field.Type().Elem() 313 slice := reflect.MakeSlice(reflect.SliceOf(elemType), s, s) 314 for i, v := range data { 315 sv := slice.Index(i) 316 kind := sv.Kind() 317 if kind == reflect.Ptr { 318 sv = reflect.New(elemType.Elem()) 319 } else { 320 sv = sv.Addr() 321 } 322 tm := sv.Interface().(encoding.TextUnmarshaler) 323 if err := tm.UnmarshalText([]byte(v)); err != nil { 324 return newParseError(sf, err) 325 } 326 if kind == reflect.Ptr { 327 slice.Index(i).Set(sv) 328 } 329 } 330 331 field.Set(slice) 332 333 return nil 334} 335 336func newParseError(sf reflect.StructField, err error) error { 337 if err == nil { 338 return nil 339 } 340 return parseError{ 341 sf: sf, 342 err: err, 343 } 344} 345 346type parseError struct { 347 sf reflect.StructField 348 err error 349} 350 351func (e parseError) Error() string { 352 return fmt.Sprintf(`env: parse error on field "%s" of type "%s": %v`, e.sf.Name, e.sf.Type, e.err) 353} 354 355func newNoParserError(sf reflect.StructField) error { 356 return fmt.Errorf(`env: no parser found for field "%s" of type "%s"`, sf.Name, sf.Type) 357} 358