summaryrefslogtreecommitdiffstats
path: root/utils
diff options
context:
space:
mode:
authorChris <ccbrown112@gmail.com>2017-10-31 09:39:31 -0500
committerJoram Wilander <jwawilander@gmail.com>2017-10-31 10:39:31 -0400
commitce2b2be5de578bd9eb44b26e04db75ca61d67ca5 (patch)
tree54203a18ecfb167dcf1d7e0742cea0ed9aab220a /utils
parentb446d0aa0aa2bd3d87028b0543752eb539507481 (diff)
downloadchat-ce2b2be5de578bd9eb44b26e04db75ca61d67ca5.tar.gz
chat-ce2b2be5de578bd9eb44b26e04db75ca61d67ca5.tar.bz2
chat-ce2b2be5de578bd9eb44b26e04db75ca61d67ca5.zip
Refactoring cfg refs and load / save functions (#7749)
* refactoring cfg refs and load / save functions * improve error output
Diffstat (limited to 'utils')
-rw-r--r--utils/config.go171
-rw-r--r--utils/config_test.go33
-rw-r--r--utils/file_test.go2
-rw-r--r--utils/mail_test.go4
4 files changed, 119 insertions, 91 deletions
diff --git a/utils/config.go b/utils/config.go
index 2c3a9d291..4b377cce2 100644
--- a/utils/config.go
+++ b/utils/config.go
@@ -70,16 +70,14 @@ func RemoveConfigListener(id string) {
delete(cfgListeners, id)
}
-func FindConfigFile(fileName string) string {
- if _, err := os.Stat("./config/" + fileName); err == nil {
- fileName, _ = filepath.Abs("./config/" + fileName)
- } else if _, err := os.Stat("../config/" + fileName); err == nil {
- fileName, _ = filepath.Abs("../config/" + fileName)
- } else if _, err := os.Stat(fileName); err == nil {
- fileName, _ = filepath.Abs(fileName)
+func FindConfigFile(fileName string) (path string) {
+ for _, dir := range []string{"./config", "../config", "../../config", "."} {
+ path, _ := filepath.Abs(filepath.Join(dir, fileName))
+ if _, err := os.Stat(path); err == nil {
+ return path
+ }
}
-
- return fileName
+ return ""
}
func FindDir(dir string) (string, bool) {
@@ -197,12 +195,6 @@ func SaveConfig(fileName string, config *model.Config) *model.AppError {
return nil
}
-func EnableConfigFromEnviromentVars() {
- viper.SetEnvPrefix("mm")
- viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
- viper.AutomaticEnv()
-}
-
func InitializeConfigWatch() {
cfgMutex.Lock()
defer cfgMutex.Unlock()
@@ -230,7 +222,7 @@ func InitializeConfigWatch() {
l4g.Info(fmt.Sprintf("Config file watcher detected a change reloading %v", CfgFileName))
if configReadErr := viper.ReadInConfig(); configReadErr == nil {
- LoadConfig(CfgFileName)
+ LoadGlobalConfig(CfgFileName)
} else {
l4g.Error(fmt.Sprintf("Failed to read while watching config file at %v with err=%v", CfgFileName, configReadErr.Error()))
}
@@ -274,83 +266,108 @@ func InitAndLoadConfig(filename string) error {
return err
}
- EnableConfigFromEnviromentVars()
- LoadConfig(filename)
+ LoadGlobalConfig(filename)
InitializeConfigWatch()
EnableConfigWatch()
return nil
}
-// LoadConfig will try to search around for the corresponding config file.
-// It will search /tmp/fileName then attempt ./config/fileName,
-// then ../config/fileName and last it will look at fileName
-func LoadConfig(fileName string) *model.Config {
- cfgMutex.Lock()
- defer cfgMutex.Unlock()
-
- // Cfg should never be null
- oldConfig := *Cfg
+// ReadConfig reads and parses the given configuration.
+func ReadConfig(r io.Reader, allowEnvironmentOverrides bool) (*model.Config, error) {
+ v := viper.New()
- fileNameWithExtension := filepath.Base(fileName)
- fileExtension := filepath.Ext(fileNameWithExtension)
- fileDir := filepath.Dir(fileName)
+ if allowEnvironmentOverrides {
+ v.SetEnvPrefix("mm")
+ v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
+ v.AutomaticEnv()
+ }
- if len(fileNameWithExtension) > 0 {
- fileNameOnly := fileNameWithExtension[:len(fileNameWithExtension)-len(fileExtension)]
- viper.SetConfigName(fileNameOnly)
- } else {
- viper.SetConfigName("config")
+ v.SetConfigType("json")
+ if err := v.ReadConfig(r); err != nil {
+ return nil, err
}
- if len(fileDir) > 0 {
- viper.AddConfigPath(fileDir)
+ var config model.Config
+ unmarshalErr := v.Unmarshal(&config)
+ if unmarshalErr == nil {
+ // https://github.com/spf13/viper/issues/324
+ // https://github.com/spf13/viper/issues/348
+ config.PluginSettings = model.PluginSettings{}
+ unmarshalErr = v.UnmarshalKey("pluginsettings", &config.PluginSettings)
}
+ return &config, unmarshalErr
+}
- viper.SetConfigType("json")
- viper.AddConfigPath("./config")
- viper.AddConfigPath("../config")
- viper.AddConfigPath("../../config")
- viper.AddConfigPath(".")
+// ReadConfigFile reads and parses the configuration at the given file path.
+func ReadConfigFile(path string, allowEnvironmentOverrides bool) (*model.Config, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ return ReadConfig(f, allowEnvironmentOverrides)
+}
- configReadErr := viper.ReadInConfig()
- if configReadErr != nil {
- if _, ok := configReadErr.(viper.ConfigFileNotFoundError); ok {
- // In case of a file-not-found error, try to copy default.json if it's present.
- defaultPath := FindConfigFile("default.json")
- if src, err := os.Open(defaultPath); err == nil {
- if dest, err := os.OpenFile(filepath.Join(filepath.Dir(defaultPath), "config.json"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600); err == nil {
- if _, err := io.Copy(dest, src); err == nil {
- configReadErr = viper.ReadInConfig()
- }
- dest.Close()
- }
- src.Close()
- }
+// EnsureConfigFile will attempt to locate a config file with the given name. If it does not exist,
+// it will attempt to locate a default config file, and copy it. In either case, the config file
+// path is returned.
+func EnsureConfigFile(fileName string) (string, error) {
+ if configFile := FindConfigFile(fileName); configFile != "" {
+ return configFile, nil
+ }
+ if defaultPath := FindConfigFile("default.json"); defaultPath != "" {
+ destPath := filepath.Join(filepath.Dir(defaultPath), fileName)
+ src, err := os.Open(defaultPath)
+ if err != nil {
+ return "", err
+ }
+ defer src.Close()
+ dest, err := os.OpenFile(destPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
+ if err != nil {
+ return "", err
+ }
+ defer dest.Close()
+ if _, err := io.Copy(dest, src); err == nil {
+ return destPath, nil
}
}
+ return "", fmt.Errorf("no config file found")
+}
- if configReadErr != nil {
- errMsg := T("utils.config.load_config.opening.panic", map[string]interface{}{"Filename": fileName, "Error": configReadErr.Error()})
- fmt.Fprintln(os.Stderr, errMsg)
- os.Exit(1)
- }
+// LoadGlobalConfig will try to search around for the corresponding config file. It will search
+// /tmp/fileName then attempt ./config/fileName, then ../config/fileName and last it will look at
+// fileName
+//
+// XXX: This is deprecated.
+func LoadGlobalConfig(fileName string) *model.Config {
+ cfgMutex.Lock()
+ defer cfgMutex.Unlock()
- var config model.Config
- unmarshalErr := viper.Unmarshal(&config)
- if unmarshalErr == nil {
- // https://github.com/spf13/viper/issues/324
- // https://github.com/spf13/viper/issues/348
- config.PluginSettings = model.PluginSettings{}
- unmarshalErr = viper.UnmarshalKey("pluginsettings", &config.PluginSettings)
+ // Cfg should never be null
+ oldConfig := *Cfg
+
+ var configPath string
+ if fileName != filepath.Base(fileName) {
+ configPath = fileName
+ } else {
+ if path, err := EnsureConfigFile(fileName); err != nil {
+ errMsg := T("utils.config.load_config.opening.panic", map[string]interface{}{"Filename": fileName, "Error": err.Error()})
+ fmt.Fprintln(os.Stderr, errMsg)
+ os.Exit(1)
+ } else {
+ configPath = path
+ }
}
- if unmarshalErr != nil {
- errMsg := T("utils.config.load_config.decoding.panic", map[string]interface{}{"Filename": fileName, "Error": unmarshalErr.Error()})
+
+ config, err := ReadConfigFile(configPath, true)
+ if err != nil {
+ errMsg := T("utils.config.load_config.decoding.panic", map[string]interface{}{"Filename": fileName, "Error": err.Error()})
fmt.Fprintln(os.Stderr, errMsg)
os.Exit(1)
}
- CfgFileName = viper.ConfigFileUsed()
+ CfgFileName = configPath
needSave := len(config.SqlSettings.AtRestEncryptKey) == 0 || len(*config.FileSettings.PublicLinkSalt) == 0 ||
len(config.EmailSettings.InviteSalt) == 0
@@ -363,16 +380,16 @@ func LoadConfig(fileName string) *model.Config {
if needSave {
cfgMutex.Unlock()
- if err := SaveConfig(CfgFileName, &config); err != nil {
+ if err := SaveConfig(CfgFileName, config); err != nil {
err.Translate(T)
l4g.Warn(err.Error())
}
cfgMutex.Lock()
}
- if err := ValidateLocales(&config); err != nil {
+ if err := ValidateLocales(config); err != nil {
cfgMutex.Unlock()
- if err := SaveConfig(CfgFileName, &config); err != nil {
+ if err := SaveConfig(CfgFileName, config); err != nil {
err.Translate(T)
l4g.Warn(err.Error())
}
@@ -388,7 +405,7 @@ func LoadConfig(fileName string) *model.Config {
}
}
- Cfg = &config
+ Cfg = config
CfgHash = fmt.Sprintf("%x", md5.Sum([]byte(Cfg.ToJson())))
ClientCfg = getClientConfig(Cfg)
clientCfgJson, _ := json.Marshal(ClientCfg)
@@ -398,10 +415,10 @@ func LoadConfig(fileName string) *model.Config {
SetSiteURL(*Cfg.ServiceSettings.SiteURL)
for _, listener := range cfgListeners {
- listener(&oldConfig, &config)
+ listener(&oldConfig, config)
}
- return &config
+ return config
}
func RegenerateClientConfig() {
diff --git a/utils/config_test.go b/utils/config_test.go
index 527718bbb..92d3c6fd4 100644
--- a/utils/config_test.go
+++ b/utils/config_test.go
@@ -9,25 +9,26 @@ import (
"testing"
"time"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
"github.com/mattermost/mattermost-server/model"
)
func TestConfig(t *testing.T) {
TranslationsPreInit()
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
InitTranslations(Cfg.LocalizationSettings)
}
func TestConfigFromEnviroVars(t *testing.T) {
-
os.Setenv("MM_TEAMSETTINGS_SITENAME", "From Enviroment")
os.Setenv("MM_TEAMSETTINGS_CUSTOMBRANDTEXT", "Custom Brand")
os.Setenv("MM_SERVICESETTINGS_ENABLECOMMANDS", "false")
os.Setenv("MM_SERVICESETTINGS_READTIMEOUT", "400")
TranslationsPreInit()
- EnableConfigFromEnviromentVars()
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
if Cfg.TeamSettings.SiteName != "From Enviroment" {
t.Fatal("Couldn't read config from enviroment var")
@@ -56,7 +57,7 @@ func TestConfigFromEnviroVars(t *testing.T) {
*Cfg.ServiceSettings.ReadTimeout = 300
SaveConfig(CfgFileName, Cfg)
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
if Cfg.TeamSettings.SiteName != "Mattermost" {
t.Fatal("should have been reset")
@@ -65,7 +66,7 @@ func TestConfigFromEnviroVars(t *testing.T) {
func TestRedirectStdLog(t *testing.T) {
TranslationsPreInit()
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
InitTranslations(Cfg.LocalizationSettings)
log := NewRedirectStdLog("test", false)
@@ -110,8 +111,7 @@ func TestAddRemoveConfigListener(t *testing.T) {
func TestConfigListener(t *testing.T) {
TranslationsPreInit()
- EnableConfigFromEnviromentVars()
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
SiteName := Cfg.TeamSettings.SiteName
defer func() {
@@ -148,7 +148,7 @@ func TestConfigListener(t *testing.T) {
listener2Id := AddConfigListener(listener2)
defer RemoveConfigListener(listener2Id)
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
if !listenerCalled {
t.Fatal("listener should've been called")
@@ -159,7 +159,7 @@ func TestConfigListener(t *testing.T) {
func TestValidateLocales(t *testing.T) {
TranslationsPreInit()
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
defaultServerLocale := *Cfg.LocalizationSettings.DefaultServerLocale
defaultClientLocale := *Cfg.LocalizationSettings.DefaultClientLocale
@@ -278,10 +278,21 @@ func TestValidateLocales(t *testing.T) {
func TestGetClientConfig(t *testing.T) {
TranslationsPreInit()
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
configMap := getClientConfig(Cfg)
if configMap["EmailNotificationContentsType"] != *Cfg.EmailSettings.EmailNotificationContentsType {
t.Fatal("EmailSettings.EmailNotificationContentsType not exposed to client config")
}
}
+
+func TestReadConfig(t *testing.T) {
+ config, err := ReadConfig(strings.NewReader(`{
+ "ServiceSettings": {
+ "SiteURL": "http://foo.bar"
+ }
+ }`), false)
+ require.NoError(t, err)
+
+ assert.Equal(t, "http://foo.bar", *config.ServiceSettings.SiteURL)
+}
diff --git a/utils/file_test.go b/utils/file_test.go
index ed2e8683b..5c7162450 100644
--- a/utils/file_test.go
+++ b/utils/file_test.go
@@ -43,7 +43,7 @@ func TestFileMinioTestSuite(t *testing.T) {
func (s *FileTestSuite) SetupTest() {
TranslationsPreInit()
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
InitTranslations(Cfg.LocalizationSettings)
// Save state to restore after the test has run.
diff --git a/utils/mail_test.go b/utils/mail_test.go
index 774ecbf5b..1d4643429 100644
--- a/utils/mail_test.go
+++ b/utils/mail_test.go
@@ -9,7 +9,7 @@ import (
)
func TestMailConnection(t *testing.T) {
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
if conn, err := connectToSMTPServer(Cfg); err != nil {
t.Log(err)
@@ -32,7 +32,7 @@ func TestMailConnection(t *testing.T) {
}
func TestSendMail(t *testing.T) {
- LoadConfig("config.json")
+ LoadGlobalConfig("config.json")
T = GetUserTranslations("en")
var emailTo string = "test@example.com"