Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
stateStoreReaderInfo.stateSchemaProviderOpt,
stateStoreReaderInfo.joinColFamilyOpt,
stateStoreReaderInfo.allColumnFamiliesReaderInfo)
stateStoreReaderInfo.allColumnFamiliesReaderInfo,
stateStoreReaderInfo.joinStateFormatVersion)
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
Expand All @@ -110,11 +111,13 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
val (keySchema, valueSchema) = sourceOptions.joinSide match {
case JoinSideValues.left =>
StreamStreamJoinStateHelper.readKeyValueSchema(session, stateCheckpointLocation.toString,
sourceOptions.operatorId, LeftSide, oldSchemaFilePaths)
sourceOptions.operatorId, LeftSide, oldSchemaFilePaths,
joinStateFormatVersion = stateStoreReaderInfo.joinStateFormatVersion)

case JoinSideValues.right =>
StreamStreamJoinStateHelper.readKeyValueSchema(session, stateCheckpointLocation.toString,
sourceOptions.operatorId, RightSide, oldSchemaFilePaths)
sourceOptions.operatorId, RightSide, oldSchemaFilePaths,
joinStateFormatVersion = stateStoreReaderInfo.joinStateFormatVersion)

case JoinSideValues.none =>
// we should have the schema for the state store if joinSide is none
Expand Down Expand Up @@ -162,9 +165,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging

/**
* Returns true if this is a read-all-column-families request for a stream-stream join
* that uses virtual column families (state format version 3).
* that uses virtual column families (state format version >= 3).
*/
private def isReadAllColFamiliesOnJoinV3(
private def isReadAllColFamiliesOnJoinWithVCF(
sourceOptions: StateSourceOptions,
storeMetadata: Array[StateMetadataTableEntry]): Boolean = {
sourceOptions.internalOnlyReadAllColumnFamilies &&
Expand Down Expand Up @@ -243,9 +246,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
opMetadata.operatorName match {
case opName: String if opName ==
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME =>
// Verify that the storename is valid
val possibleStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(
LeftSide, RightSide)
LeftSide, RightSide) ++
SymmetricHashJoinStateManager.allStateStoreNamesV4(LeftSide, RightSide)
if (!possibleStoreNames.contains(name)) {
val errorMsg = s"Store name $name not allowed for join operator. Allowed names are " +
s"$possibleStoreNames. " +
Expand Down Expand Up @@ -393,7 +396,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
// However, Join V3 does not have a "default" column family. Therefore, we pick the first
// schema as resultSchema which will be used as placeholder schema for default schema
// in StatePartitionAllColumnFamiliesReader
val resultSchema = if (isReadAllColFamiliesOnJoinV3(sourceOptions, storeMetadata)) {
val resultSchema = if (isReadAllColFamiliesOnJoinWithVCF(sourceOptions, storeMetadata)) {
stateSchema.head
} else {
stateSchema.filter(_.colFamilyName == stateVarName).head
Expand All @@ -408,17 +411,18 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
}
}

val joinFormatVersion = getStateFormatVersion(
storeMetadata,
sourceOptions.resolvedCpLocation,
sourceOptions.batchId
)

val allColFamilyReaderInfoOpt: Option[AllColumnFamiliesReaderInfo] =
if (sourceOptions.internalOnlyReadAllColumnFamilies) {
assert(storeMetadata.nonEmpty, "storeMetadata shouldn't be empty")
val operatorName = storeMetadata.head.operatorName
val stateFormatVersion = getStateFormatVersion(
storeMetadata,
sourceOptions.resolvedCpLocation,
sourceOptions.batchId
)
Some(AllColumnFamiliesReaderInfo(
stateStoreColFamilySchemas, stateVariableInfos, operatorName, stateFormatVersion))
stateStoreColFamilySchemas, stateVariableInfos, operatorName, joinFormatVersion))
} else {
None
}
Expand All @@ -428,7 +432,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
transformWithStateVariableInfoOpt,
stateSchemaProvider,
joinColFamilyOpt,
allColFamilyReaderInfoOpt
allColFamilyReaderInfoOpt,
joinFormatVersion
)
}

Expand Down Expand Up @@ -819,7 +824,8 @@ case class StateStoreReaderInfo(
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String], // Only used for join op with state format v3
// List of all column family schemas - used when internalOnlyReadAllColumnFamilies=true
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
joinStateFormatVersion: Option[Int] = None
)

object StateDataSource {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorsUtils, StatePartitionKeyExtractorFactory}
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state._
Expand All @@ -43,6 +44,18 @@ case class AllColumnFamiliesReaderInfo(
operatorName: String,
stateFormatVersion: Option[Int] = None)

private[state] object StatePartitionReaderUtils {
val v4JoinCFNames: Set[String] =
SymmetricHashJoinStateManager.allStateStoreNamesV4(LeftSide, RightSide).toSet

def isMultiValuedCF(
colFamilyNameOpt: Option[String],
stateVariableInfoOpt: Option[TransformWithStateVariableInfo]): Boolean = {
SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.ListState) ||
colFamilyNameOpt.exists(v4JoinCFNames.contains)
}
}

/**
* An implementation of [[PartitionReaderFactory]] for State data source. This is used to support
* general read from a state store instance, rather than specific to the operator.
Expand Down Expand Up @@ -145,8 +158,8 @@ abstract class StatePartitionReaderBase(

val useColFamilies = stateVariableInfoOpt.isDefined || joinColFamilyOpt.isDefined

val useMultipleValuesPerKey = SchemaUtil.checkVariableType(stateVariableInfoOpt,
StateVariableType.ListState)
val useMultipleValuesPerKey = StatePartitionReaderUtils.isMultiValuedCF(
joinColFamilyOpt, stateVariableInfoOpt)

val provider = StateStoreProvider.createAndInit(
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
Expand Down Expand Up @@ -249,6 +262,12 @@ class StatePartitionReader(
val stateVarType = stateVariableInfo.stateVariableType
SchemaUtil.processStateEntries(stateVarType, colFamilyName, store,
keySchema, partition.partition, partition.sourceOptions)
} else if (joinColFamilyOpt.exists(StatePartitionReaderUtils.v4JoinCFNames.contains)) {
store
.iteratorWithMultiValues(colFamilyName)
.map { pair =>
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
}
} else {
store
.iterator(colFamilyName)
Expand Down Expand Up @@ -327,10 +346,10 @@ class StatePartitionAllColumnFamiliesReader(
}.toMap
}

private def isListType(colFamilyName: String): Boolean = {
SchemaUtil.checkVariableType(
stateVariableInfos.find(info => info.stateName == colFamilyName),
StateVariableType.ListState)
private def isMultiValuedCF(colFamilyName: String): Boolean = {
StatePartitionReaderUtils.isMultiValuedCF(
Some(colFamilyName),
stateVariableInfos.find(info => info.stateName == colFamilyName))
}

override protected lazy val provider: StateStoreProvider = {
Expand Down Expand Up @@ -386,7 +405,7 @@ class StatePartitionAllColumnFamiliesReader(
case _ =>
val isInternal =
StateStoreColumnFamilySchemaUtils.isInternalColFamily(cfSchema.colFamilyName)
val useMultipleValuesPerKey = isListType(cfSchema.colFamilyName)
val useMultipleValuesPerKey = isMultiValuedCF(cfSchema.colFamilyName)
require(cfSchema.keyStateEncoderSpec.isDefined,
s"keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}")
stateStore.createColFamilyIfAbsent(
Expand All @@ -410,15 +429,11 @@ class StatePartitionAllColumnFamiliesReader(
.filter(schema => !isDefaultColFamilyInTWS(operatorName, schema.colFamilyName))
.flatMap { cfSchema =>
val extractor = cfPartitionKeyExtractors(cfSchema.colFamilyName)
if (isListType(cfSchema.colFamilyName)) {
store.iterator(cfSchema.colFamilyName).flatMap(
pair =>
store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
value =>
SchemaUtil.unifyStateRowPairAsRawBytes(
(pair.key, value), cfSchema.colFamilyName, extractor)
}
)
if (isMultiValuedCF(cfSchema.colFamilyName)) {
store.iteratorWithMultiValues(cfSchema.colFamilyName).map { pair =>
SchemaUtil.unifyStateRowPairAsRawBytes(
(pair.key, pair.value), cfSchema.colFamilyName, extractor)
}
} else {
store.iterator(cfSchema.colFamilyName).map { pair =>
SchemaUtil.unifyStateRowPairAsRawBytes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ class StateScanBuilder(
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) extends ScanBuilder {
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
joinStateFormatVersion: Option[Int] = None) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf,
batchNumPartitions, keyStateEncoderSpec,
stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
joinColFamilyOpt, allColumnFamiliesReaderInfo)
joinColFamilyOpt, allColumnFamiliesReaderInfo, joinStateFormatVersion)
}

/** An implementation of [[InputPartition]] for State Store data source. */
Expand All @@ -73,7 +74,8 @@ class StateScan(
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo])
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
joinStateFormatVersion: Option[Int] = None)
extends Scan with Batch {

// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
Expand Down Expand Up @@ -136,19 +138,21 @@ class StateScan(
hadoopConfBroadcast.value.value)
val stateSchema = StreamStreamJoinStateHelper.readSchema(session,
sourceOptions.stateCheckpointLocation.toString, sourceOptions.operatorId, LeftSide,
oldSchemaFilePaths, excludeAuxColumns = false)
oldSchemaFilePaths, excludeAuxColumns = false,
joinStateFormatVersion = joinStateFormatVersion)
new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf,
hadoopConfBroadcast.value, userFacingSchema, stateSchema)
hadoopConfBroadcast.value, userFacingSchema, stateSchema, joinStateFormatVersion)

case JoinSideValues.right =>
val userFacingSchema = schema
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions,
hadoopConfBroadcast.value.value)
val stateSchema = StreamStreamJoinStateHelper.readSchema(session,
sourceOptions.stateCheckpointLocation.toString, sourceOptions.operatorId, RightSide,
oldSchemaFilePaths, excludeAuxColumns = false)
oldSchemaFilePaths, excludeAuxColumns = false,
joinStateFormatVersion = joinStateFormatVersion)
new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf,
hadoopConfBroadcast.value, userFacingSchema, stateSchema)
hadoopConfBroadcast.value, userFacingSchema, stateSchema, joinStateFormatVersion)

case JoinSideValues.none =>
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class StateTable(
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] = None)
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] = None,
joinStateFormatVersion: Option[Int] = None)
extends Table with SupportsRead with SupportsMetadataColumns {

import StateTable._
Expand Down Expand Up @@ -90,7 +91,7 @@ class StateTable(
new StateScanBuilder(session, schema, sourceOptions, stateConf,
batchNumPartitions, keyStateEncoderSpec,
stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
joinColFamilyOpt, allColumnFamiliesReaderInfo)
joinColFamilyOpt, allColumnFamiliesReaderInfo, joinStateFormatVersion)

override def properties(): util.Map[String, String] = Map.empty[String, String].asJava

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,23 @@ object StreamStreamJoinStateHelper {
operatorId: Int,
side: JoinSide,
oldSchemaFilePaths: List[Path],
excludeAuxColumns: Boolean = true): StructType = {
excludeAuxColumns: Boolean = true,
joinStateFormatVersion: Option[Int] = None): StructType = {
val (keySchema, valueSchema) = readKeyValueSchema(session, stateCheckpointLocation,
operatorId, side, oldSchemaFilePaths, excludeAuxColumns)
operatorId, side, oldSchemaFilePaths, excludeAuxColumns, joinStateFormatVersion)

new StructType()
.add("key", keySchema)
.add("value", valueSchema)
}

// Returns whether the checkpoint uses stateFormatVersion 3 which uses VCF for the join.
// Returns whether the checkpoint uses VCF for the join (stateFormatVersion >= 3).
def usesVirtualColumnFamilies(
hadoopConf: Configuration,
stateCheckpointLocation: String,
operatorId: Int): Boolean = {
// If the schema exists for operatorId/partitionId/left-keyToNumValues, it is not
// stateFormatVersion 3.
// stateFormatVersion >= 3 (which uses VCF).
val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
val storeId = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, SymmetricHashJoinStateManager.allStateStoreNames(LeftSide).toList.head)
Expand All @@ -72,16 +73,17 @@ object StreamStreamJoinStateHelper {
operatorId: Int,
side: JoinSide,
oldSchemaFilePaths: List[Path],
excludeAuxColumns: Boolean = true): (StructType, StructType) = {
excludeAuxColumns: Boolean = true,
joinStateFormatVersion: Option[Int] = None): (StructType, StructType) = {

val newHadoopConf = session.sessionState.newHadoopConf()
val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
// KeyToNumValuesType, KeyWithIndexToValueType
val storeNames = SymmetricHashJoinStateManager.allStateStoreNames(side).toList

val (keySchema, valueSchema) =
if (!usesVirtualColumnFamilies(
newHadoopConf, stateCheckpointLocation, operatorId)) {
val (keySchema, valueSchema) = joinStateFormatVersion match {
case Some(1) | Some(2) | None =>
// v1/v2: separate state stores per store type.
// None handles old checkpoints without operator metadata (always v1/v2).
val storeNames = SymmetricHashJoinStateManager.allStateStoreNames(side).toList
val storeIdForKeyToNumValues = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, storeNames(0))
val providerIdForKeyToNumValues = new StateStoreProviderId(storeIdForKeyToNumValues,
Expand All @@ -92,42 +94,64 @@ object StreamStreamJoinStateHelper {
val providerIdForKeyWithIndexToValue = new StateStoreProviderId(
storeIdForKeyWithIndexToValue, UUID.randomUUID())

// read the key schema from the keyToNumValues store for the join keys
val manager = new StateSchemaCompatibilityChecker(
providerIdForKeyToNumValues, newHadoopConf, oldSchemaFilePaths,
createSchemaDir = false)
val kSchema = manager.readSchemaFile().head.keySchema

// read the value schema from the keyWithIndexToValue store for the values
val manager2 = new StateSchemaCompatibilityChecker(providerIdForKeyWithIndexToValue,
newHadoopConf, oldSchemaFilePaths, createSchemaDir = false)
val vSchema = manager2.readSchemaFile().head.valueSchema

(kSchema, vSchema)
} else {

case Some(3) =>
// v3: single state store with virtual column families
val storeNames = SymmetricHashJoinStateManager.allStateStoreNames(side).toList
val storeId = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, StateStoreId.DEFAULT_STORE_NAME)
val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())

val manager = new StateSchemaCompatibilityChecker(
providerId, newHadoopConf, oldSchemaFilePaths, createSchemaDir = false)
val kSchema = manager.readSchemaFile().find { schema =>
schema.colFamilyName == storeNames(0)
}.map(_.keySchema).get
val schemas = manager.readSchemaFile()

val vSchema = manager.readSchemaFile().find { schema =>
schema.colFamilyName == storeNames(1)
}.map(_.valueSchema).get
val kSchema = schemas.find(_.colFamilyName == storeNames(0)).map(_.keySchema).get
val vSchema = schemas.find(_.colFamilyName == storeNames(1)).map(_.valueSchema).get

(kSchema, vSchema)
}

case Some(4) =>
// v4: single state store with virtual column families, timestamp-based keys
val v4Names = SymmetricHashJoinStateManager.allStateStoreNamesV4(side).toList
val storeId = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, StateStoreId.DEFAULT_STORE_NAME)
val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())

val manager = new StateSchemaCompatibilityChecker(
providerId, newHadoopConf, oldSchemaFilePaths, createSchemaDir = false)
val schemas = manager.readSchemaFile()

// In v4, the primary CF (keyWithTsToValues) stores both the key and value schemas.
// This differs from v3 where keyToNumValues has the key schema and
// keyWithIndexToValue has the value schema.
val primaryCF = v4Names(0)
val kSchema = schemas.find(_.colFamilyName == primaryCF).map(_.keySchema).get
val vSchema = schemas.find(_.colFamilyName == primaryCF).map(_.valueSchema).get

(kSchema, vSchema)

case Some(v) =>
throw new IllegalArgumentException(
s"Unsupported join state format version: $v")
}

val maybeMatchedColumn = valueSchema.last

// remove internal column `matched` for format version >= 2
if (excludeAuxColumns
&& maybeMatchedColumn.name == "matched"
&& maybeMatchedColumn.dataType == BooleanType) {
// remove internal column `matched` for format version 2
(keySchema, StructType(valueSchema.dropRight(1)))
} else {
(keySchema, valueSchema)
Expand Down
Loading