Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions java/org/apache/tomcat/websocket/AsyncChannelWrapperSecure.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.res.StringManager;
import org.apache.tomcat.util.threads.VirtualThreadExecutor;

/**
* Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot more testing before it can be considered
Expand All @@ -57,14 +58,23 @@ public class AsyncChannelWrapperSecure implements AsyncChannelWrapper {
private final SSLEngine sslEngine;
private final ByteBuffer socketReadBuffer;
private final ByteBuffer socketWriteBuffer;
// One thread for read, one for write
private final ExecutorService executor = Executors.newFixedThreadPool(2, new SecureIOThreadFactory());
private final ExecutorService executor;
private final AtomicBoolean writing = new AtomicBoolean(false);
private final AtomicBoolean reading = new AtomicBoolean(false);

public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, SSLEngine sslEngine) {
// One thread for read, one for write
this(socketChannel, sslEngine, Executors.newFixedThreadPool(2, new SecureIOThreadFactory()));
}

public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, SSLEngine sslEngine, VirtualThreadExecutor executor) {
this(socketChannel, sslEngine, (ExecutorService) executor);
}

private AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, SSLEngine sslEngine, ExecutorService executor) {
this.socketChannel = socketChannel;
this.sslEngine = sslEngine;
this.executor = executor;

int socketBufferSize = sslEngine.getSession().getPacketBufferSize();
socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize);
Expand Down Expand Up @@ -142,7 +152,10 @@ public void close() {
log.info(sm.getString("asyncChannelWrapperSecure.closeFail"));
}
}
executor.shutdownNow();

if (!(executor instanceof VirtualThreadExecutor)) {
executor.shutdownNow();
}
}

@Override
Expand Down
25 changes: 24 additions & 1 deletion java/org/apache/tomcat/websocket/WsWebSocketContainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import org.apache.tomcat.util.buf.StringUtils;
import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
import org.apache.tomcat.util.res.StringManager;
import org.apache.tomcat.util.threads.VirtualThreadExecutor;

public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess {

Expand Down Expand Up @@ -102,6 +103,8 @@ public class WsWebSocketContainer implements WebSocketContainer, BackgroundProce

private InstanceManager instanceManager;

private VirtualThreadExecutor virtualThreadExecutor;

protected InstanceManager getInstanceManager(ClassLoader classLoader) {
if (instanceManager != null) {
return instanceManager;
Expand Down Expand Up @@ -302,7 +305,11 @@ private Session connectToServerRecursive(ClientEndpointHolder clientEndpointHold
// proxy CONNECT, need to use TLS from this point on so wrap the
// original AsynchronousSocketChannel
SSLEngine sslEngine = createSSLEngine(clientEndpointConfiguration, host, port);
channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine);
if (useVirtualThreads()) {
channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine, virtualThreadExecutor);
} else {
channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine);
}
} else if (channel == null) {
// Only need to wrap as this point if it wasn't wrapped to process a
// proxy CONNECT
Expand Down Expand Up @@ -1010,6 +1017,10 @@ public void destroy() {
}
}
}

if (useVirtualThreads()) {
virtualThreadExecutor.close();
}
}


Expand All @@ -1028,6 +1039,18 @@ private AsynchronousChannelGroup getAsynchronousChannelGroup() {
return result;
}

public void setUseVirtualThreads(boolean useVirtualThreads) {
if (useVirtualThreads) {
virtualThreadExecutor = new VirtualThreadExecutor("WebSocketClient-IO-");
} else {
virtualThreadExecutor = null;
}
}

public boolean useVirtualThreads() {
return virtualThreadExecutor != null;
}


// ----------------------------------------------- BackgroundProcess methods

Expand Down