File tomcat-9.0.36-CVE-2024-23672.patch of Package tomcat.33319
Index: apache-tomcat-9.0.36-src/java/org/apache/tomcat/websocket/Constants.java
===================================================================
--- apache-tomcat-9.0.36-src.orig/java/org/apache/tomcat/websocket/Constants.java
+++ apache-tomcat-9.0.36-src/java/org/apache/tomcat/websocket/Constants.java
@@ -19,6 +19,7 @@ package org.apache.tomcat.websocket;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.concurrent.TimeUnit;
import javax.websocket.Extension;
@@ -114,6 +115,11 @@ public class Constants {
// Milliseconds so this is 20 seconds
public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000;
+ // Configuration for session close timeout
+ public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT";
+ // Default is 30 seconds - setting is in milliseconds
+ public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30);
+
// Configuration for background processing checks intervals
static final int DEFAULT_PROCESS_PERIOD = Integer.getInteger(
"org.apache.tomcat.websocket.DEFAULT_PROCESS_PERIOD", 10)
Index: apache-tomcat-9.0.36-src/java/org/apache/tomcat/websocket/WsSession.java
===================================================================
--- apache-tomcat-9.0.36-src.orig/java/org/apache/tomcat/websocket/WsSession.java
+++ apache-tomcat-9.0.36-src/java/org/apache/tomcat/websocket/WsSession.java
@@ -28,7 +28,9 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCode;
@@ -91,7 +93,7 @@ public class WsSession implements Sessio
// Expected to handle message types of <ByteBuffer> only
private volatile MessageHandler binaryMessageHandler = null;
private volatile MessageHandler.Whole<PongMessage> pongMessageHandler = null;
- private volatile State state = State.OPEN;
+ private AtomicReference<State> state = new AtomicReference<>(State.OPEN);
private final Object stateLock = new Object();
private final Map<String, Object> userProperties = new ConcurrentHashMap<>();
private volatile int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
@@ -99,6 +101,7 @@ public class WsSession implements Sessio
private volatile long maxIdleTimeout = 0;
private volatile long lastActive = System.currentTimeMillis();
private Map<FutureToSendHandler, FutureToSendHandler> futures = new ConcurrentHashMap<>();
+ private volatile Long sessionCloseTimeoutExpiry;
/**
* Creates a new WebSocket session for communication between the two
@@ -368,9 +371,12 @@ public class WsSession implements Sessio
@Override
public boolean isOpen() {
- return state == State.OPEN;
+ return state.get() == State.OPEN;
}
+ public boolean isClosed() {
+ return state.get() == State.CLOSED;
+ }
@Override
public long getMaxIdleTimeout() {
@@ -470,37 +476,46 @@ public class WsSession implements Sessio
* @param closeSocket Should the socket be closed immediately rather than waiting
* for the server to respond
*/
- public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal,
- boolean closeSocket) {
- // Double-checked locking. OK because state is volatile
- if (state != State.OPEN) {
+ public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal, boolean closeSocket) {
+
+ if (!state.compareAndSet(State.OPEN, State.OUTPUT_CLOSING)) {
+ // Close process has already been started. Don't start it again.
return;
}
- synchronized (stateLock) {
- if (state != State.OPEN) {
- return;
- }
-
- if (log.isDebugEnabled()) {
- log.debug(sm.getString("wsSession.doClose", id));
- }
- try {
- wsRemoteEndpoint.setBatchingAllowed(false);
- } catch (IOException e) {
- log.warn(sm.getString("wsSession.flushFailOnClose"), e);
- fireEndpointOnError(e);
- }
+ if (log.isDebugEnabled()) {
+ log.debug(sm.getString("wsSession.doClose", id));
+ }
- state = State.OUTPUT_CLOSED;
+ // Flush any batched messages not yet sent.
+ try {
+ wsRemoteEndpoint.setBatchingAllowed(false);
+ } catch (IOException e) {
+ log.warn(sm.getString("wsSession.flushFailOnClose"), e);
+ fireEndpointOnError(e);
+ }
- sendCloseMessage(closeReasonMessage);
- if (closeSocket) {
- wsRemoteEndpoint.close();
- }
- fireEndpointOnClose(closeReasonLocal);
+ // Send the close message to the remote endpoint.
+ sendCloseMessage(closeReasonMessage);
+ fireEndpointOnClose(closeReasonLocal);
+ if (!state.compareAndSet(State.OUTPUT_CLOSING, State.OUTPUT_CLOSED) || closeSocket) {
+ /*
+ * A close message was received in another thread or this is handling an error condition. Either way, no
+ * further close message is expected to be received. Mark the session as fully closed...
+ */
+ state.set(State.CLOSED);
+ // ... and close the network connection.
+ closeConnection();
+ } else {
+ /*
+ * Set close timeout. If the client fails to send a close message response within the timeout, the session
+ * and the connection will be closed when the timeout expires.
+ */
+ sessionCloseTimeoutExpiry =
+ Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout()));
}
+ // Fail any uncompleted messages.
IOException ioe = new IOException(sm.getString("wsSession.messageFailed"));
SendResult sr = new SendResult(ioe);
for (FutureToSendHandler f2sh : futures.keySet()) {
@@ -518,28 +533,86 @@ public class WsSession implements Sessio
* message.
*/
public void onClose(CloseReason closeReason) {
+ if (state.compareAndSet(State.OPEN, State.CLOSING)) {
+ // Standard close.
- synchronized (stateLock) {
- if (state != State.CLOSED) {
- try {
- wsRemoteEndpoint.setBatchingAllowed(false);
- } catch (IOException e) {
- log.warn(sm.getString("wsSession.flushFailOnClose"), e);
- fireEndpointOnError(e);
- }
- if (state == State.OPEN) {
- state = State.OUTPUT_CLOSED;
- sendCloseMessage(closeReason);
- fireEndpointOnClose(closeReason);
- }
- state = State.CLOSED;
+ // Flush any batched messages not yet sent.
+ try {
+ wsRemoteEndpoint.setBatchingAllowed(false);
+ } catch (IOException e) {
+ log.warn(sm.getString("wsSession.flushFailOnClose"), e);
+ fireEndpointOnError(e);
+ }
+
+ // Send the close message response to the remote endpoint.
+ sendCloseMessage(closeReason);
+ fireEndpointOnClose(closeReason);
+
+ // Mark the session as fully closed.
+ state.set(State.CLOSED);
+
+ // Close the network connection.
+ closeConnection();
+ } else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) {
+ /*
+ * The local endpoint sent a close message the the same time as the remote endpoint. The local close is
+ * still being processed. Update the state so the the local close process will also close the network
+ * connection once it has finished sending a close message.
+ */
+ } else if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) {
+ /*
+ * The local endpoint sent the first close message. The remote endpoint has now responded with its own close
+ * message so mark the session as fully closed and close the network connection.
+ */
+ closeConnection();
+ }
+ // CLOSING and CLOSED are NO-OPs
+ }
+
+
+ private void closeConnection() {
+ /*
+ * Close the network connection.
+ */
+ wsRemoteEndpoint.close();
+ /*
+ * Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for
+ * tracking the session close timeout.
+ */
+ webSocketContainer.unregisterSession(getSessionMapKey(), this);
+ }
- // Close the socket
- wsRemoteEndpoint.close();
+
+ /*
+ * Returns the session close timeout in milliseconds
+ */
+ protected long getSessionCloseTimeout() {
+ long result = 0;
+ Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY);
+ if (obj instanceof Long) {
+ result = ((Long) obj).intValue();
+ }
+ if (result <= 0) {
+ result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT;
+ }
+ return result;
+ }
+
+
+ protected void checkCloseTimeout() {
+ // Skip the check if no session close timeout has been set.
+ if (sessionCloseTimeoutExpiry != null) {
+ // Check if the timeout has expired.
+ if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) {
+ // Check if the session has been closed in another thread while the timeout was being processed.
+ if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) {
+ closeConnection();
+ }
}
}
}
+
private void fireEndpointOnClose(CloseReason closeReason) {
// Fire the onClose event
@@ -616,7 +689,7 @@ public class WsSession implements Sessio
if (log.isDebugEnabled()) {
log.debug(sm.getString("wsSession.sendCloseFail", id), e);
}
- wsRemoteEndpoint.close();
+ closeConnection();
// Failure to send a close message is not unexpected in the case of
// an abnormal closure (usually triggered by a failure to read/write
// from/to the client. In this case do not trigger the endpoint's
@@ -687,13 +760,13 @@ public class WsSession implements Sessio
// Always register the future.
futures.put(f2sh, f2sh);
- if (state == State.OPEN) {
+ if (isOpen()) {
// The session is open. The future has been registered with the open
// session. Normal processing continues.
return;
}
- // The session is closed. The future may or may not have been registered
+ // The session is closing / closed. The future may or may not have been registered
// in time for it to be processed during session closure.
if (f2sh.isDone()) {
@@ -703,7 +776,7 @@ public class WsSession implements Sessio
return;
}
- // The session is closed. The Future had not completed when last checked.
+ // The session is closing / closed. The Future had not completed when last checked.
// There is a small timing window that means the Future may have been
// completed since the last check. There is also the possibility that
// the Future was not registered in time to be cleaned up during session
@@ -756,6 +829,11 @@ public class WsSession implements Sessio
@Override
public Principal getUserPrincipal() {
checkState();
+ return getUserPrincipalInternal();
+ }
+
+
+ public Principal getUserPrincipalInternal() {
return userPrincipal;
}
@@ -828,7 +906,7 @@ public class WsSession implements Sessio
private void checkState() {
- if (state == State.CLOSED) {
+ if (isClosed()) {
/*
* As per RFC 6455, a WebSocket connection is considered to be
* closed once a peer has sent and received a WebSocket close frame.
@@ -839,7 +917,9 @@ public class WsSession implements Sessio
private enum State {
OPEN,
+ OUTPUT_CLOSING,
OUTPUT_CLOSED,
+ CLOSING,
CLOSED
}
Index: apache-tomcat-9.0.36-src/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
===================================================================
--- apache-tomcat-9.0.36-src.orig/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
+++ apache-tomcat-9.0.36-src/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
@@ -621,7 +621,12 @@ public class WsWebSocketContainer implem
synchronized (endPointSessionMapLock) {
Set<WsSession> sessions = endpointSessionMap.get(key);
if (sessions != null) {
- result.addAll(sessions);
+ // Some sessions may be in the process of closing
+ for (WsSession session : sessions) {
+ if (session.isOpen()) {
+ result.add(session);
+ }
+ }
}
}
return result;
@@ -1075,8 +1080,10 @@ public class WsWebSocketContainer implem
if (backgroundProcessCount >= processPeriod) {
backgroundProcessCount = 0;
+ // Check all registered sessions.
for (WsSession wsSession : sessions.keySet()) {
wsSession.checkExpiration();
+ wsSession.checkCloseTimeout();
}
}
Index: apache-tomcat-9.0.36-src/java/org/apache/tomcat/websocket/server/WsServerContainer.java
===================================================================
--- apache-tomcat-9.0.36-src.orig/java/org/apache/tomcat/websocket/server/WsServerContainer.java
+++ apache-tomcat-9.0.36-src/java/org/apache/tomcat/websocket/server/WsServerContainer.java
@@ -415,7 +415,7 @@ public class WsServerContainer extends W
*/
@Override
protected void unregisterSession(Object key, WsSession wsSession) {
- if (wsSession.getUserPrincipal() != null &&
+ if (wsSession.getUserPrincipalInternal() != null &&
wsSession.getHttpSessionId() != null) {
unregisterAuthenticatedSession(wsSession,
wsSession.getHttpSessionId());
Index: apache-tomcat-9.0.36-src/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java
===================================================================
--- apache-tomcat-9.0.36-src.orig/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java
+++ apache-tomcat-9.0.36-src/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java
@@ -23,6 +23,8 @@ import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
+import javax.servlet.ServletContextEvent;
+import javax.servlet.ServletContextListener;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
@@ -39,7 +41,9 @@ import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
+import org.apache.tomcat.websocket.server.Constants;
import org.apache.tomcat.websocket.server.TesterEndpointConfig;
+import org.apache.tomcat.websocket.server.WsServerContainer;
public class TestWsSessionSuspendResume extends WebSocketBaseTest {
@@ -141,4 +145,99 @@ public class TestWsSessionSuspendResume
}
}
}
-}
\ No newline at end of file
+
+
+ @Test
+ public void testSuspendThenClose() throws Exception {
+ Tomcat tomcat = getTomcatInstance();
+
+ Context ctx = getProgrammaticRootContext();
+ ctx.addApplicationListener(SuspendCloseConfig.class.getName());
+ ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName());
+
+ Tomcat.addServlet(ctx, "default", new DefaultServlet());
+ ctx.addServletMappingDecoded("/", "default");
+
+ tomcat.start();
+
+ WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
+
+ ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
+ Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig,
+ new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH));
+
+ wsSession.getBasicRemote().sendText("start test");
+
+ // Wait for the client response to be received by the server
+ int count = 0;
+ while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) {
+ Thread.sleep(100);
+ count ++;
+ }
+ Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed());
+ }
+
+
+ public static final class SuspendCloseConfig extends TesterEndpointConfig {
+ private static final String PATH = "/echo";
+
+ @Override
+ protected Class<?> getEndpointClass() {
+ return SuspendCloseEndpoint.class;
+ }
+
+ @Override
+ protected ServerEndpointConfig getServerEndpointConfig() {
+ return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build();
+ }
+ }
+
+
+ public static final class SuspendCloseEndpoint extends Endpoint {
+
+ // Yes, a static variable is a hack.
+ private static WsSession serverSession;
+
+ @Override
+ public void onOpen(Session session, EndpointConfig epc) {
+ serverSession = (WsSession) session;
+ // Set a short session close timeout (milliseconds)
+ serverSession.getUserProperties().put(
+ org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000));
+ // Any message will trigger the suspend then close
+ serverSession.addMessageHandler(String.class, message -> {
+ try {
+ serverSession.getBasicRemote().sendText("server session open");
+ serverSession.getBasicRemote().sendText("suspending server session");
+ serverSession.suspend();
+ serverSession.getBasicRemote().sendText("closing server session");
+ serverSession.close();
+ } catch (IOException ioe) {
+ ioe.printStackTrace();
+ // Attempt to make the failure more obvious
+ throw new RuntimeException(ioe);
+ }
+ });
+ }
+
+ @Override
+ public void onError(Session session, Throwable t) {
+ t.printStackTrace();
+ }
+
+ public static boolean isServerSessionFullyClosed() {
+ return serverSession.isClosed();
+ }
+ }
+
+
+ public static class WebSocketFastServerTimeout implements ServletContextListener {
+
+ @Override
+ public void contextInitialized(ServletContextEvent sce) {
+ WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute(
+ Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
+ container.setProcessPeriod(0);
+ }
+ }
+}
Index: apache-tomcat-9.0.36-src/webapps/docs/changelog.xml
===================================================================
--- apache-tomcat-9.0.36-src.orig/webapps/docs/changelog.xml
+++ apache-tomcat-9.0.36-src/webapps/docs/changelog.xml
@@ -198,6 +198,11 @@
<bug>65033</bug>: Fix JNDI realm error handling when connecting to a
failed server when pooling was not enabled. (remm)
</fix>
+ <fix>
+ <bug>66574</bug>: Refactor WebSocket session close to remove the lock on
+ the <code>SocketWrapper</code> which was a potential cause of deadlocks
+ if the application code used simulated blocking. (markt)
+ </fix>
</changelog>
</subsection>
<subsection name="Other">
@@ -1470,6 +1475,11 @@
When running on Java 9 and above, don't attempt to instantiate WebSocket
Endpoints found in modules that are not exported. (markt)
</fix>
+ <fix>
+ Ensure that WebSocket connection closure completes if the connection is
+ closed when the server side has used the proprietary suspend/resume
+ feature to suspend the connection. (markt)
+ </fix>
</changelog>
</subsection>
<subsection name="Web Applications">
Index: apache-tomcat-9.0.36-src/webapps/docs/web-socket-howto.xml
===================================================================
--- apache-tomcat-9.0.36-src.orig/webapps/docs/web-socket-howto.xml
+++ apache-tomcat-9.0.36-src/webapps/docs/web-socket-howto.xml
@@ -63,6 +63,13 @@
the timeout to use in milliseconds. For an infinite timeout, use
<code>-1</code>.</p>
+<p>The session close timeout defaults to 30000 milliseconds (30 seconds). This
+ may be changed by setting the property
+ <code>org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT</code> in the user
+ properties collection attached to the WebSocket session. The value assigned
+ to this property should be a <code>Long</code> and represents the timeout to
+ use in milliseconds. Values less than or equal to zero will be ignored.</p>
+
<p>If the application does not define a <code>MessageHandler.Partial</code> for
incoming binary messages, any incoming binary messages must be buffered so
the entire message can be delivered in a single call to the registered