diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java index 33297dd4f122..6f72bb0f953e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java @@ -345,14 +345,24 @@ public Receiptable acknowledge(StompHeaders headers, boolean consumed) { return receiptable; } - private void unsubscribe(String id, @Nullable StompHeaders headers) { - StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.UNSUBSCRIBE); - if (headers != null) { - accessor.addNativeHeaders(headers); + private Receiptable unsubscribe(String id, @Nullable StompHeaders headers) { + Assert.hasText(id, "Subscription id is required"); + + if (headers == null){ + headers = new StompHeaders(); } + + String receiptId = checkOrAddReceipt(headers); + Receiptable receiptable = new ReceiptHandler(receiptId); + + StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.UNSUBSCRIBE); + accessor.addNativeHeaders(headers); accessor.setSubscriptionId(id); + Message message = createMessage(accessor, EMPTY_PAYLOAD); execute(message); + + return receiptable; } @Override @@ -674,17 +684,19 @@ public StompFrameHandler getHandler() { } @Override - public void unsubscribe() { - unsubscribe(null); + public Receiptable unsubscribe() { + return unsubscribe(null); } @Override - public void unsubscribe(@Nullable StompHeaders headers) { + public Receiptable unsubscribe(@Nullable StompHeaders headers) { String id = this.headers.getId(); + Receiptable receiptable = new ReceiptHandler(null); if (id != null) { DefaultStompSession.this.subscriptions.remove(id); - DefaultStompSession.this.unsubscribe(id, headers); + receiptable = DefaultStompSession.this.unsubscribe(id, headers); } + return receiptable; } @Override diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java index d940cc74731b..bcb816c0f6a7 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java @@ -183,7 +183,7 @@ interface Subscription extends Receiptable { /** * Remove the subscription by sending an UNSUBSCRIBE frame. */ - void unsubscribe(); + Receiptable unsubscribe(); /** * Alternative to {@link #unsubscribe()} with additional custom headers @@ -192,7 +192,7 @@ interface Subscription extends Receiptable { * @param headers the custom headers, if any * @since 5.0 */ - void unsubscribe(@Nullable StompHeaders headers); + Receiptable unsubscribe(@Nullable StompHeaders headers); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java index 09b092f86be2..c310afb870f7 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.BeforeEach; @@ -662,6 +663,75 @@ public void receiptNotReceived() { verifyNoMoreInteractions(future); } + @Test + void unsubscribeWithReceipt() { + this.session.afterConnected(this.connection); + assertThat(this.session.isConnected()).isTrue(); + Subscription subscription = this.session.subscribe("/topic/foo", mock()); + + Receiptable receipt = subscription.unsubscribe(); + assertThat(receipt).isNotNull(); + assertThat(receipt.getReceiptId()).isNull(); + + Message message = this.messageCaptor.getValue(); + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + assertThat(accessor.getCommand()).isEqualTo(StompCommand.UNSUBSCRIBE); + + StompHeaders stompHeaders = StompHeaders.readOnlyStompHeaders(accessor.getNativeHeaders()); + assertThat(stompHeaders).hasSize(1); + assertThat(stompHeaders.getId()).isEqualTo(subscription.getSubscriptionId()); + } + + @Test + void unsubscribeWithCustomHeaderAndReceipt() { + this.session.afterConnected(this.connection); + this.session.setTaskScheduler(mock()); + this.session.setAutoReceipt(true); + + StompHeaders subHeaders = new StompHeaders(); + subHeaders.setDestination("/topic/foo"); + Subscription subscription = this.session.subscribe(subHeaders, mock()); + + StompHeaders custom = new StompHeaders(); + custom.set("x-cust", "value"); + + Receiptable receipt = subscription.unsubscribe(custom); + assertThat(receipt).isNotNull(); + assertThat(receipt.getReceiptId()).isNotNull(); + + Message message = this.messageCaptor.getValue(); + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + assertThat(accessor.getCommand()).isEqualTo(StompCommand.UNSUBSCRIBE); + + StompHeaders stompHeaders = StompHeaders.readOnlyStompHeaders(accessor.getNativeHeaders()); + assertThat(stompHeaders.getId()).isEqualTo(subscription.getSubscriptionId()); + assertThat(stompHeaders.get("x-cust")).containsExactly("value"); + assertThat(stompHeaders.getReceipt()).isEqualTo(receipt.getReceiptId()); + } + + @Test + void receiptReceivedOnUnsubscribe() { + this.session.afterConnected(this.connection); + TaskScheduler scheduler = mock(); + this.session.setTaskScheduler(scheduler); + this.session.setAutoReceipt(true); + + Subscription subscription = this.session.subscribe("/topic/foo", mock()); + Receiptable receipt = subscription.unsubscribe(); + + StompHeaderAccessor ack = StompHeaderAccessor.create(StompCommand.RECEIPT); + ack.setReceiptId(receipt.getReceiptId()); + ack.setLeaveMutable(true); + Message receiptMessage = MessageBuilder.createMessage(new byte[0], ack.getMessageHeaders()); + + AtomicBoolean called = new AtomicBoolean(false); + receipt.addReceiptTask(() -> called.set(true)); + + this.session.handleMessage(receiptMessage); + + assertThat(called.get()).isTrue(); + } + @Test void disconnect() { this.session.afterConnected(this.connection);