Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
bin/
*.sw?
tmp
lab
47 changes: 40 additions & 7 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/terminal"
)

// Client is a wrapper over the SSH connection/sessions.
Expand All @@ -29,6 +30,10 @@ type SSHClient struct {
running bool
env string //export FOO="bar"; export BAR="baz";
color string

ask bool // For interactive "ask:root@..."
password string // For config "password: ..."

}

type ErrConnect struct {
Expand All @@ -54,6 +59,15 @@ func (c *SSHClient) parseHost(host string) error {
if at := strings.LastIndex(c.host, "@"); at != -1 {
c.user = c.host[:at]
c.host = c.host[at+1:]

// Check if the username starts with "ask:"
c.ask = false
if strings.HasPrefix(c.user, "ask:") {
// Remove "ask:" from the username
c.user = strings.TrimPrefix(c.user, "ask:")
// Set the flag so ConnectWith knows to prompt
c.ask = true
}
}

// Add default user, if not set
Expand Down Expand Up @@ -135,11 +149,30 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error {
return err
}

// 1. Start with Public Keys (SSH Agent + ~/.ssh/id_rsa)
auths := []ssh.AuthMethod{
authMethod,
}

// 2. If a password was provided in Supfile, add it to the list
if c.password != "" {
auths = append(auths, ssh.Password(c.password))
}

// 3. If the user requested an interactive prompt (ask:user@host), ask now
if c.ask {
fmt.Printf("Enter Password for %s@%s: ", c.user, c.host)
pass, err := terminal.ReadPassword(int(os.Stdin.Fd()))
if err != nil {
return err
}
fmt.Println("") // Print newline after input
auths = append(auths, ssh.Password(string(pass)))
}

config := &ssh.ClientConfig{
User: c.user,
Auth: []ssh.AuthMethod{
authMethod,
},
User: c.user,
Auth: auths, // Use the combined list of auth methods
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}

Expand All @@ -155,10 +188,10 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error {
// Run runs the task.Run command remotely on c.host.
func (c *SSHClient) Run(task *Task) error {
if c.running {
return fmt.Errorf("Session already running")
return fmt.Errorf("session already running")
}
if c.sessOpened {
return fmt.Errorf("Session already connected")
return fmt.Errorf("session already connected")
}

sess, err := c.conn.NewSession()
Expand Down Expand Up @@ -241,7 +274,7 @@ func (c *SSHClient) Close() error {
c.sessOpened = false
}
if !c.connOpened {
return fmt.Errorf("Trying to close the already closed connection")
return fmt.Errorf("trying to close the already closed connection")
}

err := c.conn.Close()
Expand Down
16 changes: 10 additions & 6 deletions sup.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"golang.org/x/crypto/ssh"
)

const VERSION = "0.5"
const VERSION = "0.5.1"

type Stackup struct {
conf *Supfile
Expand All @@ -29,7 +29,8 @@ func New(conf *Supfile) (*Stackup, error) {

// Run runs set of commands on multiple hosts defined by network sequentially.
// TODO: This megamoth method needs a big refactor and should be split
// to multiple smaller methods.
//
// to multiple smaller methods.
func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) error {
if len(commands) == 0 {
return errors.New("no commands to be run")
Expand All @@ -40,7 +41,9 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command)
// Create clients for every host (either SSH or Localhost).
var bastion *SSHClient
if network.Bastion != "" {
bastion = &SSHClient{}
bastion = &SSHClient{
password: network.Password,
}
if err := bastion.Connect(network.Bastion); err != nil {
return errors.Wrap(err, "connecting to bastion failed")
}
Expand Down Expand Up @@ -70,9 +73,10 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command)

// SSH client.
remote := &SSHClient{
env: env + `export SUP_HOST="` + host + `";`,
user: network.User,
color: Colors[i%len(Colors)],
env: env + `export SUP_HOST="` + host + `";`,
user: network.User,
color: Colors[i%len(Colors)],
password: network.Password,
}

if bastion != nil {
Expand Down
11 changes: 6 additions & 5 deletions supfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Network struct {
Inventory string `yaml:"inventory"`
Hosts []string `yaml:"hosts"`
Bastion string `yaml:"bastion"` // Jump host for the environment
Password string `yaml:"password"`

// Should these live on Hosts too? We'd have to change []string to struct, even in Supfile.
User string // `yaml:"user"`
Expand Down Expand Up @@ -146,9 +147,10 @@ func (t *Targets) Get(name string) ([]string, bool) {
// Upload represents file copy operation from localhost Src path to Dst
// path of every host in a given Network.
type Upload struct {
Src string `yaml:"src"`
Dst string `yaml:"dst"`
Exc string `yaml:"exclude"`
Src string `yaml:"src"`
Dst string `yaml:"dst"`
Exc string `yaml:"exclude"`
Flatten bool `yaml:"flatten"`
}

// EnvVar represents an environment variable
Expand Down Expand Up @@ -318,8 +320,7 @@ func NewSupfile(data []byte) (*Supfile, error) {

fallthrough

case "0.4", "0.5":

case "0.4", "0.5", "0.5.1":
default:
return nil, ErrUnsupportedSupfileVersion{"unsupported Supfile version " + conf.Version}
}
Expand Down
8 changes: 4 additions & 4 deletions tar.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func RemoteTarCommand(dir string) string {
return fmt.Sprintf("tar -C \"%s\" -xzf -", dir)
}

func LocalTarCmdArgs(path, exclude string) []string {
func LocalTarCmdArgs(basedir, path, exclude string) []string {
args := []string{}

// Added pattens to exclude from tar compress
Expand All @@ -31,14 +31,14 @@ func LocalTarCmdArgs(path, exclude string) []string {
}
}

args = append(args, "-C", ".", "-czf", "-", path)
args = append(args, "-C", basedir, "-czf", "-", path)
return args
}

// NewTarStreamReader creates a tar stream reader from a local path.
// TODO: Refactor. Use "archive/tar" instead.
func NewTarStreamReader(cwd, path, exclude string) (io.Reader, error) {
cmd := exec.Command("tar", LocalTarCmdArgs(path, exclude)...)
func NewTarStreamReader(cwd, basedir, path, exclude string) (io.Reader, error) {
cmd := exec.Command("tar", LocalTarCmdArgs(basedir, path, exclude)...)
cmd.Dir = cwd
stdout, err := cmd.StdoutPipe()
if err != nil {
Expand Down
15 changes: 14 additions & 1 deletion task.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"io/ioutil"
"os"
"path/filepath"

"github.com/pkg/errors"
)
Expand All @@ -31,7 +32,19 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]*
if err != nil {
return nil, errors.Wrap(err, "upload: "+upload.Src)
}
uploadTarReader, err := NewTarStreamReader(cwd, uploadFile, upload.Exc)

tarBaseDir := "."
tarTarget := uploadFile

if upload.Flatten {
// If flatten is true, we change directory to the file's parent
// and archive only the filename.
// e.g. "./bin/squeed" -> BaseDir: "./bin", Target: "squeed"
tarBaseDir = filepath.Dir(uploadFile)
tarTarget = filepath.Base(uploadFile)
}

uploadTarReader, err := NewTarStreamReader(cwd, tarBaseDir, tarTarget, upload.Exc)
if err != nil {
return nil, errors.Wrap(err, "upload: "+upload.Src)
}
Expand Down