diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 672e8e3786181..bcf9c5db367a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -95,7 +95,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging stateStoreReaderInfo.stateStoreColFamilySchemaOpt, stateStoreReaderInfo.stateSchemaProviderOpt, stateStoreReaderInfo.joinColFamilyOpt, - stateStoreReaderInfo.allColumnFamiliesReaderInfo) + stateStoreReaderInfo.allColumnFamiliesReaderInfo, + stateStoreReaderInfo.joinStateFormatVersion) } override def inferSchema(options: CaseInsensitiveStringMap): StructType = { @@ -110,11 +111,13 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val (keySchema, valueSchema) = sourceOptions.joinSide match { case JoinSideValues.left => StreamStreamJoinStateHelper.readKeyValueSchema(session, stateCheckpointLocation.toString, - sourceOptions.operatorId, LeftSide, oldSchemaFilePaths) + sourceOptions.operatorId, LeftSide, oldSchemaFilePaths, + joinStateFormatVersion = stateStoreReaderInfo.joinStateFormatVersion) case JoinSideValues.right => StreamStreamJoinStateHelper.readKeyValueSchema(session, stateCheckpointLocation.toString, - sourceOptions.operatorId, RightSide, oldSchemaFilePaths) + sourceOptions.operatorId, RightSide, oldSchemaFilePaths, + joinStateFormatVersion = stateStoreReaderInfo.joinStateFormatVersion) case JoinSideValues.none => // we should have the schema for the state store if joinSide is none @@ -162,9 +165,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging /** * Returns true if this is a read-all-column-families request for a stream-stream join - * that uses virtual column families (state format version 3). + * that uses virtual column families (state format version >= 3). */ - private def isReadAllColFamiliesOnJoinV3( + private def isReadAllColFamiliesOnJoinWithVCF( sourceOptions: StateSourceOptions, storeMetadata: Array[StateMetadataTableEntry]): Boolean = { sourceOptions.internalOnlyReadAllColumnFamilies && @@ -243,9 +246,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging opMetadata.operatorName match { case opName: String if opName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME => - // Verify that the storename is valid val possibleStoreNames = SymmetricHashJoinStateManager.allStateStoreNames( - LeftSide, RightSide) + LeftSide, RightSide) ++ + SymmetricHashJoinStateManager.allStateStoreNamesV4(LeftSide, RightSide) if (!possibleStoreNames.contains(name)) { val errorMsg = s"Store name $name not allowed for join operator. Allowed names are " + s"$possibleStoreNames. " + @@ -393,7 +396,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging // However, Join V3 does not have a "default" column family. Therefore, we pick the first // schema as resultSchema which will be used as placeholder schema for default schema // in StatePartitionAllColumnFamiliesReader - val resultSchema = if (isReadAllColFamiliesOnJoinV3(sourceOptions, storeMetadata)) { + val resultSchema = if (isReadAllColFamiliesOnJoinWithVCF(sourceOptions, storeMetadata)) { stateSchema.head } else { stateSchema.filter(_.colFamilyName == stateVarName).head @@ -408,17 +411,18 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging } } + val joinFormatVersion = getStateFormatVersion( + storeMetadata, + sourceOptions.resolvedCpLocation, + sourceOptions.batchId + ) + val allColFamilyReaderInfoOpt: Option[AllColumnFamiliesReaderInfo] = if (sourceOptions.internalOnlyReadAllColumnFamilies) { assert(storeMetadata.nonEmpty, "storeMetadata shouldn't be empty") val operatorName = storeMetadata.head.operatorName - val stateFormatVersion = getStateFormatVersion( - storeMetadata, - sourceOptions.resolvedCpLocation, - sourceOptions.batchId - ) Some(AllColumnFamiliesReaderInfo( - stateStoreColFamilySchemas, stateVariableInfos, operatorName, stateFormatVersion)) + stateStoreColFamilySchemas, stateVariableInfos, operatorName, joinFormatVersion)) } else { None } @@ -428,7 +432,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging transformWithStateVariableInfoOpt, stateSchemaProvider, joinColFamilyOpt, - allColFamilyReaderInfoOpt + allColFamilyReaderInfoOpt, + joinFormatVersion ) } @@ -819,7 +824,8 @@ case class StateStoreReaderInfo( stateSchemaProviderOpt: Option[StateSchemaProvider], joinColFamilyOpt: Option[String], // Only used for join op with state format v3 // List of all column family schemas - used when internalOnlyReadAllColumnFamilies=true - allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] + allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo], + joinStateFormatVersion: Option[Int] = None ) object StateDataSource { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index f03e3e26f6ee7..2f7d9ab2d8c7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorsUtils, StatePartitionKeyExtractorFactory} +import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.state._ @@ -43,6 +44,18 @@ case class AllColumnFamiliesReaderInfo( operatorName: String, stateFormatVersion: Option[Int] = None) +private[state] object StatePartitionReaderUtils { + val v4JoinCFNames: Set[String] = + SymmetricHashJoinStateManager.allStateStoreNamesV4(LeftSide, RightSide).toSet + + def isMultiValuedCF( + colFamilyNameOpt: Option[String], + stateVariableInfoOpt: Option[TransformWithStateVariableInfo]): Boolean = { + SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.ListState) || + colFamilyNameOpt.exists(v4JoinCFNames.contains) + } +} + /** * An implementation of [[PartitionReaderFactory]] for State data source. This is used to support * general read from a state store instance, rather than specific to the operator. @@ -145,8 +158,8 @@ abstract class StatePartitionReaderBase( val useColFamilies = stateVariableInfoOpt.isDefined || joinColFamilyOpt.isDefined - val useMultipleValuesPerKey = SchemaUtil.checkVariableType(stateVariableInfoOpt, - StateVariableType.ListState) + val useMultipleValuesPerKey = StatePartitionReaderUtils.isMultiValuedCF( + joinColFamilyOpt, stateVariableInfoOpt) val provider = StateStoreProvider.createAndInit( stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec, @@ -249,6 +262,12 @@ class StatePartitionReader( val stateVarType = stateVariableInfo.stateVariableType SchemaUtil.processStateEntries(stateVarType, colFamilyName, store, keySchema, partition.partition, partition.sourceOptions) + } else if (joinColFamilyOpt.exists(StatePartitionReaderUtils.v4JoinCFNames.contains)) { + store + .iteratorWithMultiValues(colFamilyName) + .map { pair => + SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) + } } else { store .iterator(colFamilyName) @@ -327,10 +346,10 @@ class StatePartitionAllColumnFamiliesReader( }.toMap } - private def isListType(colFamilyName: String): Boolean = { - SchemaUtil.checkVariableType( - stateVariableInfos.find(info => info.stateName == colFamilyName), - StateVariableType.ListState) + private def isMultiValuedCF(colFamilyName: String): Boolean = { + StatePartitionReaderUtils.isMultiValuedCF( + Some(colFamilyName), + stateVariableInfos.find(info => info.stateName == colFamilyName)) } override protected lazy val provider: StateStoreProvider = { @@ -386,7 +405,7 @@ class StatePartitionAllColumnFamiliesReader( case _ => val isInternal = StateStoreColumnFamilySchemaUtils.isInternalColFamily(cfSchema.colFamilyName) - val useMultipleValuesPerKey = isListType(cfSchema.colFamilyName) + val useMultipleValuesPerKey = isMultiValuedCF(cfSchema.colFamilyName) require(cfSchema.keyStateEncoderSpec.isDefined, s"keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}") stateStore.createColFamilyIfAbsent( @@ -410,15 +429,11 @@ class StatePartitionAllColumnFamiliesReader( .filter(schema => !isDefaultColFamilyInTWS(operatorName, schema.colFamilyName)) .flatMap { cfSchema => val extractor = cfPartitionKeyExtractors(cfSchema.colFamilyName) - if (isListType(cfSchema.colFamilyName)) { - store.iterator(cfSchema.colFamilyName).flatMap( - pair => - store.valuesIterator(pair.key, cfSchema.colFamilyName).map { - value => - SchemaUtil.unifyStateRowPairAsRawBytes( - (pair.key, value), cfSchema.colFamilyName, extractor) - } - ) + if (isMultiValuedCF(cfSchema.colFamilyName)) { + store.iteratorWithMultiValues(cfSchema.colFamilyName).map { pair => + SchemaUtil.unifyStateRowPairAsRawBytes( + (pair.key, pair.value), cfSchema.colFamilyName, extractor) + } } else { store.iterator(cfSchema.colFamilyName).map { pair => SchemaUtil.unifyStateRowPairAsRawBytes( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala index c3056767ee4b9..1bdc4d26f169e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala @@ -48,11 +48,12 @@ class StateScanBuilder( stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], stateSchemaProviderOpt: Option[StateSchemaProvider], joinColFamilyOpt: Option[String], - allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) extends ScanBuilder { + allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo], + joinStateFormatVersion: Option[Int] = None) extends ScanBuilder { override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf, batchNumPartitions, keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, - joinColFamilyOpt, allColumnFamiliesReaderInfo) + joinColFamilyOpt, allColumnFamiliesReaderInfo, joinStateFormatVersion) } /** An implementation of [[InputPartition]] for State Store data source. */ @@ -73,7 +74,8 @@ class StateScan( stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], stateSchemaProviderOpt: Option[StateSchemaProvider], joinColFamilyOpt: Option[String], - allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) + allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo], + joinStateFormatVersion: Option[Int] = None) extends Scan with Batch { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it @@ -136,9 +138,10 @@ class StateScan( hadoopConfBroadcast.value.value) val stateSchema = StreamStreamJoinStateHelper.readSchema(session, sourceOptions.stateCheckpointLocation.toString, sourceOptions.operatorId, LeftSide, - oldSchemaFilePaths, excludeAuxColumns = false) + oldSchemaFilePaths, excludeAuxColumns = false, + joinStateFormatVersion = joinStateFormatVersion) new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf, - hadoopConfBroadcast.value, userFacingSchema, stateSchema) + hadoopConfBroadcast.value, userFacingSchema, stateSchema, joinStateFormatVersion) case JoinSideValues.right => val userFacingSchema = schema @@ -146,9 +149,10 @@ class StateScan( hadoopConfBroadcast.value.value) val stateSchema = StreamStreamJoinStateHelper.readSchema(session, sourceOptions.stateCheckpointLocation.toString, sourceOptions.operatorId, RightSide, - oldSchemaFilePaths, excludeAuxColumns = false) + oldSchemaFilePaths, excludeAuxColumns = false, + joinStateFormatVersion = joinStateFormatVersion) new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf, - hadoopConfBroadcast.value, userFacingSchema, stateSchema) + hadoopConfBroadcast.value, userFacingSchema, stateSchema, joinStateFormatVersion) case JoinSideValues.none => new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index 43f2f28c6b954..de92ed8d1fefd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -47,7 +47,8 @@ class StateTable( stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], stateSchemaProviderOpt: Option[StateSchemaProvider], joinColFamilyOpt: Option[String], - allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] = None) + allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] = None, + joinStateFormatVersion: Option[Int] = None) extends Table with SupportsRead with SupportsMetadataColumns { import StateTable._ @@ -90,7 +91,7 @@ class StateTable( new StateScanBuilder(session, schema, sourceOptions, stateConf, batchNumPartitions, keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, - joinColFamilyOpt, allColumnFamiliesReaderInfo) + joinColFamilyOpt, allColumnFamiliesReaderInfo, joinStateFormatVersion) override def properties(): util.Map[String, String] = Map.empty[String, String].asJava diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala index 78648a01a6f6f..60af21a044439 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala @@ -41,22 +41,23 @@ object StreamStreamJoinStateHelper { operatorId: Int, side: JoinSide, oldSchemaFilePaths: List[Path], - excludeAuxColumns: Boolean = true): StructType = { + excludeAuxColumns: Boolean = true, + joinStateFormatVersion: Option[Int] = None): StructType = { val (keySchema, valueSchema) = readKeyValueSchema(session, stateCheckpointLocation, - operatorId, side, oldSchemaFilePaths, excludeAuxColumns) + operatorId, side, oldSchemaFilePaths, excludeAuxColumns, joinStateFormatVersion) new StructType() .add("key", keySchema) .add("value", valueSchema) } - // Returns whether the checkpoint uses stateFormatVersion 3 which uses VCF for the join. + // Returns whether the checkpoint uses VCF for the join (stateFormatVersion >= 3). def usesVirtualColumnFamilies( hadoopConf: Configuration, stateCheckpointLocation: String, operatorId: Int): Boolean = { // If the schema exists for operatorId/partitionId/left-keyToNumValues, it is not - // stateFormatVersion 3. + // stateFormatVersion >= 3 (which uses VCF). val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA val storeId = new StateStoreId(stateCheckpointLocation, operatorId, partitionId, SymmetricHashJoinStateManager.allStateStoreNames(LeftSide).toList.head) @@ -72,16 +73,17 @@ object StreamStreamJoinStateHelper { operatorId: Int, side: JoinSide, oldSchemaFilePaths: List[Path], - excludeAuxColumns: Boolean = true): (StructType, StructType) = { + excludeAuxColumns: Boolean = true, + joinStateFormatVersion: Option[Int] = None): (StructType, StructType) = { val newHadoopConf = session.sessionState.newHadoopConf() val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA - // KeyToNumValuesType, KeyWithIndexToValueType - val storeNames = SymmetricHashJoinStateManager.allStateStoreNames(side).toList - val (keySchema, valueSchema) = - if (!usesVirtualColumnFamilies( - newHadoopConf, stateCheckpointLocation, operatorId)) { + val (keySchema, valueSchema) = joinStateFormatVersion match { + case Some(1) | Some(2) | None => + // v1/v2: separate state stores per store type. + // None handles old checkpoints without operator metadata (always v1/v2). + val storeNames = SymmetricHashJoinStateManager.allStateStoreNames(side).toList val storeIdForKeyToNumValues = new StateStoreId(stateCheckpointLocation, operatorId, partitionId, storeNames(0)) val providerIdForKeyToNumValues = new StateStoreProviderId(storeIdForKeyToNumValues, @@ -92,42 +94,64 @@ object StreamStreamJoinStateHelper { val providerIdForKeyWithIndexToValue = new StateStoreProviderId( storeIdForKeyWithIndexToValue, UUID.randomUUID()) - // read the key schema from the keyToNumValues store for the join keys val manager = new StateSchemaCompatibilityChecker( providerIdForKeyToNumValues, newHadoopConf, oldSchemaFilePaths, createSchemaDir = false) val kSchema = manager.readSchemaFile().head.keySchema - // read the value schema from the keyWithIndexToValue store for the values val manager2 = new StateSchemaCompatibilityChecker(providerIdForKeyWithIndexToValue, newHadoopConf, oldSchemaFilePaths, createSchemaDir = false) val vSchema = manager2.readSchemaFile().head.valueSchema (kSchema, vSchema) - } else { + + case Some(3) => + // v3: single state store with virtual column families + val storeNames = SymmetricHashJoinStateManager.allStateStoreNames(side).toList val storeId = new StateStoreId(stateCheckpointLocation, operatorId, partitionId, StateStoreId.DEFAULT_STORE_NAME) val providerId = new StateStoreProviderId(storeId, UUID.randomUUID()) val manager = new StateSchemaCompatibilityChecker( providerId, newHadoopConf, oldSchemaFilePaths, createSchemaDir = false) - val kSchema = manager.readSchemaFile().find { schema => - schema.colFamilyName == storeNames(0) - }.map(_.keySchema).get + val schemas = manager.readSchemaFile() - val vSchema = manager.readSchemaFile().find { schema => - schema.colFamilyName == storeNames(1) - }.map(_.valueSchema).get + val kSchema = schemas.find(_.colFamilyName == storeNames(0)).map(_.keySchema).get + val vSchema = schemas.find(_.colFamilyName == storeNames(1)).map(_.valueSchema).get (kSchema, vSchema) - } + + case Some(4) => + // v4: single state store with virtual column families, timestamp-based keys + val v4Names = SymmetricHashJoinStateManager.allStateStoreNamesV4(side).toList + val storeId = new StateStoreId(stateCheckpointLocation, operatorId, + partitionId, StateStoreId.DEFAULT_STORE_NAME) + val providerId = new StateStoreProviderId(storeId, UUID.randomUUID()) + + val manager = new StateSchemaCompatibilityChecker( + providerId, newHadoopConf, oldSchemaFilePaths, createSchemaDir = false) + val schemas = manager.readSchemaFile() + + // In v4, the primary CF (keyWithTsToValues) stores both the key and value schemas. + // This differs from v3 where keyToNumValues has the key schema and + // keyWithIndexToValue has the value schema. + val primaryCF = v4Names(0) + val kSchema = schemas.find(_.colFamilyName == primaryCF).map(_.keySchema).get + val vSchema = schemas.find(_.colFamilyName == primaryCF).map(_.valueSchema).get + + (kSchema, vSchema) + + case Some(v) => + throw new IllegalArgumentException( + s"Unsupported join state format version: $v") + } val maybeMatchedColumn = valueSchema.last + // remove internal column `matched` for format version >= 2 if (excludeAuxColumns && maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) { - // remove internal column `matched` for format version 2 (keySchema, StructType(valueSchema.dropRight(1))) } else { (keySchema, valueSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index 4946977e72b7e..471ebc6fa8a6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -38,10 +38,12 @@ class StreamStreamJoinStatePartitionReaderFactory( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, userFacingSchema: StructType, - stateSchema: StructType) extends PartitionReaderFactory { + stateSchema: StructType, + joinStateFormatVersion: Option[Int] = None) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { new StreamStreamJoinStatePartitionReader(storeConf, hadoopConf, - partition.asInstanceOf[StateStoreInputPartition], userFacingSchema, stateSchema) + partition.asInstanceOf[StateStoreInputPartition], userFacingSchema, stateSchema, + joinStateFormatVersion) } } @@ -54,7 +56,9 @@ class StreamStreamJoinStatePartitionReader( hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, userFacingSchema: StructType, - stateSchema: StructType) extends PartitionReader[InternalRow] with Logging { + stateSchema: StructType, + joinStateFormatVersion: Option[Int] = None) + extends PartitionReader[InternalRow] with Logging { private val keySchema = SchemaUtil.getSchemaAsDataType(stateSchema, "key") .asInstanceOf[StructType] @@ -112,28 +116,27 @@ class StreamStreamJoinStatePartitionReader( endStateStoreCheckpointIds.right.keyWithIndexToValue } - /* - * This is to handle the difference of schema across state format versions. The major difference - * is whether we have added new field(s) in addition to the fields from input schema. - * - * - version 1: no additional field - * - version 2: the field "matched" is added to the last - */ private val (inputAttributes, formatVersion) = { val maybeMatchedColumn = valueSchema.last - val (fields, version) = { - // If there is a matched column, version is either 2 or 3. We need to drop the matched - // column from the value schema to get the actual fields. - if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) { - // If checkpoint is using one store and virtual column families, version is 3 - if (usesVirtualColumnFamilies) { - (valueSchema.dropRight(1), 3) + // If there is a matched column, version is higher than 1. We need to drop the matched + // column from the value schema to get the actual fields. + val (fields, version) = joinStateFormatVersion match { + // Use explicit format version when available from offset log + case Some(v) if v >= 2 => + (valueSchema.dropRight(1), v) + case Some(1) => + (valueSchema, 1) + // Fall back to heuristic-based detection for old checkpoints + case _ => + if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) { + if (usesVirtualColumnFamilies) { + (valueSchema.dropRight(1), 3) + } else { + (valueSchema.dropRight(1), 2) + } } else { - (valueSchema.dropRight(1), 2) + (valueSchema, 1) } - } else { - (valueSchema, 1) - } } assert(fields.toArray.sameElements(userFacingValueSchema.fields), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index a70a5da49a441..ef03a42a00707 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -276,7 +276,13 @@ class SymmetricHashJoinStateManagerV4( // V4 uses a single store with VCFs (not separate keyToNumValues/keyWithIndexToValue stores). // Use the keyToNumValues checkpoint ID for loading the correct committed version. private val stateStoreCkptId: Option[String] = keyToNumValuesStateStoreCkptId - private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None + private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = snapshotOptions.map { opts => + HandlerSnapshotOptions( + snapshotVersion = opts.snapshotVersion, + endVersion = opts.endVersion, + startStateStoreCkptId = opts.startKeyToNumValuesStateStoreCkptId, + endStateStoreCkptId = opts.endKeyToNumValuesStateStoreCkptId) + } private var stateStoreProvider: StateStoreProvider = _ @@ -1913,6 +1919,13 @@ object SymmetricHashJoinStateManager { } } + def allStateStoreNamesV4(joinSides: JoinSide*): Seq[String] = { + val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyWithTsToValuesType, TsWithKeyType) + for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { + getStateStoreName(joinSide, stateStoreType) + } + } + def getSchemaForStateStores( joinSide: JoinSide, inputValueAttributes: Seq[Attribute], @@ -1964,7 +1977,6 @@ object SymmetricHashJoinStateManager { inputValueAttributes: Seq[Attribute], joinKeys: Seq[Expression], stateFormatVersion: Int): Map[String, StateStoreColFamilySchema] = { - // Convert the original schemas for state stores into StateStoreColFamilySchema objects val schemas = getSchemaForStateStores(joinSide, inputValueAttributes, joinKeys, stateFormatVersion) @@ -2141,9 +2153,13 @@ object SymmetricHashJoinStateManager { } else if (storeName == getStateStoreName(LeftSide, KeyWithIndexToValueType) || storeName == getStateStoreName(RightSide, KeyWithIndexToValueType)) { KeyWithIndexToValueType + } else if (storeName == getStateStoreName(LeftSide, KeyWithTsToValuesType) || + storeName == getStateStoreName(RightSide, KeyWithTsToValuesType)) { + KeyWithTsToValuesType + } else if (storeName == getStateStoreName(LeftSide, TsWithKeyType) || + storeName == getStateStoreName(RightSide, TsWithKeyType)) { + TsWithKeyType } else { - // TODO: [SPARK-55628] Add support of KeyWithTsToValuesType and TsWithKeyType during - // integration. throw new IllegalArgumentException(s"Unsupported join store name: $storeName") } } @@ -2158,17 +2174,17 @@ object SymmetricHashJoinStateManager { stateFormatVersion: Int): StatePartitionKeyExtractor = { assert(stateFormatVersion <= 4, "State format version must be less than or equal to 4") val name = if (stateFormatVersion >= 3) colFamilyName else storeName - if (getStoreType(name) == KeyWithIndexToValueType) { - // For KeyWithIndex, the index is added to the join (i.e. partition) key. - // Drop the last field (index) to get the partition key - new DropLastNFieldsStatePartitionKeyExtractor(stateKeySchema, numLastColsToDrop = 1) - } else if (getStoreType(name) == KeyToNumValuesType) { - // State key is the partition key - new NoopStatePartitionKeyExtractor(stateKeySchema) - } else { - // TODO: [SPARK-55628] Add support of KeyWithTsToValuesType and TsWithKeyType during - // integration. - throw new IllegalArgumentException(s"Unsupported join store name: $storeName") + getStoreType(name) match { + case KeyWithIndexToValueType => + // For KeyWithIndex, the index is added to the join (i.e. partition) key. + // Drop the last field (index) to get the partition key + new DropLastNFieldsStatePartitionKeyExtractor(stateKeySchema, numLastColsToDrop = 1) + case KeyToNumValuesType => + new NoopStatePartitionKeyExtractor(stateKeySchema) + case KeyWithTsToValuesType | TsWithKeyType => + // For v4 stores, the logical key schema in the schema file is just the join key + // (timestamp is managed by the encoder), so the state key IS the partition key. + new NoopStatePartitionKeyExtractor(stateKeySchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index 7da8c5a6bd3ca..a0b73f7d02a1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -170,7 +170,9 @@ object SchemaHelper { version match { case 1 if Utils.isTesting => new SchemaV1Writer case 2 => new SchemaV2Writer - case 3 => new SchemaV3Writer + case 3 | 4 => new SchemaV3Writer + case _ => throw new IllegalArgumentException( + s"Unsupported schema writer version: $version") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 166ec450bfbdd..0d0cdff6c1093 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -132,8 +132,8 @@ class StateSchemaCompatibilityChecker( stateStoreColFamilySchema: List[StateStoreColFamilySchema], stateSchemaVersion: Int): Unit = { // Ensure that schema file path is passed explicitly for schema version 3 - if (stateSchemaVersion == SCHEMA_FORMAT_V3 && newSchemaFilePath.isEmpty) { - throw new IllegalStateException("Schema file path is required for schema version 3") + if (stateSchemaVersion >= SCHEMA_FORMAT_V3 && newSchemaFilePath.isEmpty) { + throw new IllegalStateException("Schema file path is required for schema version 3+") } val schemaWriter = SchemaWriter.createSchemaWriter(stateSchemaVersion) @@ -302,7 +302,7 @@ class StateSchemaCompatibilityChecker( newColFamilies.diff(oldColFamilies).toList, oldColFamilies.diff(newColFamilies).toList) } - if (stateSchemaVersion == SCHEMA_FORMAT_V3 && newSchemaFileWritten) { + if (stateSchemaVersion >= SCHEMA_FORMAT_V3 && newSchemaFileWritten) { createSchemaFile(evolvedSchemas.sortBy(_.colFamilyName), stateSchemaVersion) } @@ -410,10 +410,10 @@ object StateSchemaCompatibilityChecker extends Logging { throw result.get } val schemaFileLocation = if (evolvedSchema) { - // if we are using the state schema v3, and we have + // if we are using the state schema v3+, and we have // evolved schema, this newSchemaFilePath should be defined // and we want to populate the metadata with this file - if (stateSchemaVersion == SCHEMA_FORMAT_V3) { + if (stateSchemaVersion >= SCHEMA_FORMAT_V3) { newSchemaFilePath.get.toString } else { // if we are using any version less than v3, we have written @@ -422,10 +422,10 @@ object StateSchemaCompatibilityChecker extends Logging { } } else { // if we have not evolved schema (there has been a previous schema) - // and we are using state schema v3, this file path would be defined + // and we are using state schema v3+, this file path would be defined // so we would just populate the next run's metadata file with this // file path - if (stateSchemaVersion == SCHEMA_FORMAT_V3) { + if (stateSchemaVersion >= SCHEMA_FORMAT_V3) { oldSchemaFilePaths.last.toString } else { // if we are using any version less than v3, we have written diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index a1be83627f319..c26d68f000425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -48,6 +48,106 @@ class RocksDBWithChangelogCheckpointStateDataSourceChangeDataReaderSuite extends spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled", "true") } + + test("read stream-stream join v4 state change feed") { + withSQLConf( + SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "4", + SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "true") { + withTempDir { tempDir => + runStreamStreamJoinQuery(tempDir.getAbsolutePath) + + // event time in microseconds, timestamp in milliseconds + def evtTime(sec: Int): Long = sec * 1000000L + def ts(sec: Int): Timestamp = new Timestamp(sec * 1000L) + + val keyWithTsToValuesDf = spark.read.format("statestore") + .option(StateSourceOptions.STORE_NAME, "left-keyWithTsToValues") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load(tempDir.getAbsolutePath) + + // Schema: (batch_id, change_type, + // key: (field0: Int, __event_time: Long), + // value: (leftId: Int, leftTime: Timestamp, matched: Boolean), + // partition_id) + // Each entry first appears as "append" with matched=false, then an "update" + // with matched=true once the join finds a match on the right side. + checkAnswer(keyWithTsToValuesDf, Seq( + Row(0L, "append", Row(2, evtTime(2)), Row(2, ts(2), false), 4), + Row(0L, "update", Row(2, evtTime(2)), Row(2, ts(2), true), 4), + Row(0L, "append", Row(4, evtTime(4)), Row(4, ts(4), false), 2), + Row(0L, "update", Row(4, evtTime(4)), Row(4, ts(4), true), 2), + Row(1L, "append", Row(6, evtTime(6)), Row(6, ts(6), false), 4), + Row(1L, "update", Row(6, evtTime(6)), Row(6, ts(6), true), 4), + Row(1L, "append", Row(8, evtTime(8)), Row(8, ts(8), false), 3), + Row(1L, "update", Row(8, evtTime(8)), Row(8, ts(8), true), 3), + Row(1L, "append", Row(10, evtTime(10)), Row(10, ts(10), false), 2), + Row(1L, "update", Row(10, evtTime(10)), Row(10, ts(10), true), 2) + )) + + val tsWithKeyDf = spark.read.format("statestore") + .option(StateSourceOptions.STORE_NAME, "left-tsWithKey") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load(tempDir.getAbsolutePath) + + // Schema: (batch_id, change_type, + // key: (field0: Int, __event_time: Long), + // value: (__dummy__: Void), + // partition_id) + checkAnswer(tsWithKeyDf, Seq( + Row(0L, "append", Row(2, evtTime(2)), Row(null), 4), + Row(0L, "append", Row(4, evtTime(4)), Row(null), 2), + Row(1L, "append", Row(6, evtTime(6)), Row(null), 4), + Row(1L, "append", Row(8, evtTime(8)), Row(null), 3), + Row(1L, "append", Row(10, evtTime(10)), Row(null), 2) + )) + + val rightKeyWithTsToValuesDf = spark.read.format("statestore") + .option(StateSourceOptions.STORE_NAME, "right-keyWithTsToValues") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load(tempDir.getAbsolutePath) + + // Schema: (batch_id, change_type, + // key: (field0: Int, __event_time: Long), + // value: (rightId: Int, rightTime: Timestamp, matched: Boolean), + // partition_id) + // Right side values are appended with matched=true directly because the + // right side is processed after the left side within the same batch, + // so the match is found immediately. + checkAnswer(rightKeyWithTsToValuesDf, Seq( + Row(0L, "append", Row(2, evtTime(2)), Row(2, ts(2), true), 4), + Row(0L, "append", Row(4, evtTime(4)), Row(4, ts(4), true), 2), + Row(1L, "append", Row(6, evtTime(6)), Row(6, ts(6), true), 4), + Row(1L, "append", Row(8, evtTime(8)), Row(8, ts(8), true), 3), + Row(1L, "append", Row(10, evtTime(10)), Row(10, ts(10), true), 2) + )) + + val rightTsWithKeyDf = spark.read.format("statestore") + .option(StateSourceOptions.STORE_NAME, "right-tsWithKey") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load(tempDir.getAbsolutePath) + + // Schema: (batch_id, change_type, + // key: (field0: Int, __event_time: Long), + // value: (__dummy__: Void), + // partition_id) + checkAnswer(rightTsWithKeyDf, Seq( + Row(0L, "append", Row(2, evtTime(2)), Row(null), 4), + Row(0L, "append", Row(4, evtTime(4)), Row(null), 2), + Row(1L, "append", Row(6, evtTime(6)), Row(null), 4), + Row(1L, "append", Row(8, evtTime(8)), Row(null), 3), + Row(1L, "append", Row(10, evtTime(10)), Row(null), 2) + )) + } + } + } } class RocksDBWithCheckpointV2StateDataSourceChangeDataReaderSuite extends diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 05916c4816a9d..2def79828fac1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -579,6 +579,10 @@ StateDataSourceReadSuite { test("snapshotStartBatchId on join state v3") { testSnapshotOnJoinStateV3() } + + test("snapshotStartBatchId on join state v4") { + testSnapshotOnJoinStateV4() + } } class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceReadSuite { @@ -678,6 +682,10 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSnapshotSuite extends StateDat test("snapshotStartBatchId on join state v3") { testSnapshotOnJoinStateV3() } + + test("snapshotStartBatchId on join state v4") { + testSnapshotOnJoinStateV4() + } } abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Assertions { @@ -983,6 +991,10 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass testStreamStreamJoin(3) } + test("stream-stream join, state ver 4") { + testStreamStreamJoin(4) + } + private def testStreamStreamJoin(stateVersion: Int): Unit = { def assertInternalColumnIsNotExposed(df: DataFrame): Unit = { val valueSchema = SchemaUtil.getSchemaAsDataType(df.schema, "value") @@ -993,8 +1005,8 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass } } - // We should only test state version 3 with RocksDBStateStoreProvider - if (stateVersion == 3 + // State version >= 3 requires RocksDBStateStoreProvider + if (stateVersion >= 3 && SQLConf.get.stateStoreProviderClass != classOf[RocksDBStateStoreProvider].getName) { return } @@ -1036,48 +1048,86 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass Seq(Row(6, 6, 6L), Row(8, 8, 8L), Row(10, 10, 10L)) ) - val stateReaderForRightKeyToNumValues = spark.read - .format("statestore") - .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) - .option(StateSourceOptions.STORE_NAME, - "right-keyToNumValues") - - val stateReadDfForRightKeyToNumValues = stateReaderForRightKeyToNumValues.load() - val resultDf3 = stateReadDfForRightKeyToNumValues - .selectExpr("key.field0 AS key_0", "value.value") + if (stateVersion <= 3) { + // v1-v3: test reading specific stores by name + val stateReaderForRightKeyToNumValues = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STORE_NAME, + "right-keyToNumValues") - checkAnswer( - resultDf3, - Seq(Row(6, 1L), Row(8, 1L), Row(10, 1L)) - ) + val stateReadDfForRightKeyToNumValues = stateReaderForRightKeyToNumValues.load() + val resultDf3 = stateReadDfForRightKeyToNumValues + .selectExpr("key.field0 AS key_0", "value.value") - val stateReaderForRightKeyWithIndexToValue = spark.read - .format("statestore") - .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) - .option(StateSourceOptions.STORE_NAME, - "right-keyWithIndexToValue") + checkAnswer( + resultDf3, + Seq(Row(6, 1L), Row(8, 1L), Row(10, 1L)) + ) - val stateReadDfForRightKeyWithIndexToValue = stateReaderForRightKeyWithIndexToValue.load() + val stateReaderForRightKeyWithIndexToValue = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STORE_NAME, + "right-keyWithIndexToValue") + + val stateReadDfForRightKeyWithIndexToValue = + stateReaderForRightKeyWithIndexToValue.load() + + if (stateVersion >= 2) { + val resultDf4 = stateReadDfForRightKeyWithIndexToValue + .selectExpr("key.field0 AS key_0", "key.index AS key_index", + "value.rightId AS rightId", "CAST(value.rightTime AS integer) AS rightTime", + "value.matched As matched") + + checkAnswer( + resultDf4, + Seq(Row(6, 0, 6, 6L, true), Row(8, 0, 8, 8L, true), Row(10, 0, 10, 10L, true)) + ) + } else { + // stateVersion == 1 + val resultDf4 = stateReadDfForRightKeyWithIndexToValue + .selectExpr("key.field0 AS key_0", "key.index AS key_index", + "value.rightId AS rightId", "CAST(value.rightTime AS integer) AS rightTime") + + checkAnswer( + resultDf4, + Seq(Row(6, 0, 6, 6L), Row(8, 0, 8, 8L), Row(10, 0, 10, 10L)) + ) + } + } else { + // v4: test reading specific stores by name + val stateReaderForRightKeyWithTsToValues = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STORE_NAME, + "right-keyWithTsToValues") - if (stateVersion >= 2) { - val resultDf4 = stateReadDfForRightKeyWithIndexToValue - .selectExpr("key.field0 AS key_0", "key.index AS key_index", + val stateReadDfForRightKeyWithTsToValues = + stateReaderForRightKeyWithTsToValues.load() + val resultDf3 = stateReadDfForRightKeyWithTsToValues + .selectExpr("key.field0 AS key_0", "value.rightId AS rightId", "CAST(value.rightTime AS integer) AS rightTime", - "value.matched As matched") + "value.matched AS matched") checkAnswer( - resultDf4, - Seq(Row(6, 0, 6, 6L, true), Row(8, 0, 8, 8L, true), Row(10, 0, 10, 10L, true)) + resultDf3, + Seq(Row(6, 6, 6L, true), Row(8, 8, 8L, true), Row(10, 10, 10L, true)) ) - } else { - // stateVersion == 1 - val resultDf4 = stateReadDfForRightKeyWithIndexToValue - .selectExpr("key.field0 AS key_0", "key.index AS key_index", - "value.rightId AS rightId", "CAST(value.rightTime AS integer) AS rightTime") + + val stateReaderForRightTsWithKey = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STORE_NAME, + "right-tsWithKey") + + val stateReadDfForRightTsWithKey = stateReaderForRightTsWithKey.load() + val resultDf4 = stateReadDfForRightTsWithKey + .selectExpr("key.field0 AS key_0") checkAnswer( resultDf4, - Seq(Row(6, 0, 6, 6L), Row(8, 0, 8, 8L), Row(10, 0, 10, 10L)) + Seq(Row(6), Row(8), Row(10)) ) } } @@ -1465,9 +1515,17 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass } protected def testSnapshotOnJoinStateV3(): Unit = { + testSnapshotOnJoinStateWithVCF(3) + } + + protected def testSnapshotOnJoinStateV4(): Unit = { + testSnapshotOnJoinStateWithVCF(4) + } + + private def testSnapshotOnJoinStateWithVCF(stateFormatVersion: Int): Unit = { withTempDir { tmpDir => withSQLConf( - SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3", + SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> stateFormatVersion.toString, SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100" ) { val inputData = MemoryStream[(Int, Long)] @@ -1486,18 +1544,31 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass StopStream ) - val stateSnapshotDf = spark.read.format("statestore") + val stateSnapshotDfLeft = spark.read.format("statestore") .option("snapshotPartitionId", 2) .option("snapshotStartBatchId", 0) .option("joinSide", "left") .load(tmpDir.getCanonicalPath) - val stateDf = spark.read.format("statestore") + val stateDfLeft = spark.read.format("statestore") .option("joinSide", "left") .load(tmpDir.getCanonicalPath) .filter(col("partition_id") === 2) - checkAnswer(stateSnapshotDf, stateDf) + checkAnswer(stateSnapshotDfLeft, stateDfLeft) + + val stateSnapshotDfRight = spark.read.format("statestore") + .option("snapshotPartitionId", 2) + .option("snapshotStartBatchId", 0) + .option("joinSide", "right") + .load(tmpDir.getCanonicalPath) + + val stateDfRight = spark.read.format("statestore") + .option("joinSide", "right") + .load(tmpDir.getCanonicalPath) + .filter(col("partition_id") === 2) + + checkAnswer(stateSnapshotDfRight, stateDfRight) } } }