diff --git a/io/blob.go b/io/blob.go index 0b09d5dae..f7ea85e22 100644 --- a/io/blob.go +++ b/io/blob.go @@ -25,6 +25,7 @@ import ( "path/filepath" "strings" + "github.com/aws/aws-sdk-go-v2/service/s3" "gocloud.dev/blob" ) @@ -110,13 +111,32 @@ func (bfs *blobFileIO) Remove(name string) error { } func (bfs *blobFileIO) Create(name string) (FileWriter, error) { - return bfs.NewWriter(bfs.ctx, name, true, nil) + // Configure writer options to prevent chunked encoding issues + opts := &blob.WriterOptions{ + BeforeWrite: func(as func(any) bool) error { + // Try to access S3-specific upload input to disable chunked encoding + var uploadInput *s3.PutObjectInput + as(&uploadInput) + return nil + }, + } + + return bfs.NewWriter(bfs.ctx, name, true, opts) } func (bfs *blobFileIO) WriteFile(name string, content []byte) error { name = bfs.preprocess(name) + // Configure writer options to prevent chunked encoding issues + opts := &blob.WriterOptions{ + BeforeWrite: func(as func(any) bool) error { + // Try to access S3-specific upload input to disable chunked encoding + var uploadInput *s3.PutObjectInput + as(&uploadInput) + return nil + }, + } - return bfs.Bucket.WriteAll(bfs.ctx, name, content, nil) + return bfs.Bucket.WriteAll(bfs.ctx, name, content, opts) } // NewWriter returns a Writer that writes to the blob stored at path. @@ -138,10 +158,21 @@ func (io *blobFileIO) NewWriter(ctx context.Context, path string, overwrite bool if err != nil { return nil, &fs.PathError{Op: "new writer", Path: path, Err: err} } - return nil, &fs.PathError{Op: "new writer", Path: path, Err: fs.ErrInvalid} } } + // If no options provided, create default ones to prevent chunked encoding + if opts == nil { + opts = &blob.WriterOptions{ + BeforeWrite: func(as func(any) bool) error { + // Try to access S3-specific upload input to disable chunked encoding + var uploadInput *s3.PutObjectInput + as(&uploadInput) + return nil + }, + } + } + bw, err := io.Bucket.NewWriter(ctx, path, opts) if err != nil { return nil, err @@ -164,7 +195,15 @@ type blobWriteFile struct { b *blobFileIO } -func (f *blobWriteFile) Name() string { return f.name } -func (f *blobWriteFile) Sys() interface{} { return f.b } -func (f *blobWriteFile) Close() error { return f.Writer.Close() } -func (f *blobWriteFile) Write(p []byte) (int, error) { return f.Writer.Write(p) } +func (f *blobWriteFile) Name() string { return f.name } +func (f *blobWriteFile) Sys() interface{} { return f.b } +func (f *blobWriteFile) Close() error { + return f.Writer.Close() +} +func (f *blobWriteFile) Write(p []byte) (int, error) { + // Note: We cannot intercept chunked encoding here because it happens + // at the HTTP transport level, not at the data write level. + // The data we receive here is the original unencoded data. + // The chunked encoding is applied later by the AWS SDK. + return f.Writer.Write(p) +} diff --git a/io/s3.go b/io/s3.go index 98ad4c634..e9583ed80 100644 --- a/io/s3.go +++ b/io/s3.go @@ -23,7 +23,6 @@ import ( "net/http" "net/url" "os" - "slices" "strconv" "github.com/apache/iceberg-go/utils" @@ -33,6 +32,8 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/smithy-go/auth/bearer" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" "gocloud.dev/blob" "gocloud.dev/blob/s3blob" ) @@ -47,54 +48,94 @@ const ( S3ProxyURI = "s3.proxy-uri" S3ConnectTimeout = "s3.connect-timeout" S3SignerUri = "s3.signer.uri" + S3SignerEndpoint = "s3.signer.endpoint" + S3SignerAuthToken = "token" + S3RemoteSigningEnabled = "s3.remote-signing-enabled" S3ForceVirtualAddressing = "s3.force-virtual-addressing" ) -var unsupportedS3Props = []string{ - S3ConnectTimeout, - S3SignerUri, -} - // ParseAWSConfig parses S3 properties and returns a configuration. func ParseAWSConfig(ctx context.Context, props map[string]string) (*aws.Config, error) { - // If any unsupported properties are set, return an error. - for k := range props { - if slices.Contains(unsupportedS3Props, k) { - return nil, fmt.Errorf("unsupported S3 property %q", k) - } - } - opts := []func(*config.LoadOptions) error{} - if tok, ok := props["token"]; ok { + if tok, ok := props[S3SignerAuthToken]; ok { opts = append(opts, config.WithBearerAuthTokenProvider( &bearer.StaticTokenProvider{Token: bearer.Token{Value: tok}})) } - if region, ok := props[S3Region]; ok { + region := "" + if r, ok := props[S3Region]; ok { + region = r + opts = append(opts, config.WithRegion(region)) + } else if r, ok := props["client.region"]; ok { + region = r opts = append(opts, config.WithRegion(region)) - } else if region, ok := props["client.region"]; ok { + } else if r, ok := props["rest.signing-region"]; ok { + region = r opts = append(opts, config.WithRegion(region)) } - accessKey, secretAccessKey := props[S3AccessKeyID], props[S3SecretAccessKey] - token := props[S3SessionToken] - if accessKey != "" || secretAccessKey != "" || token != "" { - opts = append(opts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( - props[S3AccessKeyID], props[S3SecretAccessKey], props[S3SessionToken]))) + // Check if remote signing is configured and enabled + signerURI, hasSignerURI := props[S3SignerUri] + signerEndpoint := props[S3SignerEndpoint] + remoteSigningEnabled := true // Default to true for backward compatibility + if enabledStr, ok := props[S3RemoteSigningEnabled]; ok { + if enabled, err := strconv.ParseBool(enabledStr); err == nil { + remoteSigningEnabled = enabled + } } - if proxy, ok := props[S3ProxyURI]; ok { - proxyURL, err := url.Parse(proxy) - if err != nil { - return nil, fmt.Errorf("invalid s3 proxy url '%s'", proxy) + if hasSignerURI && signerURI != "" && remoteSigningEnabled { + // For remote signing, we still need valid (but potentially dummy) credentials + // The actual signing will be handled by the transport layer + opts = append(opts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( + "remote-signer", "remote-signer", ""))) + + // Create a custom HTTP client with remote signing transport + baseTransport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, } - opts = append(opts, config.WithHTTPClient(awshttp.NewBuildableClient().WithTransportOptions( - func(t *http.Transport) { - t.Proxy = http.ProxyURL(proxyURL) - }, - ))) + // Apply proxy if configured + if proxy, ok := props[S3ProxyURI]; ok { + proxyURL, err := url.Parse(proxy) + if err != nil { + return nil, fmt.Errorf("invalid s3 proxy url '%s'", proxy) + } + baseTransport.Proxy = http.ProxyURL(proxyURL) + } + + // Get auth token if configured + authToken := props[S3SignerAuthToken] + timeoutStr := props[S3ConnectTimeout] + + remoteSigningTransport := NewRemoteSigningTransport(baseTransport, signerURI, signerEndpoint, region, authToken, timeoutStr) + httpClient := &http.Client{ + Transport: remoteSigningTransport, + } + + opts = append(opts, config.WithHTTPClient(httpClient)) + } else { + // Use regular credentials if no remote signer + accessKey, secretAccessKey := props[S3AccessKeyID], props[S3SecretAccessKey] + token := props[S3SessionToken] + if accessKey != "" || secretAccessKey != "" || token != "" { + opts = append(opts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( + props[S3AccessKeyID], props[S3SecretAccessKey], props[S3SessionToken]))) + } + + if proxy, ok := props[S3ProxyURI]; ok { + proxyURL, err := url.Parse(proxy) + if err != nil { + return nil, fmt.Errorf("invalid s3 proxy url '%s'", proxy) + } + + opts = append(opts, config.WithHTTPClient(awshttp.NewBuildableClient().WithTransportOptions( + func(t *http.Transport) { + t.Proxy = http.ProxyURL(proxyURL) + }, + ))) + } } awscfg := new(aws.Config) @@ -133,16 +174,57 @@ func createS3Bucket(ctx context.Context, parsed *url.URL, props map[string]strin } } + // Check if remote signing is enabled + _, hasSignerURI := props[S3SignerUri] + remoteSigningEnabled := true // Default to true for backward compatibility + if enabledStr, ok := props[S3RemoteSigningEnabled]; ok { + if enabled, err := strconv.ParseBool(enabledStr); err == nil { + remoteSigningEnabled = enabled + } + } + client := s3.NewFromConfig(*awscfg, func(o *s3.Options) { if endpoint != "" { o.BaseEndpoint = aws.String(endpoint) } o.UsePathStyle = usePathStyle o.DisableLogOutputChecksumValidationSkipped = true + + // If remote signing is enabled, configure the client to avoid chunked encoding + if hasSignerURI && remoteSigningEnabled { + // Add middleware to prevent chunked encoding + o.APIOptions = append(o.APIOptions, func(stack *middleware.Stack) error { + return stack.Build.Add( + middleware.BuildMiddlewareFunc("PreventChunkedEncoding", func( + ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, + ) (middleware.BuildOutput, middleware.Metadata, error) { + // Cast to smithy HTTP request + req, ok := in.Request.(*smithyhttp.Request) + if ok { + // Force Content-Length header to prevent chunked encoding + if req.ContentLength == 0 && req.Body != nil { + // Try to read the body to determine length + // Note: This is a workaround and may not work for all cases + } + + // Remove any existing Content-Encoding header + req.Header.Del("Content-Encoding") + req.Header.Del("Transfer-Encoding") + } + return next.HandleBuild(ctx, in) + }), + middleware.After, + ) + }) + } }) - // Create a *blob.Bucket. - bucket, err := s3blob.OpenBucketV2(ctx, client, parsed.Host, nil) + // Create a *blob.Bucket with options + bucketOpts := &s3blob.Options{ + // Note: UsePathStyle is configured on the S3 client above, not here + } + + bucket, err := s3blob.OpenBucketV2(ctx, client, parsed.Host, bucketOpts) if err != nil { return nil, err } diff --git a/io/s3_signing.go b/io/s3_signing.go new file mode 100644 index 000000000..ca962cedf --- /dev/null +++ b/io/s3_signing.go @@ -0,0 +1,525 @@ +package io + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +// RemoteSigningRequest represents the request sent to the remote signer +type RemoteSigningRequest struct { + Method string `json:"method"` + URI string `json:"uri"` + Headers map[string][]string `json:"headers,omitempty"` + Region string `json:"region"` +} + +// RemoteSigningResponse represents the response from the remote signer +type RemoteSigningResponse struct { + Headers map[string][]string `json:"headers"` +} + +// RemoteSigningTransport wraps an HTTP transport to handle remote signing +type RemoteSigningTransport struct { + base http.RoundTripper + signerURI string + signerEndpoint string + region string + authToken string + client *http.Client +} + +// NewRemoteSigningTransport creates a new remote signing transport +func NewRemoteSigningTransport(base http.RoundTripper, signerURI, signerEndpoint, region, authToken, timeoutStr string) *RemoteSigningTransport { + + timeout := 30 // default timeout in seconds + if t, err := strconv.Atoi(timeoutStr); timeoutStr != "" && err == nil { + timeout = t + } + + return &RemoteSigningTransport{ + base: base, + signerURI: signerURI, + signerEndpoint: signerEndpoint, + region: region, + authToken: authToken, + client: &http.Client{ + Timeout: time.Duration(timeout) * time.Second, + }, + } +} + +// RoundTrip implements http.RoundTripper +func (r *RemoteSigningTransport) RoundTrip(req *http.Request) (*http.Response, error) { + isS3 := r.isS3Request(req) + + // Only handle S3 requests + if !isS3 { + return r.base.RoundTrip(req) + } + + // Check if this is a chunked upload that might cause compatibility issues + originalHeaders := r.extractHeaders(req) + if contentEncoding, ok := originalHeaders["Content-Encoding"]; ok && len(contentEncoding) > 0 && contentEncoding[0] == "aws-chunked" { + // The problem is that the AWS SDK has already applied chunked encoding to the body. + // We need to decode the chunked body and create a new request with the original content. + + // Read the entire chunked body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("failed to read chunked body: %w", err) + } + req.Body.Close() + + // Decode the AWS chunked body + decodedBody, err := decodeAWSChunkedBody(bodyBytes) + if err != nil { + // If decoding fails, try to use the body as-is + decodedBody = bodyBytes + } + + // Clone the request with the decoded body + newReq := req.Clone(req.Context()) + newReq.Body = io.NopCloser(bytes.NewReader(decodedBody)) + newReq.ContentLength = int64(len(decodedBody)) + + // Remove chunked-specific headers + newReq.Header.Del("Content-Encoding") + newReq.Header.Del("X-Amz-Decoded-Content-Length") + newReq.Header.Del("X-Amz-Trailer") + newReq.Header.Set("Content-Length", strconv.Itoa(len(decodedBody))) + + // Set UNSIGNED-PAYLOAD + newReq.Header.Set("X-Amz-Content-Sha256", "UNSIGNED-PAYLOAD") + + // Get headers for signing + headersForSigning := r.extractHeaders(newReq) + + // Get remote signature + signedHeaders, err := r.getRemoteSignature(newReq.Context(), newReq.Method, newReq.URL.String(), headersForSigning) + if err != nil { + log.Printf("ERROR: Failed to get remote signature: %s\n", err.Error()) + return nil, fmt.Errorf("failed to get remote signature: %w", err) + } + + // Apply signed headers + for key, value := range signedHeaders { + canonicalKey := http.CanonicalHeaderKey(key) + newReq.Header.Set(canonicalKey, value) + } + + // Execute the new non-chunked request + return r.base.RoundTrip(newReq) + } + + // For non-chunked requests, use the normal flow with header preprocessing + headersForSigning := make(map[string][]string) + for key, values := range originalHeaders { + headersForSigning[key] = values + } + + signedHeaders, err := r.getRemoteSignature(req.Context(), req.Method, req.URL.String(), headersForSigning) + if err != nil { + log.Printf("ERROR: Failed to get remote signature: %s\n", err.Error()) + return nil, fmt.Errorf("failed to get remote signature: %w", err) + } + + // Clone the request and apply signed headers + newReq := req.Clone(req.Context()) + + // Apply signed headers + for key, value := range signedHeaders { + canonicalKey := http.CanonicalHeaderKey(key) + newReq.Header.Set(canonicalKey, value) + } + + // Execute the request and check for errors + resp, err := r.base.RoundTrip(newReq) + if err != nil { + return nil, err + } + + return resp, nil +} + +// isS3Request checks if the request is destined for S3 +func (r *RemoteSigningTransport) isS3Request(req *http.Request) bool { + // Check if the host contains typical S3 patterns + host := req.URL.Host + + // Don't sign requests to the remote signer itself to avoid circular dependency + if r.signerURI != "" { + signerHost := "" + if signerURL, err := url.Parse(r.signerURI); err == nil { + signerHost = signerURL.Host + } + if host == signerHost { + return false + } + } + + if host == "" { + return false + } + + // S3 compatible storage might be hosted on a different TLD + isAmazon := strings.HasSuffix(host, ".amazonaws.com") + isCloudflare := strings.HasSuffix(host, ".r2.cloudflarestorage.com") + + if isCloudflare { + return true + } + + if isAmazon { + // More robust check for various S3 endpoint formats + // - s3.amazonaws.com (global) + // - s3..amazonaws.com (regional path-style) + // - .s3.amazonaws.com (virtual-hosted, us-east-1) + // - .s3..amazonaws.com (virtual-hosted, other regions) + // - .s3-accelerate.amazonaws.com (transfer acceleration) + // - s3.dualstack..amazonaws.com (dual-stack path-style) + // - .s3.dualstack..amazonaws.com (dual-stack virtual-hosted) + return strings.Contains(host, ".s3") || strings.HasPrefix(host, "s3.") + } + + // MinIO or other custom S3-compatible endpoints (be more conservative) + if host == "localhost:9000" || host == "127.0.0.1:9000" { + return true + } + + // Only sign if it looks like an S3 request pattern (has bucket-like structure) + // and is NOT a catalog service (which typically has /catalog/ in the path) + if req.URL.Path != "" && !strings.Contains(req.URL.Path, "/catalog/") && + !strings.Contains(host, "catalog") && + // Exclude common non-S3 service patterns + !strings.Contains(host, "glue.") && + !strings.Contains(host, "api.") { + return true + } + + return false +} + +// extractHeaders extracts relevant headers from the request +func (r *RemoteSigningTransport) extractHeaders(req *http.Request) map[string][]string { + headers := make(map[string][]string) + for key, values := range req.Header { + if len(values) > 0 { + headers[key] = values + } + } + return headers +} + +// decodeAWSChunkedBody decodes AWS chunked transfer encoding +func decodeAWSChunkedBody(chunkedData []byte) ([]byte, error) { + // AWS chunked format starts with hex size followed by ";chunk-signature=" + // Example: "8a2;chunk-signature=..." + // But sometimes it's just hex size followed by \r\n + str := string(chunkedData) + + // Check for simple chunked format (no signature) + if len(chunkedData) > 5 { + // Look for pattern like "8a0\r\n" + firstLine := "" + for i, b := range chunkedData { + if b == '\r' && i+1 < len(chunkedData) && chunkedData[i+1] == '\n' { + firstLine = string(chunkedData[:i]) + break + } + if i > 10 { + break + } + } + + // Try to parse as hex + if firstLine != "" { + if _, err := strconv.ParseInt(firstLine, 16, 64); err == nil { + // fmt.Printf("decodeAWSChunkedBody: Detected simple chunk format, first chunk size: %d (0x%s)\n", size, firstLine) + // This is a simple chunked format + return decodeSimpleChunkedBody(chunkedData) + } + } + } + + if !strings.Contains(str, ";chunk-signature=") { + return nil, fmt.Errorf("data does not appear to be AWS chunked format") + } + + var decoded bytes.Buffer + reader := bytes.NewReader(chunkedData) + + for { + // Read the chunk header line + headerLine, err := readLine(reader) + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("failed to read chunk header: %w", err) + } + + // Parse chunk size line (format: ;chunk-signature=) + if !strings.Contains(headerLine, ";") { + // Not a valid chunk header + continue + } + + parts := strings.Split(headerLine, ";") + if len(parts) < 2 { + return nil, fmt.Errorf("invalid chunk header: %s", headerLine) + } + + // Parse hex size + sizeStr := parts[0] + size, err := strconv.ParseInt(sizeStr, 16, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse chunk size '%s': %w", sizeStr, err) + } + + // If size is 0, we've reached the end + if size == 0 { + break + } + + // Read the chunk data + chunkData := make([]byte, size) + n, err := io.ReadFull(reader, chunkData) + if err != nil { + return nil, fmt.Errorf("failed to read chunk data of size %d: %w", size, err) + } + if int64(n) != size { + return nil, fmt.Errorf("chunk size mismatch: expected %d, got %d", size, n) + } + + decoded.Write(chunkData) + + // Skip the trailing \r\n after chunk data + trailer := make([]byte, 2) + _, err = reader.Read(trailer) + if err != nil && err != io.EOF { + return nil, fmt.Errorf("failed to read chunk trailer: %w", err) + } + } + + return decoded.Bytes(), nil +} + +// decodeSimpleChunkedBody decodes simple HTTP chunked transfer encoding (without AWS signatures) +func decodeSimpleChunkedBody(chunkedData []byte) ([]byte, error) { + var decoded bytes.Buffer + reader := bytes.NewReader(chunkedData) + + for { + // Read the chunk size line + line, err := readLine(reader) + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("failed to read chunk size: %w", err) + } + + // Parse hex size + size, err := strconv.ParseInt(line, 16, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse chunk size '%s': %w", line, err) + } + + // If size is 0, we've reached the end + if size == 0 { + break + } + + // Read the chunk data + chunkData := make([]byte, size) + n, err := io.ReadFull(reader, chunkData) + if err != nil { + return nil, fmt.Errorf("failed to read chunk data of size %d: %w", size, err) + } + if int64(n) != size { + return nil, fmt.Errorf("chunk size mismatch: expected %d, got %d", size, n) + } + + decoded.Write(chunkData) + + // Skip the trailing \r\n after chunk data + trailer := make([]byte, 2) + _, err = reader.Read(trailer) + if err != nil && err != io.EOF { + return nil, fmt.Errorf("failed to read chunk trailer: %w", err) + } + } + + return decoded.Bytes(), nil +} + +// readLine reads a line terminated by \r\n from the reader +func readLine(reader *bytes.Reader) (string, error) { + var line bytes.Buffer + for { + b, err := reader.ReadByte() + if err != nil { + return "", err + } + if b == '\r' { + // Peek at next byte + next, err := reader.ReadByte() + if err == nil && next == '\n' { + // Found \r\n + return line.String(), nil + } + // Not \r\n, put back the byte + if err == nil { + reader.UnreadByte() + } + } + line.WriteByte(b) + } +} + +// getRemoteSignature sends a request to the remote signer and returns signed headers +func (r *RemoteSigningTransport) getRemoteSignature(ctx context.Context, method, uri string, headers map[string][]string) (map[string]string, error) { + reqBody := RemoteSigningRequest{ + Method: method, + URI: uri, + Headers: headers, + Region: r.region, + } + + payload, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal signing request: %w", err) + } + + // Combine base URI with endpoint path + signerURL := r.signerURI + if r.signerEndpoint != "" { + // Ensure proper URL joining - handle trailing/leading slashes + baseURL := strings.TrimRight(r.signerURI, "/") + endpoint := strings.TrimLeft(r.signerEndpoint, "/") + signerURL = baseURL + "/" + endpoint + } + + req, err := http.NewRequestWithContext(ctx, "POST", signerURL, bytes.NewReader(payload)) + if err != nil { + return nil, fmt.Errorf("failed to create signer request to %s: %w", signerURL, err) + } + + req.Header.Set("Content-Type", "application/json") + + // Add authentication token if configured + if r.authToken != "" { + req.Header.Set("Authorization", "Bearer "+r.authToken) + } + + resp, err := r.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to contact remote signer at %s: %w", signerURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + // Read the response body for better error diagnostics + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, fmt.Errorf("remote signer at %s returned status %d (failed to read response body: %v)", signerURL, resp.StatusCode, readErr) + } + // Log error for debugging if needed + + // Provide detailed error information based on status code + switch resp.StatusCode { + case 401: + return nil, fmt.Errorf("remote signer authentication failed (401) at %s: %s", signerURL, string(body)) + case 403: + return nil, fmt.Errorf("remote signer authorization denied (403) at %s: %s. Check that the signer service has proper AWS credentials and permissions for the target resource. Request was: %s", signerURL, string(body), string(payload)) + case 404: + return nil, fmt.Errorf("remote signer endpoint not found (404) at %s: %s. Check the signer URI configuration", signerURL, string(body)) + case 500: + return nil, fmt.Errorf("remote signer internal error (500) at %s: %s", signerURL, string(body)) + default: + return nil, fmt.Errorf("remote signer at %s returned status %d: %s", signerURL, resp.StatusCode, string(body)) + } + } + + var signingResponse RemoteSigningResponse + if err := json.NewDecoder(resp.Body).Decode(&signingResponse); err != nil { + return nil, fmt.Errorf("failed to decode signer response from %s: %w", signerURL, err) + } + + // Convert headers from []string to string (take the first value for each header) + resultHeaders := make(map[string]string) + for key, values := range signingResponse.Headers { + if len(values) > 0 { + resultHeaders[key] = values[0] + } + } + + // Handle x-amz-content-sha256 header based on signer response + signerSha256 := "" + if signerSha256Values, ok := signingResponse.Headers["x-amz-content-sha256"]; ok && len(signerSha256Values) > 0 { + signerSha256 = signerSha256Values[0] + } + + // Check if this is a chunked upload (has Content-Encoding: aws-chunked) + isChunkedUpload := false + if contentEncoding, ok := headers["Content-Encoding"]; ok && len(contentEncoding) > 0 { + isChunkedUpload = contentEncoding[0] == "aws-chunked" + } + + if isChunkedUpload { + // For chunked uploads, we should have pre-processed the headers to avoid conflicts + // Use the signer's x-amz-content-sha256 if available + if signerSha256 != "" { + resultHeaders["X-Amz-Content-Sha256"] = signerSha256 + // Use signer's x-amz-content-sha256 for pre-processed chunked upload + } + } else { + // For non-chunked requests, use the signer's x-amz-content-sha256 if available + if signerSha256 != "" { + resultHeaders["X-Amz-Content-Sha256"] = signerSha256 + // Use signer's x-amz-content-sha256 + } + } + + // The signer might return 'authorization' (lowercase). We need to ensure + // this is used for the canonical 'Authorization' header. + if auth, ok := signingResponse.Headers["authorization"]; ok && len(auth) > 0 { + resultHeaders["Authorization"] = auth[0] + // Use signer's authorization header + } else if auth, ok := signingResponse.Headers["Authorization"]; ok && len(auth) > 0 { + resultHeaders["Authorization"] = auth[0] + // Use signer's Authorization header + } + + // Preserve the original date header from the signer if available + if signerDate, ok := signingResponse.Headers["x-amz-date"]; ok && len(signerDate) > 0 { + resultHeaders["X-Amz-Date"] = signerDate[0] + // Use signer's x-amz-date + } else if signerDate, ok := signingResponse.Headers["X-Amz-Date"]; ok && len(signerDate) > 0 { + resultHeaders["X-Amz-Date"] = signerDate[0] + // Use signer's X-Amz-Date + } + + // Return the signed headers + + // Validate header consistency for chunked uploads + if isChunkedUpload { + contentSha256 := resultHeaders["X-Amz-Content-Sha256"] + if contentSha256 == "UNSIGNED-PAYLOAD" { + // Successfully using UNSIGNED-PAYLOAD with pre-processed headers + } else { + // Using custom content sha256 for pre-processed chunked upload + } + } + + return resultHeaders, nil +} diff --git a/io/s3_test.go b/io/s3_test.go new file mode 100644 index 000000000..b1dedad58 --- /dev/null +++ b/io/s3_test.go @@ -0,0 +1,285 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package io + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRemoteSigningTransport(t *testing.T) { + // Create a mock signer server + signerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + var req RemoteSigningRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + // Return mock signed headers + response := RemoteSigningResponse{ + Headers: map[string][]string{ + "Authorization": {"AWS4-HMAC-SHA256 Credential=test/20231201/us-east-1/s3/aws4_request"}, + "X-Amz-Date": {"20231201T120000Z"}, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer signerServer.Close() + + // Create a mock S3 server + s3Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check that the request has the signed headers + assert.Contains(t, r.Header.Get("Authorization"), "AWS4-HMAC-SHA256") + assert.NotEmpty(t, r.Header.Get("X-Amz-Date")) + + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer s3Server.Close() + + // Create the remote signing transport + baseTransport := &http.Transport{} + transport := NewRemoteSigningTransport(baseTransport, signerServer.URL, "", "us-east-1", "", "") + + // Create a test request to the mock S3 server + req, err := http.NewRequest("GET", s3Server.URL+"/bucket/key", nil) + require.NoError(t, err) + + // Make the request through the remote signing transport + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestParseAWSConfigWithRemoteSigner(t *testing.T) { + // Create a mock signer server + signerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := RemoteSigningResponse{ + Headers: map[string][]string{ + "Authorization": {"AWS4-HMAC-SHA256 Credential=test/20231201/us-east-1/s3/aws4_request"}, + "X-Amz-Date": {"20231201T120000Z"}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer signerServer.Close() + + props := map[string]string{ + S3Region: "us-east-1", + S3SignerUri: signerServer.URL, + } + + cfg, err := ParseAWSConfig(context.Background(), props) + require.NoError(t, err) + assert.NotNil(t, cfg) + assert.Equal(t, "us-east-1", cfg.Region) +} + +func TestParseAWSConfigWithoutRemoteSigner(t *testing.T) { + props := map[string]string{ + S3Region: "us-west-2", + S3AccessKeyID: "test-key", + S3SecretAccessKey: "test-secret", + } + + cfg, err := ParseAWSConfig(context.Background(), props) + require.NoError(t, err) + assert.NotNil(t, cfg) + assert.Equal(t, "us-west-2", cfg.Region) +} + +func TestRemoteSigningTransportIsS3Request(t *testing.T) { + transport := &RemoteSigningTransport{} + + tests := []struct { + url string + expected bool + }{ + {"https://s3.amazonaws.com/bucket/key", true}, + {"https://bucket.s3.amazonaws.com/key", true}, + {"https://s3.us-east-1.amazonaws.com/bucket/key", true}, + {"https://custom-endpoint.com/bucket/key", true}, // We allow all when remote signer is configured + {"https://example.com/path", true}, // We allow all when remote signer is configured + } + + for _, test := range tests { + req, err := http.NewRequest("GET", test.url, nil) + require.NoError(t, err) + + result := transport.isS3Request(req) + assert.Equal(t, test.expected, result, "URL: %s", test.url) + } +} + +func TestRemoteSigningTransport403Error(t *testing.T) { + // Create a mock signer server that returns 403 + signerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{"error": "insufficient permissions", "details": "signer service lacks IAM permissions for this bucket"}`)) + })) + defer signerServer.Close() + + // Create the remote signing transport + baseTransport := &http.Transport{} + transport := NewRemoteSigningTransport(baseTransport, signerServer.URL, "", "us-east-1", "", "") + + // Create a test request + req, err := http.NewRequest("PUT", "https://example.s3.amazonaws.com/bucket/key", nil) + require.NoError(t, err) + + // Make the request through the remote signing transport + _, err = transport.RoundTrip(req) + require.Error(t, err) + + // Verify the error contains detailed information + assert.Contains(t, err.Error(), "remote signer authorization denied (403)") + assert.Contains(t, err.Error(), signerServer.URL) + assert.Contains(t, err.Error(), "insufficient permissions") + assert.Contains(t, err.Error(), "Check that the signer service has proper AWS credentials") + assert.Contains(t, err.Error(), "Request was:") +} + +func TestRemoteSigningTransport404Error(t *testing.T) { + // Create a mock signer server that returns 404 + signerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{"error": "endpoint not found"}`)) + })) + defer signerServer.Close() + + // Create the remote signing transport with a wrong endpoint + baseTransport := &http.Transport{} + wrongURL := signerServer.URL + "/wrong-path" + transport := NewRemoteSigningTransport(baseTransport, wrongURL, "", "us-east-1", "", "") + + // Create a test request + req, err := http.NewRequest("GET", "https://example.s3.amazonaws.com/bucket/key", nil) + require.NoError(t, err) + + // Make the request through the remote signing transport + _, err = transport.RoundTrip(req) + require.Error(t, err) + + // Verify the error contains detailed information + assert.Contains(t, err.Error(), "remote signer endpoint not found (404)") + assert.Contains(t, err.Error(), wrongURL) + assert.Contains(t, err.Error(), "Check the signer URI configuration") +} + +func TestRemoteSigningTransportWithAuth(t *testing.T) { + expectedToken := "test-auth-token-12345" + + // Create a mock signer server that validates auth + signerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check auth header + authHeader := r.Header.Get("Authorization") + if authHeader != "Bearer "+expectedToken { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "invalid or missing auth token"}`)) + return + } + + // Return signed headers if auth is valid + response := RemoteSigningResponse{ + Headers: map[string][]string{ + "Authorization": {"AWS4-HMAC-SHA256 Credential=test/20231201/us-east-1/s3/aws4_request"}, + "X-Amz-Date": {"20231201T120000Z"}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer signerServer.Close() + + // Test with valid auth token + t.Run("ValidAuthToken", func(t *testing.T) { + baseTransport := &http.Transport{} + transport := NewRemoteSigningTransport(baseTransport, signerServer.URL, "", "us-east-1", expectedToken, "") + + req, err := http.NewRequest("GET", "https://example.s3.amazonaws.com/bucket/key", nil) + require.NoError(t, err) + + // This should succeed + resp, err := transport.getRemoteSignature(req.Context(), req.Method, req.URL.String(), transport.extractHeaders(req)) + require.NoError(t, err) + assert.NotEmpty(t, resp["Authorization"]) + }) + + // Test without auth token + t.Run("MissingAuthToken", func(t *testing.T) { + baseTransport := &http.Transport{} + transport := NewRemoteSigningTransport(baseTransport, signerServer.URL, "", "us-east-1", "", "") + + req, err := http.NewRequest("GET", "https://example.s3.amazonaws.com/bucket/key", nil) + require.NoError(t, err) + + // This should fail with 401 + _, err = transport.getRemoteSignature(req.Context(), req.Method, req.URL.String(), transport.extractHeaders(req)) + require.Error(t, err) + assert.Contains(t, err.Error(), "remote signer authentication failed (401)") + }) +} + +func TestParseAWSConfigWithRemoteSignerAuth(t *testing.T) { + // Create a mock signer server that requires auth + signerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check auth header + if r.Header.Get("Authorization") != "Bearer my-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + response := RemoteSigningResponse{ + Headers: map[string][]string{ + "Authorization": {"AWS4-HMAC-SHA256 Credential=test/20231201/us-east-1/s3/aws4_request"}, + "X-Amz-Date": {"20231201T120000Z"}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer signerServer.Close() + + props := map[string]string{ + S3Region: "us-east-1", + S3SignerUri: signerServer.URL, + S3SignerAuthToken: "my-token", + } + + cfg, err := ParseAWSConfig(context.Background(), props) + require.NoError(t, err) + assert.NotNil(t, cfg) + assert.Equal(t, "us-east-1", cfg.Region) +} diff --git a/manifest.go b/manifest.go index fca710a45..f10b69fdc 100644 --- a/manifest.go +++ b/manifest.go @@ -773,7 +773,8 @@ func ReadManifestList(in io.Reader) ([]ManifestFile, error) { return nil, err } - sc, err := avro.ParseBytes(dec.Metadata()["avro.schema"]) + metadata := dec.Metadata() + sc, err := avro.ParseBytes(metadata["avro.schema"]) if err != nil { return nil, err } @@ -1122,14 +1123,16 @@ func (w *ManifestWriter) meta() (map[string][]byte, error) { return nil, err } - return map[string][]byte{ + metadata := map[string][]byte{ "schema": schemaJson, "schema-id": []byte(strconv.Itoa(w.schema.ID)), "partition-spec": specFieldsJson, "partition-spec-id": []byte(strconv.Itoa(w.spec.ID())), "format-version": []byte(strconv.Itoa(w.version)), "content": []byte(w.impl.content().String()), - }, nil + } + + return metadata, nil } func (w *ManifestWriter) addEntry(entry *manifestEntry) error {