diff options
-rw-r--r-- | api4/system.go | 2 | ||||
-rw-r--r-- | app/app.go | 2 | ||||
-rw-r--r-- | app/config.go | 8 | ||||
-rw-r--r-- | app/config_test.go | 7 | ||||
-rw-r--r-- | app/import.go | 29 | ||||
-rw-r--r-- | app/import_test.go | 69 | ||||
-rw-r--r-- | app/post.go | 11 | ||||
-rw-r--r-- | app/post_test.go | 59 | ||||
-rw-r--r-- | app/webhook.go | 8 | ||||
-rw-r--r-- | app/webhook_test.go | 16 | ||||
-rw-r--r-- | i18n/en.json | 12 | ||||
-rw-r--r-- | model/post.go | 8 | ||||
-rw-r--r-- | model/post_test.go | 31 | ||||
-rw-r--r-- | store/sqlstore/post_store.go | 205 | ||||
-rw-r--r-- | store/store.go | 1 | ||||
-rw-r--r-- | store/storetest/mocks/PostStore.go | 15 | ||||
-rw-r--r-- | store/storetest/post_store.go | 6 |
17 files changed, 349 insertions, 140 deletions
diff --git a/api4/system.go b/api4/system.go index 4ae8ee7b9..b34f2af6b 100644 --- a/api4/system.go +++ b/api4/system.go @@ -248,7 +248,7 @@ func getClientConfig(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Write([]byte(model.MapToJson(c.App.ClientConfigWithNoAccounts()))) + w.Write([]byte(model.MapToJson(c.App.ClientConfigWithComputed()))) } func getClientLicense(c *Context, w http.ResponseWriter, r *http.Request) { diff --git a/app/app.go b/app/app.go index 6329a80d3..cd9fdaa66 100644 --- a/app/app.go +++ b/app/app.go @@ -139,7 +139,7 @@ func New(options ...Option) (outApp *App, outErr error) { message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_CONFIG_CHANGED, "", "", "", nil) - message.Add("config", app.ClientConfigWithNoAccounts()) + message.Add("config", app.ClientConfigWithComputed()) app.Go(func() { app.Publish(message) }) diff --git a/app/config.go b/app/config.go index ccd7236a0..761fe3ec9 100644 --- a/app/config.go +++ b/app/config.go @@ -273,15 +273,17 @@ func (a *App) GetSiteURL() string { return a.siteURL } -// ClientConfigWithNoAccounts gets the configuration in a format suitable for sending to the client. -func (a *App) ClientConfigWithNoAccounts() map[string]string { +// ClientConfigWithComputed gets the configuration in a format suitable for sending to the client. +func (a *App) ClientConfigWithComputed() map[string]string { respCfg := map[string]string{} for k, v := range a.ClientConfig() { respCfg[k] = v } - // NoAccounts is not actually part of the configuration, but is expected by the client. + // These properties are not configurable, but nevertheless represent configuration expected + // by the client. respCfg["NoAccounts"] = strconv.FormatBool(a.IsFirstUserAccount()) + respCfg["MaxPostSize"] = strconv.Itoa(a.MaxPostSize()) return respCfg } diff --git a/app/config_test.go b/app/config_test.go index 051fa8fd8..4fc7df5e2 100644 --- a/app/config_test.go +++ b/app/config_test.go @@ -64,12 +64,15 @@ func TestAsymmetricSigningKey(t *testing.T) { assert.NotEmpty(t, th.App.ClientConfig()["AsymmetricSigningPublicKey"]) } -func TestClientConfigWithNoAccounts(t *testing.T) { +func TestClientConfigWithComputed(t *testing.T) { th := Setup().InitBasic() defer th.TearDown() - config := th.App.ClientConfigWithNoAccounts() + config := th.App.ClientConfigWithComputed() if _, ok := config["NoAccounts"]; !ok { t.Fatal("expected NoAccounts in returned config") } + if _, ok := config["MaxPostSize"]; !ok { + t.Fatal("expected MaxPostSize in returned config") + } } diff --git a/app/import.go b/app/import.go index e2e3aa1b7..23a315be7 100644 --- a/app/import.go +++ b/app/import.go @@ -1086,7 +1086,7 @@ func (a *App) ImportReaction(data *ReactionImportData, post *model.Post, dryRun } func (a *App) ImportReply(data *ReplyImportData, post *model.Post, dryRun bool) *model.AppError { - if err := validateReplyImportData(data, post.CreateAt); err != nil { + if err := validateReplyImportData(data, post.CreateAt, a.MaxPostSize()); err != nil { return err } @@ -1136,7 +1136,7 @@ func (a *App) ImportReply(data *ReplyImportData, post *model.Post, dryRun bool) } func (a *App) ImportPost(data *PostImportData, dryRun bool) *model.AppError { - if err := validatePostImportData(data); err != nil { + if err := validatePostImportData(data, a.MaxPostSize()); err != nil { return err } @@ -1271,14 +1271,14 @@ func validateReactionImportData(data *ReactionImportData, parentCreateAt int64) return nil } -func validateReplyImportData(data *ReplyImportData, parentCreateAt int64) *model.AppError { +func validateReplyImportData(data *ReplyImportData, parentCreateAt int64, maxPostSize int) *model.AppError { if data.User == nil { return model.NewAppError("BulkImport", "app.import.validate_reply_import_data.user_missing.error", nil, "", http.StatusBadRequest) } if data.Message == nil { return model.NewAppError("BulkImport", "app.import.validate_reply_import_data.message_missing.error", nil, "", http.StatusBadRequest) - } else if utf8.RuneCountInString(*data.Message) > model.POST_MESSAGE_MAX_RUNES { + } else if utf8.RuneCountInString(*data.Message) > maxPostSize { return model.NewAppError("BulkImport", "app.import.validate_reply_import_data.message_length.error", nil, "", http.StatusBadRequest) } @@ -1293,7 +1293,7 @@ func validateReplyImportData(data *ReplyImportData, parentCreateAt int64) *model return nil } -func validatePostImportData(data *PostImportData) *model.AppError { +func validatePostImportData(data *PostImportData, maxPostSize int) *model.AppError { if data.Team == nil { return model.NewAppError("BulkImport", "app.import.validate_post_import_data.team_missing.error", nil, "", http.StatusBadRequest) } @@ -1308,7 +1308,7 @@ func validatePostImportData(data *PostImportData) *model.AppError { if data.Message == nil { return model.NewAppError("BulkImport", "app.import.validate_post_import_data.message_missing.error", nil, "", http.StatusBadRequest) - } else if utf8.RuneCountInString(*data.Message) > model.POST_MESSAGE_MAX_RUNES { + } else if utf8.RuneCountInString(*data.Message) > maxPostSize { return model.NewAppError("BulkImport", "app.import.validate_post_import_data.message_length.error", nil, "", http.StatusBadRequest) } @@ -1326,7 +1326,7 @@ func validatePostImportData(data *PostImportData) *model.AppError { if data.Replies != nil { for _, reply := range *data.Replies { - validateReplyImportData(&reply, *data.CreateAt) + validateReplyImportData(&reply, *data.CreateAt, maxPostSize) } } @@ -1446,7 +1446,7 @@ func validateDirectChannelImportData(data *DirectChannelImportData) *model.AppEr } func (a *App) ImportDirectPost(data *DirectPostImportData, dryRun bool) *model.AppError { - if err := validateDirectPostImportData(data); err != nil { + if err := validateDirectPostImportData(data, a.MaxPostSize()); err != nil { return err } @@ -1572,7 +1572,7 @@ func (a *App) ImportDirectPost(data *DirectPostImportData, dryRun bool) *model.A return nil } -func validateDirectPostImportData(data *DirectPostImportData) *model.AppError { +func validateDirectPostImportData(data *DirectPostImportData, maxPostSize int) *model.AppError { if data.ChannelMembers == nil { return model.NewAppError("BulkImport", "app.import.validate_direct_post_import_data.channel_members_required.error", nil, "", http.StatusBadRequest) } @@ -1591,7 +1591,7 @@ func validateDirectPostImportData(data *DirectPostImportData) *model.AppError { if data.Message == nil { return model.NewAppError("BulkImport", "app.import.validate_direct_post_import_data.message_missing.error", nil, "", http.StatusBadRequest) - } else if utf8.RuneCountInString(*data.Message) > model.POST_MESSAGE_MAX_RUNES { + } else if utf8.RuneCountInString(*data.Message) > maxPostSize { return model.NewAppError("BulkImport", "app.import.validate_direct_post_import_data.message_length.error", nil, "", http.StatusBadRequest) } @@ -1624,7 +1624,7 @@ func validateDirectPostImportData(data *DirectPostImportData) *model.AppError { if data.Replies != nil { for _, reply := range *data.Replies { - validateReplyImportData(&reply, *data.CreateAt) + validateReplyImportData(&reply, *data.CreateAt, maxPostSize) } } @@ -1640,12 +1640,13 @@ func validateDirectPostImportData(data *DirectPostImportData) *model.AppError { func (a *App) OldImportPost(post *model.Post) { // Workaround for empty messages, which may be the case if they are webhook posts. firstIteration := true + maxPostSize := a.MaxPostSize() for messageRuneCount := utf8.RuneCountInString(post.Message); messageRuneCount > 0 || firstIteration; messageRuneCount = utf8.RuneCountInString(post.Message) { firstIteration = false var remainder string - if messageRuneCount > model.POST_MESSAGE_MAX_RUNES { - remainder = string(([]rune(post.Message))[model.POST_MESSAGE_MAX_RUNES:]) - post.Message = truncateRunes(post.Message, model.POST_MESSAGE_MAX_RUNES) + if messageRuneCount > maxPostSize { + remainder = string(([]rune(post.Message))[maxPostSize:]) + post.Message = truncateRunes(post.Message, maxPostSize) } else { remainder = "" } diff --git a/app/import_test.go b/app/import_test.go index f294c8731..23213d81b 100644 --- a/app/import_test.go +++ b/app/import_test.go @@ -644,12 +644,13 @@ func TestImportValidateReactionImportData(t *testing.T) { func TestImportValidateReplyImportData(t *testing.T) { // Test with minimum required valid properties. parentCreateAt := model.GetMillis() - 100 + maxPostSize := 10000 data := ReplyImportData{ User: ptrStr("username"), Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateReplyImportData(&data, parentCreateAt); err != nil { + if err := validateReplyImportData(&data, parentCreateAt, maxPostSize); err != nil { t.Fatal("Validation failed but should have been valid.") } @@ -658,7 +659,7 @@ func TestImportValidateReplyImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateReplyImportData(&data, parentCreateAt); err == nil { + if err := validateReplyImportData(&data, parentCreateAt, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -666,7 +667,7 @@ func TestImportValidateReplyImportData(t *testing.T) { User: ptrStr("username"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateReplyImportData(&data, parentCreateAt); err == nil { + if err := validateReplyImportData(&data, parentCreateAt, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -674,17 +675,17 @@ func TestImportValidateReplyImportData(t *testing.T) { User: ptrStr("username"), Message: ptrStr("message"), } - if err := validateReplyImportData(&data, parentCreateAt); err == nil { + if err := validateReplyImportData(&data, parentCreateAt, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } // Test with invalid message. data = ReplyImportData{ User: ptrStr("username"), - Message: ptrStr(strings.Repeat("1234567890", 500)), + Message: ptrStr(strings.Repeat("0", maxPostSize+1)), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateReplyImportData(&data, parentCreateAt); err == nil { + if err := validateReplyImportData(&data, parentCreateAt, maxPostSize); err == nil { t.Fatal("Should have failed due to too long message.") } @@ -694,7 +695,7 @@ func TestImportValidateReplyImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(0), } - if err := validateReplyImportData(&data, parentCreateAt); err == nil { + if err := validateReplyImportData(&data, parentCreateAt, maxPostSize); err == nil { t.Fatal("Should have failed due to 0 create-at value.") } @@ -703,12 +704,13 @@ func TestImportValidateReplyImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(parentCreateAt - 100), } - if err := validateReplyImportData(&data, parentCreateAt); err == nil { + if err := validateReplyImportData(&data, parentCreateAt, maxPostSize); err == nil { t.Fatal("Should have failed due parent with newer create-at value.") } } func TestImportValidatePostImportData(t *testing.T) { + maxPostSize := 10000 // Test with minimum required valid properties. data := PostImportData{ @@ -718,7 +720,7 @@ func TestImportValidatePostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validatePostImportData(&data); err != nil { + if err := validatePostImportData(&data, maxPostSize); err != nil { t.Fatal("Validation failed but should have been valid.") } @@ -729,7 +731,7 @@ func TestImportValidatePostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validatePostImportData(&data); err == nil { + if err := validatePostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -739,7 +741,7 @@ func TestImportValidatePostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validatePostImportData(&data); err == nil { + if err := validatePostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -749,7 +751,7 @@ func TestImportValidatePostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validatePostImportData(&data); err == nil { + if err := validatePostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -759,7 +761,7 @@ func TestImportValidatePostImportData(t *testing.T) { User: ptrStr("username"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validatePostImportData(&data); err == nil { + if err := validatePostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -769,7 +771,7 @@ func TestImportValidatePostImportData(t *testing.T) { User: ptrStr("username"), Message: ptrStr("message"), } - if err := validatePostImportData(&data); err == nil { + if err := validatePostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -778,10 +780,10 @@ func TestImportValidatePostImportData(t *testing.T) { Team: ptrStr("teamname"), Channel: ptrStr("channelname"), User: ptrStr("username"), - Message: ptrStr(strings.Repeat("1234567890", 500)), + Message: ptrStr(strings.Repeat("0", maxPostSize+1)), CreateAt: ptrInt64(model.GetMillis()), } - if err := validatePostImportData(&data); err == nil { + if err := validatePostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to too long message.") } @@ -793,7 +795,7 @@ func TestImportValidatePostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(0), } - if err := validatePostImportData(&data); err == nil { + if err := validatePostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to 0 create-at value.") } @@ -817,7 +819,7 @@ func TestImportValidatePostImportData(t *testing.T) { Reactions: &reactions, Replies: &replies, } - if err := validatePostImportData(&data); err != nil { + if err := validatePostImportData(&data, maxPostSize); err != nil { t.Fatal("Should have succeeded.") } } @@ -933,6 +935,7 @@ func TestImportValidateDirectChannelImportData(t *testing.T) { } func TestImportValidateDirectPostImportData(t *testing.T) { + maxPostSize := 10000 // Test with minimum required valid properties. data := DirectPostImportData{ @@ -944,7 +947,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err != nil { + if err := validateDirectPostImportData(&data, maxPostSize); err != nil { t.Fatal("Validation failed but should have been valid.") } @@ -954,7 +957,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err == nil { + if err := validateDirectPostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -966,7 +969,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err == nil { + if err := validateDirectPostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -978,7 +981,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { User: ptrStr("username"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err == nil { + if err := validateDirectPostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -990,7 +993,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { User: ptrStr("username"), Message: ptrStr("message"), } - if err := validateDirectPostImportData(&data); err == nil { + if err := validateDirectPostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to missing required property.") } @@ -1001,7 +1004,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err == nil { + if err := validateDirectPostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to unsuitable number of members.") } @@ -1013,7 +1016,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err == nil { + if err := validateDirectPostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to unsuitable number of members.") } @@ -1034,7 +1037,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err == nil { + if err := validateDirectPostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to unsuitable number of members.") } @@ -1049,7 +1052,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err != nil { + if err := validateDirectPostImportData(&data, maxPostSize); err != nil { t.Fatal("Validation failed but should have been valid.") } @@ -1060,10 +1063,10 @@ func TestImportValidateDirectPostImportData(t *testing.T) { model.NewId(), }, User: ptrStr("username"), - Message: ptrStr(strings.Repeat("1234567890", 500)), + Message: ptrStr(strings.Repeat("0", maxPostSize+1)), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err == nil { + if err := validateDirectPostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to too long message.") } @@ -1077,7 +1080,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(0), } - if err := validateDirectPostImportData(&data); err == nil { + if err := validateDirectPostImportData(&data, maxPostSize); err == nil { t.Fatal("Should have failed due to 0 create-at value.") } @@ -1097,7 +1100,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err == nil { + if err := validateDirectPostImportData(&data, maxPostSize); err == nil { t.Fatal("Validation should have failed due to non-member flagged.") } @@ -1115,7 +1118,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Message: ptrStr("message"), CreateAt: ptrInt64(model.GetMillis()), } - if err := validateDirectPostImportData(&data); err != nil { + if err := validateDirectPostImportData(&data, maxPostSize); err != nil { t.Fatal(err) } @@ -1146,7 +1149,7 @@ func TestImportValidateDirectPostImportData(t *testing.T) { Replies: &replies, } - if err := validateDirectPostImportData(&data); err != nil { + if err := validateDirectPostImportData(&data, maxPostSize); err != nil { t.Fatal(err) } } diff --git a/app/post.go b/app/post.go index 5067777ab..d9445155b 100644 --- a/app/post.go +++ b/app/post.go @@ -965,3 +965,14 @@ func (a *App) ImageProxyRemover() (f func(string) string) { return url } } + +func (a *App) MaxPostSize() int { + maxPostSize := model.POST_MESSAGE_MAX_RUNES_V1 + if result := <-a.Srv.Store.Post().GetMaxPostSize(); result.Err != nil { + l4g.Error(result.Err) + } else { + maxPostSize = result.Data.(int) + } + + return maxPostSize +} diff --git a/app/post_test.go b/app/post_test.go index 2472e40c6..8455656d7 100644 --- a/app/post_test.go +++ b/app/post_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "time" @@ -17,6 +18,8 @@ import ( "github.com/stretchr/testify/require" "github.com/mattermost/mattermost-server/model" + "github.com/mattermost/mattermost-server/store" + "github.com/mattermost/mattermost-server/store/storetest" ) func TestUpdatePostEditAt(t *testing.T) { @@ -346,3 +349,59 @@ func TestMakeOpenGraphURLsAbsolute(t *testing.T) { }) } } + +func TestMaxPostSize(t *testing.T) { + t.Parallel() + + testCases := []struct { + Description string + StoreMaxPostSize int + ExpectedMaxPostSize int + ExpectedError *model.AppError + }{ + { + "error fetching max post size", + 0, + model.POST_MESSAGE_MAX_RUNES_V1, + model.NewAppError("TestMaxPostSize", "this is an error", nil, "", http.StatusBadRequest), + }, + { + "4000 rune limit", + 4000, + 4000, + nil, + }, + { + "16383 rune limit", + 16383, + 16383, + nil, + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.Description, func(t *testing.T) { + t.Parallel() + + mockStore := &storetest.Store{} + defer mockStore.AssertExpectations(t) + + mockStore.PostStore.On("GetMaxPostSize").Return( + storetest.NewStoreChannel(store.StoreResult{ + Data: testCase.StoreMaxPostSize, + Err: testCase.ExpectedError, + }), + ) + + app := App{ + Srv: &Server{ + Store: mockStore, + }, + config: atomic.Value{}, + } + + assert.Equal(t, testCase.ExpectedMaxPostSize, app.MaxPostSize()) + }) + } +} diff --git a/app/webhook.go b/app/webhook.go index abfc388b5..5c3e963ce 100644 --- a/app/webhook.go +++ b/app/webhook.go @@ -143,7 +143,7 @@ func (a *App) TriggerWebhook(payload *model.OutgoingWebhookPayload, hook *model. } } -func SplitWebhookPost(post *model.Post) ([]*model.Post, *model.AppError) { +func SplitWebhookPost(post *model.Post, maxPostSize int) ([]*model.Post, *model.AppError) { splits := make([]*model.Post, 0) remainingText := post.Message @@ -159,12 +159,12 @@ func SplitWebhookPost(post *model.Post) ([]*model.Post, *model.AppError) { return nil, model.NewAppError("SplitWebhookPost", "web.incoming_webhook.split_props_length.app_error", map[string]interface{}{"Max": model.POST_PROPS_MAX_USER_RUNES}, "", http.StatusBadRequest) } - for utf8.RuneCountInString(remainingText) > model.POST_MESSAGE_MAX_RUNES { + for utf8.RuneCountInString(remainingText) > maxPostSize { split := base x := 0 for index := range remainingText { x++ - if x > model.POST_MESSAGE_MAX_RUNES { + if x > maxPostSize { split.Message = remainingText[:index] remainingText = remainingText[index:] break @@ -266,7 +266,7 @@ func (a *App) CreateWebhookPost(userId string, channel *model.Channel, text, ove } } - splits, err := SplitWebhookPost(post) + splits, err := SplitWebhookPost(post, a.MaxPostSize()) if err != nil { return nil, err } diff --git a/app/webhook_test.go b/app/webhook_test.go index 4d2bc58fa..8931100ac 100644 --- a/app/webhook_test.go +++ b/app/webhook_test.go @@ -383,23 +383,25 @@ func TestSplitWebhookPost(t *testing.T) { Expected []*model.Post } + maxPostSize := 10000 + for name, tc := range map[string]TestCase{ "LongPost": { Post: &model.Post{ - Message: strings.Repeat("本", model.POST_MESSAGE_MAX_RUNES*3/2), + Message: strings.Repeat("本", maxPostSize*3/2), }, Expected: []*model.Post{ { - Message: strings.Repeat("本", model.POST_MESSAGE_MAX_RUNES), + Message: strings.Repeat("本", maxPostSize), }, { - Message: strings.Repeat("本", model.POST_MESSAGE_MAX_RUNES/2), + Message: strings.Repeat("本", maxPostSize/2), }, }, }, "LongPostAndMultipleAttachments": { Post: &model.Post{ - Message: strings.Repeat("本", model.POST_MESSAGE_MAX_RUNES*3/2), + Message: strings.Repeat("本", maxPostSize*3/2), Props: map[string]interface{}{ "attachments": []*model.SlackAttachment{ &model.SlackAttachment{ @@ -416,10 +418,10 @@ func TestSplitWebhookPost(t *testing.T) { }, Expected: []*model.Post{ { - Message: strings.Repeat("本", model.POST_MESSAGE_MAX_RUNES), + Message: strings.Repeat("本", maxPostSize), }, { - Message: strings.Repeat("本", model.POST_MESSAGE_MAX_RUNES/2), + Message: strings.Repeat("本", maxPostSize/2), Props: map[string]interface{}{ "attachments": []*model.SlackAttachment{ &model.SlackAttachment{ @@ -452,7 +454,7 @@ func TestSplitWebhookPost(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - splits, err := SplitWebhookPost(tc.Post) + splits, err := SplitWebhookPost(tc.Post, maxPostSize) if tc.Expected == nil { require.NotNil(t, err) } else { diff --git a/i18n/en.json b/i18n/en.json index d623ce864..0ce0a47e6 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -6447,6 +6447,18 @@ "translation": "We couldn't select the posts to delete for the user (too many), please re-run" }, { + "id": "store.sql_post.query_max_post_size.error", + "translation": "We couldn't determine the maximum supported post size" + }, + { + "id": "store.sql_post.query_max_post_size.unrecognized_driver", + "translation": "No implementation found to determine the maximum supported post size" + }, + { + "id": "store.sql_post.query_max_post_size.max_post_size_bytes", + "translation": "Post.Message supports at most %d bytes" + }, + { "id": "store.sql_post.save.app_error", "translation": "We couldn't save the Post" }, diff --git a/model/post.go b/model/post.go index 4a774b5d4..ae0627a03 100644 --- a/model/post.go +++ b/model/post.go @@ -40,7 +40,9 @@ const ( POST_FILEIDS_MAX_RUNES = 150 POST_FILENAMES_MAX_RUNES = 4000 POST_HASHTAGS_MAX_RUNES = 1000 - POST_MESSAGE_MAX_RUNES = 4000 + POST_MESSAGE_MAX_RUNES_V1 = 4000 + POST_MESSAGE_MAX_BYTES_V2 = 65535 // Maximum size of a TEXT column in MySQL + POST_MESSAGE_MAX_RUNES_V2 = POST_MESSAGE_MAX_BYTES_V2 / 4 // Assume a worst-case representation POST_PROPS_MAX_RUNES = 8000 POST_PROPS_MAX_USER_RUNES = POST_PROPS_MAX_RUNES - 400 // Leave some room for system / pre-save modifications POST_CUSTOM_TYPE_PREFIX = "custom_" @@ -141,7 +143,7 @@ func (o *Post) Etag() string { return Etag(o.Id, o.UpdateAt) } -func (o *Post) IsValid() *AppError { +func (o *Post) IsValid(maxPostSize int) *AppError { if len(o.Id) != 26 { return NewAppError("Post.IsValid", "model.post.is_valid.id.app_error", nil, "", http.StatusBadRequest) @@ -179,7 +181,7 @@ func (o *Post) IsValid() *AppError { return NewAppError("Post.IsValid", "model.post.is_valid.original_id.app_error", nil, "", http.StatusBadRequest) } - if utf8.RuneCountInString(o.Message) > POST_MESSAGE_MAX_RUNES { + if utf8.RuneCountInString(o.Message) > maxPostSize { return NewAppError("Post.IsValid", "model.post.is_valid.msg.app_error", nil, "id="+o.Id, http.StatusBadRequest) } diff --git a/model/post_test.go b/model/post_test.go index 5d5e7c9ec..af350d76e 100644 --- a/model/post_test.go +++ b/model/post_test.go @@ -23,72 +23,73 @@ func TestPostJson(t *testing.T) { func TestPostIsValid(t *testing.T) { o := Post{} + maxPostSize := 10000 - if err := o.IsValid(); err == nil { + if err := o.IsValid(maxPostSize); err == nil { t.Fatal("should be invalid") } o.Id = NewId() - if err := o.IsValid(); err == nil { + if err := o.IsValid(maxPostSize); err == nil { t.Fatal("should be invalid") } o.CreateAt = GetMillis() - if err := o.IsValid(); err == nil { + if err := o.IsValid(maxPostSize); err == nil { t.Fatal("should be invalid") } o.UpdateAt = GetMillis() - if err := o.IsValid(); err == nil { + if err := o.IsValid(maxPostSize); err == nil { t.Fatal("should be invalid") } o.UserId = NewId() - if err := o.IsValid(); err == nil { + if err := o.IsValid(maxPostSize); err == nil { t.Fatal("should be invalid") } o.ChannelId = NewId() o.RootId = "123" - if err := o.IsValid(); err == nil { + if err := o.IsValid(maxPostSize); err == nil { t.Fatal("should be invalid") } o.RootId = "" o.ParentId = "123" - if err := o.IsValid(); err == nil { + if err := o.IsValid(maxPostSize); err == nil { t.Fatal("should be invalid") } o.ParentId = NewId() o.RootId = "" - if err := o.IsValid(); err == nil { + if err := o.IsValid(maxPostSize); err == nil { t.Fatal("should be invalid") } o.ParentId = "" - o.Message = strings.Repeat("0", 4001) - if err := o.IsValid(); err == nil { + o.Message = strings.Repeat("0", maxPostSize+1) + if err := o.IsValid(maxPostSize); err == nil { t.Fatal("should be invalid") } - o.Message = strings.Repeat("0", 4000) - if err := o.IsValid(); err != nil { + o.Message = strings.Repeat("0", maxPostSize) + if err := o.IsValid(maxPostSize); err != nil { t.Fatal(err) } o.Message = "test" - if err := o.IsValid(); err != nil { + if err := o.IsValid(maxPostSize); err != nil { t.Fatal(err) } o.Type = "junk" - if err := o.IsValid(); err == nil { + if err := o.IsValid(maxPostSize); err == nil { t.Fatal("should be invalid") } o.Type = POST_CUSTOM_TYPE_PREFIX + "type" - if err := o.IsValid(); err != nil { + if err := o.IsValid(maxPostSize); err != nil { t.Fatal(err) } } diff --git a/store/sqlstore/post_store.go b/store/sqlstore/post_store.go index 3ff9a3e1b..182cf4891 100644 --- a/store/sqlstore/post_store.go +++ b/store/sqlstore/post_store.go @@ -4,13 +4,13 @@ package sqlstore import ( + "bytes" "fmt" "net/http" "regexp" "strconv" "strings" - - "bytes" + "sync" l4g "github.com/alecthomas/log4go" "github.com/mattermost/mattermost-server/einterfaces" @@ -21,7 +21,11 @@ import ( type SqlPostStore struct { SqlStore - metrics einterfaces.MetricsInterface + metrics einterfaces.MetricsInterface + lastPostTimeCache *utils.Cache + lastPostsCache *utils.Cache + maxPostSizeOnce sync.Once + maxPostSizeCached int } const ( @@ -32,12 +36,9 @@ const ( LAST_POSTS_CACHE_SEC = 900 // 15 minutes ) -var lastPostTimeCache = utils.NewLru(LAST_POST_TIME_CACHE_SIZE) -var lastPostsCache = utils.NewLru(LAST_POSTS_CACHE_SIZE) - -func (s SqlPostStore) ClearCaches() { - lastPostTimeCache.Purge() - lastPostsCache.Purge() +func (s *SqlPostStore) ClearCaches() { + s.lastPostTimeCache.Purge() + s.lastPostsCache.Purge() if s.metrics != nil { s.metrics.IncrementMemCacheInvalidationCounter("Last Post Time - Purge") @@ -47,8 +48,11 @@ func (s SqlPostStore) ClearCaches() { func NewSqlPostStore(sqlStore SqlStore, metrics einterfaces.MetricsInterface) store.PostStore { s := &SqlPostStore{ - SqlStore: sqlStore, - metrics: metrics, + SqlStore: sqlStore, + metrics: metrics, + lastPostTimeCache: utils.NewLru(LAST_POST_TIME_CACHE_SIZE), + lastPostsCache: utils.NewLru(LAST_POSTS_CACHE_SIZE), + maxPostSizeCached: model.POST_MESSAGE_MAX_RUNES_V1, } for _, db := range sqlStore.GetAllConns() { @@ -59,18 +63,18 @@ func NewSqlPostStore(sqlStore SqlStore, metrics einterfaces.MetricsInterface) st table.ColMap("RootId").SetMaxSize(26) table.ColMap("ParentId").SetMaxSize(26) table.ColMap("OriginalId").SetMaxSize(26) - table.ColMap("Message").SetMaxSize(4000) + table.ColMap("Message").SetMaxSize(model.POST_MESSAGE_MAX_BYTES_V2) table.ColMap("Type").SetMaxSize(26) table.ColMap("Hashtags").SetMaxSize(1000) table.ColMap("Props").SetMaxSize(8000) - table.ColMap("Filenames").SetMaxSize(4000) + table.ColMap("Filenames").SetMaxSize(model.POST_FILENAMES_MAX_RUNES) table.ColMap("FileIds").SetMaxSize(150) } return s } -func (s SqlPostStore) CreateIndexesIfNotExists() { +func (s *SqlPostStore) CreateIndexesIfNotExists() { s.CreateIndexIfNotExists("idx_posts_update_at", "Posts", "UpdateAt") s.CreateIndexIfNotExists("idx_posts_create_at", "Posts", "CreateAt") s.CreateIndexIfNotExists("idx_posts_delete_at", "Posts", "DeleteAt") @@ -86,15 +90,23 @@ func (s SqlPostStore) CreateIndexesIfNotExists() { s.CreateFullTextIndexIfNotExists("idx_posts_hashtags_txt", "Posts", "Hashtags") } -func (s SqlPostStore) Save(post *model.Post) store.StoreChannel { +func (s *SqlPostStore) Save(post *model.Post) store.StoreChannel { return store.Do(func(result *store.StoreResult) { if len(post.Id) > 0 { result.Err = model.NewAppError("SqlPostStore.Save", "store.sql_post.save.existing.app_error", nil, "id="+post.Id, http.StatusBadRequest) return } + var maxPostSize int + if result := <-s.GetMaxPostSize(); result.Err != nil { + result.Err = model.NewAppError("SqlPostStore.Save", "store.sql_post.save.app_error", nil, "id="+post.Id+", "+result.Err.Error(), http.StatusInternalServerError) + return + } else { + maxPostSize = result.Data.(int) + } + post.PreSave() - if result.Err = post.IsValid(); result.Err != nil { + if result.Err = post.IsValid(maxPostSize); result.Err != nil { return } @@ -122,7 +134,7 @@ func (s SqlPostStore) Save(post *model.Post) store.StoreChannel { }) } -func (s SqlPostStore) Update(newPost *model.Post, oldPost *model.Post) store.StoreChannel { +func (s *SqlPostStore) Update(newPost *model.Post, oldPost *model.Post) store.StoreChannel { return store.Do(func(result *store.StoreResult) { newPost.UpdateAt = model.GetMillis() newPost.PreCommit() @@ -133,7 +145,15 @@ func (s SqlPostStore) Update(newPost *model.Post, oldPost *model.Post) store.Sto oldPost.Id = model.NewId() oldPost.PreCommit() - if result.Err = newPost.IsValid(); result.Err != nil { + var maxPostSize int + if result := <-s.GetMaxPostSize(); result.Err != nil { + result.Err = model.NewAppError("SqlPostStore.Save", "store.sql_post.update.app_error", nil, "id="+newPost.Id+", "+result.Err.Error(), http.StatusInternalServerError) + return + } else { + maxPostSize = result.Data.(int) + } + + if result.Err = newPost.IsValid(maxPostSize); result.Err != nil { return } @@ -155,11 +175,19 @@ func (s SqlPostStore) Update(newPost *model.Post, oldPost *model.Post) store.Sto }) } -func (s SqlPostStore) Overwrite(post *model.Post) store.StoreChannel { +func (s *SqlPostStore) Overwrite(post *model.Post) store.StoreChannel { return store.Do(func(result *store.StoreResult) { post.UpdateAt = model.GetMillis() - if result.Err = post.IsValid(); result.Err != nil { + var maxPostSize int + if result := <-s.GetMaxPostSize(); result.Err != nil { + result.Err = model.NewAppError("SqlPostStore.Save", "store.sql_post.overwrite.app_error", nil, "id="+post.Id+", "+result.Err.Error(), http.StatusInternalServerError) + return + } else { + maxPostSize = result.Data.(int) + } + + if result.Err = post.IsValid(maxPostSize); result.Err != nil { return } @@ -171,7 +199,7 @@ func (s SqlPostStore) Overwrite(post *model.Post) store.StoreChannel { }) } -func (s SqlPostStore) GetFlaggedPosts(userId string, offset int, limit int) store.StoreChannel { +func (s *SqlPostStore) GetFlaggedPosts(userId string, offset int, limit int) store.StoreChannel { return store.Do(func(result *store.StoreResult) { pl := model.NewPostList() @@ -189,7 +217,7 @@ func (s SqlPostStore) GetFlaggedPosts(userId string, offset int, limit int) stor }) } -func (s SqlPostStore) GetFlaggedPostsForTeam(userId, teamId string, offset int, limit int) store.StoreChannel { +func (s *SqlPostStore) GetFlaggedPostsForTeam(userId, teamId string, offset int, limit int) store.StoreChannel { return store.Do(func(result *store.StoreResult) { pl := model.NewPostList() @@ -234,7 +262,7 @@ func (s SqlPostStore) GetFlaggedPostsForTeam(userId, teamId string, offset int, }) } -func (s SqlPostStore) GetFlaggedPostsForChannel(userId, channelId string, offset int, limit int) store.StoreChannel { +func (s *SqlPostStore) GetFlaggedPostsForChannel(userId, channelId string, offset int, limit int) store.StoreChannel { return store.Do(func(result *store.StoreResult) { pl := model.NewPostList() @@ -263,7 +291,7 @@ func (s SqlPostStore) GetFlaggedPostsForChannel(userId, channelId string, offset }) } -func (s SqlPostStore) Get(id string) store.StoreChannel { +func (s *SqlPostStore) Get(id string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { pl := model.NewPostList() @@ -308,7 +336,7 @@ func (s SqlPostStore) Get(id string) store.StoreChannel { }) } -func (s SqlPostStore) GetSingle(id string) store.StoreChannel { +func (s *SqlPostStore) GetSingle(id string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { var post model.Post err := s.GetReplica().SelectOne(&post, "SELECT * FROM Posts WHERE Id = :Id AND DeleteAt = 0", map[string]interface{}{"Id": id}) @@ -325,12 +353,12 @@ type etagPosts struct { UpdateAt int64 } -func (s SqlPostStore) InvalidateLastPostTimeCache(channelId string) { - lastPostTimeCache.Remove(channelId) +func (s *SqlPostStore) InvalidateLastPostTimeCache(channelId string) { + s.lastPostTimeCache.Remove(channelId) // Keys are "{channelid}{limit}" and caching only occurs on limits of 30 and 60 - lastPostsCache.Remove(channelId + "30") - lastPostsCache.Remove(channelId + "60") + s.lastPostsCache.Remove(channelId + "30") + s.lastPostsCache.Remove(channelId + "60") if s.metrics != nil { s.metrics.IncrementMemCacheInvalidationCounter("Last Post Time - Remove by Channel Id") @@ -338,10 +366,10 @@ func (s SqlPostStore) InvalidateLastPostTimeCache(channelId string) { } } -func (s SqlPostStore) GetEtag(channelId string, allowFromCache bool) store.StoreChannel { +func (s *SqlPostStore) GetEtag(channelId string, allowFromCache bool) store.StoreChannel { return store.Do(func(result *store.StoreResult) { if allowFromCache { - if cacheItem, ok := lastPostTimeCache.Get(channelId); ok { + if cacheItem, ok := s.lastPostTimeCache.Get(channelId); ok { if s.metrics != nil { s.metrics.IncrementMemCacheHitCounter("Last Post Time") } @@ -366,11 +394,11 @@ func (s SqlPostStore) GetEtag(channelId string, allowFromCache bool) store.Store result.Data = fmt.Sprintf("%v.%v", model.CurrentVersion, et.UpdateAt) } - lastPostTimeCache.AddWithExpiresInSecs(channelId, et.UpdateAt, LAST_POST_TIME_CACHE_SEC) + s.lastPostTimeCache.AddWithExpiresInSecs(channelId, et.UpdateAt, LAST_POST_TIME_CACHE_SEC) }) } -func (s SqlPostStore) Delete(postId string, time int64) store.StoreChannel { +func (s *SqlPostStore) Delete(postId string, time int64) store.StoreChannel { return store.Do(func(result *store.StoreResult) { _, err := s.GetMaster().Exec("Update Posts SET DeleteAt = :DeleteAt, UpdateAt = :UpdateAt WHERE Id = :Id OR RootId = :RootId", map[string]interface{}{"DeleteAt": time, "UpdateAt": time, "Id": postId, "RootId": postId}) if err != nil { @@ -379,7 +407,7 @@ func (s SqlPostStore) Delete(postId string, time int64) store.StoreChannel { }) } -func (s SqlPostStore) permanentDelete(postId string) store.StoreChannel { +func (s *SqlPostStore) permanentDelete(postId string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { _, err := s.GetMaster().Exec("DELETE FROM Posts WHERE Id = :Id OR RootId = :RootId", map[string]interface{}{"Id": postId, "RootId": postId}) if err != nil { @@ -388,7 +416,7 @@ func (s SqlPostStore) permanentDelete(postId string) store.StoreChannel { }) } -func (s SqlPostStore) permanentDeleteAllCommentByUser(userId string) store.StoreChannel { +func (s *SqlPostStore) permanentDeleteAllCommentByUser(userId string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { _, err := s.GetMaster().Exec("DELETE FROM Posts WHERE UserId = :UserId AND RootId != ''", map[string]interface{}{"UserId": userId}) if err != nil { @@ -397,7 +425,7 @@ func (s SqlPostStore) permanentDeleteAllCommentByUser(userId string) store.Store }) } -func (s SqlPostStore) PermanentDeleteByUser(userId string) store.StoreChannel { +func (s *SqlPostStore) PermanentDeleteByUser(userId string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { // First attempt to delete all the comments for a user if r := <-s.permanentDeleteAllCommentByUser(userId); r.Err != nil { @@ -437,7 +465,7 @@ func (s SqlPostStore) PermanentDeleteByUser(userId string) store.StoreChannel { }) } -func (s SqlPostStore) PermanentDeleteByChannel(channelId string) store.StoreChannel { +func (s *SqlPostStore) PermanentDeleteByChannel(channelId string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { if _, err := s.GetMaster().Exec("DELETE FROM Posts WHERE ChannelId = :ChannelId", map[string]interface{}{"ChannelId": channelId}); err != nil { result.Err = model.NewAppError("SqlPostStore.PermanentDeleteByChannel", "store.sql_post.permanent_delete_by_channel.app_error", nil, "channel_id="+channelId+", "+err.Error(), http.StatusInternalServerError) @@ -445,7 +473,7 @@ func (s SqlPostStore) PermanentDeleteByChannel(channelId string) store.StoreChan }) } -func (s SqlPostStore) GetPosts(channelId string, offset int, limit int, allowFromCache bool) store.StoreChannel { +func (s *SqlPostStore) GetPosts(channelId string, offset int, limit int, allowFromCache bool) store.StoreChannel { return store.Do(func(result *store.StoreResult) { if limit > 1000 { result.Err = model.NewAppError("SqlPostStore.GetLinearPosts", "store.sql_post.get_posts.app_error", nil, "channelId="+channelId, http.StatusBadRequest) @@ -454,7 +482,7 @@ func (s SqlPostStore) GetPosts(channelId string, offset int, limit int, allowFro // Caching only occurs on limits of 30 and 60, the common limits requested by MM clients if allowFromCache && offset == 0 && (limit == 60 || limit == 30) { - if cacheItem, ok := lastPostsCache.Get(fmt.Sprintf("%s%v", channelId, limit)); ok { + if cacheItem, ok := s.lastPostsCache.Get(fmt.Sprintf("%s%v", channelId, limit)); ok { if s.metrics != nil { s.metrics.IncrementMemCacheHitCounter("Last Posts Cache") } @@ -498,7 +526,7 @@ func (s SqlPostStore) GetPosts(channelId string, offset int, limit int, allowFro // Caching only occurs on limits of 30 and 60, the common limits requested by MM clients if offset == 0 && (limit == 60 || limit == 30) { - lastPostsCache.AddWithExpiresInSecs(fmt.Sprintf("%s%v", channelId, limit), list, LAST_POSTS_CACHE_SEC) + s.lastPostsCache.AddWithExpiresInSecs(fmt.Sprintf("%s%v", channelId, limit), list, LAST_POSTS_CACHE_SEC) } result.Data = list @@ -506,12 +534,12 @@ func (s SqlPostStore) GetPosts(channelId string, offset int, limit int, allowFro }) } -func (s SqlPostStore) GetPostsSince(channelId string, time int64, allowFromCache bool) store.StoreChannel { +func (s *SqlPostStore) GetPostsSince(channelId string, time int64, allowFromCache bool) store.StoreChannel { return store.Do(func(result *store.StoreResult) { if allowFromCache { // If the last post in the channel's time is less than or equal to the time we are getting posts since, // we can safely return no posts. - if cacheItem, ok := lastPostTimeCache.Get(channelId); ok && cacheItem.(int64) <= time { + if cacheItem, ok := s.lastPostTimeCache.Get(channelId); ok && cacheItem.(int64) <= time { if s.metrics != nil { s.metrics.IncrementMemCacheHitCounter("Last Post Time") } @@ -576,22 +604,22 @@ func (s SqlPostStore) GetPostsSince(channelId string, time int64, allowFromCache } } - lastPostTimeCache.AddWithExpiresInSecs(channelId, latestUpdate, LAST_POST_TIME_CACHE_SEC) + s.lastPostTimeCache.AddWithExpiresInSecs(channelId, latestUpdate, LAST_POST_TIME_CACHE_SEC) result.Data = list } }) } -func (s SqlPostStore) GetPostsBefore(channelId string, postId string, numPosts int, offset int) store.StoreChannel { +func (s *SqlPostStore) GetPostsBefore(channelId string, postId string, numPosts int, offset int) store.StoreChannel { return s.getPostsAround(channelId, postId, numPosts, offset, true) } -func (s SqlPostStore) GetPostsAfter(channelId string, postId string, numPosts int, offset int) store.StoreChannel { +func (s *SqlPostStore) GetPostsAfter(channelId string, postId string, numPosts int, offset int) store.StoreChannel { return s.getPostsAround(channelId, postId, numPosts, offset, false) } -func (s SqlPostStore) getPostsAround(channelId string, postId string, numPosts int, offset int, before bool) store.StoreChannel { +func (s *SqlPostStore) getPostsAround(channelId string, postId string, numPosts int, offset int, before bool) store.StoreChannel { return store.Do(func(result *store.StoreResult) { var direction string var sort string @@ -672,7 +700,7 @@ func (s SqlPostStore) getPostsAround(channelId string, postId string, numPosts i }) } -func (s SqlPostStore) getRootPosts(channelId string, offset int, limit int) store.StoreChannel { +func (s *SqlPostStore) getRootPosts(channelId string, offset int, limit int) store.StoreChannel { return store.Do(func(result *store.StoreResult) { var posts []*model.Post _, err := s.GetReplica().Select(&posts, "SELECT * FROM Posts WHERE ChannelId = :ChannelId AND DeleteAt = 0 ORDER BY CreateAt DESC LIMIT :Limit OFFSET :Offset", map[string]interface{}{"ChannelId": channelId, "Offset": offset, "Limit": limit}) @@ -684,7 +712,7 @@ func (s SqlPostStore) getRootPosts(channelId string, offset int, limit int) stor }) } -func (s SqlPostStore) getParentsPosts(channelId string, offset int, limit int) store.StoreChannel { +func (s *SqlPostStore) getParentsPosts(channelId string, offset int, limit int) store.StoreChannel { return store.Do(func(result *store.StoreResult) { var posts []*model.Post _, err := s.GetReplica().Select(&posts, ` @@ -771,7 +799,7 @@ var specialSearchChar = []string{ ":", } -func (s SqlPostStore) Search(teamId string, userId string, params *model.SearchParams) store.StoreChannel { +func (s *SqlPostStore) Search(teamId string, userId string, params *model.SearchParams) store.StoreChannel { return store.Do(func(result *store.StoreResult) { queryParams := map[string]interface{}{ "TeamId": teamId, @@ -945,7 +973,7 @@ func (s SqlPostStore) Search(teamId string, userId string, params *model.SearchP }) } -func (s SqlPostStore) AnalyticsUserCountsWithPostsByDay(teamId string) store.StoreChannel { +func (s *SqlPostStore) AnalyticsUserCountsWithPostsByDay(teamId string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { query := `SELECT DISTINCT @@ -998,7 +1026,7 @@ func (s SqlPostStore) AnalyticsUserCountsWithPostsByDay(teamId string) store.Sto }) } -func (s SqlPostStore) AnalyticsPostCountsByDay(teamId string) store.StoreChannel { +func (s *SqlPostStore) AnalyticsPostCountsByDay(teamId string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { query := `SELECT @@ -1053,7 +1081,7 @@ func (s SqlPostStore) AnalyticsPostCountsByDay(teamId string) store.StoreChannel }) } -func (s SqlPostStore) AnalyticsPostCount(teamId string, mustHaveFile bool, mustHaveHashtag bool) store.StoreChannel { +func (s *SqlPostStore) AnalyticsPostCount(teamId string, mustHaveFile bool, mustHaveHashtag bool) store.StoreChannel { return store.Do(func(result *store.StoreResult) { query := `SELECT @@ -1084,7 +1112,7 @@ func (s SqlPostStore) AnalyticsPostCount(teamId string, mustHaveFile bool, mustH }) } -func (s SqlPostStore) GetPostsCreatedAt(channelId string, time int64) store.StoreChannel { +func (s *SqlPostStore) GetPostsCreatedAt(channelId string, time int64) store.StoreChannel { return store.Do(func(result *store.StoreResult) { query := `SELECT * FROM Posts WHERE CreateAt = :CreateAt AND ChannelId = :ChannelId` @@ -1099,7 +1127,7 @@ func (s SqlPostStore) GetPostsCreatedAt(channelId string, time int64) store.Stor }) } -func (s SqlPostStore) GetPostsByIds(postIds []string) store.StoreChannel { +func (s *SqlPostStore) GetPostsByIds(postIds []string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { keys := bytes.Buffer{} params := make(map[string]interface{}) @@ -1127,7 +1155,7 @@ func (s SqlPostStore) GetPostsByIds(postIds []string) store.StoreChannel { }) } -func (s SqlPostStore) GetPostsBatchForIndexing(startTime int64, endTime int64, limit int) store.StoreChannel { +func (s *SqlPostStore) GetPostsBatchForIndexing(startTime int64, endTime int64, limit int) store.StoreChannel { return store.Do(func(result *store.StoreResult) { var posts []*model.PostForIndexing _, err1 := s.GetSearchReplica().Select(&posts, @@ -1167,7 +1195,7 @@ func (s SqlPostStore) GetPostsBatchForIndexing(startTime int64, endTime int64, l }) } -func (s SqlPostStore) PermanentDeleteBatch(endTime int64, limit int64) store.StoreChannel { +func (s *SqlPostStore) PermanentDeleteBatch(endTime int64, limit int64) store.StoreChannel { return store.Do(func(result *store.StoreResult) { var query string if s.DriverName() == "postgres" { @@ -1191,7 +1219,7 @@ func (s SqlPostStore) PermanentDeleteBatch(endTime int64, limit int64) store.Sto }) } -func (s SqlPostStore) GetOldest() store.StoreChannel { +func (s *SqlPostStore) GetOldest() store.StoreChannel { return store.Do(func(result *store.StoreResult) { var post model.Post err := s.GetReplica().SelectOne(&post, "SELECT * FROM Posts ORDER BY CreateAt LIMIT 1") @@ -1202,3 +1230,66 @@ func (s SqlPostStore) GetOldest() store.StoreChannel { result.Data = &post }) } + +func (s *SqlPostStore) determineMaxPostSize() int { + var maxPostSize int = model.POST_MESSAGE_MAX_RUNES_V1 + var maxPostSizeBytes int32 + + if s.DriverName() == model.DATABASE_DRIVER_POSTGRES { + // The Post.Message column in Postgres has historically been VARCHAR(4000), but + // may be manually enlarged to support longer posts. + if err := s.GetReplica().SelectOne(&maxPostSizeBytes, ` + SELECT + COALESCE(character_maximum_length, 0) + FROM + information_schema.columns + WHERE + table_name = 'posts' + AND column_name = 'message' + `); err != nil { + l4g.Error(utils.T("store.sql_post.query_max_post_size.error") + err.Error()) + } + } else if s.DriverName() == model.DATABASE_DRIVER_MYSQL { + // The Post.Message column in MySQL has historically been TEXT, with a maximum + // limit of 65535. + if err := s.GetReplica().SelectOne(&maxPostSizeBytes, ` + SELECT + COALESCE(CHARACTER_MAXIMUM_LENGTH, 0) + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + table_schema = DATABASE() + AND table_name = 'Posts' + AND column_name = 'Message' + LIMIT 0, 1 + `); err != nil { + l4g.Error(utils.T("store.sql_post.query_max_post_size.error") + err.Error()) + } + } else { + l4g.Warn(utils.T("store.sql_post.query_max_post_size.unrecognized_driver")) + } + + l4g.Trace(utils.T("store.sql_post.query_max_post_size.max_post_size_bytes"), maxPostSizeBytes) + + // Assume a worst-case representation of four bytes per rune. + maxPostSize = int(maxPostSizeBytes) / 4 + + // To maintain backwards compatibility, don't yield a maximum post + // size smaller than the previous limit, even though it wasn't + // actually possible to store 4000 runes in all cases. + if maxPostSize < model.POST_MESSAGE_MAX_RUNES_V1 { + maxPostSize = model.POST_MESSAGE_MAX_RUNES_V1 + } + + return maxPostSize +} + +// GetMaxPostSize returns the maximum number of runes that may be stored in a post. +func (s *SqlPostStore) GetMaxPostSize() store.StoreChannel { + return store.Do(func(result *store.StoreResult) { + s.maxPostSizeOnce.Do(func() { + s.maxPostSizeCached = s.determineMaxPostSize() + }) + result.Data = s.maxPostSizeCached + }) +} diff --git a/store/store.go b/store/store.go index f070a45db..773dfff02 100644 --- a/store/store.go +++ b/store/store.go @@ -198,6 +198,7 @@ type PostStore interface { GetPostsBatchForIndexing(startTime int64, endTime int64, limit int) StoreChannel PermanentDeleteBatch(endTime int64, limit int64) StoreChannel GetOldest() StoreChannel + GetMaxPostSize() StoreChannel } type UserStore interface { diff --git a/store/storetest/mocks/PostStore.go b/store/storetest/mocks/PostStore.go index c405d5030..bdd0d1d16 100644 --- a/store/storetest/mocks/PostStore.go +++ b/store/storetest/mocks/PostStore.go @@ -422,3 +422,18 @@ func (_m *PostStore) Update(newPost *model.Post, oldPost *model.Post) store.Stor return r0 } + +func (_m *PostStore) GetMaxPostSize() store.StoreChannel { + ret := _m.Called() + + var r0 store.StoreChannel + if rf, ok := ret.Get(0).(func() store.StoreChannel); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.StoreChannel) + } + } + + return r0 +} diff --git a/store/storetest/post_store.go b/store/storetest/post_store.go index 91fc40213..44ce47d9d 100644 --- a/store/storetest/post_store.go +++ b/store/storetest/post_store.go @@ -43,6 +43,7 @@ func TestPostStore(t *testing.T, ss store.Store) { t.Run("GetPostsBatchForIndexing", func(t *testing.T) { testPostStoreGetPostsBatchForIndexing(t, ss) }) t.Run("PermanentDeleteBatch", func(t *testing.T) { testPostStorePermanentDeleteBatch(t, ss) }) t.Run("GetOldest", func(t *testing.T) { testPostStoreGetOldest(t, ss) }) + t.Run("TestGetMaxPostSize", func(t *testing.T) { testGetMaxPostSize(t, ss) }) } func testPostStoreSave(t *testing.T, ss store.Store) { @@ -1783,3 +1784,8 @@ func testPostStoreGetOldest(t *testing.T, ss store.Store) { assert.EqualValues(t, o2.Id, r1.Id) } + +func testGetMaxPostSize(t *testing.T, ss store.Store) { + assert.Equal(t, model.POST_MESSAGE_MAX_RUNES_V2, (<-ss.Post().GetMaxPostSize()).Data.(int)) + assert.Equal(t, model.POST_MESSAGE_MAX_RUNES_V2, (<-ss.Post().GetMaxPostSize()).Data.(int)) +} |