From 55261a2b155c4e3d699fc6bd4255f05fb8258157 Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 23 Feb 2018 12:56:32 -0600 Subject: fix scheduled task race (#8355) --- app/email_batching.go | 18 +++-- model/scheduled_task.go | 97 +++++++++---------------- model/scheduled_task_test.go | 163 +++++++------------------------------------ 3 files changed, 70 insertions(+), 208 deletions(-) diff --git a/app/email_batching.go b/app/email_batching.go index 2a33d7d3e..07adda674 100644 --- a/app/email_batching.go +++ b/app/email_batching.go @@ -7,6 +7,7 @@ import ( "fmt" "html/template" "strconv" + "sync" "time" "github.com/mattermost/mattermost-server/model" @@ -57,6 +58,8 @@ type EmailBatchingJob struct { app *App newNotifications chan *batchedNotification pendingNotifications map[string][]*batchedNotification + task *model.ScheduledTask + taskMutex sync.Mutex } func NewEmailBatchingJob(a *App, bufferSize int) *EmailBatchingJob { @@ -68,12 +71,17 @@ func NewEmailBatchingJob(a *App, bufferSize int) *EmailBatchingJob { } func (job *EmailBatchingJob) Start() { - if task := model.GetTaskByName(EMAIL_BATCHING_TASK_NAME); task != nil { - task.Cancel() - } - l4g.Debug(utils.T("api.email_batching.start.starting"), *job.app.Config().EmailSettings.EmailBatchingInterval) - model.CreateRecurringTask(EMAIL_BATCHING_TASK_NAME, job.CheckPendingEmails, time.Duration(*job.app.Config().EmailSettings.EmailBatchingInterval)*time.Second) + newTask := model.CreateRecurringTask(EMAIL_BATCHING_TASK_NAME, job.CheckPendingEmails, time.Duration(*job.app.Config().EmailSettings.EmailBatchingInterval)*time.Second) + + job.taskMutex.Lock() + oldTask := job.task + job.task = newTask + job.taskMutex.Unlock() + + if oldTask != nil { + oldTask.Cancel() + } } func (job *EmailBatchingJob) Add(user *model.User, post *model.Post, team *model.Team) bool { diff --git a/model/scheduled_task.go b/model/scheduled_task.go index 453828bd2..f3529dedb 100644 --- a/model/scheduled_task.go +++ b/model/scheduled_task.go @@ -5,7 +5,6 @@ package model import ( "fmt" - "sync" "time" ) @@ -15,89 +14,57 @@ type ScheduledTask struct { Name string `json:"name"` Interval time.Duration `json:"interval"` Recurring bool `json:"recurring"` - function TaskFunc - timer *time.Timer -} - -var taskMutex = sync.Mutex{} -var tasks = make(map[string]*ScheduledTask) - -func addTask(task *ScheduledTask) { - taskMutex.Lock() - defer taskMutex.Unlock() - tasks[task.Name] = task -} - -func removeTaskByName(name string) { - taskMutex.Lock() - defer taskMutex.Unlock() - delete(tasks, name) -} - -func GetTaskByName(name string) *ScheduledTask { - taskMutex.Lock() - defer taskMutex.Unlock() - if task, ok := tasks[name]; ok { - return task - } - return nil -} - -func GetAllTasks() *map[string]*ScheduledTask { - taskMutex.Lock() - defer taskMutex.Unlock() - return &tasks + function func() + cancel chan struct{} + cancelled chan struct{} } func CreateTask(name string, function TaskFunc, timeToExecution time.Duration) *ScheduledTask { - task := &ScheduledTask{ - Name: name, - Interval: timeToExecution, - Recurring: false, - function: function, - } - - taskRunner := func() { - go task.function() - removeTaskByName(task.Name) - } - - task.timer = time.AfterFunc(timeToExecution, taskRunner) - - addTask(task) - - return task + return createTask(name, function, timeToExecution, false) } func CreateRecurringTask(name string, function TaskFunc, interval time.Duration) *ScheduledTask { + return createTask(name, function, interval, true) +} + +func createTask(name string, function TaskFunc, interval time.Duration, recurring bool) *ScheduledTask { task := &ScheduledTask{ Name: name, Interval: interval, - Recurring: true, + Recurring: recurring, function: function, + cancel: make(chan struct{}), + cancelled: make(chan struct{}), } - taskRecurer := func() { - go task.function() - task.timer.Reset(task.Interval) - } + go func() { + defer close(task.cancelled) - task.timer = time.AfterFunc(interval, taskRecurer) + ticker := time.NewTicker(interval) + defer func() { + ticker.Stop() + }() - addTask(task) + for { + select { + case <-ticker.C: + function() + case <-task.cancel: + return + } + + if !task.Recurring { + break + } + } + }() return task } func (task *ScheduledTask) Cancel() { - task.timer.Stop() - removeTaskByName(task.Name) -} - -// Executes the task immediatly. A recurring task will be run regularally after interval. -func (task *ScheduledTask) Execute() { - task.function() - task.timer.Reset(task.Interval) + close(task.cancel) + <-task.cancelled } func (task *ScheduledTask) String() string { diff --git a/model/scheduled_task_test.go b/model/scheduled_task_test.go index 5af43b1ef..9537a662a 100644 --- a/model/scheduled_task_test.go +++ b/model/scheduled_task_test.go @@ -4,185 +4,72 @@ package model import ( + "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestCreateTask(t *testing.T) { TASK_NAME := "Test Task" - TASK_TIME := time.Second * 3 + TASK_TIME := time.Second * 2 - testValue := 0 + executionCount := new(int32) testFunc := func() { - testValue = 1 + atomic.AddInt32(executionCount, 1) } task := CreateTask(TASK_NAME, testFunc, TASK_TIME) - if testValue != 0 { - t.Fatal("Unexpected execuition of task") - } + assert.EqualValues(t, 0, atomic.LoadInt32(executionCount)) time.Sleep(TASK_TIME + time.Second) - if testValue != 1 { - t.Fatal("Task did not execute") - } - - if task.Name != TASK_NAME { - t.Fatal("Bad name") - } - - if task.Interval != TASK_TIME { - t.Fatal("Bad interval") - } - - if task.Recurring { - t.Fatal("should not reccur") - } + assert.EqualValues(t, 1, atomic.LoadInt32(executionCount)) + assert.Equal(t, TASK_NAME, task.Name) + assert.Equal(t, TASK_TIME, task.Interval) + assert.False(t, task.Recurring) } func TestCreateRecurringTask(t *testing.T) { TASK_NAME := "Test Recurring Task" - TASK_TIME := time.Second * 3 + TASK_TIME := time.Second * 2 - testValue := 0 + executionCount := new(int32) testFunc := func() { - testValue += 1 + atomic.AddInt32(executionCount, 1) } task := CreateRecurringTask(TASK_NAME, testFunc, TASK_TIME) - if testValue != 0 { - t.Fatal("Unexpected execuition of task") - } + assert.EqualValues(t, 0, atomic.LoadInt32(executionCount)) time.Sleep(TASK_TIME + time.Second) - if testValue != 1 { - t.Fatal("Task did not execute") - } + assert.EqualValues(t, 1, atomic.LoadInt32(executionCount)) time.Sleep(TASK_TIME) - if testValue != 2 { - t.Fatal("Task did not re-execute") - } - - if task.Name != TASK_NAME { - t.Fatal("Bad name") - } - - if task.Interval != TASK_TIME { - t.Fatal("Bad interval") - } - - if !task.Recurring { - t.Fatal("should reccur") - } + assert.EqualValues(t, 2, atomic.LoadInt32(executionCount)) + assert.Equal(t, TASK_NAME, task.Name) + assert.Equal(t, TASK_TIME, task.Interval) + assert.True(t, task.Recurring) task.Cancel() } func TestCancelTask(t *testing.T) { TASK_NAME := "Test Task" - TASK_TIME := time.Second * 3 + TASK_TIME := time.Second - testValue := 0 + executionCount := new(int32) testFunc := func() { - testValue = 1 + atomic.AddInt32(executionCount, 1) } task := CreateTask(TASK_NAME, testFunc, TASK_TIME) - if testValue != 0 { - t.Fatal("Unexpected execuition of task") - } + assert.EqualValues(t, 0, atomic.LoadInt32(executionCount)) task.Cancel() time.Sleep(TASK_TIME + time.Second) - - if testValue != 0 { - t.Fatal("Unexpected execuition of task") - } -} - -func TestGetAllTasks(t *testing.T) { - doNothing := func() {} - - CreateTask("Task1", doNothing, time.Hour) - CreateTask("Task2", doNothing, time.Second) - CreateRecurringTask("Task3", doNothing, time.Second) - task4 := CreateRecurringTask("Task4", doNothing, time.Second) - - task4.Cancel() - - time.Sleep(time.Second * 3) - - tasks := *GetAllTasks() - if len(tasks) != 2 { - t.Fatal("Wrong number of tasks got: ", len(tasks)) - } - for _, task := range tasks { - if task.Name != "Task1" && task.Name != "Task3" { - t.Fatal("Wrong tasks") - } - } -} - -func TestExecuteTask(t *testing.T) { - TASK_NAME := "Test Task" - TASK_TIME := time.Second * 5 - - testValue := 0 - testFunc := func() { - testValue += 1 - } - - task := CreateTask(TASK_NAME, testFunc, TASK_TIME) - if testValue != 0 { - t.Fatal("Unexpected execuition of task") - } - - task.Execute() - - if testValue != 1 { - t.Fatal("Task did not execute") - } - - time.Sleep(TASK_TIME + time.Second) - - if testValue != 2 { - t.Fatal("Task re-executed") - } -} - -func TestExecuteTaskRecurring(t *testing.T) { - TASK_NAME := "Test Recurring Task" - TASK_TIME := time.Second * 5 - - testValue := 0 - testFunc := func() { - testValue += 1 - } - - task := CreateRecurringTask(TASK_NAME, testFunc, TASK_TIME) - if testValue != 0 { - t.Fatal("Unexpected execuition of task") - } - - time.Sleep(time.Second * 3) - - task.Execute() - if testValue != 1 { - t.Fatal("Task did not execute") - } - - time.Sleep(time.Second * 3) - if testValue != 1 { - t.Fatal("Task should not have executed before 5 seconds") - } - - time.Sleep(time.Second * 3) - - if testValue != 2 { - t.Fatal("Task did not re-execute after forced execution") - } + assert.EqualValues(t, 0, atomic.LoadInt32(executionCount)) } -- cgit v1.2.3-1-g7c22