diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/config.go | 178 | ||||
-rw-r--r-- | utils/config_test.go | 162 | ||||
-rw-r--r-- | utils/mail_test.go | 6 |
3 files changed, 274 insertions, 72 deletions
diff --git a/utils/config.go b/utils/config.go index 6026f43f9..fa436f70d 100644 --- a/utils/config.go +++ b/utils/config.go @@ -10,6 +10,7 @@ import ( "io/ioutil" "os" "path/filepath" + "reflect" "strconv" "strings" @@ -188,7 +189,7 @@ func NewConfigWatcher(cfgFileName string, f func()) (*ConfigWatcher, error) { if event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create { l4g.Info(fmt.Sprintf("Config file watcher detected a change reloading %v", cfgFileName)) - if _, configReadErr := ReadConfigFile(cfgFileName, true); configReadErr == nil { + if _, _, configReadErr := ReadConfigFile(cfgFileName, true); configReadErr == nil { f() } else { l4g.Error(fmt.Sprintf("Failed to read while watching config file at %v with err=%v", cfgFileName, configReadErr.Error())) @@ -212,18 +213,11 @@ func (w *ConfigWatcher) Close() { } // ReadConfig reads and parses the given configuration. -func ReadConfig(r io.Reader, allowEnvironmentOverrides bool) (*model.Config, error) { - v := viper.New() +func ReadConfig(r io.Reader, allowEnvironmentOverrides bool) (*model.Config, map[string]interface{}, error) { + v := newViper(allowEnvironmentOverrides) - if allowEnvironmentOverrides { - v.SetEnvPrefix("mm") - v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - v.AutomaticEnv() - } - - v.SetConfigType("json") if err := v.ReadConfig(r); err != nil { - return nil, err + return nil, nil, err } var config model.Config @@ -234,14 +228,151 @@ func ReadConfig(r io.Reader, allowEnvironmentOverrides bool) (*model.Config, err config.PluginSettings = model.PluginSettings{} unmarshalErr = v.UnmarshalKey("pluginsettings", &config.PluginSettings) } - return &config, unmarshalErr + + envConfig := v.EnvSettings() + + var envErr error + if envConfig, envErr = fixEnvSettingsCase(envConfig); envErr != nil { + return nil, nil, envErr + } + + return &config, envConfig, unmarshalErr +} + +func newViper(allowEnvironmentOverrides bool) *viper.Viper { + v := viper.New() + + v.SetConfigType("json") + + if allowEnvironmentOverrides { + v.SetEnvPrefix("mm") + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + v.AutomaticEnv() + } + + // Set zeroed defaults for all the config settings so that Viper knows what environment variables + // it needs to be looking for. The correct defaults will later be applied using Config.SetDefaults. + defaults := flattenStructToMap(structToMap(reflect.TypeOf(model.Config{}))) + + for key, value := range defaults { + v.SetDefault(key, value) + } + + return v +} + +// Converts a struct type into a nested map with keys matching the struct's fields and values +// matching the zeroed value of the corresponding field. +func structToMap(t reflect.Type) (out map[string]interface{}) { + defer func() { + if r := recover(); r != nil { + l4g.Error("Panicked in structToMap. This should never happen. %v", r) + } + }() + + if t.Kind() != reflect.Struct { + // Should never hit this, but this will prevent a panic if that does happen somehow + return nil + } + + out = map[string]interface{}{} + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + var value interface{} + + switch field.Type.Kind() { + case reflect.Struct: + value = structToMap(field.Type) + case reflect.Ptr: + value = nil + default: + value = reflect.Zero(field.Type).Interface() + } + + out[field.Name] = value + } + + return +} + +// Flattens a nested map so that the result is a single map with keys corresponding to the +// path through the original map. For example, +// { +// "a": { +// "b": 1 +// }, +// "c": "sea" +// } +// would flatten to +// { +// "a.b": 1, +// "c": "sea" +// } +func flattenStructToMap(in map[string]interface{}) map[string]interface{} { + out := make(map[string]interface{}) + + for key, value := range in { + if valueAsMap, ok := value.(map[string]interface{}); ok { + sub := flattenStructToMap(valueAsMap) + + for subKey, subValue := range sub { + out[key+"."+subKey] = subValue + } + } else { + out[key] = value + } + } + + return out +} + +// Fixes the case of the environment variables sent back from Viper since Viper stores +// everything as lower case. +func fixEnvSettingsCase(in map[string]interface{}) (out map[string]interface{}, err error) { + defer func() { + if r := recover(); r != nil { + l4g.Error("Panicked in fixEnvSettingsCase. This should never happen. %v", r) + out = in + } + }() + + var fixCase func(map[string]interface{}, reflect.Type) map[string]interface{} + fixCase = func(in map[string]interface{}, t reflect.Type) map[string]interface{} { + if t.Kind() != reflect.Struct { + // Should never hit this, but this will prevent a panic if that does happen somehow + return nil + } + + out := make(map[string]interface{}, len(in)) + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + key := field.Name + if value, ok := in[strings.ToLower(key)]; ok { + if valueAsMap, ok := value.(map[string]interface{}); ok { + out[key] = fixCase(valueAsMap, field.Type) + } else { + out[key] = value + } + } + } + + return out + } + + out = fixCase(in, reflect.TypeOf(model.Config{})) + + return } // ReadConfigFile reads and parses the configuration at the given file path. -func ReadConfigFile(path string, allowEnvironmentOverrides bool) (*model.Config, error) { +func ReadConfigFile(path string, allowEnvironmentOverrides bool) (*model.Config, map[string]interface{}, error) { f, err := os.Open(path) if err != nil { - return nil, err + return nil, nil, err } defer f.Close() return ReadConfig(f, allowEnvironmentOverrides) @@ -276,22 +407,24 @@ func EnsureConfigFile(fileName string) (string, error) { // LoadConfig will try to search around for the corresponding config file. It will search // /tmp/fileName then attempt ./config/fileName, then ../config/fileName and last it will look at // fileName. -func LoadConfig(fileName string) (config *model.Config, configPath string, appErr *model.AppError) { +func LoadConfig(fileName string) (*model.Config, string, map[string]interface{}, *model.AppError) { + var configPath string + if fileName != filepath.Base(fileName) { configPath = fileName } else { if path, err := EnsureConfigFile(fileName); err != nil { - appErr = model.NewAppError("LoadConfig", "utils.config.load_config.opening.panic", map[string]interface{}{"Filename": fileName, "Error": err.Error()}, "", 0) - return + appErr := model.NewAppError("LoadConfig", "utils.config.load_config.opening.panic", map[string]interface{}{"Filename": fileName, "Error": err.Error()}, "", 0) + return nil, "", nil, appErr } else { configPath = path } } - config, err := ReadConfigFile(configPath, true) + config, envConfig, err := ReadConfigFile(configPath, true) if err != nil { - appErr = model.NewAppError("LoadConfig", "utils.config.load_config.decoding.panic", map[string]interface{}{"Filename": fileName, "Error": err.Error()}, "", 0) - return + appErr := model.NewAppError("LoadConfig", "utils.config.load_config.decoding.panic", map[string]interface{}{"Filename": fileName, "Error": err.Error()}, "", 0) + return nil, "", nil, appErr } needSave := len(config.SqlSettings.AtRestEncryptKey) == 0 || len(*config.FileSettings.PublicLinkSalt) == 0 || @@ -300,7 +433,7 @@ func LoadConfig(fileName string) (config *model.Config, configPath string, appEr config.SetDefaults() if err := config.IsValid(); err != nil { - return nil, "", err + return nil, "", nil, err } if needSave { @@ -322,7 +455,7 @@ func LoadConfig(fileName string) (config *model.Config, configPath string, appEr } } - return config, configPath, nil + return config, configPath, envConfig, nil } func GenerateClientConfig(c *model.Config, diagnosticId string, license *model.License) map[string]string { @@ -383,6 +516,7 @@ func GenerateClientConfig(c *model.Config, diagnosticId string, license *model.L props["EnableTutorial"] = strconv.FormatBool(*c.ServiceSettings.EnableTutorial) props["ExperimentalEnableDefaultChannelLeaveJoinMessages"] = strconv.FormatBool(*c.ServiceSettings.ExperimentalEnableDefaultChannelLeaveJoinMessages) props["ExperimentalGroupUnreadChannels"] = *c.ServiceSettings.ExperimentalGroupUnreadChannels + props["ExperimentalEnableAutomaticReplies"] = strconv.FormatBool(*c.TeamSettings.ExperimentalEnableAutomaticReplies) props["ExperimentalTimezone"] = strconv.FormatBool(*c.DisplaySettings.ExperimentalTimezone) props["SendEmailNotifications"] = strconv.FormatBool(c.EmailSettings.SendEmailNotifications) diff --git a/utils/config_test.go b/utils/config_test.go index 84e7291b0..fbac577ee 100644 --- a/utils/config_test.go +++ b/utils/config_test.go @@ -18,7 +18,7 @@ import ( func TestConfig(t *testing.T) { TranslationsPreInit() - cfg, _, err := LoadConfig("config.json") + cfg, _, _, err := LoadConfig("config.json") require.Nil(t, err) InitTranslations(cfg.LocalizationSettings) } @@ -50,53 +50,133 @@ func TestFindConfigFile(t *testing.T) { } func TestConfigFromEnviroVars(t *testing.T) { - os.Setenv("MM_TEAMSETTINGS_SITENAME", "From Environment") - os.Setenv("MM_TEAMSETTINGS_CUSTOMBRANDTEXT", "Custom Brand") - os.Setenv("MM_SERVICESETTINGS_ENABLECOMMANDS", "false") - os.Setenv("MM_SERVICESETTINGS_READTIMEOUT", "400") - TranslationsPreInit() - cfg, cfgPath, err := LoadConfig("config.json") - require.Nil(t, err) - if cfg.TeamSettings.SiteName != "From Environment" { - t.Fatal("Couldn't read config from environment var") - } + config := `{ + "ServiceSettings": { + "EnableCommands": true, + "ReadTimeout": 100 + }, + "TeamSettings": { + "SiteName": "Mattermost", + "CustomBrandText": "" + } + }` - if *cfg.TeamSettings.CustomBrandText != "Custom Brand" { - t.Fatal("Couldn't read config from environment var") - } + t.Run("string settings", func(t *testing.T) { + os.Setenv("MM_TEAMSETTINGS_SITENAME", "From Environment") + os.Setenv("MM_TEAMSETTINGS_CUSTOMBRANDTEXT", "Custom Brand") - if *cfg.ServiceSettings.EnableCommands { - t.Fatal("Couldn't read config from environment var") - } + cfg, envCfg, err := ReadConfig(strings.NewReader(config), true) + require.Nil(t, err) - if *cfg.ServiceSettings.ReadTimeout != 400 { - t.Fatal("Couldn't read config from environment var") - } + if cfg.TeamSettings.SiteName != "From Environment" { + t.Fatal("Couldn't read config from environment var") + } - os.Unsetenv("MM_TEAMSETTINGS_SITENAME") - os.Unsetenv("MM_TEAMSETTINGS_CUSTOMBRANDTEXT") - os.Unsetenv("MM_SERVICESETTINGS_ENABLECOMMANDS") - os.Unsetenv("MM_SERVICESETTINGS_READTIMEOUT") + if *cfg.TeamSettings.CustomBrandText != "Custom Brand" { + t.Fatal("Couldn't read config from environment var") + } - cfg.TeamSettings.SiteName = "Mattermost" - *cfg.ServiceSettings.SiteURL = "" - *cfg.ServiceSettings.EnableCommands = true - *cfg.ServiceSettings.ReadTimeout = 300 - SaveConfig(cfgPath, cfg) + if teamSettings, ok := envCfg["TeamSettings"]; !ok { + t.Fatal("TeamSettings is missing from envConfig") + } else if teamSettingsAsMap, ok := teamSettings.(map[string]interface{}); !ok { + t.Fatal("TeamSettings is not a map in envConfig") + } else { + if siteNameInEnv, ok := teamSettingsAsMap["SiteName"].(bool); !ok || !siteNameInEnv { + t.Fatal("SiteName should be in envConfig") + } - cfg, _, err = LoadConfig("config.json") - require.Nil(t, err) + if customBrandTextInEnv, ok := teamSettingsAsMap["CustomBrandText"].(bool); !ok || !customBrandTextInEnv { + t.Fatal("SiteName should be in envConfig") + } + } - if cfg.TeamSettings.SiteName != "Mattermost" { - t.Fatal("should have been reset") - } + os.Unsetenv("MM_TEAMSETTINGS_SITENAME") + os.Unsetenv("MM_TEAMSETTINGS_CUSTOMBRANDTEXT") + + cfg, envCfg, err = ReadConfig(strings.NewReader(config), true) + require.Nil(t, err) + + if cfg.TeamSettings.SiteName != "Mattermost" { + t.Fatal("should have been reset") + } + + if _, ok := envCfg["TeamSettings"]; ok { + t.Fatal("TeamSettings should be missing from envConfig") + } + }) + + t.Run("boolean setting", func(t *testing.T) { + os.Setenv("MM_SERVICESETTINGS_ENABLECOMMANDS", "false") + defer os.Unsetenv("MM_SERVICESETTINGS_ENABLECOMMANDS") + + cfg, envCfg, err := ReadConfig(strings.NewReader(config), true) + require.Nil(t, err) + + if *cfg.ServiceSettings.EnableCommands { + t.Fatal("Couldn't read config from environment var") + } + + if serviceSettings, ok := envCfg["ServiceSettings"]; !ok { + t.Fatal("ServiceSettings is missing from envConfig") + } else if serviceSettingsAsMap, ok := serviceSettings.(map[string]interface{}); !ok { + t.Fatal("ServiceSettings is not a map in envConfig") + } else { + if enableCommandsInEnv, ok := serviceSettingsAsMap["EnableCommands"].(bool); !ok || !enableCommandsInEnv { + t.Fatal("EnableCommands should be in envConfig") + } + } + }) + + t.Run("integer setting", func(t *testing.T) { + os.Setenv("MM_SERVICESETTINGS_READTIMEOUT", "400") + defer os.Unsetenv("MM_SERVICESETTINGS_READTIMEOUT") + + cfg, envCfg, err := ReadConfig(strings.NewReader(config), true) + require.Nil(t, err) + + if *cfg.ServiceSettings.ReadTimeout != 400 { + t.Fatal("Couldn't read config from environment var") + } + + if serviceSettings, ok := envCfg["ServiceSettings"]; !ok { + t.Fatal("ServiceSettings is missing from envConfig") + } else if serviceSettingsAsMap, ok := serviceSettings.(map[string]interface{}); !ok { + t.Fatal("ServiceSettings is not a map in envConfig") + } else { + if readTimeoutInEnv, ok := serviceSettingsAsMap["ReadTimeout"].(bool); !ok || !readTimeoutInEnv { + t.Fatal("ReadTimeout should be in envConfig") + } + } + }) + + t.Run("setting missing from config.json", func(t *testing.T) { + os.Setenv("MM_SERVICESETTINGS_SITEURL", "https://example.com") + defer os.Unsetenv("MM_SERVICESETTINGS_SITEURL") + + cfg, envCfg, err := ReadConfig(strings.NewReader(config), true) + require.Nil(t, err) + + if *cfg.ServiceSettings.SiteURL != "https://example.com" { + t.Fatal("Couldn't read config from environment var") + } + + if serviceSettings, ok := envCfg["ServiceSettings"]; !ok { + t.Fatal("ServiceSettings is missing from envConfig") + } else if serviceSettingsAsMap, ok := serviceSettings.(map[string]interface{}); !ok { + t.Fatal("ServiceSettings is not a map in envConfig") + } else { + if siteURLInEnv, ok := serviceSettingsAsMap["SiteURL"].(bool); !ok || !siteURLInEnv { + t.Fatal("SiteURL should be in envConfig") + } + } + }) } func TestValidateLocales(t *testing.T) { TranslationsPreInit() - cfg, _, err := LoadConfig("config.json") + cfg, _, _, err := LoadConfig("config.json") require.Nil(t, err) *cfg.LocalizationSettings.DefaultServerLocale = "en" @@ -294,18 +374,6 @@ func TestGetClientConfig(t *testing.T) { } }) } - -} - -func TestReadConfig(t *testing.T) { - config, err := ReadConfig(strings.NewReader(`{ - "ServiceSettings": { - "SiteURL": "http://foo.bar" - } - }`), false) - require.NoError(t, err) - - assert.Equal(t, "http://foo.bar", *config.ServiceSettings.SiteURL) } func sToP(s string) *string { diff --git a/utils/mail_test.go b/utils/mail_test.go index 65b89c240..6bd8e7044 100644 --- a/utils/mail_test.go +++ b/utils/mail_test.go @@ -13,7 +13,7 @@ import ( ) func TestMailConnectionFromConfig(t *testing.T) { - cfg, _, err := LoadConfig("config.json") + cfg, _, _, err := LoadConfig("config.json") require.Nil(t, err) if conn, err := ConnectToSMTPServer(cfg); err != nil { @@ -36,7 +36,7 @@ func TestMailConnectionFromConfig(t *testing.T) { } func TestMailConnectionAdvanced(t *testing.T) { - cfg, _, err := LoadConfig("config.json") + cfg, _, _, err := LoadConfig("config.json") require.Nil(t, err) if conn, err := ConnectToSMTPServerAdvanced( @@ -86,7 +86,7 @@ func TestMailConnectionAdvanced(t *testing.T) { } func TestSendMailUsingConfig(t *testing.T) { - cfg, _, err := LoadConfig("config.json") + cfg, _, _, err := LoadConfig("config.json") require.Nil(t, err) T = GetUserTranslations("en") |