diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java index d8f536e0be8f..d243af8ac641 100644 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java @@ -24,10 +24,15 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodySubscriber; +import java.net.http.HttpResponse.BodySubscribers; +import java.net.http.HttpResponse.ResponseInfo; import java.net.http.HttpTimeoutException; import java.nio.ByteBuffer; import java.time.Duration; import java.util.Collections; +import java.util.List; import java.util.Locale; import java.util.Set; import java.util.TreeSet; @@ -37,6 +42,8 @@ import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.TimeUnit; +import java.util.zip.GZIPInputStream; +import java.util.zip.InflaterInputStream; import org.jspecify.annotations.Nullable; @@ -59,6 +66,8 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest { private static final Set DISALLOWED_HEADERS = disallowedHeaders(); + private static final List ALLOWED_ENCODINGS = List.of("gzip", "deflate"); + private final HttpClient httpClient; @@ -70,15 +79,18 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest { private final @Nullable Duration timeout; + private final boolean compressionEnabled; + public JdkClientHttpRequest(HttpClient httpClient, URI uri, HttpMethod method, Executor executor, - @Nullable Duration readTimeout) { + @Nullable Duration readTimeout, boolean compressionEnabled) { this.httpClient = httpClient; this.uri = uri; this.method = method; this.executor = executor; this.timeout = readTimeout; + this.compressionEnabled = compressionEnabled; } @@ -98,7 +110,7 @@ protected ClientHttpResponse executeInternal(HttpHeaders headers, @Nullable Body CompletableFuture> responseFuture = null; try { HttpRequest request = buildRequest(headers, body); - responseFuture = this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()); + responseFuture = this.httpClient.sendAsync(request, new DecompressingBodyHandler()); if (this.timeout != null) { TimeoutHandler timeoutHandler = new TimeoutHandler(responseFuture, this.timeout); @@ -141,6 +153,15 @@ else if (cause instanceof IOException ioEx) { private HttpRequest buildRequest(HttpHeaders headers, @Nullable Body body) { HttpRequest.Builder builder = HttpRequest.newBuilder().uri(this.uri); + // When compression is enabled and valid encoding is absent, we add gzip as standard encoding + if (this.compressionEnabled) { + if (headers.containsHeader(HttpHeaders.ACCEPT_ENCODING) && + !ALLOWED_ENCODINGS.contains(headers.getFirst(HttpHeaders.ACCEPT_ENCODING))) { + headers.remove(HttpHeaders.ACCEPT_ENCODING); + } + headers.add(HttpHeaders.ACCEPT_ENCODING, "gzip"); + } + headers.forEach((headerName, headerValues) -> { if (!DISALLOWED_HEADERS.contains(headerName.toLowerCase(Locale.ROOT))) { for (String headerValue : headerValues) { @@ -226,7 +247,7 @@ public ByteBuffer map(byte[] b, int off, int len) { /** * Temporary workaround to use instead of {@link HttpRequest.Builder#timeout(Duration)} * until JDK-8258397 - * is fixed. Essentially, create a future wiht a timeout handler, and use it + * is fixed. Essentially, create a future with a timeout handler, and use it * to close the response. * @see OpenJDK discussion thread */ @@ -269,4 +290,39 @@ public void close() throws IOException { } } + /** + * Custom BodyHandler that checks the Content-Encoding header and applies the appropriate decompression algorithm. + * Supports Gzip and Deflate encoded responses. + */ + public static final class DecompressingBodyHandler implements BodyHandler { + + @Override + public BodySubscriber apply(ResponseInfo responseInfo) { + String contentEncoding = responseInfo.headers().firstValue(HttpHeaders.CONTENT_ENCODING).orElse(""); + if (contentEncoding.equalsIgnoreCase("gzip")) { + // If the content is gzipped, wrap the InputStream with a GZIPInputStream + return BodySubscribers.mapping( + BodySubscribers.ofInputStream(), + (InputStream is) -> { + try { + return new GZIPInputStream(is); + } + catch (IOException ex) { + throw new UncheckedIOException(ex); // Propagate IOExceptions + } + }); + } + else if (contentEncoding.equalsIgnoreCase("deflate")) { + // If the content is encoded using deflate, wrap the InputStream with a InflaterInputStream + return BodySubscribers.mapping( + BodySubscribers.ofInputStream(), + InflaterInputStream::new); + } + else { + // Otherwise, return a standard InputStream BodySubscriber + return BodySubscribers.ofInputStream(); + } + } + } + } diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java index 886a64e2a773..01cce828239b 100644 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java @@ -43,6 +43,8 @@ public class JdkClientHttpRequestFactory implements ClientHttpRequestFactory { private @Nullable Duration readTimeout; + private boolean compressionEnabled; + /** * Create a new instance of the {@code JdkClientHttpRequestFactory} @@ -96,10 +98,18 @@ public void setReadTimeout(Duration readTimeout) { this.readTimeout = readTimeout; } + /** + * Sets custom {@link BodyHandler} that can handle gzip encoded {@link HttpClient}'s response. + * @param compressionEnabled to enable compression by default for all {@link HttpClient}'s requests. + */ + public void setCompressionEnabled(boolean compressionEnabled) { + this.compressionEnabled = compressionEnabled; + } + @Override public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { - return new JdkClientHttpRequest(this.httpClient, uri, httpMethod, this.executor, this.readTimeout); + return new JdkClientHttpRequest(this.httpClient, uri, httpMethod, this.executor, this.readTimeout, this.compressionEnabled); } } diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java index 28e83978e15a..922b8c00ba7e 100644 --- a/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java @@ -23,8 +23,14 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.springframework.http.HttpHeaders; import org.springframework.util.StringUtils; +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.zip.DeflaterOutputStream; +import java.util.zip.GZIPOutputStream; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -106,6 +112,26 @@ else if(request.getTarget().startsWith("/header/")) { String headerName = request.getTarget().replace("/header/",""); return new MockResponse.Builder().body(headerName + ":" + request.getHeaders().get(headerName)).code(200).build(); } + else if(request.getTarget().startsWith("/compress/")) { + String encoding = request.getTarget().replace("/compress/",""); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + if (encoding.equals("gzip")) { + try(GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream)) { + gzipOutputStream.write("Test Payload".getBytes()); + gzipOutputStream.flush(); + } + } + else if(encoding.equals("deflate")) { + try(DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(byteArrayOutputStream)) { + deflaterOutputStream.write("Test Payload".getBytes()); + deflaterOutputStream.flush(); + } + } else { + byteArrayOutputStream.write("Test Payload".getBytes()); + } + return new MockResponse.Builder().body(byteArrayOutputStream.toString(StandardCharsets.ISO_8859_1)) + .code(200).setHeader(HttpHeaders.CONTENT_ENCODING, encoding).build(); + } return new MockResponse.Builder().code(404).build(); } catch (Throwable ex) { diff --git a/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java index 025f47e0c44f..8f380754c9ac 100644 --- a/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java @@ -108,6 +108,44 @@ void deleteRequestWithBody() throws Exception { } } + @Test + void compressionDisabled() throws IOException { + URI uri = URI.create(baseUrl + "/compress/"); + ClientHttpRequest request = this.factory.createRequest(uri, HttpMethod.GET); + try (ClientHttpResponse response = request.execute()) { + assertThat(response.getStatusCode()).as("Invalid response status").isEqualTo(HttpStatus.OK); + assertThat(StreamUtils.copyToString(response.getBody(), StandardCharsets.ISO_8859_1)) + .as("Invalid request body").isEqualTo("Test Payload"); + } + } + + @Test + void compressionGzip() throws IOException { + URI uri = URI.create(baseUrl + "/compress/gzip"); + JdkClientHttpRequestFactory requestFactory = (JdkClientHttpRequestFactory) this.factory; + requestFactory.setCompressionEnabled(true); + ClientHttpRequest request = requestFactory.createRequest(uri, HttpMethod.GET); + + try (ClientHttpResponse response = request.execute()) { + assertThat(response.getStatusCode()).as("Invalid response status").isEqualTo(HttpStatus.OK); + assertThat(StreamUtils.copyToString(response.getBody(), StandardCharsets.ISO_8859_1)) + .as("Invalid request body").isEqualTo("Test Payload"); + } + } + + @Test + void compressionDeflate() throws IOException { + URI uri = URI.create(baseUrl + "/compress/deflate"); + JdkClientHttpRequestFactory requestFactory = (JdkClientHttpRequestFactory) this.factory; + requestFactory.setCompressionEnabled(true); + ClientHttpRequest request = requestFactory.createRequest(uri, HttpMethod.GET); + try (ClientHttpResponse response = request.execute()) { + assertThat(response.getStatusCode()).as("Invalid response status").isEqualTo(HttpStatus.OK); + assertThat(StreamUtils.copyToString(response.getBody(), StandardCharsets.ISO_8859_1)) + .as("Invalid request body").isEqualTo("Test Payload"); + } + } + @Test // gh-34971 @EnabledForJreRange(min = JRE.JAVA_19) // behavior fixed in Java 19 void requestContentLengthHeaderWhenNoBody() throws Exception {