1package env 2 3import ( 4 "encoding" 5 "errors" 6 "fmt" 7 "os" 8 "reflect" 9 "strconv" 10 "strings" 11 "time" 12) 13 14var ( 15 // ErrNotAStructPtr is returned if you pass something that is not a pointer to a 16 // Struct to Parse 17 ErrNotAStructPtr = errors.New("Expected a pointer to a Struct") 18 // ErrUnsupportedType if the struct field type is not supported by env 19 ErrUnsupportedType = errors.New("Type is not supported") 20 // ErrUnsupportedSliceType if the slice element type is not supported by env 21 ErrUnsupportedSliceType = errors.New("Unsupported slice type") 22 // OnEnvVarSet is an optional convenience callback, such as for logging purposes. 23 // If not nil, it's called after successfully setting the given field from the given value. 24 OnEnvVarSet func(reflect.StructField, string) 25 // Friendly names for reflect types 26 sliceOfInts = reflect.TypeOf([]int(nil)) 27 sliceOfInt64s = reflect.TypeOf([]int64(nil)) 28 sliceOfUint64s = reflect.TypeOf([]uint64(nil)) 29 sliceOfStrings = reflect.TypeOf([]string(nil)) 30 sliceOfBools = reflect.TypeOf([]bool(nil)) 31 sliceOfFloat32s = reflect.TypeOf([]float32(nil)) 32 sliceOfFloat64s = reflect.TypeOf([]float64(nil)) 33 sliceOfDurations = reflect.TypeOf([]time.Duration(nil)) 34) 35 36// CustomParsers is a friendly name for the type that `ParseWithFuncs()` accepts 37type CustomParsers map[reflect.Type]ParserFunc 38 39// ParserFunc defines the signature of a function that can be used within `CustomParsers` 40type ParserFunc func(v string) (interface{}, error) 41 42// Parse parses a struct containing `env` tags and loads its values from 43// environment variables. 44func Parse(v interface{}) error { 45 ptrRef := reflect.ValueOf(v) 46 if ptrRef.Kind() != reflect.Ptr { 47 return ErrNotAStructPtr 48 } 49 ref := ptrRef.Elem() 50 if ref.Kind() != reflect.Struct { 51 return ErrNotAStructPtr 52 } 53 return doParse(ref, make(map[reflect.Type]ParserFunc, 0)) 54} 55 56// ParseWithFuncs is the same as `Parse` except it also allows the user to pass 57// in custom parsers. 58func ParseWithFuncs(v interface{}, funcMap CustomParsers) error { 59 ptrRef := reflect.ValueOf(v) 60 if ptrRef.Kind() != reflect.Ptr { 61 return ErrNotAStructPtr 62 } 63 ref := ptrRef.Elem() 64 if ref.Kind() != reflect.Struct { 65 return ErrNotAStructPtr 66 } 67 return doParse(ref, funcMap) 68} 69 70func doParse(ref reflect.Value, funcMap CustomParsers) error { 71 refType := ref.Type() 72 var errorList []string 73 74 for i := 0; i < refType.NumField(); i++ { 75 refField := ref.Field(i) 76 if reflect.Ptr == refField.Kind() && !refField.IsNil() && refField.CanSet() { 77 err := Parse(refField.Interface()) 78 if nil != err { 79 return err 80 } 81 continue 82 } 83 refTypeField := refType.Field(i) 84 value, err := get(refTypeField) 85 if err != nil { 86 errorList = append(errorList, err.Error()) 87 continue 88 } 89 if value == "" { 90 continue 91 } 92 if err := set(refField, refTypeField, value, funcMap); err != nil { 93 errorList = append(errorList, err.Error()) 94 continue 95 } 96 if OnEnvVarSet != nil { 97 OnEnvVarSet(refTypeField, value) 98 } 99 } 100 if len(errorList) == 0 { 101 return nil 102 } 103 return errors.New(strings.Join(errorList, ". ")) 104} 105 106func get(field reflect.StructField) (string, error) { 107 var ( 108 val string 109 err error 110 ) 111 112 key, opts := parseKeyForOption(field.Tag.Get("env")) 113 114 defaultValue := field.Tag.Get("envDefault") 115 val = getOr(key, defaultValue) 116 117 expandVar := field.Tag.Get("envExpand") 118 if strings.ToLower(expandVar) == "true" { 119 val = os.ExpandEnv(val) 120 } 121 122 if len(opts) > 0 { 123 for _, opt := range opts { 124 // The only option supported is "required". 125 switch opt { 126 case "": 127 break 128 case "required": 129 val, err = getRequired(key) 130 default: 131 err = fmt.Errorf("env tag option %q not supported", opt) 132 } 133 } 134 } 135 136 return val, err 137} 138 139// split the env tag's key into the expected key and desired option, if any. 140func parseKeyForOption(key string) (string, []string) { 141 opts := strings.Split(key, ",") 142 return opts[0], opts[1:] 143} 144 145func getRequired(key string) (string, error) { 146 if value, ok := os.LookupEnv(key); ok { 147 return value, nil 148 } 149 return "", fmt.Errorf("required environment variable %q is not set", key) 150} 151 152func getOr(key, defaultValue string) string { 153 value, ok := os.LookupEnv(key) 154 if ok { 155 return value 156 } 157 return defaultValue 158} 159 160func set(field reflect.Value, refType reflect.StructField, value string, funcMap CustomParsers) error { 161 // use custom parser if configured for this type 162 parserFunc, ok := funcMap[refType.Type] 163 if ok { 164 val, err := parserFunc(value) 165 if err != nil { 166 return fmt.Errorf("Custom parser error: %v", err) 167 } 168 field.Set(reflect.ValueOf(val)) 169 return nil 170 } 171 172 // fall back to built-in parsers 173 switch field.Kind() { 174 case reflect.Slice: 175 separator := refType.Tag.Get("envSeparator") 176 return handleSlice(field, value, separator) 177 case reflect.String: 178 field.SetString(value) 179 case reflect.Bool: 180 bvalue, err := strconv.ParseBool(value) 181 if err != nil { 182 return err 183 } 184 field.SetBool(bvalue) 185 case reflect.Int: 186 intValue, err := strconv.ParseInt(value, 10, 32) 187 if err != nil { 188 return err 189 } 190 field.SetInt(intValue) 191 case reflect.Uint: 192 uintValue, err := strconv.ParseUint(value, 10, 32) 193 if err != nil { 194 return err 195 } 196 field.SetUint(uintValue) 197 case reflect.Float32: 198 v, err := strconv.ParseFloat(value, 32) 199 if err != nil { 200 return err 201 } 202 field.SetFloat(v) 203 case reflect.Float64: 204 v, err := strconv.ParseFloat(value, 64) 205 if err != nil { 206 return err 207 } 208 field.Set(reflect.ValueOf(v)) 209 case reflect.Int64: 210 if refType.Type.String() == "time.Duration" { 211 dValue, err := time.ParseDuration(value) 212 if err != nil { 213 return err 214 } 215 field.Set(reflect.ValueOf(dValue)) 216 } else { 217 intValue, err := strconv.ParseInt(value, 10, 64) 218 if err != nil { 219 return err 220 } 221 field.SetInt(intValue) 222 } 223 case reflect.Uint64: 224 uintValue, err := strconv.ParseUint(value, 10, 64) 225 if err != nil { 226 return err 227 } 228 field.SetUint(uintValue) 229 default: 230 return handleTextUnmarshaler(field, value) 231 } 232 return nil 233} 234 235func handleSlice(field reflect.Value, value, separator string) error { 236 if separator == "" { 237 separator = "," 238 } 239 240 splitData := strings.Split(value, separator) 241 242 switch field.Type() { 243 case sliceOfStrings: 244 field.Set(reflect.ValueOf(splitData)) 245 case sliceOfInts: 246 intData, err := parseInts(splitData) 247 if err != nil { 248 return err 249 } 250 field.Set(reflect.ValueOf(intData)) 251 case sliceOfInt64s: 252 int64Data, err := parseInt64s(splitData) 253 if err != nil { 254 return err 255 } 256 field.Set(reflect.ValueOf(int64Data)) 257 case sliceOfUint64s: 258 uint64Data, err := parseUint64s(splitData) 259 if err != nil { 260 return err 261 } 262 field.Set(reflect.ValueOf(uint64Data)) 263 case sliceOfFloat32s: 264 data, err := parseFloat32s(splitData) 265 if err != nil { 266 return err 267 } 268 field.Set(reflect.ValueOf(data)) 269 case sliceOfFloat64s: 270 data, err := parseFloat64s(splitData) 271 if err != nil { 272 return err 273 } 274 field.Set(reflect.ValueOf(data)) 275 case sliceOfBools: 276 boolData, err := parseBools(splitData) 277 if err != nil { 278 return err 279 } 280 field.Set(reflect.ValueOf(boolData)) 281 case sliceOfDurations: 282 durationData, err := parseDurations(splitData) 283 if err != nil { 284 return err 285 } 286 field.Set(reflect.ValueOf(durationData)) 287 default: 288 elemType := field.Type().Elem() 289 // Ensure we test *type as we can always address elements in a slice. 290 if elemType.Kind() == reflect.Ptr { 291 elemType = elemType.Elem() 292 } 293 if _, ok := reflect.New(elemType).Interface().(encoding.TextUnmarshaler); !ok { 294 return ErrUnsupportedSliceType 295 } 296 return parseTextUnmarshalers(field, splitData) 297 298 } 299 return nil 300} 301 302func handleTextUnmarshaler(field reflect.Value, value string) error { 303 if reflect.Ptr == field.Kind() { 304 if field.IsNil() { 305 field.Set(reflect.New(field.Type().Elem())) 306 } 307 } else if field.CanAddr() { 308 field = field.Addr() 309 } 310 311 tm, ok := field.Interface().(encoding.TextUnmarshaler) 312 if !ok { 313 return ErrUnsupportedType 314 } 315 316 return tm.UnmarshalText([]byte(value)) 317} 318 319func parseInts(data []string) ([]int, error) { 320 intSlice := make([]int, 0, len(data)) 321 322 for _, v := range data { 323 intValue, err := strconv.ParseInt(v, 10, 32) 324 if err != nil { 325 return nil, err 326 } 327 intSlice = append(intSlice, int(intValue)) 328 } 329 return intSlice, nil 330} 331 332func parseInt64s(data []string) ([]int64, error) { 333 intSlice := make([]int64, 0, len(data)) 334 335 for _, v := range data { 336 intValue, err := strconv.ParseInt(v, 10, 64) 337 if err != nil { 338 return nil, err 339 } 340 intSlice = append(intSlice, int64(intValue)) 341 } 342 return intSlice, nil 343} 344 345func parseUint64s(data []string) ([]uint64, error) { 346 var uintSlice []uint64 347 348 for _, v := range data { 349 uintValue, err := strconv.ParseUint(v, 10, 64) 350 if err != nil { 351 return nil, err 352 } 353 uintSlice = append(uintSlice, uint64(uintValue)) 354 } 355 return uintSlice, nil 356} 357 358func parseFloat32s(data []string) ([]float32, error) { 359 float32Slice := make([]float32, 0, len(data)) 360 361 for _, v := range data { 362 data, err := strconv.ParseFloat(v, 32) 363 if err != nil { 364 return nil, err 365 } 366 float32Slice = append(float32Slice, float32(data)) 367 } 368 return float32Slice, nil 369} 370 371func parseFloat64s(data []string) ([]float64, error) { 372 float64Slice := make([]float64, 0, len(data)) 373 374 for _, v := range data { 375 data, err := strconv.ParseFloat(v, 64) 376 if err != nil { 377 return nil, err 378 } 379 float64Slice = append(float64Slice, float64(data)) 380 } 381 return float64Slice, nil 382} 383 384func parseBools(data []string) ([]bool, error) { 385 boolSlice := make([]bool, 0, len(data)) 386 387 for _, v := range data { 388 bvalue, err := strconv.ParseBool(v) 389 if err != nil { 390 return nil, err 391 } 392 393 boolSlice = append(boolSlice, bvalue) 394 } 395 return boolSlice, nil 396} 397 398func parseDurations(data []string) ([]time.Duration, error) { 399 durationSlice := make([]time.Duration, 0, len(data)) 400 401 for _, v := range data { 402 dvalue, err := time.ParseDuration(v) 403 if err != nil { 404 return nil, err 405 } 406 407 durationSlice = append(durationSlice, dvalue) 408 } 409 return durationSlice, nil 410} 411 412func parseTextUnmarshalers(field reflect.Value, data []string) error { 413 s := len(data) 414 elemType := field.Type().Elem() 415 slice := reflect.MakeSlice(reflect.SliceOf(elemType), s, s) 416 for i, v := range data { 417 sv := slice.Index(i) 418 kind := sv.Kind() 419 if kind == reflect.Ptr { 420 sv = reflect.New(elemType.Elem()) 421 } else { 422 sv = sv.Addr() 423 } 424 tm := sv.Interface().(encoding.TextUnmarshaler) 425 if err := tm.UnmarshalText([]byte(v)); err != nil { 426 return err 427 } 428 if kind == reflect.Ptr { 429 slice.Index(i).Set(sv) 430 } 431 } 432 433 field.Set(slice) 434 435 return nil 436} 437