@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2222import org .apache .spark .sql .connector .read .{InputPartition , PartitionReader , PartitionReaderFactory }
2323import org .apache .spark .sql .execution .datasources .v2 .state .utils .SchemaUtil
2424import org .apache .spark .sql .execution .streaming .operators .stateful .{StatefulOperatorsUtils , StatePartitionKeyExtractorFactory }
25+ import org .apache .spark .sql .execution .streaming .operators .stateful .join .StreamingSymmetricHashJoinHelper .{LeftSide , RightSide }
2526import org .apache .spark .sql .execution .streaming .operators .stateful .join .SymmetricHashJoinStateManager
2627import org .apache .spark .sql .execution .streaming .operators .stateful .transformwithstate .{StateStoreColumnFamilySchemaUtils , StateVariableType , TransformWithStateVariableInfo }
2728import org .apache .spark .sql .execution .streaming .state ._
@@ -138,6 +139,10 @@ abstract class StatePartitionReaderBase(
138139 getStoreUniqueId(partition.sourceOptions.endOperatorStateUniqueIds)
139140 }
140141
142+ protected val isJoinV4MultiValuedCF : Boolean = joinColFamilyOpt.exists { cfName =>
143+ SymmetricHashJoinStateManager .allStateStoreNamesV4(LeftSide , RightSide ).contains(cfName)
144+ }
145+
141146 protected lazy val provider : StateStoreProvider = {
142147 val stateStoreId = StateStoreId (partition.sourceOptions.stateCheckpointLocation.toString,
143148 partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
@@ -146,7 +151,7 @@ abstract class StatePartitionReaderBase(
146151 val useColFamilies = stateVariableInfoOpt.isDefined || joinColFamilyOpt.isDefined
147152
148153 val useMultipleValuesPerKey = SchemaUtil .checkVariableType(stateVariableInfoOpt,
149- StateVariableType .ListState )
154+ StateVariableType .ListState ) || isJoinV4MultiValuedCF
150155
151156 val provider = StateStoreProvider .createAndInit(
152157 stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
@@ -249,6 +254,12 @@ class StatePartitionReader(
249254 val stateVarType = stateVariableInfo.stateVariableType
250255 SchemaUtil .processStateEntries(stateVarType, colFamilyName, store,
251256 keySchema, partition.partition, partition.sourceOptions)
257+ } else if (isJoinV4MultiValuedCF) {
258+ store
259+ .iteratorWithMultiValues(colFamilyName)
260+ .map { pair =>
261+ SchemaUtil .unifyStateRowPair((pair.key, pair.value), partition.partition)
262+ }
252263 } else {
253264 store
254265 .iterator(colFamilyName)
@@ -333,6 +344,13 @@ class StatePartitionAllColumnFamiliesReader(
333344 StateVariableType .ListState )
334345 }
335346
347+ private val v4JoinCFNames : Set [String ] =
348+ SymmetricHashJoinStateManager .allStateStoreNamesV4(LeftSide , RightSide ).toSet
349+
350+ private def isMultiValuedCF (colFamilyName : String ): Boolean = {
351+ isListType(colFamilyName) || v4JoinCFNames.contains(colFamilyName)
352+ }
353+
336354 override protected lazy val provider : StateStoreProvider = {
337355 val stateStoreId = StateStoreId (partition.sourceOptions.stateCheckpointLocation.toString,
338356 partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
@@ -386,7 +404,7 @@ class StatePartitionAllColumnFamiliesReader(
386404 case _ =>
387405 val isInternal =
388406 StateStoreColumnFamilySchemaUtils .isInternalColFamily(cfSchema.colFamilyName)
389- val useMultipleValuesPerKey = isListType (cfSchema.colFamilyName)
407+ val useMultipleValuesPerKey = isMultiValuedCF (cfSchema.colFamilyName)
390408 require(cfSchema.keyStateEncoderSpec.isDefined,
391409 s " keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}" )
392410 stateStore.createColFamilyIfAbsent(
@@ -410,15 +428,11 @@ class StatePartitionAllColumnFamiliesReader(
410428 .filter(schema => ! isDefaultColFamilyInTWS(operatorName, schema.colFamilyName))
411429 .flatMap { cfSchema =>
412430 val extractor = cfPartitionKeyExtractors(cfSchema.colFamilyName)
413- if (isListType(cfSchema.colFamilyName)) {
414- store.iterator(cfSchema.colFamilyName).flatMap(
415- pair =>
416- store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
417- value =>
418- SchemaUtil .unifyStateRowPairAsRawBytes(
419- (pair.key, value), cfSchema.colFamilyName, extractor)
420- }
421- )
431+ if (isMultiValuedCF(cfSchema.colFamilyName)) {
432+ store.iteratorWithMultiValues(cfSchema.colFamilyName).map { pair =>
433+ SchemaUtil .unifyStateRowPairAsRawBytes(
434+ (pair.key, pair.value), cfSchema.colFamilyName, extractor)
435+ }
422436 } else {
423437 store.iterator(cfSchema.colFamilyName).map { pair =>
424438 SchemaUtil .unifyStateRowPairAsRawBytes(
0 commit comments