From b87fae646a624507f5b2c1270cae1d3585f589ac Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 28 Nov 2017 15:02:56 -0600 Subject: PLT-5458: If someone posts a channel link to channel_A that you don't belong to, it doesn't render properly (#7833) * add channel link hints to post props * optimization * update regex, add unit test * fix rebase issue --- app/channel.go | 12 +++++++ app/post.go | 48 ++++++++++++++++++++++++++++ app/post_test.go | 38 ++++++++++++++++++++++ model/post.go | 17 ++++++++++ model/post_test.go | 7 +++++ store/sqlstore/channel_store.go | 59 +++++++++++++++++++++++++++++++++++ store/store.go | 1 + store/storetest/channel_store.go | 54 ++++++++++++++++++++++++++++++++ store/storetest/mocks/ChannelStore.go | 16 ++++++++++ 9 files changed, 252 insertions(+) diff --git a/app/channel.go b/app/channel.go index 50067d42d..16c5dd084 100644 --- a/app/channel.go +++ b/app/channel.go @@ -757,6 +757,18 @@ func (a *App) GetChannelByName(channelName, teamId string) (*model.Channel, *mod } } +func (a *App) GetChannelsByNames(channelNames []string, teamId string) ([]*model.Channel, *model.AppError) { + if result := <-a.Srv.Store.Channel().GetByNames(teamId, channelNames, true); result.Err != nil && result.Err.Id == "store.sql_channel.get_by_name.missing.app_error" { + result.Err.StatusCode = http.StatusNotFound + return nil, result.Err + } else if result.Err != nil { + result.Err.StatusCode = http.StatusBadRequest + return nil, result.Err + } else { + return result.Data.([]*model.Channel), nil + } +} + func (a *App) GetChannelByNameForTeamName(channelName, teamName string) (*model.Channel, *model.AppError) { var team *model.Team diff --git a/app/post.go b/app/post.go index 00944ee3b..0bd3b654f 100644 --- a/app/post.go +++ b/app/post.go @@ -152,6 +152,10 @@ func (a *App) CreatePost(post *model.Post, channel *model.Channel, triggerWebhoo post.Hashtags, _ = model.ParseHashtags(post.Message) + if err := a.FillInPostProps(post, channel); err != nil { + return nil, err + } + var rpost *model.Post if result := <-a.Srv.Store.Post().Save(post); result.Err != nil { return nil, result.Err @@ -192,6 +196,46 @@ func (a *App) CreatePost(post *model.Post, channel *model.Channel, triggerWebhoo return rpost, nil } +// FillInPostProps should be invoked before saving posts to fill in properties such as +// channel_mentions. +// +// If channel is nil, FillInPostProps will look up the channel corresponding to the post. +func (a *App) FillInPostProps(post *model.Post, channel *model.Channel) *model.AppError { + channelMentions := post.ChannelMentions() + channelMentionsProp := make(map[string]interface{}) + + if len(channelMentions) > 0 { + if channel == nil { + result := <-a.Srv.Store.Channel().GetForPost(post.Id) + if result.Err == nil { + return model.NewAppError("FillInPostProps", "api.context.invalid_param.app_error", map[string]interface{}{"Name": "post.channel_id"}, result.Err.Error(), http.StatusBadRequest) + } + channel = result.Data.(*model.Channel) + } + + mentionedChannels, err := a.GetChannelsByNames(channelMentions, channel.TeamId) + if err != nil { + return err + } + + for _, mentioned := range mentionedChannels { + if mentioned.Type == model.CHANNEL_OPEN { + channelMentionsProp[mentioned.Name] = map[string]interface{}{ + "display_name": mentioned.DisplayName, + } + } + } + } + + if len(channelMentionsProp) > 0 { + post.AddProp("channel_mentions", channelMentionsProp) + } else if post.Props != nil { + delete(post.Props, "channel_mentions") + } + + return nil +} + func (a *App) handlePostEvents(post *model.Post, user *model.User, channel *model.Channel, triggerWebhooks bool, parentPostList *model.PostList) *model.AppError { var tchan store.StoreChannel if len(channel.TeamId) > 0 { @@ -329,6 +373,10 @@ func (a *App) UpdatePost(post *model.Post, safeUpdate bool) (*model.Post, *model newPost.Props = post.Props } + if err := a.FillInPostProps(post, nil); err != nil { + return nil, err + } + if result := <-a.Srv.Store.Post().Update(newPost, oldPost); result.Err != nil { return nil, result.Err } else { diff --git a/app/post_test.go b/app/post_test.go index e2e9a7261..3b7e8d039 100644 --- a/app/post_test.go +++ b/app/post_test.go @@ -138,3 +138,41 @@ func TestPostAction(t *testing.T) { err = th.App.DoPostAction(post.Id, attachments[0].Actions[0].Id, th.BasicUser.Id) require.Nil(t, err) } + +func TestPostChannelMentions(t *testing.T) { + th := Setup().InitBasic() + defer th.TearDown() + + channel := th.BasicChannel + user := th.BasicUser + + channelToMention, err := th.App.CreateChannel(&model.Channel{ + DisplayName: "Mention Test", + Name: "mention-test", + Type: model.CHANNEL_OPEN, + TeamId: th.BasicTeam.Id, + }, false) + if err != nil { + t.Fatal(err.Error()) + } + defer th.App.PermanentDeleteChannel(channelToMention) + + _, err = th.App.AddUserToChannel(user, channel) + require.Nil(t, err) + + post := &model.Post{ + Message: fmt.Sprintf("hello, ~%v!", channelToMention.Name), + ChannelId: channel.Id, + PendingPostId: model.NewId() + ":" + fmt.Sprint(model.GetMillis()), + UserId: user.Id, + CreateAt: 0, + } + + result, err := th.App.CreatePostAsUser(post) + require.Nil(t, err) + assert.Equal(t, map[string]interface{}{ + "mention-test": map[string]interface{}{ + "display_name": "Mention Test", + }, + }, result.Props["channel_mentions"]) +} diff --git a/model/post.go b/model/post.go index 8e4689eb7..b7b38e7ad 100644 --- a/model/post.go +++ b/model/post.go @@ -7,6 +7,7 @@ import ( "encoding/json" "io" "net/http" + "regexp" "strings" "unicode/utf8" ) @@ -294,6 +295,22 @@ func PostPatchFromJson(data io.Reader) *PostPatch { return &post } +var channelMentionRegexp = regexp.MustCompile(`\B~[a-zA-Z0-9\-_]+`) + +func (o *Post) ChannelMentions() (names []string) { + if strings.Contains(o.Message, "~") { + alreadyMentioned := make(map[string]bool) + for _, match := range channelMentionRegexp.FindAllString(o.Message, -1) { + name := match[1:] + if !alreadyMentioned[name] { + names = append(names, name) + alreadyMentioned[name] = true + } + } + } + return +} + func (r *PostActionIntegrationRequest) ToJson() string { b, err := json.Marshal(r) if err != nil { diff --git a/model/post_test.go b/model/post_test.go index 846c8c775..6a908887d 100644 --- a/model/post_test.go +++ b/model/post_test.go @@ -6,6 +6,8 @@ package model import ( "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestPostJson(t *testing.T) { @@ -124,6 +126,11 @@ func TestPostIsSystemMessage(t *testing.T) { } } +func TestPostChannelMentions(t *testing.T) { + post := Post{Message: "~a ~b ~b ~c/~d."} + assert.Equal(t, []string{"a", "b", "c", "d"}, post.ChannelMentions()) +} + func TestPostSanitizeProps(t *testing.T) { post1 := &Post{ Message: "test", diff --git a/store/sqlstore/channel_store.go b/store/sqlstore/channel_store.go index 7328e2017..9869b3720 100644 --- a/store/sqlstore/channel_store.go +++ b/store/sqlstore/channel_store.go @@ -586,6 +586,65 @@ func (s SqlChannelStore) GetByName(teamId string, name string, allowFromCache bo return s.getByName(teamId, name, false, allowFromCache) } +func (s SqlChannelStore) GetByNames(teamId string, names []string, allowFromCache bool) store.StoreChannel { + return store.Do(func(result *store.StoreResult) { + var channels []*model.Channel + + if allowFromCache { + var misses []string + visited := make(map[string]struct{}) + for _, name := range names { + if _, ok := visited[name]; ok { + continue + } + visited[name] = struct{}{} + if cacheItem, ok := channelByNameCache.Get(teamId + name); ok { + if s.metrics != nil { + s.metrics.IncrementMemCacheHitCounter("Channel By Name") + } + channels = append(channels, cacheItem.(*model.Channel)) + } else { + if s.metrics != nil { + s.metrics.IncrementMemCacheMissCounter("Channel By Name") + } + misses = append(misses, name) + } + } + names = misses + } + + if len(names) > 0 { + props := map[string]interface{}{} + var namePlaceholders []string + for _, name := range names { + key := fmt.Sprintf("Name%v", len(namePlaceholders)) + props[key] = name + namePlaceholders = append(namePlaceholders, ":"+key) + } + + var query string + if teamId == "" { + query = `SELECT * FROM Channels WHERE Name IN (` + strings.Join(namePlaceholders, ", ") + `) AND DeleteAt = 0` + } else { + props["TeamId"] = teamId + query = `SELECT * FROM Channels WHERE Name IN (` + strings.Join(namePlaceholders, ", ") + `) AND TeamId = :TeamId AND DeleteAt = 0` + } + + var dbChannels []*model.Channel + if _, err := s.GetReplica().Select(&dbChannels, query, props); err != nil && err != sql.ErrNoRows { + result.Err = model.NewAppError("SqlChannelStore.GetByName", "store.sql_channel.get_by_name.existing.app_error", nil, "teamId="+teamId+", "+err.Error(), http.StatusInternalServerError) + return + } + for _, channel := range dbChannels { + channelByNameCache.AddWithExpiresInSecs(teamId+channel.Name, channel, CHANNEL_CACHE_SEC) + channels = append(channels, channel) + } + } + + result.Data = channels + }) +} + func (s SqlChannelStore) GetByNameIncludeDeleted(teamId string, name string, allowFromCache bool) store.StoreChannel { return s.getByName(teamId, name, true, allowFromCache) } diff --git a/store/store.go b/store/store.go index 7997000ec..3c950495d 100644 --- a/store/store.go +++ b/store/store.go @@ -119,6 +119,7 @@ type ChannelStore interface { PermanentDeleteByTeam(teamId string) StoreChannel PermanentDelete(channelId string) StoreChannel GetByName(team_id string, name string, allowFromCache bool) StoreChannel + GetByNames(team_id string, names []string, allowFromCache bool) StoreChannel GetByNameIncludeDeleted(team_id string, name string, allowFromCache bool) StoreChannel GetDeletedByName(team_id string, name string) StoreChannel GetDeleted(team_id string, offset int, limit int) StoreChannel diff --git a/store/storetest/channel_store.go b/store/storetest/channel_store.go index 1189fd976..853de67d8 100644 --- a/store/storetest/channel_store.go +++ b/store/storetest/channel_store.go @@ -4,10 +4,12 @@ package storetest import ( + "sort" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/store" @@ -24,6 +26,7 @@ func TestChannelStore(t *testing.T, ss store.Store) { t.Run("Restore", func(t *testing.T) { testChannelStoreRestore(t, ss) }) t.Run("Delete", func(t *testing.T) { testChannelStoreDelete(t, ss) }) t.Run("GetByName", func(t *testing.T) { testChannelStoreGetByName(t, ss) }) + t.Run("GetByNames", func(t *testing.T) { testChannelStoreGetByNames(t, ss) }) t.Run("GetDeletedByName", func(t *testing.T) { testChannelStoreGetDeletedByName(t, ss) }) t.Run("GetDeleted", func(t *testing.T) { testChannelStoreGetDeleted(t, ss) }) t.Run("ChannelMemberStore", func(t *testing.T) { testChannelMemberStore(t, ss) }) @@ -550,6 +553,57 @@ func testChannelStoreGetByName(t *testing.T, ss store.Store) { } } +func testChannelStoreGetByNames(t *testing.T, ss store.Store) { + o1 := model.Channel{ + TeamId: model.NewId(), + DisplayName: "Name", + Name: "zz" + model.NewId() + "b", + Type: model.CHANNEL_OPEN, + } + store.Must(ss.Channel().Save(&o1, -1)) + + o2 := model.Channel{ + TeamId: o1.TeamId, + DisplayName: "Name", + Name: "zz" + model.NewId() + "b", + Type: model.CHANNEL_OPEN, + } + store.Must(ss.Channel().Save(&o2, -1)) + + for index, tc := range []struct { + TeamId string + Names []string + ExpectedIds []string + }{ + {o1.TeamId, []string{o1.Name}, []string{o1.Id}}, + {o1.TeamId, []string{o1.Name, o2.Name}, []string{o1.Id, o2.Id}}, + {o1.TeamId, nil, nil}, + {o1.TeamId, []string{"foo"}, nil}, + {o1.TeamId, []string{o1.Name, "foo", o2.Name, o2.Name}, []string{o1.Id, o2.Id}}, + {"", []string{o1.Name, "foo", o2.Name, o2.Name}, []string{o1.Id, o2.Id}}, + {"asd", []string{o1.Name, "foo", o2.Name, o2.Name}, nil}, + } { + r := <-ss.Channel().GetByNames(tc.TeamId, tc.Names, true) + require.Nil(t, r.Err) + channels := r.Data.([]*model.Channel) + var ids []string + for _, channel := range channels { + ids = append(ids, channel.Id) + } + sort.Strings(ids) + sort.Strings(tc.ExpectedIds) + assert.Equal(t, tc.ExpectedIds, ids, "tc %v", index) + } + + store.Must(ss.Channel().Delete(o1.Id, model.GetMillis())) + store.Must(ss.Channel().Delete(o2.Id, model.GetMillis())) + + r := <-ss.Channel().GetByNames(o1.TeamId, []string{o1.Name}, false) + require.Nil(t, r.Err) + channels := r.Data.([]*model.Channel) + assert.Len(t, channels, 0) +} + func testChannelStoreGetDeletedByName(t *testing.T, ss store.Store) { o1 := model.Channel{} o1.TeamId = model.NewId() diff --git a/store/storetest/mocks/ChannelStore.go b/store/storetest/mocks/ChannelStore.go index 2ea5256d7..c878d602b 100644 --- a/store/storetest/mocks/ChannelStore.go +++ b/store/storetest/mocks/ChannelStore.go @@ -189,6 +189,22 @@ func (_m *ChannelStore) GetByNameIncludeDeleted(team_id string, name string, all return r0 } +// GetByNames provides a mock function with given fields: team_id, names, allowFromCache +func (_m *ChannelStore) GetByNames(team_id string, names []string, allowFromCache bool) store.StoreChannel { + ret := _m.Called(team_id, names, allowFromCache) + + var r0 store.StoreChannel + if rf, ok := ret.Get(0).(func(string, []string, bool) store.StoreChannel); ok { + r0 = rf(team_id, names, allowFromCache) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.StoreChannel) + } + } + + return r0 +} + // GetChannelCounts provides a mock function with given fields: teamId, userId func (_m *ChannelStore) GetChannelCounts(teamId string, userId string) store.StoreChannel { ret := _m.Called(teamId, userId) -- cgit v1.2.3-1-g7c22