diff --git a/crates/codegen/src/unrealcpp.rs b/crates/codegen/src/unrealcpp.rs index bab5d0bfed5..2e980843cb1 100644 --- a/crates/codegen/src/unrealcpp.rs +++ b/crates/codegen/src/unrealcpp.rs @@ -47,6 +47,7 @@ impl Lang for UnrealCpp<'_> { &[ "Types/Builtins.h", &format!("ModuleBindings/Types/{struct_name}Type.g.h"), + "DBCache/ClientCache.h", "Tables/RemoteTable.h", "DBCache/WithBsatn.h", "DBCache/TableHandle.h", @@ -63,6 +64,35 @@ impl Lang for UnrealCpp<'_> { // Generate unique index classes first let product_type = module.typespace_for_generate()[table.product_type_ref].as_product(); + if let (false, Some(pk)) = (table.is_event, schema.pk()) { + let (pk_name, pk_ty) = &product_type.unwrap().elements[pk.col_pos.idx()]; + let pk_type_str = cpp_ty_fmt_with_module(self.module_prefix, module, pk_ty, self.module_name).to_string(); + if pk_type_str == "uint64" { + let pk_field_name = pk_name.deref().to_case(Case::Pascal); + let pk_index_name = pk.col_name.deref(); + writeln!(output, "namespace UE::SpacetimeDB"); + writeln!(output, "{{"); + writeln!(output, "template<>"); + writeln!(output, "struct TCompactPrimaryKeyTraits<::{row_struct}>"); + writeln!(output, "{{"); + writeln!(output, " static constexpr bool bEnabled = true;"); + writeln!(output, " using KeyType = uint64;"); + writeln!(output); + writeln!(output, " static KeyType GetKey(const ::{row_struct}& Row)"); + writeln!(output, " {{"); + writeln!(output, " return Row.{pk_field_name};"); + writeln!(output, " }}"); + writeln!(output); + writeln!(output, " static const TCHAR* GetUniqueIndexName()"); + writeln!(output, " {{"); + writeln!(output, " return TEXT(\"{pk_index_name}\");"); + writeln!(output, " }}"); + writeln!(output, "}};"); + writeln!(output, "}}"); + writeln!(output); + } + } + let mut unique_indexes = Vec::new(); let mut multi_key_indexes = Vec::new(); @@ -308,6 +338,17 @@ impl Lang for UnrealCpp<'_> { writeln!(output, " void PostInitialize();"); writeln!(output); + writeln!( + output, + " void SetCacheApplyMode(UE::SpacetimeDB::ETableCacheApplyMode InApplyMode);" + ); + writeln!( + output, + " UE::SpacetimeDB::ETableCacheApplyMode GetCacheApplyMode() const;" + ); + writeln!(output, " void SetRuntimeBTreeIndexApplyMode(const FString& IndexName, UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode InApplyMode);"); + writeln!(output, " void ClearRuntimeBTreeIndexApplyModes();"); + writeln!(output); writeln!(output, " /** Update function for {table_name} table*/"); writeln!( @@ -1205,7 +1246,7 @@ fn generate_table_cpp( writeln!( output, " {table_pascal}Table->AddUniqueConstraint<{field_type}>(\"{}\", [](const {row_struct}& Row) -> const {field_type}& {{", - field_name.to_lowercase() + field_name ); writeln!(output, " return Row.{}; }});", field_name.to_case(Case::Pascal)); } @@ -1255,6 +1296,52 @@ fn generate_table_cpp( writeln!(output, "}}"); writeln!(output); + writeln!( + output, + "void U{table_pascal}Table::SetCacheApplyMode(UE::SpacetimeDB::ETableCacheApplyMode InApplyMode)" + ); + writeln!(output, "{{"); + writeln!( + output, + " checkf(Data.IsValid(), TEXT(\"{table_pascal} table cache policy set before PostInitialize.\"));" + ); + writeln!(output, " Data->SetApplyMode(InApplyMode);"); + writeln!(output, "}}"); + writeln!(output); + + writeln!( + output, + "UE::SpacetimeDB::ETableCacheApplyMode U{table_pascal}Table::GetCacheApplyMode() const" + ); + writeln!(output, "{{"); + writeln!( + output, + " checkf(Data.IsValid(), TEXT(\"{table_pascal} table cache policy read before PostInitialize.\"));" + ); + writeln!(output, " return Data->GetApplyMode();"); + writeln!(output, "}}"); + writeln!(output); + + writeln!(output, "void U{table_pascal}Table::SetRuntimeBTreeIndexApplyMode(const FString& IndexName, UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode InApplyMode)"); + writeln!(output, "{{"); + writeln!( + output, + " checkf(Data.IsValid(), TEXT(\"{table_pascal} runtime B-Tree index policy set before PostInitialize.\"));" + ); + writeln!( + output, + " Data->SetRuntimeBTreeIndexApplyMode(IndexName, InApplyMode);" + ); + writeln!(output, "}}"); + writeln!(output); + + writeln!(output, "void U{table_pascal}Table::ClearRuntimeBTreeIndexApplyModes()"); + writeln!(output, "{{"); + writeln!(output, " checkf(Data.IsValid(), TEXT(\"{table_pascal} runtime B-Tree index policy reset before PostInitialize.\"));"); + writeln!(output, " Data->ClearRuntimeBTreeIndexApplyModes();"); + writeln!(output, "}}"); + writeln!(output); + // Generate Update implementation writeln!( output, @@ -1267,9 +1354,12 @@ fn generate_table_cpp( " // Event tables are callback-only: do not persist rows in the local cache." ); writeln!(output, " FTableAppliedDiff<{row_struct}> Diff;"); - writeln!(output, " for (const FWithBsatn<{row_struct}>& Insert : InsertsRef)"); + writeln!(output, " for (FWithBsatn<{row_struct}>& Insert : InsertsRef)"); writeln!(output, " {{"); - writeln!(output, " Diff.Inserts.Add(Insert.Bsatn, Insert.Row);"); + writeln!( + output, + " Diff.Inserts.Add(MakeShared<{row_struct}>(MoveTemp(Insert.Row)));" + ); writeln!(output, " }}"); } else { writeln!( diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp index 669f8079789..04a8dc9fd7f 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp @@ -1,13 +1,16 @@ -#include "Connection/DbConnectionBase.h" +#include "Connection/DbConnectionBase.h" #include "Connection/DbConnectionBuilder.h" #include "Connection/Credentials.h" #include "Connection/LogCategory.h" -#include "Containers/Ticker.h" #include "ModuleBindings/Types/ClientMessageType.g.h" #include "ModuleBindings/Types/SubscriptionErrorType.g.h" #include "Misc/Compression.h" #include "Misc/ScopeLock.h" -#include "Async/Async.h" +#include "HAL/Event.h" +#include "HAL/PlatformProcess.h" +#include "HAL/Runnable.h" +#include "HAL/RunnableThread.h" +#include "ProfilingDebugging/CpuProfilerTrace.h" #include "BSATN/UEBSATNHelpers.h" #include "Connection/ProcedureFlags.h" @@ -20,7 +23,21 @@ enum class EWsCompressionTag : uint8 Gzip = 2, }; -static FDatabaseUpdateType QueryRowsToDatabaseUpdate(const FQueryRowsType& Rows, bool bAsDeletes) +constexpr int32 MaxQueuedInboundRawMessages = 8192; +constexpr int64 MaxQueuedInboundRawBytes = 128ll * 1024ll * 1024ll; +constexpr int32 MaxPendingInboundParsedMessages = 8192; +constexpr int64 MaxInboundDecodedPayloadBytes = 128ll * 1024ll * 1024ll; +constexpr int64 MaxPendingInboundParsedEstimatedBytes = 256ll * 1024ll * 1024ll; +constexpr int32 PendingInboundCompactionMinConsumedMessages = 512; +constexpr int32 InboundRawCompactionMinConsumedMessages = 512; +constexpr uint32 InboundWorkerStackSizeBytes = 0; +constexpr EThreadPriority InboundWorkerThreadPriority = TPri_Normal; +constexpr const TCHAR* InboundWorkerThreadName = TEXT("SpacetimeDBInboundWorker"); +constexpr int32 SpacetimeDbCompressionTagBytes = 1; +constexpr int32 GzipFooterUncompressedSizeBytes = 4; +constexpr int32 MaxInboundApplyLogTableContributors = 6; + +static FDatabaseUpdateType QueryRowsToDatabaseUpdate(const FQueryRowsType& Rows, UE::SpacetimeDB::EQueryRowsApplyMode Mode) { FDatabaseUpdateType Update; for (const FSingleTableRowsType& TableRows : Rows.Tables) @@ -29,13 +46,17 @@ static FDatabaseUpdateType QueryRowsToDatabaseUpdate(const FQueryRowsType& Rows, TableUpdate.TableName = TableRows.Table; FPersistentTableRowsType PersistentRows; - if (bAsDeletes) + switch (Mode) { + case UE::SpacetimeDB::EQueryRowsApplyMode::Deletes: PersistentRows.Deletes = TableRows.Rows; - } - else - { + break; + case UE::SpacetimeDB::EQueryRowsApplyMode::Inserts: PersistentRows.Inserts = TableRows.Rows; + break; + default: + checkf(false, TEXT("Unsupported query-row apply mode for table %s"), *TableRows.Table); + continue; } TableUpdate.Rows.Add(FTableUpdateRowsType::PersistentTable(PersistentRows)); Update.Tables.Add(TableUpdate); @@ -64,99 +85,221 @@ static FString DecodeReducerErrorMessage(const TArray& ErrorBytes) } return UE::SpacetimeDB::Deserialize(ErrorBytes); } + +static const TCHAR* DescribeServerMessageTag(EServerMessageTag Tag) +{ + switch (Tag) + { + case EServerMessageTag::InitialConnection: + return TEXT("InitialConnection"); + case EServerMessageTag::TransactionUpdate: + return TEXT("TransactionUpdate"); + case EServerMessageTag::OneOffQueryResult: + return TEXT("OneOffQueryResult"); + case EServerMessageTag::SubscribeApplied: + return TEXT("SubscribeApplied"); + case EServerMessageTag::UnsubscribeApplied: + return TEXT("UnsubscribeApplied"); + case EServerMessageTag::SubscriptionError: + return TEXT("SubscriptionError"); + case EServerMessageTag::ReducerResult: + return TEXT("ReducerResult"); + case EServerMessageTag::ProcedureResult: + return TEXT("ProcedureResult"); + default: + return TEXT("Unknown"); + } } - -UDbConnectionBase::UDbConnectionBase(const FObjectInitializer& ObjectInitializer) - : Super(ObjectInitializer) + +static FString FormatInboundTableApplyStats(const FSpacetimeDBTableApplyStats& Stats) { - NextRequestId = 1; - NextSubscriptionId = 1; - ProcedureCallbacks = CreateDefaultSubobject(TEXT("ProcedureCallbacks")); + return FString::Printf( + TEXT("%s rows=%d ins=%d del=%d bytes=%lld cache=%.2fus broadcast=%.2fus diff=%d"), + *Stats.TableName, + Stats.RowSetCount, + Stats.InsertRowCount, + Stats.DeleteRowCount, + Stats.InsertRowBytes + Stats.DeleteRowBytes, + Stats.CacheMicros, + Stats.BroadcastMicros, + Stats.bProducedDiff ? 1 : 0); } -void UDbConnectionBase::SetAutoTicking(bool bAutoTick) +static int64 EstimatePreprocessedTableDataBytes(const FPreprocessedTableDataMap& PreprocessedTableData) { - if (bIsAutoTicking == bAutoTick) + int64 EstimatedBytes = PreprocessedTableData.GetAllocatedSize(); + for (const TPair>>& TablePair : PreprocessedTableData) { - return; + EstimatedBytes += TablePair.Key.TableName.GetAllocatedSize(); + EstimatedBytes += TablePair.Value.GetAllocatedSize(); + for (const TSharedPtr& Data : TablePair.Value) + { + if (Data.IsValid()) + { + EstimatedBytes += Data->EstimateMemoryBytes(); + } + } } + return EstimatedBytes; +} - bIsAutoTicking = bAutoTick; +static int64 EstimateInboundParsedMessageBytes(const FInboundParsedMessage& Message) +{ + return sizeof(FInboundParsedMessage) + + static_cast(Message.DecodedPayloadSizeBytes) + + Message.ProtocolError.GetAllocatedSize() + + EstimatePreprocessedTableDataBytes(Message.PreprocessedTableData); +} +} - if (bIsAutoTicking) +class FSpacetimeDbInboundWorker final : public FRunnable +{ +public: + explicit FSpacetimeDbInboundWorker(UDbConnectionBase& InConnection) + : Connection(&InConnection) + { + WorkAvailableEvent = FPlatformProcess::GetSynchEventFromPool(false); + checkf(WorkAvailableEvent != nullptr, TEXT("Failed to allocate SpacetimeDB inbound worker event.")); + + Thread = FRunnableThread::Create( + this, + InboundWorkerThreadName, + InboundWorkerStackSizeBytes, + InboundWorkerThreadPriority); + checkf(Thread != nullptr, TEXT("Failed to create SpacetimeDB inbound worker thread.")); + } + + virtual ~FSpacetimeDbInboundWorker() override + { + StopAndJoin(); + } + + void Notify() { - if (!TickerHandle.IsValid()) + if (WorkAvailableEvent) { - TickerHandle = FTSTicker::GetCoreTicker().AddTicker(FTickerDelegate::CreateUObject(this, &UDbConnectionBase::OnTickerTick)); + WorkAvailableEvent->Trigger(); } } - else if (TickerHandle.IsValid()) + + virtual uint32 Run() override + { + while (!bStopRequested) + { + checkf(WorkAvailableEvent != nullptr, TEXT("SpacetimeDB inbound worker event was not initialized.")); + WorkAvailableEvent->Wait(); + + if (bStopRequested) + { + break; + } + + if (Connection) + { + Connection->DrainInboundRawMessagesOnWorker(); + } + } + + return 0; + } + + virtual void Stop() override + { + bStopRequested = true; + Notify(); + } + + void StopAndJoin() { - FTSTicker::GetCoreTicker().RemoveTicker(TickerHandle); - TickerHandle.Reset(); + Stop(); + + if (Thread) + { + Thread->WaitForCompletion(); + delete Thread; + Thread = nullptr; + } + + if (WorkAvailableEvent) + { + FPlatformProcess::ReturnSynchEventToPool(WorkAvailableEvent); + WorkAvailableEvent = nullptr; + } + + Connection = nullptr; } + +private: + UDbConnectionBase* Connection = nullptr; + FEvent* WorkAvailableEvent = nullptr; + FRunnableThread* Thread = nullptr; + FThreadSafeBool bStopRequested = false; +}; + +UDbConnectionBase::UDbConnectionBase(const FObjectInitializer& ObjectInitializer) + : Super(ObjectInitializer) +{ + NextRequestId = 1; + NextSubscriptionId = 1; + ProcedureCallbacks = CreateDefaultSubobject(TEXT("ProcedureCallbacks")); } void UDbConnectionBase::BeginDestroy() { - if (TickerHandle.IsValid()) + StopInboundMessageWorker(); + Super::BeginDestroy(); +} + +void UDbConnectionBase::Disconnect() +{ + StopInboundMessageWorker(); + if (WebSocket) { - FTSTicker::GetCoreTicker().RemoveTicker(TickerHandle); - TickerHandle.Reset(); + WebSocket->Disconnect(); } - bIsAutoTicking = false; +} - Super::BeginDestroy(); +bool UDbConnectionBase::IsActive() const +{ + return WebSocket && WebSocket->IsConnected(); +} + + +bool UDbConnectionBase::TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const +{ + if (bIsIdentitySet) + { + OutIdentity = Identity; + return true; + } + + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("TryGetIdentity called before identity was set")); + return false; +} + +FSpacetimeDBConnectionId UDbConnectionBase::GetConnectionId() const +{ + return ConnectionId; +} + +bool UDbConnectionBase::SendRawMessage(const FString& Message) +{ + return WebSocket && WebSocket->SendMessage(Message); +} + +bool UDbConnectionBase::SendRawMessage(const TArray& Message) +{ + return WebSocket && WebSocket->SendMessage(Message); +} + +USubscriptionBuilderBase* UDbConnectionBase::SubscriptionBuilderBase() +{ + return NewObject(); } - -void UDbConnectionBase::Disconnect() -{ - if (WebSocket) - { - WebSocket->Disconnect(); - } -} - -bool UDbConnectionBase::IsActive() const -{ - return WebSocket && WebSocket->IsConnected(); -} - - -bool UDbConnectionBase::TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const -{ - if (bIsIdentitySet) - { - OutIdentity = Identity; - return true; - } - - UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("TryGetIdentity called before identity was set")); - return false; -} - -FSpacetimeDBConnectionId UDbConnectionBase::GetConnectionId() const -{ - return ConnectionId; -} - -bool UDbConnectionBase::SendRawMessage(const FString& Message) -{ - return WebSocket && WebSocket->SendMessage(Message); -} - -bool UDbConnectionBase::SendRawMessage(const TArray& Message) -{ - return WebSocket && WebSocket->SendMessage(Message); -} - -USubscriptionBuilderBase* UDbConnectionBase::SubscriptionBuilderBase() -{ - return NewObject(); -} - + void UDbConnectionBase::HandleWSError(const FString& Error) { + StopInboundMessageWorker(); bProtocolViolationHandled = false; ClearPendingOperations(Error); if (OnConnectErrorDelegate.IsBound()) @@ -167,6 +310,7 @@ void UDbConnectionBase::HandleWSError(const FString& Error) void UDbConnectionBase::HandleWSClosed(int32 /*StatusCode*/, const FString& Reason, bool /*bWasClean*/) { + StopInboundMessageWorker(); bProtocolViolationHandled = false; ClearPendingOperations(Reason); if (OnDisconnectBaseDelegate.IsBound()) @@ -185,6 +329,7 @@ void UDbConnectionBase::HandleProtocolViolation(const FString& ErrorMessage) UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("%s"), *ErrorMessage); TriggerError(ErrorMessage); + StopInboundMessageWorker(); ClearPendingOperations(ErrorMessage); // Match Rust/C# behavior: parse/protocol violations are fatal for the connection. @@ -197,103 +342,653 @@ void UDbConnectionBase::HandleProtocolViolation(const FString& ErrorMessage) OnConnectErrorDelegate.Execute(ErrorMessage); } } - -void UDbConnectionBase::HandleWSBinaryMessage(const TArray& Message) -{ - //tag for arrival order - const int32 Id = NextPreprocessId.GetValue(); - NextPreprocessId.Increment(); - - //do expensive work off-thread - TWeakObjectPtr WeakThis(this); - Async(EAsyncExecution::Thread, [WeakThis, Message, Id]() - { - if (!WeakThis.IsValid()) - { - return; + +void UDbConnectionBase::StartInboundMessageWorker() +{ + FScopeLock Lock(&InboundWorkerMutex); + if (InboundWorker) + { + return; + } + + { + FScopeLock RawLock(&InboundRawMessagesMutex); + InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; + InboundQueuedRawBytes = 0; + ++InboundConnectionEpoch; + NextInboundSequenceId = 0; + bInboundAcceptingMessages = true; + bInboundProtocolErrorQueued = false; + } + + { + FScopeLock PendingLock(&PendingMessagesMutex); + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedEstimatedBytes = 0; + } + + ActivePreprocessedTableData = nullptr; + InboundWorker = new FSpacetimeDbInboundWorker(*this); +} + +void UDbConnectionBase::StopInboundMessageWorker() +{ + FSpacetimeDbInboundWorker* WorkerToStop = nullptr; + { + FScopeLock Lock(&InboundWorkerMutex); + { + FScopeLock RawLock(&InboundRawMessagesMutex); + InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; + InboundQueuedRawBytes = 0; + ++InboundConnectionEpoch; + bInboundAcceptingMessages = false; + bInboundProtocolErrorQueued = false; + } + + { + FScopeLock PendingLock(&PendingMessagesMutex); + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedEstimatedBytes = 0; + } + + ActivePreprocessedTableData = nullptr; + WorkerToStop = InboundWorker; + InboundWorker = nullptr; + } + + if (WorkerToStop) + { + delete WorkerToStop; + } + + ClearInboundMessageQueues(); +} + +void UDbConnectionBase::ClearInboundMessageQueues() +{ + { + FScopeLock Lock(&InboundRawMessagesMutex); + InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; + InboundQueuedRawBytes = 0; + } + + { + FScopeLock Lock(&PendingMessagesMutex); + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedEstimatedBytes = 0; + } + + ActivePreprocessedTableData = nullptr; +} + +void UDbConnectionBase::NotifyInboundWorkerIfNeeded() +{ + FScopeLock WorkerLock(&InboundWorkerMutex); + if (InboundWorker == nullptr) + { + return; + } + + bool bShouldNotify = false; + { + FScopeLock RawLock(&InboundRawMessagesMutex); + checkf(InboundRawMessageReadIndex <= InboundRawMessages.Num(), + TEXT("SpacetimeDB inbound raw queue read index exceeded queued messages while notifying worker.")); + bShouldNotify = InboundRawMessages.Num() > InboundRawMessageReadIndex && bInboundAcceptingMessages && !bInboundProtocolErrorQueued; + } + + if (bShouldNotify) + { + InboundWorker->Notify(); + } +} + +bool UDbConnectionBase::IsInboundProtocolErrorQueued() const +{ + FScopeLock Lock(&InboundRawMessagesMutex); + return bInboundProtocolErrorQueued; +} + +bool UDbConnectionBase::IsInboundEpochCurrentAndAccepting(uint64 ConnectionEpoch) const +{ + FScopeLock Lock(&InboundRawMessagesMutex); + return bInboundAcceptingMessages && !bInboundProtocolErrorQueued && InboundConnectionEpoch == ConnectionEpoch; +} + +void UDbConnectionBase::MarkInboundProtocolErrorQueued() +{ + FScopeLock Lock(&InboundRawMessagesMutex); + bInboundProtocolErrorQueued = true; + bInboundAcceptingMessages = false; + InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; + InboundQueuedRawBytes = 0; +} + +void UDbConnectionBase::EnqueueInboundProtocolError(uint64 SequenceId, int32 PayloadSizeBytes, uint8 CompressionTag, int32 QueueDepthAtEnqueue, int64 QueuedBytesAtEnqueue, const FString& ErrorMessage) +{ + UE_LOG( + LogSpacetimeDb_Connection, + Error, + TEXT("SpacetimeDB inbound protocol error: sequence=%llu payload_bytes=%d compression_tag=%u queued_messages=%d queued_bytes=%lld detail=%s"), + SequenceId, + PayloadSizeBytes, + static_cast(CompressionTag), + QueueDepthAtEnqueue, + QueuedBytesAtEnqueue, + *ErrorMessage); + + FInboundParsedMessage Parsed; + Parsed.SequenceId = SequenceId; + Parsed.PayloadSizeBytes = PayloadSizeBytes; + Parsed.CompressionTag = CompressionTag; + Parsed.QueueDepthAtEnqueue = QueueDepthAtEnqueue; + Parsed.QueuedBytesAtEnqueue = QueuedBytesAtEnqueue; + Parsed.bProtocolError = true; + Parsed.ProtocolError = ErrorMessage; + Parsed.EstimatedMemoryBytes = EstimateInboundParsedMessageBytes(Parsed); + + FScopeLock Lock(&PendingMessagesMutex); + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedEstimatedBytes = 0; + PendingMessages.Add(MoveTemp(Parsed)); + PendingParsedEstimatedBytes += PendingMessages[0].EstimatedMemoryBytes; +} + +void UDbConnectionBase::HandleWSBinaryMessage(const TArray& Message) +{ + TArray OwnedMessage = Message; + HandleWSBinaryMessageOwned(MoveTemp(OwnedMessage)); +} + +void UDbConnectionBase::HandleWSBinaryMessageOwned(TArray&& Message) +{ + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_InboundEnqueue); + + const int32 PayloadSizeBytes = Message.Num(); + const uint8 CompressionTag = PayloadSizeBytes > 0 ? Message[0] : 0; + uint64 SequenceId = 0; + int32 QueueDepthAtEnqueue = 0; + int64 QueuedBytesAtEnqueue = 0; + uint64 ConnectionEpoch = 0; + bool bQueueOverloaded = false; + FString QueueOverloadError; + + { + FScopeLock WorkerLock(&InboundWorkerMutex); + FScopeLock RawLock(&InboundRawMessagesMutex); + if (!bInboundAcceptingMessages || bInboundProtocolErrorQueued) + { + return; + } + checkf(InboundWorker != nullptr, TEXT("SpacetimeDB inbound worker missing while inbound connection epoch %llu is accepting messages."), InboundConnectionEpoch); + + ConnectionEpoch = InboundConnectionEpoch; + SequenceId = NextInboundSequenceId++; + checkf(InboundRawMessageReadIndex <= InboundRawMessages.Num(), + TEXT("SpacetimeDB inbound raw queue read index exceeded queued messages while enqueuing payload.")); + const int64 NewQueuedRawBytes = InboundQueuedRawBytes + static_cast(PayloadSizeBytes); + const int32 LiveQueuedRawMessageCount = InboundRawMessages.Num() - InboundRawMessageReadIndex; + const int32 NewQueuedRawMessageCount = LiveQueuedRawMessageCount + 1; + QueueDepthAtEnqueue = NewQueuedRawMessageCount; + QueuedBytesAtEnqueue = NewQueuedRawBytes; + if (NewQueuedRawMessageCount > MaxQueuedInboundRawMessages || NewQueuedRawBytes > MaxQueuedInboundRawBytes) + { + bInboundProtocolErrorQueued = true; + bInboundAcceptingMessages = false; + InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; + InboundQueuedRawBytes = 0; + bQueueOverloaded = true; + QueueOverloadError = FString::Printf( + TEXT("SpacetimeDB inbound queue overload: sequence=%llu payload_bytes=%d compression_tag=%u queued_messages=%d queued_bytes=%lld max_messages=%d max_bytes=%lld"), + SequenceId, + PayloadSizeBytes, + static_cast(CompressionTag), + NewQueuedRawMessageCount, + NewQueuedRawBytes, + MaxQueuedInboundRawMessages, + MaxQueuedInboundRawBytes); + } + else + { + FInboundRawMessage RawMessage; + RawMessage.ConnectionEpoch = ConnectionEpoch; + RawMessage.SequenceId = SequenceId; + RawMessage.QueueDepthAtEnqueue = QueueDepthAtEnqueue; + RawMessage.QueuedBytesAtEnqueue = QueuedBytesAtEnqueue; + RawMessage.Payload = MoveTemp(Message); + InboundRawMessages.Add(MoveTemp(RawMessage)); + InboundQueuedRawBytes = NewQueuedRawBytes; + InboundWorker->Notify(); + } + } + + if (bQueueOverloaded) + { + EnqueueInboundProtocolError(SequenceId, PayloadSizeBytes, CompressionTag, QueueDepthAtEnqueue, QueuedBytesAtEnqueue, QueueOverloadError); + return; + } +} + +void UDbConnectionBase::FrameTick() +{ + int32 MessagesProcessed = 0; + int64 PayloadBytesProcessed = 0; + const uint64 FrameStartCycles = FPlatformTime::Cycles64(); + const bool bDrainAllPendingMessages = InboundApplyBudget.bDrainAllPendingMessages; + { + FScopeLock Lock(&PendingMessagesMutex); + const int32 PendingCount = PendingMessages.Num() - PendingMessageReadIndex; + if (PendingCount <= 0) + { + //nothing to process, return early + return; + } + + } + + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_GameThreadApplyInbound); + + while (true) + { + FInboundParsedMessage Msg; + { + FScopeLock Lock(&PendingMessagesMutex); + if (PendingMessageReadIndex >= PendingMessages.Num()) + { + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedEstimatedBytes = 0; + break; + } + + if (!bDrainAllPendingMessages && MessagesProcessed >= InboundApplyBudget.MaxMessagesPerFrame) + { + break; + } + + FInboundParsedMessage& PendingMessage = PendingMessages[PendingMessageReadIndex]; + const int64 PendingPayloadBytes = static_cast(PendingMessage.PayloadSizeBytes); + const int64 PendingEstimatedBytes = PendingMessage.EstimatedMemoryBytes; + if (!bDrainAllPendingMessages && MessagesProcessed > 0 && PayloadBytesProcessed + PendingPayloadBytes > InboundApplyBudget.MaxPayloadBytesPerFrame) + { + break; + } + + PayloadBytesProcessed += PendingPayloadBytes; + PendingParsedEstimatedBytes = FMath::Max(0, PendingParsedEstimatedBytes - PendingEstimatedBytes); + Msg = MoveTemp(PendingMessage); + ++PendingMessageReadIndex; + + if (PendingMessageReadIndex == PendingMessages.Num()) + { + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedEstimatedBytes = 0; + } + else if (PendingMessageReadIndex >= PendingInboundCompactionMinConsumedMessages) + { + PendingMessages.RemoveAt(0, PendingMessageReadIndex, EAllowShrinking::No); + PendingMessageReadIndex = 0; + } + } + + if (Msg.bProtocolError) + { + HandleProtocolViolation(Msg.ProtocolError); + break; + } + + const uint64 MessageStartCycles = FPlatformTime::Cycles64(); + FSpacetimeDBInboundMessageApplyStats ApplyStats; + ApplyStats.MessageKind = DescribeServerMessageTag(Msg.Message.Tag); + ApplyStats.SequenceId = Msg.SequenceId; + ApplyStats.PayloadSizeBytes = Msg.PayloadSizeBytes; + ApplyStats.QueueDepthAtEnqueue = Msg.QueueDepthAtEnqueue; + ApplyStats.QueuedBytesAtEnqueue = Msg.QueuedBytesAtEnqueue; + if (Msg.Message.Tag == EServerMessageTag::ReducerResult) + { + const FReducerResultType& Payload = Msg.Message.MessageData.Get(); + ApplyStats.RequestId = Payload.RequestId; + if (const FReducerCallInfoType* FoundReducerCall = PendingReducerCalls.Find(Payload.RequestId)) + { + ApplyStats.ReducerName = FoundReducerCall->ReducerName; + } } - UDbConnectionBase* This = WeakThis.Get(); - //parse the message, decompress if needed - FServerMessageType Parsed; - if (!This->PreProcessMessage(Message, Parsed)) + ProcessInboundServerMessage(Msg, ApplyStats); + const double MessageElapsedMicros = + FPlatformTime::ToMilliseconds64(FPlatformTime::Cycles64() - MessageStartCycles) * 1000.0; + if (!bDrainAllPendingMessages && + InboundApplyBudget.SoftTimeBudgetMicros > 0 && + MessageElapsedMicros >= static_cast(InboundApplyBudget.SoftTimeBudgetMicros)) { - AsyncTask(ENamedThreads::GameThread, [WeakThis]() + TArray SortedStats = ApplyStats.TableStats; + SortedStats.Sort([](const FSpacetimeDBTableApplyStats& A, const FSpacetimeDBTableApplyStats& B) { - if (!WeakThis.IsValid()) + return (A.CacheMicros + A.BroadcastMicros) > (B.CacheMicros + B.BroadcastMicros); + }); + TArray TopTableSummaries; + const int32 LoggedTableCount = FMath::Min(SortedStats.Num(), MaxInboundApplyLogTableContributors); + TopTableSummaries.Reserve(LoggedTableCount); + for (int32 TableIndex = 0; TableIndex < LoggedTableCount; ++TableIndex) + { + TopTableSummaries.Add(FormatInboundTableApplyStats(SortedStats[TableIndex])); + } + const FString TopTablesText = TopTableSummaries.IsEmpty() + ? TEXT("") + : FString::Join(TopTableSummaries, TEXT(" | ")); + UE_LOG(LogSpacetimeDb_Connection, + Warning, + TEXT("SpacetimeDB inbound single-message apply exceeded soft budget: %.2fus >= %lldus kind=%s sequence=%llu request_id=%u reducer=%s payload_bytes=%d queued_messages=%d queued_bytes=%lld messages_processed_before=%d tables=%d top_tables=%s"), + MessageElapsedMicros, + InboundApplyBudget.SoftTimeBudgetMicros, + *ApplyStats.MessageKind, + ApplyStats.SequenceId, + ApplyStats.RequestId, + ApplyStats.ReducerName.IsEmpty() ? TEXT("") : *ApplyStats.ReducerName, + ApplyStats.PayloadSizeBytes, + ApplyStats.QueueDepthAtEnqueue, + ApplyStats.QueuedBytesAtEnqueue, + MessagesProcessed, + ApplyStats.TableStats.Num(), + *TopTablesText); + } + ++MessagesProcessed; + + if (!bDrainAllPendingMessages && + MessagesProcessed >= InboundApplyBudget.MinMessagesPerFrame && + InboundApplyBudget.SoftTimeBudgetMicros > 0) + { + const double ElapsedMicros = FPlatformTime::ToMilliseconds64(FPlatformTime::Cycles64() - FrameStartCycles) * 1000.0; + if (ElapsedMicros >= static_cast(InboundApplyBudget.SoftTimeBudgetMicros)) + { + break; + } + } + } + + if (MessagesProcessed > 0) + { + NotifyInboundWorkerIfNeeded(); + } +} + +void UDbConnectionBase::DrainInboundRawMessagesOnWorker() +{ + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_InboundWorkerDrain); + + while (!IsInboundProtocolErrorQueued()) + { + int32 ParsedMessageCapacity = 0; + int64 ParsedEstimatedByteCapacity = 0; + { + FScopeLock Lock(&PendingMessagesMutex); + const int32 LivePendingMessages = PendingMessages.Num() - PendingMessageReadIndex; + ParsedMessageCapacity = MaxPendingInboundParsedMessages - LivePendingMessages; + ParsedEstimatedByteCapacity = MaxPendingInboundParsedEstimatedBytes - PendingParsedEstimatedBytes; + } + + if (ParsedMessageCapacity <= 0 || ParsedEstimatedByteCapacity <= 0) + { + return; + } + + TArray LocalRawMessages; + int64 DrainedRawBytes = 0; + { + FScopeLock Lock(&InboundRawMessagesMutex); + checkf(InboundRawMessageReadIndex <= InboundRawMessages.Num(), + TEXT("SpacetimeDB inbound raw queue read index exceeded queued messages while draining worker queue.")); + const int32 LiveRawMessageCount = InboundRawMessages.Num() - InboundRawMessageReadIndex; + if (LiveRawMessageCount <= 0 || !bInboundAcceptingMessages || bInboundProtocolErrorQueued) + { + return; + } + + int32 DrainCount = 0; + for (; DrainCount < LiveRawMessageCount && DrainCount < ParsedMessageCapacity; ++DrainCount) + { + const FInboundRawMessage& Candidate = InboundRawMessages[InboundRawMessageReadIndex + DrainCount]; + const int64 NextPayloadBytes = static_cast(Candidate.Payload.Num()); + if (DrainCount > 0 && DrainedRawBytes + NextPayloadBytes > ParsedEstimatedByteCapacity) + { + break; + } + if (DrainCount == 0 && NextPayloadBytes > ParsedEstimatedByteCapacity) { return; } - UDbConnectionBase* Conn = WeakThis.Get(); - Conn->HandleProtocolViolation(TEXT("Failed to parse/decompress incoming WebSocket message")); - }); + + DrainedRawBytes += NextPayloadBytes; + } + + if (DrainCount == 0) + { + return; + } + + LocalRawMessages.Reserve(DrainCount); + for (int32 Index = 0; Index < DrainCount; ++Index) + { + LocalRawMessages.Add(MoveTemp(InboundRawMessages[InboundRawMessageReadIndex + Index])); + } + + InboundRawMessageReadIndex += DrainCount; + if (InboundRawMessageReadIndex == InboundRawMessages.Num()) + { + InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; + } + else if (InboundRawMessageReadIndex >= InboundRawCompactionMinConsumedMessages) + { + InboundRawMessages.RemoveAt(0, InboundRawMessageReadIndex, EAllowShrinking::No); + InboundRawMessageReadIndex = 0; + } + InboundQueuedRawBytes = FMath::Max(0, InboundQueuedRawBytes - DrainedRawBytes); + } + + TArray LocalParsedMessages; + LocalParsedMessages.Reserve(LocalRawMessages.Num()); + + for (FInboundRawMessage& RawMessage : LocalRawMessages) + { + if (!IsInboundEpochCurrentAndAccepting(RawMessage.ConnectionEpoch)) + { + return; + } + + FInboundParsedMessage ParsedMessage; + if (!BuildInboundParsedMessage(RawMessage, ParsedMessage)) + { + if (!IsInboundEpochCurrentAndAccepting(RawMessage.ConnectionEpoch)) + { + return; + } + MarkInboundProtocolErrorQueued(); + LocalParsedMessages.Add(MoveTemp(ParsedMessage)); + break; + } + + if (!IsInboundEpochCurrentAndAccepting(RawMessage.ConnectionEpoch)) + { + return; + } + LocalParsedMessages.Add(MoveTemp(ParsedMessage)); + } + + if (LocalParsedMessages.Num() == 0) + { + continue; + } + + const bool bBatchEndsWithProtocolError = LocalParsedMessages.Last().bProtocolError; + if (IsInboundProtocolErrorQueued() && !bBatchEndsWithProtocolError) + { + return; + } + if (!bBatchEndsWithProtocolError && !IsInboundEpochCurrentAndAccepting(LocalParsedMessages[0].ConnectionEpoch)) + { + return; + } + + uint64 OverloadSequenceId = LocalParsedMessages[0].SequenceId; + int32 OverloadPayloadSizeBytes = LocalParsedMessages[0].PayloadSizeBytes; + uint8 OverloadCompressionTag = LocalParsedMessages[0].CompressionTag; + bool bParsedQueueOverloaded = false; + FString ParsedQueueOverloadError; + { + FScopeLock Lock(&PendingMessagesMutex); + int64 AddedEstimatedBytes = 0; + for (const FInboundParsedMessage& ParsedMessage : LocalParsedMessages) + { + AddedEstimatedBytes += ParsedMessage.EstimatedMemoryBytes; + } + + const int32 LivePendingMessages = PendingMessages.Num() - PendingMessageReadIndex; + const int32 NewPendingMessageCount = LivePendingMessages + LocalParsedMessages.Num(); + const int64 NewPendingEstimatedBytes = PendingParsedEstimatedBytes + AddedEstimatedBytes; + bParsedQueueOverloaded = !bBatchEndsWithProtocolError && + (NewPendingMessageCount > MaxPendingInboundParsedMessages || + NewPendingEstimatedBytes > MaxPendingInboundParsedEstimatedBytes); + if (bParsedQueueOverloaded) + { + ParsedQueueOverloadError = FString::Printf( + TEXT("SpacetimeDB parsed inbound queue overload: sequence=%llu payload_bytes=%d compression_tag=%u queued_messages=%d estimated_bytes=%lld max_messages=%d max_estimated_bytes=%lld"), + OverloadSequenceId, + OverloadPayloadSizeBytes, + static_cast(OverloadCompressionTag), + NewPendingMessageCount, + NewPendingEstimatedBytes, + MaxPendingInboundParsedMessages, + MaxPendingInboundParsedEstimatedBytes); + } + else + { + PendingMessages.Append(MoveTemp(LocalParsedMessages)); + PendingParsedEstimatedBytes = NewPendingEstimatedBytes; + } + } + + if (bParsedQueueOverloaded) + { + MarkInboundProtocolErrorQueued(); + EnqueueInboundProtocolError( + OverloadSequenceId, + OverloadPayloadSizeBytes, + OverloadCompressionTag, + LocalParsedMessages[0].QueueDepthAtEnqueue, + LocalParsedMessages[0].QueuedBytesAtEnqueue, + ParsedQueueOverloadError); return; } - - //queue: re-order buffer - TArray Ready; - { - FScopeLock Lock(&This->PreprocessMutex); - // Move the parsed message into the map to avoid copying - This->PreprocessedMessages.Add(Id, MoveTemp(Parsed)); - //check if we can release any messages in order - while (This->PreprocessedMessages.Contains(This->NextReleaseId)) - { - Ready.Add(This->PreprocessedMessages.FindAndRemoveChecked(This->NextReleaseId)); - ++This->NextReleaseId; - } - } - //if we have any ready messages, append them to the pending messages list that is processed in Tick - if (Ready.Num() > 0) - { - FScopeLock Lock(&This->PendingMessagesMutex); - This->PendingMessages.Append(MoveTemp(Ready)); - } - }); -} - -void UDbConnectionBase::FrameTick() -{ - TArray Local; - { - FScopeLock Lock(&PendingMessagesMutex); - if (PendingMessages.Num() == 0) - { - //nothing to process, return early - return; - } - //move pending messages to local array for processing - Local = MoveTemp(PendingMessages); - PendingMessages.Empty(); - } - - //process all messages in the local array - for (const FServerMessageType& Msg : Local) - { - //process the message, this will call DbUpdate or trigger subscription events as needed - ProcessServerMessage(Msg); } } -bool UDbConnectionBase::OnTickerTick(float DeltaTime) +bool UDbConnectionBase::BuildInboundParsedMessage(const FInboundRawMessage& RawMessage, FInboundParsedMessage& OutMessage) { - if (HasAnyFlags(RF_BeginDestroyed | RF_FinishDestroyed) || !bIsAutoTicking) + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_InboundPreprocess); + + OutMessage.ConnectionEpoch = RawMessage.ConnectionEpoch; + OutMessage.SequenceId = RawMessage.SequenceId; + OutMessage.PayloadSizeBytes = RawMessage.Payload.Num(); + OutMessage.CompressionTag = RawMessage.Payload.Num() > 0 ? RawMessage.Payload[0] : 0; + OutMessage.QueueDepthAtEnqueue = RawMessage.QueueDepthAtEnqueue; + OutMessage.QueuedBytesAtEnqueue = RawMessage.QueuedBytesAtEnqueue; + + if (!PreProcessMessage(RawMessage.Payload, OutMessage)) { + OutMessage.bProtocolError = true; + OutMessage.ProtocolError = FString::Printf( + TEXT("Failed to parse/decompress incoming WebSocket message: sequence=%llu payload_bytes=%d compression_tag=%u queued_messages=%d queued_bytes=%lld"), + OutMessage.SequenceId, + OutMessage.PayloadSizeBytes, + static_cast(OutMessage.CompressionTag), + OutMessage.QueueDepthAtEnqueue, + OutMessage.QueuedBytesAtEnqueue); + OutMessage.EstimatedMemoryBytes = EstimateInboundParsedMessageBytes(OutMessage); return false; } - FrameTick(); return true; } - - + +void UDbConnectionBase::Tick(float DeltaTime) +{ + if (bIsAutoTicking) + { + FrameTick(); + } +} + +TStatId UDbConnectionBase::GetStatId() const +{ + // This is used by the engine to track tickables, we return a unique stat ID for this class + RETURN_QUICK_DECLARE_CYCLE_STAT(UMyTickableObject, STATGROUP_Tickables); +} + +bool UDbConnectionBase::IsTickable() const +{ + return bIsAutoTicking; +} + +bool UDbConnectionBase::IsTickableInEditor() const +{ + return bIsAutoTicking; +} + + +void UDbConnectionBase::ProcessInboundServerMessage(FInboundParsedMessage& InboundMessage, FSpacetimeDBInboundMessageApplyStats& ApplyStats) +{ + struct FActivePreprocessedDataGuard + { + FPreprocessedTableDataMap*& Target; + FSpacetimeDBInboundMessageApplyStats*& StatsTarget; + + FActivePreprocessedDataGuard( + FPreprocessedTableDataMap*& InTarget, + FPreprocessedTableDataMap* InValue, + FSpacetimeDBInboundMessageApplyStats*& InStatsTarget, + FSpacetimeDBInboundMessageApplyStats* InStatsValue) + : Target(InTarget) + , StatsTarget(InStatsTarget) + { + checkf(Target == nullptr, TEXT("Nested SpacetimeDB inbound table preprocessing scope detected.")); + checkf(StatsTarget == nullptr, TEXT("Nested SpacetimeDB inbound apply stats scope detected.")); + Target = InValue; + StatsTarget = InStatsValue; + } + + ~FActivePreprocessedDataGuard() + { + Target = nullptr; + StatsTarget = nullptr; + } + }; + + FActivePreprocessedDataGuard Guard( + ActivePreprocessedTableData, + &InboundMessage.PreprocessedTableData, + ActiveInboundMessageApplyStats, + &ApplyStats); + ProcessServerMessage(InboundMessage.Message); +} + void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) { switch (Message.Tag) { case EServerMessageTag::InitialConnection: { - const FInitialConnectionType Payload = Message.GetAsInitialConnection(); + const FInitialConnectionType& Payload = Message.MessageData.Get(); Token = Payload.Token; UCredentials::SaveToken(Token); Identity = Payload.Identity; @@ -307,7 +1002,7 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::TransactionUpdate: { - const FTransactionUpdateType Payload = Message.GetAsTransactionUpdate(); + const FTransactionUpdateType& Payload = Message.MessageData.Get(); const FDatabaseUpdateType Update = TransactionUpdateToDatabaseUpdate(Payload); DbUpdate(Update, FSpacetimeDBEvent::Transaction(FSpacetimeDBUnit())); break; @@ -319,8 +1014,8 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::SubscribeApplied: { - const FSubscribeAppliedType Payload = Message.GetAsSubscribeApplied(); - const FDatabaseUpdateType Update = QueryRowsToDatabaseUpdate(Payload.Rows, false); + const FSubscribeAppliedType& Payload = Message.MessageData.Get(); + const FDatabaseUpdateType Update = QueryRowsToDatabaseUpdate(Payload.Rows, UE::SpacetimeDB::EQueryRowsApplyMode::Inserts); DbUpdate(Update, FSpacetimeDBEvent::SubscribeApplied(FSpacetimeDBUnit())); if (TObjectPtr* HandlePtr = ActiveSubscriptions.Find(Payload.QuerySetId.Id)) @@ -339,10 +1034,10 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::UnsubscribeApplied: { - const FUnsubscribeAppliedType Payload = Message.GetAsUnsubscribeApplied(); + const FUnsubscribeAppliedType& Payload = Message.MessageData.Get(); if (Payload.Rows.IsSet()) { - const FDatabaseUpdateType Update = QueryRowsToDatabaseUpdate(Payload.Rows.Value, true); + const FDatabaseUpdateType Update = QueryRowsToDatabaseUpdate(Payload.Rows.Value, UE::SpacetimeDB::EQueryRowsApplyMode::Deletes); DbUpdate(Update, FSpacetimeDBEvent::UnsubscribeApplied(FSpacetimeDBUnit())); } @@ -369,7 +1064,7 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::SubscriptionError: { - const FSubscriptionErrorType Payload = Message.GetAsSubscriptionError(); + const FSubscriptionErrorType& Payload = Message.MessageData.Get(); UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("SubscriptionError received for QuerySetId=%u Error=%s"), Payload.QuerySetId.Id, *Payload.Error); @@ -391,7 +1086,7 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::ReducerResult: { - const FReducerResultType Payload = Message.GetAsReducerResult(); + const FReducerResultType& Payload = Message.MessageData.Get(); const FReducerCallInfoType* FoundReducerCall = PendingReducerCalls.Find(Payload.RequestId); if (!FoundReducerCall) { @@ -415,7 +1110,7 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) if (Payload.Result.IsOk()) { RedEvent.Status = FSpacetimeDBStatus::Committed(FSpacetimeDBUnit()); - const FReducerOkType Ok = Payload.Result.GetAsOk(); + const FReducerOkType& Ok = Payload.Result.MessageData.Get(); const FDatabaseUpdateType Update = TransactionUpdateToDatabaseUpdate(Ok.TransactionUpdate); DbUpdate(Update, FSpacetimeDBEvent::Reducer(RedEvent)); ReducerEvent(RedEvent); @@ -430,11 +1125,11 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) FString ErrorMessage; if (Payload.Result.IsErr()) { - ErrorMessage = DecodeReducerErrorMessage(Payload.Result.GetAsErr()); + ErrorMessage = DecodeReducerErrorMessage(Payload.Result.MessageData.Get>()); } else { - ErrorMessage = Payload.Result.GetAsInternalError(); + ErrorMessage = Payload.Result.MessageData.Get(); } RedEvent.Status = FSpacetimeDBStatus::Failed(ErrorMessage); ReducerEvent(RedEvent); @@ -444,7 +1139,7 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::ProcedureResult: { - const FProcedureResultType Payload = Message.GetAsProcedureResult(); + const FProcedureResultType& Payload = Message.MessageData.Get(); FProcedureEvent ProcEvent; ProcEvent.Status = Payload.Status; ProcEvent.Timestamp = Payload.Timestamp; @@ -488,54 +1183,52 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) break; } } - -bool UDbConnectionBase::DecompressBrotli(const TArray& InData, TArray& OutData) -{ - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Brotli decompression unavilable")); - return false; -} - -bool UDbConnectionBase::DecompressGzip(const TArray& InData, TArray& OutData) -{ - if (InData.Num() < 4) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Gzip data too small")); - return false; - } - - // Gzip data ends with 4 bytes indicating the uncompressed size - const uint8* SizePtr = InData.GetData() + InData.Num() - 4; - uint32 OutSize = SizePtr[0] | (SizePtr[1] << 8) | (SizePtr[2] << 16) | (SizePtr[3] << 24); - - // Validate the output size - OutData.SetNumUninitialized(OutSize); - // Attempt to decompress the Gzip data - if (!FCompression::UncompressMemory(NAME_Gzip, OutData.GetData(), OutSize, InData.GetData(), InData.Num())) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Gzip decompression failed")); - return false; - } - - OutData.SetNum(OutSize); - return true; -} - -bool UDbConnectionBase::DecompressPayload(uint8 Variant, const TArray& In, TArray& Out) -{ - switch (static_cast(Variant)) + +bool UDbConnectionBase::DecompressBrotli(const uint8* InData, int32 InSize, TArray& OutData) +{ + (void)InData; + (void)InSize; + (void)OutData; + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Brotli decompression unavilable")); + return false; +} + +bool UDbConnectionBase::DecompressGzip(const uint8* InData, int32 InSize, TArray& OutData) +{ + if (InData == nullptr || InSize < GzipFooterUncompressedSizeBytes) { - case EWsCompressionTag::Uncompressed: - // No compression, just copy the data - Out = In; - return true; - case EWsCompressionTag::Brotli: - return DecompressBrotli(In, Out); - case EWsCompressionTag::Gzip: - return DecompressGzip(In, Out); - default: - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Unknown compression variant")); + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Gzip data too small")); return false; } + + // Gzip data ends with 4 bytes indicating the uncompressed size + const uint8* SizePtr = InData + InSize - GzipFooterUncompressedSizeBytes; + const uint32 OutSize = + static_cast(SizePtr[0]) | + (static_cast(SizePtr[1]) << 8) | + (static_cast(SizePtr[2]) << 16) | + (static_cast(SizePtr[3]) << 24); + if (static_cast(OutSize) > MaxInboundDecodedPayloadBytes) + { + UE_LOG(LogSpacetimeDb_Connection, + Error, + TEXT("Gzip payload declares %u decoded bytes, exceeding max decoded bytes %lld"), + OutSize, + MaxInboundDecodedPayloadBytes); + return false; + } + + // Validate the output size + OutData.SetNumUninitialized(OutSize); + // Attempt to decompress the Gzip data + if (!FCompression::UncompressMemory(NAME_Gzip, OutData.GetData(), OutSize, InData, InSize)) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Gzip decompression failed")); + return false; + } + + OutData.SetNum(OutSize); + return true; } void UDbConnectionBase::ClearPendingOperations(const FString& Reason) @@ -550,107 +1243,181 @@ void UDbConnectionBase::ClearPendingOperations(const FString& Reason) UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("Cleared pending operations due to connection issue: %s"), *Reason); } } - -void UDbConnectionBase::PreProcessDatabaseUpdate(const FDatabaseUpdateType& Update) + +void UDbConnectionBase::PreProcessTableUpdateRows( + const FString& TableName, + const TArray& RowSets, + FPreprocessedTableDataMap& OutPreprocessedTableData) { - for (const FTableUpdateType& TableUpdate : Update.Tables) + TSharedPtr Deserializer = FindTableDeserializerForPreprocess(TableName); + if (!Deserializer.IsValid()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Skipping table %s updates due to missing deserializer"), *TableName); + return; + } + + StorePreprocessedTableData(TableName, Deserializer->PreProcess(RowSets, TableName), OutPreprocessedTableData); +} + +void UDbConnectionBase::PreProcessQueryRows( + const FQueryRowsType& Rows, + UE::SpacetimeDB::EQueryRowsApplyMode Mode, + FPreprocessedTableDataMap& OutPreprocessedTableData) +{ + for (const FSingleTableRowsType& TableRows : Rows.Tables) { - // Attempt to deserialize rows after payload decode. - TSharedPtr Deserializer; + TSharedPtr Deserializer = FindTableDeserializerForPreprocess(TableRows.Table); + if (!Deserializer.IsValid()) { - // Find the deserializer for this table - FScopeLock Lock(&TableDeserializersMutex); - if (TSharedPtr* Found = TableDeserializers.Find(TableUpdate.TableName)) - { - // If found, use the deserializer - Deserializer = *Found; - } - else - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("No deserializer found for table %s"), *TableUpdate.TableName); - } + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Skipping table %s query rows due to missing deserializer"), *TableRows.Table); + continue; } - if (Deserializer) + + StorePreprocessedTableData(TableRows.Table, Deserializer->PreProcessQueryRows(TableRows.Rows, Mode, TableRows.Table), OutPreprocessedTableData); + } +} + +void UDbConnectionBase::PreProcessTransactionUpdate( + const FTransactionUpdateType& Update, + FPreprocessedTableDataMap& OutPreprocessedTableData) +{ + for (const FQuerySetUpdateType& QuerySet : Update.QuerySets) + { + for (const FTableUpdateType& TableUpdate : QuerySet.Tables) { - TSharedPtr Data = Deserializer->PreProcess(TableUpdate.Rows, TableUpdate.TableName); - if (Data.IsValid()) - { - FScopeLock Lock(&PreprocessedDataMutex); - FPreprocessedTableKey Key(TableUpdate.TableName); - TArray>& Queue = PreprocessedTableData.FindOrAdd(Key); - Queue.Add(Data); - } - } - else - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Skipping table %s updates due to missing deserializer"), *TableUpdate.TableName); - } - } -} - -bool UDbConnectionBase::PreProcessMessage(const TArray& Message, FServerMessageType& OutMessage) + PreProcessTableUpdateRows(TableUpdate.TableName, TableUpdate.Rows, OutPreprocessedTableData); + } + } +} + +TSharedPtr UDbConnectionBase::FindTableDeserializerForPreprocess(const FString& TableName) +{ + FScopeLock Lock(&TableDeserializersMutex); + if (TSharedPtr* Found = TableDeserializers.Find(TableName)) + { + return *Found; + } + return nullptr; +} + +void UDbConnectionBase::StorePreprocessedTableData( + const FString& TableName, + TSharedPtr Data, + FPreprocessedTableDataMap& OutPreprocessedTableData) +{ + checkf(Data.IsValid(), TEXT("Invalid message-scoped preprocessed data generated for table '%s'."), *TableName); + + FPreprocessedTableKey Key(TableName); + TArray>& Queue = OutPreprocessedTableData.FindOrAdd(Key); + Queue.Add(MoveTemp(Data)); +} + +bool UDbConnectionBase::PreProcessMessage(const TArray& Message, FInboundParsedMessage& OutMessage) { - if (Message.Num() == 0) + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_PreProcessMessage); + + if (Message.Num() <= SpacetimeDbCompressionTagBytes) { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Empty message recived from server, ignored")); + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Empty message received from server, ignored")); return false; } // The first byte indicates compression format for the payload. const uint8 Compression = Message[0]; - TArray CompressedPayload; - CompressedPayload.Append(Message.GetData() + 1, Message.Num() - 1); - - // Decompress the payload based on the compression tag - TArray Decompressed; - if (!DecompressPayload(Compression, CompressedPayload, Decompressed)) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Failed to decompress incoming message")); + + const uint8* DecodedPayload = Message.GetData() + SpacetimeDbCompressionTagBytes; + int32 DecodedPayloadSize = Message.Num() - SpacetimeDbCompressionTagBytes; + TArray DecompressedStorage; + switch (static_cast(Compression)) + { + case EWsCompressionTag::Uncompressed: + break; + case EWsCompressionTag::Brotli: + if (!DecompressBrotli(DecodedPayload, DecodedPayloadSize, DecompressedStorage)) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Failed to decompress Brotli incoming message")); + return false; + } + DecodedPayload = DecompressedStorage.GetData(); + DecodedPayloadSize = DecompressedStorage.Num(); + break; + case EWsCompressionTag::Gzip: + if (!DecompressGzip(DecodedPayload, DecodedPayloadSize, DecompressedStorage)) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Failed to decompress Gzip incoming message")); + return false; + } + DecodedPayload = DecompressedStorage.GetData(); + DecodedPayloadSize = DecompressedStorage.Num(); + break; + default: + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Unknown compression variant")); + return false; + } + if (static_cast(DecodedPayloadSize) > MaxInboundDecodedPayloadBytes) + { + UE_LOG(LogSpacetimeDb_Connection, + Error, + TEXT("Decoded server message payload has %d bytes, exceeding max decoded bytes %lld after compression tag %u."), + DecodedPayloadSize, + MaxInboundDecodedPayloadBytes, + static_cast(Compression)); + return false; + } + if (DecodedPayloadSize <= 0 || DecodedPayload == nullptr) + { + UE_LOG(LogSpacetimeDb_Connection, + Error, + TEXT("SpacetimeDB decoded server message payload is empty after compression tag %u."), + static_cast(Compression)); return false; } + OutMessage.DecodedPayloadSizeBytes = DecodedPayloadSize; // Deserialize the decompressed data into a UServerMessageType object - OutMessage = UE::SpacetimeDB::Deserialize(Decompressed); + OutMessage.Message = UE::SpacetimeDB::DeserializeView(DecodedPayload, DecodedPayloadSize); // Preprocess row-bearing payloads for table deserializers. - switch (OutMessage.Tag) + switch (OutMessage.Message.Tag) { case EServerMessageTag::SubscribeApplied: { - const FSubscribeAppliedType Payload = OutMessage.GetAsSubscribeApplied(); - PreProcessDatabaseUpdate(QueryRowsToDatabaseUpdate(Payload.Rows, false)); + const FSubscribeAppliedType& Payload = OutMessage.Message.MessageData.Get(); + PreProcessQueryRows(Payload.Rows, UE::SpacetimeDB::EQueryRowsApplyMode::Inserts, OutMessage.PreprocessedTableData); break; } case EServerMessageTag::UnsubscribeApplied: { - const FUnsubscribeAppliedType Payload = OutMessage.GetAsUnsubscribeApplied(); + const FUnsubscribeAppliedType& Payload = OutMessage.Message.MessageData.Get(); if (Payload.Rows.IsSet()) { - PreProcessDatabaseUpdate(QueryRowsToDatabaseUpdate(Payload.Rows.Value, true)); + PreProcessQueryRows(Payload.Rows.Value, UE::SpacetimeDB::EQueryRowsApplyMode::Deletes, OutMessage.PreprocessedTableData); } break; } case EServerMessageTag::TransactionUpdate: { - const FTransactionUpdateType Payload = OutMessage.GetAsTransactionUpdate(); - PreProcessDatabaseUpdate(TransactionUpdateToDatabaseUpdate(Payload)); + const FTransactionUpdateType& Payload = OutMessage.Message.MessageData.Get(); + PreProcessTransactionUpdate(Payload, OutMessage.PreprocessedTableData); break; } case EServerMessageTag::ReducerResult: { - const FReducerResultType Payload = OutMessage.GetAsReducerResult(); + const FReducerResultType& Payload = OutMessage.Message.MessageData.Get(); if (Payload.Result.IsOk()) { - PreProcessDatabaseUpdate(TransactionUpdateToDatabaseUpdate(Payload.Result.GetAsOk().TransactionUpdate)); + const FReducerOkType& Ok = Payload.Result.MessageData.Get(); + PreProcessTransactionUpdate(Ok.TransactionUpdate, OutMessage.PreprocessedTableData); } break; } default: break; } + OutMessage.EstimatedMemoryBytes = EstimateInboundParsedMessageBytes(OutMessage); return true; } - - + + uint32 UDbConnectionBase::GetNextRequestId() { return NextRequestId++; @@ -660,21 +1427,21 @@ uint32 UDbConnectionBase::GetNextSubscriptionId() { return NextSubscriptionId++; } - + void UDbConnectionBase::StartSubscription(USubscriptionHandleBase* Handle) -{ - if (!Handle) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("StartSubscription called with null handle")); - return; - } - - if (Handle->QuerySqls.Num() == 0) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("StartSubscription called with empty query list")); - return; - } - +{ + if (!Handle) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("StartSubscription called with null handle")); + return; + } + + if (Handle->QuerySqls.Num() == 0) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("StartSubscription called with empty query list")); + return; + } + const uint32 QuerySetId = GetNextSubscriptionId(); Handle->QuerySetId = QuerySetId; Handle->ConnInternal = this; @@ -689,14 +1456,14 @@ void UDbConnectionBase::StartSubscription(USubscriptionHandleBase* Handle) TArray Data = UE::SpacetimeDB::Serialize(Msg); SendRawMessage(Data); } - -void UDbConnectionBase::UnsubscribeInternal(USubscriptionHandleBase* Handle) -{ - if (!Handle || Handle->bEnded) - { - return; - } - + +void UDbConnectionBase::UnsubscribeInternal(USubscriptionHandleBase* Handle) +{ + if (!Handle || Handle->bEnded) + { + return; + } + const uint32 QuerySetId = Handle->QuerySetId; FUnsubscribeType MsgData; MsgData.RequestId = GetNextRequestId(); @@ -707,7 +1474,7 @@ void UDbConnectionBase::UnsubscribeInternal(USubscriptionHandleBase* Handle) TArray Data = UE::SpacetimeDB::Serialize(Msg); SendRawMessage(Data); } - + uint32 UDbConnectionBase::InternalCallReducer(const FString& Reducer, TArray Args) { if (!WebSocket || !WebSocket->IsConnected()) @@ -732,51 +1499,124 @@ uint32 UDbConnectionBase::InternalCallReducer(const FString& Reducer, TArray Args, const FOnProcedureCompleteDelegate& Callback) -{ - if (!WebSocket || !WebSocket->IsConnected()) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Cannot call proceduer, not connected to server!")); - return; - } - FCallProcedureType MsgData; - MsgData.Procedure = ProcedureName; - MsgData.Args = Args; - MsgData.RequestId = ProcedureCallbacks->RegisterCallback(Callback); - MsgData.Flags = static_cast(EProcedureFlags::Default); - - FClientMessageType Msg = FClientMessageType::CallProcedure(MsgData); - TArray Data = UE::SpacetimeDB::Serialize(Msg); - SendRawMessage(Data); -} - -void UDbConnectionBase::ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context) -{ - // Ensure we have a valid context for the update - TArray> Handlers; - for (const FTableUpdateType& TableUpdate : Update.Tables) - { - TSharedPtr Handler; - { - // Find the handler for this table update - FScopeLock Lock(&RegisteredTablesMutex); - if (TSharedPtr* Found = RegisteredTables.Find(TableUpdate.TableName)) - { - Handler = *Found; - } - } - if (Handler.IsValid()) - { - // Update the cache for the handler with the table update and context - Handler->UpdateCache(this, TableUpdate, Context); - Handlers.Add(Handler); - } - } - - for (TSharedPtr& Handler : Handlers) - { - // Broadcast the diff for each handler - Handler->BroadcastDiff(this, Context); - } + +void UDbConnectionBase::InternalCallProcedure(const FString& ProcedureName, TArray Args, const FOnProcedureCompleteDelegate& Callback) +{ + if (!WebSocket || !WebSocket->IsConnected()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Cannot call proceduer, not connected to server!")); + return; + } + FCallProcedureType MsgData; + MsgData.Procedure = ProcedureName; + MsgData.Args = Args; + MsgData.RequestId = ProcedureCallbacks->RegisterCallback(Callback); + MsgData.Flags = static_cast(EProcedureFlags::Default); + + FClientMessageType Msg = FClientMessageType::CallProcedure(MsgData); + TArray Data = UE::SpacetimeDB::Serialize(Msg); + SendRawMessage(Data); +} + +void UDbConnectionBase::ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context) +{ + checkf(ActivePreprocessedTableData != nullptr, TEXT("ApplyRegisteredTableUpdates requires active message-scoped preprocessed data.")); + + TSharedPtr>> RegisteredHandlers; + { + FScopeLock Lock(&RegisteredTablesMutex); + RegisteredHandlers = RegisteredTablesSnapshot; + } + if (!RegisteredHandlers.IsValid()) + { + return; + } + + TableUpdateHandlersScratch.Reset(); + TableUpdateHandlersScratch.Reserve(Update.Tables.Num()); + if (ActiveInboundMessageApplyStats != nullptr) + { + ActiveInboundMessageApplyStats->TableStats.Reserve( + ActiveInboundMessageApplyStats->TableStats.Num() + Update.Tables.Num()); + } + for (const FTableUpdateType& TableUpdate : Update.Tables) + { + if (TableUpdate.Rows.IsEmpty()) + { + continue; + } + + TSharedPtr Handler; + if (const TSharedPtr* Found = RegisteredHandlers->Find(TableUpdate.TableName)) + { + Handler = *Found; + } + if (Handler.IsValid()) + { + FSpacetimeDBTableApplyStats* TableStats = nullptr; + int32 TableStatsIndex = INDEX_NONE; + if (ActiveInboundMessageApplyStats != nullptr) + { + TableStatsIndex = ActiveInboundMessageApplyStats->TableStats.AddDefaulted(); + TableStats = &ActiveInboundMessageApplyStats->TableStats[TableStatsIndex]; + TableStats->TableName = Handler->GetTableName(); + } + + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_TableUpdateCache); + const uint64 TableCacheStartCycles = FPlatformTime::Cycles64(); + const bool bHasNonEmptyDiff = Handler->UpdateCache(this, TableUpdate, Context, TableStats); + const double TableCacheElapsedMicros = + FPlatformTime::ToMilliseconds64(FPlatformTime::Cycles64() - TableCacheStartCycles) * 1000.0; + if (TableStats != nullptr) + { + TableStats->CacheMicros = TableCacheElapsedMicros; + } + if (InboundApplyBudget.SoftTimeBudgetMicros > 0 && + TableCacheElapsedMicros >= static_cast(InboundApplyBudget.SoftTimeBudgetMicros)) + { + UE_LOG(LogSpacetimeDb_Connection, + Warning, + TEXT("SpacetimeDB table cache apply exceeded soft budget: table=%s elapsed=%.2fus budget=%lldus row_ops=%d"), + *Handler->GetTableName(), + TableCacheElapsedMicros, + InboundApplyBudget.SoftTimeBudgetMicros, + TableUpdate.Rows.Num()); + } + if (bHasNonEmptyDiff) + { + FPendingTableBroadcast& PendingBroadcast = TableUpdateHandlersScratch.AddDefaulted_GetRef(); + PendingBroadcast.Handler = Handler; + PendingBroadcast.StatsIndex = TableStatsIndex; + } + } + } + + for (FPendingTableBroadcast& PendingBroadcast : TableUpdateHandlersScratch) + { + TSharedPtr& Handler = PendingBroadcast.Handler; + checkf(Handler.IsValid(), TEXT("Invalid pending SpacetimeDB table broadcast handler.")); + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_TableBroadcastDiff); + const uint64 BroadcastStartCycles = FPlatformTime::Cycles64(); + Handler->BroadcastDiff(this, Context); + const double BroadcastElapsedMicros = + FPlatformTime::ToMilliseconds64(FPlatformTime::Cycles64() - BroadcastStartCycles) * 1000.0; + if (ActiveInboundMessageApplyStats != nullptr && PendingBroadcast.StatsIndex != INDEX_NONE) + { + checkf(ActiveInboundMessageApplyStats->TableStats.IsValidIndex(PendingBroadcast.StatsIndex), + TEXT("Invalid SpacetimeDB inbound apply stats index %d."), + PendingBroadcast.StatsIndex); + ActiveInboundMessageApplyStats->TableStats[PendingBroadcast.StatsIndex].BroadcastMicros = BroadcastElapsedMicros; + } + if (InboundApplyBudget.SoftTimeBudgetMicros > 0 && + BroadcastElapsedMicros >= static_cast(InboundApplyBudget.SoftTimeBudgetMicros)) + { + UE_LOG(LogSpacetimeDb_Connection, + Warning, + TEXT("SpacetimeDB table broadcast exceeded soft budget: table=%s elapsed=%.2fus budget=%lldus"), + *Handler->GetTableName(), + BroadcastElapsedMicros, + InboundApplyBudget.SoftTimeBudgetMicros); + } + } + TableUpdateHandlersScratch.Reset(); } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp index 64f2bbc66b0..e17fbb04150 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp @@ -136,13 +136,14 @@ UDbConnectionBase* UDbConnectionBuilderBase::BuildConnection(UDbConnectionBase* *ModuleName, *CompressionName); - Connection->WebSocket->OnConnectionError.AddDynamic(Connection, &UDbConnectionBase::HandleWSError); - Connection->WebSocket->OnClosed.AddDynamic(Connection, &UDbConnectionBase::HandleWSClosed); - Connection->WebSocket->OnBinaryMessageReceived.AddDynamic(Connection, &UDbConnectionBase::HandleWSBinaryMessage); - // Set the initialization token for the WebSocket connection - Connection->WebSocket->SetInitToken(Token); - // Connect the WebSocket to the constructed URL - Connection->WebSocket->Connect(WebSocketUrl); + Connection->WebSocket->OnConnectionError.AddDynamic(Connection, &UDbConnectionBase::HandleWSError); + Connection->WebSocket->OnClosed.AddDynamic(Connection, &UDbConnectionBase::HandleWSClosed); + Connection->WebSocket->SetNativeBinaryMessageTarget(Connection); + // Set the initialization token for the WebSocket connection + Connection->WebSocket->SetInitToken(Token); + Connection->StartInboundMessageWorker(); + // Connect the WebSocket to the constructed URL + Connection->WebSocket->Connect(WebSocketUrl); return Connection; } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp index 7b4bbe53f40..5aacb0b2b9b 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp @@ -1,5 +1,6 @@ #include "Connection/Websocket.h" +#include "Connection/DbConnectionBase.h" #include "WebSocketsModule.h" // Required for FWebSocketsModule #include "SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h" #include "ModuleBindings/Types/ServerMessageType.g.h" @@ -16,13 +17,14 @@ UWebsocketManager::UWebsocketManager() void UWebsocketManager::BeginDestroy() { - UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager::BeginDestroy: Cleaning up WebSocket.")); - if (!HasAnyFlags(RF_ClassDefaultObject)) - { - Disconnect(); - } - Super::BeginDestroy(); -} + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager::BeginDestroy: Cleaning up WebSocket.")); + if (!HasAnyFlags(RF_ClassDefaultObject)) + { + Disconnect(); + } + NativeBinaryMessageTarget.Reset(); + Super::BeginDestroy(); +} void UWebsocketManager::Connect(const FString& ServerUrl) { @@ -128,15 +130,20 @@ bool UWebsocketManager::SendMessage(const TArray& Data) return true; } -bool UWebsocketManager::IsConnected() const -{ - return WebSocket.IsValid() && WebSocket->IsConnected(); -} - -void UWebsocketManager::SetInitToken(FString Token) -{ - InitToken = Token; -} +bool UWebsocketManager::IsConnected() const +{ + return WebSocket.IsValid() && WebSocket->IsConnected(); +} + +void UWebsocketManager::SetNativeBinaryMessageTarget(UDbConnectionBase* Target) +{ + NativeBinaryMessageTarget = Target; +} + +void UWebsocketManager::SetInitToken(FString Token) +{ + InitToken = Token; +} void UWebsocketManager::HandleConnected() { @@ -157,39 +164,56 @@ void UWebsocketManager::HandleMessageReceived(const FString& Message) OnMessageReceived.Broadcast(Message); } -void UWebsocketManager::HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment) -{ - if (Size == 0) - { +void UWebsocketManager::HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment) +{ + if (Size == 0) + { return; } - - // Handle binary messages, which may be fragmented - const uint8* Bytes = static_cast(Data); - - // Append this fragment to our buffer - IncompleteMessage.Append(Bytes, Size); - - // If this is the last fragment, we have the complete message - if (bIsLastFragment) - { - // We have the complete message - TArray MessageBytes = IncompleteMessage; - IncompleteMessage.Reset(); - bAwaitingBinaryFragments = false; - - // Forward the complete binary payload to listeners. - OnBinaryMessageReceived.Broadcast(MessageBytes); - } - else - { - // More fragments are coming - bAwaitingBinaryFragments = true; - } -} - -void UWebsocketManager::HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean) -{ + + // Handle binary messages, which may be fragmented + const uint8* Bytes = static_cast(Data); + + if (!bAwaitingBinaryFragments && bIsLastFragment) + { + TArray MessageBytes; + MessageBytes.Append(Bytes, Size); + DispatchCompleteBinaryMessage(MoveTemp(MessageBytes)); + return; + } + + // Append this fragment to our buffer + IncompleteMessage.Append(Bytes, Size); + + // If this is the last fragment, we have the complete message + if (bIsLastFragment) + { + // We have the complete message + TArray MessageBytes = MoveTemp(IncompleteMessage); + bAwaitingBinaryFragments = false; + + DispatchCompleteBinaryMessage(MoveTemp(MessageBytes)); + } + else + { + // More fragments are coming + bAwaitingBinaryFragments = true; + } +} + +void UWebsocketManager::DispatchCompleteBinaryMessage(TArray&& Message) +{ + if (UDbConnectionBase* Target = NativeBinaryMessageTarget.Get()) + { + Target->HandleWSBinaryMessageOwned(MoveTemp(Message)); + return; + } + + OnBinaryMessageReceived.Broadcast(Message); +} + +void UWebsocketManager::HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean) +{ UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager: WebSocket Closed. Status: %d, Reason: %s, Clean: %s"), StatusCode, *Reason, bWasClean ? TEXT("true") : TEXT("false")); // Notify listeners about the closure diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h index 7d44d69c24f..8999bf9fd9c 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h @@ -9,6 +9,29 @@ namespace UE::SpacetimeDB { + enum class EQueryRowsApplyMode : uint8 + { + Inserts, + Deletes + }; + + namespace Private + { + template + static void AddParsedRowWithBsatn(const uint8* RowData, int32 RowLength, TArray>& OutRows) + { + checkf(RowLength >= 0, TEXT("Cannot parse a negative BSATN row length: %d"), RowLength); + checkf(RowData != nullptr || RowLength == 0, TEXT("Cannot parse null BSATN row data with length %d"), RowLength); + + RowType Row = DeserializeView(RowData, RowLength); + TArray Bsatn; + if (RowLength > 0) + { + Bsatn.Append(RowData, RowLength); + } + OutRows.Add(FWithBsatn(MoveTemp(Bsatn), MoveTemp(Row))); + } + } /** Parse a single row list based on its size hint and retain BSATN bytes */ template @@ -21,20 +44,19 @@ namespace UE::SpacetimeDB if (List.SizeHint.IsFixedSize()) { // Get the fixed size from the size hint - uint16 Size = List.SizeHint.GetAsFixedSize(); + const uint16 Size = List.SizeHint.GetAsFixedSize(); if (Size > 0) { + checkf(List.RowsData.Num() % Size == 0, + TEXT("Fixed-size BSATN row list has %d bytes, which is not divisible by row size %u"), + List.RowsData.Num(), + static_cast(Size)); // If the size is valid, parse the rows based on the fixed size - int32 Count = List.RowsData.Num() / Size; + const int32 Count = List.RowsData.Num() / Size; + OutRows.Reserve(OutRows.Num() + Count); for (int32 i = 0; i < Count; ++i) { - // Create a slice of the row data based on the fixed size - TArray Slice; - Slice.Append(List.RowsData.GetData() + i * Size, Size); - // Deserialize the row from the slice - RowType Row = UE::SpacetimeDB::Deserialize(Slice); - // Add the row with its BSATN bytes to the output array - OutRows.Add(FWithBsatn(Slice, Row)); + Private::AddParsedRowWithBsatn(List.RowsData.GetData() + i * Size, Size, OutRows); } return; } @@ -43,26 +65,29 @@ namespace UE::SpacetimeDB else if (List.SizeHint.IsRowOffsets()) { // Get the offsets from the size hint - TArray Offsets = List.SizeHint.GetAsRowOffsets(); + const TArray& Offsets = List.SizeHint.MessageData.Get>(); if (Offsets.Num() > 0) { // If the offsets are valid, parse the rows based on the offsets - UEReader Reader(List.RowsData); + OutRows.Reserve(OutRows.Num() + Offsets.Num()); for (int32 i = 0; i < Offsets.Num(); ++i) { // If this is the last offset, read until the end of the data - int64 Start = Offsets[i]; - int64 End = (i + 1 < Offsets.Num()) ? Offsets[i + 1] : List.RowsData.Num(); - int64 Length = End - Start; - TArray Slice; - Slice.Append(List.RowsData.GetData() + Start, Length); - - // Deserialize the row from the slice - UEReader SliceReader(Slice); - RowType Row = deserialize(SliceReader); - - // Add the row with its BSATN bytes to the output array - OutRows.Add(FWithBsatn(Slice, Row)); + const uint64 Start = Offsets[i]; + const uint64 End = (i + 1 < Offsets.Num()) ? Offsets[i + 1] : static_cast(List.RowsData.Num()); + checkf(Start <= End, + TEXT("BSATN row offsets are not sorted: start=%llu end=%llu row_index=%d"), + Start, + End, + i); + checkf(End <= static_cast(List.RowsData.Num()), + TEXT("BSATN row offset %llu exceeds row data size %d at row_index=%d"), + End, + List.RowsData.Num(), + i); + const int32 RowStart = static_cast(Start); + const int32 RowLength = static_cast(End - Start); + Private::AddParsedRowWithBsatn(List.RowsData.GetData() + RowStart, RowLength, OutRows); } } } @@ -79,14 +104,14 @@ namespace UE::SpacetimeDB { if (RowSet.IsPersistentTable()) { - const FPersistentTableRowsType Persistent = RowSet.GetAsPersistentTable(); + const FPersistentTableRowsType& Persistent = RowSet.MessageData.Get(); ParseRowListWithBsatn(Persistent.Inserts, Inserts); ParseRowListWithBsatn(Persistent.Deletes, Deletes); } // Event-table rows are callback-only inserts and should not create delete paths. else if (RowSet.IsEventTable()) { - const FEventTableRowsType EventRows = RowSet.GetAsEventTable(); + const FEventTableRowsType& EventRows = RowSet.MessageData.Get(); ParseRowListWithBsatn(EventRows.Events, Inserts); } else @@ -100,6 +125,15 @@ namespace UE::SpacetimeDB struct FPreprocessedTableDataBase { virtual ~FPreprocessedTableDataBase() {} + virtual int64 EstimateMemoryBytes() const + { + return sizeof(FPreprocessedTableDataBase); + } + int32 InsertRowCount = 0; + int32 DeleteRowCount = 0; + int32 RowSetCount = 0; + int64 InsertRowBytes = 0; + int64 DeleteRowBytes = 0; }; /** A wrapper for a row type that includes its bsatn value. Used to store rows with their bsatn values. */ @@ -109,6 +143,23 @@ namespace UE::SpacetimeDB // The type of the row being processed TArray> Inserts; TArray> Deletes; + + virtual int64 EstimateMemoryBytes() const override + { + auto EstimateRowsBytes = [](const TArray>& Rows) + { + int64 Bytes = Rows.GetAllocatedSize(); + for (const FWithBsatn& Row : Rows) + { + Bytes += Row.Bsatn.GetAllocatedSize(); + } + return Bytes; + }; + + return sizeof(TPreprocessedTableData) + + EstimateRowsBytes(Inserts) + + EstimateRowsBytes(Deletes); + } }; /** Interface for deserializing table rows from a database update. Allows for different row types to be processed in SDK. */ @@ -117,7 +168,8 @@ namespace UE::SpacetimeDB public: virtual ~ITableRowDeserializer() {} /** Preprocess the table update and return a shared pointer to preprocessed data. */ - virtual TSharedPtr PreProcess(const TArray& RowSets, const FString TableName) const = 0; + virtual TSharedPtr PreProcess(const TArray& RowSets, const FString& TableName) const = 0; + virtual TSharedPtr PreProcessQueryRows(const FBsatnRowListType& Rows, EQueryRowsApplyMode Mode, const FString& TableName) const = 0; }; /** Specialization of ITableRowDeserializer for a specific row type not defined in SDK. Used to deserialize rows of a specific type from a database update. */ @@ -125,24 +177,34 @@ namespace UE::SpacetimeDB class TTableRowDeserializer : public ITableRowDeserializer { public: - virtual TSharedPtr PreProcess(const TArray& RowSets, const FString TableName) const override + virtual TSharedPtr PreProcess(const TArray& RowSets, const FString& TableName) const override { // Create a new preprocessed table data object for the specific row type TSharedPtr> Result = MakeShared>(); + Result->RowSetCount = RowSets.Num(); // Process each row-set update in the table update for (const FTableUpdateRowsType& RowSet : RowSets) { if (RowSet.IsPersistentTable()) { - const FPersistentTableRowsType Persistent = RowSet.GetAsPersistentTable(); + const FPersistentTableRowsType& Persistent = RowSet.MessageData.Get(); + const int32 InsertCountBefore = Result->Inserts.Num(); + const int32 DeleteCountBefore = Result->Deletes.Num(); ParseRowListWithBsatn(Persistent.Inserts, Result->Inserts); ParseRowListWithBsatn(Persistent.Deletes, Result->Deletes); + Result->InsertRowCount += Result->Inserts.Num() - InsertCountBefore; + Result->DeleteRowCount += Result->Deletes.Num() - DeleteCountBefore; + Result->InsertRowBytes += Persistent.Inserts.RowsData.Num(); + Result->DeleteRowBytes += Persistent.Deletes.RowsData.Num(); } else if (RowSet.IsEventTable()) { // Event rows are insert-style callback payloads only. - const FEventTableRowsType Events = RowSet.GetAsEventTable(); + const FEventTableRowsType& Events = RowSet.MessageData.Get(); + const int32 InsertCountBefore = Result->Inserts.Num(); ParseRowListWithBsatn(Events.Events, Result->Inserts); + Result->InsertRowCount += Result->Inserts.Num() - InsertCountBefore; + Result->InsertRowBytes += Events.Events.RowsData.Num(); } else { @@ -151,5 +213,34 @@ namespace UE::SpacetimeDB } return Result; } + + virtual TSharedPtr PreProcessQueryRows(const FBsatnRowListType& Rows, EQueryRowsApplyMode Mode, const FString& TableName) const override + { + TSharedPtr> Result = MakeShared>(); + Result->RowSetCount = 1; + switch (Mode) + { + case EQueryRowsApplyMode::Inserts: + { + const int32 InsertCountBefore = Result->Inserts.Num(); + ParseRowListWithBsatn(Rows, Result->Inserts); + Result->InsertRowCount += Result->Inserts.Num() - InsertCountBefore; + Result->InsertRowBytes += Rows.RowsData.Num(); + break; + } + case EQueryRowsApplyMode::Deletes: + { + const int32 DeleteCountBefore = Result->Deletes.Num(); + ParseRowListWithBsatn(Rows, Result->Deletes); + Result->DeleteRowCount += Result->Deletes.Num() - DeleteCountBefore; + Result->DeleteRowBytes += Rows.RowsData.Num(); + break; + } + default: + checkf(false, TEXT("Unsupported query-row apply mode for table %s"), *TableName); + break; + } + return Result; + } }; } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h index 8b25806e282..283d2261383 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h @@ -167,14 +167,33 @@ namespace UE::SpacetimeDB { * This class provides a UE-friendly interface for deserializing BSATN data, * handling conversions between standard C++ types and UE types. * - * The reader stores a copy of the input data to prevent lifetime issues - * that could occur if the original TArray is destroyed. + * The TArray and std::vector constructors retain copy semantics for callers + * that need owned reader storage. The pointer and array-view constructors are + * caller-owned views intended for immediate, scoped parsing. */ class UEReader { private: - std::vector stored_data; ///< Local copy of data to ensure lifetime + std::vector stored_data; ///< Local copy of data when ownership is requested ::SpacetimeDb::bsatn::Reader core_reader; ///< Underlying BSATN reader + static const uint8_t* ValidateReaderData(const uint8_t* data, int32 size) + { + checkf(size >= 0, TEXT("UEReader cannot read a negative byte count: %d"), size); + checkf(data != nullptr || size == 0, TEXT("UEReader received null data for %d bytes"), size); + static constexpr uint8_t EmptyReaderByte = 0; + if (data == nullptr) + { + return &EmptyReaderByte; + } + return data; + } + + static size_t ValidateReaderSize(int32 size) + { + checkf(size >= 0, TEXT("UEReader cannot read a negative byte count: %d"), size); + return static_cast(size); + } + public: /** * Construct a reader from a UE byte array @@ -198,6 +217,12 @@ namespace UE::SpacetimeDB { : stored_data(data), core_reader(stored_data) {} + explicit UEReader(const uint8_t* data, int32 size) + : core_reader(ValidateReaderData(data, size), ValidateReaderSize(size)) {} + + explicit UEReader(TConstArrayView data) + : UEReader(data.GetData(), data.Num()) {} + // ------------------------------------------------------------------------- // Primitive Type Readers // ------------------------------------------------------------------------- @@ -894,6 +919,27 @@ namespace UE::SpacetimeDB { } } + template + T DeserializeView(const uint8* data, int32 size) { + + UEReader reader(data, size); + + if constexpr (is_tarray_v) { + return DeserializeHelper::deserialize(reader); + } + else if constexpr (is_toptional_v) { + return DeserializeHelper::deserialize(reader); + } + else { + return deserialize(reader); + } + } + + template + T DeserializeView(TConstArrayView data) { + return DeserializeView(data.GetData(), data.Num()); + } + /** @} */ // end of HighLevelAPI group // ============================================================================= @@ -932,7 +978,7 @@ namespace UE::SpacetimeDB { * @brief Helper macro to generate deserialize specialization for TOptional * * Use this macro when you have structs with TOptional fields of custom types. - * + * * @Note: We are not using TOptional directly becouse it is not compatable wiht blueprints. * This macro is kept for future compatibility if things change on the Engine side. * @@ -1322,7 +1368,7 @@ namespace UE::SpacetimeDB { UE_SPACETIMEDB_ENABLE_TARRAY(float) UE_SPACETIMEDB_ENABLE_TARRAY(double) - + UE_SPACETIMEDB_ENABLE_TARRAY(bool) // Large integer type containers @@ -1330,8 +1376,8 @@ namespace UE::SpacetimeDB { UE_SPACETIMEDB_ENABLE_TARRAY(FSpacetimeDBUInt256) UE_SPACETIMEDB_ENABLE_TARRAY(FSpacetimeDBInt128) UE_SPACETIMEDB_ENABLE_TARRAY(FSpacetimeDBInt256) - + /** @} */ // end of CommonSpecializations group -} // namespace UE::SpacetimeDB \ No newline at end of file +} // namespace UE::SpacetimeDB diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h index b32221ce641..b2b130743bb 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h @@ -1,48 +1,47 @@ -#pragma once - +#pragma once + #include "CoreMinimal.h" -#include "Containers/Ticker.h" #include "UObject/NoExportTypes.h" -#include "Types/Builtins.h" -#include "Websocket.h" -#include "Subscription.h" -#include "ModuleBindings/Types/ServerMessageType.g.h" -#include "DBCache/TableAppliedDiff.h" -#include "HAL/CriticalSection.h" -#include "Containers/Queue.h" +#include "Types/Builtins.h" +#include "Websocket.h" +#include "Subscription.h" +#include "ModuleBindings/Types/ServerMessageType.g.h" +#include "DBCache/TableAppliedDiff.h" +#include "HAL/CriticalSection.h" #include "HAL/ThreadSafeBool.h" #include "BSATN/UEBSATNHelpers.h" #include "Connection/Callback.h" #include "LogCategory.h" #include - -#include "DbConnectionBase.generated.h" - -// Forward declarations -class UDbConnectionBuilder; -class UProcedureCallbacks; - -/** Macro for safae way to bind delegate without needing to write Function name as an FName. */ -#define BIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ - DelegateVar.BindUFunction(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) - -/** Macro for safe way to unbind delegate without needing to write Function name as an FName. */ -#define UNBIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ - DelegateVar.Remove(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) - + +#include "DbConnectionBase.generated.h" + +// Forward declarations +class UDbConnectionBuilder; +class UProcedureCallbacks; +class FSpacetimeDbInboundWorker; + +/** Macro for safae way to bind delegate without needing to write Function name as an FName. */ +#define BIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ + DelegateVar.BindUFunction(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) + +/** Macro for safe way to unbind delegate without needing to write Function name as an FName. */ +#define UNBIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ + DelegateVar.Remove(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) + /** Delegate called when the connection attempt fails. */ -DECLARE_DYNAMIC_DELEGATE_OneParam( - FOnConnectErrorDelegate, - const FString&, ErrorMessage); - -/** Called when a connection is established. */ -DECLARE_DYNAMIC_DELEGATE_ThreeParams( - FOnConnectBaseDelegate, - UDbConnectionBase*, Connection, - FSpacetimeDBIdentity, Identity, - const FString&, Token); - -/** Called when a connection closes. */ +DECLARE_DYNAMIC_DELEGATE_OneParam( + FOnConnectErrorDelegate, + const FString&, ErrorMessage); + +/** Called when a connection is established. */ +DECLARE_DYNAMIC_DELEGATE_ThreeParams( + FOnConnectBaseDelegate, + UDbConnectionBase*, Connection, + FSpacetimeDBIdentity, Identity, + const FString&, Token); + +/** Called when a connection closes. */ DECLARE_DYNAMIC_DELEGATE_TwoParams( FOnDisconnectBaseDelegate, UDbConnectionBase*, Connection, @@ -95,6 +94,94 @@ FORCEINLINE uint32 GetTypeHash(const FPreprocessedTableKey& Key) return GetTypeHash(Key.TableName); } +using FPreprocessedTableDataMap = TMap>>; + +struct FInboundRawMessage +{ + uint64 ConnectionEpoch = 0; + uint64 SequenceId = 0; + int32 QueueDepthAtEnqueue = 0; + int64 QueuedBytesAtEnqueue = 0; + TArray Payload; +}; + +struct FInboundParsedMessage +{ + uint64 ConnectionEpoch = 0; + uint64 SequenceId = 0; + int32 PayloadSizeBytes = 0; + int32 DecodedPayloadSizeBytes = 0; + int64 EstimatedMemoryBytes = 0; + uint8 CompressionTag = 0; + int32 QueueDepthAtEnqueue = 0; + int64 QueuedBytesAtEnqueue = 0; + bool bProtocolError = false; + FString ProtocolError; + FServerMessageType Message; + FPreprocessedTableDataMap PreprocessedTableData; +}; + +struct FSpacetimeDBTableApplyStats +{ + FString TableName; + int32 RowSetCount = 0; + int32 InsertRowCount = 0; + int32 DeleteRowCount = 0; + int64 InsertRowBytes = 0; + int64 DeleteRowBytes = 0; + double CacheMicros = 0.0; + double BroadcastMicros = 0.0; + bool bProducedDiff = false; +}; + +struct FSpacetimeDBInboundMessageApplyStats +{ + FString MessageKind; + FString ReducerName; + uint32 RequestId = 0; + uint64 SequenceId = 0; + int32 PayloadSizeBytes = 0; + int32 QueueDepthAtEnqueue = 0; + int64 QueuedBytesAtEnqueue = 0; + TArray TableStats; +}; + +USTRUCT(BlueprintType) +struct SPACETIMEDBSDK_API FSpacetimeDBInboundApplyBudget +{ + GENERATED_BODY() + + UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "SpacetimeDB") + int32 MaxMessagesPerFrame = 256; + + UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "SpacetimeDB") + int64 MaxPayloadBytesPerFrame = 4 * 1024 * 1024; + + UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "SpacetimeDB") + int32 MinMessagesPerFrame = 1; + + UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "SpacetimeDB") + int64 SoftTimeBudgetMicros = 0; + + UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "SpacetimeDB") + bool bDrainAllPendingMessages = false; + + static FSpacetimeDBInboundApplyBudget MakeDrainAllPendingMessages() + { + FSpacetimeDBInboundApplyBudget Budget; + Budget.bDrainAllPendingMessages = true; + return Budget; + } + + void Sanitize() + { + MaxMessagesPerFrame = FMath::Max(1, MaxMessagesPerFrame); + MaxPayloadBytesPerFrame = FMath::Max(1, MaxPayloadBytesPerFrame); + MinMessagesPerFrame = FMath::Clamp(MinMessagesPerFrame, 1, MaxMessagesPerFrame); + SoftTimeBudgetMicros = FMath::Max(0, SoftTimeBudgetMicros); + } +}; + template struct THasOnDeleteDelegate : std::false_type { @@ -114,251 +201,596 @@ template struct THasOnUpdateDelegate> : std::true_type { }; - + +template +const void* GetNativeTableListenerTypeId() +{ + static const uint8 TypeId = 0; + return &TypeId; +} + +enum class ESpacetimeDBNativeListenerDispatchMode : uint8 +{ + NativeAndDynamic, + NativeOnly +}; + +struct FNativeTableListenerBinding +{ + using FInsertThunk = void(*)(void* Owner, const void* Context, const void* Row); + using FUpdateThunk = void(*)(void* Owner, const void* Context, const void* OldRow, const void* NewRow); + using FDeleteThunk = void(*)(void* Owner, const void* Context, const void* Row); + using FDiffThunk = void(*)(void* Owner, const void* Context, const void* Diff); + + TWeakObjectPtr Owner; + void* OwnerKey = nullptr; + const void* RowTypeId = nullptr; + const void* EventContextTypeId = nullptr; + FInsertThunk InsertThunk = nullptr; + FUpdateThunk UpdateThunk = nullptr; + FDeleteThunk DeleteThunk = nullptr; + FDiffThunk DiffThunk = nullptr; + ESpacetimeDBNativeListenerDispatchMode DispatchMode = ESpacetimeDBNativeListenerDispatchMode::NativeAndDynamic; + + bool IsComplete() const + { + return OwnerKey != nullptr && + RowTypeId != nullptr && + EventContextTypeId != nullptr && + (DiffThunk != nullptr || + (InsertThunk != nullptr && + UpdateThunk != nullptr && + DeleteThunk != nullptr)); + } +}; + UCLASS() -class SPACETIMEDBSDK_API UDbConnectionBase : public UObject -{ - GENERATED_BODY() - -public: - - /** The default constructor is private to prevent instantiation without using the builder. */ - explicit UDbConnectionBase(const FObjectInitializer& ObjectInitializer = FObjectInitializer::Get()); - - /** Disconnect from the server. */ - UFUNCTION(BlueprintCallable, Category="SpacetimeDB") - void Disconnect(); - - /** Check if the underlying WebSocket is connected. */ - UFUNCTION(BlueprintPure, Category="SpacetimeDB") - bool IsActive() const; - - UFUNCTION(BlueprintCallable, Category="SpacetimeDB") - void FrameTick(); - +class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGameObject +{ + GENERATED_BODY() + +public: + + /** The default constructor is private to prevent instantiation without using the builder. */ + explicit UDbConnectionBase(const FObjectInitializer& ObjectInitializer = FObjectInitializer::Get()); + + virtual void BeginDestroy() override; + + /** Disconnect from the server. */ + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") + void Disconnect(); + + /** Check if the underlying WebSocket is connected. */ + UFUNCTION(BlueprintPure, Category="SpacetimeDB") + bool IsActive() const; + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") - void SetAutoTicking(bool bAutoTick); - - /** Send a raw JSON message to the server. */ - bool SendRawMessage(const FString& Message); - /** Send a raw binary message to the server. */ - bool SendRawMessage(const TArray& Message); - - /** Get the current subscription builder. This is used to create subscriptions. */ - UFUNCTION() - USubscriptionBuilderBase* SubscriptionBuilderBase(); - - /** Get the current identity of the SpacetimeDB instance. This is used to identify the connection. */ - UFUNCTION(BlueprintPure, Category = "SpacetimeDB") - bool TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const; - - /** Get the current connection id. This is used to identify the connection. */ - UFUNCTION(BlueprintPure, Category = "SpacetimeDB") - FSpacetimeDBConnectionId GetConnectionId() const; - - // Typed reducer call helper: hides BSATN bytes from callers. + void FrameTick(); + + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") + void SetAutoTicking(bool bAutoTick) { bIsAutoTicking = bAutoTick; } + + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") + void SetInboundApplyBudget(FSpacetimeDBInboundApplyBudget InBudget) + { + InBudget.Sanitize(); + InboundApplyBudget = InBudget; + } + + /** Send a raw JSON message to the server. */ + bool SendRawMessage(const FString& Message); + /** Send a raw binary message to the server. */ + bool SendRawMessage(const TArray& Message); + + /** Get the current subscription builder. This is used to create subscriptions. */ + UFUNCTION() + USubscriptionBuilderBase* SubscriptionBuilderBase(); + + /** Get the current identity of the SpacetimeDB instance. This is used to identify the connection. */ + UFUNCTION(BlueprintPure, Category = "SpacetimeDB") + bool TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const; + + /** Get the current connection id. This is used to identify the connection. */ + UFUNCTION(BlueprintPure, Category = "SpacetimeDB") + FSpacetimeDBConnectionId GetConnectionId() const; + + // Typed reducer call helper: hides BSATN bytes from callers. template uint32 CallReducerTyped(const FString& Reducer, const ArgsStruct& Args) { TArray Bytes = UE::SpacetimeDB::Serialize(Args); return InternalCallReducer(Reducer, MoveTemp(Bytes)); } - - template - void CallProcedureTyped(const FString& ProcedureName, const ArgsStruct& Args, const FOnProcedureCompleteDelegate& Callback) - { - TArray Bytes = UE::SpacetimeDB::Serialize(Args); - InternalCallProcedure(ProcedureName, MoveTemp(Bytes), Callback); - } - - template - void RegisterTable(const FString& TableName) - { - FScopeLock Lock(&TableDeserializersMutex); - TableDeserializers.Add(TableName, MakeShared>()); - } - - /** Internal interface for applying table updates generically */ - class ITableUpdateHandler - { - public: - virtual ~ITableUpdateHandler() {} - - /** Update the in-memory cache for the table and store the diff */ - virtual void UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context) = 0; - - /** Broadcast the previously stored diff */ - virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) = 0; - }; - - template - class TTableUpdateHandler : public ITableUpdateHandler - { - public: - explicit TTableUpdateHandler(TableClass* InTable) : Table(InTable) {} - - //** Update the in-memory cache for the table and store the diff */ - virtual void UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context) override - { - // Attempt to take preprocessed data if available - TSharedPtr> Pre; - if (Conn->TakePreprocessedTableData(Update, Pre)) - { - // If preprocessed data is available, use it to update the table - LastDiff = Table->Update(Pre->Inserts, Pre->Deletes); - } - else - { - // If no preprocessed data, process the update directly. Backup - UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("No preprocessed data for table update. Processing directly.")); - TArray> Inserts, Deletes; - UE::SpacetimeDB::ProcessTableUpdateWithBsatn(Update, Inserts, Deletes); - LastDiff = Table->Update(Inserts, Deletes); - } - } - //** Broadcast the last stored diff to the table's delegates */ - virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) override - { - EventContext& Ctx = *reinterpret_cast(Context); - Conn->BroadcastDiff(Table, LastDiff, Ctx); - } - - private: - TableClass* Table; - FTableAppliedDiff LastDiff; - }; - //** Register a table with the connection. This will allow the connection to handle updates for the table. - template - void RegisterTable(const FString& TableName, TableClass* Table) - { - RegisterTable(TableName); - FScopeLock Lock(&RegisteredTablesMutex); - RegisteredTables.Add(TableName, MakeShared>(Table)); - } - //** Take preprocessed table row data. */ - template + + template + void CallProcedureTyped(const FString& ProcedureName, const ArgsStruct& Args, const FOnProcedureCompleteDelegate& Callback) + { + TArray Bytes = UE::SpacetimeDB::Serialize(Args); + InternalCallProcedure(ProcedureName, MoveTemp(Bytes), Callback); + } + + template + void RegisterTable(const FString& TableName) + { + FScopeLock Lock(&TableDeserializersMutex); + TableDeserializers.Add(TableName, MakeShared>()); + } + + /** Internal interface for applying table updates generically */ + class ITableUpdateHandler + { + public: + virtual ~ITableUpdateHandler() {} + + /** Update the in-memory cache for the table and store the diff */ + virtual bool UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context, FSpacetimeDBTableApplyStats* OutStats) = 0; + + /** Broadcast the previously stored diff */ + virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) = 0; + virtual const FString& GetTableName() const = 0; + virtual void RegisterNativeListener(const FNativeTableListenerBinding& Binding) = 0; + virtual void UnregisterNativeListener(void* OwnerKey) = 0; + }; + + template + class TTableUpdateHandler : public ITableUpdateHandler + { + public: + explicit TTableUpdateHandler(const FString& InTableName, TableClass* InTable) + : TableName(InTableName) + , Table(InTable) + { + } + + //** Update the in-memory cache for the table and store the diff */ + virtual bool UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context, FSpacetimeDBTableApplyStats* OutStats) override + { + if (PendingDiffReadIndex == PendingDiffs.Num()) + { + PendingDiffs.Reset(); + PendingDiffReadIndex = 0; + } + + TSharedPtr> Pre; + const bool bTookPreprocessedData = Conn->TakePreprocessedTableData(Update, Pre); + checkf(bTookPreprocessedData && Pre.IsValid(), TEXT("Missing message-scoped preprocessed data for table '%s'."), *Update.TableName); + if (OutStats != nullptr) + { + OutStats->TableName = TableName; + OutStats->RowSetCount = Pre->RowSetCount; + OutStats->InsertRowCount = Pre->InsertRowCount; + OutStats->DeleteRowCount = Pre->DeleteRowCount; + OutStats->InsertRowBytes = Pre->InsertRowBytes; + OutStats->DeleteRowBytes = Pre->DeleteRowBytes; + } + FTableAppliedDiff AppliedDiff = Table->Update(MoveTemp(Pre->Inserts), MoveTemp(Pre->Deletes)); + if (AppliedDiff.IsEmpty()) + { + if (OutStats != nullptr) + { + OutStats->bProducedDiff = false; + } + return false; + } + if (OutStats != nullptr) + { + OutStats->bProducedDiff = true; + } + PendingDiffs.Add(MoveTemp(AppliedDiff)); + return true; + } + //** Broadcast the last stored diff to the table's delegates */ + virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) override + { + checkf(PendingDiffReadIndex < PendingDiffs.Num(), TEXT("Missing pending SpacetimeDB table diff for broadcast.")); + EventContext& Ctx = *reinterpret_cast(Context); + const FTableAppliedDiff& Diff = PendingDiffs[PendingDiffReadIndex]; + const bool bSuppressDynamicDispatch = BroadcastNativeDiff(Diff, Ctx); + if (!bSuppressDynamicDispatch) + { + Conn->BroadcastDiff(Table, Diff, Ctx); + } + ++PendingDiffReadIndex; + if (PendingDiffReadIndex == PendingDiffs.Num()) + { + PendingDiffs.Reset(); + PendingDiffReadIndex = 0; + } + } + + virtual const FString& GetTableName() const override + { + return TableName; + } + + virtual void RegisterNativeListener(const FNativeTableListenerBinding& Binding) override + { + checkf(!bBroadcastingNativeListeners, + TEXT("Cannot register native SpacetimeDB table listener during broadcast for table '%s'."), + *TableName); + checkf(Binding.IsComplete(), TEXT("Incomplete native SpacetimeDB table listener for table '%s'."), *TableName); + checkf(Binding.Owner.IsValid(), + TEXT("Cannot register invalid native SpacetimeDB table listener owner for table '%s'."), + *TableName); + checkf(Binding.RowTypeId == GetNativeTableListenerTypeId(), + TEXT("Native SpacetimeDB table listener row type mismatch for table '%s'."), *TableName); + checkf(Binding.EventContextTypeId == GetNativeTableListenerTypeId(), + TEXT("Native SpacetimeDB table listener context type mismatch for table '%s'."), *TableName); + for (const FNativeTableListenerBinding& ExistingBinding : NativeListeners) + { + checkf(ExistingBinding.OwnerKey != Binding.OwnerKey, + TEXT("Duplicate native SpacetimeDB table listener owner for table '%s'."), + *TableName); + } + NativeListeners.Add(Binding); + } + + virtual void UnregisterNativeListener(void* OwnerKey) override + { + checkf(!bBroadcastingNativeListeners, + TEXT("Cannot unregister native SpacetimeDB table listener during broadcast for table '%s'."), + *TableName); + checkf(OwnerKey != nullptr, TEXT("Cannot unregister null native SpacetimeDB table listener owner for table '%s'."), *TableName); + const int32 ListenerIndex = NativeListeners.IndexOfByPredicate( + [OwnerKey](const FNativeTableListenerBinding& Binding) + { + return Binding.OwnerKey == OwnerKey; + }); + checkf(ListenerIndex != INDEX_NONE, + TEXT("Missing native SpacetimeDB table listener for table '%s'."), + *TableName); + NativeListeners.RemoveAtSwap(ListenerIndex, 1, EAllowShrinking::No); + } + + private: + bool BroadcastNativeDiff(const FTableAppliedDiff& Diff, const EventContext& Context) + { + if (NativeListeners.IsEmpty()) + { + return false; + } + + bool bSuppressDynamicDispatch = false; + TArray ExpiredOwnerKeys; + { + TGuardValue BroadcastingScope(bBroadcastingNativeListeners, true); + for (const FNativeTableListenerBinding& Listener : NativeListeners) + { + UObject* OwnerObject = Listener.Owner.Get(); + if (OwnerObject == nullptr) + { + ExpiredOwnerKeys.Add(Listener.OwnerKey); + UE_LOG(LogSpacetimeDb_Connection, + Error, + TEXT("Removing expired native SpacetimeDB table listener owner for table '%s'. Native listeners must be unregistered before owner destruction."), + *TableName); + continue; + } + + BroadcastNativeDiffToListener(Diff, Context, Listener, OwnerObject); + if (Listener.DispatchMode == ESpacetimeDBNativeListenerDispatchMode::NativeOnly) + { + bSuppressDynamicDispatch = true; + } + } + } + + for (void* ExpiredOwnerKey : ExpiredOwnerKeys) + { + NativeListeners.RemoveAllSwap( + [ExpiredOwnerKey](const FNativeTableListenerBinding& Listener) + { + return Listener.OwnerKey == ExpiredOwnerKey; + }, + EAllowShrinking::No); + } + return bSuppressDynamicDispatch; + } + + void BroadcastNativeDiffToListener( + const FTableAppliedDiff& Diff, + const EventContext& Context, + const FNativeTableListenerBinding& Listener, + UObject* OwnerObject) + { + checkf(Listener.IsComplete(), TEXT("Incomplete native SpacetimeDB table listener for table '%s'."), *TableName); + checkf(OwnerObject != nullptr, TEXT("Cannot dispatch native SpacetimeDB table listener to a null owner for table '%s'."), *TableName); + checkf(Listener.Owner.Get() == OwnerObject, + TEXT("Native SpacetimeDB table listener owner identity mismatch for table '%s'."), + *TableName); + if (Listener.DiffThunk != nullptr) + { + Listener.DiffThunk(Listener.OwnerKey, &Context, &Diff); + return; + } + + for (const TSharedPtr& Row : Diff.Inserts) + { + checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB native insert diff row for table '%s'."), *TableName); + Listener.InsertThunk(Listener.OwnerKey, &Context, Row.Get()); + } + + for (const TSharedPtr& Row : Diff.Deletes) + { + checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB native delete diff row for table '%s'."), *TableName); + Listener.DeleteThunk(Listener.OwnerKey, &Context, Row.Get()); + } + + checkf(Diff.UpdateDeletes.Num() == Diff.UpdateInserts.Num(), + TEXT("Mismatched SpacetimeDB native update diff counts for table '%s'."), *TableName); + for (int32 Index = 0; Index < Diff.UpdateInserts.Num(); ++Index) + { + const TSharedPtr& OldRow = Diff.UpdateDeletes[Index]; + const TSharedPtr& NewRow = Diff.UpdateInserts[Index]; + checkf(OldRow.IsValid() && NewRow.IsValid(), TEXT("Invalid SpacetimeDB native update diff row for table '%s'."), *TableName); + Listener.UpdateThunk(Listener.OwnerKey, &Context, OldRow.Get(), NewRow.Get()); + } + } + + FString TableName; + TableClass* Table; + TArray> PendingDiffs; + int32 PendingDiffReadIndex = 0; + TArray NativeListeners; + bool bBroadcastingNativeListeners = false; + }; + //** Register a table with the connection. This will allow the connection to handle updates for the table. + template + void RegisterTable(const FString& TableName, TableClass* Table) + { + RegisterTable(TableName); + FScopeLock Lock(&RegisteredTablesMutex); + RegisteredTables.Add(TableName, MakeShared>(TableName, Table)); + RegisteredTablesSnapshot = MakeShared>>(RegisteredTables); + } + + template + void RegisterNativeTableListener( + const FString& TableName, + OwnerType* Owner, + ESpacetimeDBNativeListenerDispatchMode DispatchMode = ESpacetimeDBNativeListenerDispatchMode::NativeAndDynamic) + { + static_assert(std::is_base_of_v, "Native SpacetimeDB table listener owner must derive from UObject."); + checkf(Owner != nullptr, TEXT("Cannot register null native SpacetimeDB table listener owner for table '%s'."), *TableName); + + FNativeTableListenerBinding Binding; + Binding.Owner = Owner; + Binding.OwnerKey = Owner; + Binding.RowTypeId = GetNativeTableListenerTypeId(); + Binding.EventContextTypeId = GetNativeTableListenerTypeId(); + Binding.DispatchMode = DispatchMode; + Binding.InsertThunk = [](void* RawOwner, const void* RawContext, const void* RawRow) + { + (static_cast(RawOwner)->*InsertFn)( + *static_cast(RawContext), + *static_cast(RawRow)); + }; + Binding.UpdateThunk = [](void* RawOwner, const void* RawContext, const void* RawOldRow, const void* RawNewRow) + { + (static_cast(RawOwner)->*UpdateFn)( + *static_cast(RawContext), + *static_cast(RawOldRow), + *static_cast(RawNewRow)); + }; + Binding.DeleteThunk = [](void* RawOwner, const void* RawContext, const void* RawRow) + { + (static_cast(RawOwner)->*DeleteFn)( + *static_cast(RawContext), + *static_cast(RawRow)); + }; + + FScopeLock Lock(&RegisteredTablesMutex); + TSharedPtr* Handler = RegisteredTables.Find(TableName); + checkf(Handler != nullptr && Handler->IsValid(), + TEXT("Missing SpacetimeDB table handler while registering native listener for table '%s'."), *TableName); + (*Handler)->RegisterNativeListener(Binding); + } + + template&)> + void RegisterNativeTableDiffListener( + const FString& TableName, + OwnerType* Owner, + ESpacetimeDBNativeListenerDispatchMode DispatchMode = ESpacetimeDBNativeListenerDispatchMode::NativeAndDynamic) + { + static_assert(std::is_base_of_v, "Native SpacetimeDB table diff listener owner must derive from UObject."); + checkf(Owner != nullptr, TEXT("Cannot register null native SpacetimeDB table diff listener owner for table '%s'."), *TableName); + + FNativeTableListenerBinding Binding; + Binding.Owner = Owner; + Binding.OwnerKey = Owner; + Binding.RowTypeId = GetNativeTableListenerTypeId(); + Binding.EventContextTypeId = GetNativeTableListenerTypeId(); + Binding.DispatchMode = DispatchMode; + Binding.DiffThunk = [](void* RawOwner, const void* RawContext, const void* RawDiff) + { + (static_cast(RawOwner)->*DiffFn)( + *static_cast(RawContext), + *static_cast*>(RawDiff)); + }; + + FScopeLock Lock(&RegisteredTablesMutex); + TSharedPtr* Handler = RegisteredTables.Find(TableName); + checkf(Handler != nullptr && Handler->IsValid(), + TEXT("Missing SpacetimeDB table handler while registering native diff listener for table '%s'."), *TableName); + (*Handler)->RegisterNativeListener(Binding); + } + + template + void UnregisterNativeTableListener(const FString& TableName, OwnerType* Owner) + { + static_assert(std::is_base_of_v, "Native SpacetimeDB table listener owner must derive from UObject."); + checkf(Owner != nullptr, TEXT("Cannot unregister null native SpacetimeDB table listener owner for table '%s'."), *TableName); + + FScopeLock Lock(&RegisteredTablesMutex); + TSharedPtr* Handler = RegisteredTables.Find(TableName); + checkf(Handler != nullptr && Handler->IsValid(), + TEXT("Missing SpacetimeDB table handler while unregistering native listener for table '%s'."), *TableName); + (*Handler)->UnregisterNativeListener(Owner); + } + + template&)> + void UnregisterNativeTableDiffListener(const FString& TableName, OwnerType* Owner) + { + static_assert(std::is_base_of_v, "Native SpacetimeDB table diff listener owner must derive from UObject."); + checkf(Owner != nullptr, TEXT("Cannot unregister null native SpacetimeDB table diff listener owner for table '%s'."), *TableName); + + FScopeLock Lock(&RegisteredTablesMutex); + TSharedPtr* Handler = RegisteredTables.Find(TableName); + checkf(Handler != nullptr && Handler->IsValid(), + TEXT("Missing SpacetimeDB table handler while unregistering native diff listener for table '%s'."), *TableName); + (*Handler)->UnregisterNativeListener(Owner); + } + //** Take preprocessed table row data. */ + template bool TakePreprocessedTableData(const FTableUpdateType& Update, TSharedPtr>& OutData) { - FScopeLock Lock(&PreprocessedDataMutex); + checkf(ActivePreprocessedTableData != nullptr, TEXT("No active inbound message while applying table update '%s'."), *Update.TableName); FPreprocessedTableKey Key(Update.TableName); - if (TArray>* Found = PreprocessedTableData.Find(Key)) + TArray>* Found = ActivePreprocessedTableData->Find(Key); + checkf(Found != nullptr && Found->Num() > 0, TEXT("Missing message-scoped preprocessed data for table '%s'."), *Update.TableName); + OutData = StaticCastSharedPtr>((*Found)[0]); + Found->RemoveAt(0, 1, EAllowShrinking::No); + if (Found->Num() == 0) { - if (Found->Num() > 0) - { - OutData = StaticCastSharedPtr>((*Found)[0]); - Found->RemoveAt(0); - if (Found->Num() == 0) - { - PreprocessedTableData.Remove(Key); - } - return OutData.IsValid(); - } - } - return false; - } - - + ActivePreprocessedTableData->Remove(Key); + } + checkf(OutData.IsValid(), TEXT("Invalid message-scoped preprocessed data for table '%s'."), *Update.TableName); + return true; + } + + protected: - - virtual void BeginDestroy() override; friend class UDbConnectionBuilderBase; - friend class UDbConnectionBuilder; - friend class USubscriptionHandleBase; - friend class USubscriptionBuilder; - friend class URemoteReducers; - - /** Allow derived classes to override the delegates used when connecting */ - void SetOnConnectDelegate(const FOnConnectBaseDelegate& Delegate) { OnConnectBaseDelegate = Delegate; } - void SetOnDisconnectDelegate(const FOnDisconnectBaseDelegate& Delegate) { OnDisconnectBaseDelegate = Delegate; } - - UFUNCTION() - void HandleWSError(const FString& Error); - UFUNCTION() - void HandleWSClosed(int32 StatusCode, const FString& Reason, bool bWasClean); - UFUNCTION() - void HandleWSBinaryMessage(const TArray& Message); - - bool OnTickerTick(float DeltaTime); - + friend class UDbConnectionBuilder; + friend class FSpacetimeDbInboundWorker; + friend class UWebsocketManager; + friend class USubscriptionHandleBase; + friend class USubscriptionBuilder; + friend class URemoteReducers; + + /** Allow derived classes to override the delegates used when connecting */ + void SetOnConnectDelegate(const FOnConnectBaseDelegate& Delegate) { OnConnectBaseDelegate = Delegate; } + void SetOnDisconnectDelegate(const FOnDisconnectBaseDelegate& Delegate) { OnDisconnectBaseDelegate = Delegate; } + + UFUNCTION() + void HandleWSError(const FString& Error); + UFUNCTION() + void HandleWSClosed(int32 StatusCode, const FString& Reason, bool bWasClean); + UFUNCTION() + void HandleWSBinaryMessage(const TArray& Message); + void HandleWSBinaryMessageOwned(TArray&& Message); + void StartInboundMessageWorker(); + void StopInboundMessageWorker(); + void ClearInboundMessageQueues(); + void NotifyInboundWorkerIfNeeded(); + void DrainInboundRawMessagesOnWorker(); + bool BuildInboundParsedMessage(const FInboundRawMessage& RawMessage, FInboundParsedMessage& OutMessage); + void EnqueueInboundProtocolError(uint64 SequenceId, int32 PayloadSizeBytes, uint8 CompressionTag, int32 QueueDepthAtEnqueue, int64 QueuedBytesAtEnqueue, const FString& ErrorMessage); + bool IsInboundProtocolErrorQueued() const; + bool IsInboundEpochCurrentAndAccepting(uint64 ConnectionEpoch) const; + void MarkInboundProtocolErrorQueued(); + + virtual void Tick(float DeltaTime) override; + + virtual TStatId GetStatId() const override; + + virtual bool IsTickable() const override; + + virtual bool IsTickableInEditor() const override; + /** Internal handler that processes a single server message. */ void ProcessServerMessage(const FServerMessageType& Message); - void PreProcessDatabaseUpdate(const FDatabaseUpdateType& Update); + void ProcessInboundServerMessage(FInboundParsedMessage& InboundMessage, FSpacetimeDBInboundMessageApplyStats& ApplyStats); + void PreProcessTableUpdateRows(const FString& TableName, const TArray& RowSets, FPreprocessedTableDataMap& OutPreprocessedTableData); + void PreProcessQueryRows(const FQueryRowsType& Rows, UE::SpacetimeDB::EQueryRowsApplyMode Mode, FPreprocessedTableDataMap& OutPreprocessedTableData); + void PreProcessTransactionUpdate(const FTransactionUpdateType& Update, FPreprocessedTableDataMap& OutPreprocessedTableData); + TSharedPtr FindTableDeserializerForPreprocess(const FString& TableName); + void StorePreprocessedTableData(const FString& TableName, TSharedPtr Data, FPreprocessedTableDataMap& OutPreprocessedTableData); /** Decompress and parse a raw message. */ - bool PreProcessMessage(const TArray& Message, FServerMessageType& OutMessage); - bool DecompressPayload(uint8 Variant, const TArray& In, TArray& Out); - bool DecompressGzip(const TArray& InData, TArray& OutData); - bool DecompressBrotli(const TArray& InData, TArray& OutData); + bool PreProcessMessage(const TArray& Message, FInboundParsedMessage& OutMessage); + bool DecompressGzip(const uint8* InData, int32 InSize, TArray& OutData); + bool DecompressBrotli(const uint8* InData, int32 InSize, TArray& OutData); void ClearPendingOperations(const FString& Reason); void HandleProtocolViolation(const FString& ErrorMessage); - - /** Pending messages awaiting processing on the game thread. */ - TArray PendingMessages; - - /** Mutex protecting access to PendingMessages. */ - FCriticalSection PendingMessagesMutex; - - /** Map of preprocessed messages keyed by their sequential id. */ - TMap PreprocessedMessages; - - /** Protects PreprocessedMessages and PendingMessages ordering state. */ - FCriticalSection PreprocessMutex; - - /** Counter for assigning ids to incoming messages. */ - FThreadSafeCounter NextPreprocessId; - - /** Id of the next message expected to be released. */ - int32 NextReleaseId = 0; - - // Map of table name to row deserializer - TMap> TableDeserializers; - FCriticalSection TableDeserializersMutex; - - // Map from table update pointer to preprocessed data - TMap>> PreprocessedTableData; - FCriticalSection PreprocessedDataMutex; - - // Map of table name to generic table update handler - TMap> RegisteredTables; - FCriticalSection RegisteredTablesMutex; - - - /** Start a subscription. This will add the subscription to the active list and send a subscribe message to the server. */ - void StartSubscription(USubscriptionHandleBase* Handle); - /** Unsubscribe from a subscription. This will remove the subscription from the active list and send an unsubscribe message to the server. */ - void UnsubscribeInternal(USubscriptionHandleBase* Handle); - - /** Call a reducer on the connected SpacetimeDB instance. */ + + /** Parsed inbound messages awaiting processing on the game thread. */ + TArray PendingMessages; + + /** Mutex protecting access to PendingMessages. */ + FCriticalSection PendingMessagesMutex; + int32 PendingMessageReadIndex = 0; + int64 PendingParsedEstimatedBytes = 0; + + /** Raw inbound messages awaiting FIFO processing by the connection-owned worker. */ + TArray InboundRawMessages; + mutable FCriticalSection InboundRawMessagesMutex; + int32 InboundRawMessageReadIndex = 0; + int64 InboundQueuedRawBytes = 0; + uint64 InboundConnectionEpoch = 0; + uint64 NextInboundSequenceId = 0; + bool bInboundAcceptingMessages = false; + bool bInboundProtocolErrorQueued = false; + FSpacetimeDbInboundWorker* InboundWorker = nullptr; + FCriticalSection InboundWorkerMutex; + + // Map of table name to row deserializer + TMap> TableDeserializers; + FCriticalSection TableDeserializersMutex; + + // Message-scoped preprocessed table rows active only while applying one inbound server message. + FPreprocessedTableDataMap* ActivePreprocessedTableData = nullptr; + + // Map of table name to generic table update handler + TMap> RegisteredTables; + TSharedPtr>> RegisteredTablesSnapshot; + FCriticalSection RegisteredTablesMutex; + + + /** Start a subscription. This will add the subscription to the active list and send a subscribe message to the server. */ + void StartSubscription(USubscriptionHandleBase* Handle); + /** Unsubscribe from a subscription. This will remove the subscription from the active list and send an unsubscribe message to the server. */ + void UnsubscribeInternal(USubscriptionHandleBase* Handle); + + /** Call a reducer on the connected SpacetimeDB instance. */ uint32 InternalCallReducer(const FString& Reducer, TArray Args); - - /** Call a reducer on the connected SpacetimeDB instance. */ - void InternalCallProcedure(const FString& ProcedureName, TArray Args, const FOnProcedureCompleteDelegate& Callback); - - /** - * Update function to apply database changes. - * Must be implemented by child classes. - * @param Update - Struct containing update data. - */ - virtual void DbUpdate(const FDatabaseUpdateType& Update, const FSpacetimeDBEvent& Event) {}; - - /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ - virtual void ReducerEvent(const FReducerEvent& Event) {}; - - /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ - virtual void ReducerEventFailed(const FReducerEvent& Event, const FString ErrorMessage) {}; - - /** Event handler for procedure events. This can should overridden by child classes to handle specific procedure events. */ - virtual void ProcedureEventFailed(const FProcedureEvent& Event, const FString ErrorMessage) {}; - - /** Event handler for error events. This can should overridden by child classes to handle specific error events. */ - virtual void TriggerError(const FString& ErrorMessage) {}; - - /** Event handler for subscription events. This can should overridden by child classes to handle specific subscription events. */ - virtual void TriggerSubscription() {}; - - /** Apply updates for all registered tables using the provided context pointer */ - void ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context); - + + /** Call a reducer on the connected SpacetimeDB instance. */ + void InternalCallProcedure(const FString& ProcedureName, TArray Args, const FOnProcedureCompleteDelegate& Callback); + + /** + * Update function to apply database changes. + * Must be implemented by child classes. + * @param Update - Struct containing update data. + */ + virtual void DbUpdate(const FDatabaseUpdateType& Update, const FSpacetimeDBEvent& Event) {}; + + /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ + virtual void ReducerEvent(const FReducerEvent& Event) {}; + + /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ + virtual void ReducerEventFailed(const FReducerEvent& Event, const FString ErrorMessage) {}; + + /** Event handler for procedure events. This can should overridden by child classes to handle specific procedure events. */ + virtual void ProcedureEventFailed(const FProcedureEvent& Event, const FString ErrorMessage) {}; + + /** Event handler for error events. This can should overridden by child classes to handle specific error events. */ + virtual void TriggerError(const FString& ErrorMessage) {}; + + /** Event handler for subscription events. This can should overridden by child classes to handle specific subscription events. */ + virtual void TriggerSubscription() {}; + + /** Apply updates for all registered tables using the provided context pointer */ + void ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context); + /** Called when a subscription is updated. */ UPROPERTY() TMap> ActiveSubscriptions; @@ -366,9 +798,9 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject /** Pending reducer call metadata keyed by request id for ReducerResult correlation. */ UPROPERTY() TMap PendingReducerCalls; - - UPROPERTY() - TObjectPtr ProcedureCallbacks; + + UPROPERTY() + TObjectPtr ProcedureCallbacks; /** Get the next request id for a message. This is used to track requests and responses. */ uint32 NextRequestId; /** Get the next subscription id for a subscription. This is used to track subscriptions and their responses. */ @@ -377,67 +809,78 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject uint32 GetNextRequestId(); /** Get the next subscription id for a subscription. This is used to track subscriptions and their responses. */ uint32 GetNextSubscriptionId(); - - /** The WebSocket manager used to connect to the server. */ - UPROPERTY() - UWebsocketManager* WebSocket = nullptr; - - /** The URI of the SpacetimeDB server to connect to. */ - UPROPERTY() - FString Uri; - /** The module name to connect to. This is used to identify the SpacetimeDB instance. */ - UPROPERTY() - FString ModuleName; - /** The token used to authenticate the connection. */ - UPROPERTY() - FString Token; - - /** The identity of the SpacetimeDB instance. This is used to identify the connection. */ - UPROPERTY() - FSpacetimeDBIdentity Identity; - UPROPERTY() - /** Whether the identity has been set. This is used to prevent multiple identity sets. */ - bool bIsIdentitySet = false; - /** The connection id of the SpacetimeDB instance. This is used to identify the connection. */ - UPROPERTY() - FSpacetimeDBConnectionId ConnectionId; - + + /** The WebSocket manager used to connect to the server. */ + UPROPERTY() + UWebsocketManager* WebSocket = nullptr; + + /** The URI of the SpacetimeDB server to connect to. */ + UPROPERTY() + FString Uri; + /** The module name to connect to. This is used to identify the SpacetimeDB instance. */ + UPROPERTY() + FString ModuleName; + /** The token used to authenticate the connection. */ + UPROPERTY() + FString Token; + + /** The identity of the SpacetimeDB instance. This is used to identify the connection. */ + UPROPERTY() + FSpacetimeDBIdentity Identity; + UPROPERTY() + /** Whether the identity has been set. This is used to prevent multiple identity sets. */ + bool bIsIdentitySet = false; + /** The connection id of the SpacetimeDB instance. This is used to identify the connection. */ + UPROPERTY() + FSpacetimeDBConnectionId ConnectionId; + UPROPERTY() bool bIsAutoTicking = false; - FTSTicker::FDelegateHandle TickerHandle; + + FSpacetimeDBInboundApplyBudget InboundApplyBudget; + struct FPendingTableBroadcast + { + TSharedPtr Handler; + int32 StatsIndex = INDEX_NONE; + }; + TArray TableUpdateHandlersScratch; + FSpacetimeDBInboundMessageApplyStats* ActiveInboundMessageApplyStats = nullptr; + /** Guard to avoid repeatedly handling the same fatal protocol error. */ FThreadSafeBool bProtocolViolationHandled = false; - - UPROPERTY() - FOnConnectErrorDelegate OnConnectErrorDelegate; - UPROPERTY() - FOnDisconnectBaseDelegate OnDisconnectBaseDelegate; - UPROPERTY() - FOnConnectBaseDelegate OnConnectBaseDelegate; - - /** Called when the connection is established. */ - template - void BroadcastDiff(TableClass* Table, const FTableAppliedDiff& Diff, const EventContext& Context) - { - if (!Table) return; - - // Broadcast the diff to the table's delegates - if (Table->OnInsert.IsBound()) - { - for (const TPair, RowType>& Pair : Diff.Inserts) - { - Table->OnInsert.Broadcast(Context, Pair.Value); - } - } - + + UPROPERTY() + FOnConnectErrorDelegate OnConnectErrorDelegate; + UPROPERTY() + FOnDisconnectBaseDelegate OnDisconnectBaseDelegate; + UPROPERTY() + FOnConnectBaseDelegate OnConnectBaseDelegate; + + /** Called when the connection is established. */ + template + void BroadcastDiff(TableClass* Table, const FTableAppliedDiff& Diff, const EventContext& Context) + { + if (!Table) return; + + // Broadcast the diff to the table's delegates + if (Table->OnInsert.IsBound()) + { + for (const TSharedPtr& Row : Diff.Inserts) + { + checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB insert diff row.")); + Table->OnInsert.Broadcast(Context, *Row); + } + } + // Event tables intentionally omit delete/update delegates. if constexpr (THasOnDeleteDelegate::value) { if (Table->OnDelete.IsBound()) { - for (const TPair, RowType>& Pair : Diff.Deletes) + for (const TSharedPtr& Row : Diff.Deletes) { - Table->OnDelete.Broadcast(Context, Pair.Value); + checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB delete diff row.")); + Table->OnDelete.Broadcast(Context, *Row); } } } @@ -446,12 +889,13 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject { if (Table->OnUpdate.IsBound()) { - int32 Count = FMath::Min(Diff.UpdateDeletes.Num(), Diff.UpdateInserts.Num()); - for (int32 Index = 0; Index < Count; ++Index) + checkf(Diff.UpdateDeletes.Num() == Diff.UpdateInserts.Num(), TEXT("Mismatched SpacetimeDB update diff counts.")); + for (int32 Index = 0; Index < Diff.UpdateInserts.Num(); ++Index) { - const RowType& OldRow = Diff.UpdateDeletes[Index]; - const RowType& NewRow = Diff.UpdateInserts[Index]; - Table->OnUpdate.Broadcast(Context, OldRow, NewRow); + const TSharedPtr& OldRow = Diff.UpdateDeletes[Index]; + const TSharedPtr& NewRow = Diff.UpdateInserts[Index]; + checkf(OldRow.IsValid() && NewRow.IsValid(), TEXT("Invalid SpacetimeDB update diff row.")); + Table->OnUpdate.Broadcast(Context, *OldRow, *NewRow); } } } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h index b9e6f91378d..15d375f7f2b 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h @@ -29,11 +29,11 @@ DECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParam(FOnWebSocketBinaryMessageReceived, c * Handles connecting, disconnecting, sending messages, and receiving messages. */ UCLASS(BlueprintType) -class SPACETIMEDBSDK_API UWebsocketManager : public UObject -{ - GENERATED_BODY() - -public: +class SPACETIMEDBSDK_API UWebsocketManager : public UObject +{ + GENERATED_BODY() + +public: UWebsocketManager(); virtual void BeginDestroy() override; @@ -66,14 +66,16 @@ class SPACETIMEDBSDK_API UWebsocketManager : public UObject /** * Checks if the WebSocket connection is currently active. * @return True if connected, false otherwise. - */ - bool IsConnected() const; - - /** - * Sets the initial auth token used when connecting. - * @param Token JWT or session token expected by the server. - */ - void SetInitToken(FString Token); + */ + bool IsConnected() const; + + void SetNativeBinaryMessageTarget(class UDbConnectionBase* Target); + + /** + * Sets the initial auth token used when connecting. + * @param Token JWT or session token expected by the server. + */ + void SetInitToken(FString Token); /** Delegates for WebSocket events */ UPROPERTY() @@ -105,16 +107,18 @@ class SPACETIMEDBSDK_API UWebsocketManager : public UObject void HandleConnectionError(const FString& Error); /** Handler for incoming text messages */ void HandleMessageReceived(const FString& Message); - /** Handler for incoming binary messages */ - void HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment); - /** Handler for socket close */ - void HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean); - - FString InitToken; - - /** Buffer used to accumulate binary fragments until a complete message - * is received. */ - TArray IncompleteMessage; + /** Handler for incoming binary messages */ + void HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment); + void DispatchCompleteBinaryMessage(TArray&& Message); + /** Handler for socket close */ + void HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean); + + FString InitToken; + TWeakObjectPtr NativeBinaryMessageTarget; + + /** Buffer used to accumulate binary fragments until a complete message + * is received. */ + TArray IncompleteMessage; /** Tracks if we are waiting for additional binary fragments. */ bool bAwaitingBinaryFragments = false; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h index 70693f4338c..1d90345eecb 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h @@ -1,7 +1,59 @@ #pragma once #include "CoreMinimal.h" +#include "BSATN/UESpacetimeDB.h" #include "TableCache.h" #include "TableAppliedDiff.h" +#include "WithBsatn.h" + +#include +#include + +namespace UE::SpacetimeDB +{ +enum class ETableCacheApplyMode : uint8 +{ + PersistentIndexed, + DirectNativeDiff +}; + +enum class ERuntimeBTreeIndexApplyMode : uint8 +{ + Apply, + Skip +}; + +template +struct TCompactPrimaryKeyTraits +{ + static constexpr bool bEnabled = false; + using KeyType = uint64; + + static KeyType GetKey(const RowType& Row) + { + (void)Row; + static_assert(bEnabled, "SpacetimeDB compact cache key trait is not generated for this row type."); + return 0; + } + + static const TCHAR* GetUniqueIndexName() + { + static_assert(bEnabled, "SpacetimeDB compact cache key trait is not generated for this row type."); + return TEXT(""); + } +}; + +template +struct TTableCachePolicy +{ + static constexpr ETableCacheApplyMode ApplyMode = ETableCacheApplyMode::PersistentIndexed; + + static ERuntimeBTreeIndexApplyMode GetRuntimeBTreeIndexApplyMode(const FString& IndexName) + { + (void)IndexName; + return ERuntimeBTreeIndexApplyMode::Apply; + } +}; +} /* ============================================================================ * * ClientCache.h (2025-05-28) @@ -20,6 +72,40 @@ class UClientCache */ TSharedPtr> Table; + void SetApplyMode(UE::SpacetimeDB::ETableCacheApplyMode InApplyMode) + { + ApplyMode = InApplyMode; + } + + UE::SpacetimeDB::ETableCacheApplyMode GetApplyMode() const + { + return ApplyMode; + } + + void SetRuntimeBTreeIndexApplyMode( + const FString& IndexName, + UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode InApplyMode) + { + checkf(!IndexName.IsEmpty(), TEXT("Cannot configure runtime B-Tree index policy for an empty index name.")); + switch (InApplyMode) + { + case UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode::Apply: + RuntimeBTreeIndexApplyModes.Add(IndexName, InApplyMode); + break; + case UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode::Skip: + RuntimeBTreeIndexApplyModes.Add(IndexName, InApplyMode); + break; + default: + checkf(false, TEXT("Unknown runtime B-Tree index apply mode for index '%s'."), *IndexName); + break; + } + } + + void ClearRuntimeBTreeIndexApplyModes() + { + RuntimeBTreeIndexApplyModes.Reset(); + } + /** * Retrieves the existing table cache or creates a new one if none exists. @@ -70,6 +156,261 @@ class UClientCache } + FTableAppliedDiff ApplyDiffByPrimaryKey( + const FString& Name, + TArray>&& Inserts, + TArray>&& Deletes, + const TCHAR* ExpectedUniqueIndexName) + { + using FCompactPrimaryKeyTraits = UE::SpacetimeDB::TCompactPrimaryKeyTraits; + static_assert(FCompactPrimaryKeyTraits::bEnabled, "ApplyDiffByPrimaryKey requires a generated compact primary-key trait."); + using KeyType = typename FCompactPrimaryKeyTraits::KeyType; + + checkf(!Name.IsEmpty(), TEXT("ApplyDiffByPrimaryKey called with empty table name.")); + checkf(Table.IsValid(), TEXT("ApplyDiffByPrimaryKey could not find table cache for %s."), *Name); + checkf(ExpectedUniqueIndexName != nullptr && ExpectedUniqueIndexName[0] != TEXT('\0'), + TEXT("ApplyDiffByPrimaryKey for %s requires a generated unique index name."), *Name); + const FString ExpectedIndexName(ExpectedUniqueIndexName); + checkf(Table->UniqueIndices.Contains(ExpectedIndexName), + TEXT("ApplyDiffByPrimaryKey for %s requires generated unique index %s."), + *Name, + ExpectedUniqueIndexName); + + if (ShouldApplyDirectNativeDiff()) + { + return BuildDirectDiffByPrimaryKey(Name, MoveTemp(Inserts), MoveTemp(Deletes)); + } + + struct FDeletedRow + { + TArray CacheKey; + TSharedPtr Row; + int32 PendingCount = 0; + bool bUpdateApplied = false; + }; + + auto BuildCacheKey = [](const KeyType& Key) + { + static constexpr int32 CompactPrimaryKeyBytes = sizeof(KeyType); + TArray CacheKey; + CacheKey.SetNumUninitialized(CompactPrimaryKeyBytes); + uint64 Remaining = Key; + for (int32 ByteIndex = 0; ByteIndex < CompactPrimaryKeyBytes; ++ByteIndex) + { + CacheKey[ByteIndex] = static_cast(Remaining & 0xffu); + Remaining >>= 8; + } + return CacheKey; + }; + + auto RemoveFromIndices = [this](const TArray& Key, const TSharedPtr& Row) + { + checkf(Row.IsValid(), TEXT("Cannot remove invalid row from table indices.")); + for (auto& IndexPair : Table->UniqueIndices) + { + IndexPair.Value->RemoveRow(Row); + } + for (auto& IndexPair : Table->BTreeIndices) + { + if (!ShouldApplyRuntimeBTreeIndex(IndexPair.Key)) + { + continue; + } + IndexPair.Value->RemoveRow(Key, Row); + } + }; + + auto AddToIndices = [this](const TArray& Key, const TSharedPtr& Row) + { + checkf(Row.IsValid(), TEXT("Cannot add invalid row to table indices.")); + for (auto& IndexPair : Table->UniqueIndices) + { + IndexPair.Value->AddRow(Row); + } + for (auto& IndexPair : Table->BTreeIndices) + { + if (!ShouldApplyRuntimeBTreeIndex(IndexPair.Key)) + { + continue; + } + IndexPair.Value->AddRow(Key, Row); + } + }; + + FTableAppliedDiff Diff; + Diff.bPrimaryKeyUpdatesClassified = true; + Diff.Inserts.Reserve(Inserts.Num()); + Diff.Deletes.Reserve(Deletes.Num()); + Diff.UpdateDeletes.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + Diff.UpdateInserts.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + + TMap DeletedRows; + DeletedRows.Reserve(Deletes.Num()); + + for (const FWithBsatn& Delete : Deletes) + { + const KeyType PrimaryKey = FCompactPrimaryKeyTraits::GetKey(Delete.Row); + const TArray CacheKey = BuildCacheKey(PrimaryKey); + FRowEntry* Entry = Table->Entries.Find(CacheKey); + if (!Entry) + { + continue; + } + + checkf(Entry->RefCount > 0, TEXT("Table cache row for %s has invalid refcount before primary-key delete."), *Name); + checkf(Entry->Row.IsValid(), TEXT("Table cache row for %s is invalid before primary-key delete."), *Name); + FDeletedRow& Deleted = DeletedRows.FindOrAdd(PrimaryKey); + if (!Deleted.Row.IsValid()) + { + Deleted.CacheKey = CacheKey; + Deleted.Row = Entry->Row; + } + checkf(Deleted.CacheKey == CacheKey, TEXT("Mismatched compact cache key for primary-key delete on %s."), *Name); + ++Deleted.PendingCount; + --Entry->RefCount; + checkf(Entry->RefCount >= 0, TEXT("Table cache row for %s has negative refcount after primary-key delete."), *Name); + } + + for (FWithBsatn& Insert : Inserts) + { + const KeyType PrimaryKey = FCompactPrimaryKeyTraits::GetKey(Insert.Row); + const TArray CacheKey = BuildCacheKey(PrimaryKey); + FDeletedRow* MatchingDelete = DeletedRows.Find(PrimaryKey); + if (MatchingDelete && MatchingDelete->PendingCount > 0) + { + checkf(MatchingDelete->CacheKey == CacheKey, + TEXT("Mismatched compact cache key for primary-key update on %s."), *Name); + FRowEntry* Entry = Table->Entries.Find(CacheKey); + checkf(Entry != nullptr, TEXT("Missing table cache row for primary-key update on %s."), *Name); + checkf(Entry->RefCount >= 0, TEXT("Table cache row for %s has invalid refcount before primary-key update insert."), *Name); + checkf(MatchingDelete->Row.IsValid(), TEXT("Invalid deleted row for primary-key update on %s."), *Name); + + ++Entry->RefCount; + if (!MatchingDelete->bUpdateApplied) + { + TSharedPtr OldRow = MatchingDelete->Row; + TSharedPtr NewRow = MakeShared(MoveTemp(Insert.Row)); + RemoveFromIndices(CacheKey, OldRow); + Entry->Row = NewRow; + AddToIndices(CacheKey, NewRow); + + Diff.UpdateDeletes.Add(OldRow); + Diff.UpdateInserts.Add(NewRow); + MatchingDelete->bUpdateApplied = true; + MatchingDelete->Row = NewRow; + } + --MatchingDelete->PendingCount; + continue; + } + + FRowEntry* ExistingEntry = Table->Entries.Find(CacheKey); + if (ExistingEntry) + { + checkf(ExistingEntry->RefCount > 0, + TEXT("Primary-key insert for %s found an existing row with invalid refcount."), *Name); + ++ExistingEntry->RefCount; + continue; + } + + TSharedPtr NewRow = MakeShared(MoveTemp(Insert.Row)); + Table->Entries.Add(CacheKey, FRowEntry{NewRow, 1}); + AddToIndices(CacheKey, NewRow); + Diff.Inserts.Add(NewRow); + } + + for (const TPair& DeletedPair : DeletedRows) + { + const FDeletedRow& Deleted = DeletedPair.Value; + if (Deleted.PendingCount <= 0) + { + continue; + } + + FRowEntry* Entry = Table->Entries.Find(Deleted.CacheKey); + if (!Entry || Entry->RefCount > 0) + { + continue; + } + + checkf(Deleted.Row.IsValid(), TEXT("Invalid deleted row for primary-key delete on %s."), *Name); + checkf(Entry->RefCount == 0, TEXT("Primary-key delete for %s reached impossible refcount state."), *Name); + RemoveFromIndices(Deleted.CacheKey, Deleted.Row); + Diff.Deletes.Add(Deleted.Row); + Table->Entries.Remove(Deleted.CacheKey); + } + + return Diff; + } + +private: + FTableAppliedDiff BuildDirectDiffByPrimaryKey( + const FString& Name, + TArray>&& Inserts, + TArray>&& Deletes) + { + using FCompactPrimaryKeyTraits = UE::SpacetimeDB::TCompactPrimaryKeyTraits; + static_assert(FCompactPrimaryKeyTraits::bEnabled, "BuildDirectDiffByPrimaryKey requires a generated compact primary-key trait."); + using KeyType = typename FCompactPrimaryKeyTraits::KeyType; + + checkf(!Name.IsEmpty(), TEXT("BuildDirectDiffByPrimaryKey called with empty table name.")); + + FTableAppliedDiff Diff; + Diff.bPrimaryKeyUpdatesClassified = true; + Diff.Inserts.Reserve(Inserts.Num()); + Diff.Deletes.Reserve(Deletes.Num()); + Diff.UpdateDeletes.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + Diff.UpdateInserts.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + + TMap>> DeletesByKey; + DeletesByKey.Reserve(Deletes.Num()); + for (FWithBsatn& Delete : Deletes) + { + const KeyType PrimaryKey = FCompactPrimaryKeyTraits::GetKey(Delete.Row); + TArray>& Rows = DeletesByKey.FindOrAdd(PrimaryKey); + Rows.Add(MakeShared(MoveTemp(Delete.Row))); + } + + for (FWithBsatn& Insert : Inserts) + { + const KeyType PrimaryKey = FCompactPrimaryKeyTraits::GetKey(Insert.Row); + if (TArray>* MatchingDeletes = DeletesByKey.Find(PrimaryKey)) + { + checkf(!MatchingDeletes->IsEmpty(), + TEXT("Direct compact diff for %s found an empty delete bucket."), + *Name); + TSharedPtr OldRow = MatchingDeletes->Pop(EAllowShrinking::No); + checkf(OldRow.IsValid(), + TEXT("Direct compact diff for %s found an invalid deleted row."), + *Name); + if (MatchingDeletes->IsEmpty()) + { + DeletesByKey.Remove(PrimaryKey); + } + + Diff.UpdateDeletes.Add(OldRow); + Diff.UpdateInserts.Add(MakeShared(MoveTemp(Insert.Row))); + continue; + } + + Diff.Inserts.Add(MakeShared(MoveTemp(Insert.Row))); + } + + for (TPair>>& DeletePair : DeletesByKey) + { + for (TSharedPtr& DeletedRow : DeletePair.Value) + { + checkf(DeletedRow.IsValid(), + TEXT("Direct compact diff for %s retained an invalid deleted row."), + *Name); + Diff.Deletes.Add(MoveTemp(DeletedRow)); + } + } + + return Diff; + } + +public: + /** * Apply Inserts + Deletes to the specified table. * Inserts: increment refCount, add new entry when needed. @@ -77,8 +418,8 @@ class UClientCache */ FTableAppliedDiff ApplyDiff( const FString& Name, - const TArray, RowType>>& Inserts, - const TArray>& Deletes) + TArray>&& Inserts, + TArray>&& Deletes) { if (Name.IsEmpty()) { @@ -91,96 +432,150 @@ class UClientCache UE_LOG(LogTemp, Error, TEXT("Failed to create or retrieve table: %s"), *Name); return FTableAppliedDiff(); } + checkf(ApplyMode != UE::SpacetimeDB::ETableCacheApplyMode::DirectNativeDiff, + TEXT("DirectNativeDiff for table %s requires a generated compact uint64 primary-key trait."), + *Name); - FTableAppliedDiff Diff; - - // Map of deleted SerializedBytes -> (Key, Row) - // The key type is now generic TArray - TMap, TPair, TSharedPtr>> DeletedEntries; - - // Phase 1: Pre-process Deletes - for (const TArray& Key : Deletes) + struct FDeletedRow { + TSharedPtr Row; + bool bMatchedInsert = false; + }; + struct FInsertedRow + { + TArray Key; + TSharedPtr Row; + }; - FRowEntry* Entry = Table->Entries.Find(Key); - if (!Entry) continue; - - // Decrement refcount and store the entry if it's about to be deleted - if (--Entry->RefCount == 0) + auto RemoveFromIndices = [this](const TArray& Key, const TSharedPtr& Row) + { + checkf(Row.IsValid(), TEXT("Cannot remove invalid row from table indices.")); + for (auto& IndexPair : Table->UniqueIndices) { - DeletedEntries.Add(Key, TPair, TSharedPtr>(Key, Entry->Row)); + IndexPair.Value->RemoveRow(Row); } - } + for (auto& IndexPair : Table->BTreeIndices) + { + if (!ShouldApplyRuntimeBTreeIndex(IndexPair.Key)) + { + continue; + } + IndexPair.Value->RemoveRow(Key, Row); + } + }; - // Phase 2: Process Inserts and Updates - for (const auto& Ins : Inserts) + auto AddToIndices = [this](const TArray& Key, const TSharedPtr& Row) { - const TArray& Key = Ins.Key; - const RowType& Row = Ins.Value; - - TSharedPtr NewRow = MakeShared(Row); - - FRowEntry* Entry = Table->Entries.Find(Key); - if (!Entry) + checkf(Row.IsValid(), TEXT("Cannot add invalid row to table indices.")); + for (auto& IndexPair : Table->UniqueIndices) { - // True insert — these row-bytes are not cached yet. Either a - // genuinely new row, or the insert half of an update (the - // paired delete of different bytes is tracked in phase 1/3; - // DeriveUpdatesByPrimaryKey reconciles them by PK afterward). - Table->Entries.Add(Key, FRowEntry{NewRow, 1}); - Diff.Inserts.Add(Key, *NewRow); + IndexPair.Value->AddRow(Row); } - else + for (auto& IndexPair : Table->BTreeIndices) { - // Refcount bump — an overlapping subscription brought an - // identical row already in cache. Mirror the delete path - // (which only emits Diff.Deletes on refcount == 0) by not - // emitting a spurious Diff.Inserts entry here. - Table->Entries.Add(Key, FRowEntry{NewRow, Entry->RefCount + 1}); + if (!ShouldApplyRuntimeBTreeIndex(IndexPair.Key)) + { + continue; + } + IndexPair.Value->AddRow(Key, Row); } - } + }; - // Phase 3: Finalize Deletes and Update Indices - for (const auto& KeyValue : DeletedEntries) - { - // Add to diff before removal - Diff.Deletes.Add(KeyValue.Key, *KeyValue.Value.Value); - Table->Entries.Remove(KeyValue.Key); - } + FTableAppliedDiff Diff; + Diff.Inserts.Reserve(Inserts.Num()); + Diff.Deletes.Reserve(Deletes.Num()); + Diff.UpdateDeletes.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + Diff.UpdateInserts.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + + TMap, FDeletedRow> DeletedRows; + DeletedRows.Reserve(Deletes.Num()); + TArray InsertedRows; + InsertedRows.Reserve(Inserts.Num()); - // Now, update all indices with the completed diff - for (const auto& DeletePair : Diff.Deletes) + for (const FWithBsatn& Delete : Deletes) { - for (auto& IndexPair : Table->UniqueIndices) + const TArray& Key = Delete.Bsatn; + FRowEntry* Entry = Table->Entries.Find(Key); + if (!Entry) + { + continue; + } + + checkf(Entry->RefCount > 0, TEXT("Table cache row for %s has invalid refcount before delete."), *Name); + checkf(Entry->Row.IsValid(), TEXT("Table cache row for %s is invalid before delete."), *Name); + FDeletedRow& Deleted = DeletedRows.FindOrAdd(Key); + if (!Deleted.Row.IsValid()) { - // Assuming RemoveRow takes the TSharedPtr directly - IndexPair.Value->RemoveRow(MakeShared(DeletePair.Value)); + Deleted.Row = Entry->Row; } + --Entry->RefCount; } - for (const auto& InsertPair : Diff.Inserts) + for (FWithBsatn& Insert : Inserts) { - for (auto& IndexPair : Table->UniqueIndices) + const TArray& Key = Insert.Bsatn; + FRowEntry* ExistingEntry = Table->Entries.Find(Key); + FDeletedRow* MatchingDelete = DeletedRows.Find(Key); + if (ExistingEntry) { - // Assuming AddRow takes TSharedPtr directly - IndexPair.Value->AddRow(MakeShared(InsertPair.Value)); + ++ExistingEntry->RefCount; + if (MatchingDelete) + { + MatchingDelete->bMatchedInsert = true; + } + continue; } + + TSharedPtr NewRow = MakeShared(MoveTemp(Insert.Row)); + Table->Entries.Add(Key, FRowEntry{NewRow, 1}); + InsertedRows.Add(FInsertedRow{Key, NewRow}); + Diff.Inserts.Add(NewRow); } - // And for BtreeIndices... - for (auto& Pair : Table->BTreeIndices) + for (const TPair, FDeletedRow>& DeletedPair : DeletedRows) { - for (const auto& DeletePair : Diff.Deletes) + const TArray& Key = DeletedPair.Key; + const FDeletedRow& Deleted = DeletedPair.Value; + if (Deleted.bMatchedInsert) { - Pair.Value->RemoveRow(DeletePair.Key, MakeShared(DeletePair.Value)); + continue; } - for (const auto& InsertPair : Diff.Inserts) + FRowEntry* Entry = Table->Entries.Find(Key); + if (!Entry || Entry->RefCount > 0) { - Pair.Value->AddRow(InsertPair.Key, MakeShared(InsertPair.Value)); + continue; } + + RemoveFromIndices(Key, Deleted.Row); + Diff.Deletes.Add(Deleted.Row); + Table->Entries.Remove(Key); + } + + for (const FInsertedRow& Inserted : InsertedRows) + { + AddToIndices(Inserted.Key, Inserted.Row); } return Diff; } +private: + bool ShouldApplyDirectNativeDiff() const + { + return ApplyMode == UE::SpacetimeDB::ETableCacheApplyMode::DirectNativeDiff; + } + + bool ShouldApplyRuntimeBTreeIndex(const FString& IndexName) const + { + checkf(!IndexName.IsEmpty(), TEXT("Cannot apply runtime B-Tree index policy for an empty index name.")); + if (const UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode* RuntimeMode = RuntimeBTreeIndexApplyModes.Find(IndexName)) + { + return *RuntimeMode == UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode::Apply; + } + return UE::SpacetimeDB::TTableCachePolicy::GetRuntimeBTreeIndexApplyMode(IndexName) + == UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode::Apply; + } + + UE::SpacetimeDB::ETableCacheApplyMode ApplyMode = UE::SpacetimeDB::TTableCachePolicy::ApplyMode; + TMap RuntimeBTreeIndexApplyModes; }; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/README.md b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/README.md index 7ff0ed9ee3d..d1a23afdaaa 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/README.md +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/README.md @@ -8,7 +8,7 @@ Utilities used to maintain a client side cache of database tables. These classe - `ClientCache.h` – Owns `FTableCache` objects and applies insert/delete diffs sent over the network. - `IUniqueIndex.h` – Interface that unique index implementations conform to. - `RowEntry.h` – Wrapper storing a row value with a reference count used by overlapping subscriptions. -- `TableAppliedDiff.h` – Describes the inserts, deletes and updates detected when applying a diff. +- `TableAppliedDiff.h` – Describes the inserts, deletes and updates detected when applying a diff. Direct C++ consumers receive `TSharedPtr` row arrays; generated dynamic delegates continue broadcasting value references. - `TableCache.h` – In-memory representation of a table and its unique indices. - `TableHandle.h` – Lightweight helper exposing read only access to a cached table. - `UniqueConstraintHandle.h` – Helper that allows typed lookups against a unique constraint. @@ -126,4 +126,4 @@ Table.GetValues(AllRows); - `TArray` keys allow serialized identifiers (network-friendly). - Adding indices after inserting rows is **not supported** without manual rebuild. -- B-Tree indices can later be extended for **range queries**.ed to a **single column** per call. \ No newline at end of file +- B-Tree indices can later be extended for **range queries** and are currently scoped to a **single column** per call. diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h index cc76786999e..640b50343a3 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h @@ -4,23 +4,25 @@ /* ============================================================================ * * TableAppliedDiff.h (2025-05-28) * ---------------------------------------------------------------------------- - * Captures the semantic result of applying a low‑level diff (inserts/deletes) - * to a table cache. Rows that transition from dead→live are inserts, live→dead - * are deletes, and a delete+insert with the same primary‑key is surfaced as an - * update pair. + * Captures the semantic result of applying a low-level diff (inserts/deletes) + * to a table cache. Rows that transition from dead to live are inserts, live + * to dead are deletes, and a delete+insert with the same primary key is + * surfaced as an update pair. + * + * This is the SDK's lower-copy diff representation. Direct C++ consumers read + * row payloads from TSharedPtr-backed arrays; generated dynamic delegates still + * broadcast value references for Blueprint and existing dynamic delegate code. * ============================================================================ */ template struct FTableAppliedDiff { - // SerializedKey -> Row copy. Keeping the rows by value ensures - // the memory stays valid even if the underlying table reallocates - // or removes entries while this diff is alive. - TMap, RowType> Deletes; - TMap, RowType> Inserts; + TArray> Deletes; + TArray> Inserts; - // Parallel arrays for (old, new) row update pairs. - TArray UpdateDeletes; - TArray UpdateInserts; + TArray> UpdateDeletes; + TArray> UpdateInserts; + + bool bPrimaryKeyUpdatesClassified = false; bool IsEmpty() const { @@ -36,38 +38,79 @@ struct FTableAppliedDiff template void DeriveUpdatesByPrimaryKey(TFunctionRef DerivePK) { - if (Deletes.IsEmpty()) return; + if (bPrimaryKeyUpdatesClassified) + { + checkf(UpdateDeletes.Num() == UpdateInserts.Num(), TEXT("Pre-classified primary-key update diff arrays are mismatched.")); + return; + } + if (Deletes.IsEmpty() || Inserts.IsEmpty()) return; - // Build PK->(key,row) map for deletes. - TMap, RowType>> DeletePK; - for (const auto& Pair : Deletes) + const int32 DeleteCount = Deletes.Num(); + const int32 InsertCount = Inserts.Num(); + TMap>> DeletePK; + DeletePK.Reserve(Deletes.Num()); + for (int32 DeleteIndex = DeleteCount - 1; DeleteIndex >= 0; --DeleteIndex) { - DeletePK.Add(DerivePK(Pair.Value), { Pair.Key, Pair.Value }); + const TSharedPtr& DeletedRow = Deletes[DeleteIndex]; + checkf(DeletedRow.IsValid(), TEXT("Invalid deleted row while deriving SpacetimeDB table updates.")); + const KeyType PK = DerivePK(*DeletedRow); + DeletePK.FindOrAdd(PK).Add(DeleteIndex); } - // Scan inserts for matching PKs. - TArray> DeleteKeys; - TArray> InsertKeys; - for (const auto& Pair : Inserts) + const int32 MaxUpdatePairs = FMath::Min(DeleteCount, InsertCount); + TArray MatchedDeletes; + TArray MatchedInserts; + MatchedDeletes.Init(0, DeleteCount); + MatchedInserts.Init(0, InsertCount); + UpdateDeletes.Reserve(UpdateDeletes.Num() + MaxUpdatePairs); + UpdateInserts.Reserve(UpdateInserts.Num() + MaxUpdatePairs); + int32 MatchedPairCount = 0; + for (int32 InsertIndex = 0; InsertIndex < InsertCount; ++InsertIndex) { - KeyType PK = DerivePK(Pair.Value); - if (const auto* Found = DeletePK.Find(PK)) + const TSharedPtr& InsertedRow = Inserts[InsertIndex]; + checkf(InsertedRow.IsValid(), TEXT("Invalid inserted row while deriving SpacetimeDB table updates.")); + KeyType PK = DerivePK(*InsertedRow); + if (TArray>* DeleteIndices = DeletePK.Find(PK)) { - UpdateDeletes.Add(Found->Value); - UpdateInserts.Add(Pair.Value); - DeleteKeys.Add(Found->Key); - InsertKeys.Add(Pair.Key); + checkf(!DeleteIndices->IsEmpty(), TEXT("Empty deleted row index list while deriving SpacetimeDB table updates.")); + const int32 DeleteIndex = DeleteIndices->Pop(EAllowShrinking::No); + checkf(Deletes.IsValidIndex(DeleteIndex), TEXT("Invalid deleted row index while deriving SpacetimeDB table updates.")); + UpdateDeletes.Add(Deletes[DeleteIndex]); + UpdateInserts.Add(InsertedRow); + MatchedDeletes[DeleteIndex] = 1; + MatchedInserts[InsertIndex] = 1; + if (DeleteIndices->IsEmpty()) + { + DeletePK.Remove(PK); + } + ++MatchedPairCount; } } - // Remove update pairs from base maps. - for (const auto& K : DeleteKeys) + if (MatchedPairCount == 0) + { + return; + } + + TArray> RemainingDeletes; + TArray> RemainingInserts; + RemainingDeletes.Reserve(DeleteCount - MatchedPairCount); + RemainingInserts.Reserve(InsertCount - MatchedPairCount); + for (int32 DeleteIndex = 0; DeleteIndex < DeleteCount; ++DeleteIndex) { - Deletes.Remove(K); + if (MatchedDeletes[DeleteIndex] == 0) + { + RemainingDeletes.Add(MoveTemp(Deletes[DeleteIndex])); + } } - for (const auto& K : InsertKeys) + for (int32 InsertIndex = 0; InsertIndex < InsertCount; ++InsertIndex) { - Inserts.Remove(K); + if (MatchedInserts[InsertIndex] == 0) + { + RemainingInserts.Add(MoveTemp(Inserts[InsertIndex])); + } } + Deletes = MoveTemp(RemainingDeletes); + Inserts = MoveTemp(RemainingInserts); } }; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/WithBsatn.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/WithBsatn.h index e4c2a405459..353d88e59ec 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/WithBsatn.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/WithBsatn.h @@ -10,8 +10,11 @@ struct FWithBsatn /** Deserialized row value */ RowType Row; - FWithBsatn() = default; - FWithBsatn(const TArray& InBsatn, const RowType& InRow) - : Bsatn(InBsatn), Row(InRow) { - } -}; \ No newline at end of file + FWithBsatn() = default; + FWithBsatn(const TArray& InBsatn, const RowType& InRow) + : Bsatn(InBsatn), Row(InRow) { + } + FWithBsatn(TArray&& InBsatn, RowType&& InRow) + : Bsatn(MoveTemp(InBsatn)), Row(MoveTemp(InRow)) { + } +}; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Tables/RemoteTable.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Tables/RemoteTable.h index cb331255ec7..daa21cde936 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Tables/RemoteTable.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Tables/RemoteTable.h @@ -29,8 +29,8 @@ class SPACETIMEDBSDK_API URemoteTable : public UObject */ template FTableAppliedDiff BaseUpdate( - const TArray>& InsertsRef, - const TArray>& DeletesRef, + TArray>& InsertsRef, + TArray>& DeletesRef, const TSharedPtr>& ClientCache, const FString& InTableName ) @@ -42,19 +42,19 @@ class SPACETIMEDBSDK_API URemoteTable : public UObject return {}; } - TArray, T>> Inserts; - for (const FWithBsatn& Insert : InsertsRef) + // Forward ownership of the worker-preprocessed row arrays to avoid rebuilding them on the game thread. + using FCompactPrimaryKeyTraits = UE::SpacetimeDB::TCompactPrimaryKeyTraits; + if constexpr (FCompactPrimaryKeyTraits::bEnabled) { - Inserts.Add({ Insert.Bsatn, Insert.Row }); + return ClientCache->ApplyDiffByPrimaryKey( + InTableName, + MoveTemp(InsertsRef), + MoveTemp(DeletesRef), + FCompactPrimaryKeyTraits::GetUniqueIndexName()); } - - TArray> Deletes; - for (const FWithBsatn& Delete : DeletesRef) + else { - Deletes.Add(Delete.Bsatn); + return ClientCache->ApplyDiff(InTableName, MoveTemp(InsertsRef), MoveTemp(DeletesRef)); } - - // Forward to the shared client cache implementation - return ClientCache->ApplyDiff(InTableName, Inserts, Deletes); } -}; \ No newline at end of file +};