From 16a7785c489d81ac1bda8a7203a7dfb0dccb9c95 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Mon, 12 Jan 2026 14:09:20 +0000 Subject: [PATCH] [runner] Rework and fix user processing * Drop --home-dir option, use process user's home dir instead * Fix ownership of Git credentials, consider Git credentials errors non-fatal Closes: https://github.com/dstackai/dstack/issues/3419 --- runner/cmd/runner/main.go | 47 +- runner/consts/consts.go | 4 - runner/internal/common/utils.go | 4 +- runner/internal/common/utils_test.go | 8 +- runner/internal/executor/executor.go | 426 ++++++------------ runner/internal/executor/executor_test.go | 20 +- runner/internal/executor/files.go | 25 +- runner/internal/executor/repo.go | 20 +- runner/internal/executor/user.go | 184 ++++++++ runner/internal/executor/user_test.go | 232 ++++++++++ runner/internal/linux/user/user.go | 96 ++++ runner/internal/schemas/schemas.go | 16 - runner/internal/shim/docker.go | 3 - .../_internal/core/backends/base/compute.py | 4 - .../core/backends/kubernetes/compute.py | 1 - .../_internal/server/services/proxy/repo.py | 2 +- src/dstack/_internal/server/services/ssh.py | 2 +- 17 files changed, 715 insertions(+), 379 deletions(-) create mode 100644 runner/internal/executor/user.go create mode 100644 runner/internal/executor/user_test.go create mode 100644 runner/internal/linux/user/user.go diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index 27e529417..c2ed94f0e 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -16,6 +16,7 @@ import ( "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/executor" + linuxuser "github.com/dstackai/dstack/runner/internal/linux/user" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/runner/api" "github.com/dstackai/dstack/runner/internal/ssh" @@ -30,7 +31,6 @@ func main() { func mainInner() int { var tempDir string - var homeDir string var httpPort int var sshPort int var sshAuthorizedKeys []string @@ -61,13 +61,6 @@ func mainInner() int { Destination: &tempDir, TakesFile: true, }, - &cli.StringFlag{ - Name: "home-dir", - Usage: "HomeDir directory for credentials and $HOME", - Value: consts.RunnerHomeDir, - Destination: &homeDir, - TakesFile: true, - }, &cli.IntFlag{ Name: "http-port", Usage: "Set a http port", @@ -87,7 +80,7 @@ func mainInner() int { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - return start(ctx, tempDir, homeDir, httpPort, sshPort, sshAuthorizedKeys, logLevel, Version) + return start(ctx, tempDir, httpPort, sshPort, sshAuthorizedKeys, logLevel, Version) }, }, }, @@ -104,7 +97,7 @@ func mainInner() int { return 0 } -func start(ctx context.Context, tempDir string, homeDir string, httpPort int, sshPort int, sshAuthorizedKeys []string, logLevel int, version string) error { +func start(ctx context.Context, tempDir string, httpPort int, sshPort int, sshAuthorizedKeys []string, logLevel int, version string) error { if err := os.MkdirAll(tempDir, 0o755); err != nil { return fmt.Errorf("create temp directory: %w", err) } @@ -114,15 +107,39 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss return fmt.Errorf("create default log file: %w", err) } defer func() { - closeErr := defaultLogFile.Close() - if closeErr != nil { - log.Error(ctx, "Failed to close default log file", "err", closeErr) + if err := defaultLogFile.Close(); err != nil { + log.Error(ctx, "Failed to close default log file", "err", err) } }() - log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile)) log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel)) + currentUser, err := linuxuser.FromCurrentProcess() + if err != nil { + return fmt.Errorf("get current process user: %w", err) + } + if !currentUser.IsRoot() { + return fmt.Errorf("must be root: %s", currentUser) + } + if currentUser.HomeDir == "" { + log.Warning(ctx, "Current user does not have home dir, using /root as a fallback", "user", currentUser) + currentUser.HomeDir = "/root" + } + // Fix the current process HOME, just in case some internals require it (e.g., they use os.UserHomeDir() or + // spawn a child process which uses that variable) + envHome, envHomeIsSet := os.LookupEnv("HOME") + if envHome != currentUser.HomeDir { + if !envHomeIsSet { + log.Warning(ctx, "HOME is not set, setting the value", "home", currentUser.HomeDir) + } else { + log.Warning(ctx, "HOME is incorrect, fixing the value", "current", envHome, "home", currentUser.HomeDir) + } + if err := os.Setenv("HOME", currentUser.HomeDir); err != nil { + return fmt.Errorf("set HOME: %w", err) + } + } + log.Trace(ctx, "Running as", "user", currentUser) + // NB: The Mkdir/Chown/Chmod code below relies on the fact that RunnerDstackDir path is _not_ nested (/dstack). // Adjust it if the path is changed to, e.g., /opt/dstack const dstackDir = consts.RunnerDstackDir @@ -163,7 +180,7 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss } }() - ex, err := executor.NewRunExecutor(tempDir, homeDir, dstackDir, sshd) + ex, err := executor.NewRunExecutor(tempDir, dstackDir, *currentUser, sshd) if err != nil { return fmt.Errorf("create executor: %w", err) } diff --git a/runner/consts/consts.go b/runner/consts/consts.go index 4da4a139f..99f405c29 100644 --- a/runner/consts/consts.go +++ b/runner/consts/consts.go @@ -26,10 +26,6 @@ const ( // NOTE: RunnerRuntimeDir would be a more appropriate name, but it's called tempDir // throughout runner's codebase RunnerTempDir = "/tmp/runner" - // Currently, it's a directory where authorized_keys, git credentials, etc. are placed - // The current user's homedir (as of 2024-12-28, it's always root) should be used - // instead of the hardcoded value - RunnerHomeDir = "/root" // A directory for: // 1. Files used by the runner and related components (e.g., sshd stores its config and log inside /dstack/ssh) // 2. Files shared between users (e.g., sshd authorized_keys, MPI hostfile) diff --git a/runner/internal/common/utils.go b/runner/internal/common/utils.go index 258279970..5be68edf7 100644 --- a/runner/internal/common/utils.go +++ b/runner/internal/common/utils.go @@ -49,7 +49,7 @@ func ExpandPath(pth string, base string, home string) (string, error) { return pth, nil } -func MkdirAll(ctx context.Context, pth string, uid int, gid int) error { +func MkdirAll(ctx context.Context, pth string, uid int, gid int, perm os.FileMode) error { paths := []string{pth} for { pth = path.Dir(pth) @@ -60,7 +60,7 @@ func MkdirAll(ctx context.Context, pth string, uid int, gid int) error { } for _, p := range slices.Backward(paths) { if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) { - if err := os.Mkdir(p, 0o755); err != nil { + if err := os.Mkdir(p, perm); err != nil { return err } if uid != -1 || gid != -1 { diff --git a/runner/internal/common/utils_test.go b/runner/internal/common/utils_test.go index a49d080a2..5fe780d50 100644 --- a/runner/internal/common/utils_test.go +++ b/runner/internal/common/utils_test.go @@ -120,7 +120,7 @@ func TestExpandtPath_ErrorTildeUsernameNotSupported_TildeUsernameWithPath(t *tes func TestMkdirAll_AbsPath_NotExists(t *testing.T) { absPath := path.Join(t.TempDir(), "a/b/c") require.NoDirExists(t, absPath) - err := MkdirAll(context.Background(), absPath, -1, -1) + err := MkdirAll(context.Background(), absPath, -1, -1, 0o755) require.NoError(t, err) require.DirExists(t, absPath) } @@ -128,7 +128,7 @@ func TestMkdirAll_AbsPath_NotExists(t *testing.T) { func TestMkdirAll_AbsPath_Exists(t *testing.T) { absPath, err := os.Getwd() require.NoError(t, err) - err = MkdirAll(context.Background(), absPath, -1, -1) + err = MkdirAll(context.Background(), absPath, -1, -1, 0o755) require.NoError(t, err) require.DirExists(t, absPath) } @@ -139,7 +139,7 @@ func TestMkdirAll_RelPath_NotExists(t *testing.T) { relPath := "a/b/c" absPath := path.Join(cwd, relPath) require.NoDirExists(t, absPath) - err := MkdirAll(context.Background(), relPath, -1, -1) + err := MkdirAll(context.Background(), relPath, -1, -1, 0o755) require.NoError(t, err) require.DirExists(t, absPath) } @@ -151,7 +151,7 @@ func TestMkdirAll_RelPath_Exists(t *testing.T) { absPath := path.Join(cwd, relPath) err := os.MkdirAll(absPath, 0o755) require.NoError(t, err) - err = MkdirAll(context.Background(), relPath, -1, -1) + err = MkdirAll(context.Background(), relPath, -1, -1, 0o755) require.NoError(t, err) require.DirExists(t, absPath) } diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index fc4039cf9..cd3bd1be9 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -9,7 +9,6 @@ import ( "net/url" "os" "os/exec" - osuser "os/user" "path" "path/filepath" "runtime" @@ -27,6 +26,7 @@ import ( "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/common" "github.com/dstackai/dstack/runner/internal/connections" + linuxuser "github.com/dstackai/dstack/runner/internal/linux/user" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/schemas" "github.com/dstackai/dstack/runner/internal/ssh" @@ -52,14 +52,13 @@ type ConnectionTracker interface { } type RunExecutor struct { - tempDir string - homeDir string - dstackDir string + tempDir string + dstackDir string + currentUser linuxuser.User + sshd ssh.SshdManager + fileArchiveDir string repoBlobDir string - sshd ssh.SshdManager - - currentUid uint32 run schemas.Run jobSpec schemas.JobSpec @@ -69,10 +68,9 @@ type RunExecutor struct { repoCredentials *schemas.RepoCredentials repoDir string repoBlobPath string - jobUid int - jobGid int - jobHomeDir string - jobWorkingDir string + // If the user is not specified in the JobSpec, jobUser should point to currentUser + jobUser *linuxuser.User + jobWorkingDir string mu *sync.RWMutex state string @@ -93,17 +91,9 @@ func (s *stubConnectionTracker) GetNoConnectionsSecs() int64 { return 0 } func (s *stubConnectionTracker) Track(ticker <-chan time.Time) {} func (s *stubConnectionTracker) Stop() {} -func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.SshdManager) (*RunExecutor, error) { +func NewRunExecutor(tempDir string, dstackDir string, currentUser linuxuser.User, sshd ssh.SshdManager) (*RunExecutor, error) { mu := &sync.RWMutex{} timestamp := NewMonotonicTimestamp() - user, err := osuser.Current() - if err != nil { - return nil, fmt.Errorf("failed to get current user: %w", err) - } - uid, err := parseStringId(user.Uid) - if err != nil { - return nil, fmt.Errorf("failed to parse current user uid: %w", err) - } // Try to initialize procfs, but don't fail if it's not available (e.g., on macOS) var connectionTracker ConnectionTracker @@ -124,15 +114,13 @@ func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.S } return &RunExecutor{ - tempDir: tempDir, - homeDir: homeDir, - dstackDir: dstackDir, + tempDir: tempDir, + dstackDir: dstackDir, + currentUser: currentUser, + sshd: sshd, + fileArchiveDir: filepath.Join(tempDir, "file_archives"), repoBlobDir: filepath.Join(tempDir, "repo_blobs"), - sshd: sshd, - currentUid: uid, - jobUid: -1, - jobGid: -1, mu: mu, state: WaitSubmit, @@ -188,29 +176,41 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { ctx = log.WithLogger(ctx, log.NewEntry(logger, int(log.DefaultEntry.Logger.Level))) // todo loglevel log.Info(ctx, "Run job", "log_level", log.GetLogger(ctx).Logger.Level.String()) - if ex.jobSpec.User == nil { - ex.jobSpec.User = &schemas.User{Uid: &ex.currentUid} - } - if err := fillUser(ex.jobSpec.User); err != nil { + if err := ex.setJobUser(ctx); err != nil { ex.SetJobStateWithTerminationReason( ctx, types.JobStateFailed, types.TerminationReasonExecutorError, - fmt.Sprintf("Failed to fill in the job user fields (%s)", err), + fmt.Sprintf("Failed to set job user (%s)", err), ) - return fmt.Errorf("fill user: %w", err) + return fmt.Errorf("set job user: %w", err) } - ex.setJobCredentials(ctx) + // setJobUser sets User.HomeDir to "/" if the original home dir is not set or not accessible, + // in that case we skip home dir provisioning + if ex.jobUser.HomeDir == "/" { + log.Info(ctx, "Skipping home dir provisioning") + } else { + // All home dir-related errors are considered non-fatal + cleanupGitCredentials, err := ex.setupGitCredentials(ctx) + if err != nil { + log.Error(ctx, "Failed to set up Git credentials", "err", err) + } else { + defer cleanupGitCredentials() + } + if err := ex.setupClusterSsh(ctx); err != nil { + log.Error(ctx, "Failed to set up cluster SSH", "err", err) + } + } if err := ex.setJobWorkingDir(ctx); err != nil { ex.SetJobStateWithTerminationReason( ctx, types.JobStateFailed, types.TerminationReasonExecutorError, - fmt.Sprintf("Failed to set up the working dir (%s)", err), + fmt.Sprintf("Failed to set job working dir (%s)", err), ) - return fmt.Errorf("prepare job working dir: %w", err) + return fmt.Errorf("set job working dir: %w", err) } if err := ex.setupRepo(ctx); err != nil { @@ -233,13 +233,6 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { return fmt.Errorf("setup files: %w", err) } - cleanupCredentials, err := ex.setupCredentials(ctx) - if err != nil { - ex.SetJobState(ctx, types.JobStateFailed) - return fmt.Errorf("setup credentials: %w", err) - } - defer cleanupCredentials() - connectionTrackerTicker := time.NewTicker(2500 * time.Millisecond) go ex.connectionTracker.Track(connectionTrackerTicker.C) defer ex.connectionTracker.Stop() @@ -339,21 +332,7 @@ func (ex *RunExecutor) SetRunnerState(state string) { ex.state = state } -func (ex *RunExecutor) setJobCredentials(ctx context.Context) { - if ex.jobSpec.User.Uid != nil { - ex.jobUid = int(*ex.jobSpec.User.Uid) - } - if ex.jobSpec.User.Gid != nil { - ex.jobGid = int(*ex.jobSpec.User.Gid) - } - if ex.jobSpec.User.HomeDir != "" { - ex.jobHomeDir = ex.jobSpec.User.HomeDir - } else { - ex.jobHomeDir = "/" - } - log.Trace(ctx, "Job credentials", "uid", ex.jobUid, "gid", ex.jobGid, "home", ex.jobHomeDir) -} - +// setJobWorkingDir must be called from Run after setJobUser func (ex *RunExecutor) setJobWorkingDir(ctx context.Context) error { var err error if ex.jobSpec.WorkingDir == nil { @@ -362,18 +341,73 @@ func (ex *RunExecutor) setJobWorkingDir(ctx context.Context) error { return fmt.Errorf("get working directory: %w", err) } } else { - ex.jobWorkingDir, err = common.ExpandPath(*ex.jobSpec.WorkingDir, "", ex.jobHomeDir) + ex.jobWorkingDir, err = common.ExpandPath(*ex.jobSpec.WorkingDir, "", ex.jobUser.HomeDir) if err != nil { return fmt.Errorf("expand working dir path: %w", err) } if !path.IsAbs(ex.jobWorkingDir) { - return fmt.Errorf("working_dir must be absolute: %s", ex.jobWorkingDir) + return fmt.Errorf("working dir must be absolute: %s", ex.jobWorkingDir) } } log.Trace(ctx, "Job working dir", "path", ex.jobWorkingDir) return nil } +// setupClusterSsh must be called from Run after setJobUser +func (ex *RunExecutor) setupClusterSsh(ctx context.Context) error { + if ex.jobSpec.SSHKey == nil || len(ex.clusterInfo.JobIPs) < 2 { + return nil + } + + sshDir, err := prepareUserSshDir(ex.jobUser) + if err != nil { + return fmt.Errorf("prepare user ssh dir: %w", err) + } + + privatePath := filepath.Join(sshDir, "dstack_job") + privateFile, err := os.OpenFile(privatePath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0o600) + if err != nil { + return fmt.Errorf("open private key file: %w", err) + } + defer privateFile.Close() + if err := os.Chown(privatePath, ex.jobUser.Uid, ex.jobUser.Uid); err != nil { + return fmt.Errorf("chown private key: %w", err) + } + if _, err := privateFile.WriteString(ex.jobSpec.SSHKey.Private); err != nil { + return fmt.Errorf("write private key: %w", err) + } + + // TODO: move job hosts config to ~/.dstack/ssh/config.d/current_job.conf + // and add "Include ~/.dstack/ssh/config.d/*.conf" directive to ~/.ssh/config if not present + // instead of appending job hosts config directly (don't bloat user's ssh_config) + configPath := filepath.Join(sshDir, "config") + configFile, err := os.OpenFile(configPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600) + if err != nil { + return fmt.Errorf("open SSH config: %w", err) + } + defer configFile.Close() + if err := os.Chown(configPath, ex.jobUser.Uid, ex.jobUser.Gid); err != nil { + return fmt.Errorf("chown SSH config: %w", err) + } + configBuffer := new(bytes.Buffer) + for _, ip := range ex.clusterInfo.JobIPs { + fmt.Fprintf(configBuffer, "\nHost %s\n", ip) + fmt.Fprintf(configBuffer, " Port %d\n", ex.sshd.Port()) + configBuffer.WriteString(" StrictHostKeyChecking no\n") + configBuffer.WriteString(" UserKnownHostsFile /dev/null\n") + fmt.Fprintf(configBuffer, " IdentityFile %s\n", privatePath) + } + if _, err := configFile.Write(configBuffer.Bytes()); err != nil { + return fmt.Errorf("write SSH config: %w", err) + } + + if err := ex.sshd.AddAuthorizedKeys(ctx, ex.jobSpec.SSHKey.Public); err != nil { + return fmt.Errorf("add authorized key: %w", err) + } + + return nil +} + func (ex *RunExecutor) getRepoData() schemas.RepoData { if ex.jobSpec.RepoData == nil { // jobs submitted before 0.19.17 do not have jobSpec.RepoData @@ -425,33 +459,26 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error } cmd.WaitDelay = ex.killDelay // kills the process if it doesn't exit in time - if err := common.MkdirAll(ctx, ex.jobWorkingDir, ex.jobUid, ex.jobGid); err != nil { + if err := common.MkdirAll(ctx, ex.jobWorkingDir, ex.jobUser.Uid, ex.jobUser.Gid, 0o755); err != nil { return fmt.Errorf("create working directory: %w", err) } cmd.Dir = ex.jobWorkingDir - // User must be already set - user := ex.jobSpec.User // Strictly speaking, we need CAP_SETUID and CAP_GUID (for Cmd.Start()-> // Cmd.SysProcAttr.Credential) and CAP_CHOWN (for startCommand()->os.Chown()), // but for the sake of simplicity we instead check if we are root or not - if ex.currentUid == 0 { - log.Trace( - ctx, "Using credentials", - "uid", *user.Uid, "gid", *user.Gid, "groups", user.GroupIds, - "username", user.GetUsername(), "groupname", user.GetGroupname(), - "home", user.HomeDir, - ) + if ex.currentUser.IsRoot() { + log.Trace(ctx, "Using credentials", "user", ex.jobUser) if cmd.SysProcAttr == nil { cmd.SysProcAttr = &syscall.SysProcAttr{} } - cmd.SysProcAttr.Credential = &syscall.Credential{ - Uid: *user.Uid, - Gid: *user.Gid, - Groups: user.GroupIds, + creds, err := ex.jobUser.ProcessCredentials() + if err != nil { + return fmt.Errorf("prepare process credentials: %w", err) } + cmd.SysProcAttr.Credential = creds } else { - log.Info(ctx, "Current user is not root, cannot set process credentials", "uid", ex.currentUid) + log.Info(ctx, "Current user is not root, cannot set process credentials", "user", ex.currentUser) } envMap := NewEnvMap(ParseEnvList(os.Environ()), jobEnvs, ex.secrets) @@ -466,54 +493,11 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error log.Warning(ctx, "failed to include dstack_profile", "path", profilePath, "err", err) } - // As of 2024-11-29, ex.homeDir is always set to /root - if _, err := prepareSSHDir(-1, -1, ex.homeDir); err != nil { - log.Warning(ctx, "failed to prepare ssh dir", "home", ex.homeDir, "err", err) - } - userSSHDir := "" - uid := -1 - gid := -1 - if user != nil && *user.Uid != 0 { - // non-root user - uid = int(*user.Uid) - gid = int(*user.Gid) - homeDir, isHomeDirAccessible := prepareHomeDir(ctx, uid, gid, user.HomeDir) - envMap["HOME"] = homeDir - if isHomeDirAccessible { - log.Trace(ctx, "provisioning homeDir", "path", homeDir) - userSSHDir, err = prepareSSHDir(uid, gid, homeDir) - if err != nil { - log.Warning(ctx, "failed to prepare ssh dir", "home", homeDir, "err", err) - } - } else { - log.Trace(ctx, "homeDir is not accessible, skipping provisioning", "path", homeDir) - } - } else { - // root user - envMap["HOME"] = ex.homeDir - userSSHDir = filepath.Join(ex.homeDir, ".ssh") - } - - if ex.jobSpec.SSHKey != nil && userSSHDir != "" { - err := configureSSH( - ex.jobSpec.SSHKey.Private, ex.clusterInfo.JobIPs, ex.sshd.Port(), - uid, gid, userSSHDir, - ) - if err == nil { - err = ex.sshd.AddAuthorizedKeys(ctx, ex.jobSpec.SSHKey.Public) - } - if err != nil { - log.Warning(ctx, "failed to configure SSH", "err", err) - } - } - err = writeMpiHostfile(ctx, ex.clusterInfo.JobIPs, gpusPerNodeNum, mpiHostfilePath) if err != nil { return fmt.Errorf("write MPI hostfile: %w", err) } - cmd.Env = envMap.Render() - // Configure process resource limits // TODO: Make rlimits customizable in the run configuration. Currently, we only set max locked memory // to unlimited to fix the issue with InfiniBand/RDMA: "Cannot allocate memory". @@ -529,6 +513,10 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error log.Error(ctx, "Failed to set resource limits", "err", err) } + // HOME must be added after writeDstackProfile to avoid overriding the correct per-user value set by sshd + envMap["HOME"] = ex.jobUser.HomeDir + cmd.Env = envMap.Render() + log.Trace(ctx, "Starting exec", "cmd", cmd.String(), "working_dir", cmd.Dir, "env", cmd.Env) ptm, err := startCommand(cmd) @@ -551,26 +539,32 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error return nil } -func (ex *RunExecutor) setupCredentials(ctx context.Context) (func(), error) { +// setupGitCredentials must be called from Run after setJobUser +func (ex *RunExecutor) setupGitCredentials(ctx context.Context) (func(), error) { if ex.repoCredentials == nil { return func() {}, nil } + switch ex.repoCredentials.GetProtocol() { case "ssh": if ex.repoCredentials.PrivateKey == nil { return nil, fmt.Errorf("private key is missing") } - keyPath := filepath.Join(ex.homeDir, ".ssh/id_rsa") + sshDir, err := prepareUserSshDir(ex.jobUser) + if err != nil { + return nil, fmt.Errorf("prepare user ssh dir: %w", err) + } + keyPath := filepath.Join(sshDir, "id_rsa") if _, err := os.Stat(keyPath); err == nil { return nil, fmt.Errorf("private key already exists") } - if err := os.MkdirAll(filepath.Dir(keyPath), 0o700); err != nil { - return nil, fmt.Errorf("create ssh directory: %w", err) - } log.Info(ctx, "Writing private key", "path", keyPath) if err := os.WriteFile(keyPath, []byte(*ex.repoCredentials.PrivateKey), 0o600); err != nil { return nil, fmt.Errorf("write private key: %w", err) } + if err := os.Chown(keyPath, ex.jobUser.Uid, ex.jobUser.Gid); err != nil { + return nil, fmt.Errorf("chown private key: %w", err) + } return func() { log.Info(ctx, "Removing private key", "path", keyPath) _ = os.Remove(keyPath) @@ -579,11 +573,11 @@ func (ex *RunExecutor) setupCredentials(ctx context.Context) (func(), error) { if ex.repoCredentials.OAuthToken == nil { return func() {}, nil } - hostsPath := filepath.Join(ex.homeDir, ".config/gh/hosts.yml") + hostsPath := filepath.Join(ex.jobUser.HomeDir, ".config/gh/hosts.yml") if _, err := os.Stat(hostsPath); err == nil { return nil, fmt.Errorf("hosts.yml file already exists") } - if err := os.MkdirAll(filepath.Dir(hostsPath), 0o700); err != nil { + if err := common.MkdirAll(ctx, filepath.Dir(hostsPath), ex.jobUser.Uid, ex.jobUser.Gid, 0o700); err != nil { return nil, fmt.Errorf("create gh config directory: %w", err) } log.Info(ctx, "Writing OAuth token", "path", hostsPath) @@ -595,6 +589,9 @@ func (ex *RunExecutor) setupCredentials(ctx context.Context) (func(), error) { if err := os.WriteFile(hostsPath, []byte(ghHost), 0o600); err != nil { return nil, fmt.Errorf("write OAuth token: %w", err) } + if err := os.Chown(hostsPath, ex.jobUser.Uid, ex.jobUser.Gid); err != nil { + return nil, fmt.Errorf("chown OAuth token: %w", err) + } return func() { log.Info(ctx, "Removing OAuth token", "path", hostsPath) _ = os.Remove(hostsPath) @@ -643,104 +640,6 @@ func buildLDLibraryPathEnv(ctx context.Context) (string, error) { return currentLDPath, nil } -// fillUser fills missing User fields -// Since normally only one kind of identifier is set (either id or name), we don't check -// (id, name) pair consistency -- id has higher priority and overwites name with a real -// name, ignoring the already set name value (if any) -// HomeDir and SupplementaryGroupIds are always set unconditionally, as they are not -// provided by the dstack server -func fillUser(user *schemas.User) error { - if user.Uid == nil && user.Username == nil { - return errors.New("neither Uid nor Username is set") - } - - if user.Gid == nil && user.Groupname != nil { - osGroup, err := osuser.LookupGroup(*user.Groupname) - if err != nil { - return fmt.Errorf("failed to look up group by Groupname: %w", err) - } - gid, err := parseStringId(osGroup.Gid) - if err != nil { - return fmt.Errorf("failed to parse group Gid: %w", err) - } - user.Gid = &gid - } - - var osUser *osuser.User - - if user.Uid == nil { - var err error - osUser, err = osuser.Lookup(*user.Username) - if err != nil { - return fmt.Errorf("failed to look up user by Username: %w", err) - } - uid, err := parseStringId(osUser.Uid) - if err != nil { - return fmt.Errorf("failed to parse Uid: %w", err) - } - user.Uid = &uid - } else { - var err error - osUser, err = osuser.LookupId(strconv.Itoa(int(*user.Uid))) - if err != nil { - var notFoundErr osuser.UnknownUserIdError - if !errors.As(err, ¬FoundErr) { - return fmt.Errorf("failed to look up user by Uid: %w", err) - } - } - } - - if osUser != nil { - user.Username = &osUser.Username - user.HomeDir = osUser.HomeDir - } else { - user.Username = nil - user.HomeDir = "" - } - - // If Gid is not set, either directly or via Groupname, use user's primary group - // and supplementary groups, see https://docs.docker.com/reference/dockerfile/#user - // If user doesn't exist, set Gid to 0 and supplementary groups to an empty list - if user.Gid == nil { - if osUser != nil { - gid, err := parseStringId(osUser.Gid) - if err != nil { - return fmt.Errorf("failed to parse primary Gid: %w", err) - } - user.Gid = &gid - groupStringIds, err := osUser.GroupIds() - if err != nil { - return fmt.Errorf("failed to get supplementary groups: %w", err) - } - var groupIds []uint32 - for _, groupStringId := range groupStringIds { - groupId, err := parseStringId(groupStringId) - if err != nil { - return fmt.Errorf("failed to parse supplementary group id: %w", err) - } - groupIds = append(groupIds, groupId) - } - user.GroupIds = groupIds - } else { - var fallbackGid uint32 = 0 - user.Gid = &fallbackGid - user.GroupIds = []uint32{} - } - } - return nil -} - -func parseStringId(stringId string) (uint32, error) { - id, err := strconv.ParseInt(stringId, 10, 32) - if err != nil { - return 0, err - } - if id < 0 { - return 0, fmt.Errorf("negative id value: %d", id) - } - return uint32(id), nil -} - // A simplified copypasta of creack/pty Start->StartWithSize->StartWithAttrs // with two additions: // * controlling terminal is properly set (cmd.Extrafiles, Cmd.SysProcAttr.Ctty) @@ -784,55 +683,24 @@ func startCommand(cmd *exec.Cmd) (*os.File, error) { return ptm, nil } -func prepareHomeDir(ctx context.Context, uid int, gid int, homeDir string) (string, bool) { - if homeDir == "" { - // user does not exist - return "/", false - } - if info, err := os.Stat(homeDir); errors.Is(err, os.ErrNotExist) { - if strings.Contains(homeDir, "nonexistent") { - // let `/nonexistent` stay non-existent - return homeDir, false - } - if err = os.MkdirAll(homeDir, 0o755); err != nil { - log.Warning(ctx, "failed to create homeDir", "err", err) - return homeDir, false - } - if err = os.Chmod(homeDir, 0o750); err != nil { - log.Warning(ctx, "failed to chmod homeDir", "err", err) - } - if err = os.Chown(homeDir, uid, gid); err != nil { - log.Warning(ctx, "failed to chown homeDir", "err", err) - } - return homeDir, true - } else if err != nil { - log.Warning(ctx, "homeDir is not accessible", "err", err) - return homeDir, false - } else if !info.IsDir() { - log.Warning(ctx, "HomeDir is not a dir", "path", homeDir) - return homeDir, false - } - return homeDir, true -} - -func prepareSSHDir(uid int, gid int, homeDir string) (string, error) { - sshDir := filepath.Join(homeDir, ".ssh") +func prepareUserSshDir(user *linuxuser.User) (string, error) { + sshDir := filepath.Join(user.HomeDir, ".ssh") info, err := os.Stat(sshDir) if err == nil { if !info.IsDir() { return "", fmt.Errorf("not a directory: %s", sshDir) } - if err = os.Chmod(sshDir, 0o700); err != nil { + if err := os.Chmod(sshDir, 0o700); err != nil { return "", fmt.Errorf("chmod ssh dir: %w", err) } } else if errors.Is(err, os.ErrNotExist) { - if err = os.MkdirAll(sshDir, 0o700); err != nil { + if err := os.MkdirAll(sshDir, 0o700); err != nil { return "", fmt.Errorf("create ssh dir: %w", err) } } else { return "", err } - if err = os.Chown(sshDir, uid, gid); err != nil { + if err := os.Chown(sshDir, user.Uid, user.Gid); err != nil { return "", fmt.Errorf("chown ssh dir: %w", err) } return sshDir, nil @@ -915,43 +783,3 @@ func includeDstackProfile(profilePath string, dstackProfilePath string) error { } return nil } - -func configureSSH(private string, ips []string, port int, uid int, gid int, sshDir string) error { - privatePath := filepath.Join(sshDir, "dstack_job") - privateFile, err := os.OpenFile(privatePath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0o600) - if err != nil { - return fmt.Errorf("open private key file: %w", err) - } - defer privateFile.Close() - if err := os.Chown(privatePath, uid, gid); err != nil { - return fmt.Errorf("chown private key: %w", err) - } - if _, err := privateFile.WriteString(private); err != nil { - return fmt.Errorf("write private key: %w", err) - } - - // TODO: move job hosts config to ~/.dstack/ssh/config.d/current_job - // and add "Include ~/.dstack/ssh/config.d/*" directive to ~/.ssh/config if not present - // instead of appending job hosts config directly (don't bloat user's ssh_config) - configPath := filepath.Join(sshDir, "config") - configFile, err := os.OpenFile(configPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600) - if err != nil { - return fmt.Errorf("open SSH config: %w", err) - } - defer configFile.Close() - if err := os.Chown(configPath, uid, gid); err != nil { - return fmt.Errorf("chown SSH config: %w", err) - } - var configBuffer bytes.Buffer - for _, ip := range ips { - configBuffer.WriteString(fmt.Sprintf("\nHost %s\n", ip)) - configBuffer.WriteString(fmt.Sprintf(" Port %d\n", port)) - configBuffer.WriteString(" StrictHostKeyChecking no\n") - configBuffer.WriteString(" UserKnownHostsFile /dev/null\n") - configBuffer.WriteString(fmt.Sprintf(" IdentityFile %s\n", privatePath)) - } - if _, err := configFile.Write(configBuffer.Bytes()); err != nil { - return fmt.Errorf("write SSH config: %w", err) - } - return nil -} diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index 0d935dd64..105493e30 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + linuxuser "github.com/dstackai/dstack/runner/internal/linux/user" "github.com/dstackai/dstack/runner/internal/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -63,7 +64,7 @@ func TestExecutor_HomeDir(t *testing.T) { err := ex.execJob(t.Context(), io.Writer(&b)) assert.NoError(t, err) - assert.Equal(t, ex.homeDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) + assert.Equal(t, ex.currentUser.HomeDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } func TestExecutor_NonZeroExit(t *testing.T) { @@ -90,7 +91,7 @@ func TestExecutor_SSHCredentials(t *testing.T) { PrivateKey: &key, } - clean, err := ex.setupCredentials(t.Context()) + clean, err := ex.setupGitCredentials(t.Context()) defer clean() require.NoError(t, err) @@ -206,14 +207,23 @@ func makeTestExecutor(t *testing.T) *RunExecutor { tempDir := filepath.Join(baseDir, "temp") require.NoError(t, os.Mkdir(tempDir, 0o700)) - homeDir := filepath.Join(baseDir, "home") - require.NoError(t, os.Mkdir(homeDir, 0o700)) + dstackDir := filepath.Join(baseDir, "dstack") require.NoError(t, os.Mkdir(dstackDir, 0o755)) - ex, err := NewRunExecutor(tempDir, homeDir, dstackDir, new(sshdMock)) + + currentUser, err := linuxuser.FromCurrentProcess() + require.NoError(t, err) + homeDir := filepath.Join(baseDir, "home") + require.NoError(t, os.Mkdir(homeDir, 0o700)) + currentUser.HomeDir = homeDir + + ex, err := NewRunExecutor(tempDir, dstackDir, *currentUser, new(sshdMock)) require.NoError(t, err) + ex.SetJob(body) + require.NoError(t, ex.setJobUser(t.Context())) require.NoError(t, ex.setJobWorkingDir(t.Context())) + return ex } diff --git a/runner/internal/executor/files.go b/runner/internal/executor/files.go index ee1170c41..6b992ce2c 100644 --- a/runner/internal/executor/files.go +++ b/runner/internal/executor/files.go @@ -34,19 +34,22 @@ func (ex *RunExecutor) WriteFileArchive(id string, src io.Reader) error { return nil } -// setupFiles must be called from Run -// Must be called after setJobWorkingDir and setJobCredentials +// setupFiles must be called from Run after setJobUser and setJobWorkingDir func (ex *RunExecutor) setupFiles(ctx context.Context) error { log.Trace(ctx, "Setting up files") if ex.jobWorkingDir == "" { - return errors.New("setup files: working dir is not set") + return errors.New("working dir is not set") } if !filepath.IsAbs(ex.jobWorkingDir) { - return fmt.Errorf("setup files: working dir must be absolute: %s", ex.jobWorkingDir) + return fmt.Errorf("working dir must be absolute: %s", ex.jobWorkingDir) } for _, fa := range ex.jobSpec.FileArchives { archivePath := path.Join(ex.fileArchiveDir, fa.Id) - if err := extractFileArchive(ctx, archivePath, fa.Path, ex.jobWorkingDir, ex.jobUid, ex.jobGid, ex.jobHomeDir); err != nil { + err := extractFileArchive( + ctx, archivePath, fa.Path, ex.jobWorkingDir, ex.jobUser.HomeDir, + ex.jobUser.Uid, ex.jobUser.Gid, + ) + if err != nil { return fmt.Errorf("extract file archive %s: %w", fa.Id, err) } } @@ -56,7 +59,7 @@ func (ex *RunExecutor) setupFiles(ctx context.Context) error { return nil } -func extractFileArchive(ctx context.Context, archivePath string, destPath string, baseDir string, uid int, gid int, homeDir string) error { +func extractFileArchive(ctx context.Context, archivePath string, destPath string, baseDir string, homeDir string, uid int, gid int) error { log.Trace(ctx, "Extracting file archive", "archive", archivePath, "dest", destPath, "base", baseDir, "home", homeDir) destPath, err := common.ExpandPath(destPath, baseDir, homeDir) @@ -64,7 +67,7 @@ func extractFileArchive(ctx context.Context, archivePath string, destPath string return fmt.Errorf("expand destination path: %w", err) } destBase, destName := path.Split(destPath) - if err := common.MkdirAll(ctx, destBase, uid, gid); err != nil { + if err := common.MkdirAll(ctx, destBase, uid, gid, 0o755); err != nil { return fmt.Errorf("create destination directory: %w", err) } if err := os.RemoveAll(destPath); err != nil { @@ -88,11 +91,9 @@ func extractFileArchive(ctx context.Context, archivePath string, destPath string return fmt.Errorf("extract tar archive: %w", err) } - if uid != -1 || gid != -1 { - for _, p := range paths { - if err := os.Chown(path.Join(destBase, p), uid, gid); err != nil { - log.Warning(ctx, "Failed to chown", "path", p, "err", err) - } + for _, p := range paths { + if err := os.Chown(path.Join(destBase, p), uid, gid); err != nil { + log.Warning(ctx, "Failed to chown", "path", p, "err", err) } } diff --git a/runner/internal/executor/repo.go b/runner/internal/executor/repo.go index 2f757f63c..467c783a8 100644 --- a/runner/internal/executor/repo.go +++ b/runner/internal/executor/repo.go @@ -36,22 +36,21 @@ func (ex *RunExecutor) WriteRepoBlob(src io.Reader) error { return nil } -// setupRepo must be called from Run -// Must be called after setJobWorkingDir and setJobCredentials +// setupRepo must be called from Run after setJobUser and setJobWorkingDir func (ex *RunExecutor) setupRepo(ctx context.Context) error { log.Trace(ctx, "Setting up repo") if ex.jobWorkingDir == "" { - return errors.New("setup repo: working dir is not set") + return errors.New("working dir is not set") } if !filepath.IsAbs(ex.jobWorkingDir) { - return fmt.Errorf("setup repo: working dir must be absolute: %s", ex.jobWorkingDir) + return fmt.Errorf("working dir must be absolute: %s", ex.jobWorkingDir) } if ex.jobSpec.RepoDir == nil { - return errors.New("repo_dir is not set") + return errors.New("repo dir is not set") } var err error - ex.repoDir, err = common.ExpandPath(*ex.jobSpec.RepoDir, ex.jobWorkingDir, ex.jobHomeDir) + ex.repoDir, err = common.ExpandPath(*ex.jobSpec.RepoDir, ex.jobWorkingDir, ex.jobUser.HomeDir) if err != nil { return fmt.Errorf("expand repo dir path: %w", err) } @@ -71,12 +70,12 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error { } switch repoExistsAction { case schemas.RepoExistsActionError: - return fmt.Errorf("setup repo: repo dir is not empty: %s", ex.repoDir) + return fmt.Errorf("repo dir is not empty: %s", ex.repoDir) case schemas.RepoExistsActionSkip: log.Info(ctx, "Skipping repo checkout: repo dir is not empty", "path", ex.repoDir) return nil default: - return fmt.Errorf("setup repo: unsupported action: %s", repoExistsAction) + return fmt.Errorf("unsupported action: %s", repoExistsAction) } } @@ -237,9 +236,6 @@ func (ex *RunExecutor) restoreRepoDir(ctx context.Context, tmpDir string) error func (ex *RunExecutor) chownRepoDir(ctx context.Context) error { log.Trace(ctx, "Chowning repo dir") - if ex.jobUid == -1 && ex.jobGid == -1 { - return nil - } return filepath.WalkDir( ex.repoDir, func(p string, d fs.DirEntry, err error) error { @@ -248,7 +244,7 @@ func (ex *RunExecutor) chownRepoDir(ctx context.Context) error { log.Debug(ctx, "Error while walking repo dir", "path", p, "err", err) return nil } - if err := os.Chown(p, ex.jobUid, ex.jobGid); err != nil { + if err := os.Chown(p, ex.jobUser.Uid, ex.jobUser.Gid); err != nil { log.Debug(ctx, "Error while chowning repo dir", "path", p, "err", err) } return nil diff --git a/runner/internal/executor/user.go b/runner/internal/executor/user.go new file mode 100644 index 000000000..30affda61 --- /dev/null +++ b/runner/internal/executor/user.go @@ -0,0 +1,184 @@ +package executor + +import ( + "context" + "errors" + "fmt" + "os" + osuser "os/user" + "path" + "strconv" + "strings" + + linuxuser "github.com/dstackai/dstack/runner/internal/linux/user" + "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/schemas" +) + +func (ex *RunExecutor) setJobUser(ctx context.Context) error { + if ex.jobSpec.User == nil { + // JobSpec.User is nil if the user is not specified either in the dstack configuration + // (the `user` property) or in the image (the `USER` Dockerfile instruction). + // In such cases, the root user should be used as a fallback, and we use the current user, + // assuming that the runner is started by root. + ex.jobUser = &ex.currentUser + } else { + jobUser, err := jobUserFromJobSpecUser( + ex.jobSpec.User, + osuser.LookupId, osuser.Lookup, + osuser.LookupGroup, (*osuser.User).GroupIds, + ) + if err != nil { + return fmt.Errorf("job user from job spec: %w", err) + } + ex.jobUser = jobUser + } + + if err := checkHomeDir(ex.jobUser.HomeDir); err != nil { + log.Warning(ctx, "Error while checking job user home dir, using / instead", "err", err) + ex.jobUser.HomeDir = "/" + } + + log.Trace(ctx, "Job user", "user", ex.jobUser) + return nil +} + +func jobUserFromJobSpecUser( + jobSpecUser *schemas.User, + userLookupIdFunc func(string) (*osuser.User, error), + userLookupNameFunc func(string) (*osuser.User, error), + groupLookupNameFunc func(string) (*osuser.Group, error), + userGroupIdsFunc func(*osuser.User) ([]string, error), +) (*linuxuser.User, error) { + if jobSpecUser.Uid == nil && jobSpecUser.Username == nil { + return nil, errors.New("neither uid nor username is set") + } + + var err error + var osUser *osuser.User + + // -1 is a placeholder value, the actual value must be >= 0 + //nolint:ineffassign + uid := -1 + if jobSpecUser.Uid != nil { + uid = int(*jobSpecUser.Uid) + osUser, err = userLookupIdFunc(strconv.Itoa(uid)) + if err != nil { + var notFoundErr osuser.UnknownUserIdError + if !errors.As(err, ¬FoundErr) { + return nil, fmt.Errorf("lookup user by id: %w", err) + } + } + } else { + osUser, err = userLookupNameFunc(*jobSpecUser.Username) + if err != nil { + return nil, fmt.Errorf("lookup user by name: %w", err) + } + uid, err = parseStringId(osUser.Uid) + if err != nil { + return nil, fmt.Errorf("parse user id: %w", err) + } + } + if uid == -1 { + // Assertion, should never occur + return nil, errors.New("failed to infer user id") + } + + // -1 is a placeholder value, the actual value must be >= 0 + //nolint:ineffassign + gid := -1 + // Must include at least one gid, see len(gids) == 0 assertion below + var gids []int + if jobSpecUser.Gid != nil { + gid = int(*jobSpecUser.Gid) + // Here and below: + // > Note that when specifying a group for the user, the user will have + // > only the specified group membership. + // > Any other configured group memberships will be ignored. + // See: https://docs.docker.com/reference/dockerfile/#user + gids = []int{gid} + } else if jobSpecUser.Groupname != nil { + osGroup, err := groupLookupNameFunc(*jobSpecUser.Groupname) + if err != nil { + return nil, fmt.Errorf("lookup group by name: %w", err) + } + gid, err = parseStringId(osGroup.Gid) + if err != nil { + return nil, fmt.Errorf("parse group id: %w", err) + } + gids = []int{gid} + } else if osUser != nil { + gid, err = parseStringId(osUser.Gid) + if err != nil { + return nil, fmt.Errorf("parse group id: %w", err) + } + rawGids, err := userGroupIdsFunc(osUser) + if err != nil { + return nil, fmt.Errorf("get user supplementary group ids: %w", err) + } + // [main_gid, supplementary_gid_1, supplementary_gid_2, ...] + gids = make([]int, len(rawGids)+1) + gids[0] = gid + for index, rawGid := range rawGids { + supplementaryGid, err := parseStringId(rawGid) + if err != nil { + return nil, fmt.Errorf("parse supplementary group id: %w", err) + } + gids[index+1] = supplementaryGid + } + } else { + // > When the user doesn't have a primary group then the image + // > (or the next instructions) will be run with the root group. + // See: https://docs.docker.com/reference/dockerfile/#user + gid = 0 + gids = []int{gid} + } + if gid == -1 { + // Assertion, should never occur + return nil, errors.New("failed to infer group id") + } + if len(gids) == 0 { + // Assertion, should never occur + return nil, errors.New("failed to infer supplementary group ids") + } + + username := "" + homeDir := "" + if osUser != nil { + username = osUser.Username + homeDir = osUser.HomeDir + } + + return linuxuser.NewUser(uid, gid, gids, username, homeDir), nil +} + +func parseStringId(stringId string) (int, error) { + id, err := strconv.Atoi(stringId) + if err != nil { + return 0, err + } + if id < 0 { + return 0, fmt.Errorf("negative id value: %d", id) + } + return id, nil +} + +func checkHomeDir(homeDir string) error { + if homeDir == "" { + return errors.New("not set") + } + if !path.IsAbs(homeDir) { + return fmt.Errorf("must be absolute: %s", homeDir) + } + if info, err := os.Stat(homeDir); errors.Is(err, os.ErrNotExist) { + if strings.Contains(homeDir, "nonexistent") { + // let `/nonexistent` stay non-existent + return fmt.Errorf("non-existent: %s", homeDir) + } + } else if err != nil { + return err + } else if !info.IsDir() { + return fmt.Errorf("not a directory: %s", homeDir) + } + return nil +} diff --git a/runner/internal/executor/user_test.go b/runner/internal/executor/user_test.go new file mode 100644 index 000000000..2bc6a19d8 --- /dev/null +++ b/runner/internal/executor/user_test.go @@ -0,0 +1,232 @@ +package executor + +import ( + "errors" + osuser "os/user" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + + linuxuser "github.com/dstackai/dstack/runner/internal/linux/user" + "github.com/dstackai/dstack/runner/internal/schemas" +) + +var shouldNotBeCalledErr = errors.New("this function should not be called") + +func unknownUserIdError(t *testing.T, strUid string) osuser.UnknownUserIdError { + t.Helper() + uid, err := strconv.Atoi(strUid) + require.NoError(t, err) + return osuser.UnknownUserIdError(uid) +} + +func TestJobUserFromJobSpecUser_Uid_UserDoesNotExist(t *testing.T) { + specUid := uint32(2000) + specUser := schemas.User{Uid: &specUid} + expectedUser := linuxuser.User{ + Uid: 2000, + Gid: 0, + Gids: []int{0}, + Username: "", + HomeDir: "", + } + + user, err := jobUserFromJobSpecUser( + &specUser, + func(id string) (*osuser.User, error) { return nil, unknownUserIdError(t, id) }, + func(name string) (*osuser.User, error) { return nil, shouldNotBeCalledErr }, + func(name string) (*osuser.Group, error) { return nil, shouldNotBeCalledErr }, + func(*osuser.User) ([]string, error) { return nil, shouldNotBeCalledErr }, + ) + + require.NoError(t, err) + require.Equal(t, expectedUser, *user) +} + +func TestJobUserFromJobSpecUser_Uid_Gid_UserDoesNotExist(t *testing.T) { + specUid := uint32(2000) + specGid := uint32(200) + specUser := schemas.User{Uid: &specUid, Gid: &specGid} + expectedUser := linuxuser.User{ + Uid: 2000, + Gid: 200, + Gids: []int{200}, + Username: "", + HomeDir: "", + } + + user, err := jobUserFromJobSpecUser( + &specUser, + func(id string) (*osuser.User, error) { return nil, unknownUserIdError(t, id) }, + func(name string) (*osuser.User, error) { return nil, shouldNotBeCalledErr }, + func(name string) (*osuser.Group, error) { return nil, shouldNotBeCalledErr }, + func(*osuser.User) ([]string, error) { return nil, shouldNotBeCalledErr }, + ) + + require.NoError(t, err) + require.Equal(t, expectedUser, *user) +} + +func TestJobUserFromJobSpecUser_Uid_UserExists(t *testing.T) { + specUid := uint32(2000) + specUser := schemas.User{Uid: &specUid} + osUser := osuser.User{ + Uid: "2000", + Gid: "300", + Username: "testuser", + HomeDir: "/home/testuser", + } + osUserGids := []string{"300", "400", "500"} + expectedUser := linuxuser.User{ + Uid: 2000, + Gid: 300, + Gids: []int{300, 400, 500}, + Username: "testuser", + HomeDir: "/home/testuser", + } + + user, err := jobUserFromJobSpecUser( + &specUser, + func(uid string) (*osuser.User, error) { return &osUser, nil }, + func(name string) (*osuser.User, error) { return nil, shouldNotBeCalledErr }, + func(gid string) (*osuser.Group, error) { return nil, shouldNotBeCalledErr }, + func(*osuser.User) ([]string, error) { return osUserGids, nil }, + ) + + require.NoError(t, err) + require.Equal(t, expectedUser, *user) +} + +func TestJobUserFromJobSpecUser_Uid_Gid_UserExists(t *testing.T) { + specUid := uint32(2000) + specGid := uint32(200) + specUser := schemas.User{Uid: &specUid, Gid: &specGid} + osUser := osuser.User{ + Uid: "2000", + Gid: "300", + Username: "testuser", + HomeDir: "/home/testuser", + } + expectedUser := linuxuser.User{ + Uid: 2000, + Gid: 200, + Gids: []int{200}, + Username: "testuser", + HomeDir: "/home/testuser", + } + + user, err := jobUserFromJobSpecUser( + &specUser, + func(id string) (*osuser.User, error) { return &osUser, nil }, + func(name string) (*osuser.User, error) { return nil, shouldNotBeCalledErr }, + func(name string) (*osuser.Group, error) { return nil, shouldNotBeCalledErr }, + func(*osuser.User) ([]string, error) { return nil, shouldNotBeCalledErr }, + ) + + require.NoError(t, err) + require.Equal(t, expectedUser, *user) +} + +func TestJobUserFromJobSpecUser_Username_UserDoesNotExist(t *testing.T) { + specUsername := "unknownuser" + specUser := schemas.User{Username: &specUsername} + + user, err := jobUserFromJobSpecUser( + &specUser, + func(id string) (*osuser.User, error) { return nil, shouldNotBeCalledErr }, + func(name string) (*osuser.User, error) { return nil, osuser.UnknownUserError(name) }, + func(name string) (*osuser.Group, error) { return nil, shouldNotBeCalledErr }, + func(*osuser.User) ([]string, error) { return nil, shouldNotBeCalledErr }, + ) + + require.ErrorContains(t, err, "lookup user by name") + require.Nil(t, user) +} + +func TestJobUserFromJobSpecUser_Username_UserExists(t *testing.T) { + specUsername := "testnuser" + specUser := schemas.User{Username: &specUsername} + osUser := osuser.User{ + Uid: "2000", + Gid: "300", + Username: "testuser", + HomeDir: "/home/testuser", + } + osUserGids := []string{"300", "400", "500"} + expectedUser := linuxuser.User{ + Uid: 2000, + Gid: 300, + Gids: []int{300, 400, 500}, + Username: "testuser", + HomeDir: "/home/testuser", + } + + user, err := jobUserFromJobSpecUser( + &specUser, + func(id string) (*osuser.User, error) { return nil, shouldNotBeCalledErr }, + func(name string) (*osuser.User, error) { return &osUser, nil }, + func(name string) (*osuser.Group, error) { return nil, shouldNotBeCalledErr }, + func(*osuser.User) ([]string, error) { return osUserGids, nil }, + ) + + require.NoError(t, err) + require.Equal(t, expectedUser, *user) +} + +func TestJobUserFromJobSpecUser_Username_Groupname_UserExists_GroupExists(t *testing.T) { + specUsername := "testnuser" + specGroupname := "testgroup" + specUser := schemas.User{Username: &specUsername, Groupname: &specGroupname} + osUser := osuser.User{ + Uid: "2000", + Gid: "300", + Username: "testuser", + HomeDir: "/home/testuser", + } + osGroup := osuser.Group{ + Gid: "200", + Name: specGroupname, + } + expectedUser := linuxuser.User{ + Uid: 2000, + Gid: 200, + Gids: []int{200}, + Username: "testuser", + HomeDir: "/home/testuser", + } + + user, err := jobUserFromJobSpecUser( + &specUser, + func(id string) (*osuser.User, error) { return nil, shouldNotBeCalledErr }, + func(name string) (*osuser.User, error) { return &osUser, nil }, + func(name string) (*osuser.Group, error) { return &osGroup, nil }, + func(*osuser.User) ([]string, error) { return nil, shouldNotBeCalledErr }, + ) + + require.NoError(t, err) + require.Equal(t, expectedUser, *user) +} + +func TestJobUserFromJobSpecUser_Username_Groupname_UserExists_GroupDoesNotExist(t *testing.T) { + specUsername := "testnuser" + specGroupname := "testgroup" + specUser := schemas.User{Username: &specUsername, Groupname: &specGroupname} + osUser := osuser.User{ + Uid: "2000", + Gid: "300", + Username: "testuser", + HomeDir: "/home/testuser", + } + + user, err := jobUserFromJobSpecUser( + &specUser, + func(id string) (*osuser.User, error) { return nil, shouldNotBeCalledErr }, + func(name string) (*osuser.User, error) { return &osUser, nil }, + func(name string) (*osuser.Group, error) { return nil, osuser.UnknownGroupError(name) }, + func(*osuser.User) ([]string, error) { return nil, shouldNotBeCalledErr }, + ) + + require.ErrorContains(t, err, "lookup group by name") + require.Nil(t, user) +} diff --git a/runner/internal/linux/user/user.go b/runner/internal/linux/user/user.go new file mode 100644 index 000000000..caecc1324 --- /dev/null +++ b/runner/internal/linux/user/user.go @@ -0,0 +1,96 @@ +// Despite this package is being located inside the linux package, it should work on any Unix-like system. +package user + +import ( + "fmt" + osuser "os/user" + "slices" + "strconv" + "syscall" +) + +// User represents the user part of process `credentials(7)` +// (real user ID, real group ID, supplementary group IDs) enriched with +// some info from the user database `passwd(5)` (login name, home dir). +// Note, unlike the User struct from os/user, User does not necessarily +// correspond to any existing user account, for example, any of IDs may not exist +// in passwd(5) or group(5) databases at all or the user may not belong to +// the primary group or any of the specified supplementary groups. +type User struct { + // Real user ID + Uid int + // Real group ID + Gid int + // Supplementary group IDs. The primary group should be always included and + // the resulting list should be sorted in ascending order with duplicates removed; + // NewUser() performs such normalization + Gids []int + // May be empty, e.g., if the user does not exist + Username string + // May be Empty, e.g., if the user does not exist + HomeDir string +} + +func (u *User) String() string { + // The format is inspired by `id(1)` + formattedUsername := "" + if u.Username != "" { + formattedUsername = fmt.Sprintf("(%s)", u.Username) + } + return fmt.Sprintf("uid=%d%s gid=%d groups=%v home=%s", u.Uid, formattedUsername, u.Gid, u.Gids, u.HomeDir) +} + +func (u *User) ProcessCredentials() (*syscall.Credential, error) { + if u.Uid < 0 { + return nil, fmt.Errorf("negative user id: %d", u.Uid) + } + if u.Gid < 0 { + return nil, fmt.Errorf("negative group id: %d", u.Gid) + } + groups := make([]uint32, len(u.Gids)) + for index, gid := range u.Gids { + if gid < 0 { + return nil, fmt.Errorf("negative supplementary group id: %d", gid) + } + groups[index] = uint32(gid) + } + creds := syscall.Credential{ + Uid: uint32(u.Uid), + Gid: uint32(u.Gid), + Groups: groups, + } + return &creds, nil +} + +func (u *User) IsRoot() bool { + return u.Uid == 0 +} + +func NewUser(uid int, gid int, gids []int, username string, homeDir string) *User { + normalizedGids := append([]int{gid}, gids...) + slices.Sort(normalizedGids) + normalizedGids = slices.Compact(normalizedGids) + return &User{ + Uid: uid, + Gid: gid, + Gids: normalizedGids, + Username: username, + HomeDir: homeDir, + } +} + +func FromCurrentProcess() (*User, error) { + uid := syscall.Getuid() + gid := syscall.Getgid() + gids, err := syscall.Getgroups() + if err != nil { + return nil, fmt.Errorf("get supplementary groups: %w", err) + } + username := "" + homeDir := "" + if osUser, err := osuser.LookupId(strconv.Itoa(uid)); err == nil { + username = osUser.Username + homeDir = osUser.HomeDir + } + return NewUser(uid, gid, gids, username, homeDir), nil +} diff --git a/runner/internal/schemas/schemas.go b/runner/internal/schemas/schemas.go index 106bc61f8..152637dec 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/schemas/schemas.go @@ -124,22 +124,6 @@ type User struct { Username *string `json:"username"` Gid *uint32 `json:"gid"` Groupname *string `json:"groupname"` - GroupIds []uint32 - HomeDir string -} - -func (u *User) GetUsername() string { - if u.Username == nil { - return "" - } - return *u.Username -} - -func (u *User) GetGroupname() string { - if u.Groupname == nil { - return "" - } - return *u.Groupname } type HealthcheckResponse struct { diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 7e29e92dd..1fd8d959a 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -927,8 +927,6 @@ func getSSHShellCommands() []string { `unset LD_LIBRARY_PATH && unset LD_PRELOAD`, // common functions `exists() { command -v "$1" > /dev/null 2>&1; }`, - // TODO(#1535): support non-root images properly - "mkdir -p /root && chown root:root /root && export HOME=/root", // package manager detection/abstraction `install_pkg() { NAME=Distribution; test -f /etc/os-release && . /etc/os-release; echo $NAME not supported; exit 11; }`, `if exists apt-get; then install_pkg() { apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y "$1"; }; fi`, @@ -1190,7 +1188,6 @@ func (c *CLIArgs) DockerShellCommands(publicKeys []string) []string { consts.RunnerBinaryPath, "--log-level", strconv.Itoa(c.Runner.LogLevel), "start", - "--home-dir", consts.RunnerHomeDir, "--temp-dir", consts.RunnerTempDir, "--http-port", strconv.Itoa(c.Runner.HTTPPort), "--ssh-port", strconv.Itoa(c.Runner.SSHPort), diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 13cba1eb5..75a68e77f 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -944,8 +944,6 @@ def get_docker_commands( "unset LD_LIBRARY_PATH && unset LD_PRELOAD", # common functions 'exists() { command -v "$1" > /dev/null 2>&1; }', - # TODO(#1535): support non-root images properly - "mkdir -p /root && chown root:root /root && export HOME=/root", # package manager detection/abstraction "install_pkg() { NAME=Distribution; test -f /etc/os-release && . /etc/os-release; echo $NAME not supported; exit 11; }", 'if exists apt-get; then install_pkg() { apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y "$1"; }; fi', @@ -963,8 +961,6 @@ def get_docker_commands( "--log-level", "6", "start", - "--home-dir", - "/root", "--temp-dir", "/tmp/runner", "--http-port", diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 53feb9cda..4f6379b17 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -249,7 +249,6 @@ def run_job( ) ], security_context=client.V1SecurityContext( - # TODO(#1535): support non-root images properly run_as_user=0, run_as_group=0, privileged=job.job_spec.privileged, diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index ae7ea19f8..f8c8d882c 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -81,7 +81,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic ssh_port = jpd.ssh_port ssh_proxy = jpd.ssh_proxy else: - ssh_destination = "root@localhost" # TODO(#1535): support non-root images properly + ssh_destination = "root@localhost" ssh_port = DSTACK_RUNNER_SSH_PORT job_submission = jobs_services.job_model_to_job_submission(job) jrd = job_submission.job_runtime_data diff --git a/src/dstack/_internal/server/services/ssh.py b/src/dstack/_internal/server/services/ssh.py index a7967d803..d1ba8ffc8 100644 --- a/src/dstack/_internal/server/services/ssh.py +++ b/src/dstack/_internal/server/services/ssh.py @@ -30,7 +30,7 @@ def container_ssh_tunnel( ssh_port = jpd.ssh_port ssh_proxy = jpd.ssh_proxy else: - ssh_destination = "root@localhost" # TODO(#1535): support non-root images properly + ssh_destination = "root@localhost" ssh_port = DSTACK_RUNNER_SSH_PORT job_submission = jobs_services.job_model_to_job_submission(job) jrd = job_submission.job_runtime_data