Skip to content

Commit c7180fe

Browse files
authored
Put envd, ptys, socats, and commands into their own cgroups (#1580)
1 parent 03f977c commit c7180fe

File tree

13 files changed

+474
-25
lines changed

13 files changed

+474
-25
lines changed

packages/envd/internal/port/forward.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"syscall"
1515

1616
"github.com/rs/zerolog"
17+
18+
"github.com/e2b-dev/infra/packages/envd/internal/services/cgroups"
1719
)
1820

1921
type PortState string
@@ -36,7 +38,8 @@ type PortToForward struct {
3638
}
3739

3840
type Forwarder struct {
39-
logger *zerolog.Logger
41+
logger *zerolog.Logger
42+
cgroupManager cgroups.Manager
4043
// Map of ports that are being currently forwarded.
4144
ports map[string]*PortToForward
4245
scannerSubscriber *ScannerSubscriber
@@ -46,6 +49,7 @@ type Forwarder struct {
4649
func NewForwarder(
4750
logger *zerolog.Logger,
4851
scanner *Scanner,
52+
cgroupManager cgroups.Manager,
4953
) *Forwarder {
5054
scannerSub := scanner.AddSubscriber(
5155
logger,
@@ -62,6 +66,7 @@ func NewForwarder(
6266
sourceIP: defaultGatewayIP,
6367
ports: make(map[string]*PortToForward),
6468
scannerSubscriber: scannerSub,
69+
cgroupManager: cgroupManager,
6570
}
6671
}
6772

@@ -135,8 +140,13 @@ func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) {
135140
fmt.Sprintf("TCP4-LISTEN:%v,bind=%s,reuseaddr,fork", p.port, f.sourceIP.To4()),
136141
fmt.Sprintf("TCP%d:localhost:%v", p.family, p.port),
137142
)
143+
144+
cgroupFD, ok := f.cgroupManager.GetFileDescriptor(cgroups.ProcessTypeSocat)
145+
138146
cmd.SysProcAttr = &syscall.SysProcAttr{
139-
Setpgid: true,
147+
Setpgid: true,
148+
CgroupFD: cgroupFD,
149+
UseCgroupFD: ok,
140150
}
141151

142152
f.logger.Debug().
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package cgroups
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
9+
"golang.org/x/sys/unix"
10+
)
11+
12+
type Cgroup2Manager struct {
13+
cgroupFDs map[ProcessType]int
14+
}
15+
16+
var _ Manager = (*Cgroup2Manager)(nil)
17+
18+
type cgroup2Config struct {
19+
rootPath string
20+
processTypes map[ProcessType]Cgroup2Config
21+
}
22+
23+
type Cgroup2ManagerOption func(*cgroup2Config)
24+
25+
func WithCgroup2RootSysFSPath(path string) Cgroup2ManagerOption {
26+
return func(config *cgroup2Config) {
27+
config.rootPath = path
28+
}
29+
}
30+
31+
func WithCgroup2ProcessType(processType ProcessType, path string, properties map[string]string) Cgroup2ManagerOption {
32+
return func(config *cgroup2Config) {
33+
if config.processTypes == nil {
34+
config.processTypes = make(map[ProcessType]Cgroup2Config)
35+
}
36+
config.processTypes[processType] = Cgroup2Config{Path: path, Properties: properties}
37+
}
38+
}
39+
40+
type Cgroup2Config struct {
41+
Path string
42+
Properties map[string]string
43+
}
44+
45+
func NewCgroup2Manager(opts ...Cgroup2ManagerOption) (*Cgroup2Manager, error) {
46+
config := cgroup2Config{
47+
rootPath: "/sys/fs/cgroup",
48+
}
49+
50+
for _, opt := range opts {
51+
opt(&config)
52+
}
53+
54+
cgroupFDs, err := createCgroups(config)
55+
if err != nil {
56+
return nil, fmt.Errorf("failed to create cgroups: %w", err)
57+
}
58+
59+
return &Cgroup2Manager{cgroupFDs: cgroupFDs}, nil
60+
}
61+
62+
func createCgroups(configs cgroup2Config) (map[ProcessType]int, error) {
63+
var (
64+
results = make(map[ProcessType]int)
65+
errs []error
66+
)
67+
68+
for procType, config := range configs.processTypes {
69+
fullPath := filepath.Join(configs.rootPath, config.Path)
70+
fd, err := createCgroup(fullPath, config.Properties)
71+
if err != nil {
72+
errs = append(errs, fmt.Errorf("failed to create %s cgroup: %w", procType, err))
73+
74+
continue
75+
}
76+
results[procType] = fd
77+
}
78+
79+
if len(errs) > 0 {
80+
for procType, fd := range results {
81+
err := unix.Close(fd)
82+
if err != nil {
83+
errs = append(errs, fmt.Errorf("failed to close cgroup fd for %s: %w", procType, err))
84+
}
85+
}
86+
87+
return nil, errors.Join(errs...)
88+
}
89+
90+
return results, nil
91+
}
92+
93+
func createCgroup(fullPath string, properties map[string]string) (int, error) {
94+
if err := os.MkdirAll(fullPath, 0o755); err != nil {
95+
return -1, fmt.Errorf("failed to create cgroup root: %w", err)
96+
}
97+
98+
var errs []error
99+
for name, value := range properties {
100+
if err := os.WriteFile(filepath.Join(fullPath, name), []byte(value), 0o644); err != nil {
101+
errs = append(errs, fmt.Errorf("failed to write cgroup property: %w", err))
102+
}
103+
}
104+
if len(errs) > 0 {
105+
return -1, errors.Join(errs...)
106+
}
107+
108+
return unix.Open(fullPath, unix.O_RDONLY, 0)
109+
}
110+
111+
func (c Cgroup2Manager) GetFileDescriptor(procType ProcessType) (int, bool) {
112+
fd, ok := c.cgroupFDs[procType]
113+
114+
return fd, ok
115+
}
116+
117+
func (c Cgroup2Manager) Close() error {
118+
var errs []error
119+
for procType, fd := range c.cgroupFDs {
120+
if err := unix.Close(fd); err != nil {
121+
errs = append(errs, fmt.Errorf("failed to close cgroup fd for %s: %w", procType, err))
122+
}
123+
delete(c.cgroupFDs, procType)
124+
}
125+
126+
return errors.Join(errs...)
127+
}
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
package cgroups
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"math/rand"
7+
"os"
8+
"os/exec"
9+
"strconv"
10+
"syscall"
11+
"testing"
12+
"time"
13+
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
const (
19+
oneByte = 1
20+
kilobyte = 1024 * oneByte
21+
megabyte = 1024 * kilobyte
22+
)
23+
24+
func TestCgroupRoundTrip(t *testing.T) {
25+
t.Parallel()
26+
27+
if os.Geteuid() != 0 {
28+
t.Skip("must run as root")
29+
30+
return
31+
}
32+
33+
maxTimeout := time.Second * 5
34+
35+
t.Run("process does not die without cgroups", func(t *testing.T) {
36+
t.Parallel()
37+
38+
// create manager
39+
m, err := NewCgroup2Manager()
40+
require.NoError(t, err)
41+
42+
// create new child process
43+
cmd := startProcess(t, m, "not-a-real-one")
44+
45+
// wait for child process to die
46+
err = waitForProcess(t, cmd, maxTimeout)
47+
48+
require.ErrorIs(t, err, context.DeadlineExceeded)
49+
})
50+
51+
t.Run("process dies with cgroups", func(t *testing.T) {
52+
t.Parallel()
53+
54+
cgroupPath := createCgroupPath(t, "real-one")
55+
56+
// create manager
57+
m, err := NewCgroup2Manager(
58+
WithCgroup2ProcessType(ProcessTypePTY, cgroupPath, map[string]string{
59+
"memory.max": strconv.Itoa(1 * megabyte),
60+
}),
61+
)
62+
require.NoError(t, err)
63+
64+
t.Cleanup(func() {
65+
err := m.Close()
66+
assert.NoError(t, err)
67+
})
68+
69+
// create new child process
70+
cmd := startProcess(t, m, ProcessTypePTY)
71+
72+
// wait for child process to die
73+
err = waitForProcess(t, cmd, maxTimeout)
74+
75+
// verify process exited correctly
76+
var exitErr *exec.ExitError
77+
require.ErrorAs(t, err, &exitErr)
78+
assert.Equal(t, "signal: killed", exitErr.Error())
79+
assert.False(t, exitErr.Exited())
80+
assert.False(t, exitErr.Success())
81+
assert.Equal(t, -1, exitErr.ExitCode())
82+
83+
// dig a little deeper
84+
ws, ok := exitErr.Sys().(syscall.WaitStatus)
85+
require.True(t, ok)
86+
assert.Equal(t, syscall.SIGKILL, ws.Signal())
87+
assert.True(t, ws.Signaled())
88+
assert.False(t, ws.Stopped())
89+
assert.False(t, ws.Continued())
90+
assert.False(t, ws.CoreDump())
91+
assert.False(t, ws.Exited())
92+
assert.Equal(t, -1, ws.ExitStatus())
93+
})
94+
95+
t.Run("process cannot be spawned because memory limit is too low", func(t *testing.T) {
96+
t.Parallel()
97+
98+
cgroupPath := createCgroupPath(t, "real-one")
99+
100+
// create manager
101+
m, err := NewCgroup2Manager(
102+
WithCgroup2ProcessType(ProcessTypeSocat, cgroupPath, map[string]string{
103+
"memory.max": strconv.Itoa(1 * kilobyte),
104+
}),
105+
)
106+
require.NoError(t, err)
107+
108+
t.Cleanup(func() {
109+
err := m.Close()
110+
assert.NoError(t, err)
111+
})
112+
113+
// create new child process
114+
cmd := startProcess(t, m, ProcessTypeSocat)
115+
116+
// wait for child process to die
117+
err = waitForProcess(t, cmd, maxTimeout)
118+
119+
// verify process exited correctly
120+
var exitErr *exec.ExitError
121+
require.ErrorAs(t, err, &exitErr)
122+
assert.Equal(t, "exit status 253", exitErr.Error())
123+
assert.True(t, exitErr.Exited())
124+
assert.False(t, exitErr.Success())
125+
assert.Equal(t, 253, exitErr.ExitCode())
126+
127+
// dig a little deeper
128+
ws, ok := exitErr.Sys().(syscall.WaitStatus)
129+
require.True(t, ok)
130+
assert.Equal(t, syscall.Signal(-1), ws.Signal())
131+
assert.False(t, ws.Signaled())
132+
assert.False(t, ws.Stopped())
133+
assert.False(t, ws.Continued())
134+
assert.False(t, ws.CoreDump())
135+
assert.True(t, ws.Exited())
136+
assert.Equal(t, 253, ws.ExitStatus())
137+
})
138+
}
139+
140+
func createCgroupPath(t *testing.T, s string) string {
141+
t.Helper()
142+
143+
randPart := rand.Int()
144+
145+
return fmt.Sprintf("envd-test-%s-%d", s, randPart)
146+
}
147+
148+
func startProcess(t *testing.T, m *Cgroup2Manager, pt ProcessType) *exec.Cmd {
149+
t.Helper()
150+
151+
cmdName, args := "bash", []string{"-c", `sleep 1 && tail /dev/zero`}
152+
cmd := exec.CommandContext(t.Context(), cmdName, args...)
153+
154+
fd, ok := m.GetFileDescriptor(pt)
155+
cmd.SysProcAttr = &syscall.SysProcAttr{
156+
UseCgroupFD: ok,
157+
CgroupFD: fd,
158+
}
159+
160+
err := cmd.Start()
161+
require.NoError(t, err)
162+
163+
return cmd
164+
}
165+
166+
func waitForProcess(t *testing.T, cmd *exec.Cmd, timeout time.Duration) error {
167+
t.Helper()
168+
169+
done := make(chan error, 1)
170+
171+
go func() {
172+
defer close(done)
173+
done <- cmd.Wait()
174+
}()
175+
176+
ctx, cancel := context.WithTimeout(t.Context(), timeout)
177+
t.Cleanup(cancel)
178+
179+
select {
180+
case <-ctx.Done():
181+
return ctx.Err()
182+
case err := <-done:
183+
return err
184+
}
185+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package cgroups
2+
3+
type ProcessType string
4+
5+
const (
6+
ProcessTypePTY ProcessType = "pty"
7+
ProcessTypeUser ProcessType = "user"
8+
ProcessTypeSocat ProcessType = "socat"
9+
)
10+
11+
type Manager interface {
12+
GetFileDescriptor(procType ProcessType) (int, bool)
13+
Close() error
14+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package cgroups
2+
3+
type NoopManager struct{}
4+
5+
var _ Manager = (*NoopManager)(nil)
6+
7+
func NewNoopManager() *NoopManager {
8+
return &NoopManager{}
9+
}
10+
11+
func (n NoopManager) GetFileDescriptor(ProcessType) (int, bool) {
12+
return 0, false
13+
}
14+
15+
func (n NoopManager) Close() error {
16+
return nil
17+
}

0 commit comments

Comments
 (0)