From 6d0738974fc3451508112dfdaeee2e174666e1cc Mon Sep 17 00:00:00 2001 From: Oleh Dokuka <5380167+OlegDokuka@users.noreply.github.com> Date: Mon, 24 Apr 2023 12:27:07 +0300 Subject: [PATCH] adds class check for discarded values (#1091) --- .../io/rsocket/core/DefaultRSocketClient.java | 17 +++++----- .../main/java/io/rsocket/core/SendUtils.java | 12 ++++--- .../core/DefaultRSocketClientTests.java | 22 +++++++++++++ .../java/io/rsocket/core/SendUtilsTest.java | 31 +++++++++++++++++++ 4 files changed, 70 insertions(+), 12 deletions(-) create mode 100644 rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java index 9cd89c0b1..82a02268d 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java @@ -46,13 +46,16 @@ */ class DefaultRSocketClient extends ResolvingOperator implements CoreSubscriber, CorePublisher, RSocketClient { - static final Consumer DISCARD_ELEMENTS_CONSUMER = - referenceCounted -> { - if (referenceCounted.refCnt() > 0) { - try { - referenceCounted.release(); - } catch (IllegalReferenceCountException e) { - // ignored + static final Consumer DISCARD_ELEMENTS_CONSUMER = + data -> { + if (data instanceof ReferenceCounted) { + ReferenceCounted referenceCounted = ((ReferenceCounted) data); + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } } } }; diff --git a/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java index 53d222605..568dada2e 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java @@ -40,11 +40,13 @@ final class SendUtils { private static final Consumer DROPPED_ELEMENTS_CONSUMER = data -> { - try { - ReferenceCounted referenceCounted = (ReferenceCounted) data; - referenceCounted.release(); - } catch (Throwable e) { - // ignored + if (data instanceof ReferenceCounted) { + try { + ReferenceCounted referenceCounted = (ReferenceCounted) data; + referenceCounted.release(); + } catch (Throwable e) { + // ignored + } } }; diff --git a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java index a8a5f2e58..84576e6ce 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java @@ -39,6 +39,7 @@ import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.Stream; import org.assertj.core.api.Assertions; @@ -49,6 +50,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; import org.reactivestreams.Publisher; import reactor.core.Disposable; import reactor.core.publisher.Flux; @@ -79,6 +81,26 @@ public void setUp() throws Throwable { public void tearDown() { Hooks.resetOnErrorDropped(); Hooks.resetOnNextDropped(); + rule.allocator.assertHasNoLeaks(); + } + + @Test + @SuppressWarnings("unchecked") + void discardElementsConsumerShouldAcceptOtherTypesThanReferenceCounted() { + Consumer discardElementsConsumer = DefaultRSocketClient.DISCARD_ELEMENTS_CONSUMER; + discardElementsConsumer.accept(new Object()); + } + + @Test + void droppedElementsConsumerReleaseReference() { + ReferenceCounted referenceCounted = Mockito.mock(ReferenceCounted.class); + Mockito.when(referenceCounted.release()).thenReturn(true); + Mockito.when(referenceCounted.refCnt()).thenReturn(1); + + Consumer discardElementsConsumer = DefaultRSocketClient.DISCARD_ELEMENTS_CONSUMER; + discardElementsConsumer.accept(referenceCounted); + + Mockito.verify(referenceCounted).release(); } static Stream interactions() { diff --git a/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java new file mode 100644 index 000000000..9a51b9419 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java @@ -0,0 +1,31 @@ +package io.rsocket.core; + +import static org.mockito.Mockito.*; + +import io.netty.util.ReferenceCounted; +import java.util.function.Consumer; +import org.junit.jupiter.api.Test; + +public class SendUtilsTest { + + @Test + void droppedElementsConsumerShouldAcceptOtherTypesThanReferenceCounted() { + Consumer value = extractDroppedElementConsumer(); + value.accept(new Object()); + } + + @Test + void droppedElementsConsumerReleaseReference() { + ReferenceCounted referenceCounted = mock(ReferenceCounted.class); + when(referenceCounted.release()).thenReturn(true); + + Consumer value = extractDroppedElementConsumer(); + value.accept(referenceCounted); + + verify(referenceCounted).release(); + } + + private static Consumer extractDroppedElementConsumer() { + return (Consumer) SendUtils.DISCARD_CONTEXT.stream().findAny().get().getValue(); + } +}