Skip to content

Commit

Permalink
Add batch consumption of messages (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenuy authored and manub committed Feb 3, 2017
1 parent e122a2c commit d6b6a74
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.scalatest.Suite

import scala.collection.JavaConversions.mapAsJavaMap
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.concurrent.duration._
import scala.concurrent.{ExecutionContext, TimeoutException}
import scala.language.{higherKinds, postfixOps}
Expand Down Expand Up @@ -210,10 +211,23 @@ sealed trait EmbeddedKafkaSupport {
ProducerConfig.RETRY_BACKOFF_MS_CONFIG -> 1000.toString
)

private def baseConsumerConfig(implicit config: EmbeddedKafkaConfig) : Properties = {
val props = new Properties()
props.put("group.id", s"embedded-kafka-spec")
props.put("bootstrap.servers", s"localhost:${config.kafkaPort}")
props.put("auto.offset.reset", "earliest")
props.put("enable.auto.commit", "false")
props
}

def consumeFirstStringMessageFrom(topic: String, autoCommit: Boolean = false)(
implicit config: EmbeddedKafkaConfig): String =
consumeFirstMessageFrom(topic, autoCommit)(config, new StringDeserializer())

def consumeNumberStringMessagesFrom(topic: String, number: Int, autoCommit: Boolean = false)(
implicit config: EmbeddedKafkaConfig): List[String] =
consumeNumberMessagesFrom(topic, number, autoCommit)(config, new StringDeserializer())

/**
* Consumes the first message available in a given topic, deserializing it as a String.
*
Expand All @@ -238,10 +252,7 @@ sealed trait EmbeddedKafkaSupport {

import scala.collection.JavaConversions._

val props = new Properties()
props.put("group.id", s"embedded-kafka-spec")
props.put("bootstrap.servers", s"localhost:${config.kafkaPort}")
props.put("auto.offset.reset", "earliest")
val props = baseConsumerConfig
props.put("enable.auto.commit", autoCommit.toString)

val consumer =
Expand Down Expand Up @@ -271,6 +282,69 @@ sealed trait EmbeddedKafkaSupport {
}.get
}

/**
* Consumes the first n messages available in a given topic, deserializing it as a String, and returns
* the n messages as a List.
*
* Only the messsages that are returned are committed if autoCommit is false.
* If autoCommit is true then all messages that were polled will be committed.
*
* @param topic the topic to consume a message from
* @param number the number of messagese to consume in a batch
* @param autoCommit if false, only the offset for the consumed message will be commited.
* if true, the offset for the last polled message will be committed instead.
* Defaulted to false.
* @param config an implicit [[EmbeddedKafkaConfig]]
* @param deserializer an implicit [[org.apache.kafka.common.serialization.Deserializer]] for the type [[T]]
* @return the first message consumed from the given topic, with a type [[T]]
* @throws TimeoutException if unable to consume a message within 5 seconds
* @throws KafkaUnavailableException if unable to connect to Kafka
*/
def consumeNumberMessagesFrom[T](topic: String, number: Int, autoCommit: Boolean = false)(
implicit config: EmbeddedKafkaConfig,
deserializer: Deserializer[T]): List[T] = {

import scala.collection.JavaConverters._

val props = baseConsumerConfig
props.put("enable.auto.commit", autoCommit.toString)

val consumer =
new KafkaConsumer[String, T](props, new StringDeserializer, deserializer)

val messages = Try {
val messagesBuffer = ListBuffer.empty[T]
var messagesRead = 0
consumer.subscribe(List(topic).asJava)
consumer.partitionsFor(topic)

while (messagesRead < number) {
val records = consumer.poll(5000)
if (records.isEmpty) {
throw new TimeoutException(
"Unable to retrieve a message from Kafka in 5000ms")
}

val recordIter = records.iterator()
while (recordIter.hasNext && messagesRead < number) {
val record = recordIter.next()
messagesBuffer += record.value()
val tp = new TopicPartition(record.topic(), record.partition())
val om = new OffsetAndMetadata(record.offset() + 1)
consumer.commitSync(Map(tp -> om).asJava)
messagesRead += 1
}
}
messagesBuffer.toList
}

consumer.close()
messages.recover {
case ex: KafkaException => throw new KafkaUnavailableException(ex)
}.get
}


object aKafkaProducer {
private[this] var producers = Vector.empty[KafkaProducer[_, _]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,52 @@ class EmbeddedKafkaMethodsSpec extends EmbeddedKafkaSpecSupport with EmbeddedKaf
}
}

"the consumeNumberStringMessagesFrom method" should {
"consume set number of messages when multiple messages have been published to a topic" in {
val messages = Set("message 1", "message 2", "message 3")
val topic = "consume_test_topic"
val producer = new KafkaProducer[String, String](Map(
ProducerConfig.BOOTSTRAP_SERVERS_CONFIG -> s"localhost:6001",
ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[StringSerializer].getName,
ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[StringSerializer].getName
))

messages.foreach { message =>
producer.send(new ProducerRecord[String, String](topic, message))
}

producer.flush()

val consumedMessages = consumeNumberStringMessagesFrom(topic, messages.size)

consumedMessages.toSet shouldEqual messages

producer.close()
}

"timeout and throw a TimeoutException when n messages are not received in time" in {
val messages = Set("message 1", "message 2", "message 3")
val topic = "consume_test_topic"
val producer = new KafkaProducer[String, String](Map(
ProducerConfig.BOOTSTRAP_SERVERS_CONFIG -> s"localhost:6001",
ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[StringSerializer].getName,
ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[StringSerializer].getName
))

messages.foreach { message =>
producer.send(new ProducerRecord[String, String](topic, message))
}

producer.flush()

a[TimeoutException] shouldBe thrownBy {
consumeNumberStringMessagesFrom(topic, messages.size + 1)
}

producer.close()
}
}

"the aKafkaProducerThat method" should {
"return a producer that encodes messages for the given encoder" in {
val producer = aKafkaProducer thatSerializesValuesWith classOf[ByteArraySerializer]
Expand Down

0 comments on commit d6b6a74

Please sign in to comment.