From 9071553165cfc9f073f57aab96a3e6a7c771c8f3 Mon Sep 17 00:00:00 2001 From: Corey Hulen Date: Mon, 24 Oct 2016 17:04:11 -0700 Subject: PLT-4359 fixing push notification for more than 1 device (#4318) * PLT-4359 fixing push notification for more than 1 device * Addressing feedback --- api/post.go | 40 ++++++++++++++++++---------------------- store/sql_session_store.go | 22 ++++++++++++++++++++++ store/sql_session_store_test.go | 28 ++++++++++++++++++++++++++++ store/store.go | 1 + 4 files changed, 69 insertions(+), 22 deletions(-) diff --git a/api/post.go b/api/post.go index 293ee0af0..35159693a 100644 --- a/api/post.go +++ b/api/post.go @@ -901,9 +901,9 @@ func getMessageForNotification(post *model.Post, translateFunc i18n.TranslateFun } func sendPushNotification(post *model.Post, user *model.User, channel *model.Channel, senderName string, wasMentioned bool) { - session := getMobileAppSession(user.Id) + sessions := getMobileAppSessions(user.Id) - if session == nil { + if sessions == nil { return } @@ -928,8 +928,6 @@ func sendPushNotification(post *model.Post, user *model.User, channel *model.Cha msg.ChannelId = channel.Id msg.ChannelName = channel.Name - msg.SetDeviceIdAndPlatform(session.DeviceId) - if *utils.Cfg.EmailSettings.PushNotificationContents == model.FULL_NOTIFICATION { if channel.Type == model.CHANNEL_DIRECT { msg.Category = model.CATEGORY_DM @@ -949,12 +947,17 @@ func sendPushNotification(post *model.Post, user *model.User, channel *model.Cha } l4g.Debug(utils.T("api.post.send_notifications_and_forget.push_notification.debug"), msg.DeviceId, msg.Message) - sendToPushProxy(msg) + + for _, session := range sessions { + tmpMessage := *model.PushNotificationFromJson(strings.NewReader(msg.ToJson())) + tmpMessage.SetDeviceIdAndPlatform(session.DeviceId) + sendToPushProxy(tmpMessage) + } } func clearPushNotification(userId string, channelId string) { - session := getMobileAppSession(userId) - if session == nil { + sessions := getMobileAppSessions(userId) + if sessions == nil { return } @@ -969,10 +972,12 @@ func clearPushNotification(userId string, channelId string) { msg.Badge = int(badge.Data.(int64)) } - msg.SetDeviceIdAndPlatform(session.DeviceId) - l4g.Debug(utils.T("api.post.send_notifications_and_forget.clear_push_notification.debug"), msg.DeviceId, msg.ChannelId) - sendToPushProxy(msg) + for _, session := range sessions { + tmpMessage := *model.PushNotificationFromJson(strings.NewReader(msg.ToJson())) + tmpMessage.SetDeviceIdAndPlatform(session.DeviceId) + sendToPushProxy(tmpMessage) + } } func sendToPushProxy(msg model.PushNotification) { @@ -992,22 +997,13 @@ func sendToPushProxy(msg model.PushNotification) { } } -func getMobileAppSession(userId string) *model.Session { - var sessions []*model.Session - if result := <-Srv.Store.Session().GetSessions(userId); result.Err != nil { +func getMobileAppSessions(userId string) []*model.Session { + if result := <-Srv.Store.Session().GetSessionsWithActiveDeviceIds(userId); result.Err != nil { l4g.Error(utils.T("api.post.send_notifications_and_forget.sessions.error"), userId, result.Err) return nil } else { - sessions = result.Data.([]*model.Session) - } - - for _, session := range sessions { - if session.IsMobileApp() { - return session - } + return result.Data.([]*model.Session) } - - return nil } func sendOutOfChannelMentions(c *Context, post *model.Post, profiles map[string]*model.User) { diff --git a/store/sql_session_store.go b/store/sql_session_store.go index 4e1bea3cf..5892dab01 100644 --- a/store/sql_session_store.go +++ b/store/sql_session_store.go @@ -168,6 +168,28 @@ func (me SqlSessionStore) GetSessions(userId string) StoreChannel { return storeChannel } +func (me SqlSessionStore) GetSessionsWithActiveDeviceIds(userId string) StoreChannel { + storeChannel := make(StoreChannel, 1) + + go func() { + + result := StoreResult{} + var sessions []*model.Session + + if _, err := me.GetReplica().Select(&sessions, "SELECT * FROM Sessions WHERE UserId = :UserId AND ExpiresAt != 0 AND :ExpiresAt <= ExpiresAt AND DeviceId != ''", map[string]interface{}{"UserId": userId, "ExpiresAt": model.GetMillis()}); err != nil { + result.Err = model.NewLocAppError("SqlSessionStore.GetActiveSessionsWithDeviceIds", "store.sql_session.get_sessions.app_error", nil, err.Error()) + } else { + + result.Data = sessions + } + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + func (me SqlSessionStore) Remove(sessionIdOrToken string) StoreChannel { storeChannel := make(StoreChannel, 1) diff --git a/store/sql_session_store_test.go b/store/sql_session_store_test.go index e64a350ba..24526a4a9 100644 --- a/store/sql_session_store_test.go +++ b/store/sql_session_store_test.go @@ -50,7 +50,35 @@ func TestSessionGet(t *testing.T) { t.Fatal("should match len") } } +} + +func TestSessionGetWithDeviceId(t *testing.T) { + Setup() + + s1 := model.Session{} + s1.UserId = model.NewId() + s1.ExpiresAt = model.GetMillis() + 10000 + Must(store.Session().Save(&s1)) + + s2 := model.Session{} + s2.UserId = s1.UserId + s2.DeviceId = model.NewId() + s2.ExpiresAt = model.GetMillis() + 10000 + Must(store.Session().Save(&s2)) + s3 := model.Session{} + s3.UserId = s1.UserId + s3.ExpiresAt = 1 + s3.DeviceId = model.NewId() + Must(store.Session().Save(&s3)) + + if rs1 := (<-store.Session().GetSessionsWithActiveDeviceIds(s1.UserId)); rs1.Err != nil { + t.Fatal(rs1.Err) + } else { + if len(rs1.Data.([]*model.Session)) != 1 { + t.Fatal("should match len") + } + } } func TestSessionRemove(t *testing.T) { diff --git a/store/store.go b/store/store.go index 900709f16..51aada920 100644 --- a/store/store.go +++ b/store/store.go @@ -172,6 +172,7 @@ type SessionStore interface { Save(session *model.Session) StoreChannel Get(sessionIdOrToken string) StoreChannel GetSessions(userId string) StoreChannel + GetSessionsWithActiveDeviceIds(userId string) StoreChannel Remove(sessionIdOrToken string) StoreChannel RemoveAllSessions() StoreChannel PermanentDeleteSessionsByUser(teamId string) StoreChannel -- cgit v1.2.3-1-g7c22