Skip to content

Commit 19a0871

Browse files
committed
fix tests
1 parent 2d51526 commit 19a0871

2 files changed

Lines changed: 32 additions & 12 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2222
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
2323
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
2424
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorsUtils, StatePartitionKeyExtractorFactory}
25+
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
2526
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
2627
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateVariableInfo}
2728
import org.apache.spark.sql.execution.streaming.state._
@@ -138,6 +139,10 @@ abstract class StatePartitionReaderBase(
138139
getStoreUniqueId(partition.sourceOptions.endOperatorStateUniqueIds)
139140
}
140141

142+
protected val isJoinV4MultiValuedCF: Boolean = joinColFamilyOpt.exists { cfName =>
143+
SymmetricHashJoinStateManager.allStateStoreNamesV4(LeftSide, RightSide).contains(cfName)
144+
}
145+
141146
protected lazy val provider: StateStoreProvider = {
142147
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
143148
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
@@ -146,7 +151,7 @@ abstract class StatePartitionReaderBase(
146151
val useColFamilies = stateVariableInfoOpt.isDefined || joinColFamilyOpt.isDefined
147152

148153
val useMultipleValuesPerKey = SchemaUtil.checkVariableType(stateVariableInfoOpt,
149-
StateVariableType.ListState)
154+
StateVariableType.ListState) || isJoinV4MultiValuedCF
150155

151156
val provider = StateStoreProvider.createAndInit(
152157
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
@@ -249,6 +254,12 @@ class StatePartitionReader(
249254
val stateVarType = stateVariableInfo.stateVariableType
250255
SchemaUtil.processStateEntries(stateVarType, colFamilyName, store,
251256
keySchema, partition.partition, partition.sourceOptions)
257+
} else if (isJoinV4MultiValuedCF) {
258+
store
259+
.iteratorWithMultiValues(colFamilyName)
260+
.map { pair =>
261+
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
262+
}
252263
} else {
253264
store
254265
.iterator(colFamilyName)
@@ -333,6 +344,13 @@ class StatePartitionAllColumnFamiliesReader(
333344
StateVariableType.ListState)
334345
}
335346

347+
private val v4JoinCFNames: Set[String] =
348+
SymmetricHashJoinStateManager.allStateStoreNamesV4(LeftSide, RightSide).toSet
349+
350+
private def isMultiValuedCF(colFamilyName: String): Boolean = {
351+
isListType(colFamilyName) || v4JoinCFNames.contains(colFamilyName)
352+
}
353+
336354
override protected lazy val provider: StateStoreProvider = {
337355
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
338356
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
@@ -386,7 +404,7 @@ class StatePartitionAllColumnFamiliesReader(
386404
case _ =>
387405
val isInternal =
388406
StateStoreColumnFamilySchemaUtils.isInternalColFamily(cfSchema.colFamilyName)
389-
val useMultipleValuesPerKey = isListType(cfSchema.colFamilyName)
407+
val useMultipleValuesPerKey = isMultiValuedCF(cfSchema.colFamilyName)
390408
require(cfSchema.keyStateEncoderSpec.isDefined,
391409
s"keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}")
392410
stateStore.createColFamilyIfAbsent(
@@ -410,15 +428,11 @@ class StatePartitionAllColumnFamiliesReader(
410428
.filter(schema => !isDefaultColFamilyInTWS(operatorName, schema.colFamilyName))
411429
.flatMap { cfSchema =>
412430
val extractor = cfPartitionKeyExtractors(cfSchema.colFamilyName)
413-
if (isListType(cfSchema.colFamilyName)) {
414-
store.iterator(cfSchema.colFamilyName).flatMap(
415-
pair =>
416-
store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
417-
value =>
418-
SchemaUtil.unifyStateRowPairAsRawBytes(
419-
(pair.key, value), cfSchema.colFamilyName, extractor)
420-
}
421-
)
431+
if (isMultiValuedCF(cfSchema.colFamilyName)) {
432+
store.iteratorWithMultiValues(cfSchema.colFamilyName).map { pair =>
433+
SchemaUtil.unifyStateRowPairAsRawBytes(
434+
(pair.key, pair.value), cfSchema.colFamilyName, extractor)
435+
}
422436
} else {
423437
store.iterator(cfSchema.colFamilyName).map { pair =>
424438
SchemaUtil.unifyStateRowPairAsRawBytes(

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,13 @@ class SymmetricHashJoinStateManagerV4(
249249
// V4 uses a single store with VCFs (not separate keyToNumValues/keyWithIndexToValue stores).
250250
// Use the keyToNumValues checkpoint ID for loading the correct committed version.
251251
private val stateStoreCkptId: Option[String] = keyToNumValuesStateStoreCkptId
252-
private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None
252+
private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = snapshotOptions.map { opts =>
253+
HandlerSnapshotOptions(
254+
snapshotVersion = opts.snapshotVersion,
255+
endVersion = opts.endVersion,
256+
startStateStoreCkptId = opts.startKeyToNumValuesStateStoreCkptId,
257+
endStateStoreCkptId = opts.endKeyToNumValuesStateStoreCkptId)
258+
}
253259

254260
private var stateStoreProvider: StateStoreProvider = _
255261

0 commit comments

Comments
 (0)