summaryrefslogtreecommitdiffstats
path: root/store/sqlstore/store_test.go
blob: 58065d65d42bb7245c685ad20210a6b86c5ce6f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See License.txt for license information.

package sqlstore

import (
	"os"
	"sync"
	"testing"

	"github.com/mattermost/mattermost-server/mlog"
	"github.com/mattermost/mattermost-server/model"
	"github.com/mattermost/mattermost-server/store"
	"github.com/mattermost/mattermost-server/store/storetest"
	"github.com/mattermost/mattermost-server/utils"
)

var storeTypes = []*struct {
	Name      string
	Func      func() (*storetest.RunningContainer, *model.SqlSettings, error)
	Container *storetest.RunningContainer
	Store     store.Store
}{
	{
		Name: "MySQL",
		Func: storetest.NewMySQLContainer,
	},
	{
		Name: "PostgreSQL",
		Func: storetest.NewPostgreSQLContainer,
	},
}

func StoreTest(t *testing.T, f func(*testing.T, store.Store)) {
	defer func() {
		if err := recover(); err != nil {
			tearDownStores()
			panic(err)
		}
	}()
	for _, st := range storeTypes {
		st := st
		t.Run(st.Name, func(t *testing.T) { f(t, st.Store) })
	}
}

func initStores() {
	defer func() {
		if err := recover(); err != nil {
			tearDownStores()
			panic(err)
		}
	}()
	var wg sync.WaitGroup
	errCh := make(chan error, len(storeTypes))
	wg.Add(len(storeTypes))
	for _, st := range storeTypes {
		st := st
		go func() {
			defer wg.Done()
			container, settings, err := st.Func()
			if err != nil {
				errCh <- err
				return
			}
			st.Container = container
			st.Store = store.NewLayeredStore(NewSqlSupplier(*settings, nil), nil, nil)
			st.Store.MarkSystemRanUnitTests()
		}()
	}
	wg.Wait()
	select {
	case err := <-errCh:
		panic(err)
	default:
	}
}

var tearDownStoresOnce sync.Once

func tearDownStores() {
	tearDownStoresOnce.Do(func() {
		var wg sync.WaitGroup
		wg.Add(len(storeTypes))
		for _, st := range storeTypes {
			st := st
			go func() {
				if st.Store != nil {
					st.Store.Close()
				}
				if st.Container != nil {
					st.Container.Stop()
				}
				wg.Done()
			}()
		}
		wg.Wait()
	})
}

func TestMain(m *testing.M) {
	// Setup a global logger to catch tests logging outside of app context
	// The global logger will be stomped by apps initalizing but that's fine for testing. Ideally this won't happen.
	mlog.InitGlobalLogger(mlog.NewLogger(&mlog.LoggerConfiguration{
		EnableConsole: true,
		ConsoleJson:   true,
		ConsoleLevel:  "error",
		EnableFile:    false,
	}))

	utils.TranslationsPreInit()

	status := 0

	initStores()
	defer func() {
		tearDownStores()
		os.Exit(status)
	}()

	status = m.Run()
}