Skip to content

Commit

Permalink
Merge pull request #685 from traPtitech/feat/config-from-env
Browse files Browse the repository at this point in the history
  • Loading branch information
ras0q authored Jun 19, 2024
2 parents 4556ac6 + f73f06e commit f202eba
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 71 deletions.
66 changes: 38 additions & 28 deletions util/config/config.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
//nolint:errcheck
package config

import (
"flag"
"fmt"
"log"
"os"
"strings"
"time"

"github.com/go-sql-driver/mysql"
Expand All @@ -15,26 +17,6 @@ import (
"gorm.io/gorm/logger"
)

var (
flagKeys = []struct{ path, flag string }{
{"production", "production"},
{"port", "port"},
{"onlyMigrate", "only-migrate"},
{"insertMockData", "insert-mock-data"},
{"db.user", "db-user"},
{"db.pass", "db-pass"},
{"db.host", "db-host"},
{"db.name", "db-name"},
{"db.port", "db-port"},
{"db.verbose", "db-verbose"},
{"traq.accessToken", "traq-access-token"},
{"knoq.cookie", "knoq-cookie"},
{"knoq.apiEndpoint", "knoq-api-endpoint"},
{"portal.cookie", "portal-cookie"},
{"portal.apiEndpoint", "portal-api-endpoint"},
}
)

type (
Config struct {
IsProduction bool `mapstructure:"production"`
Expand Down Expand Up @@ -69,22 +51,52 @@ type (

func init() {
pflag.Bool("production", false, "whether production or development")
viper.BindPFlag("production", pflag.Lookup("production"))

pflag.Int("port", 1323, "api port")
viper.BindPFlag("port", pflag.Lookup("port"))

pflag.Bool("only-migrate", false, "only migrate db (not start server)")
viper.BindPFlag("onlyMigrate", pflag.Lookup("only-migrate"))

pflag.Bool("insert-mock-data", false, "insert sample mock data(for dev)")
viper.BindPFlag("insertMockData", pflag.Lookup("insert-mock-data"))

pflag.String("db-user", "root", "db user name")
viper.BindPFlag("db.user", pflag.Lookup("db-user"))

pflag.String("db-pass", "password", "db password")
viper.BindPFlag("db.pass", pflag.Lookup("db-pass"))

pflag.String("db-host", "localhost", "db host")
viper.BindPFlag("db.host", pflag.Lookup("db-host"))

pflag.String("db-name", "portfolio", "db name")
viper.BindPFlag("db.name", pflag.Lookup("db-name"))

pflag.Int("db-port", 3306, "db port")
viper.BindPFlag("db.port", pflag.Lookup("db-port"))

pflag.Bool("db-verbose", false, "db verbose mode")
viper.BindPFlag("db.verbose", pflag.Lookup("db-verbose"))

pflag.String("traq-access-token", "", "traq access token")
viper.BindPFlag("traq.accessToken", pflag.Lookup("traq-access-token"))

pflag.String("knoq-cookie", "", "knoq cookie")
viper.BindPFlag("knoq.cookie", pflag.Lookup("knoq-cookie"))

pflag.String("knoq-api-endpoint", "", "knoq api endpoint")
viper.BindPFlag("knoq.apiEndpoint", pflag.Lookup("knoq-api-endpoint"))

pflag.String("portal-cookie", "", "portal cookie")
viper.BindPFlag("portal.cookie", pflag.Lookup("portal-cookie"))

pflag.String("portal-api-endpoint", "", "portal api endpoint")
viper.BindPFlag("portal.apiEndpoint", pflag.Lookup("portal-api-endpoint"))

pflag.StringP("config", "c", "", "config file path")

pflag.CommandLine.AddGoFlagSet(flag.CommandLine)
}

Expand All @@ -93,19 +105,16 @@ type LoadOpts struct {
}

func Load(opts LoadOpts) (*Config, error) {
var c Config

pflag.Parse()
for _, key := range flagKeys {
if err := viper.BindPFlag(key.path, pflag.Lookup(key.flag)); err != nil {
return nil, fmt.Errorf("bind flag %s: %w", key.flag, err)
}
}

if err := viper.BindPFlags(pflag.CommandLine); err != nil {
return nil, fmt.Errorf("bind flags: %w", err)
}

viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.SetEnvPrefix("TPF")
viper.AutomaticEnv()

if !opts.SkipReadFromFiles {
configPath := viper.GetString("config")
if len(configPath) > 0 {
Expand All @@ -123,7 +132,7 @@ func Load(opts LoadOpts) (*Config, error) {
return nil, fmt.Errorf("read config from %s: %w", configPath, err)
}

log.Printf("config file does not found: %v\n", err)
log.Printf("config file did not used: %v\n", err)
} else {
return nil, fmt.Errorf("read config: %w", err)
}
Expand All @@ -132,6 +141,7 @@ func Load(opts LoadOpts) (*Config, error) {
}
}

var c Config
if err := viper.Unmarshal(&c); err != nil {
return nil, fmt.Errorf("unmarshal config: %w", err)
}
Expand Down
130 changes: 106 additions & 24 deletions util/config/config_test.go
Original file line number Diff line number Diff line change
@@ -1,42 +1,124 @@
package config
//nolint:errcheck
package config_test

import (
"os"
"path/filepath"
"testing"

"github.com/spf13/viper"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/traPtitech/traPortfolio/util/config"
)

func TestParse(t *testing.T) {
expected := Config{
IsProduction: true,
Port: 3000,
OnlyMigrate: true,
InsertMockData: true,
DB: SQLConfig{
func TestLoad(t *testing.T) {
defaultConfig := config.Config{
IsProduction: false,
Port: 1323,
OnlyMigrate: false,
InsertMockData: false,
DB: config.SQLConfig{
User: "root",
Pass: "password",
Host: "mysql",
Host: "localhost",
Name: "portfolio",
Port: 3001,
Verbose: true,
Port: 3306,
Verbose: false,
},
Traq: TraqConfig{
AccessToken: "traq token",
Traq: config.TraqConfig{
AccessToken: "",
},
Knoq: APIConfig{
Cookie: "knoq cookie",
APIEndpoint: "knoq endpoint",
Knoq: config.APIConfig{
Cookie: "",
APIEndpoint: "",
},
Portal: APIConfig{
Cookie: "portal cookie",
APIEndpoint: "portal endpoint",
Portal: config.APIConfig{
Cookie: "",
APIEndpoint: "",
},
}

viper.AddConfigPath("./testdata")
t.Run("default", func(t *testing.T) {
got, err := config.Load(config.LoadOpts{})
assert.NoError(t, err)
assert.Equal(t, &defaultConfig, got)
})

got, err := Load(LoadOpts{})
assert.NoError(t, err)
assert.Equal(t, &expected, got)
t.Run("from file", func(t *testing.T) {
yaml := `
production: true
port: 8000`
configPath := filepath.Join(t.TempDir(), "config.yaml")
os.Create(configPath)
os.WriteFile(configPath, []byte(yaml), 0644)
t.Setenv("TPF_CONFIG", configPath)

expected := defaultConfig
expected.IsProduction = true
expected.Port = 8000

got, err := config.Load(config.LoadOpts{})
assert.NoError(t, err)
assert.Equal(t, &expected, got)
})

t.Run("from env", func(t *testing.T) {
t.Setenv("TPF_PRODUCTION", "true")
t.Setenv("TPF_PORT", "8000")

expected := defaultConfig
expected.IsProduction = true
expected.Port = 8000

got, err := config.Load(config.LoadOpts{})
assert.NoError(t, err)
assert.Equal(t, &expected, got)
})

t.Run("from flag", func(t *testing.T) {
pflag.CommandLine.Set("production", "true")
pflag.CommandLine.Set("port", "8000")
t.Cleanup(func() {
pflag.CommandLine = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError)
})

expected := defaultConfig
expected.IsProduction = true
expected.Port = 8000

got, err := config.Load(config.LoadOpts{})
assert.NoError(t, err)
assert.Equal(t, &expected, got)
})

t.Run("priority order is flag, env, file, then default", func(t *testing.T) {
t.Skip("It fails if flag is set twice")

yaml := `
db:
user: file
pass: file
host: file`
configPath := filepath.Join(t.TempDir(), "config.yaml")
os.Create(configPath)
os.WriteFile(configPath, []byte(yaml), 0644)
t.Setenv("TPF_CONFIG", configPath)

t.Setenv("TPF_DB_USER", "env")
t.Setenv("TPF_DB_PASS", "env")

pflag.CommandLine.Set("db-user", "flag")
t.Cleanup(func() {
pflag.CommandLine = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError)
})

expected := defaultConfig
expected.DB.User = "flag"
expected.DB.Pass = "env"
expected.DB.Host = "file"

got, err := config.Load(config.LoadOpts{})
assert.NoError(t, err)
assert.Equal(t, &expected, got)
})
}
19 changes: 0 additions & 19 deletions util/config/testdata/config.yaml

This file was deleted.

0 comments on commit f202eba

Please sign in to comment.