Skip to content

Commit

Permalink
Use semaphore to allow parallel creation of pool connections.
Browse files Browse the repository at this point in the history
  • Loading branch information
Harmandeep Singh committed Oct 1, 2020
1 parent e6ff797 commit a6cf9aa
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 115 deletions.
3 changes: 0 additions & 3 deletions rubix-spi/src/main/java/com/qubole/rubix/spi/CacheConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
*/
package com.qubole.rubix.spi;

import com.google.common.collect.ImmutableList;
import org.apache.hadoop.conf.Configuration;

import java.util.List;

import static com.qubole.rubix.spi.utils.DataSizeUnits.MEGABYTES;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
*/
public interface ObjectFactory<T>
{
T create(String host, int socketTimeout, int connectTimeout);
T create(String host, int socketTimeout, int connectTimeout)
throws Exception;

void destroy(T t);

Expand Down
15 changes: 8 additions & 7 deletions rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
import com.google.common.util.concurrent.AbstractScheduledService;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;

import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;

import static java.lang.Thread.currentThread;
Expand Down Expand Up @@ -78,8 +76,11 @@ public Poolable<T> borrowObject(String host)
}
log.debug(this.name + " : Borrowing object for partition: " + host);
for (int i = 0; i < 3; i++) { // try at most three times
Poolable<T> result = getObject(false, host);
if (factory.validate(result.getObject())) {
Poolable<T> result = getObject(host);
if (result == null) {
continue;
}
else if (factory.validate(result.getObject())) {
return result;
}
else {
Expand All @@ -89,10 +90,10 @@ public Poolable<T> borrowObject(String host)
throw new RuntimeException("Cannot find a valid object");
}

private Poolable<T> getObject(boolean blocking, String host)
private Poolable<T> getObject(String host)
{
ObjectPoolPartition<T> subPool = this.hostToPoolMap.get(host);
return subPool.getObject(blocking);
return subPool.getObject();
}

public void returnObject(Poolable<T> obj)
Expand All @@ -105,7 +106,7 @@ public int getSize()
{
int size = 0;
for (ObjectPoolPartition<T> subPool : hostToPoolMap.values()) {
size += subPool.getTotalCount();
size += subPool.getAliveObjectCount();
}
return size;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import org.apache.commons.logging.LogFactory;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static com.google.common.base.Preconditions.checkState;

Expand All @@ -35,10 +37,11 @@ public class ObjectPoolPartition<T>
private final PoolConfig config;
private final BlockingQueue<Poolable<T>> objectQueue;
private final ObjectFactory<T> objectFactory;
private int totalCount;
private final String host;
private final int socketTimeout;
private final int connectTimeout;
private final Semaphore takeSemaphore;
private final AtomicInteger aliveObjectCount;

public ObjectPoolPartition(ObjectPool<T> pool, PoolConfig config,
ObjectFactory<T> objectFactory, BlockingQueue<Poolable<T>> queue, String host, String name)
Expand All @@ -50,113 +53,41 @@ public ObjectPoolPartition(ObjectPool<T> pool, PoolConfig config,
this.host = host;
this.socketTimeout = config.getSocketTimeoutMilliseconds();
this.connectTimeout = config.getConnectTimeoutMilliseconds();
this.totalCount = 0;
this.aliveObjectCount = new AtomicInteger();
this.log = new CustomLogger(name, host);
for (int i = 0; i < config.getMinSize(); i++) {
T object = objectFactory.create(host, socketTimeout, connectTimeout);
if (object != null) {
objectQueue.add(new Poolable<>(object, pool, host));
totalCount++;
}
}
}

public void returnObject(Poolable<T> object)
{
if (!objectFactory.validate(object.getObject())) {
log.debug(String.format("Invalid object...removing: %s ", object));
decreaseObject(object);
return;
}

log.debug(String.format("Returning object: %s to queue. Queue size: %d", object, objectQueue.size()));
if (!objectQueue.offer(object)) {
log.warn("Created more objects than configured. Created=" + totalCount + " QueueSize=" + objectQueue.size());
decreaseObject(object);
}
}

public Poolable<T> getObject(boolean blocking)
{
if (objectQueue.size() == 0) {
// increase objects and return one, it will return null if pool reaches max size or if object creation fails
Poolable<T> object = increaseObjects(this.config.getDelta(), true);

if (object != null) {
return object;
}

if (totalCount == 0) {
// Could not create objects, this is mostly due to connection timeouts hence no point blocking as there is not other producer of sockets
throw new RuntimeException("Could not add connections to pool");
}
// else wait for a connection to get free
}

Poolable<T> freeObject;
this.takeSemaphore = new Semaphore(config.getMaxSize(), true);
try {
if (blocking) {
freeObject = objectQueue.take();
}
else {
freeObject = objectQueue.poll(config.getMaxWaitMilliseconds(), TimeUnit.MILLISECONDS);
if (freeObject == null) {
throw new RuntimeException("Cannot get a free object from the pool");
}
for (int i = 0; i < config.getMinSize(); i++) {
T object = objectFactory.create(host, socketTimeout, connectTimeout);
objectQueue.add(new Poolable<>(object, pool, host));
aliveObjectCount.incrementAndGet();
}
}
catch (InterruptedException e) {
throw new RuntimeException(e); // will never happen
catch (Exception e) {
// skipping logging the exception as factories are already logging.
}

freeObject.setLastAccessTs(System.currentTimeMillis());
return freeObject;
}

private Poolable<T> increaseObjects(int delta, boolean returnObject)
public void returnObject(Poolable<T> object)
{
int oldCount = totalCount;
if (delta + totalCount > config.getMaxSize()) {
delta = config.getMaxSize() - totalCount;
}

Poolable<T> objectToReturn = null;
try {
for (int i = 0; i < delta; i++) {
T object = objectFactory.create(host, socketTimeout, connectTimeout);
if (object != null) {
// Do not put the first object on queue
// it will be returned to the caller to ensure it's request is satisfied first if object is requested
Poolable<T> poolable = new Poolable<>(object, pool, host);
if (objectToReturn == null && returnObject) {
objectToReturn = poolable;
}
else {
objectQueue.put(poolable);
}
totalCount++;
}
if (!objectFactory.validate(object.getObject())) {
log.debug(String.format("Invalid object...removing: %s ", object));
decreaseObject(object);
return;
}

if (delta > 0 && (totalCount - oldCount) == 0) {
log.warn(String.format("Could not increase pool size. Pool state: totalCount=%d queueSize=%d delta=%d", totalCount, objectQueue.size(), delta));
}
else {
log.debug(String.format("Increased pool size by %d, to new size: %d, current queue size: %d, delta: %d",
totalCount - oldCount, totalCount, objectQueue.size(), delta));
log.debug(String.format("Returning object: %s to queue. Queue size: %d", object, objectQueue.size()));
if (!objectQueue.offer(object)) {
String errorLog = "Created more objects than configured. Created=" + aliveObjectCount + " QueueSize=" + objectQueue.size();
log.warn(errorLog);
decreaseObject(object);
throw new RuntimeException(errorLog);
}
}
catch (Exception e) {
log.warn(String.format("Unable to increase pool size. Pool state: totalCount=%d queueSize=%d delta=%d", totalCount, objectQueue.size(), delta), e);
// objectToReturn is not on the queue hence untracked, clean it up before forwarding exception
if (objectToReturn != null) {
objectFactory.destroy(objectToReturn.getObject());
objectToReturn.destroy();
}
throw new RuntimeException(e);
finally {
takeSemaphore.release();
}

return objectToReturn;
}

public boolean decreaseObject(Poolable<T> obj)
Expand All @@ -165,27 +96,69 @@ public boolean decreaseObject(Poolable<T> obj)
checkState(obj.getHost().equals(this.host),
"Call to free object of wrong partition, current partition=%s requested partition = %s",
this.host, obj.getHost());
objectRemoved();
log.debug("Decreasing pool size object: " + obj);
objectFactory.destroy(obj.getObject());
aliveObjectCount.decrementAndGet();
obj.destroy();
return true;
}

private synchronized void objectRemoved()
public Poolable<T> getObject()
{
Poolable<T> object;
try {
takeSemaphore.tryAcquire(config.getMaxWaitMilliseconds(), TimeUnit.MILLISECONDS);
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
return null;
}

try {
object = tryGetObject();
object.setLastAccessTs(System.currentTimeMillis());
}
catch (Exception e) {
takeSemaphore.release();
throw new RuntimeException("Cannot get a free object from the pool", e);
}
return object;
}

private Poolable<T> tryGetObject() throws Exception
{
totalCount--;
Poolable<T> poolable = objectQueue.poll();
if (poolable == null)
{
try {
T object = objectFactory.create(host, socketTimeout, connectTimeout);
poolable = new Poolable<>(object, pool, host);
aliveObjectCount.incrementAndGet();
log.debug(String.format("Added a connection, Pool state: totalCount: %s, queueSize: %d", aliveObjectCount,
objectQueue.size()));
}
catch (Exception e) {
log.warn(String.format("Unable create a connection. Pool state: totalCount=%s queueSize=%d", aliveObjectCount,
objectQueue.size()), e);
if (poolable != null) {
objectFactory.destroy(poolable.getObject());
poolable.destroy();
}
throw e;
}
}
return poolable;
}

public synchronized int getTotalCount()
public int getAliveObjectCount()
{
return totalCount;
return aliveObjectCount.get();
}

// set the scavenge interval carefully
public void scavenge() throws InterruptedException
{
int delta = this.totalCount - config.getMinSize();
int delta = this.aliveObjectCount.get() - config.getMinSize();
if (delta <= 0) {
log.debug("Scavenge for delta <= 0, Skipping !!!");
return;
Expand Down Expand Up @@ -223,7 +196,7 @@ public void scavenge() throws InterruptedException
public synchronized int shutdown()
{
int removed = 0;
while (this.totalCount > 0) {
while (this.aliveObjectCount.get() > 0) {
Poolable<T> obj = objectQueue.poll();
if (obj != null) {
decreaseObject(obj);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ public SocketChannelObjectFactory(int port)

@Override
public SocketChannel create(String host, int socketTimeout, int connectTimeout)
throws IOException
{
SocketAddress sad = new InetSocketAddress(host, this.port);
SocketChannel socket = null;
SocketChannel socket;
try {
socket = SocketChannel.open();
socket.socket().setSoTimeout(socketTimeout);
Expand All @@ -49,6 +50,7 @@ public SocketChannel create(String host, int socketTimeout, int connectTimeout)
}
catch (IOException e) {
log.warn(LDS_POOL + " : Unable to open connection to host " + host, e);
throw e;
}
return socket;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ public SocketObjectFactory(int port)

@Override
public TSocket create(String host, int socketTimeout, int connectTimeout)
throws TTransportException
{
log.debug(BKS_POOL + " : Opening connection to host: " + host);
TSocket socket = null;
TSocket socket;
try {
socket = new TSocket(host, port, socketTimeout, connectTimeout);
socket.open();
}
catch (TTransportException e) {
socket = null;
log.warn("Unable to open connection to host " + host, e);
throw e;
}
return socket;
}
Expand Down

0 comments on commit a6cf9aa

Please sign in to comment.