Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions runner/cmd/runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/dstackai/dstack/runner/consts"
"github.com/dstackai/dstack/runner/internal/executor"
linuxuser "github.com/dstackai/dstack/runner/internal/linux/user"
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/runner/api"
"github.com/dstackai/dstack/runner/internal/ssh"
Expand All @@ -30,7 +31,6 @@ func main() {

func mainInner() int {
var tempDir string
var homeDir string
var httpPort int
var sshPort int
var sshAuthorizedKeys []string
Expand Down Expand Up @@ -61,13 +61,6 @@ func mainInner() int {
Destination: &tempDir,
TakesFile: true,
},
&cli.StringFlag{
Name: "home-dir",
Usage: "HomeDir directory for credentials and $HOME",
Value: consts.RunnerHomeDir,
Destination: &homeDir,
TakesFile: true,
},
&cli.IntFlag{
Name: "http-port",
Usage: "Set a http port",
Expand All @@ -87,7 +80,7 @@ func mainInner() int {
},
},
Action: func(ctx context.Context, cmd *cli.Command) error {
return start(ctx, tempDir, homeDir, httpPort, sshPort, sshAuthorizedKeys, logLevel, Version)
return start(ctx, tempDir, httpPort, sshPort, sshAuthorizedKeys, logLevel, Version)
},
},
},
Expand All @@ -104,7 +97,7 @@ func mainInner() int {
return 0
}

func start(ctx context.Context, tempDir string, homeDir string, httpPort int, sshPort int, sshAuthorizedKeys []string, logLevel int, version string) error {
func start(ctx context.Context, tempDir string, httpPort int, sshPort int, sshAuthorizedKeys []string, logLevel int, version string) error {
if err := os.MkdirAll(tempDir, 0o755); err != nil {
return fmt.Errorf("create temp directory: %w", err)
}
Expand All @@ -114,15 +107,39 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
return fmt.Errorf("create default log file: %w", err)
}
defer func() {
closeErr := defaultLogFile.Close()
if closeErr != nil {
log.Error(ctx, "Failed to close default log file", "err", closeErr)
if err := defaultLogFile.Close(); err != nil {
log.Error(ctx, "Failed to close default log file", "err", err)
}
}()

log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile))
log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel))

currentUser, err := linuxuser.FromCurrentProcess()
if err != nil {
return fmt.Errorf("get current process user: %w", err)
}
if !currentUser.IsRoot() {
return fmt.Errorf("must be root: %s", currentUser)
}
if currentUser.HomeDir == "" {
log.Warning(ctx, "Current user does not have home dir, using /root as a fallback", "user", currentUser)
currentUser.HomeDir = "/root"
}
// Fix the current process HOME, just in case some internals require it (e.g., they use os.UserHomeDir() or
// spawn a child process which uses that variable)
envHome, envHomeIsSet := os.LookupEnv("HOME")
if envHome != currentUser.HomeDir {
if !envHomeIsSet {
log.Warning(ctx, "HOME is not set, setting the value", "home", currentUser.HomeDir)
} else {
log.Warning(ctx, "HOME is incorrect, fixing the value", "current", envHome, "home", currentUser.HomeDir)
}
if err := os.Setenv("HOME", currentUser.HomeDir); err != nil {
return fmt.Errorf("set HOME: %w", err)
}
}
log.Trace(ctx, "Running as", "user", currentUser)

// NB: The Mkdir/Chown/Chmod code below relies on the fact that RunnerDstackDir path is _not_ nested (/dstack).
// Adjust it if the path is changed to, e.g., /opt/dstack
const dstackDir = consts.RunnerDstackDir
Expand Down Expand Up @@ -163,7 +180,7 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
}
}()

ex, err := executor.NewRunExecutor(tempDir, homeDir, dstackDir, sshd)
ex, err := executor.NewRunExecutor(tempDir, dstackDir, *currentUser, sshd)
if err != nil {
return fmt.Errorf("create executor: %w", err)
}
Expand Down
4 changes: 0 additions & 4 deletions runner/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ const (
// NOTE: RunnerRuntimeDir would be a more appropriate name, but it's called tempDir
// throughout runner's codebase
RunnerTempDir = "/tmp/runner"
// Currently, it's a directory where authorized_keys, git credentials, etc. are placed
// The current user's homedir (as of 2024-12-28, it's always root) should be used
// instead of the hardcoded value
RunnerHomeDir = "/root"
// A directory for:
// 1. Files used by the runner and related components (e.g., sshd stores its config and log inside /dstack/ssh)
// 2. Files shared between users (e.g., sshd authorized_keys, MPI hostfile)
Expand Down
4 changes: 2 additions & 2 deletions runner/internal/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func ExpandPath(pth string, base string, home string) (string, error) {
return pth, nil
}

func MkdirAll(ctx context.Context, pth string, uid int, gid int) error {
func MkdirAll(ctx context.Context, pth string, uid int, gid int, perm os.FileMode) error {
paths := []string{pth}
for {
pth = path.Dir(pth)
Expand All @@ -60,7 +60,7 @@ func MkdirAll(ctx context.Context, pth string, uid int, gid int) error {
}
for _, p := range slices.Backward(paths) {
if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
if err := os.Mkdir(p, 0o755); err != nil {
if err := os.Mkdir(p, perm); err != nil {
return err
}
if uid != -1 || gid != -1 {
Expand Down
8 changes: 4 additions & 4 deletions runner/internal/common/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,15 @@ func TestExpandtPath_ErrorTildeUsernameNotSupported_TildeUsernameWithPath(t *tes
func TestMkdirAll_AbsPath_NotExists(t *testing.T) {
absPath := path.Join(t.TempDir(), "a/b/c")
require.NoDirExists(t, absPath)
err := MkdirAll(context.Background(), absPath, -1, -1)
err := MkdirAll(context.Background(), absPath, -1, -1, 0o755)
require.NoError(t, err)
require.DirExists(t, absPath)
}

func TestMkdirAll_AbsPath_Exists(t *testing.T) {
absPath, err := os.Getwd()
require.NoError(t, err)
err = MkdirAll(context.Background(), absPath, -1, -1)
err = MkdirAll(context.Background(), absPath, -1, -1, 0o755)
require.NoError(t, err)
require.DirExists(t, absPath)
}
Expand All @@ -139,7 +139,7 @@ func TestMkdirAll_RelPath_NotExists(t *testing.T) {
relPath := "a/b/c"
absPath := path.Join(cwd, relPath)
require.NoDirExists(t, absPath)
err := MkdirAll(context.Background(), relPath, -1, -1)
err := MkdirAll(context.Background(), relPath, -1, -1, 0o755)
require.NoError(t, err)
require.DirExists(t, absPath)
}
Expand All @@ -151,7 +151,7 @@ func TestMkdirAll_RelPath_Exists(t *testing.T) {
absPath := path.Join(cwd, relPath)
err := os.MkdirAll(absPath, 0o755)
require.NoError(t, err)
err = MkdirAll(context.Background(), relPath, -1, -1)
err = MkdirAll(context.Background(), relPath, -1, -1, 0o755)
require.NoError(t, err)
require.DirExists(t, absPath)
}
Loading