summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--model/config.go2
-rw-r--r--model/config_test.go46
-rw-r--r--model/utils.go59
-rw-r--r--model/utils_test.go183
4 files changed, 289 insertions, 1 deletions
diff --git a/model/config.go b/model/config.go
index 8d1a61926..93533b8aa 100644
--- a/model/config.go
+++ b/model/config.go
@@ -1715,8 +1715,8 @@ func (s *MessageExportSettings) SetDefaults() {
if s.GlobalRelaySettings == nil {
s.GlobalRelaySettings = &GlobalRelayMessageExportSettings{}
- s.GlobalRelaySettings.SetDefaults()
}
+ s.GlobalRelaySettings.SetDefaults()
}
type DisplaySettings struct {
diff --git a/model/config_test.go b/model/config_test.go
index 1f917af27..b7533145b 100644
--- a/model/config_test.go
+++ b/model/config_test.go
@@ -4,11 +4,57 @@
package model
import (
+ "fmt"
+ "reflect"
"testing"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
+func TestConfigDefaults(t *testing.T) {
+ t.Parallel()
+
+ t.Run("somewhere nil when uninitialized", func(t *testing.T) {
+ c := Config{}
+ require.False(t, checkNowhereNil(t, "config", c))
+ })
+
+ t.Run("nowhere nil when initialized", func(t *testing.T) {
+ c := Config{}
+ c.SetDefaults()
+ require.True(t, checkNowhereNil(t, "config", c))
+ })
+
+ t.Run("nowhere nil when partially initialized", func(t *testing.T) {
+ var recursivelyUninitialize func(*Config, string, reflect.Value)
+ recursivelyUninitialize = func(config *Config, name string, v reflect.Value) {
+ if v.Type().Kind() == reflect.Ptr {
+ // Set every pointer we find in the tree to nil
+ v.Set(reflect.Zero(v.Type()))
+ require.True(t, v.IsNil())
+
+ // SetDefaults on the root config should make it non-nil, otherwise
+ // it means that SetDefaults isn't being called recursively in
+ // all cases.
+ config.SetDefaults()
+ if assert.False(t, v.IsNil(), "%s should be non-nil after SetDefaults()", name) {
+ recursivelyUninitialize(config, fmt.Sprintf("(*%s)", name), v.Elem())
+ }
+
+ } else if v.Type().Kind() == reflect.Struct {
+ for i := 0; i < v.NumField(); i++ {
+ recursivelyUninitialize(config, fmt.Sprintf("%s.%s", name, v.Type().Field(i).Name), v.Field(i))
+ }
+ }
+ }
+
+ c := Config{}
+ c.SetDefaults()
+ recursivelyUninitialize(&c, "config", reflect.ValueOf(&c).Elem())
+ })
+}
+
func TestConfigDefaultFileSettingsDirectory(t *testing.T) {
c1 := Config{}
c1.SetDefaults()
diff --git a/model/utils.go b/model/utils.go
index 72369852b..2d61b49f6 100644
--- a/model/utils.go
+++ b/model/utils.go
@@ -15,9 +15,11 @@ import (
"net/http"
"net/mail"
"net/url"
+ "reflect"
"regexp"
"strconv"
"strings"
+ "testing"
"time"
"unicode"
@@ -469,3 +471,60 @@ func IsValidId(value string) bool {
return true
}
+
+// checkNowhereNil checks that the given interface value is not nil, and if a struct, that all of
+// its public fields are also nowhere nil
+func checkNowhereNil(t *testing.T, name string, value interface{}) bool {
+ if value == nil {
+ return false
+ }
+
+ v := reflect.ValueOf(value)
+ switch v.Type().Kind() {
+ case reflect.Ptr:
+ if v.IsNil() {
+ t.Logf("%s was nil", name)
+ return false
+ }
+
+ return checkNowhereNil(t, fmt.Sprintf("(*%s)", name), v.Elem().Interface())
+
+ case reflect.Map:
+ if v.IsNil() {
+ t.Logf("%s was nil", name)
+ return false
+ }
+
+ // Don't check map values
+ return true
+
+ case reflect.Struct:
+ nowhereNil := true
+ for i := 0; i < v.NumField(); i++ {
+ f := v.Field(i)
+ // Ignore unexported fields
+ if v.Type().Field(i).PkgPath != "" {
+ continue
+ }
+
+ nowhereNil = nowhereNil && checkNowhereNil(t, fmt.Sprintf("%s.%s", name, v.Type().Field(i).Name), f.Interface())
+ }
+
+ return nowhereNil
+
+ case reflect.Array:
+ fallthrough
+ case reflect.Chan:
+ fallthrough
+ case reflect.Func:
+ fallthrough
+ case reflect.Interface:
+ fallthrough
+ case reflect.UnsafePointer:
+ t.Logf("unhandled field %s, type: %s", name, v.Type().Kind())
+ return false
+
+ default:
+ return true
+ }
+}
diff --git a/model/utils_test.go b/model/utils_test.go
index 411d7bf50..92354c0a1 100644
--- a/model/utils_test.go
+++ b/model/utils_test.go
@@ -7,6 +7,8 @@ import (
"net/http"
"strings"
"testing"
+
+ "github.com/stretchr/testify/require"
)
func TestNewId(t *testing.T) {
@@ -367,3 +369,184 @@ func TestIsValidId(t *testing.T) {
}
}
}
+
+func TestNowhereNil(t *testing.T) {
+ t.Parallel()
+
+ var nilStringPtr *string
+ var nonNilStringPtr *string = new(string)
+ var nilSlice []string
+ var nilStruct *struct{}
+ var nilMap map[bool]bool
+
+ var nowhereNilStruct = struct {
+ X *string
+ Y *string
+ }{
+ nonNilStringPtr,
+ nonNilStringPtr,
+ }
+ var somewhereNilStruct = struct {
+ X *string
+ Y *string
+ }{
+ nonNilStringPtr,
+ nilStringPtr,
+ }
+
+ var privateSomewhereNilStruct = struct {
+ X *string
+ y *string
+ }{
+ nonNilStringPtr,
+ nilStringPtr,
+ }
+
+ testCases := []struct {
+ Description string
+ Value interface{}
+ Expected bool
+ }{
+ {
+ "nil",
+ nil,
+ false,
+ },
+ {
+ "empty string",
+ "",
+ true,
+ },
+ {
+ "non-empty string",
+ "not empty!",
+ true,
+ },
+ {
+ "nil string pointer",
+ nilStringPtr,
+ false,
+ },
+ {
+ "non-nil string pointer",
+ nonNilStringPtr,
+ true,
+ },
+ {
+ "0",
+ 0,
+ true,
+ },
+ {
+ "1",
+ 1,
+ true,
+ },
+ {
+ "0 (int64)",
+ int64(0),
+ true,
+ },
+ {
+ "1 (int64)",
+ int64(1),
+ true,
+ },
+ {
+ "true",
+ true,
+ true,
+ },
+ {
+ "false",
+ false,
+ true,
+ },
+ {
+ "nil slice",
+ nilSlice,
+ // A nil slice is observably the same as an empty slice, so allow it.
+ true,
+ },
+ {
+ "empty slice",
+ []string{},
+ true,
+ },
+ {
+ "slice containing nils",
+ []*string{nil, nil},
+ true,
+ },
+ {
+ "nil map",
+ nilMap,
+ false,
+ },
+ {
+ "non-nil map",
+ make(map[bool]bool),
+ true,
+ },
+ {
+ "non-nil map containing nil",
+ map[bool]*string{true: nilStringPtr, false: nonNilStringPtr},
+ // Map values are not checked
+ true,
+ },
+ {
+ "nil struct",
+ nilStruct,
+ false,
+ },
+ {
+ "empty struct",
+ struct{}{},
+ true,
+ },
+ {
+ "struct containing no nil",
+ nowhereNilStruct,
+ true,
+ },
+ {
+ "struct containing nil",
+ somewhereNilStruct,
+ false,
+ },
+ {
+ "struct pointer containing no nil",
+ &nowhereNilStruct,
+ true,
+ },
+ {
+ "struct pointer containing nil",
+ &somewhereNilStruct,
+ false,
+ },
+ {
+ "struct containing private nil",
+ privateSomewhereNilStruct,
+ true,
+ },
+ {
+ "struct pointer containing private nil",
+ &privateSomewhereNilStruct,
+ true,
+ },
+ }
+
+ for _, testCase := range testCases {
+ testCase := testCase
+ t.Run(testCase.Description, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r != nil {
+ t.Errorf("panic: %v", r)
+ }
+ }()
+
+ t.Parallel()
+ require.Equal(t, testCase.Expected, checkNowhereNil(t, "value", testCase.Value))
+ })
+ }
+}