Skip to content
3 changes: 3 additions & 0 deletions pkg/engine/internal/scheduler/wire/wire_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ var (
// LocalWorker is the address of the local worker when using the
// [Local] listener.
LocalWorker net.Addr = localAddr("worker")

// LocalWorker2 is another address of the local worker.
LocalWorker2 net.Addr = localAddr("worker2")
)

type localAddr string
Expand Down
25 changes: 18 additions & 7 deletions pkg/engine/internal/worker/thread.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync"
"time"

"github.com/go-kit/log"
Expand Down Expand Up @@ -42,11 +43,21 @@ type thread struct {
Logger log.Logger

Ready chan<- readyRequest

stopped chan struct{}
stopOnce sync.Once
}

func (t *thread) Stop() {
t.stopOnce.Do(func() {
close(t.stopped)
})
}

// Run starts the thread. Run will request and run tasks in a loop until the
// context is canceled.
func (t *thread) Run(ctx context.Context) error {
// thread is stopped. Run will not stop if any job failed, it will log the error and continue
// acceptinh other jobs.
func (t *thread) Run() {
NextTask:
for {
level.Debug(t.Logger).Log("msg", "requesting task")
Expand All @@ -59,20 +70,20 @@ NextTask:
// ensures that the context of tasks written to respCh are bound to the
// lifetime of the thread, but can also be canceled by the scheduler.
req := readyRequest{
Context: ctx,
Context: context.Background(),
Response: respCh,
}

// Send our request.
select {
case <-ctx.Done():
return nil
case <-t.stopped:
return
case t.Ready <- req:
}

// Wait for a task assignment.
select {
case <-ctx.Done():
case <-t.stopped:
// TODO(rfratto): This will silently drop tasks written to respCh.
// But since Run only exits when the worker is exiting, this should
// be handled gracefully by the scheduler (it will detect the
Expand All @@ -81,7 +92,7 @@ NextTask:
// If, in the future, we dynamically change the number of threads,
// we'll want a mechanism to gracefully handle this so the writer to
// respCh knows that the task was dropped.
return nil
return

case resp := <-respCh:
if resp.Error != nil {
Expand Down
76 changes: 58 additions & 18 deletions pkg/engine/internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,57 +162,97 @@ func (w *Worker) Service() services.Service {

// run starts the worker, running until the provided context is canceled.
func (w *Worker) run(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
threadsGroup := &sync.WaitGroup{}
threads := make([]*thread, w.numThreads)

// Spin up worker threads.
for i := range w.numThreads {
t := &thread{
threads[i] = &thread{
BatchSize: w.config.BatchSize,
Logger: log.With(w.logger, "thread", i),
Bucket: w.config.Bucket,

Ready: w.readyCh,
Ready: w.readyCh,
stopped: make(chan struct{}),
}

g.Go(func() error { return t.Run(ctx) })
threadsGroup.Go(func() { threads[i].Run() })
}

g.Go(func() error { return w.runAcceptLoop(ctx) })
// Spin up the listener for peer connections
peerConnectionsCtx, peerConnectionsCancel := context.WithCancel(context.Background())
defer peerConnectionsCancel()
listenerCtx, listenerCancel := context.WithCancel(context.Background())
defer listenerCancel()

go func() {
w.runAcceptLoop(listenerCtx, peerConnectionsCtx)
}()

// Spin up the scheduler loop
schedulerCtx, schedulerCancel := context.WithCancel(context.Background())
defer schedulerCancel()

schedulerGroup := errgroup.Group{}
if w.config.SchedulerLookupAddress != "" {
disc, err := newSchedulerLookup(w.logger, w.config.SchedulerLookupAddress, w.config.SchedulerLookupInterval)
if err != nil {
return fmt.Errorf("creating scheduler lookup: %w", err)
}
g.Go(func() error {
return disc.Run(ctx, func(ctx context.Context, addr net.Addr) {
schedulerGroup.Go(func() error {
return disc.Run(schedulerCtx, func(ctx context.Context, addr net.Addr) {
_ = w.schedulerLoop(ctx, addr)
})
})
}

if w.config.SchedulerAddress != nil {
level.Info(w.logger).Log("msg", "directly connecting to scheduler", "scheduler_addr", w.config.SchedulerAddress)
g.Go(func() error { return w.schedulerLoop(ctx, w.config.SchedulerAddress) })
schedulerGroup.Go(func() error { return w.schedulerLoop(schedulerCtx, w.config.SchedulerAddress) })
}

// Wait for shutdown
<-ctx.Done()

// Stop accepting new connections from peers.
listenerCancel()

// Signal all worker threads to stop. This will make them not to ask for new tasks, but continue processing current jobs.
for _, t := range threads {
t.Stop()
}
// Wait for all worker threads to finish their current jobs.
threadsGroup.Wait()

return g.Wait()
// Stop scheduler loop
schedulerCancel()

// Wait for scheduler loop to finish
err := schedulerGroup.Wait()
if err != nil {
return err
}

// Close all peer connections
peerConnectionsCancel()

return nil
}

// runAcceptLoop handles incoming connections from peers. Incoming connections
// are exclusively used to receive task results from other workers, or between
// threads within this worker.
func (w *Worker) runAcceptLoop(ctx context.Context) error {
func (w *Worker) runAcceptLoop(listenerCtx, peerConnectionsCtx context.Context) {
for {
conn, err := w.listener.Accept(ctx)
if err != nil && ctx.Err() != nil {
return nil
conn, err := w.listener.Accept(listenerCtx)
if err != nil && listenerCtx.Err() != nil {
return
} else if err != nil {
level.Warn(w.logger).Log("msg", "failed to accept connection", "err", err)
continue
}

go w.handleConn(ctx, conn)
go w.handleConn(peerConnectionsCtx, conn)
}
}

Expand Down Expand Up @@ -365,13 +405,13 @@ func (w *Worker) handleSchedulerConn(ctx context.Context, logger log.Logger, con
return handleAssignment(peer, msg)

case wire.TaskCancelMessage:
return w.handleCancelMessage(ctx, msg)
return w.handleCancelMessage(msg)

case wire.StreamBindMessage:
return w.handleBindMessage(ctx, msg)

case wire.StreamStatusMessage:
return w.handleStreamStatusMessage(ctx, msg)
return w.handleStreamStatusMessage(msg)

default:
level.Warn(logger).Log("msg", "unsupported message type", "type", reflect.TypeOf(msg).String())
Expand Down Expand Up @@ -526,7 +566,7 @@ func (w *Worker) newJob(ctx context.Context, scheduler *wire.Peer, logger log.Lo
return job, nil
}

func (w *Worker) handleCancelMessage(_ context.Context, msg wire.TaskCancelMessage) error {
func (w *Worker) handleCancelMessage(msg wire.TaskCancelMessage) error {
w.resourcesMut.RLock()
job, found := w.jobs[msg.ID]
w.resourcesMut.RUnlock()
Expand All @@ -550,7 +590,7 @@ func (w *Worker) handleBindMessage(ctx context.Context, msg wire.StreamBindMessa
return sink.Bind(ctx, msg.Receiver)
}

func (w *Worker) handleStreamStatusMessage(_ context.Context, msg wire.StreamStatusMessage) error {
func (w *Worker) handleStreamStatusMessage(msg wire.StreamStatusMessage) error {
w.resourcesMut.RLock()
source, found := w.sources[msg.StreamID]
w.resourcesMut.RUnlock()
Expand Down
Loading
Loading