Skip to content

Commit ce10418

Browse files
authored
Recursively construct toml key (#7)
1 parent 43633a3 commit ce10418

File tree

2 files changed

+90
-7
lines changed

2 files changed

+90
-7
lines changed

config.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func Load(filepath string, dst interface{}) error {
3131
return err
3232
}
3333

34-
return bindFlags(dst, metadata)
34+
return bindFlags(dst, metadata, "")
3535
}
3636

3737
// bindEnvVariables will bind CLI flags to their respective elements in dst, defined by the struct-tag "env".
@@ -65,7 +65,7 @@ func bindEnvVariables(dst interface{}) error {
6565
}
6666

6767
// bindFlags will bind CLI flags to their respective elements in dst, defined by the struct-tag "flag".
68-
func bindFlags(dst interface{}, metadata toml.MetaData) error {
68+
func bindFlags(dst interface{}, metadata toml.MetaData, fieldPath string) error {
6969
fields := structs.Fields(dst)
7070
for _, field := range fields {
7171
tag := field.Tag(flagTag)
@@ -75,7 +75,18 @@ func bindFlags(dst interface{}, metadata toml.MetaData) error {
7575
continue
7676
}
7777

78-
if err := bindFlags(dstElem.Addr().Interface(), metadata); err != nil {
78+
var path string
79+
if fieldPath != "" {
80+
path = fmt.Sprintf("%s.", fieldPath)
81+
}
82+
83+
if field.Tag(tomlTag) != "" {
84+
path += field.Tag(tomlTag)
85+
} else {
86+
path += field.Name()
87+
}
88+
89+
if err := bindFlags(dstElem.Addr().Interface(), metadata, path); err != nil {
7990
return err
8091
}
8192

@@ -91,7 +102,15 @@ func bindFlags(dst interface{}, metadata toml.MetaData) error {
91102
useFlagDefaultValue := false
92103
if !isFlagSet(tag) {
93104
_, envHasKey := os.LookupEnv(field.Tag(envTag))
94-
if envHasKey || tomlHasKey(metadata, field.Tag(tomlTag)) {
105+
106+
var tomlKey string
107+
if fieldPath == "" {
108+
tomlKey = field.Tag(tomlTag)
109+
} else {
110+
tomlKey = fmt.Sprintf("%s.%s", fieldPath, field.Tag(tomlTag))
111+
}
112+
113+
if envHasKey || tomlHasKey(metadata, tomlKey) {
95114
continue
96115
} else {
97116
useFlagDefaultValue = true
@@ -190,10 +209,10 @@ func isFlagSet(tag string) bool {
190209
return flagSet
191210
}
192211

193-
// tomlHasKey will check if the tag presents in toml metadata
194-
func tomlHasKey(metadata toml.MetaData, tag string) bool {
212+
// tomlHasKey will check if the toml key presents in toml metadata
213+
func tomlHasKey(metadata toml.MetaData, tomlKey string) bool {
195214
for _, key := range metadata.Keys() {
196-
if strings.ToLower(key.String()) == strings.ToLower(tag) {
215+
if strings.EqualFold(key.String(), tomlKey) {
197216
return true
198217
}
199218
}

config_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,70 @@ LogLevel = "debug"
428428
}
429429
}
430430

431+
func TestLoad_TomlNested_FlagSetAndNotGiven(t *testing.T) {
432+
var cfg struct {
433+
DB struct {
434+
Account string `toml:"account" flag:"db-account"`
435+
Username string `toml:"username" flag:"db-user"`
436+
Credentials struct {
437+
Secret string `toml:"secret" flag:"db-secret"`
438+
Password string `toml:"password" flag:"db-password"`
439+
} `toml:"credentials"`
440+
Options *struct {
441+
Port int `toml:"port" flag:"db-port"`
442+
}
443+
} `toml:"database"`
444+
}
445+
tmp, _ := ioutil.TempFile("", "")
446+
defer os.Remove(tmp.Name())
447+
448+
_, err := tmp.WriteString(`
449+
[database]
450+
account = "test_account"
451+
username = "test_user"
452+
[database.credentials]
453+
secret = "wowowow"
454+
password = "12345"
455+
[database.options]
456+
port = 3306
457+
`)
458+
if err != nil {
459+
t.Fatalf("unexpected error: %v", err)
460+
}
461+
462+
fs := flag.NewFlagSet("tmp", flag.ExitOnError)
463+
_ = fs.String("db-account", "default", "")
464+
_ = fs.String("db-user", "default", "")
465+
_ = fs.String("db-secret", "default", "")
466+
_ = fs.String("db-password", "default", "")
467+
_ = fs.Int("db-port", 0, "")
468+
flag.CommandLine = fs
469+
470+
if err := Load(tmp.Name(), &cfg); err != nil {
471+
t.Fatalf("unexpected error %v", err)
472+
}
473+
474+
if cfg.DB.Account != "test_account" {
475+
t.Errorf("got: %v, expected: %v", cfg.DB.Account, "test_account")
476+
}
477+
478+
if cfg.DB.Username != "test_user" {
479+
t.Errorf("got: %v, expected: %v", cfg.DB.Username, "test_user")
480+
}
481+
482+
if cfg.DB.Credentials.Secret != "wowowow" {
483+
t.Errorf("got: %v, expected: %v", cfg.DB.Credentials.Secret, "wowowow")
484+
}
485+
486+
if cfg.DB.Credentials.Password != "12345" {
487+
t.Errorf("got: %v, expected: %v", cfg.DB.Credentials.Password, "12345")
488+
}
489+
490+
if cfg.DB.Options.Port != 3306 {
491+
t.Errorf("got: %v, expected: %v", cfg.DB.Options.Port, 3306)
492+
}
493+
}
494+
431495
func TestLoad_EnvGivenWithNested(t *testing.T) {
432496
os.Clearenv()
433497
var cfg struct {

0 commit comments

Comments
 (0)