From 07777f5ff9e0bde26abd0288164e5f73b6da992a Mon Sep 17 00:00:00 2001 From: Chris Date: Wed, 4 Oct 2017 13:09:41 -0700 Subject: Fix races / finally remove global app for good (#7570) * finally remove global app for good * test compilation fixes * fix races * fix deadlock * wake up write pump so it doesn't take forever to clean up --- api/admin_test.go | 5 ---- api4/system_test.go | 5 ---- app/app.go | 56 ++++++++++++++------------------------ app/authorization_test.go | 2 -- app/channel_test.go | 2 ++ app/cluster_discovery_test.go | 1 + app/command_channel_rename_test.go | 1 + app/command_test.go | 1 + app/diagnostics_test.go | 1 + app/email_batching_test.go | 10 +++++++ app/file_test.go | 1 + app/import_test.go | 10 +++++-- app/job_test.go | 2 ++ app/license_test.go | 3 ++ app/notification_test.go | 29 ++++++++++++++++++-- app/oauth_test.go | 3 ++ app/post_test.go | 3 ++ app/session_test.go | 2 ++ app/team_test.go | 6 ++++ app/user_test.go | 8 ++++++ app/web_conn.go | 34 +++++++++++++++-------- app/web_hub.go | 10 +++++-- store/store.go | 11 ++++++++ 23 files changed, 139 insertions(+), 67 deletions(-) diff --git a/api/admin_test.go b/api/admin_test.go index dadc96c7d..e4ff1c202 100644 --- a/api/admin_test.go +++ b/api/admin_test.go @@ -8,16 +8,11 @@ import ( "strings" "testing" - "github.com/mattermost/mattermost-server/app" "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/store" "github.com/mattermost/mattermost-server/utils" ) -func init() { - app.UseGlobalApp = false -} - func TestGetLogs(t *testing.T) { th := Setup().InitSystemAdmin().InitBasic() defer th.TearDown() diff --git a/api4/system_test.go b/api4/system_test.go index 76dac5316..2855e5840 100644 --- a/api4/system_test.go +++ b/api4/system_test.go @@ -6,15 +6,10 @@ import ( "testing" l4g "github.com/alecthomas/log4go" - "github.com/mattermost/mattermost-server/app" "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/utils" ) -func init() { - app.UseGlobalApp = false -} - func TestGetPing(t *testing.T) { th := Setup().InitBasic().InitSystemAdmin() defer th.TearDown() diff --git a/app/app.go b/app/app.go index d0d5bb4e0..a250efe5c 100644 --- a/app/app.go +++ b/app/app.go @@ -6,7 +6,6 @@ package app import ( "io/ioutil" "net/http" - "sync" "sync/atomic" l4g "github.com/alecthomas/log4go" @@ -47,56 +46,41 @@ type App struct { Saml einterfaces.SamlInterface } -var globalApp App = App{ - goroutineExitSignal: make(chan struct{}, 1), - Jobs: &jobs.JobServer{}, -} - var appCount = 0 -var initEnterprise sync.Once - -var UseGlobalApp = true // New creates a new App. You must call Shutdown when you're done with it. -// XXX: Doesn't necessarily create a new App yet. +// XXX: For now, only one at a time is allowed as some resources are still shared. func New() *App { appCount++ - - if !UseGlobalApp { - if appCount > 1 { - panic("Only one App should exist at a time. Did you forget to call Shutdown()?") - } - app := &App{ - goroutineExitSignal: make(chan struct{}, 1), - Jobs: &jobs.JobServer{}, - } - app.initEnterprise() - return app + if appCount > 1 { + panic("Only one App should exist at a time. Did you forget to call Shutdown()?") } - initEnterprise.Do(func() { - globalApp.initEnterprise() - }) - return &globalApp + app := &App{ + goroutineExitSignal: make(chan struct{}, 1), + Jobs: &jobs.JobServer{}, + } + app.initEnterprise() + return app } func (a *App) Shutdown() { appCount-- - if appCount == 0 { - if a.Srv != nil { - l4g.Info(utils.T("api.server.stop_server.stopping.info")) - a.Srv.GracefulServer.Stop(TIME_TO_WAIT_FOR_CONNECTIONS_TO_CLOSE_ON_SERVER_SHUTDOWN) - a.Srv.Store.Close() - a.HubStop() + if a.Srv != nil { + l4g.Info(utils.T("api.server.stop_server.stopping.info")) - a.ShutDownPlugins() - a.WaitForGoroutines() + a.Srv.GracefulServer.Stop(TIME_TO_WAIT_FOR_CONNECTIONS_TO_CLOSE_ON_SERVER_SHUTDOWN) + <-a.Srv.GracefulServer.StopChan() + a.HubStop() - a.Srv = nil + a.ShutDownPlugins() + a.WaitForGoroutines() - l4g.Info(utils.T("api.server.stop_server.stopped.info")) - } + a.Srv.Store.Close() + a.Srv = nil + + l4g.Info(utils.T("api.server.stop_server.stopped.info")) } } diff --git a/app/authorization_test.go b/app/authorization_test.go index 279bf17dc..375b279dc 100644 --- a/app/authorization_test.go +++ b/app/authorization_test.go @@ -10,8 +10,6 @@ import ( ) func TestCheckIfRolesGrantPermission(t *testing.T) { - Setup() - cases := []struct { roles []string permissionId string diff --git a/app/channel_test.go b/app/channel_test.go index 34a9d8150..b1d403896 100644 --- a/app/channel_test.go +++ b/app/channel_test.go @@ -9,6 +9,7 @@ import ( func TestPermanentDeleteChannel(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() incomingWasEnabled := utils.Cfg.ServiceSettings.EnableIncomingWebhooks outgoingWasEnabled := utils.Cfg.ServiceSettings.EnableOutgoingWebhooks @@ -67,6 +68,7 @@ func TestPermanentDeleteChannel(t *testing.T) { func TestMoveChannel(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() sourceTeam := th.CreateTeam() targetTeam := th.CreateTeam() diff --git a/app/cluster_discovery_test.go b/app/cluster_discovery_test.go index cd61c0f03..c9d7e814d 100644 --- a/app/cluster_discovery_test.go +++ b/app/cluster_discovery_test.go @@ -13,6 +13,7 @@ import ( func TestClusterDiscoveryService(t *testing.T) { th := Setup() + defer th.TearDown() ds := th.App.NewClusterDiscoveryService() ds.Type = model.CDS_TYPE_APP diff --git a/app/command_channel_rename_test.go b/app/command_channel_rename_test.go index ed4186c7f..372e366b9 100644 --- a/app/command_channel_rename_test.go +++ b/app/command_channel_rename_test.go @@ -9,6 +9,7 @@ import ( func TestRenameProviderDoCommand(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() rp := RenameProvider{} args := &model.CommandArgs{ diff --git a/app/command_test.go b/app/command_test.go index be1da3ac7..b37e78ea9 100644 --- a/app/command_test.go +++ b/app/command_test.go @@ -13,6 +13,7 @@ import ( func TestMoveCommand(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() sourceTeam := th.CreateTeam() targetTeam := th.CreateTeam() diff --git a/app/diagnostics_test.go b/app/diagnostics_test.go index 034b4cc9d..a7d879a7f 100644 --- a/app/diagnostics_test.go +++ b/app/diagnostics_test.go @@ -48,6 +48,7 @@ func TestPluginSetting(t *testing.T) { func TestDiagnostics(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() if testing.Short() { t.SkipNow() diff --git a/app/email_batching_test.go b/app/email_batching_test.go index b69eeec2d..4ebccf334 100644 --- a/app/email_batching_test.go +++ b/app/email_batching_test.go @@ -14,6 +14,7 @@ import ( func TestHandleNewNotifications(t *testing.T) { th := Setup() + defer th.TearDown() id1 := model.NewId() id2 := model.NewId() @@ -94,6 +95,7 @@ func TestHandleNewNotifications(t *testing.T) { func TestCheckPendingNotifications(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() job := NewEmailBatchingJob(th.App, 128) job.pendingNotifications[th.BasicUser.Id] = []*batchedNotification{ @@ -201,6 +203,8 @@ func TestCheckPendingNotifications(t *testing.T) { */ func TestCheckPendingNotificationsDefaultInterval(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() + job := NewEmailBatchingJob(th.App, 128) // bypasses recent user activity check @@ -237,6 +241,8 @@ func TestCheckPendingNotificationsDefaultInterval(t *testing.T) { */ func TestCheckPendingNotificationsCantParseInterval(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() + job := NewEmailBatchingJob(th.App, 128) // bypasses recent user activity check @@ -281,6 +287,8 @@ func TestCheckPendingNotificationsCantParseInterval(t *testing.T) { */ func TestRenderBatchedPostGeneric(t *testing.T) { th := Setup() + defer th.TearDown() + var post = &model.Post{} post.Message = "This is the message" var notification = &batchedNotification{} @@ -306,6 +314,8 @@ func TestRenderBatchedPostGeneric(t *testing.T) { */ func TestRenderBatchedPostFull(t *testing.T) { th := Setup() + defer th.TearDown() + var post = &model.Post{} post.Message = "This is the message" var notification = &batchedNotification{} diff --git a/app/file_test.go b/app/file_test.go index f3141fa18..d86272063 100644 --- a/app/file_test.go +++ b/app/file_test.go @@ -37,6 +37,7 @@ func TestGeneratePublicLinkHash(t *testing.T) { func TestDoUploadFile(t *testing.T) { th := Setup() + defer th.TearDown() teamId := model.NewId() channelId := model.NewId() diff --git a/app/import_test.go b/app/import_test.go index 86485900d..ccd49608e 100644 --- a/app/import_test.go +++ b/app/import_test.go @@ -966,6 +966,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { func TestImportImportTeam(t *testing.T) { th := Setup() + defer th.TearDown() // Check how many teams are in the database. var teamsCount int64 @@ -1074,6 +1075,7 @@ func TestImportImportTeam(t *testing.T) { func TestImportImportChannel(t *testing.T) { th := Setup() + defer th.TearDown() // Import a Team. teamName := model.NewId() @@ -1233,6 +1235,7 @@ func TestImportImportChannel(t *testing.T) { func TestImportImportUser(t *testing.T) { th := Setup() + defer th.TearDown() // Check how many users are in the database. var userCount int64 @@ -1910,6 +1913,7 @@ func AssertAllPostsCount(t *testing.T, a *App, initialCount int64, change int64, func TestImportImportPost(t *testing.T) { th := Setup() + defer th.TearDown() // Create a Team. teamName := model.NewId() @@ -2190,6 +2194,7 @@ func TestImportImportPost(t *testing.T) { func TestImportImportDirectChannel(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() // Check how many channels are in the database. var directChannelCount int64 @@ -2400,6 +2405,7 @@ func AssertChannelCount(t *testing.T, a *App, channelType string, expectedCount func TestImportImportDirectPost(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() // Create the DIRECT channel. channelData := DirectChannelImportData{ @@ -2798,6 +2804,7 @@ func TestImportImportDirectPost(t *testing.T) { func TestImportImportLine(t *testing.T) { th := Setup() + defer th.TearDown() // Try import line with an invalid type. line := LineImportData{ @@ -2847,6 +2854,7 @@ func TestImportImportLine(t *testing.T) { func TestImportBulkImport(t *testing.T) { th := Setup() + defer th.TearDown() teamName := model.NewId() channelName := model.NewId() @@ -2888,8 +2896,6 @@ func TestImportBulkImport(t *testing.T) { } func TestImportProcessImportDataFileVersionLine(t *testing.T) { - Setup() - data := LineImportData{ Type: "version", Version: ptrInt(1), diff --git a/app/job_test.go b/app/job_test.go index 18186cc47..35c4a6bc8 100644 --- a/app/job_test.go +++ b/app/job_test.go @@ -12,6 +12,7 @@ import ( func TestGetJob(t *testing.T) { th := Setup() + defer th.TearDown() status := &model.Job{ Id: model.NewId(), @@ -32,6 +33,7 @@ func TestGetJob(t *testing.T) { func TestGetJobByType(t *testing.T) { th := Setup() + defer th.TearDown() jobType := model.NewId() diff --git a/app/license_test.go b/app/license_test.go index 376972b2b..632034b11 100644 --- a/app/license_test.go +++ b/app/license_test.go @@ -12,6 +12,7 @@ import ( func TestLoadLicense(t *testing.T) { th := Setup() + defer th.TearDown() th.App.LoadLicense() if utils.IsLicensed() { @@ -21,6 +22,7 @@ func TestLoadLicense(t *testing.T) { func TestSaveLicense(t *testing.T) { th := Setup() + defer th.TearDown() b1 := []byte("junk") @@ -31,6 +33,7 @@ func TestSaveLicense(t *testing.T) { func TestRemoveLicense(t *testing.T) { th := Setup() + defer th.TearDown() if err := th.App.RemoveLicense(); err != nil { t.Fatal("should have removed license") diff --git a/app/notification_test.go b/app/notification_test.go index 28c931d1f..f5224e84e 100644 --- a/app/notification_test.go +++ b/app/notification_test.go @@ -13,6 +13,7 @@ import ( func TestSendNotifications(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() th.App.AddUserToChannel(th.BasicUser2, th.BasicChannel) @@ -408,7 +409,9 @@ func TestRemoveCodeFromMessage(t *testing.T) { } func TestGetMentionKeywords(t *testing.T) { - Setup() + th := Setup() + defer th.TearDown() + // user with username or custom mentions enabled user1 := &model.User{ Id: model.NewId(), @@ -856,7 +859,9 @@ func TestDoesStatusAllowPushNotification(t *testing.T) { } func TestGetDirectMessageNotificationEmailSubject(t *testing.T) { - Setup() + th := Setup() + defer th.TearDown() + expectedPrefix := "[http://localhost:8065] New Direct Message from sender on" post := &model.Post{ CreateAt: 1501804801000, @@ -869,7 +874,9 @@ func TestGetDirectMessageNotificationEmailSubject(t *testing.T) { } func TestGetNotificationEmailSubject(t *testing.T) { - Setup() + th := Setup() + defer th.TearDown() + expectedPrefix := "[http://localhost:8065] Notification in team on" post := &model.Post{ CreateAt: 1501804801000, @@ -883,6 +890,8 @@ func TestGetNotificationEmailSubject(t *testing.T) { func TestGetNotificationEmailBodyFullNotificationPublicChannel(t *testing.T) { th := Setup() + defer th.TearDown() + recipient := &model.User{} post := &model.Post{ Message: "This is the message", @@ -917,6 +926,8 @@ func TestGetNotificationEmailBodyFullNotificationPublicChannel(t *testing.T) { func TestGetNotificationEmailBodyFullNotificationGroupChannel(t *testing.T) { th := Setup() + defer th.TearDown() + recipient := &model.User{} post := &model.Post{ Message: "This is the message", @@ -951,6 +962,8 @@ func TestGetNotificationEmailBodyFullNotificationGroupChannel(t *testing.T) { func TestGetNotificationEmailBodyFullNotificationPrivateChannel(t *testing.T) { th := Setup() + defer th.TearDown() + recipient := &model.User{} post := &model.Post{ Message: "This is the message", @@ -985,6 +998,8 @@ func TestGetNotificationEmailBodyFullNotificationPrivateChannel(t *testing.T) { func TestGetNotificationEmailBodyFullNotificationDirectChannel(t *testing.T) { th := Setup() + defer th.TearDown() + recipient := &model.User{} post := &model.Post{ Message: "This is the message", @@ -1017,6 +1032,8 @@ func TestGetNotificationEmailBodyFullNotificationDirectChannel(t *testing.T) { // from here func TestGetNotificationEmailBodyGenericNotificationPublicChannel(t *testing.T) { th := Setup() + defer th.TearDown() + recipient := &model.User{} post := &model.Post{ Message: "This is the message", @@ -1048,6 +1065,8 @@ func TestGetNotificationEmailBodyGenericNotificationPublicChannel(t *testing.T) func TestGetNotificationEmailBodyGenericNotificationGroupChannel(t *testing.T) { th := Setup() + defer th.TearDown() + recipient := &model.User{} post := &model.Post{ Message: "This is the message", @@ -1079,6 +1098,8 @@ func TestGetNotificationEmailBodyGenericNotificationGroupChannel(t *testing.T) { func TestGetNotificationEmailBodyGenericNotificationPrivateChannel(t *testing.T) { th := Setup() + defer th.TearDown() + recipient := &model.User{} post := &model.Post{ Message: "This is the message", @@ -1110,6 +1131,8 @@ func TestGetNotificationEmailBodyGenericNotificationPrivateChannel(t *testing.T) func TestGetNotificationEmailBodyGenericNotificationDirectChannel(t *testing.T) { th := Setup() + defer th.TearDown() + recipient := &model.User{} post := &model.Post{ Message: "This is the message", diff --git a/app/oauth_test.go b/app/oauth_test.go index d756c0abe..81f331657 100644 --- a/app/oauth_test.go +++ b/app/oauth_test.go @@ -12,6 +12,8 @@ import ( func TestOAuthRevokeAccessToken(t *testing.T) { th := Setup() + defer th.TearDown() + if err := th.App.RevokeAccessToken(model.NewRandomString(16)); err == nil { t.Fatal("Should have failed bad token") } @@ -46,6 +48,7 @@ func TestOAuthRevokeAccessToken(t *testing.T) { func TestOAuthDeleteApp(t *testing.T) { th := Setup() + defer th.TearDown() oldSetting := utils.Cfg.ServiceSettings.EnableOAuthServiceProvider defer func() { diff --git a/app/post_test.go b/app/post_test.go index 92eb8857e..5fa3d50d6 100644 --- a/app/post_test.go +++ b/app/post_test.go @@ -20,6 +20,7 @@ import ( func TestUpdatePostEditAt(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() post := &model.Post{} *post = *th.BasicPost @@ -47,6 +48,7 @@ func TestPostReplyToPostWhereRootPosterLeftChannel(t *testing.T) { // This test ensures that when replying to a root post made by a user who has since left the channel, the reply // post completes successfully. This is a regression test for PLT-6523. th := Setup().InitBasic() + defer th.TearDown() channel := th.BasicChannel userInChannel := th.BasicUser2 @@ -78,6 +80,7 @@ func TestPostReplyToPostWhereRootPosterLeftChannel(t *testing.T) { func TestPostAction(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() allowedInternalConnections := *utils.Cfg.ServiceSettings.AllowedUntrustedInternalConnections defer func() { diff --git a/app/session_test.go b/app/session_test.go index c001655db..5915b932d 100644 --- a/app/session_test.go +++ b/app/session_test.go @@ -15,6 +15,7 @@ import ( func TestCache(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() session := &model.Session{ Id: model.NewId(), @@ -39,6 +40,7 @@ func TestCache(t *testing.T) { func TestGetSessionIdleTimeoutInMinutes(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() session := &model.Session{ UserId: model.NewId(), diff --git a/app/team_test.go b/app/team_test.go index b074ed14f..7992dd0c3 100644 --- a/app/team_test.go +++ b/app/team_test.go @@ -12,6 +12,7 @@ import ( func TestCreateTeam(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() id := model.NewId() team := &model.Team{ @@ -33,6 +34,7 @@ func TestCreateTeam(t *testing.T) { func TestCreateTeamWithUser(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() id := model.NewId() team := &model.Team{ @@ -76,6 +78,7 @@ func TestCreateTeamWithUser(t *testing.T) { func TestUpdateTeam(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() th.BasicTeam.DisplayName = "Testing 123" @@ -91,6 +94,7 @@ func TestUpdateTeam(t *testing.T) { func TestAddUserToTeam(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() user := model.User{Email: strings.ToLower(model.NewId()) + "success+test@example.com", Nickname: "Darth Vader", Username: "vader" + model.NewId(), Password: "passwd1", AuthService: ""} ruser, _ := th.App.CreateUser(&user) @@ -103,6 +107,7 @@ func TestAddUserToTeam(t *testing.T) { func TestAddUserToTeamByTeamId(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() user := model.User{Email: strings.ToLower(model.NewId()) + "success+test@example.com", Nickname: "Darth Vader", Username: "vader" + model.NewId(), Password: "passwd1", AuthService: ""} ruser, _ := th.App.CreateUser(&user) @@ -115,6 +120,7 @@ func TestAddUserToTeamByTeamId(t *testing.T) { func TestPermanentDeleteTeam(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() team, err := th.App.CreateTeam(&model.Team{ DisplayName: "deletion-test", diff --git a/app/user_test.go b/app/user_test.go index 63d2aafd5..39be7eafa 100644 --- a/app/user_test.go +++ b/app/user_test.go @@ -21,6 +21,8 @@ import ( func TestIsUsernameTaken(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() + user := th.BasicUser taken := th.App.IsUsernameTaken(user.Username) @@ -40,6 +42,8 @@ func TestIsUsernameTaken(t *testing.T) { func TestCheckUserDomain(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() + user := th.BasicUser cases := []struct { @@ -66,6 +70,8 @@ func TestCheckUserDomain(t *testing.T) { func TestCreateOAuthUser(t *testing.T) { th := Setup().InitBasic() + defer th.TearDown() + r := rand.New(rand.NewSource(time.Now().UnixNano())) glUser := oauthgitlab.GitLabUser{Id: int64(r.Intn(1000)) + 1, Username: "o" + model.NewId(), Email: model.NewId() + "@simulator.amazonses.com", Name: "Joram Wilander"} @@ -116,6 +122,8 @@ func TestCreateProfileImage(t *testing.T) { func TestUpdateOAuthUserAttrs(t *testing.T) { th := Setup() + defer th.TearDown() + id := model.NewId() id2 := model.NewId() gitlabProvider := einterfaces.GetOauthProvider("gitlab") diff --git a/app/web_conn.go b/app/web_conn.go index 5f66d9a51..92b54723a 100644 --- a/app/web_conn.go +++ b/app/web_conn.go @@ -40,6 +40,8 @@ type WebConn struct { AllChannelMembers map[string]string LastAllChannelMembersTime int64 Sequence int64 + endWritePump chan struct{} + pumpFinished chan struct{} } func (a *App) NewWebConn(ws *websocket.Conn, session model.Session, t goi18n.TranslateFunc, locale string) *WebConn { @@ -51,12 +53,14 @@ func (a *App) NewWebConn(ws *websocket.Conn, session model.Session, t goi18n.Tra } wc := &WebConn{ - App: a, - Send: make(chan model.WebSocketMessage, SEND_QUEUE_SIZE), - WebSocket: ws, - UserId: session.UserId, - T: t, - Locale: locale, + App: a, + Send: make(chan model.WebSocketMessage, SEND_QUEUE_SIZE), + WebSocket: ws, + UserId: session.UserId, + T: t, + Locale: locale, + endWritePump: make(chan struct{}, 1), + pumpFinished: make(chan struct{}, 1), } wc.SetSession(&session) @@ -66,6 +70,12 @@ func (a *App) NewWebConn(ws *websocket.Conn, session model.Session, t goi18n.Tra return wc } +func (wc *WebConn) Close() { + wc.WebSocket.Close() + wc.endWritePump <- struct{}{} + <-wc.pumpFinished +} + func (c *WebConn) GetSessionExpiresAt() int64 { return atomic.LoadInt64(&c.sessionExpiresAt) } @@ -97,14 +107,15 @@ func (c *WebConn) SetSession(v *model.Session) { func (c *WebConn) Pump() { ch := make(chan struct{}, 1) go func() { - c.WritePump() + c.writePump() ch <- struct{}{} }() - c.ReadPump() + c.readPump() <-ch + c.pumpFinished <- struct{}{} } -func (c *WebConn) ReadPump() { +func (c *WebConn) readPump() { defer func() { c.App.HubUnregister(c) c.WebSocket.Close() @@ -138,7 +149,7 @@ func (c *WebConn) ReadPump() { } } -func (c *WebConn) WritePump() { +func (c *WebConn) writePump() { ticker := time.NewTicker(PING_PERIOD) authTicker := time.NewTicker(AUTH_TIMEOUT) @@ -221,7 +232,8 @@ func (c *WebConn) WritePump() { return } - + case <-c.endWritePump: + return case <-authTicker.C: if c.GetSessionToken() == "" { l4g.Debug(fmt.Sprintf("websocket.authTicker: did not authenticate ip=%v", c.WebSocket.RemoteAddr())) diff --git a/app/web_hub.go b/app/web_hub.go index 0a70cb6d1..1525dfbba 100644 --- a/app/web_hub.go +++ b/app/web_hub.go @@ -36,6 +36,7 @@ type Hub struct { unregister chan *WebConn broadcast chan *model.WebSocketEvent stop chan string + didStop chan struct{} invalidateUser chan string ExplicitStop bool goroutineId int @@ -44,11 +45,12 @@ type Hub struct { func (a *App) NewWebHub() *Hub { return &Hub{ app: a, - register: make(chan *WebConn), - unregister: make(chan *WebConn), + register: make(chan *WebConn, 1), + unregister: make(chan *WebConn, 1), connections: make([]*WebConn, 0, model.SESSION_CACHE_SIZE), broadcast: make(chan *model.WebSocketEvent, BROADCAST_QUEUE_SIZE), stop: make(chan string), + didStop: make(chan struct{}, 1), invalidateUser: make(chan string), ExplicitStop: false, } @@ -348,6 +350,7 @@ func getGoroutineId() int { func (h *Hub) Stop() { h.stop <- "all" + <-h.didStop } func (h *Hub) Start() { @@ -428,9 +431,10 @@ func (h *Hub) Start() { case <-h.stop: for _, webCon := range h.connections { - webCon.WebSocket.Close() + webCon.Close() } h.ExplicitStop = true + h.didStop <- struct{}{} return } diff --git a/store/store.go b/store/store.go index bc9aa8f1a..120778e84 100644 --- a/store/store.go +++ b/store/store.go @@ -18,6 +18,17 @@ type StoreResult struct { type StoreChannel chan StoreResult +func Do(f func(result *StoreResult)) StoreChannel { + storeChannel := make(StoreChannel, 1) + go func() { + result := StoreResult{} + f(&result) + storeChannel <- result + close(storeChannel) + }() + return storeChannel +} + func Must(sc StoreChannel) interface{} { r := <-sc if r.Err != nil { -- cgit v1.2.3-1-g7c22