From 8cfb237fdc3c0af7a863ab665ed56fb80a89ba57 Mon Sep 17 00:00:00 2001 From: Carter Kozak Date: Mon, 2 Oct 2023 16:52:54 -0400 Subject: [PATCH] fix #1170: Use ReentrantLock instead of Object monitors --- .../java/org/conscrypt/ConscryptEngine.java | 136 ++++++++++++++---- .../org/conscrypt/ConscryptEngineSocket.java | 105 ++++++++++---- 2 files changed, 191 insertions(+), 50 deletions(-) diff --git a/common/src/main/java/org/conscrypt/ConscryptEngine.java b/common/src/main/java/org/conscrypt/ConscryptEngine.java index a58aa73cb..c56945b83 100644 --- a/common/src/main/java/org/conscrypt/ConscryptEngine.java +++ b/common/src/main/java/org/conscrypt/ConscryptEngine.java @@ -77,6 +77,7 @@ import java.security.interfaces.ECKey; import java.security.spec.ECParameterSpec; import java.util.Arrays; +import java.util.concurrent.locks.ReentrantLock; import javax.crypto.SecretKey; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; @@ -137,6 +138,11 @@ final class ConscryptEngine extends AbstractConscryptEngine implements NativeCry */ private final NativeSsl ssl; + /** + * Lock used for {@link #ssl} access. + */ + private final ReentrantLock sslLock = new ReentrantLock(); + /** * The BIO used for reading/writing encrypted bytes. */ @@ -227,12 +233,15 @@ static BufferAllocator getDefaultBufferAllocator() { @Override void setBufferAllocator(BufferAllocator bufferAllocator) { - synchronized (ssl) { + sslLock.lock(); + try { if (isHandshakeStarted()) { throw new IllegalStateException( "Could not set buffer allocator after the initial handshake has begun."); } this.bufferAllocator = bufferAllocator; + } finally { + sslLock.unlock(); } } @@ -254,7 +263,8 @@ int maxSealOverhead() { */ @Override void setChannelIdEnabled(boolean enabled) { - synchronized (ssl) { + sslLock.lock(); + try { if (getUseClientMode()) { throw new IllegalStateException("Not allowed in client mode"); } @@ -263,6 +273,8 @@ void setChannelIdEnabled(boolean enabled) { "Could not enable/disable Channel ID after the initial handshake has begun."); } sslParameters.channelIdEnabled = enabled; + } finally { + sslLock.unlock(); } } @@ -278,7 +290,8 @@ void setChannelIdEnabled(boolean enabled) { */ @Override byte[] getChannelId() throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { if (getUseClientMode()) { throw new IllegalStateException("Not allowed in client mode"); } @@ -288,6 +301,8 @@ byte[] getChannelId() throws SSLException { "Channel ID is only available after handshake completes"); } return ssl.getTlsChannelId(); + } finally { + sslLock.unlock(); } } @@ -309,7 +324,8 @@ void setChannelIdPrivateKey(PrivateKey privateKey) { throw new IllegalStateException("Not allowed in server mode"); } - synchronized (ssl) { + sslLock.lock(); + try { if (isHandshakeStarted()) { throw new IllegalStateException("Could not change Channel ID private key " + "after the initial handshake has begun."); @@ -337,6 +353,8 @@ void setChannelIdPrivateKey(PrivateKey privateKey) { } catch (InvalidKeyException e) { // Will have error in startHandshake } + } finally { + sslLock.unlock(); } } @@ -345,12 +363,15 @@ void setChannelIdPrivateKey(PrivateKey privateKey) { */ @Override void setHandshakeListener(HandshakeListener handshakeListener) { - synchronized (ssl) { + sslLock.lock(); + try { if (isHandshakeStarted()) { throw new IllegalStateException( "Handshake listener must be set before starting the handshake."); } this.handshakeListener = handshakeListener; + } finally { + sslLock.unlock(); } } @@ -397,8 +418,11 @@ public int getPeerPort() { @Override public void beginHandshake() throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { beginHandshakeInternal(); + } finally { + sslLock.unlock(); } } @@ -452,7 +476,8 @@ private void beginHandshakeInternal() throws SSLException { @Override public void closeInbound() { - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED || state == STATE_CLOSED_INBOUND) { return; } @@ -467,12 +492,15 @@ public void closeInbound() { // Never started the handshake. Just close now. closeAndFreeResources(); } + } finally { + sslLock.unlock(); } } @Override public void closeOutbound() { - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED || state == STATE_CLOSED_OUTBOUND) { return; } @@ -488,6 +516,8 @@ public void closeOutbound() { // Never started the handshake. Just close now. closeAndFreeResources(); } + } finally { + sslLock.unlock(); } } @@ -527,8 +557,11 @@ public void setSSLParameters(SSLParameters p) { @Override public HandshakeStatus getHandshakeStatus() { - synchronized (ssl) { + sslLock.lock(); + try { return getHandshakeStatusInternal(); + } finally { + sslLock.unlock(); } } @@ -578,7 +611,8 @@ public boolean getNeedClientAuth() { */ @Override SSLSession handshakeSession() { - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_HANDSHAKE_STARTED) { return Platform.wrapSSLSession(new ExternalSession(new ExternalSession.Provider() { @Override @@ -588,6 +622,8 @@ public ConscryptSession provideSession() { })); } return null; + } finally { + sslLock.unlock(); } } @@ -597,7 +633,8 @@ public SSLSession getSession() { } private ConscryptSession provideSession() { - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { return closedSession != null ? closedSession : SSLNullSession.getNullSession(); } @@ -606,13 +643,18 @@ private ConscryptSession provideSession() { return SSLNullSession.getNullSession(); } return activeSession; + } finally { + sslLock.unlock(); } } private ConscryptSession provideHandshakeSession() { - synchronized (ssl) { + sslLock.lock(); + try { return state == STATE_HANDSHAKE_STARTED ? activeSession : SSLNullSession.getNullSession(); + } finally { + sslLock.unlock(); } } @@ -646,21 +688,27 @@ public boolean getWantClientAuth() { @Override public boolean isInboundDone() { - synchronized (ssl) { + sslLock.lock(); + try { return (state == STATE_CLOSED || state == STATE_CLOSED_INBOUND || ssl.wasShutdownReceived()) && (pendingInboundCleartextBytes() == 0); + } finally { + sslLock.unlock(); } } @Override public boolean isOutboundDone() { - synchronized (ssl) { + sslLock.lock(); + try { return (state == STATE_CLOSED || state == STATE_CLOSED_OUTBOUND || ssl.wasShutdownSent()) && (pendingOutboundEncryptedBytes() == 0); + } finally { + sslLock.unlock(); } } @@ -686,13 +734,16 @@ public void setNeedClientAuth(boolean need) { @Override public void setUseClientMode(boolean mode) { - synchronized (ssl) { + sslLock.lock(); + try { if (isHandshakeStarted()) { throw new IllegalArgumentException( "Can not change mode after handshake: state == " + state); } transitionTo(STATE_MODE_SET); sslParameters.setUseClientMode(mode); + } finally { + sslLock.unlock(); } } @@ -703,36 +754,45 @@ public void setWantClientAuth(boolean want) { @Override public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer dst) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { try { return unwrap(singleSrcBuffer(src), singleDstBuffer(dst)); } finally { resetSingleSrcBuffer(); resetSingleDstBuffer(); } + } finally { + sslLock.unlock(); } } @Override public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { try { return unwrap(singleSrcBuffer(src), dsts); } finally { resetSingleSrcBuffer(); } + } finally { + sslLock.unlock(); } } @Override public SSLEngineResult unwrap(final ByteBuffer src, final ByteBuffer[] dsts, final int offset, final int length) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { try { return unwrap(singleSrcBuffer(src), 0, 1, dsts, offset, length); } finally { resetSingleSrcBuffer(); } + } finally { + sslLock.unlock(); } } @@ -759,7 +819,8 @@ SSLEngineResult unwrap(final ByteBuffer[] srcs, int srcsOffset, final int srcsLe final int srcsEndOffset = srcsOffset + srcsLength; final long srcLength = calcSrcsLength(srcs, srcsOffset, srcsEndOffset); - synchronized (ssl) { + sslLock.lock(); + try { switch (state) { case STATE_MODE_SET: // Begin the handshake implicitly. @@ -930,6 +991,8 @@ SSLEngineResult unwrap(final ByteBuffer[] srcs, int srcsOffset, final int srcsLe } return newResult(bytesConsumed, bytesProduced, handshakeStatus); + } finally { + sslLock.unlock(); } } @@ -1366,12 +1429,15 @@ private SSLEngineResult newResult(int bytesConsumed, int bytesProduced, @Override public SSLEngineResult wrap(ByteBuffer src, ByteBuffer dst) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { try { return wrap(singleSrcBuffer(src), dst); } finally { resetSingleSrcBuffer(); } + } finally { + sslLock.unlock(); } } @@ -1390,7 +1456,8 @@ public SSLEngineResult wrap(ByteBuffer[] srcs, int srcsOffset, int srcsLength, B } BufferUtils.checkNotNull(srcs); - synchronized (ssl) { + sslLock.lock(); + try { switch (state) { case STATE_MODE_SET: // Begin the handshake implicitly. @@ -1542,6 +1609,8 @@ public SSLEngineResult wrap(ByteBuffer[] srcs, int srcsOffset, int srcsLength, B } } return newResult(bytesConsumed, bytesProduced, handshakeStatus); + } finally { + sslLock.unlock(); } } @@ -1557,7 +1626,8 @@ public int serverPSKKeyRequested(String identityHint, String identity, byte[] ke @Override public void onSSLStateChange(int type, int val) { - synchronized (ssl) { + sslLock.lock(); + try { switch (type) { case SSL_CB_HANDSHAKE_START: { // For clients, this will allow the NEED_UNWRAP status to be @@ -1577,13 +1647,18 @@ public void onSSLStateChange(int type, int val) { default: // Ignore } + } finally { + sslLock.unlock(); } } @Override public void serverCertificateRequested() throws IOException { - synchronized (ssl) { + sslLock.lock(); + try { ssl.configureServerCertificate(); + } finally { + sslLock.unlock(); } } @@ -1677,8 +1752,11 @@ protected void finalize() throws Throwable { // If ssl is null, object must not be fully constructed so nothing for us to do here. if (ssl != null) { // Otherwise closeAndFreeResources() and callees expect to synchronize on ssl. - synchronized (ssl) { + sslLock.lock(); + try { closeAndFreeResources(); + } finally { + sslLock.unlock(); } } } finally { @@ -1758,10 +1836,13 @@ byte[] getTlsUnique() { @Override byte[] exportKeyingMaterial(String label, byte[] context, int length) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { if (state < STATE_HANDSHAKE_COMPLETED || state == STATE_CLOSED) { return null; } + } finally { + sslLock.unlock(); } return ssl.exportKeyingMaterial(label, context, length); } @@ -1786,8 +1867,11 @@ public String getApplicationProtocol() { @Override public String getHandshakeApplicationProtocol() { - synchronized (ssl) { + sslLock.lock(); + try { return state >= STATE_HANDSHAKE_STARTED ? getApplicationProtocol() : null; + } finally { + sslLock.unlock(); } } diff --git a/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java b/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java index f05fe25aa..62657f6ee 100644 --- a/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java +++ b/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java @@ -36,6 +36,8 @@ import java.security.PrivateKey; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; @@ -54,8 +56,9 @@ class ConscryptEngineSocket extends OpenSSLSocketImpl implements SSLParametersIm private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); private final ConscryptEngine engine; - private final Object stateLock = new Object(); - private final Object handshakeLock = new Object(); + private final ReentrantLock stateLock = new ReentrantLock(); + private final Condition stateLockCondition = stateLock.newCondition(); + private final ReentrantLock handshakeLock = new ReentrantLock(); private SSLOutputStream out; private SSLInputStream in; @@ -188,10 +191,12 @@ public final void startHandshake() throws IOException { checkOpen(); try { - synchronized (handshakeLock) { + handshakeLock.lock(); + try { // Only lock stateLock when we begin the handshake. This is done so that we don't // hold the stateLock when we invoke the handshake completion listeners. - synchronized (stateLock) { + stateLock.lock(); + try { // Initialize the handshake if we haven't already. if (state == STATE_NEW) { transitionTo(STATE_HANDSHAKE_STARTED); @@ -206,8 +211,12 @@ public final void startHandshake() throws IOException { // ignore addition handshake calls. return; } + } finally { + stateLock.unlock(); } doHandshake(); + } finally { + handshakeLock.unlock(); } } catch (IOException e) { close(); @@ -277,13 +286,17 @@ private void doHandshake() throws IOException { } private boolean isState(int desiredState) { - synchronized (stateLock) { + stateLock.lock(); + try { return state == desiredState; + } finally { + stateLock.unlock(); } } private int transitionTo(int newState) { - synchronized (stateLock) { + stateLock.lock(); + try { if (state == newState) { return state; } @@ -328,9 +341,11 @@ private int transitionTo(int newState) { state = newState; if (notify) { - stateLock.notifyAll(); + stateLockCondition.signalAll(); } return previousState; + } finally { + stateLock.unlock(); } } @@ -341,10 +356,13 @@ public final InputStream getInputStream() throws IOException { } private SSLInputStream createInputStream() { - synchronized (stateLock) { + stateLock.lock(); + try { if (in == null) { in = new SSLInputStream(); } + } finally { + stateLock.unlock(); } return in; } @@ -356,10 +374,13 @@ public final OutputStream getOutputStream() throws IOException { } private SSLOutputStream createOutputStream() { - synchronized (stateLock) { + stateLock.lock(); + try { if (out == null) { out = new SSLOutputStream(); } + } finally { + stateLock.unlock(); } return out; } @@ -594,13 +615,14 @@ private void onEngineHandshakeFinished() { private void waitForHandshake() throws IOException { startHandshake(); - synchronized (stateLock) { + stateLock.lock(); + try { while (state != STATE_READY // Waiting threads are allowed to compete with handshake listeners for access. && state != STATE_READY_HANDSHAKE_CUT_THROUGH && state != STATE_CLOSED) { try { - stateLock.wait(); + stateLockCondition.await(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted waiting for handshake", e); @@ -610,6 +632,8 @@ private void waitForHandshake() throws IOException { if (state == STATE_CLOSED) { throw new SocketException("Socket is closed"); } + } finally { + stateLock.unlock(); } } @@ -648,7 +672,7 @@ public final String chooseClientAlias(X509KeyManager keyManager, X500Principal[] * Wrap bytes written to the underlying socket. */ private final class SSLOutputStream extends OutputStream { - private final Object writeLock = new Object(); + private final ReentrantLock writeLock = new ReentrantLock(); private final ByteBuffer target; private final int targetArrayOffset; private OutputStream socketOutputStream; @@ -666,24 +690,33 @@ public void close() throws IOException { @Override public void write(int b) throws IOException { waitForHandshake(); - synchronized (writeLock) { + writeLock.lock(); + try { write(new byte[] {(byte) b}); + } finally { + writeLock.unlock(); } } @Override public void write(byte[] b) throws IOException { waitForHandshake(); - synchronized (writeLock) { + writeLock.lock(); + try { writeInternal(ByteBuffer.wrap(b)); + } finally { + writeLock.unlock(); } } @Override public void write(byte[] b, int off, int len) throws IOException { waitForHandshake(); - synchronized (writeLock) { + writeLock.lock(); + try { writeInternal(ByteBuffer.wrap(b, off, len)); + } finally { + writeLock.unlock(); } } @@ -727,8 +760,11 @@ private void writeInternal(ByteBuffer buffer) throws IOException { @Override public void flush() throws IOException { waitForHandshake(); - synchronized (writeLock) { + writeLock.lock(); + try { flushInternal(); + } finally { + writeLock.unlock(); } } @@ -754,7 +790,7 @@ private void writeToSocket() throws IOException { * Unwrap bytes read from the underlying socket. */ private final class SSLInputStream extends InputStream { - private final Object readLock = new Object(); + private final ReentrantLock readLock = new ReentrantLock(); private final byte[] singleByte = new byte[1]; private final ByteBuffer fromEngine; private final ByteBuffer fromSocket; @@ -783,17 +819,21 @@ public void close() throws IOException { } void release() { - synchronized (readLock) { + readLock.lock(); + try { if (allocatedBuffer != null) { allocatedBuffer.release(); } + } finally { + readLock.unlock(); } } @Override public int read() throws IOException { waitForHandshake(); - synchronized (readLock) { + readLock.lock(); + try { // Handle returning of -1 if EOF is reached. int count = read(singleByte, 0, 1); if (count == -1) { @@ -804,31 +844,42 @@ public int read() throws IOException { throw new SSLException("read incorrect number of bytes " + count); } return singleByte[0] & 0xff; + } finally { + readLock.unlock(); } } @Override public int read(byte[] b) throws IOException { waitForHandshake(); - synchronized (readLock) { + readLock.lock(); + try { return read(b, 0, b.length); + } finally { + readLock.unlock(); } } @Override public int read(byte[] b, int off, int len) throws IOException { waitForHandshake(); - synchronized (readLock) { + readLock.lock(); + try { return readUntilDataAvailable(b, off, len); + } finally { + readLock.unlock(); } } @Override public int available() throws IOException { waitForHandshake(); - synchronized (readLock) { + readLock.lock(); + try { init(); return fromEngine.remaining(); + } finally { + readLock.unlock(); } } @@ -941,8 +992,11 @@ && isHandshakeFinished()) { } private boolean isHandshakeFinished() { - synchronized (stateLock) { + stateLock.lock(); + try { return state > STATE_HANDSHAKE_STARTED; + } finally { + stateLock.unlock(); } } @@ -950,8 +1004,11 @@ private boolean isHandshakeFinished() { * Processes a renegotiation received from the remote peer. */ private void renegotiate() throws IOException { - synchronized (handshakeLock) { + handshakeLock.lock(); + try { doHandshake(); + } finally { + handshakeLock.unlock(); } }