diff --git a/pkg/shim/v1/runsc/service.go b/pkg/shim/v1/runsc/service.go index 9bb73b4772..4459f1bcfc 100644 --- a/pkg/shim/v1/runsc/service.go +++ b/pkg/shim/v1/runsc/service.go @@ -88,18 +88,8 @@ type oomPoller interface { run(ctx context.Context) } -// runscService is the shim implementation of a remote shim over gRPC. It converts -// shim calls into `runsc` commands. It runs in 2 different modes: -// 1. Service: process runs for the life time of the container and receives -// calls described in shimapi.TaskService interface. -// 2. Tool: process is short lived and runs only to perform the requested -// operations and then exits. It implements the direct functions in -// shim.Shim interface. -// -// When the service is running, it saves a json file with state information so -// that commands sent to the tool can load the state and perform the operation -// with the required context. -type runscService struct { +// runscContainer contains the container details. +type runscContainer struct { mu sync.Mutex // id is the container ID. @@ -115,7 +105,29 @@ type runscService struct { // processes maps ExecId to processes running through exec. processes map[string]extension.Process + // events is used for container related events. events chan any +} + +// runscService is the shim implementation of a remote shim over gRPC. It converts +// shim calls into `runsc` commands. It runs in 2 different modes: +// 1. Service: process runs for the life time of the container and receives +// calls described in shimapi.TaskService interface. +// 2. Tool: process is short lived and runs only to perform the requested +// operations and then exits. It implements the direct functions in +// shim.Shim interface. +// +// When the service is running, it saves a json file with state information so +// that commands sent to the tool can load the state and perform the operation +// with the required context. +type runscService struct { + mu sync.Mutex + + // id is only used in cleanup case. + id string + + // containers contains the list of containers in this sandbox. + containers map[string]*runscContainer // platform handles operations related to the console. platform stdio.Platform @@ -129,12 +141,19 @@ type runscService struct { // oomPoller monitors the sandbox's cgroup for OOM notifications. oomPoller oomPoller + + // cancel is a function that needs to be called before the shim stops. The + // function is provided by the caller to New(). + cancel func() + + // shimAddress is the location of the UDS used to communicate to containerd. + shimAddress string } var _ extension.TaskServiceExt = (*runscService)(nil) // New returns a new shim service. -func New(ctx context.Context, id string, publisher shim.Publisher) (extension.TaskServiceExt, error) { +func New(ctx context.Context, id string, publisher shim.Publisher, cancel func(), shimAddress string) (extension.TaskServiceExt, error) { var ( ep oomPoller err error @@ -149,11 +168,12 @@ func New(ctx context.Context, id string, publisher shim.Publisher) (extension.Ta } go ep.run(ctx) s := &runscService{ - id: id, - processes: make(map[string]extension.Process), - events: make(chan any, 128), - ec: proc.ExitCh, - oomPoller: ep, + ec: proc.ExitCh, + oomPoller: ep, + id: id, + containers: make(map[string]*runscContainer), + cancel: cancel, + shimAddress: shimAddress, } go s.processExits(ctx) runsccmd.Monitor = &runsccmd.LogMonitor{Next: reaper.Default} @@ -165,6 +185,18 @@ func New(ctx context.Context, id string, publisher shim.Publisher) (extension.Ta return s, nil } +func (s *runscService) getContainer(id string) (*runscContainer, error) { + s.mu.Lock() + defer s.mu.Unlock() + + c, ok := s.containers[id] + if !ok { + return nil, fmt.Errorf("failed to get container with id: %v", id) + } + + return c, nil +} + // Cleanup is called from another process (need to reload state) to stop the // container and undo all operations done in Create(). func (s *runscService) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) { @@ -185,8 +217,9 @@ func (s *runscService) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, er if err := r.Delete(ctx, s.id, &runsccmd.DeleteOpts{ Force: true, }); err != nil { - log.L.Infof("failed to remove runc container: %v", err) + log.L.Infof("failed to remove runsc container: %v", err) } + if err := mount.UnmountAll(st.Rootfs, 0); err != nil { log.L.Infof("failed to cleanup rootfs mount: %v", err) } @@ -202,10 +235,6 @@ func (s *runscService) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) s.mu.Lock() defer s.mu.Unlock() - // Save the main task id and bundle to the shim for additional requests. - s.id = r.ID - s.bundle = r.Bundle - ns, err := namespaces.NamespaceRequired(ctx) if err != nil { return nil, fmt.Errorf("create namespace: %w", err) @@ -395,7 +424,16 @@ func (s *runscService) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) // Success cu.Release() - s.task = process + + c := runscContainer{ + id: r.ID, + bundle: r.Bundle, + processes: make(map[string]extension.Process), + events: make(chan any, 128), + task: process, + } + s.containers[r.ID] = &c + return &taskAPI.CreateTaskResponse{ Pid: uint32(process.Pid()), }, nil @@ -403,7 +441,7 @@ func (s *runscService) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) // Start starts the container. func (s *runscService) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI.StartResponse, error) { - p, err := s.getProcess(r.ExecID) + p, err := s.getContainerProcess(r.ID, r.ExecID) if err != nil { return nil, err } @@ -419,19 +457,32 @@ func (s *runscService) Start(ctx context.Context, r *taskAPI.StartRequest) (*tas // Delete deletes the initial process and container. func (s *runscService) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAPI.DeleteResponse, error) { - p, err := s.getProcess(r.ExecID) + c, err := s.getContainer(r.ID) + if err != nil { + return nil, err + } + + p, err := s.getContainerProcess(r.ID, r.ExecID) if err != nil { return nil, err } if err := p.Delete(ctx); err != nil { return nil, err } + if len(r.ExecID) != 0 { + c.mu.Lock() + delete(c.processes, r.ExecID) + c.mu.Unlock() + } else { s.mu.Lock() - delete(s.processes, r.ExecID) + delete(s.containers, r.ID) + hasCont := len(s.containers) > 0 s.mu.Unlock() - } else if s.platform != nil { - s.platform.Close() + + if !hasCont && s.platform != nil { + s.platform.Close() + } } return &taskAPI.DeleteResponse{ ExitStatus: uint32(p.ExitStatus()), @@ -442,16 +493,20 @@ func (s *runscService) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*t // Exec spawns an additional process inside the container. func (s *runscService) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*types.Empty, error) { - s.mu.Lock() - p := s.processes[r.ExecID] - s.mu.Unlock() + c, err := s.getContainer(r.ID) + if err != nil { + return nil, err + } + c.mu.Lock() + defer c.mu.Unlock() + p := c.processes[r.ExecID] if p != nil { return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID) } - if s.task == nil { + if c.task == nil { return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - process, err := s.task.Exec(ctx, s.bundle, &proc.ExecConfig{ + process, err := c.task.Exec(ctx, c.bundle, &proc.ExecConfig{ ID: r.ExecID, Terminal: r.Terminal, Stdin: r.Stdin, @@ -462,15 +517,13 @@ func (s *runscService) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) if err != nil { return nil, err } - s.mu.Lock() - s.processes[r.ExecID] = process - s.mu.Unlock() + c.processes[r.ExecID] = process return empty, nil } // ResizePty resizes the terminal of a process. func (s *runscService) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*types.Empty, error) { - p, err := s.getProcess(r.ExecID) + p, err := s.getContainerProcess(r.ID, r.ExecID) if err != nil { return nil, err } @@ -486,7 +539,7 @@ func (s *runscService) ResizePty(ctx context.Context, r *taskAPI.ResizePtyReques // State returns runtime state information for the container. func (s *runscService) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.StateResponse, error) { - p, err := s.getProcess(r.ExecID) + p, err := s.getContainerProcess(r.ID, r.ExecID) if err != nil { log.L.Debugf("State failed to find process: %v", err) return nil, err @@ -508,7 +561,7 @@ func (s *runscService) State(ctx context.Context, r *taskAPI.StateRequest) (*tas sio := p.Stdio() res := &taskAPI.StateResponse{ ID: p.ID(), - Bundle: s.bundle, + Bundle: s.containers[r.ID].bundle, Pid: uint32(p.Pid()), Status: status, Stdin: sio.Stdin, @@ -524,11 +577,18 @@ func (s *runscService) State(ctx context.Context, r *taskAPI.StateRequest) (*tas // Pause the container. func (s *runscService) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Empty, error) { - if s.task == nil { + c, err := s.getContainer(r.ID) + if err != nil { + return nil, err + } + + c.mu.Lock() + defer c.mu.Unlock() + if c.task == nil { log.L.Debugf("Pause error, id: %s: container not created", r.ID) return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - err := s.task.Runtime().Pause(ctx, r.ID) + err = c.task.Runtime().Pause(ctx, r.ID) if err != nil { return nil, err } @@ -537,11 +597,18 @@ func (s *runscService) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*typ // Resume the container. func (s *runscService) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.Empty, error) { - if s.task == nil { + c, err := s.getContainer(r.ID) + if err != nil { + return nil, err + } + + c.mu.Lock() + defer c.mu.Unlock() + if c.task == nil { log.L.Debugf("Resume error, id: %s: container not created", r.ID) return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - err := s.task.Runtime().Resume(ctx, r.ID) + err = c.task.Runtime().Resume(ctx, r.ID) if err != nil { return nil, err } @@ -550,11 +617,11 @@ func (s *runscService) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*t // Kill the container with the provided signal. func (s *runscService) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empty, error) { - p, err := s.getProcess(r.ExecID) + p, err := s.getContainerProcess(r.ID, r.ExecID) if err != nil { return nil, err } - if err := p.Kill(ctx, r.Signal, r.All); err != nil { + if err = p.Kill(ctx, r.Signal, r.All); err != nil { log.L.Debugf("Kill failed: %v", err) return nil, err } @@ -564,16 +631,26 @@ func (s *runscService) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types // Pids returns all pids inside the container. func (s *runscService) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.PidsResponse, error) { + c, err := s.getContainer(r.ID) + if err != nil { + return nil, err + } + pids, err := s.getContainerPids(ctx, r.ID) if err != nil { return nil, err } + + c.mu.Lock() + contProcesses := c.processes + c.mu.Unlock() + var processes []*task.ProcessInfo for _, pid := range pids { pInfo := task.ProcessInfo{ Pid: pid, } - for _, p := range s.processes { + for _, p := range contProcesses { if p.Pid() == int(pid) { d := &runctypes.ProcessDetails{ ExecID: p.ID(), @@ -595,7 +672,7 @@ func (s *runscService) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskA // CloseIO closes the I/O context of the container. func (s *runscService) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*types.Empty, error) { - p, err := s.getProcess(r.ExecID) + p, err := s.getContainerProcess(r.ID, r.ExecID) if err != nil { return nil, err } @@ -614,7 +691,7 @@ func (s *runscService) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTask // Restore restores the container. func (s *runscService) Restore(ctx context.Context, r *extension.RestoreRequest) (*taskAPI.StartResponse, error) { - p, err := s.getProcess(r.Start.ExecID) + p, err := s.getContainerProcess(r.Start.ID, r.Start.ExecID) if err != nil { return nil, err } @@ -630,9 +707,18 @@ func (s *runscService) Restore(ctx context.Context, r *extension.RestoreRequest) // Connect returns shim information such as the shim's pid. func (s *runscService) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*taskAPI.ConnectResponse, error) { + c, err := s.getContainer(r.ID) + if err != nil { + return nil, err + } + + c.mu.Lock() + task := c.task + c.mu.Unlock() + var pid int - if s.task != nil { - pid = s.task.Pid() + if task != nil { + pid = task.Pid() } return &taskAPI.ConnectResponse{ ShimPid: uint32(os.Getpid()), @@ -641,15 +727,34 @@ func (s *runscService) Connect(ctx context.Context, r *taskAPI.ConnectRequest) ( } func (s *runscService) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) { - return nil, nil + s.mu.Lock() + if len(s.containers) > 0 { + s.mu.Unlock() + return empty, nil + } + s.cancel() + if len(s.shimAddress) != 0 { + _ = shim.RemoveSocket(s.shimAddress) + } + os.Exit(0) + panic("should not come here") } func (s *runscService) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) { - if s.task == nil { + c, err := s.getContainer(r.ID) + if err != nil { + return nil, err + } + + c.mu.Lock() + task := c.task + c.mu.Unlock() + + if task == nil { log.L.Debugf("Stats error, id: %s: container not created", r.ID) return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - stats, err := s.task.Stats(ctx, s.id) + stats, err := task.Stats(ctx, r.ID) if err != nil { log.L.Debugf("Stats error, id: %s: %v", r.ID, err) return nil, err @@ -773,7 +878,7 @@ func (s *runscService) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) // Wait waits for the container to exit. func (s *runscService) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.WaitResponse, error) { - p, err := s.getProcess(r.ExecID) + p, err := s.getContainerProcess(r.ID, r.ExecID) if err != nil { log.L.Debugf("Wait failed to find process: %v", err) return nil, err @@ -795,44 +900,60 @@ func (s *runscService) processExits(ctx context.Context) { } func (s *runscService) checkProcesses(ctx context.Context, e proc.Exit) { - // TODO(random-liu): Add `shouldKillAll` logic if container pid - // namespace is supported. - for _, p := range s.allProcesses() { - if p.ID() == e.ID { - if ip, ok := p.(*proc.Init); ok { - // Ensure all children are killed. - log.L.Debugf("Container init process exited, killing all container processes") - ip.KillAll(ctx) - } - p.SetExited(e.Status) - s.events <- &events.TaskExit{ - ContainerID: s.id, - ID: p.ID(), - Pid: uint32(p.Pid()), - ExitStatus: uint32(e.Status), - ExitedAt: p.ExitedAt(), + s.mu.Lock() + defer s.mu.Unlock() + + containers := s.containers + for _, c := range containers { + // TODO(random-liu): Add `shouldKillAll` logic if container pid + // namespace is supported. + for _, p := range c.allProcesses() { + if p.ID() == e.ID { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed. + log.L.Debugf("Container init process exited, killing all container processes") + ip.KillAll(ctx) + } + p.SetExited(e.Status) + + c.mu.Lock() + c.events <- &events.TaskExit{ + ContainerID: c.id, + ID: p.ID(), + Pid: uint32(p.Pid()), + ExitStatus: uint32(e.Status), + ExitedAt: p.ExitedAt(), + } + c.mu.Unlock() + return } - return } } } -func (s *runscService) allProcesses() (o []process.Process) { - s.mu.Lock() - defer s.mu.Unlock() - for _, p := range s.processes { +func (c *runscContainer) allProcesses() (o []process.Process) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, p := range c.processes { o = append(o, p) } - if s.task != nil { - o = append(o, s.task) + if c.task != nil { + o = append(o, c.task) } return o } func (s *runscService) getContainerPids(ctx context.Context, id string) ([]uint32, error) { - s.mu.Lock() - p := s.task - s.mu.Unlock() + c, err := s.getContainer(id) + if err != nil { + return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition) + } + + c.mu.Lock() + p := c.task + c.mu.Unlock() + if p == nil { return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition) } @@ -848,27 +969,40 @@ func (s *runscService) getContainerPids(ctx context.Context, id string) ([]uint3 } func (s *runscService) forward(ctx context.Context, publisher shim.Publisher) { - for e := range s.events { - err := publisher.Publish(ctx, getTopic(e), e) - if err != nil { - // Should not happen. - panic(fmt.Errorf("post event: %w", err)) + s.mu.Lock() + defer s.mu.Unlock() + + containers := s.containers + for _, c := range containers { + c.mu.Lock() + for e := range c.events { + err := publisher.Publish(ctx, getTopic(e), e) + if err != nil { + // Should not happen. + panic(fmt.Errorf("post event: %w", err)) + } } + c.mu.Unlock() } } -func (s *runscService) getProcess(execID string) (extension.Process, error) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *runscService) getContainerProcess(id string, execID string) (extension.Process, error) { + c, err := s.getContainer(id) + if err != nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + + c.mu.Lock() + defer c.mu.Unlock() if execID == "" { - if s.task == nil { + if c.task == nil { return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - return s.task, nil + return c.task, nil } - p := s.processes[execID] + p := c.processes[execID] if p == nil { return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID) } diff --git a/pkg/shim/v1/service.go b/pkg/shim/v1/service.go index d06140697e..2682ae8fe3 100644 --- a/pkg/shim/v1/service.go +++ b/pkg/shim/v1/service.go @@ -51,21 +51,21 @@ func New(ctx context.Context, id string, publisher shim.Publisher, cancel func() opts = ctxOpts.(shim.Opts) } - runsc, err := runsc.New(ctx, id, publisher) + var shimAddress string + if address, err := shim.ReadAddress(shimAddressPath); err == nil { + shimAddress = address + } + + runsc, err := runsc.New(ctx, id, publisher, cancel, shimAddress) if err != nil { cancel() return nil, err } s := &service{ genericOptions: opts, - cancel: cancel, main: runsc, } - if address, err := shim.ReadAddress(shimAddressPath); err == nil { - s.shimAddress = address - } - return s, nil } @@ -87,13 +87,6 @@ type service struct { // to all shims. genericOptions shim.Opts - // cancel is a function that needs to be called before the shim stops. The - // function is provided by the caller to New(). - cancel func() - - // shimAddress is the location of the UDS used to communicate to containerd. - shimAddress string - // main is the extension.TaskServiceExt that is used for all calls to the // container's shim, except for the cases where `ext` is set. // @@ -345,12 +338,7 @@ func (s *service) Shutdown(ctx context.Context, r *taskapi.ShutdownRequest) (*ty return resp, errdefs.ToGRPC(err) } - s.cancel() - if len(s.shimAddress) != 0 { - _ = shim.RemoveSocket(s.shimAddress) - } - os.Exit(0) - panic("Should not get here") + return resp, errdefs.ToGRPC(err) } func (s *service) Stats(ctx context.Context, r *taskapi.StatsRequest) (*taskapi.StatsResponse, error) {