diff --git a/cmd/step/main_test.go b/cmd/step/main_test.go index 82979e39d..33a48b74f 100644 --- a/cmd/step/main_test.go +++ b/cmd/step/main_test.go @@ -5,6 +5,7 @@ import ( "regexp" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -12,14 +13,15 @@ func TestAppHasAllCommands(t *testing.T) { app := newApp(&bytes.Buffer{}, &bytes.Buffer{}) require.NotNil(t, app) - require.Equal(t, "step", app.Name) - require.Equal(t, "step", app.HelpName) + assert.Equal(t, "step", app.Name) + assert.Equal(t, "step", app.HelpName) var names = make([]string, 0, len(app.Commands)) for _, c := range app.Commands { names = append(names, c.Name) } - require.Equal(t, []string{ + + assert.ElementsMatch(t, []string{ "help", "api", "path", "base64", "fileserver", "certificate", "completion", "context", "crl", "crypto", "oauth", "version", "ca", "beta", "ssh", @@ -42,5 +44,5 @@ func TestAppRuns(t *testing.T) { require.Empty(t, stderr.Bytes()) output := ansiRegex.ReplaceAllString(stdout.String(), "") - require.Contains(t, output, "step -- plumbing for distributed systems") + assert.Contains(t, output, "step -- plumbing for distributed systems") } diff --git a/command/api/token/create.go b/command/api/token/create.go index 3d5a2fd29..15fd087a7 100644 --- a/command/api/token/create.go +++ b/command/api/token/create.go @@ -2,19 +2,32 @@ package token import ( "bytes" + "context" + "crypto" "crypto/tls" "encoding/json" + "encoding/pem" "errors" "fmt" "net/http" "net/url" + "os" "path" "github.com/google/uuid" "github.com/urfave/cli" + "github.com/smallstep/certificates/ca" "github.com/smallstep/cli-utils/errs" "github.com/smallstep/cli-utils/ui" + "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/randutil" + "go.step.sm/crypto/tpm" + "go.step.sm/crypto/tpm/tss2" + + "github.com/smallstep/cli/flags" + "github.com/smallstep/cli/internal/cryptoutil" + "github.com/smallstep/cli/internal/httptransport" ) func createCommand() cli.Command { @@ -28,6 +41,8 @@ func createCommand() cli.Command { Flags: []cli.Flag{ apiURLFlag, audienceFlag, + flags.PasswordFile, + tpmDeviceFlag, }, Description: `**step ca api token create** creates a new token for connecting to the Smallstep API. @@ -43,14 +58,29 @@ func createCommand() cli.Command { : File to read the private key (PEM format). ## EXAMPLES -Use a certificate to get a new API token: +Use a certificate and team ID to get a new API token: ''' $ step api token create ff98be70-7cc3-4df5-a5db-37f5d3c96e23 internal.crt internal.key ''' Get a token using the team slug: ''' -$ step api token create teamfoo internal.crt internal.key +$ step api token create team-foo internal.crt internal.key +''' + +Use a certificate with a private key backed by a TPM to get a new API token: +''' +$ step api token create team-tpm ecdsa-chain.crt 'tpmkms:name=test-ecdsa' +''' + +Use a certificate with a private key backed by a TPM simulator to get a new API token: +''' +$ step api token create team-tpm-simulator ecdsa-chain.crt 'tpmkms:name=test-ecdsa;device=/path/to/tpmsimulator.sock' +''' + +Use a certificate and a TSS2 PEM encoded private key to get a new API token: +''' +$ step api token create team-tss2 ecdsa-chain.crt ecdsa.tss2.pem --tpm-device /dev/tpmrm0 ''' `, } @@ -73,54 +103,65 @@ func createAction(ctx *cli.Context) (err error) { return err } - args := ctx.Args() - - teamID := args.Get(0) - crtFile := args.Get(1) - keyFile := args.Get(2) + var ( + args = ctx.Args() + teamID = args.Get(0) + crtFile = args.Get(1) + keyFile = args.Get(2) + passwordFile = ctx.String("password-file") + apiURLFlag = ctx.String("api-url") + audience = ctx.String("audience") + tpmDevice = ctx.String("tpm-device") + ) - parsedURL, err := url.Parse(ctx.String("api-url")) + parsedURL, err := url.Parse(apiURLFlag) if err != nil { return err } parsedURL.Path = path.Join(parsedURL.Path, "api/auth") apiURL := parsedURL.String() - clientCert, err := tls.LoadX509KeyPair(crtFile, keyFile) + clientCert, err := createClientCertificate(crtFile, keyFile, passwordFile, tpmDevice) if err != nil { return err } - b := &bytes.Buffer{} - r := &createTokenReq{ + + b := new(bytes.Buffer) + r := createTokenReq{ Bundle: clientCert.Certificate, - Audience: ctx.String("audience"), + Audience: audience, } + if err := uuid.Validate(teamID); err != nil { r.TeamSlug = teamID } else { r.TeamID = teamID } - err = json.NewEncoder(b).Encode(r) - if err != nil { - return err - } - post, err := http.NewRequest("POST", apiURL, b) - if err != nil { + if err := json.NewEncoder(b).Encode(r); err != nil { return err } - post.Header.Set("Content-Type", "application/json") - transport := http.DefaultTransport.(*http.Transport).Clone() + + transport := httptransport.New() transport.TLSClientConfig = &tls.Config{ GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &clientCert, nil + return clientCert, nil }, MinVersion: tls.VersionTLS12, } client := http.Client{ Transport: transport, } - resp, err := client.Do(post) + + req, err := http.NewRequest("POST", apiURL, b) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", ca.UserAgent) // this is set to step.Version() during init; i.e. "Smallstep CLI/vX.X.X (os/arch)" + req.Header.Set(requestIDHeader, newRequestID()) + + resp, err := client.Do(req) if err != nil { return err } @@ -143,3 +184,103 @@ func createAction(ctx *cli.Context) (err error) { return nil } + +// requestIDHeader is the header name used for propagating request IDs from +// the client to the server and back again. +const requestIDHeader = "X-Request-Id" + +// newRequestID generates a new random UUIDv4 request ID. If it fails, +// the request ID will be the empty string. +func newRequestID() string { + requestID, err := randutil.UUIDv4() + if err != nil { + return "" + } + + return requestID +} + +func createClientCertificate(crtFile, keyFile, passwordFile, tpmDevice string) (*tls.Certificate, error) { + certs, err := pemutil.ReadCertificateBundle(crtFile) + if err != nil { + return nil, fmt.Errorf("failed reading %q: %w", crtFile, err) + } + + var certificates = make([][]byte, len(certs)) + for i, c := range certs { + certificates[i] = c.Raw + } + + pk, err := getPrivateKey(keyFile, passwordFile, tpmDevice) + if err != nil { + return nil, fmt.Errorf("failed reading key from %q: %w", keyFile, err) + } + + if _, ok := pk.(crypto.Signer); !ok { + return nil, fmt.Errorf("private key type %T read from %q cannot be used as a signer", pk, keyFile) + } + + return &tls.Certificate{ + Certificate: certificates, + Leaf: certs[0], + PrivateKey: pk, + }, nil +} + +func getPrivateKey(keyFile, passwordFile, tpmDevice string) (crypto.PrivateKey, error) { + if cryptoutil.IsKMS(keyFile) { + signer, err := cryptoutil.CreateSigner(keyFile, keyFile) + if err != nil { + return nil, fmt.Errorf("failed creating signer: %w", err) + } + + return signer, nil + } + + b, err := os.ReadFile(keyFile) + if err != nil { + return nil, err + } + + // detect the type of the PEM file. if it's a TSS2 PEM file, pemutil + // can't be used to create a private key, as it does not support this + // type. Support could be added, but it could require some additional + // options, such as specifying the TPM device that backs the TSS2 + // signer. + p, _ := pem.Decode(b) + if p.Type != "TSS2 PRIVATE KEY" { + var opts []pemutil.Options + if passwordFile != "" { + opts = append(opts, pemutil.WithPasswordFile(passwordFile)) + } + + pk, err := pemutil.Parse(b, opts...) + if err != nil { + return nil, fmt.Errorf("failed parsing PEM: %w", err) + } + + return pk, nil + } + + key, err := tss2.ParsePrivateKey(p.Bytes) + if err != nil { + return nil, fmt.Errorf("failed creating TSS2 private key: %w", err) + } + + var tpmOpts = []tpm.NewTPMOption{} + if tpmDevice != "" { + tpmOpts = append(tpmOpts, tpm.WithDeviceName(tpmDevice)) + } + + t, err := tpm.New(tpmOpts...) + if err != nil { + return nil, fmt.Errorf("failed initializing TPM: %w", err) + } + + signer, err := tpm.CreateTSS2Signer(context.Background(), t, key) + if err != nil { + return nil, fmt.Errorf("failed creating TSS2 signer: %w", err) + } + + return signer, nil +} diff --git a/command/api/token/token.go b/command/api/token/token.go index 1925e63cb..50610346e 100644 --- a/command/api/token/token.go +++ b/command/api/token/token.go @@ -30,4 +30,8 @@ var ( Name: "audience", Usage: "Request a token for an audience other than the API Gateway", } + tpmDeviceFlag = cli.StringFlag{ + Name: "tpm-device", + Usage: "(Optional) path to TPM device (e.g. /dev/tpmrm0)", + } ) diff --git a/internal/httptransport/httptransport.go b/internal/httptransport/httptransport.go new file mode 100644 index 000000000..b14862488 --- /dev/null +++ b/internal/httptransport/httptransport.go @@ -0,0 +1,26 @@ +// Package httptransport implements initialization of [http.Transport] instances and related +// functionality. +package httptransport + +import ( + "net" + "net/http" + "time" +) + +// New returns a reference to an [http.Transport] that's initialized just like the +// [http.DefaultTransport] is by the standard library. +func New() *http.Transport { + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +}