Fix config module for unit test

Signed-off-by: Asai Neko <sugar@sne.moe>
This commit is contained in:
2025-12-23 20:31:25 +08:00
parent 2d92d5fba7
commit b933522123
5 changed files with 43 additions and 83 deletions

11
config.default.yaml Normal file
View File

@@ -0,0 +1,11 @@
server:
address: :8000
debug_mode: false
file_logger: false
jwt_secret: someting
database:
type: postgres
host: 127.0.0.1
name: postgres
username: postgres
password: postgres

View File

@@ -7,28 +7,35 @@ import (
"github.com/spf13/viper" "github.com/spf13/viper"
) )
func Init() { func ConfigDir() string {
// Set config path by env env := os.Getenv("CONFIG_PATH")
confPath := os.Getenv("CONFIG_PATH") if env != "" {
if confPath == "" { return env
confPath = "config.yaml"
} }
return "."
}
func Init() {
// Read global config // Read global config
viper.SetConfigFile(confPath) viper.SetConfigName("config")
viper.SetDefault("Server", serverDef) viper.SetConfigType("yaml")
viper.SetDefault("Database", databaseDef) viper.AddConfigPath(ConfigDir())
// Bind ENV
viper.BindEnv("server.address", "SERVER_ADDRESS")
viper.BindEnv("server.debug_mode", "SERVER_DEBUG_MODE")
viper.BindEnv("server.file_logger", "SERVER_FILE_LOGGER")
viper.BindEnv("server.jwt_secret", "SERVER_JWT_SECRET")
viper.BindEnv("database.type", "DATABASE_TYPE")
viper.BindEnv("database.host", "DATABASE_HOST")
viper.BindEnv("database.name", "DATABASE_NAME")
viper.BindEnv("database.username", "DATABASE_USERNAME")
viper.BindEnv("database.password", "DATABASE_PASSWORD")
conf := &config{} conf := &config{}
if err := viper.ReadInConfig(); err != nil { if err := viper.ReadInConfig(); err != nil {
// Dont generate config when using dev mode // Dont generate config when using dev mode
if os.Getenv("GO_ENV") == "test" || os.Getenv("CONFIG_PATH") != "" { log.Fatalln("Can't read config!")
log.Fatalf("[Config] failed to read config %s: %v", confPath, err)
}
log.Println("Can't read config, trying to modify!")
if err := viper.WriteConfig(); err != nil {
log.Fatal("[Config] Error writing config: ", err)
}
} }
if err := viper.Unmarshal(conf); err != nil { if err := viper.Unmarshal(conf); err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -1,71 +1,12 @@
package config package config
import ( import (
"log"
"os" "os"
"strconv"
"strings"
"github.com/joho/godotenv"
) )
func GetEnv(key string) string {
_ = godotenv.Load()
upperKey := strings.ToUpper(key)
return os.Getenv(upperKey)
}
func SetEnvConf(key string, sub string) {
envJoin := strings.Join([]string{key, sub}, "_")
env := GetEnv(envJoin)
confJoin := strings.Join([]string{key, sub}, ".")
orig := Get(confJoin)
if env != "" {
switch orig.(type) {
case string:
Set(confJoin, env)
case int:
conv, err := strconv.Atoi(env)
if err != nil {
log.Panic("[Config] Error converting string to int: ", err)
}
Set(confJoin, conv)
case bool:
switch env {
case "true":
Set(confJoin, true)
case "false":
Set(confJoin, false)
}
case []string:
trim := strings.TrimSpace(env)
trim = strings.TrimPrefix(trim, "[")
trim = strings.TrimSuffix(trim, "]")
var envArray []string
for _, v := range strings.Split(trim, ",") {
trimSub := strings.TrimPrefix(v, "\"")
trimSub = strings.TrimSuffix(trimSub, "\"")
envArray = append(envArray, trimSub)
}
Set(confJoin, envArray)
}
}
}
func EnvInit() {
var dict = map[string][]string{
"server": {"address", "debug_mode", "file_logger", "jwt_secret"},
"database": {"type", "host", "name", "username", "password"},
}
for key, value := range dict {
for _, sub := range value {
SetEnvConf(key, sub)
}
}
}
func TZ() string { func TZ() string {
tz := GetEnv("TZ") tz := os.Getenv("TZ")
if tz == "" { if tz == "" {
return "Asia/Shanghai" return "Asia/Shanghai"
} }

View File

@@ -1,7 +1,8 @@
project_name := "nixcn-cms" project_name := "nixcn-cms"
project_dir := justfile_directory()
server_enrty := "main.go" server_enrty := "main.go"
output_dir := join(justfile_directory(), ".outputs") output_dir := join(project_dir, ".outputs")
client_dir := join(justfile_directory(), "client") client_dir := join(project_dir, "client")
exec_path := join(output_dir, project_name) exec_path := join(output_dir, project_name)
go_cmd := `realpath $(which go)` go_cmd := `realpath $(which go)`
bun_cmd := `realpath $(which bun)` bun_cmd := `realpath $(which bun)`
@@ -9,13 +10,14 @@ bun_cmd := `realpath $(which bun)`
default: clean build run default: clean build run
clean: clean:
rm -rf {{output_dir}}/* find .outputs -mindepth 1 ! -name config.yaml -exec rm -rf {} +
build: build:
mkdir -p {{ output_dir }} mkdir -p {{ output_dir }}
{{ go_cmd }} build -o {{ exec_path }}{{ if os() == "windows" { ".exe" } else { "" } }} {{ server_enrty }} {{ go_cmd }} build -o {{ exec_path }}{{ if os() == "windows" { ".exe" } else { "" } }} {{ server_enrty }}
run: run:
cd {{ output_dir }} && {{ exec_path }}{{ if os() == "windows" { ".exe" } else { "" } }} cd {{ output_dir }} && CONFIG_PATH={{ output_dir }} {{ exec_path }}{{ if os() == "windows" { ".exe" } else { "" } }}
test: test:
cd {{output_dir}} && CONFIG_PATH={{output_dir}}/config.yaml GO_ENV=test go test -C .. ./... cd {{ output_dir }} && CONFIG_PATH={{ output_dir }} GO_ENV=test go test -C .. ./...

View File

@@ -9,7 +9,6 @@ import (
func main() { func main() {
config.Init() config.Init()
config.EnvInit()
logger.Init() logger.Init()
data.Init() data.Init()
server.Start() server.Start()