package environ
import (
"errors"
"strings"
)
var (
// ErrRequiredNotFound is the error for required variables that are not successfully loaded in the env
ErrRequiredNotFound = errors.New("is required but failed to load value")
// ErrLoading is the error for sources that are not loading properly
ErrLoading = errors.New("encountered error loading value")
// ErrInvalidFormat is the error for tags/params that are not in the correct format
ErrInvalidFormat = errors.New("has invalid format")
// ErrInvalidInput is the error for not providing a pointer to a struct
ErrInvalidInput = errors.New("must be a pointer to a struct")
// ErrUnsupportedType is the error for types that are not supported
ErrUnsupportedType = errors.New("has unsupported type")
// ErrUnsettableParam is the error for unsettable params, or unexported fields encountered in a struct
ErrUnsettableParam = errors.New("must be a settable parameter")
)
// EnvError implements the error interface with key infomation and some helpful text for fixing the issues with loading a config
type EnvError struct {
Err error
Key string
Extra string
}
// Error returns a user friendly error message in the format below
//
// env: <key> <err message> | extra: <extra>
func (e *EnvError) Error() string {
var sb strings.Builder
sb.WriteString("env: ")
sb.WriteString(e.Key)
sb.WriteString(" ")
sb.WriteString(e.Err.Error())
if e.Extra != "" {
sb.WriteString(" | extra: ")
sb.WriteString(e.Extra)
}
return sb.String()
}
func newError(err error, key, extra string) *EnvError {
return &EnvError{
Err: err,
Key: key,
Extra: extra,
}
}
package environ
import (
"os"
"reflect"
"strconv"
"strings"
"time"
)
const (
// loading tags
defaultTag = "default" // used to set a default value, any
envTag = "env" // used to get value from env, string
ssmTag = "ssm" // used to get value from AWS Parameter store, string
asmTag = "asm" // used to get value from AWS Secrets Manager, string
gsmTag = "gsm" // used to get value from GCP Secrets, string
swiftTag = "swift" // used to get value from Swift based storage
requiredTag = "required" // used to set requirements for env params, bool: causes errors when not loaded
// formatting tags
separatorTag = "separator" // used to select custom separators for slices and map items
kvSeparatorTag = "kv_separator" // used to select custom separators for key value pairs in maps
// defaults
defaultSeparator = ","
defaultKvSeparator = ":"
// misc helpers
durationUnits = "smh"
)
// Load fills the config with values based on tags provided on the struct
func Load(config any) error {
configStruct, err := validateConfig(config)
if err != nil {
return err
}
err = handleStruct(configStruct)
if err != nil {
return err
}
return nil
}
// validates that a config is a pointer to a struct
func validateConfig(config any) (reflect.Value, error) {
var output reflect.Value
ptrRef := reflect.ValueOf(config)
if ptrRef.Kind() != reflect.Ptr {
return output, newError(ErrInvalidInput, "config", "must be provided a pointer to a struct")
}
output = ptrRef.Elem()
if output.Kind() != reflect.Struct {
return output, newError(ErrInvalidInput, "config", "must be provided a pointer to a struct")
}
return output, nil
}
// wraps handling fields of a struct
func handleStruct(input reflect.Value) error {
var (
inputType = input.Type()
err error
)
for i := 0; i < input.NumField(); i++ {
var (
field = input.Field(i)
structField = inputType.Field(i)
)
if !field.CanSet() {
return newError(ErrUnsettableParam, structField.Name, "")
}
switch field.Kind() {
case reflect.Struct:
err = handleStruct(field)
default:
err = handleField(field, structField)
}
if err != nil {
return err
}
}
return nil
}
// wraps reading and setting a param value
func handleField(input reflect.Value, structField reflect.StructField) error {
value, err := getValue(structField)
if err != nil {
return err
}
if value != "" {
err = setValue(structField, input, value)
if err != nil {
return err
}
}
return nil
}
// reads value from env/stores based on field tags
func getValue(structField reflect.StructField) (string, error) {
var (
value = structField.Tag.Get(defaultTag)
required bool
loaded bool
err error
)
t, found := structField.Tag.Lookup(requiredTag)
if found {
required, err = strconv.ParseBool(t)
if err != nil {
return value, newError(ErrInvalidFormat, structField.Name, "required tag value is not a valid boolean representation")
}
}
// check env
t, found = structField.Tag.Lookup(envTag)
if found {
v := os.Getenv(t)
if v != "" {
loaded = true
value = v
}
}
// check if the field is required but not found/loaded
if required && !loaded {
return value, newError(ErrRequiredNotFound, structField.Name, "required field not loaded")
}
return value, nil
}
// set will set the loaded value to the param, or return an error
func setValue(structField reflect.StructField, param reflect.Value, value string) error {
switch param.Type().Kind() {
case reflect.Bool:
v, err := strconv.ParseBool(value)
if err != nil {
return newError(ErrInvalidFormat, structField.Name, "value is not a valid boolean representation")
}
param.SetBool(v)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var (
v int64
err error
)
// handle parsing a stringified time.Duration env value
if param.Type() == reflect.TypeOf(time.Duration(0)) && strings.ContainsAny(value, durationUnits) {
var dur time.Duration
dur, err = time.ParseDuration(value)
v = dur.Nanoseconds()
} else {
v, err = strconv.ParseInt(value, 0, param.Type().Bits())
}
if err != nil {
return newError(ErrInvalidFormat, structField.Name, "value is not a valid integer representation")
}
if v != 0 {
param.SetInt(v)
}
case reflect.Float32, reflect.Float64:
v, err := strconv.ParseFloat(value, param.Type().Bits())
if err != nil {
return newError(ErrInvalidFormat, structField.Name, "value is not a valid float representation")
}
param.SetFloat(v)
case reflect.Map:
var (
separator = getSeparator(structField.Tag)
values = strings.Split(value, separator)
kvSeparator = getKvSeparator(structField.Tag)
t = reflect.MakeMapWithSize(param.Type(), len(values))
)
for i := range values {
var (
kv = strings.Split(values[i], kvSeparator)
key = reflect.New(param.Type().Key()).Elem()
value = reflect.New(param.Type().Elem()).Elem()
)
if len(kv) != 2 {
return newError(ErrInvalidFormat, structField.Name, "a map item has more than one kv_separator")
}
err := setValue(structField, key, kv[0])
if err != nil {
return err
}
err = setValue(structField, value, kv[1])
if err != nil {
return err
}
t.SetMapIndex(key, value)
}
param.Set(t)
case reflect.Slice:
values := strings.Split(value, getSeparator(structField.Tag))
param.Grow(len(values))
param.SetCap(len(values))
param.SetLen(len(values))
for i := range values {
err := setValue(structField, param.Index(i), values[i])
if err != nil {
return err
}
}
case reflect.String:
param.SetString(value)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
val, err := strconv.ParseUint(value, 0, param.Type().Bits())
if err != nil {
return newError(ErrInvalidFormat, structField.Name, "value is not a valid uint representation")
}
param.SetUint(val)
default:
return newError(ErrUnsupportedType, structField.Name, "provided type is not supported in this version")
}
return nil
}
func getSeparator(structTag reflect.StructTag) string {
separator := defaultSeparator
// get the separator from the tags
if s, ok := structTag.Lookup(separatorTag); ok {
separator = s
}
return separator
}
func getKvSeparator(structTag reflect.StructTag) string {
separator := defaultKvSeparator
// get the separator from the tags
if s, ok := structTag.Lookup(kvSeparatorTag); ok {
separator = s
}
return separator
}