summaryrefslogtreecommitdiffstats
path: root/store/sqlstore/store_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'store/sqlstore/store_test.go')
-rw-r--r--store/sqlstore/store_test.go104
1 files changed, 96 insertions, 8 deletions
diff --git a/store/sqlstore/store_test.go b/store/sqlstore/store_test.go
index 605c73b6a..d99c7e441 100644
--- a/store/sqlstore/store_test.go
+++ b/store/sqlstore/store_test.go
@@ -4,22 +4,110 @@
package sqlstore
import (
+ "flag"
+ "os"
+ "sync"
"testing"
+ "github.com/mattermost/mattermost-server/model"
"github.com/mattermost/mattermost-server/store"
+ "github.com/mattermost/mattermost-server/store/storetest"
"github.com/mattermost/mattermost-server/utils"
)
-var sqlStore store.Store
+var storeTypes = []*struct {
+ Name string
+ Func func() (*storetest.RunningContainer, *model.SqlSettings, error)
+ Container *storetest.RunningContainer
+ Store store.Store
+}{
+ {
+ Name: "MySQL",
+ Func: storetest.NewMySQLContainer,
+ },
+ {
+ Name: "PostgreSQL",
+ Func: storetest.NewPostgreSQLContainer,
+ },
+}
func StoreTest(t *testing.T, f func(*testing.T, store.Store)) {
- if sqlStore == nil {
- utils.TranslationsPreInit()
- utils.LoadConfig("config.json")
- utils.InitTranslations(utils.Cfg.LocalizationSettings)
- sqlStore = store.NewLayeredStore(NewSqlSupplier(nil), nil, nil)
+ defer func() {
+ if err := recover(); err != nil {
+ tearDownStores()
+ panic(err)
+ }
+ }()
+ for _, st := range storeTypes {
+ st := st
+ t.Run(st.Name, func(t *testing.T) { f(t, st.Store) })
+ }
+}
- sqlStore.MarkSystemRanUnitTests()
+func initStores() {
+ defer func() {
+ if err := recover(); err != nil {
+ tearDownStores()
+ panic(err)
+ }
+ }()
+ var wg sync.WaitGroup
+ errCh := make(chan error, len(storeTypes))
+ wg.Add(len(storeTypes))
+ for _, st := range storeTypes {
+ st := st
+ go func() {
+ defer wg.Done()
+ container, settings, err := st.Func()
+ if err != nil {
+ errCh <- err
+ return
+ }
+ st.Container = container
+ st.Store = store.NewLayeredStore(NewSqlSupplier(*settings, nil), nil, nil)
+ st.Store.MarkSystemRanUnitTests()
+ }()
+ }
+ wg.Wait()
+ select {
+ case err := <-errCh:
+ panic(err)
+ default:
}
- f(t, sqlStore)
+}
+
+var tearDownStoresOnce sync.Once
+
+func tearDownStores() {
+ tearDownStoresOnce.Do(func() {
+ var wg sync.WaitGroup
+ wg.Add(len(storeTypes))
+ for _, st := range storeTypes {
+ st := st
+ go func() {
+ st.Store.Close()
+ st.Container.Stop()
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+ })
+}
+
+func TestMain(m *testing.M) {
+ flag.Parse()
+
+ utils.TranslationsPreInit()
+ utils.LoadConfig("config.json")
+ utils.InitTranslations(utils.Cfg.LocalizationSettings)
+
+ status := 0
+
+ initStores()
+ defer func() {
+ tearDownStores()
+ os.Exit(status)
+ }()
+
+ status = m.Run()
}