Skip to content
This repository was archived by the owner on Feb 16, 2024. It is now read-only.

Commit 526a8b7

Browse files
committed
Added buffer logic and rate multiplier factor
1 parent 2670355 commit 526a8b7

3 files changed

Lines changed: 96 additions & 16 deletions

File tree

streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubInputDStream.scala

Lines changed: 91 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818
package org.apache.spark.streaming.pubsub
1919

2020
import java.io.{Externalizable, ObjectInput, ObjectOutput}
21-
import java.lang
22-
import java.lang.Runtime
23-
import java.util.concurrent.{Executors, TimeUnit}
2421

2522
import scala.collection.JavaConverters._
23+
import scala.collection.mutable.ArrayBuffer
2624
import scala.util.control.NonFatal
2725

2826
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport
@@ -33,7 +31,7 @@ import com.google.api.services.pubsub.model.{AcknowledgeRequest, PubsubMessage,
3331
import com.google.cloud.hadoop.util.RetryHttpInitializer
3432
import com.google.common.util.concurrent.RateLimiter
3533

36-
import org.apache.spark.SparkConf
34+
import org.apache.spark.{SparkConf, SparkException}
3735
import org.apache.spark.storage.StorageLevel
3836
import org.apache.spark.streaming.StreamingContext
3937
import org.apache.spark.streaming.dstream.ReceiverInputDStream
@@ -57,13 +55,14 @@ class PubsubInputDStream(
5755
val _storageLevel: StorageLevel,
5856
val autoAcknowledge: Boolean,
5957
val maxNoOfMessageInRequest: Int,
58+
val rateMultiplierFactor: Double,
6059
conf: SparkConf
6160
) extends ReceiverInputDStream[SparkPubsubMessage](_ssc) {
6261

6362
override def getReceiver(): Receiver[SparkPubsubMessage] = {
6463
new PubsubReceiver(
6564
project, topic, subscription, credential, _storageLevel, autoAcknowledge,
66-
maxNoOfMessageInRequest, conf
65+
maxNoOfMessageInRequest, rateMultiplierFactor, conf
6766
)
6867
}
6968
}
@@ -236,13 +235,19 @@ object ConnectionUtils {
236235
* See Spark streaming configurations doc
237236
* <a href="https://spark.apache.org/docs/latest/configuration.html#spark-streaming</a>
238237
*
238+
* NOTE: For given subscription assuming ackDeadlineSeconds is sufficient.
239+
* So that messages will not expire if it is buffer for given blockIntervalMs
240+
*
239241
* @param project Google cloud project id
240242
* @param topic Topic name for creating subscription if need
241243
* @param subscription Pub/Sub subscription name
242244
* @param credential Google cloud project credential to access Pub/Sub service
243245
* @param storageLevel Storage level to be used
244246
* @param autoAcknowledge Acknowledge pubsub message or not
245247
* @param maxNoOfMessageInRequest Maximum number of message in a Pubsub pull request
248+
* @param rateMultiplierFactor Increase the proposed rate estimated by PIDEstimator to take the
249+
* advantage of dynamic allocation of executor.
250+
* Default should be 1 if dynamic allocation is not enabled
246251
* @param conf Spark config
247252
*/
248253
private[pubsub]
@@ -254,6 +259,7 @@ class PubsubReceiver(
254259
storageLevel: StorageLevel,
255260
autoAcknowledge: Boolean,
256261
maxNoOfMessageInRequest: Int,
262+
rateMultiplierFactor: Double,
257263
conf: SparkConf)
258264
extends Receiver[SparkPubsubMessage](storageLevel) {
259265

@@ -267,6 +273,11 @@ class PubsubReceiver(
267273

268274
val blockSize: Int = conf.getInt("spark.streaming.blockQueueSize", maxNoOfMessageInRequest)
269275

276+
val blockIntervalMs: Long = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms")
277+
278+
var buffer: ArrayBuffer[ReceivedMessage] = createBufferArray()
279+
280+
var latestAttemptToPushInStoreTime: Long = -1
270281

271282
lazy val rateLimiter: RateLimiter = RateLimiter.create(getInitialRateLimit.toDouble)
272283

@@ -285,6 +296,7 @@ class PubsubReceiver(
285296
case Some(t) =>
286297
val sub: Subscription = new Subscription
287298
sub.setTopic(s"$projectFullName/topics/$t")
299+
sub.setAckDeadlineSeconds(30)
288300
try {
289301
client.projects().subscriptions().create(subscriptionFullName, sub).execute()
290302
} catch {
@@ -311,18 +323,28 @@ class PubsubReceiver(
311323
val pullRequest = new PullRequest()
312324
.setMaxMessages(maxNoOfMessageInRequest).setReturnImmediately(false)
313325
var backoff = INIT_BACKOFF
326+
327+
// To avoid the edge case when buffer is not full and no message pushed to store
328+
latestAttemptToPushInStoreTime = System.currentTimeMillis()
329+
314330
while (!isStopped()) {
315331
try {
332+
316333
val pullResponse =
317334
client.projects().subscriptions().pull(subscriptionFullName, pullRequest).execute()
318335
val receivedMessages = pullResponse.getReceivedMessages
319336

320337
// update rate limit if required
321338
updateRateLimit()
322339

340+
// Put data into buffer
323341
if (receivedMessages != null) {
324-
pushToStoreAndAck(receivedMessages.asScala.toList)
342+
buffer.appendAll(receivedMessages.asScala)
325343
}
344+
345+
// Push data from buffer to store
346+
push()
347+
326348
backoff = INIT_BACKOFF
327349
} catch {
328350
case e: GoogleJsonResponseException =>
@@ -349,29 +371,80 @@ class PubsubReceiver(
349371
* and update the rate limiter with new rate
350372
*/
351373
def updateRateLimit(): Unit = {
352-
val newRateLimit = supervisor.getCurrentRateLimit.min(maxRateLimit)
374+
val newRateLimit = rateMultiplierFactor * supervisor.getCurrentRateLimit.min(maxRateLimit)
353375
if (rateLimiter.getRate != newRateLimit) {
354376
rateLimiter.setRate(newRateLimit)
355377
}
356378
}
357379

380+
/**
381+
* Push data into store if
382+
* 1. buffer size greater than equal to blockSize, or
383+
* 2. blockInterval time is passed and buffer size is less than blockSize
384+
*
385+
* Before pushing the messages, first create iterator of complete block(s) and partial blocks
386+
* and assigning new array to buffer.
387+
*
388+
* So during pushing data into store if any {@link org.apache.spark.SparkException} occur
389+
* then all un-push messages or un-ack will be lost.
390+
*
391+
* To recover lost messages we are relying on pubsub
392+
* (i.e after ack deadline passed then pubsub will again give that messages)
393+
*/
394+
def push(): Unit = {
395+
396+
val diff = System.currentTimeMillis() - latestAttemptToPushInStoreTime
397+
if (buffer.length >= blockSize || (buffer.length < blockSize && diff >= blockIntervalMs)) {
398+
399+
// grouping messages into complete and partial blocks (if any)
400+
val (completeBlocks, partialBlock) = buffer.grouped(blockSize)
401+
.partition(block => block.length == blockSize)
402+
403+
// If completeBlocks is empty it means within block interval time
404+
// messages in buffer is less than blockSize. So will push partial block
405+
val iterator = if (completeBlocks.nonEmpty) completeBlocks else partialBlock
406+
407+
// Will push partial block messages back to buffer if complete blocks formed
408+
val partial = if (completeBlocks.nonEmpty && partialBlock.nonEmpty) {
409+
partialBlock.next()
410+
} else null
411+
412+
while (iterator.hasNext) {
413+
try {
414+
pushToStoreAndAck(iterator.next().toList)
415+
} catch {
416+
case e: SparkException => reportError(
417+
"Failed to write messages into reliable store", e)
418+
case NonFatal(e) => reportError(
419+
"Failed to write messages in reliable store", e)
420+
} finally {
421+
latestAttemptToPushInStoreTime = System.currentTimeMillis()
422+
}
423+
}
424+
425+
// clear existing buffer messages
426+
buffer.clear()
427+
428+
// Pushing partial block messages back to buffer if complete blocks formed
429+
if (partial != null) buffer.appendAll(partial)
430+
}
431+
}
432+
358433
/**
359434
* Push the list of received message into store and ack messages if auto ack is true
360435
* @param receivedMessages
361436
*/
362437
def pushToStoreAndAck(receivedMessages: List[ReceivedMessage]): Unit = {
363-
receivedMessages
438+
val messages = receivedMessages
364439
.map(x => {
365440
val sm = new SparkPubsubMessage
366441
sm.message = x.getMessage
367442
sm.ackId = x.getAckId
368443
sm})
369-
.grouped(blockSize)
370-
.foreach(messages => {
371-
rateLimiter.acquire(messages.size)
372-
store(messages.toIterator)
373-
if (autoAcknowledge) acknowledgeIds(messages.map(_.ackId))
374-
})
444+
445+
rateLimiter.acquire(messages.size)
446+
store(messages.toIterator)
447+
if (autoAcknowledge) acknowledgeIds(messages.map(_.ackId))
375448
}
376449

377450
/**
@@ -385,5 +458,9 @@ class PubsubReceiver(
385458
.acknowledge(subscriptionFullName, ackRequest).execute()
386459
}
387460

461+
private def createBufferArray(): ArrayBuffer[ReceivedMessage] = {
462+
new ArrayBuffer[ReceivedMessage](2 * math.max(maxNoOfMessageInRequest, blockSize))
463+
}
464+
388465
override def onStop(): Unit = {}
389466
}

streaming-pubsub/src/main/scala/org/apache/spark/streaming/pubsub/PubsubUtils.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ object PubsubUtils {
5151
credentials: SparkGCPCredentials,
5252
storageLevel: StorageLevel,
5353
autoAcknowledge: Boolean = true,
54-
maxNoOfMessageInRequest: Int = 1000): ReceiverInputDStream[SparkPubsubMessage] = {
54+
maxNoOfMessageInRequest: Int = 1000,
55+
rateMultiplierFactor: Double = 1.0): ReceiverInputDStream[SparkPubsubMessage] = {
5556
ssc.withNamedScope("pubsub stream") {
5657

5758
new PubsubInputDStream(
@@ -63,6 +64,7 @@ object PubsubUtils {
6364
storageLevel,
6465
autoAcknowledge,
6566
maxNoOfMessageInRequest,
67+
rateMultiplierFactor,
6668
ssc.conf
6769
)
6870
}

streaming-pubsub/src/test/scala/org/apache/spark/streaming/pubsub/PubsubStreamSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class PubsubStreamSuite extends ConditionalSparkFunSuite with Eventually with Be
3535

3636
val batchDuration = Seconds(1)
3737

38-
val blockSize = 10
38+
val blockSize = 15
3939

4040
private val master: String = "local[2]"
4141

@@ -79,6 +79,7 @@ class PubsubStreamSuite extends ConditionalSparkFunSuite with Eventually with Be
7979
conf.set("spark.streaming.receiver.maxRate", "100")
8080
conf.set("spark.streaming.backpressure.pid.minRate", "10")
8181
conf.set("spark.streaming.blockQueueSize", blockSize.toString)
82+
conf.set("spark.streaming.blockInterval", "1000ms")
8283
}
8384

8485

0 commit comments

Comments
 (0)