1818package org .apache .spark .streaming .pubsub
1919
2020import java .io .{Externalizable , ObjectInput , ObjectOutput }
21- import java .lang
22- import java .lang .Runtime
23- import java .util .concurrent .{Executors , TimeUnit }
2421
2522import scala .collection .JavaConverters ._
23+ import scala .collection .mutable .ArrayBuffer
2624import scala .util .control .NonFatal
2725
2826import com .google .api .client .googleapis .javanet .GoogleNetHttpTransport
@@ -33,7 +31,7 @@ import com.google.api.services.pubsub.model.{AcknowledgeRequest, PubsubMessage,
3331import com .google .cloud .hadoop .util .RetryHttpInitializer
3432import com .google .common .util .concurrent .RateLimiter
3533
36- import org .apache .spark .SparkConf
34+ import org .apache .spark .{ SparkConf , SparkException }
3735import org .apache .spark .storage .StorageLevel
3836import org .apache .spark .streaming .StreamingContext
3937import org .apache .spark .streaming .dstream .ReceiverInputDStream
@@ -57,13 +55,14 @@ class PubsubInputDStream(
5755 val _storageLevel : StorageLevel ,
5856 val autoAcknowledge : Boolean ,
5957 val maxNoOfMessageInRequest : Int ,
58+ val rateMultiplierFactor : Double ,
6059 conf : SparkConf
6160) extends ReceiverInputDStream [SparkPubsubMessage ](_ssc) {
6261
6362 override def getReceiver (): Receiver [SparkPubsubMessage ] = {
6463 new PubsubReceiver (
6564 project, topic, subscription, credential, _storageLevel, autoAcknowledge,
66- maxNoOfMessageInRequest, conf
65+ maxNoOfMessageInRequest, rateMultiplierFactor, conf
6766 )
6867 }
6968}
@@ -236,13 +235,19 @@ object ConnectionUtils {
236235 * See Spark streaming configurations doc
237236 * <a href="https://spark.apache.org/docs/latest/configuration.html#spark-streaming</a>
238237 *
238+ * NOTE: For given subscription assuming ackDeadlineSeconds is sufficient.
239+ * So that messages will not expire if it is buffer for given blockIntervalMs
240+ *
239241 * @param project Google cloud project id
240242 * @param topic Topic name for creating subscription if need
241243 * @param subscription Pub/Sub subscription name
242244 * @param credential Google cloud project credential to access Pub/Sub service
243245 * @param storageLevel Storage level to be used
244246 * @param autoAcknowledge Acknowledge pubsub message or not
245247 * @param maxNoOfMessageInRequest Maximum number of message in a Pubsub pull request
248+ * @param rateMultiplierFactor Increase the proposed rate estimated by PIDEstimator to take the
249+ * advantage of dynamic allocation of executor.
250+ * Default should be 1 if dynamic allocation is not enabled
246251 * @param conf Spark config
247252 */
248253private [pubsub]
@@ -254,6 +259,7 @@ class PubsubReceiver(
254259 storageLevel : StorageLevel ,
255260 autoAcknowledge : Boolean ,
256261 maxNoOfMessageInRequest : Int ,
262+ rateMultiplierFactor : Double ,
257263 conf : SparkConf )
258264 extends Receiver [SparkPubsubMessage ](storageLevel) {
259265
@@ -267,6 +273,11 @@ class PubsubReceiver(
267273
268274 val blockSize : Int = conf.getInt(" spark.streaming.blockQueueSize" , maxNoOfMessageInRequest)
269275
276+ val blockIntervalMs : Long = conf.getTimeAsMs(" spark.streaming.blockInterval" , " 200ms" )
277+
278+ var buffer : ArrayBuffer [ReceivedMessage ] = createBufferArray()
279+
280+ var latestAttemptToPushInStoreTime : Long = - 1
270281
271282 lazy val rateLimiter : RateLimiter = RateLimiter .create(getInitialRateLimit.toDouble)
272283
@@ -285,6 +296,7 @@ class PubsubReceiver(
285296 case Some (t) =>
286297 val sub : Subscription = new Subscription
287298 sub.setTopic(s " $projectFullName/topics/ $t" )
299+ sub.setAckDeadlineSeconds(30 )
288300 try {
289301 client.projects().subscriptions().create(subscriptionFullName, sub).execute()
290302 } catch {
@@ -311,18 +323,28 @@ class PubsubReceiver(
311323 val pullRequest = new PullRequest ()
312324 .setMaxMessages(maxNoOfMessageInRequest).setReturnImmediately(false )
313325 var backoff = INIT_BACKOFF
326+
327+ // To avoid the edge case when buffer is not full and no message pushed to store
328+ latestAttemptToPushInStoreTime = System .currentTimeMillis()
329+
314330 while (! isStopped()) {
315331 try {
332+
316333 val pullResponse =
317334 client.projects().subscriptions().pull(subscriptionFullName, pullRequest).execute()
318335 val receivedMessages = pullResponse.getReceivedMessages
319336
320337 // update rate limit if required
321338 updateRateLimit()
322339
340+ // Put data into buffer
323341 if (receivedMessages != null ) {
324- pushToStoreAndAck (receivedMessages.asScala.toList )
342+ buffer.appendAll (receivedMessages.asScala)
325343 }
344+
345+ // Push data from buffer to store
346+ push()
347+
326348 backoff = INIT_BACKOFF
327349 } catch {
328350 case e : GoogleJsonResponseException =>
@@ -349,29 +371,80 @@ class PubsubReceiver(
349371 * and update the rate limiter with new rate
350372 */
351373 def updateRateLimit (): Unit = {
352- val newRateLimit = supervisor.getCurrentRateLimit.min(maxRateLimit)
374+ val newRateLimit = rateMultiplierFactor * supervisor.getCurrentRateLimit.min(maxRateLimit)
353375 if (rateLimiter.getRate != newRateLimit) {
354376 rateLimiter.setRate(newRateLimit)
355377 }
356378 }
357379
380+ /**
381+ * Push data into store if
382+ * 1. buffer size greater than equal to blockSize, or
383+ * 2. blockInterval time is passed and buffer size is less than blockSize
384+ *
385+ * Before pushing the messages, first create iterator of complete block(s) and partial blocks
386+ * and assigning new array to buffer.
387+ *
388+ * So during pushing data into store if any {@link org.apache.spark.SparkException} occur
389+ * then all un-push messages or un-ack will be lost.
390+ *
391+ * To recover lost messages we are relying on pubsub
392+ * (i.e after ack deadline passed then pubsub will again give that messages)
393+ */
394+ def push (): Unit = {
395+
396+ val diff = System .currentTimeMillis() - latestAttemptToPushInStoreTime
397+ if (buffer.length >= blockSize || (buffer.length < blockSize && diff >= blockIntervalMs)) {
398+
399+ // grouping messages into complete and partial blocks (if any)
400+ val (completeBlocks, partialBlock) = buffer.grouped(blockSize)
401+ .partition(block => block.length == blockSize)
402+
403+ // If completeBlocks is empty it means within block interval time
404+ // messages in buffer is less than blockSize. So will push partial block
405+ val iterator = if (completeBlocks.nonEmpty) completeBlocks else partialBlock
406+
407+ // Will push partial block messages back to buffer if complete blocks formed
408+ val partial = if (completeBlocks.nonEmpty && partialBlock.nonEmpty) {
409+ partialBlock.next()
410+ } else null
411+
412+ while (iterator.hasNext) {
413+ try {
414+ pushToStoreAndAck(iterator.next().toList)
415+ } catch {
416+ case e : SparkException => reportError(
417+ " Failed to write messages into reliable store" , e)
418+ case NonFatal (e) => reportError(
419+ " Failed to write messages in reliable store" , e)
420+ } finally {
421+ latestAttemptToPushInStoreTime = System .currentTimeMillis()
422+ }
423+ }
424+
425+ // clear existing buffer messages
426+ buffer.clear()
427+
428+ // Pushing partial block messages back to buffer if complete blocks formed
429+ if (partial != null ) buffer.appendAll(partial)
430+ }
431+ }
432+
358433 /**
359434 * Push the list of received message into store and ack messages if auto ack is true
360435 * @param receivedMessages
361436 */
362437 def pushToStoreAndAck (receivedMessages : List [ReceivedMessage ]): Unit = {
363- receivedMessages
438+ val messages = receivedMessages
364439 .map(x => {
365440 val sm = new SparkPubsubMessage
366441 sm.message = x.getMessage
367442 sm.ackId = x.getAckId
368443 sm})
369- .grouped(blockSize)
370- .foreach(messages => {
371- rateLimiter.acquire(messages.size)
372- store(messages.toIterator)
373- if (autoAcknowledge) acknowledgeIds(messages.map(_.ackId))
374- })
444+
445+ rateLimiter.acquire(messages.size)
446+ store(messages.toIterator)
447+ if (autoAcknowledge) acknowledgeIds(messages.map(_.ackId))
375448 }
376449
377450 /**
@@ -385,5 +458,9 @@ class PubsubReceiver(
385458 .acknowledge(subscriptionFullName, ackRequest).execute()
386459 }
387460
461+ private def createBufferArray (): ArrayBuffer [ReceivedMessage ] = {
462+ new ArrayBuffer [ReceivedMessage ](2 * math.max(maxNoOfMessageInRequest, blockSize))
463+ }
464+
388465 override def onStop (): Unit = {}
389466}
0 commit comments