Skip to content

Commit 289689d

Browse files
[ADD] per-task mutex for thread safety in TaskScheduler
1 parent 9c509c7 commit 289689d

File tree

1 file changed

+142
-34
lines changed

1 file changed

+142
-34
lines changed

queue/task.go

Lines changed: 142 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ type Task struct {
6262
Enabled bool `json:"enabled"`
6363
CreatedAt time.Time `json:"created_at"`
6464
UpdatedAt time.Time `json:"updated_at"`
65+
mutex sync.RWMutex `json:"-"`
6566
}
6667

6768
// TaskScheduler manages background tasks
@@ -300,7 +301,11 @@ func (s *TaskScheduler) RegisterOrRescheduleCronTaskWithOptions(name, cronSpec s
300301
existingTask, exists := s.tasks[name]
301302
if exists {
302303
// Task exists, reschedule it
303-
if existingTask.IsRunning {
304+
existingTask.mutex.RLock()
305+
isRunning := existingTask.IsRunning
306+
existingTask.mutex.RUnlock()
307+
308+
if isRunning {
304309
return apperror.NewError(fmt.Sprintf("cannot reschedule running task '%s'", name))
305310
}
306311

@@ -309,10 +314,10 @@ func (s *TaskScheduler) RegisterOrRescheduleCronTaskWithOptions(name, cronSpec s
309314
return apperror.NewError(fmt.Sprintf("failed to calculate next run time: %v", err))
310315
}
311316

312-
// Update existing task
317+
existingTask.mutex.Lock()
313318
existingTask.Type = TaskTypeCron
314319
existingTask.CronSpec = cronSpec
315-
existingTask.Interval = 0 // Clear interval for cron tasks
320+
existingTask.Interval = 0
316321
existingTask.Function = fn
317322
existingTask.NextRun = nextRun
318323
existingTask.UpdatedAt = time.Now()
@@ -329,11 +334,13 @@ func (s *TaskScheduler) RegisterOrRescheduleCronTaskWithOptions(name, cronSpec s
329334
}
330335
existingTask.AllowConcurrent = options.Concurrent
331336
existingTask.Quiet = options.Quiet
337+
nextRunForLog := existingTask.NextRun
338+
existingTask.mutex.Unlock()
332339

333340
logger.Trace().
334341
Field("task_name", name).
335342
Field("cron_spec", cronSpec).
336-
Field("next_run", nextRun).
343+
Field("next_run", nextRunForLog).
337344
Msg("existing cron task rescheduled")
338345

339346
return nil
@@ -402,19 +409,22 @@ func (s *TaskScheduler) RegisterOrRescheduleIntervalTaskWithOptions(name string,
402409

403410
existingTask, exists := s.tasks[name]
404411
if exists {
405-
// Task exists, reschedule it
406-
if existingTask.IsRunning {
412+
existingTask.mutex.RLock()
413+
isRunning := existingTask.IsRunning
414+
existingTask.mutex.RUnlock()
415+
416+
if isRunning {
407417
return apperror.NewError(fmt.Sprintf("cannot reschedule running task '%s'", name))
408418
}
409419

410-
// Update existing task
420+
existingTask.mutex.Lock()
411421
existingTask.Type = TaskTypeInterval
412422
existingTask.CronSpec = "" // Clear cron spec for interval tasks
413423
existingTask.Interval = interval
414424
existingTask.Function = fn
415425
existingTask.NextRun = time.Now().Add(interval)
416426
existingTask.UpdatedAt = time.Now()
417-
// Update options if provided
427+
418428
if options.MaxRetries >= 0 { // Allow explicit 0 to disable retries
419429
existingTask.MaxRetries = options.MaxRetries
420430
}
@@ -426,11 +436,13 @@ func (s *TaskScheduler) RegisterOrRescheduleIntervalTaskWithOptions(name string,
426436
}
427437
existingTask.AllowConcurrent = options.Concurrent
428438
existingTask.Quiet = options.Quiet
439+
nextRunForLog := existingTask.NextRun
440+
existingTask.mutex.Unlock()
429441

430442
logger.Trace().
431443
Field("task_name", name).
432444
Field("interval", interval).
433-
Field("next_run", existingTask.NextRun).
445+
Field("next_run", nextRunForLog).
434446
Msg("existing interval task rescheduled")
435447

436448
return nil
@@ -536,8 +548,15 @@ func (s *TaskScheduler) checkAndRunTasks(ctx context.Context) {
536548
now := time.Now()
537549

538550
for _, task := range s.tasks {
551+
task.mutex.RLock()
552+
enabled := task.Enabled
553+
nextRun := task.NextRun
554+
isRunning := task.IsRunning
555+
allowConcurrent := task.AllowConcurrent
556+
task.mutex.RUnlock()
557+
539558
// Run task if it's enabled, scheduled to run, and either not running or concurrent execution is allowed
540-
if task.Enabled && now.After(task.NextRun) && (!task.IsRunning || task.AllowConcurrent) {
559+
if enabled && now.After(nextRun) && (!isRunning || allowConcurrent) {
541560
tasksToRun = append(tasksToRun, task)
542561
}
543562
}
@@ -553,7 +572,6 @@ func (s *TaskScheduler) checkAndRunTasks(ctx context.Context) {
553572
func (s *TaskScheduler) runTask(ctx context.Context, task *Task) {
554573
defer s.workerWg.Done()
555574

556-
s.tasksMutex.Lock()
557575
// For concurrent tasks, update next run time immediately so next instance can be scheduled
558576
// For non-concurrent tasks, set running state to prevent overlapping executions
559577
if task.AllowConcurrent {
@@ -565,10 +583,17 @@ func (s *TaskScheduler) runTask(ctx context.Context, task *Task) {
565583
Msg("failed to update next run time before execution")
566584
}
567585
} else {
586+
task.mutex.Lock()
568587
task.IsRunning = true
588+
task.UpdatedAt = time.Now()
589+
task.mutex.Unlock()
590+
}
591+
592+
if task.AllowConcurrent {
593+
task.mutex.Lock()
594+
task.UpdatedAt = time.Now()
595+
task.mutex.Unlock()
569596
}
570-
task.UpdatedAt = time.Now()
571-
s.tasksMutex.Unlock()
572597

573598
taskCtx, cancel := context.WithTimeout(ctx, task.Timeout)
574599
defer cancel()
@@ -590,7 +615,7 @@ func (s *TaskScheduler) runTask(ctx context.Context, task *Task) {
590615
err := task.Function(taskCtx)
591616

592617
if err == nil {
593-
s.tasksMutex.Lock()
618+
task.mutex.Lock()
594619
// Only mark as not running for non-concurrent tasks
595620
if !task.AllowConcurrent {
596621
task.IsRunning = false
@@ -599,23 +624,36 @@ func (s *TaskScheduler) runTask(ctx context.Context, task *Task) {
599624
task.RunCount++
600625
task.ConsecutiveFailures = 0 // Reset consecutive failures on success
601626
task.LastError = ""
602-
task.UpdatedAt = time.Now() // For non-concurrent tasks, update next run time after completion
627+
task.UpdatedAt = time.Now()
628+
629+
// For non-concurrent tasks, update next run time after completion
603630
// For concurrent tasks, this was already done at the start
631+
var nextRunTime time.Time
632+
var runCount int64
604633
if !task.AllowConcurrent {
634+
task.mutex.Unlock()
605635
err = s.updateNextRun(task)
606636
if err != nil {
607637
logger.Error().
608638
Err(err).
609639
Field("task_name", task.Name).
610640
Msg("failed to update next run time")
611641
}
642+
// Read the values for logging after update
643+
task.mutex.RLock()
644+
nextRunTime = task.NextRun
645+
runCount = task.RunCount
646+
task.mutex.RUnlock()
647+
} else {
648+
nextRunTime = task.NextRun
649+
runCount = task.RunCount
650+
task.mutex.Unlock()
612651
}
613-
s.tasksMutex.Unlock()
614652

615653
logger.Trace().
616654
Field("task_name", task.Name).
617-
Field("run_count", task.RunCount).
618-
Field("next_run", task.NextRun).
655+
Field("run_count", runCount).
656+
Field("next_run", nextRunTime).
619657
Msg("task executed successfully")
620658
return
621659
}
@@ -638,38 +676,57 @@ func (s *TaskScheduler) runTask(ctx context.Context, task *Task) {
638676
}
639677
}
640678

641-
s.tasksMutex.Lock()
679+
// Handle failure case after all retries exhausted
680+
task.mutex.Lock()
642681

643682
// For non-concurrent tasks, update next run time after completion
644683
// For concurrent tasks, this was already done at the start
645684
if !task.AllowConcurrent {
646685
task.IsRunning = false
686+
}
687+
task.ErrorCount++
688+
task.LastRun = time.Now()
689+
task.ConsecutiveFailures++ // Increment consecutive failures
690+
task.LastError = lastError.Error()
691+
task.UpdatedAt = time.Now()
692+
693+
// Capture values for logging before updating next run
694+
errorCount := task.ErrorCount
695+
consecutiveFailures := task.ConsecutiveFailures
696+
var nextRunTime time.Time
697+
698+
if !task.AllowConcurrent {
699+
task.mutex.Unlock()
647700
err := s.updateNextRun(task)
648701
if err != nil {
649702
logger.Error().
650703
Err(err).
651704
Field("task_name", task.Name).
652705
Msg("failed to update next run time after retries")
653706
}
707+
// Read next run time for logging
708+
task.mutex.RLock()
709+
nextRunTime = task.NextRun
710+
task.mutex.RUnlock()
711+
} else {
712+
nextRunTime = task.NextRun
713+
task.mutex.Unlock()
654714
}
655-
task.ErrorCount++
656-
if !task.Quiet || task.ConsecutiveFailures == 0 {
715+
716+
if !task.Quiet || consecutiveFailures == 1 {
657717
logger.Error().
658718
Err(lastError).
659719
Field("task_name", task.Name).
660-
Field("error_count", task.ErrorCount).
661-
Field("next_run", task.NextRun).
720+
Field("error_count", errorCount).
721+
Field("next_run", nextRunTime).
662722
Msg("task execution failed")
663723
}
664-
665-
task.LastRun = time.Now()
666-
task.ConsecutiveFailures++ // Increment consecutive failures
667-
task.LastError = lastError.Error()
668-
task.UpdatedAt = time.Now()
669-
s.tasksMutex.Unlock()
670724
}
671725

672726
func (s *TaskScheduler) updateNextRun(task *Task) error {
727+
task.mutex.Lock()
728+
defer task.mutex.Unlock()
729+
673730
switch task.Type {
674731
case TaskTypeCron:
675732
nextRun, err := s.calculateNextCronRun(task.CronSpec, time.Now())
@@ -686,15 +743,39 @@ func (s *TaskScheduler) updateNextRun(task *Task) error {
686743
// GetTask returns a task by name
687744
func (s *TaskScheduler) GetTask(name string) (*Task, error) {
688745
s.tasksMutex.RLock()
689-
defer s.tasksMutex.RUnlock()
690-
691746
task, exists := s.tasks[name]
747+
s.tasksMutex.RUnlock()
748+
692749
if !exists {
693750
return nil, apperror.NewError(fmt.Sprintf("task '%s' not found", name))
694751
}
695752

696-
taskCopy := *task
697-
return &taskCopy, nil
753+
task.mutex.RLock()
754+
defer task.mutex.RUnlock()
755+
return &Task{
756+
ID: task.ID,
757+
Name: task.Name,
758+
Type: task.Type,
759+
CronSpec: task.CronSpec,
760+
Interval: task.Interval,
761+
Function: task.Function,
762+
NextRun: task.NextRun,
763+
LastRun: task.LastRun,
764+
RunCount: task.RunCount,
765+
ErrorCount: task.ErrorCount,
766+
ConsecutiveFailures: task.ConsecutiveFailures,
767+
LastError: task.LastError,
768+
IsRunning: task.IsRunning,
769+
Quiet: task.Quiet,
770+
AllowConcurrent: task.AllowConcurrent,
771+
MaxRetries: task.MaxRetries,
772+
RetryDelay: task.RetryDelay,
773+
Timeout: task.Timeout,
774+
Enabled: task.Enabled,
775+
CreatedAt: task.CreatedAt,
776+
UpdatedAt: task.UpdatedAt,
777+
// Note: mutex is intentionally not copied
778+
}, nil
698779
}
699780

700781
// GetTasks returns all registered tasks
@@ -704,7 +785,34 @@ func (s *TaskScheduler) GetTasks() map[string]*Task {
704785

705786
tasks := make(map[string]*Task, len(s.tasks))
706787
for name, task := range s.tasks {
707-
taskCopy := *task
788+
// Create a safe copy of each task with proper locking
789+
task.mutex.RLock()
790+
taskCopy := Task{
791+
ID: task.ID,
792+
Name: task.Name,
793+
Type: task.Type,
794+
CronSpec: task.CronSpec,
795+
Interval: task.Interval,
796+
Function: task.Function,
797+
NextRun: task.NextRun,
798+
LastRun: task.LastRun,
799+
RunCount: task.RunCount,
800+
ErrorCount: task.ErrorCount,
801+
ConsecutiveFailures: task.ConsecutiveFailures,
802+
LastError: task.LastError,
803+
IsRunning: task.IsRunning,
804+
Quiet: task.Quiet,
805+
AllowConcurrent: task.AllowConcurrent,
806+
MaxRetries: task.MaxRetries,
807+
RetryDelay: task.RetryDelay,
808+
Timeout: task.Timeout,
809+
Enabled: task.Enabled,
810+
CreatedAt: task.CreatedAt,
811+
UpdatedAt: task.UpdatedAt,
812+
// Note: mutex is intentionally not copied
813+
}
814+
task.mutex.RUnlock()
815+
708816
tasks[name] = &taskCopy
709817
}
710818

0 commit comments

Comments
 (0)