diff --git a/pkg/engine/internal/scheduler/wire/wire_local.go b/pkg/engine/internal/scheduler/wire/wire_local.go index 966ce2c662226..c416fda116926 100644 --- a/pkg/engine/internal/scheduler/wire/wire_local.go +++ b/pkg/engine/internal/scheduler/wire/wire_local.go @@ -15,6 +15,9 @@ var ( // LocalWorker is the address of the local worker when using the // [Local] listener. LocalWorker net.Addr = localAddr("worker") + + // LocalWorker2 is another address of the local worker. + LocalWorker2 net.Addr = localAddr("worker2") ) type localAddr string diff --git a/pkg/engine/internal/worker/thread.go b/pkg/engine/internal/worker/thread.go index dcba20f7dfa85..c7a3bf30ee510 100644 --- a/pkg/engine/internal/worker/thread.go +++ b/pkg/engine/internal/worker/thread.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync" "time" "github.com/go-kit/log" @@ -42,11 +43,21 @@ type thread struct { Logger log.Logger Ready chan<- readyRequest + + stopped chan struct{} + stopOnce sync.Once +} + +func (t *thread) Stop() { + t.stopOnce.Do(func() { + close(t.stopped) + }) } // Run starts the thread. Run will request and run tasks in a loop until the -// context is canceled. -func (t *thread) Run(ctx context.Context) error { +// thread is stopped. Run will not stop if any job failed, it will log the error and continue +// acceptinh other jobs. +func (t *thread) Run() { NextTask: for { level.Debug(t.Logger).Log("msg", "requesting task") @@ -59,20 +70,20 @@ NextTask: // ensures that the context of tasks written to respCh are bound to the // lifetime of the thread, but can also be canceled by the scheduler. req := readyRequest{ - Context: ctx, + Context: context.Background(), Response: respCh, } // Send our request. select { - case <-ctx.Done(): - return nil + case <-t.stopped: + return case t.Ready <- req: } // Wait for a task assignment. select { - case <-ctx.Done(): + case <-t.stopped: // TODO(rfratto): This will silently drop tasks written to respCh. // But since Run only exits when the worker is exiting, this should // be handled gracefully by the scheduler (it will detect the @@ -81,7 +92,7 @@ NextTask: // If, in the future, we dynamically change the number of threads, // we'll want a mechanism to gracefully handle this so the writer to // respCh knows that the task was dropped. - return nil + return case resp := <-respCh: if resp.Error != nil { diff --git a/pkg/engine/internal/worker/worker.go b/pkg/engine/internal/worker/worker.go index dccec403f743d..6ea7a6a5e8d66 100644 --- a/pkg/engine/internal/worker/worker.go +++ b/pkg/engine/internal/worker/worker.go @@ -162,30 +162,45 @@ func (w *Worker) Service() services.Service { // run starts the worker, running until the provided context is canceled. func (w *Worker) run(ctx context.Context) error { - g, ctx := errgroup.WithContext(ctx) + threadsGroup := &sync.WaitGroup{} + threads := make([]*thread, w.numThreads) // Spin up worker threads. for i := range w.numThreads { - t := &thread{ + threads[i] = &thread{ BatchSize: w.config.BatchSize, Logger: log.With(w.logger, "thread", i), Bucket: w.config.Bucket, - Ready: w.readyCh, + Ready: w.readyCh, + stopped: make(chan struct{}), } - g.Go(func() error { return t.Run(ctx) }) + threadsGroup.Go(func() { threads[i].Run() }) } - g.Go(func() error { return w.runAcceptLoop(ctx) }) + // Spin up the listener for peer connections + peerConnectionsCtx, peerConnectionsCancel := context.WithCancel(context.Background()) + defer peerConnectionsCancel() + listenerCtx, listenerCancel := context.WithCancel(context.Background()) + defer listenerCancel() + + go func() { + w.runAcceptLoop(listenerCtx, peerConnectionsCtx) + }() + + // Spin up the scheduler loop + schedulerCtx, schedulerCancel := context.WithCancel(context.Background()) + defer schedulerCancel() + schedulerGroup := errgroup.Group{} if w.config.SchedulerLookupAddress != "" { disc, err := newSchedulerLookup(w.logger, w.config.SchedulerLookupAddress, w.config.SchedulerLookupInterval) if err != nil { return fmt.Errorf("creating scheduler lookup: %w", err) } - g.Go(func() error { - return disc.Run(ctx, func(ctx context.Context, addr net.Addr) { + schedulerGroup.Go(func() error { + return disc.Run(schedulerCtx, func(ctx context.Context, addr net.Addr) { _ = w.schedulerLoop(ctx, addr) }) }) @@ -193,26 +208,51 @@ func (w *Worker) run(ctx context.Context) error { if w.config.SchedulerAddress != nil { level.Info(w.logger).Log("msg", "directly connecting to scheduler", "scheduler_addr", w.config.SchedulerAddress) - g.Go(func() error { return w.schedulerLoop(ctx, w.config.SchedulerAddress) }) + schedulerGroup.Go(func() error { return w.schedulerLoop(schedulerCtx, w.config.SchedulerAddress) }) + } + + // Wait for shutdown + <-ctx.Done() + + // Stop accepting new connections from peers. + listenerCancel() + + // Signal all worker threads to stop. This will make them not to ask for new tasks, but continue processing current jobs. + for _, t := range threads { + t.Stop() } + // Wait for all worker threads to finish their current jobs. + threadsGroup.Wait() - return g.Wait() + // Stop scheduler loop + schedulerCancel() + + // Wait for scheduler loop to finish + err := schedulerGroup.Wait() + if err != nil { + return err + } + + // Close all peer connections + peerConnectionsCancel() + + return nil } // runAcceptLoop handles incoming connections from peers. Incoming connections // are exclusively used to receive task results from other workers, or between // threads within this worker. -func (w *Worker) runAcceptLoop(ctx context.Context) error { +func (w *Worker) runAcceptLoop(listenerCtx, peerConnectionsCtx context.Context) { for { - conn, err := w.listener.Accept(ctx) - if err != nil && ctx.Err() != nil { - return nil + conn, err := w.listener.Accept(listenerCtx) + if err != nil && listenerCtx.Err() != nil { + return } else if err != nil { level.Warn(w.logger).Log("msg", "failed to accept connection", "err", err) continue } - go w.handleConn(ctx, conn) + go w.handleConn(peerConnectionsCtx, conn) } } @@ -365,13 +405,13 @@ func (w *Worker) handleSchedulerConn(ctx context.Context, logger log.Logger, con return handleAssignment(peer, msg) case wire.TaskCancelMessage: - return w.handleCancelMessage(ctx, msg) + return w.handleCancelMessage(msg) case wire.StreamBindMessage: return w.handleBindMessage(ctx, msg) case wire.StreamStatusMessage: - return w.handleStreamStatusMessage(ctx, msg) + return w.handleStreamStatusMessage(msg) default: level.Warn(logger).Log("msg", "unsupported message type", "type", reflect.TypeOf(msg).String()) @@ -526,7 +566,7 @@ func (w *Worker) newJob(ctx context.Context, scheduler *wire.Peer, logger log.Lo return job, nil } -func (w *Worker) handleCancelMessage(_ context.Context, msg wire.TaskCancelMessage) error { +func (w *Worker) handleCancelMessage(msg wire.TaskCancelMessage) error { w.resourcesMut.RLock() job, found := w.jobs[msg.ID] w.resourcesMut.RUnlock() @@ -550,7 +590,7 @@ func (w *Worker) handleBindMessage(ctx context.Context, msg wire.StreamBindMessa return sink.Bind(ctx, msg.Receiver) } -func (w *Worker) handleStreamStatusMessage(_ context.Context, msg wire.StreamStatusMessage) error { +func (w *Worker) handleStreamStatusMessage(msg wire.StreamStatusMessage) error { w.resourcesMut.RLock() source, found := w.sources[msg.StreamID] w.resourcesMut.RUnlock() diff --git a/pkg/engine/internal/worker/worker_test.go b/pkg/engine/internal/worker/worker_test.go index 17b724523e6e0..b10f60b6ff822 100644 --- a/pkg/engine/internal/worker/worker_test.go +++ b/pkg/engine/internal/worker/worker_test.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "testing" + "testing/synctest" "time" "github.com/apache/arrow-go/v18/arrow" @@ -14,6 +15,7 @@ import ( "github.com/go-kit/log" "github.com/grafana/dskit/services" "github.com/grafana/dskit/user" + "github.com/oklog/ulid/v2" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "github.com/thanos-io/objstore" @@ -24,6 +26,8 @@ import ( "github.com/grafana/loki/v3/pkg/engine/internal/planner/physical" "github.com/grafana/loki/v3/pkg/engine/internal/scheduler" "github.com/grafana/loki/v3/pkg/engine/internal/scheduler/wire" + "github.com/grafana/loki/v3/pkg/engine/internal/semconv" + "github.com/grafana/loki/v3/pkg/engine/internal/util/dag" "github.com/grafana/loki/v3/pkg/engine/internal/util/objtest" "github.com/grafana/loki/v3/pkg/engine/internal/worker" "github.com/grafana/loki/v3/pkg/engine/internal/workflow" @@ -96,20 +100,212 @@ func Test(t *testing.T) { require.Equal(t, expected, actual) } +// TestWorkerGracefulShutdown tests that the worker gracefully shuts down by +// finishing execution of a task even after its context is canceled. The test +// creates a TopK node job that accepts a Stream, waits for the worker to start +// processing, cancels the worker's context, then sends data to the stream and +// verifies the job completes. +func TestWorkerGracefulShutdown(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + logger := log.NewNopLogger() + if testing.Verbose() { + logger = log.NewLogfmtLogger(log.NewSyncWriter(os.Stderr)) + } + + net := newTestNetwork() + sched := newTestScheduler(t, logger, net) + + // Create a cancelable context for the worker's run() method + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _ = newTestWorkerWithContext(t, logger, objtest.Location{}, net, runCtx) + + ctx := user.InjectOrgID(context.Background(), objtest.Tenant) + + // Create a physical plan with a TopK node + topkNode := &physical.TopK{ + NodeID: ulid.Make(), + SortBy: &physical.ColumnExpr{Ref: semconv.ColumnIdentTimestamp.ColumnRef()}, + Ascending: false, + K: 10, + } + + planGraph := &dag.Graph[physical.Node]{} + planGraph.Add(topkNode) + plan := physical.FromGraph(*planGraph) + + // Create a stream that will feed data to the TopK node + inputStream := &workflow.Stream{ + ULID: ulid.Make(), + TenantID: objtest.Tenant, + } + + // Create a workflow task manually with the TopK node and stream source + task := &workflow.Task{ + ULID: ulid.Make(), + TenantID: objtest.Tenant, + Fragment: plan, + Sources: map[physical.Node][]*workflow.Stream{ + topkNode: {inputStream}, + }, + Sinks: make(map[physical.Node][]*workflow.Stream), + } + + // Create a results stream for the workflow output + resultsStream := &workflow.Stream{ + ULID: ulid.Make(), + TenantID: objtest.Tenant, + } + task.Sinks[topkNode] = []*workflow.Stream{resultsStream} + + // Create a workflow with the task + manifest := &workflow.Manifest{ + Streams: []*workflow.Stream{inputStream, resultsStream}, + Tasks: []*workflow.Task{task}, + TaskEventHandler: func(_ context.Context, _ *workflow.Task, _ workflow.TaskStatus) { + // Empty + }, + StreamEventHandler: func(_ context.Context, _ *workflow.Stream, _ workflow.StreamState) { + // Empty + }, + } + require.NoError(t, sched.RegisterManifest(ctx, manifest)) + + // Create a simple record writer to receive results + resultsWriter := &testRecordWriter{records: make(chan arrow.RecordBatch, 10)} + require.NoError(t, sched.Listen(ctx, resultsWriter, resultsStream)) + + // Start the task - this will assign it to the worker + require.NoError(t, sched.Start(ctx, task)) + + // Wait for the task to be assigned to the worker and start waiting for input + synctest.Wait() + + // Now send data to the input stream + // The worker should process this data even though its context was canceled + schema := arrow.NewSchema([]arrow.Field{ + semconv.FieldFromIdent(semconv.ColumnIdentTimestamp, false), + semconv.FieldFromIdent(semconv.ColumnIdentMessage, false), + }, nil) + + timestampBuilder := array.NewTimestampBuilder(memory.DefaultAllocator, arrow.FixedWidthTypes.Timestamp_ns.(*arrow.TimestampType)) + messageBuilder := array.NewStringBuilder(memory.DefaultAllocator) + + // Create test data + for i := 0; i < 5; i++ { + timestampBuilder.Append(arrow.Timestamp(time.Date(2025, time.January, 1, 0, 0, i, 0, time.UTC).UnixNano())) + messageBuilder.Append(fmt.Sprintf("Message %d", i)) + } + + timestampArr := timestampBuilder.NewArray() + messageArr := messageBuilder.NewArray() + record := array.NewRecordBatch(schema, []arrow.Array{timestampArr, messageArr}, 5) + + // Send data for the stream to the worker + // We need to connect to the worker 1 on behalf of worker 2 and send a StreamDataMessage + workerConn, err := net.worker1Listener.DialFrom(ctx, wire.LocalWorker2) + require.NoError(t, err) + defer workerConn.Close() + + workerPeer := &wire.Peer{ + Logger: logger, + Conn: workerConn, + Handler: func(_ context.Context, _ *wire.Peer, _ wire.Message) error { + return nil + }, + } + go func() { _ = workerPeer.Serve(ctx) }() + + // Cancel the worker's context to trigger graceful shutdown + // This simulates the worker receiving a shutdown signal in Worker.run(). + // The earliest we can cancel is after worker1Listener is dialed to, otherwise + // the connection will not be accepted. + cancel() + + // Connect to the scheduler on behalf of worker 2 + schedulerConn, err := net.schedulerListener.DialFrom(ctx, wire.LocalWorker2) + require.NoError(t, err) + defer schedulerConn.Close() + + schedulerPeer := &wire.Peer{ + Logger: logger, + Conn: schedulerConn, + Handler: func(_ context.Context, _ *wire.Peer, _ wire.Message) error { + return nil + }, + } + go func() { _ = schedulerPeer.Serve(ctx) }() + + // Say hello to the scheduler on behalf of worker 2 + err = schedulerPeer.SendMessage(ctx, wire.WorkerHelloMessage{ + Threads: 1, + }) + require.NoError(t, err) + + synctest.Wait() + + // Send the data message + err = workerPeer.SendMessage(ctx, wire.StreamDataMessage{ + StreamID: inputStream.ULID, + Data: record, + }) + require.NoError(t, err) + + // Close the stream to signal EOF by sending a StreamStatusMessage + err = schedulerPeer.SendMessage(ctx, wire.StreamStatusMessage{ + StreamID: inputStream.ULID, + State: workflow.StreamStateClosed, + }) + require.NoError(t, err) + + // Wait for results - the worker should have processed the data + // even though its context was canceled + resultCtx, resultCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer resultCancel() + + select { + case resultRecord := <-resultsWriter.records: + require.Greater(t, resultRecord.NumRows(), int64(0), "should have at least some rows") + case <-resultCtx.Done(): + t.Fatal("did not receive result within timeout") + } + + synctest.Wait() + }) +} + +// testRecordWriter implements workflow.RecordWriter for testing +type testRecordWriter struct { + records chan arrow.RecordBatch +} + +func (w *testRecordWriter) Write(ctx context.Context, record arrow.RecordBatch) error { + select { + case <-ctx.Done(): + return ctx.Err() + case w.records <- record: + return nil + } +} + type testNetwork struct { schedulerListener *wire.Local - workerListener *wire.Local + worker1Listener *wire.Local + worker2Listener *wire.Local dialer wire.Dialer } func newTestNetwork() *testNetwork { schedulerListener := &wire.Local{Address: wire.LocalScheduler} - workerListener := &wire.Local{Address: wire.LocalWorker} + worker1Listener := &wire.Local{Address: wire.LocalWorker} + worker2Listener := &wire.Local{Address: wire.LocalWorker2} return &testNetwork{ schedulerListener: schedulerListener, - workerListener: workerListener, - dialer: wire.NewLocalDialer(schedulerListener, workerListener), + worker1Listener: worker1Listener, + worker2Listener: worker2Listener, + dialer: wire.NewLocalDialer(schedulerListener, worker1Listener, worker2Listener), } } @@ -134,6 +330,11 @@ func newTestScheduler(t *testing.T, logger log.Logger, net *testNetwork) *schedu } func newTestWorker(t *testing.T, logger log.Logger, loc objtest.Location, net *testNetwork) *worker.Worker { + return newTestWorkerWithContext(t, logger, loc, net, t.Context()) +} + +//nolint:revive +func newTestWorkerWithContext(t *testing.T, logger log.Logger, loc objtest.Location, net *testNetwork, runCtx context.Context) *worker.Worker { t.Helper() w, err := worker.New(worker.Config{ @@ -142,15 +343,15 @@ func newTestWorker(t *testing.T, logger log.Logger, loc objtest.Location, net *t BatchSize: 2048, Dialer: net.dialer, - Listener: net.workerListener, + Listener: net.worker1Listener, SchedulerAddress: wire.LocalScheduler, // Create enough threads to guarantee all tasks can be scheduled without // blocking. - NumThreads: 8, + NumThreads: 2, }) require.NoError(t, err, "expected to create worker") - require.NoError(t, services.StartAndAwaitRunning(t.Context(), w.Service())) + require.NoError(t, services.StartAndAwaitRunning(runCtx, w.Service())) t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)