From b122381e87577ddfc12b792a3de9121ea830d50e Mon Sep 17 00:00:00 2001 From: Chris Date: Wed, 16 Aug 2017 07:17:57 -0500 Subject: PLT-1649: add response_url support for custom slash commands (#6739) * add response_url support for custom slash commands * pr suggestions * pr update / suggestion * test fix --- api/webhook.go | 3 +- api4/api.go | 6 +- api4/webhook.go | 19 +++++ api4/webhook_test.go | 40 ++++++++++ app/command.go | 14 ++-- app/webhook.go | 51 +++++++++++++ cmd/platform/server.go | 10 +++ i18n/en.json | 60 +++++++++++++++ model/command_response.go | 17 +++++ model/command_response_test.go | 24 ++++++ model/command_webhook.go | 65 +++++++++++++++++ model/command_webhook_test.go | 54 ++++++++++++++ store/layered_store.go | 4 + store/sql_command_webhook_store.go | 125 ++++++++++++++++++++++++++++++++ store/sql_command_webhook_store_test.go | 65 +++++++++++++++++ store/sql_store.go | 1 + store/sql_supplier.go | 7 ++ store/store.go | 8 ++ 18 files changed, 562 insertions(+), 11 deletions(-) create mode 100644 model/command_webhook.go create mode 100644 model/command_webhook_test.go create mode 100644 store/sql_command_webhook_store.go create mode 100644 store/sql_command_webhook_store_test.go diff --git a/api/webhook.go b/api/webhook.go index c17b5bc56..9750b71a0 100644 --- a/api/webhook.go +++ b/api/webhook.go @@ -32,8 +32,7 @@ func InitWebhook() { BaseRoutes.Hooks.Handle("/{id:[A-Za-z0-9]+}", ApiAppHandler(incomingWebhook)).Methods("POST") // Old route. Remove eventually. - mr := app.Srv.Router - mr.Handle("/hooks/{id:[A-Za-z0-9]+}", ApiAppHandler(incomingWebhook)).Methods("POST") + BaseRoutes.Root.Handle("/hooks/{id:[A-Za-z0-9]+}", ApiAppHandler(incomingWebhook)).Methods("POST") } func createIncomingHook(c *Context, w http.ResponseWriter, r *http.Request) { diff --git a/api4/api.go b/api4/api.go index be957d63b..6e9534d40 100644 --- a/api4/api.go +++ b/api4/api.go @@ -55,9 +55,8 @@ type Routes struct { PublicFile *mux.Router // 'files/{file_id:[A-Za-z0-9]+}/public' - Commands *mux.Router // 'api/v4/commands' - Command *mux.Router // 'api/v4/commands/{command_id:[A-Za-z0-9]+}' - CommandsForTeam *mux.Router // 'api/v4/teams/{team_id:[A-Za-z0-9]+}/commands' + Commands *mux.Router // 'api/v4/commands' + Command *mux.Router // 'api/v4/commands/{command_id:[A-Za-z0-9]+}' Hooks *mux.Router // 'api/v4/hooks' IncomingHooks *mux.Router // 'api/v4/hooks/incoming' @@ -149,7 +148,6 @@ func InitApi(full bool) { BaseRoutes.Commands = BaseRoutes.ApiRoot.PathPrefix("/commands").Subrouter() BaseRoutes.Command = BaseRoutes.Commands.PathPrefix("/{command_id:[A-Za-z0-9]+}").Subrouter() - BaseRoutes.CommandsForTeam = BaseRoutes.Team.PathPrefix("/commands").Subrouter() BaseRoutes.Hooks = BaseRoutes.ApiRoot.PathPrefix("/hooks").Subrouter() BaseRoutes.IncomingHooks = BaseRoutes.Hooks.PathPrefix("/incoming").Subrouter() diff --git a/api4/webhook.go b/api4/webhook.go index 668636932..52576c773 100644 --- a/api4/webhook.go +++ b/api4/webhook.go @@ -7,6 +7,7 @@ import ( "net/http" l4g "github.com/alecthomas/log4go" + "github.com/gorilla/mux" "github.com/mattermost/platform/app" "github.com/mattermost/platform/model" "github.com/mattermost/platform/utils" @@ -27,6 +28,8 @@ func InitWebhook() { BaseRoutes.OutgoingHook.Handle("", ApiSessionRequired(updateOutgoingHook)).Methods("PUT") BaseRoutes.OutgoingHook.Handle("", ApiSessionRequired(deleteOutgoingHook)).Methods("DELETE") BaseRoutes.OutgoingHook.Handle("/regen_token", ApiSessionRequired(regenOutgoingHookToken)).Methods("POST") + + BaseRoutes.Root.Handle("/hooks/commands/{id:[A-Za-z0-9]+}", ApiHandler(commandWebhook)).Methods("POST") } func createIncomingHook(c *Context, w http.ResponseWriter, r *http.Request) { @@ -435,3 +438,19 @@ func deleteOutgoingHook(c *Context, w http.ResponseWriter, r *http.Request) { c.LogAudit("success") ReturnStatusOK(w) } + +func commandWebhook(c *Context, w http.ResponseWriter, r *http.Request) { + params := mux.Vars(r) + id := params["id"] + + response := model.CommandResponseFromHTTPBody(r.Header.Get("Content-Type"), r.Body) + + err := app.HandleCommandWebhook(id, response) + if err != nil { + c.Err = err + return + } + + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("ok")) +} diff --git a/api4/webhook_test.go b/api4/webhook_test.go index 96451f8a7..80328e373 100644 --- a/api4/webhook_test.go +++ b/api4/webhook_test.go @@ -4,8 +4,11 @@ package api4 import ( + "bytes" + "net/http" "testing" + "github.com/mattermost/platform/app" "github.com/mattermost/platform/model" "github.com/mattermost/platform/utils" ) @@ -893,3 +896,40 @@ func TestDeleteOutgoingHook(t *testing.T) { CheckForbiddenStatus(t, resp) }) } + +func TestCommandWebhooks(t *testing.T) { + th := Setup().InitBasic().InitSystemAdmin() + Client := th.SystemAdminClient + + cmd := &model.Command{ + CreatorId: th.BasicUser.Id, + TeamId: th.BasicTeam.Id, + URL: "http://nowhere.com", + Method: model.COMMAND_METHOD_POST, + Trigger: "delayed"} + + cmd, _ = Client.CreateCommand(cmd) + args := &model.CommandArgs{ + TeamId: th.BasicTeam.Id, + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + } + hook, err := app.CreateCommandWebhook(cmd.Id, args) + if err != nil { + t.Fatal(err) + } + + if resp, _ := http.Post(Client.Url+"/hooks/commands/123123123123", "application/json", bytes.NewBufferString("{\"text\":\"this is a test\"}")); resp.StatusCode != http.StatusNotFound { + t.Fatal("expected not-found for non-existent hook") + } + + for i := 0; i < 5; i++ { + if _, err := http.Post(Client.Url+"/hooks/commands/"+hook.Id, "application/json", bytes.NewBufferString("{\"text\":\"this is a test\"}")); err != nil { + t.Fatal(err) + } + } + + if resp, _ := http.Post(Client.Url+"/hooks/commands/"+hook.Id, "application/json", bytes.NewBufferString("{\"text\":\"this is a test\"}")); resp.StatusCode != http.StatusBadRequest { + t.Fatal("expected error for sixth usage") + } +} diff --git a/app/command.go b/app/command.go index 7fe11fffc..83500cc1f 100644 --- a/app/command.go +++ b/app/command.go @@ -45,10 +45,9 @@ func CreateCommandPost(post *model.Post, teamId string, response *model.CommandR parseSlackAttachment(post, response.Attachments) } - switch response.ResponseType { - case model.COMMAND_RESPONSE_TYPE_IN_CHANNEL: + if response.ResponseType == model.COMMAND_RESPONSE_TYPE_IN_CHANNEL { return CreatePost(post, teamId, true) - case model.COMMAND_RESPONSE_TYPE_EPHEMERAL: + } else if response.ResponseType == "" || response.ResponseType == model.COMMAND_RESPONSE_TYPE_EPHEMERAL { if response.Text == "" { return post, nil } @@ -196,7 +195,12 @@ func ExecuteCommand(args *model.CommandArgs) (*model.CommandResponse, *model.App p.Set("command", "/"+trigger) p.Set("text", message) - p.Set("response_url", "not supported yet") + + if hook, err := CreateCommandWebhook(cmd.Id, args); err != nil { + return nil, model.NewAppError("command", "api.command.execute_command.failed.app_error", map[string]interface{}{"Trigger": trigger}, err.Error(), http.StatusInternalServerError) + } else { + p.Set("response_url", args.SiteURL+"/hooks/commands/"+hook.Id) + } method := "POST" if cmd.Method == model.COMMAND_METHOD_GET { @@ -213,7 +217,7 @@ func ExecuteCommand(args *model.CommandArgs) (*model.CommandResponse, *model.App return nil, model.NewAppError("command", "api.command.execute_command.failed.app_error", map[string]interface{}{"Trigger": trigger}, err.Error(), http.StatusInternalServerError) } else { if resp.StatusCode == http.StatusOK { - response := model.CommandResponseFromJson(resp.Body) + response := model.CommandResponseFromHTTPBody(resp.Header.Get("Content-Type"), resp.Body) if response == nil { return nil, model.NewAppError("command", "api.command.execute_command.failed_empty.app_error", map[string]interface{}{"Trigger": trigger}, "", http.StatusInternalServerError) } else { diff --git a/app/webhook.go b/app/webhook.go index 4606c207f..f84086d94 100644 --- a/app/webhook.go +++ b/app/webhook.go @@ -533,3 +533,54 @@ func HandleIncomingWebhook(hookId string, req *model.IncomingWebhookRequest) *mo return nil } + +func CreateCommandWebhook(commandId string, args *model.CommandArgs) (*model.CommandWebhook, *model.AppError) { + hook := &model.CommandWebhook{ + CommandId: commandId, + UserId: args.UserId, + ChannelId: args.ChannelId, + RootId: args.RootId, + ParentId: args.ParentId, + } + + if result := <-Srv.Store.CommandWebhook().Save(hook); result.Err != nil { + return nil, result.Err + } else { + return result.Data.(*model.CommandWebhook), nil + } +} + +func HandleCommandWebhook(hookId string, response *model.CommandResponse) *model.AppError { + if response == nil { + return model.NewAppError("HandleCommandWebhook", "web.command_webhook.parse.app_error", nil, "", http.StatusBadRequest) + } + + var hook *model.CommandWebhook + if result := <-Srv.Store.CommandWebhook().Get(hookId); result.Err != nil { + return model.NewAppError("HandleCommandWebhook", "web.command_webhook.invalid.app_error", nil, "err="+result.Err.Message, result.Err.StatusCode) + } else { + hook = result.Data.(*model.CommandWebhook) + } + + var cmd *model.Command + if result := <-Srv.Store.Command().Get(hook.CommandId); result.Err != nil { + return model.NewAppError("HandleCommandWebhook", "web.command_webhook.command.app_error", nil, "err="+result.Err.Message, http.StatusBadRequest) + } else { + cmd = result.Data.(*model.Command) + } + + args := &model.CommandArgs{ + UserId: hook.UserId, + ChannelId: hook.ChannelId, + TeamId: cmd.TeamId, + RootId: hook.RootId, + ParentId: hook.ParentId, + } + + if result := <-Srv.Store.CommandWebhook().TryUse(hook.Id, 5); result.Err != nil { + return model.NewAppError("HandleCommandWebhook", "web.command_webhook.invalid.app_error", nil, "err="+result.Err.Message, result.Err.StatusCode) + } + + _, err := HandleCommandResponse(cmd, args, response, false) + return err +} diff --git a/cmd/platform/server.go b/cmd/platform/server.go index 8695129b7..79193cd0f 100644 --- a/cmd/platform/server.go +++ b/cmd/platform/server.go @@ -107,6 +107,7 @@ func runServer(configFileLocation string) { go runDiagnosticsJob() go runTokenCleanupJob() + go runCommandWebhookCleanupJob() if complianceI := einterfaces.GetComplianceInterface(); complianceI != nil { complianceI.StartComplianceDailyJob() @@ -170,6 +171,11 @@ func runTokenCleanupJob() { model.CreateRecurringTask("Token Cleanup", doTokenCleanup, time.Hour*1) } +func runCommandWebhookCleanupJob() { + doCommandWebhookCleanup() + model.CreateRecurringTask("Command Hook Cleanup", doCommandWebhookCleanup, time.Hour*1) +} + func resetStatuses() { if result := <-app.Srv.Store.Status().ResetAll(); result.Err != nil { l4g.Error(utils.T("mattermost.reset_status.error"), result.Err.Error()) @@ -204,3 +210,7 @@ func doDiagnostics() { func doTokenCleanup() { app.Srv.Store.Token().Cleanup() } + +func doCommandWebhookCleanup() { + app.Srv.Store.CommandWebhook().Cleanup() +} diff --git a/i18n/en.json b/i18n/en.json index 70a243e56..8a2d0d770 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -4783,6 +4783,34 @@ "id": "model.job.is_valid.type.app_error", "translation": "Invalid job type" }, + { + "id": "model.command_hook.id.app_error", + "translation": "Invalid command hook id" + }, + { + "id": "model.command_hook.create_at.app_error", + "translation": "Create at must be a valid time" + }, + { + "id": "model.command_hook.command_id.app_error", + "translation": "Invalid command id" + }, + { + "id": "model.command_hook.user_id.app_error", + "translation": "Invalid user id" + }, + { + "id": "model.command_hook.channel_id.app_error", + "translation": "Invalid channel id" + }, + { + "id": "model.command_hook.root_id.app_error", + "translation": "Invalid root id" + }, + { + "id": "model.command_hook.parent_id.app_error", + "translation": "Invalid parent id" + }, { "id": "model.oauth.is_valid.app_id.app_error", "translation": "Invalid app id" @@ -6503,6 +6531,26 @@ "id": "store.sql_webhooks.update_outgoing.app_error", "translation": "We couldn't update the webhook" }, + { + "id": "store.sql_command_webhooks.save.existing.app_error", + "translation": "You cannot update an existing CommandWebhook" + }, + { + "id": "store.sql_command_webhooks.save.app_error", + "translation": "We couldn't save the CommandWebhook" + }, + { + "id": "store.sql_command_webhooks.get.app_error", + "translation": "We couldn't get the webhook" + }, + { + "id": "store.sql_command_webhooks.try_use.app_error", + "translation": "Unable to use the webhook" + }, + { + "id": "store.sql_command_webhooks.try_use.invalid.app_error", + "translation": "Invalid webhook" + }, { "id": "system.message.name", "translation": "System" @@ -6727,6 +6775,18 @@ "id": "web.incoming_webhook.user.app_error", "translation": "Couldn't find the user" }, + { + "id": "web.command_webhook.parse.app_error", + "translation": "Unable to parse incoming data" + }, + { + "id": "web.command_webhook.invalid.app_error", + "translation": "Invalid webhook" + }, + { + "id": "web.command_webhook.command.app_error", + "translation": "Couldn't find the command" + }, { "id": "web.init.debug", "translation": "Initializing web routes" diff --git a/model/command_response.go b/model/command_response.go index 27d39e173..0b80b297b 100644 --- a/model/command_response.go +++ b/model/command_response.go @@ -6,6 +6,7 @@ package model import ( "encoding/json" "io" + "io/ioutil" ) const ( @@ -31,6 +32,22 @@ func (o *CommandResponse) ToJson() string { } } +func CommandResponseFromHTTPBody(contentType string, body io.Reader) *CommandResponse { + if contentType == "application/json" { + return CommandResponseFromJson(body) + } + if b, err := ioutil.ReadAll(body); err == nil { + return CommandResponseFromPlainText(string(b)) + } + return nil +} + +func CommandResponseFromPlainText(text string) *CommandResponse { + return &CommandResponse{ + Text: text, + } +} + func CommandResponseFromJson(data io.Reader) *CommandResponse { decoder := json.NewDecoder(data) var o CommandResponse diff --git a/model/command_response_test.go b/model/command_response_test.go index df478ff2c..19be796b8 100644 --- a/model/command_response_test.go +++ b/model/command_response_test.go @@ -18,6 +18,30 @@ func TestCommandResponseJson(t *testing.T) { } } +func TestCommandResponseFromHTTPBody(t *testing.T) { + for _, test := range []struct { + ContentType string + Body string + ExpectedText string + }{ + {"", "foo", "foo"}, + {"text/plain", "foo", "foo"}, + {"application/json", `{"text": "foo"}`, "foo"}, + } { + response := CommandResponseFromHTTPBody(test.ContentType, strings.NewReader(test.Body)) + if response.Text != test.ExpectedText { + t.Fatal() + } + } +} + +func TestCommandResponseFromPlainText(t *testing.T) { + response := CommandResponseFromPlainText("foo") + if response.Text != "foo" { + t.Fatal("text should be foo") + } +} + func TestCommandResponseFromJson(t *testing.T) { json := `{ "response_type": "ephemeral", diff --git a/model/command_webhook.go b/model/command_webhook.go new file mode 100644 index 000000000..0b00e00b6 --- /dev/null +++ b/model/command_webhook.go @@ -0,0 +1,65 @@ +// Copyright (c) 2017-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package model + +import ( + "net/http" +) + +type CommandWebhook struct { + Id string + CreateAt int64 + CommandId string + UserId string + ChannelId string + RootId string + ParentId string + UseCount int +} + +const ( + COMMAND_WEBHOOK_LIFETIME = 1000 * 60 * 30 +) + +func (o *CommandWebhook) PreSave() { + if o.Id == "" { + o.Id = NewId() + } + + if o.CreateAt == 0 { + o.CreateAt = GetMillis() + } +} + +func (o *CommandWebhook) IsValid() *AppError { + if len(o.Id) != 26 { + return NewAppError("CommandWebhook.IsValid", "model.command_hook.id.app_error", nil, "", http.StatusBadRequest) + } + + if o.CreateAt == 0 { + return NewAppError("CommandWebhook.IsValid", "model.command_hook.create_at.app_error", nil, "id="+o.Id, http.StatusBadRequest) + } + + if len(o.CommandId) != 26 { + return NewAppError("CommandWebhook.IsValid", "model.command_hook.command_id.app_error", nil, "", http.StatusBadRequest) + } + + if len(o.UserId) != 26 { + return NewAppError("CommandWebhook.IsValid", "model.command_hook.user_id.app_error", nil, "", http.StatusBadRequest) + } + + if len(o.ChannelId) != 26 { + return NewAppError("CommandWebhook.IsValid", "model.command_hook.channel_id.app_error", nil, "", http.StatusBadRequest) + } + + if len(o.RootId) != 0 && len(o.RootId) != 26 { + return NewAppError("CommandWebhook.IsValid", "model.command_hook.root_id.app_error", nil, "", http.StatusBadRequest) + } + + if len(o.ParentId) != 0 && len(o.ParentId) != 26 { + return NewAppError("CommandWebhook.IsValid", "model.command_hook.parent_id.app_error", nil, "", http.StatusBadRequest) + } + + return nil +} diff --git a/model/command_webhook_test.go b/model/command_webhook_test.go new file mode 100644 index 000000000..629bbdaa7 --- /dev/null +++ b/model/command_webhook_test.go @@ -0,0 +1,54 @@ +// Copyright (c) 2017-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package model + +import ( + "testing" +) + +func TestCommandWebhookPreSave(t *testing.T) { + h := CommandWebhook{} + h.PreSave() + if len(h.Id) != 26 { + t.Fatal("Id should be generated") + } + if h.CreateAt == 0 { + t.Fatal("CreateAt should be set") + } +} + +func TestCommandWebhookIsValid(t *testing.T) { + h := CommandWebhook{} + h.Id = NewId() + h.CreateAt = GetMillis() + h.CommandId = NewId() + h.UserId = NewId() + h.ChannelId = NewId() + + for _, test := range []struct { + Transform func() + ExpectedError string + }{ + {func() {}, ""}, + {func() { h.Id = "asd" }, "model.command_hook.id.app_error"}, + {func() { h.CreateAt = 0 }, "model.command_hook.create_at.app_error"}, + {func() { h.CommandId = "asd" }, "model.command_hook.command_id.app_error"}, + {func() { h.UserId = "asd" }, "model.command_hook.user_id.app_error"}, + {func() { h.ChannelId = "asd" }, "model.command_hook.channel_id.app_error"}, + {func() { h.RootId = "asd" }, "model.command_hook.root_id.app_error"}, + {func() { h.RootId = NewId() }, ""}, + {func() { h.ParentId = "asd" }, "model.command_hook.parent_id.app_error"}, + {func() { h.ParentId = NewId() }, ""}, + } { + tmp := h + test.Transform() + err := h.IsValid() + if test.ExpectedError == "" && err != nil { + t.Fatal("hook should be valid") + } else if test.ExpectedError != "" && test.ExpectedError != err.Id { + t.Fatal("expected " + test.ExpectedError + " error") + } + h = tmp + } +} diff --git a/store/layered_store.go b/store/layered_store.go index 4eb908659..d215cb2f0 100644 --- a/store/layered_store.go +++ b/store/layered_store.go @@ -107,6 +107,10 @@ func (s *LayeredStore) Command() CommandStore { return s.DatabaseLayer.Command() } +func (s *LayeredStore) CommandWebhook() CommandWebhookStore { + return s.DatabaseLayer.CommandWebhook() +} + func (s *LayeredStore) Preference() PreferenceStore { return s.DatabaseLayer.Preference() } diff --git a/store/sql_command_webhook_store.go b/store/sql_command_webhook_store.go new file mode 100644 index 000000000..af4b298b1 --- /dev/null +++ b/store/sql_command_webhook_store.go @@ -0,0 +1,125 @@ +// Copyright (c) 2017-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package store + +import ( + "net/http" + + "database/sql" + + l4g "github.com/alecthomas/log4go" + + "github.com/mattermost/platform/model" +) + +type SqlCommandWebhookStore struct { + SqlStore +} + +func NewSqlCommandWebhookStore(sqlStore SqlStore) CommandWebhookStore { + s := &SqlCommandWebhookStore{sqlStore} + + for _, db := range sqlStore.GetAllConns() { + tablec := db.AddTableWithName(model.CommandWebhook{}, "CommandWebhooks").SetKeys(false, "Id") + tablec.ColMap("Id").SetMaxSize(26) + tablec.ColMap("CommandId").SetMaxSize(26) + tablec.ColMap("UserId").SetMaxSize(26) + tablec.ColMap("ChannelId").SetMaxSize(26) + tablec.ColMap("RootId").SetMaxSize(26) + tablec.ColMap("ParentId").SetMaxSize(26) + } + + return s +} + +func (s SqlCommandWebhookStore) CreateIndexesIfNotExists() { + s.CreateIndexIfNotExists("idx_command_webhook_create_at", "CommandWebhooks", "CreateAt") +} + +func (s SqlCommandWebhookStore) Save(webhook *model.CommandWebhook) StoreChannel { + storeChannel := make(StoreChannel, 1) + + go func() { + result := StoreResult{} + + if len(webhook.Id) > 0 { + result.Err = model.NewLocAppError("SqlCommandWebhookStore.Save", "store.sql_command_webhooks.save.existing.app_error", nil, "id="+webhook.Id) + storeChannel <- result + close(storeChannel) + return + } + + webhook.PreSave() + if result.Err = webhook.IsValid(); result.Err != nil { + storeChannel <- result + close(storeChannel) + return + } + + if err := s.GetMaster().Insert(webhook); err != nil { + result.Err = model.NewLocAppError("SqlCommandWebhookStore.Save", "store.sql_command_webhooks.save.app_error", nil, "id="+webhook.Id+", "+err.Error()) + } else { + result.Data = webhook + } + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + +func (s SqlCommandWebhookStore) Get(id string) StoreChannel { + storeChannel := make(StoreChannel, 1) + + go func() { + result := StoreResult{} + + var webhook model.CommandWebhook + + exptime := model.GetMillis() - model.COMMAND_WEBHOOK_LIFETIME + if err := s.GetReplica().SelectOne(&webhook, "SELECT * FROM CommandWebhooks WHERE Id = :Id AND CreateAt > :ExpTime", map[string]interface{}{"Id": id, "ExpTime": exptime}); err != nil { + result.Err = model.NewLocAppError("SqlCommandWebhookStore.Get", "store.sql_command_webhooks.get.app_error", nil, "id="+id+", err="+err.Error()) + if err == sql.ErrNoRows { + result.Err.StatusCode = http.StatusNotFound + } + } + + result.Data = &webhook + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + +func (s SqlCommandWebhookStore) TryUse(id string, limit int) StoreChannel { + storeChannel := make(StoreChannel, 1) + + go func() { + result := StoreResult{} + + if sqlResult, err := s.GetMaster().Exec("UPDATE CommandWebhooks SET UseCount = UseCount + 1 WHERE Id = :Id AND UseCount < :UseLimit", map[string]interface{}{"Id": id, "UseLimit": limit}); err != nil { + result.Err = model.NewLocAppError("SqlCommandWebhookStore.TryUse", "store.sql_command_webhooks.try_use.app_error", nil, "id="+id+", err="+err.Error()) + } else if rows, _ := sqlResult.RowsAffected(); rows == 0 { + result.Err = model.NewAppError("SqlCommandWebhookStore.TryUse", "store.sql_command_webhooks.try_use.invalid.app_error", nil, "id="+id, http.StatusBadRequest) + } + + result.Data = id + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + +func (s SqlCommandWebhookStore) Cleanup() { + l4g.Debug("Cleaning up command webhook store.") + exptime := model.GetMillis() - model.COMMAND_WEBHOOK_LIFETIME + if _, err := s.GetMaster().Exec("DELETE FROM CommandWebhooks WHERE CreateAt < :ExpTime", map[string]interface{}{"ExpTime": exptime}); err != nil { + l4g.Error("Unable to cleanup command webhook store.") + } +} diff --git a/store/sql_command_webhook_store_test.go b/store/sql_command_webhook_store_test.go new file mode 100644 index 000000000..2215a4263 --- /dev/null +++ b/store/sql_command_webhook_store_test.go @@ -0,0 +1,65 @@ +// Copyright (c) 2017-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package store + +import ( + "testing" + + "net/http" + + "github.com/mattermost/platform/model" +) + +func TestCommandWebhookStore(t *testing.T) { + Setup() + + cws := store.CommandWebhook() + + h1 := &model.CommandWebhook{} + h1.CommandId = model.NewId() + h1.UserId = model.NewId() + h1.ChannelId = model.NewId() + h1 = (<-cws.Save(h1)).Data.(*model.CommandWebhook) + + if r1 := <-cws.Get(h1.Id); r1.Err != nil { + t.Fatal(r1.Err) + } else { + if *r1.Data.(*model.CommandWebhook) != *h1 { + t.Fatal("invalid returned webhook") + } + } + + if err := (<-cws.Get("123")).Err; err.StatusCode != http.StatusNotFound { + t.Fatal("Should have set the status as not found for missing id") + } + + h2 := &model.CommandWebhook{} + h2.CreateAt = model.GetMillis() - 2*model.COMMAND_WEBHOOK_LIFETIME + h2.CommandId = model.NewId() + h2.UserId = model.NewId() + h2.ChannelId = model.NewId() + h2 = (<-cws.Save(h2)).Data.(*model.CommandWebhook) + + if err := (<-cws.Get(h2.Id)).Err; err == nil || err.StatusCode != http.StatusNotFound { + t.Fatal("Should have set the status as not found for expired webhook") + } + + cws.Cleanup() + + if err := (<-cws.Get(h1.Id)).Err; err != nil { + t.Fatal("Should have no error getting unexpired webhook") + } + + if err := (<-cws.Get(h2.Id)).Err; err.StatusCode != http.StatusNotFound { + t.Fatal("Should have set the status as not found for expired webhook") + } + + if err := (<-cws.TryUse(h1.Id, 1)).Err; err != nil { + t.Fatal("Should be able to use webhook once") + } + + if err := (<-cws.TryUse(h1.Id, 1)).Err; err == nil || err.StatusCode != http.StatusBadRequest { + t.Fatal("Should be able to use webhook once") + } +} diff --git a/store/sql_store.go b/store/sql_store.go index 817f3fb0f..488b44522 100644 --- a/store/sql_store.go +++ b/store/sql_store.go @@ -72,6 +72,7 @@ type SqlStore interface { System() SystemStore Webhook() WebhookStore Command() CommandStore + CommandWebhook() CommandWebhookStore Preference() PreferenceStore License() LicenseStore Token() TokenStore diff --git a/store/sql_supplier.go b/store/sql_supplier.go index 5997a1339..5b9c268bb 100644 --- a/store/sql_supplier.go +++ b/store/sql_supplier.go @@ -70,6 +70,7 @@ type SqlSupplierOldStores struct { system SystemStore webhook WebhookStore command CommandStore + commandWebhook CommandWebhookStore preference PreferenceStore license LicenseStore token TokenStore @@ -111,6 +112,7 @@ func NewSqlSupplier() *SqlSupplier { supplier.oldStores.system = NewSqlSystemStore(supplier) supplier.oldStores.webhook = NewSqlWebhookStore(supplier) supplier.oldStores.command = NewSqlCommandStore(supplier) + supplier.oldStores.commandWebhook = NewSqlCommandWebhookStore(supplier) supplier.oldStores.preference = NewSqlPreferenceStore(supplier) supplier.oldStores.license = NewSqlLicenseStore(supplier) supplier.oldStores.token = NewSqlTokenStore(supplier) @@ -142,6 +144,7 @@ func NewSqlSupplier() *SqlSupplier { supplier.oldStores.system.(*SqlSystemStore).CreateIndexesIfNotExists() supplier.oldStores.webhook.(*SqlWebhookStore).CreateIndexesIfNotExists() supplier.oldStores.command.(*SqlCommandStore).CreateIndexesIfNotExists() + supplier.oldStores.commandWebhook.(*SqlCommandWebhookStore).CreateIndexesIfNotExists() supplier.oldStores.preference.(*SqlPreferenceStore).CreateIndexesIfNotExists() supplier.oldStores.license.(*SqlLicenseStore).CreateIndexesIfNotExists() supplier.oldStores.token.(*SqlTokenStore).CreateIndexesIfNotExists() @@ -732,6 +735,10 @@ func (ss *SqlSupplier) Command() CommandStore { return ss.oldStores.command } +func (ss *SqlSupplier) CommandWebhook() CommandWebhookStore { + return ss.oldStores.commandWebhook +} + func (ss *SqlSupplier) Preference() PreferenceStore { return ss.oldStores.preference } diff --git a/store/store.go b/store/store.go index d883ea5a2..e86b5f116 100644 --- a/store/store.go +++ b/store/store.go @@ -41,6 +41,7 @@ type Store interface { System() SystemStore Webhook() WebhookStore Command() CommandStore + CommandWebhook() CommandWebhookStore Preference() PreferenceStore License() LicenseStore Token() TokenStore @@ -326,6 +327,13 @@ type CommandStore interface { AnalyticsCommandCount(teamId string) StoreChannel } +type CommandWebhookStore interface { + Save(webhook *model.CommandWebhook) StoreChannel + Get(id string) StoreChannel + TryUse(id string, limit int) StoreChannel + Cleanup() +} + type PreferenceStore interface { Save(preferences *model.Preferences) StoreChannel Get(userId string, category string, name string) StoreChannel -- cgit v1.2.3-1-g7c22