From dcd8c046994a1302304da6f94333e0e1eee2b1ef Mon Sep 17 00:00:00 2001 From: Brougkr Date: Mon, 4 May 2026 22:19:18 -0400 Subject: [PATCH 1/2] feat(unreal): overhaul inbound SDK pipeline Replace the Unreal SDK inbound message path with a connection-owned worker thread, bounded raw and parsed queues, connection epoch guards, and deterministic lifecycle cleanup. This removes the previous fire-and-forget thread-pool preprocessing and reorder buffer, which could grow without backpressure and could allow stale async work to outlive a connection boundary. Move table preprocessing to message-scoped data, add zero-copy BSATN row parsing through DeserializeView, move-enabled FWithBsatn rows, and rvalue forwarding from remote tables into client caches. Cache apply now reuses shared row storage for diffs, validates structural invariants with checkf, and avoids spurious insert diffs on refcount bumps. Add compact uint64 primary-key apply support for generated FrameKey and BatchKey rows, native table listener bindings for reflection-free hot-path dispatch, multi-diff table broadcast ordering, and profiler scopes for inbound enqueue, worker preprocess, game-thread apply, cache apply, and broadcasts. Include the WebSocket native binary-message target and BSATN helper APIs required by the worker and zero-copy preprocessing path. The original plan listed only six files, but audit found those omitted files were compile dependencies for the copied implementation, so this commit keeps the PR source-complete rather than preserving a stale file count. Validation: ran git diff --check successfully; ran the required ./Scripts/full_rebuild.sh from /Users/brougkr/Documents/Unreal Projects/FACTIONS successfully, including SpacetimeDB publish, binding regeneration, and Unreal C++ build. --- .../Private/Connection/DbConnectionBase.cpp | 1415 +++++++++++++---- .../Connection/DbConnectionBuilder.cpp | 15 +- .../Private/Connection/Websocket.cpp | 118 +- .../Public/BSATN/UEBSATNHelpers.h | 128 +- .../Public/BSATN/UESpacetimeDB.h | 60 +- .../Public/Connection/DbConnectionBase.h | 1023 ++++++++---- .../Public/Connection/Websocket.h | 50 +- .../Public/DBCache/ClientCache.h | 522 +++++- .../Public/DBCache/TableAppliedDiff.h | 95 +- .../SpacetimeDbSdk/Public/DBCache/WithBsatn.h | 13 +- .../Public/Tables/RemoteTable.h | 26 +- 11 files changed, 2587 insertions(+), 878 deletions(-) diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp index 669f8079789..1607ca986c2 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp @@ -1,13 +1,16 @@ -#include "Connection/DbConnectionBase.h" +#include "Connection/DbConnectionBase.h" #include "Connection/DbConnectionBuilder.h" #include "Connection/Credentials.h" #include "Connection/LogCategory.h" -#include "Containers/Ticker.h" #include "ModuleBindings/Types/ClientMessageType.g.h" #include "ModuleBindings/Types/SubscriptionErrorType.g.h" #include "Misc/Compression.h" #include "Misc/ScopeLock.h" -#include "Async/Async.h" +#include "HAL/Event.h" +#include "HAL/PlatformProcess.h" +#include "HAL/Runnable.h" +#include "HAL/RunnableThread.h" +#include "ProfilingDebugging/CpuProfilerTrace.h" #include "BSATN/UEBSATNHelpers.h" #include "Connection/ProcedureFlags.h" @@ -20,7 +23,19 @@ enum class EWsCompressionTag : uint8 Gzip = 2, }; -static FDatabaseUpdateType QueryRowsToDatabaseUpdate(const FQueryRowsType& Rows, bool bAsDeletes) +constexpr int32 MaxQueuedInboundRawMessages = 8192; +constexpr int64 MaxQueuedInboundRawBytes = 128ll * 1024ll * 1024ll; +constexpr int32 MaxPendingInboundParsedMessages = 8192; +constexpr int64 MaxPendingInboundParsedPayloadBytes = 128ll * 1024ll * 1024ll; +constexpr int32 PendingInboundCompactionMinConsumedMessages = 512; +constexpr uint32 InboundWorkerStackSizeBytes = 0; +constexpr EThreadPriority InboundWorkerThreadPriority = TPri_Normal; +constexpr const TCHAR* InboundWorkerThreadName = TEXT("SpacetimeDBInboundWorker"); +constexpr int32 SpacetimeDbCompressionTagBytes = 1; +constexpr int32 GzipFooterUncompressedSizeBytes = 4; +constexpr int32 MaxInboundApplyLogTableContributors = 6; + +static FDatabaseUpdateType QueryRowsToDatabaseUpdate(const FQueryRowsType& Rows, UE::SpacetimeDB::EQueryRowsApplyMode Mode) { FDatabaseUpdateType Update; for (const FSingleTableRowsType& TableRows : Rows.Tables) @@ -29,13 +44,17 @@ static FDatabaseUpdateType QueryRowsToDatabaseUpdate(const FQueryRowsType& Rows, TableUpdate.TableName = TableRows.Table; FPersistentTableRowsType PersistentRows; - if (bAsDeletes) + switch (Mode) { + case UE::SpacetimeDB::EQueryRowsApplyMode::Deletes: PersistentRows.Deletes = TableRows.Rows; - } - else - { + break; + case UE::SpacetimeDB::EQueryRowsApplyMode::Inserts: PersistentRows.Inserts = TableRows.Rows; + break; + default: + checkf(false, TEXT("Unsupported query-row apply mode for table %s"), *TableRows.Table); + continue; } TableUpdate.Rows.Add(FTableUpdateRowsType::PersistentTable(PersistentRows)); Update.Tables.Add(TableUpdate); @@ -64,99 +83,195 @@ static FString DecodeReducerErrorMessage(const TArray& ErrorBytes) } return UE::SpacetimeDB::Deserialize(ErrorBytes); } + +static const TCHAR* DescribeServerMessageTag(EServerMessageTag Tag) +{ + switch (Tag) + { + case EServerMessageTag::InitialConnection: + return TEXT("InitialConnection"); + case EServerMessageTag::TransactionUpdate: + return TEXT("TransactionUpdate"); + case EServerMessageTag::OneOffQueryResult: + return TEXT("OneOffQueryResult"); + case EServerMessageTag::SubscribeApplied: + return TEXT("SubscribeApplied"); + case EServerMessageTag::UnsubscribeApplied: + return TEXT("UnsubscribeApplied"); + case EServerMessageTag::SubscriptionError: + return TEXT("SubscriptionError"); + case EServerMessageTag::ReducerResult: + return TEXT("ReducerResult"); + case EServerMessageTag::ProcedureResult: + return TEXT("ProcedureResult"); + default: + return TEXT("Unknown"); + } } - -UDbConnectionBase::UDbConnectionBase(const FObjectInitializer& ObjectInitializer) - : Super(ObjectInitializer) + +static FString FormatInboundTableApplyStats(const FSpacetimeDBTableApplyStats& Stats) { - NextRequestId = 1; - NextSubscriptionId = 1; - ProcedureCallbacks = CreateDefaultSubobject(TEXT("ProcedureCallbacks")); + return FString::Printf( + TEXT("%s rows=%d ins=%d del=%d bytes=%lld cache=%.2fus broadcast=%.2fus diff=%d"), + *Stats.TableName, + Stats.RowSetCount, + Stats.InsertRowCount, + Stats.DeleteRowCount, + Stats.InsertRowBytes + Stats.DeleteRowBytes, + Stats.CacheMicros, + Stats.BroadcastMicros, + Stats.bProducedDiff ? 1 : 0); +} } -void UDbConnectionBase::SetAutoTicking(bool bAutoTick) +class FSpacetimeDbInboundWorker final : public FRunnable { - if (bIsAutoTicking == bAutoTick) +public: + explicit FSpacetimeDbInboundWorker(UDbConnectionBase& InConnection) + : Connection(&InConnection) { - return; + WorkAvailableEvent = FPlatformProcess::GetSynchEventFromPool(false); + checkf(WorkAvailableEvent != nullptr, TEXT("Failed to allocate SpacetimeDB inbound worker event.")); + + Thread = FRunnableThread::Create( + this, + InboundWorkerThreadName, + InboundWorkerStackSizeBytes, + InboundWorkerThreadPriority); + checkf(Thread != nullptr, TEXT("Failed to create SpacetimeDB inbound worker thread.")); } - bIsAutoTicking = bAutoTick; + virtual ~FSpacetimeDbInboundWorker() override + { + StopAndJoin(); + } - if (bIsAutoTicking) + void Notify() { - if (!TickerHandle.IsValid()) + if (WorkAvailableEvent) { - TickerHandle = FTSTicker::GetCoreTicker().AddTicker(FTickerDelegate::CreateUObject(this, &UDbConnectionBase::OnTickerTick)); + WorkAvailableEvent->Trigger(); } } - else if (TickerHandle.IsValid()) + + virtual uint32 Run() override { - FTSTicker::GetCoreTicker().RemoveTicker(TickerHandle); - TickerHandle.Reset(); + while (!bStopRequested) + { + checkf(WorkAvailableEvent != nullptr, TEXT("SpacetimeDB inbound worker event was not initialized.")); + WorkAvailableEvent->Wait(); + + if (bStopRequested) + { + break; + } + + if (Connection) + { + Connection->DrainInboundRawMessagesOnWorker(); + } + } + + return 0; } + + virtual void Stop() override + { + bStopRequested = true; + Notify(); + } + + void StopAndJoin() + { + Stop(); + + if (Thread) + { + Thread->WaitForCompletion(); + delete Thread; + Thread = nullptr; + } + + if (WorkAvailableEvent) + { + FPlatformProcess::ReturnSynchEventToPool(WorkAvailableEvent); + WorkAvailableEvent = nullptr; + } + + Connection = nullptr; + } + +private: + UDbConnectionBase* Connection = nullptr; + FEvent* WorkAvailableEvent = nullptr; + FRunnableThread* Thread = nullptr; + FThreadSafeBool bStopRequested = false; +}; + +UDbConnectionBase::UDbConnectionBase(const FObjectInitializer& ObjectInitializer) + : Super(ObjectInitializer) +{ + NextRequestId = 1; + NextSubscriptionId = 1; + ProcedureCallbacks = CreateDefaultSubobject(TEXT("ProcedureCallbacks")); } void UDbConnectionBase::BeginDestroy() { - if (TickerHandle.IsValid()) + StopInboundMessageWorker(); + Super::BeginDestroy(); +} + +void UDbConnectionBase::Disconnect() +{ + StopInboundMessageWorker(); + if (WebSocket) { - FTSTicker::GetCoreTicker().RemoveTicker(TickerHandle); - TickerHandle.Reset(); + WebSocket->Disconnect(); } - bIsAutoTicking = false; +} - Super::BeginDestroy(); +bool UDbConnectionBase::IsActive() const +{ + return WebSocket && WebSocket->IsConnected(); +} + + +bool UDbConnectionBase::TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const +{ + if (bIsIdentitySet) + { + OutIdentity = Identity; + return true; + } + + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("TryGetIdentity called before identity was set")); + return false; } - -void UDbConnectionBase::Disconnect() -{ - if (WebSocket) - { - WebSocket->Disconnect(); - } -} - -bool UDbConnectionBase::IsActive() const -{ - return WebSocket && WebSocket->IsConnected(); -} - - -bool UDbConnectionBase::TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const -{ - if (bIsIdentitySet) - { - OutIdentity = Identity; - return true; - } - - UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("TryGetIdentity called before identity was set")); - return false; -} - -FSpacetimeDBConnectionId UDbConnectionBase::GetConnectionId() const -{ - return ConnectionId; -} - -bool UDbConnectionBase::SendRawMessage(const FString& Message) -{ - return WebSocket && WebSocket->SendMessage(Message); -} - -bool UDbConnectionBase::SendRawMessage(const TArray& Message) -{ - return WebSocket && WebSocket->SendMessage(Message); -} - -USubscriptionBuilderBase* UDbConnectionBase::SubscriptionBuilderBase() -{ - return NewObject(); -} - + +FSpacetimeDBConnectionId UDbConnectionBase::GetConnectionId() const +{ + return ConnectionId; +} + +bool UDbConnectionBase::SendRawMessage(const FString& Message) +{ + return WebSocket && WebSocket->SendMessage(Message); +} + +bool UDbConnectionBase::SendRawMessage(const TArray& Message) +{ + return WebSocket && WebSocket->SendMessage(Message); +} + +USubscriptionBuilderBase* UDbConnectionBase::SubscriptionBuilderBase() +{ + return NewObject(); +} + void UDbConnectionBase::HandleWSError(const FString& Error) { + StopInboundMessageWorker(); bProtocolViolationHandled = false; ClearPendingOperations(Error); if (OnConnectErrorDelegate.IsBound()) @@ -167,6 +282,7 @@ void UDbConnectionBase::HandleWSError(const FString& Error) void UDbConnectionBase::HandleWSClosed(int32 /*StatusCode*/, const FString& Reason, bool /*bWasClean*/) { + StopInboundMessageWorker(); bProtocolViolationHandled = false; ClearPendingOperations(Reason); if (OnDisconnectBaseDelegate.IsBound()) @@ -185,6 +301,7 @@ void UDbConnectionBase::HandleProtocolViolation(const FString& ErrorMessage) UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("%s"), *ErrorMessage); TriggerError(ErrorMessage); + StopInboundMessageWorker(); ClearPendingOperations(ErrorMessage); // Match Rust/C# behavior: parse/protocol violations are fatal for the connection. @@ -197,103 +314,605 @@ void UDbConnectionBase::HandleProtocolViolation(const FString& ErrorMessage) OnConnectErrorDelegate.Execute(ErrorMessage); } } - -void UDbConnectionBase::HandleWSBinaryMessage(const TArray& Message) -{ - //tag for arrival order - const int32 Id = NextPreprocessId.GetValue(); - NextPreprocessId.Increment(); - - //do expensive work off-thread - TWeakObjectPtr WeakThis(this); - Async(EAsyncExecution::Thread, [WeakThis, Message, Id]() - { - if (!WeakThis.IsValid()) - { - return; - } - UDbConnectionBase* This = WeakThis.Get(); - - //parse the message, decompress if needed - FServerMessageType Parsed; - if (!This->PreProcessMessage(Message, Parsed)) - { - AsyncTask(ENamedThreads::GameThread, [WeakThis]() + +void UDbConnectionBase::StartInboundMessageWorker() +{ + FScopeLock Lock(&InboundWorkerMutex); + if (InboundWorker) + { + return; + } + + { + FScopeLock RawLock(&InboundRawMessagesMutex); + InboundRawMessages.Reset(); + InboundQueuedRawBytes = 0; + ++InboundConnectionEpoch; + NextInboundSequenceId = 0; + bInboundAcceptingMessages = true; + bInboundProtocolErrorQueued = false; + } + + { + FScopeLock PendingLock(&PendingMessagesMutex); + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedPayloadBytes = 0; + } + + ActivePreprocessedTableData = nullptr; + InboundWorker = new FSpacetimeDbInboundWorker(*this); +} + +void UDbConnectionBase::StopInboundMessageWorker() +{ + FSpacetimeDbInboundWorker* WorkerToStop = nullptr; + { + FScopeLock Lock(&InboundWorkerMutex); + { + FScopeLock RawLock(&InboundRawMessagesMutex); + InboundRawMessages.Reset(); + InboundQueuedRawBytes = 0; + ++InboundConnectionEpoch; + bInboundAcceptingMessages = false; + bInboundProtocolErrorQueued = false; + } + + { + FScopeLock PendingLock(&PendingMessagesMutex); + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedPayloadBytes = 0; + } + + ActivePreprocessedTableData = nullptr; + WorkerToStop = InboundWorker; + InboundWorker = nullptr; + } + + if (WorkerToStop) + { + delete WorkerToStop; + } + + ClearInboundMessageQueues(); +} + +void UDbConnectionBase::ClearInboundMessageQueues() +{ + { + FScopeLock Lock(&InboundRawMessagesMutex); + InboundRawMessages.Reset(); + InboundQueuedRawBytes = 0; + } + + { + FScopeLock Lock(&PendingMessagesMutex); + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedPayloadBytes = 0; + } + + ActivePreprocessedTableData = nullptr; +} + +void UDbConnectionBase::NotifyInboundWorkerIfNeeded() +{ + FScopeLock WorkerLock(&InboundWorkerMutex); + if (InboundWorker == nullptr) + { + return; + } + + bool bShouldNotify = false; + { + FScopeLock RawLock(&InboundRawMessagesMutex); + bShouldNotify = InboundRawMessages.Num() > 0 && bInboundAcceptingMessages && !bInboundProtocolErrorQueued; + } + + if (bShouldNotify) + { + InboundWorker->Notify(); + } +} + +bool UDbConnectionBase::IsInboundProtocolErrorQueued() const +{ + FScopeLock Lock(&InboundRawMessagesMutex); + return bInboundProtocolErrorQueued; +} + +bool UDbConnectionBase::IsInboundEpochCurrentAndAccepting(uint64 ConnectionEpoch) const +{ + FScopeLock Lock(&InboundRawMessagesMutex); + return bInboundAcceptingMessages && !bInboundProtocolErrorQueued && InboundConnectionEpoch == ConnectionEpoch; +} + +void UDbConnectionBase::MarkInboundProtocolErrorQueued() +{ + FScopeLock Lock(&InboundRawMessagesMutex); + bInboundProtocolErrorQueued = true; + bInboundAcceptingMessages = false; + InboundRawMessages.Reset(); + InboundQueuedRawBytes = 0; +} + +void UDbConnectionBase::EnqueueInboundProtocolError(uint64 SequenceId, int32 PayloadSizeBytes, uint8 CompressionTag, int32 QueueDepthAtEnqueue, int64 QueuedBytesAtEnqueue, const FString& ErrorMessage) +{ + UE_LOG( + LogSpacetimeDb_Connection, + Error, + TEXT("SpacetimeDB inbound protocol error: sequence=%llu payload_bytes=%d compression_tag=%u queued_messages=%d queued_bytes=%lld detail=%s"), + SequenceId, + PayloadSizeBytes, + static_cast(CompressionTag), + QueueDepthAtEnqueue, + QueuedBytesAtEnqueue, + *ErrorMessage); + + FInboundParsedMessage Parsed; + Parsed.SequenceId = SequenceId; + Parsed.PayloadSizeBytes = PayloadSizeBytes; + Parsed.CompressionTag = CompressionTag; + Parsed.QueueDepthAtEnqueue = QueueDepthAtEnqueue; + Parsed.QueuedBytesAtEnqueue = QueuedBytesAtEnqueue; + Parsed.bProtocolError = true; + Parsed.ProtocolError = ErrorMessage; + + FScopeLock Lock(&PendingMessagesMutex); + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedPayloadBytes = 0; + PendingMessages.Add(MoveTemp(Parsed)); + PendingParsedPayloadBytes += static_cast(PayloadSizeBytes); +} + +void UDbConnectionBase::HandleWSBinaryMessage(const TArray& Message) +{ + TArray OwnedMessage = Message; + HandleWSBinaryMessageOwned(MoveTemp(OwnedMessage)); +} + +void UDbConnectionBase::HandleWSBinaryMessageOwned(TArray&& Message) +{ + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_InboundEnqueue); + + const int32 PayloadSizeBytes = Message.Num(); + const uint8 CompressionTag = PayloadSizeBytes > 0 ? Message[0] : 0; + uint64 SequenceId = 0; + int32 QueueDepthAtEnqueue = 0; + int64 QueuedBytesAtEnqueue = 0; + uint64 ConnectionEpoch = 0; + bool bQueueOverloaded = false; + FString QueueOverloadError; + + { + FScopeLock WorkerLock(&InboundWorkerMutex); + FScopeLock RawLock(&InboundRawMessagesMutex); + if (!bInboundAcceptingMessages || bInboundProtocolErrorQueued) + { + return; + } + checkf(InboundWorker != nullptr, TEXT("SpacetimeDB inbound worker missing while inbound connection epoch %llu is accepting messages."), InboundConnectionEpoch); + + ConnectionEpoch = InboundConnectionEpoch; + SequenceId = NextInboundSequenceId++; + const int64 NewQueuedRawBytes = InboundQueuedRawBytes + static_cast(PayloadSizeBytes); + const int32 NewQueuedRawMessageCount = InboundRawMessages.Num() + 1; + QueueDepthAtEnqueue = NewQueuedRawMessageCount; + QueuedBytesAtEnqueue = NewQueuedRawBytes; + if (NewQueuedRawMessageCount > MaxQueuedInboundRawMessages || NewQueuedRawBytes > MaxQueuedInboundRawBytes) + { + bInboundProtocolErrorQueued = true; + bInboundAcceptingMessages = false; + InboundRawMessages.Reset(); + InboundQueuedRawBytes = 0; + bQueueOverloaded = true; + QueueOverloadError = FString::Printf( + TEXT("SpacetimeDB inbound queue overload: sequence=%llu payload_bytes=%d compression_tag=%u queued_messages=%d queued_bytes=%lld max_messages=%d max_bytes=%lld"), + SequenceId, + PayloadSizeBytes, + static_cast(CompressionTag), + NewQueuedRawMessageCount, + NewQueuedRawBytes, + MaxQueuedInboundRawMessages, + MaxQueuedInboundRawBytes); + } + else + { + FInboundRawMessage RawMessage; + RawMessage.ConnectionEpoch = ConnectionEpoch; + RawMessage.SequenceId = SequenceId; + RawMessage.QueueDepthAtEnqueue = QueueDepthAtEnqueue; + RawMessage.QueuedBytesAtEnqueue = QueuedBytesAtEnqueue; + RawMessage.Payload = MoveTemp(Message); + InboundRawMessages.Add(MoveTemp(RawMessage)); + InboundQueuedRawBytes = NewQueuedRawBytes; + InboundWorker->Notify(); + } + } + + if (bQueueOverloaded) + { + EnqueueInboundProtocolError(SequenceId, PayloadSizeBytes, CompressionTag, QueueDepthAtEnqueue, QueuedBytesAtEnqueue, QueueOverloadError); + return; + } +} + +void UDbConnectionBase::FrameTick() +{ + int32 MessagesProcessed = 0; + int64 PayloadBytesProcessed = 0; + const uint64 FrameStartCycles = FPlatformTime::Cycles64(); + const bool bDrainAllPendingMessages = InboundApplyBudget.bDrainAllPendingMessages; + { + FScopeLock Lock(&PendingMessagesMutex); + const int32 PendingCount = PendingMessages.Num() - PendingMessageReadIndex; + if (PendingCount <= 0) + { + //nothing to process, return early + return; + } + + } + + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_GameThreadApplyInbound); + + while (true) + { + FInboundParsedMessage Msg; + { + FScopeLock Lock(&PendingMessagesMutex); + if (PendingMessageReadIndex >= PendingMessages.Num()) + { + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedPayloadBytes = 0; + break; + } + + if (!bDrainAllPendingMessages && MessagesProcessed >= InboundApplyBudget.MaxMessagesPerFrame) { - if (!WeakThis.IsValid()) + break; + } + + FInboundParsedMessage& PendingMessage = PendingMessages[PendingMessageReadIndex]; + const int64 PendingPayloadBytes = static_cast(PendingMessage.PayloadSizeBytes); + if (!bDrainAllPendingMessages && MessagesProcessed > 0 && PayloadBytesProcessed + PendingPayloadBytes > InboundApplyBudget.MaxPayloadBytesPerFrame) + { + break; + } + + PayloadBytesProcessed += PendingPayloadBytes; + PendingParsedPayloadBytes = FMath::Max(0, PendingParsedPayloadBytes - PendingPayloadBytes); + Msg = MoveTemp(PendingMessage); + ++PendingMessageReadIndex; + + if (PendingMessageReadIndex == PendingMessages.Num()) + { + PendingMessages.Reset(); + PendingMessageReadIndex = 0; + PendingParsedPayloadBytes = 0; + } + else if (PendingMessageReadIndex >= PendingInboundCompactionMinConsumedMessages) + { + PendingMessages.RemoveAt(0, PendingMessageReadIndex, EAllowShrinking::No); + PendingMessageReadIndex = 0; + } + } + + if (Msg.bProtocolError) + { + HandleProtocolViolation(Msg.ProtocolError); + break; + } + + const uint64 MessageStartCycles = FPlatformTime::Cycles64(); + FSpacetimeDBInboundMessageApplyStats ApplyStats; + ApplyStats.MessageKind = DescribeServerMessageTag(Msg.Message.Tag); + ApplyStats.SequenceId = Msg.SequenceId; + ApplyStats.PayloadSizeBytes = Msg.PayloadSizeBytes; + ApplyStats.QueueDepthAtEnqueue = Msg.QueueDepthAtEnqueue; + ApplyStats.QueuedBytesAtEnqueue = Msg.QueuedBytesAtEnqueue; + if (Msg.Message.Tag == EServerMessageTag::ReducerResult) + { + const FReducerResultType& Payload = Msg.Message.MessageData.Get(); + ApplyStats.RequestId = Payload.RequestId; + if (const FReducerCallInfoType* FoundReducerCall = PendingReducerCalls.Find(Payload.RequestId)) + { + ApplyStats.ReducerName = FoundReducerCall->ReducerName; + } + } + + ProcessInboundServerMessage(Msg, ApplyStats); + const double MessageElapsedMicros = + FPlatformTime::ToMilliseconds64(FPlatformTime::Cycles64() - MessageStartCycles) * 1000.0; + if (!bDrainAllPendingMessages && + InboundApplyBudget.SoftTimeBudgetMicros > 0 && + MessageElapsedMicros >= static_cast(InboundApplyBudget.SoftTimeBudgetMicros)) + { + TArray SortedStats = ApplyStats.TableStats; + SortedStats.Sort([](const FSpacetimeDBTableApplyStats& A, const FSpacetimeDBTableApplyStats& B) + { + return (A.CacheMicros + A.BroadcastMicros) > (B.CacheMicros + B.BroadcastMicros); + }); + TArray TopTableSummaries; + const int32 LoggedTableCount = FMath::Min(SortedStats.Num(), MaxInboundApplyLogTableContributors); + TopTableSummaries.Reserve(LoggedTableCount); + for (int32 TableIndex = 0; TableIndex < LoggedTableCount; ++TableIndex) + { + TopTableSummaries.Add(FormatInboundTableApplyStats(SortedStats[TableIndex])); + } + const FString TopTablesText = TopTableSummaries.IsEmpty() + ? TEXT("") + : FString::Join(TopTableSummaries, TEXT(" | ")); + UE_LOG(LogSpacetimeDb_Connection, + Warning, + TEXT("SpacetimeDB inbound single-message apply exceeded soft budget: %.2fus >= %lldus kind=%s sequence=%llu request_id=%u reducer=%s payload_bytes=%d queued_messages=%d queued_bytes=%lld messages_processed_before=%d tables=%d top_tables=%s"), + MessageElapsedMicros, + InboundApplyBudget.SoftTimeBudgetMicros, + *ApplyStats.MessageKind, + ApplyStats.SequenceId, + ApplyStats.RequestId, + ApplyStats.ReducerName.IsEmpty() ? TEXT("") : *ApplyStats.ReducerName, + ApplyStats.PayloadSizeBytes, + ApplyStats.QueueDepthAtEnqueue, + ApplyStats.QueuedBytesAtEnqueue, + MessagesProcessed, + ApplyStats.TableStats.Num(), + *TopTablesText); + } + ++MessagesProcessed; + + if (!bDrainAllPendingMessages && + MessagesProcessed >= InboundApplyBudget.MinMessagesPerFrame && + InboundApplyBudget.SoftTimeBudgetMicros > 0) + { + const double ElapsedMicros = FPlatformTime::ToMilliseconds64(FPlatformTime::Cycles64() - FrameStartCycles) * 1000.0; + if (ElapsedMicros >= static_cast(InboundApplyBudget.SoftTimeBudgetMicros)) + { + break; + } + } + } + + if (MessagesProcessed > 0) + { + NotifyInboundWorkerIfNeeded(); + } +} + +void UDbConnectionBase::DrainInboundRawMessagesOnWorker() +{ + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_InboundWorkerDrain); + + while (!IsInboundProtocolErrorQueued()) + { + int32 ParsedMessageCapacity = 0; + int64 ParsedPayloadByteCapacity = 0; + { + FScopeLock Lock(&PendingMessagesMutex); + const int32 LivePendingMessages = PendingMessages.Num() - PendingMessageReadIndex; + ParsedMessageCapacity = MaxPendingInboundParsedMessages - LivePendingMessages; + ParsedPayloadByteCapacity = MaxPendingInboundParsedPayloadBytes - PendingParsedPayloadBytes; + } + + if (ParsedMessageCapacity <= 0 || ParsedPayloadByteCapacity <= 0) + { + return; + } + + TArray LocalRawMessages; + int64 DrainedRawBytes = 0; + { + FScopeLock Lock(&InboundRawMessagesMutex); + if (InboundRawMessages.Num() == 0 || !bInboundAcceptingMessages || bInboundProtocolErrorQueued) + { + return; + } + + int32 DrainCount = 0; + for (; DrainCount < InboundRawMessages.Num() && DrainCount < ParsedMessageCapacity; ++DrainCount) + { + const int64 NextPayloadBytes = static_cast(InboundRawMessages[DrainCount].Payload.Num()); + if (DrainCount > 0 && DrainedRawBytes + NextPayloadBytes > ParsedPayloadByteCapacity) + { + break; + } + if (DrainCount == 0 && NextPayloadBytes > ParsedPayloadByteCapacity) { return; } - UDbConnectionBase* Conn = WeakThis.Get(); - Conn->HandleProtocolViolation(TEXT("Failed to parse/decompress incoming WebSocket message")); - }); + + DrainedRawBytes += NextPayloadBytes; + } + + if (DrainCount == 0) + { + return; + } + + LocalRawMessages.Reserve(DrainCount); + for (int32 Index = 0; Index < DrainCount; ++Index) + { + LocalRawMessages.Add(MoveTemp(InboundRawMessages[Index])); + } + + InboundRawMessages.RemoveAt(0, DrainCount, EAllowShrinking::No); + InboundQueuedRawBytes = FMath::Max(0, InboundQueuedRawBytes - DrainedRawBytes); + } + + TArray LocalParsedMessages; + LocalParsedMessages.Reserve(LocalRawMessages.Num()); + + for (FInboundRawMessage& RawMessage : LocalRawMessages) + { + if (!IsInboundEpochCurrentAndAccepting(RawMessage.ConnectionEpoch)) + { + return; + } + + FInboundParsedMessage ParsedMessage; + if (!BuildInboundParsedMessage(RawMessage, ParsedMessage)) + { + if (!IsInboundEpochCurrentAndAccepting(RawMessage.ConnectionEpoch)) + { + return; + } + MarkInboundProtocolErrorQueued(); + LocalParsedMessages.Add(MoveTemp(ParsedMessage)); + break; + } + + if (!IsInboundEpochCurrentAndAccepting(RawMessage.ConnectionEpoch)) + { + return; + } + LocalParsedMessages.Add(MoveTemp(ParsedMessage)); + } + + if (LocalParsedMessages.Num() == 0) + { + continue; + } + + const bool bBatchEndsWithProtocolError = LocalParsedMessages.Last().bProtocolError; + if (IsInboundProtocolErrorQueued() && !bBatchEndsWithProtocolError) + { return; } - - //queue: re-order buffer - TArray Ready; - { - FScopeLock Lock(&This->PreprocessMutex); - // Move the parsed message into the map to avoid copying - This->PreprocessedMessages.Add(Id, MoveTemp(Parsed)); - //check if we can release any messages in order - while (This->PreprocessedMessages.Contains(This->NextReleaseId)) - { - Ready.Add(This->PreprocessedMessages.FindAndRemoveChecked(This->NextReleaseId)); - ++This->NextReleaseId; - } - } - //if we have any ready messages, append them to the pending messages list that is processed in Tick - if (Ready.Num() > 0) - { - FScopeLock Lock(&This->PendingMessagesMutex); - This->PendingMessages.Append(MoveTemp(Ready)); - } - }); -} - -void UDbConnectionBase::FrameTick() -{ - TArray Local; - { - FScopeLock Lock(&PendingMessagesMutex); - if (PendingMessages.Num() == 0) - { - //nothing to process, return early - return; - } - //move pending messages to local array for processing - Local = MoveTemp(PendingMessages); - PendingMessages.Empty(); - } - - //process all messages in the local array - for (const FServerMessageType& Msg : Local) - { - //process the message, this will call DbUpdate or trigger subscription events as needed - ProcessServerMessage(Msg); + if (!bBatchEndsWithProtocolError && !IsInboundEpochCurrentAndAccepting(LocalParsedMessages[0].ConnectionEpoch)) + { + return; + } + + uint64 OverloadSequenceId = LocalParsedMessages[0].SequenceId; + int32 OverloadPayloadSizeBytes = LocalParsedMessages[0].PayloadSizeBytes; + uint8 OverloadCompressionTag = LocalParsedMessages[0].CompressionTag; + { + FScopeLock Lock(&PendingMessagesMutex); + int64 AddedPayloadBytes = 0; + for (const FInboundParsedMessage& ParsedMessage : LocalParsedMessages) + { + AddedPayloadBytes += static_cast(ParsedMessage.PayloadSizeBytes); + } + + const int32 LivePendingMessages = PendingMessages.Num() - PendingMessageReadIndex; + const int32 NewPendingMessageCount = LivePendingMessages + LocalParsedMessages.Num(); + const int64 NewPendingPayloadBytes = PendingParsedPayloadBytes + AddedPayloadBytes; + checkf(bBatchEndsWithProtocolError || + (NewPendingMessageCount <= MaxPendingInboundParsedMessages && + NewPendingPayloadBytes <= MaxPendingInboundParsedPayloadBytes), + TEXT("SpacetimeDB parsed inbound queue overflow despite worker backpressure: sequence=%llu payload_bytes=%d compression_tag=%u queued_messages=%d queued_bytes=%lld max_messages=%d max_bytes=%lld"), + OverloadSequenceId, + OverloadPayloadSizeBytes, + static_cast(OverloadCompressionTag), + NewPendingMessageCount, + NewPendingPayloadBytes, + MaxPendingInboundParsedMessages, + MaxPendingInboundParsedPayloadBytes); + + PendingMessages.Append(MoveTemp(LocalParsedMessages)); + PendingParsedPayloadBytes = NewPendingPayloadBytes; + } } } -bool UDbConnectionBase::OnTickerTick(float DeltaTime) +bool UDbConnectionBase::BuildInboundParsedMessage(const FInboundRawMessage& RawMessage, FInboundParsedMessage& OutMessage) { - if (HasAnyFlags(RF_BeginDestroyed | RF_FinishDestroyed) || !bIsAutoTicking) + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_InboundPreprocess); + + OutMessage.ConnectionEpoch = RawMessage.ConnectionEpoch; + OutMessage.SequenceId = RawMessage.SequenceId; + OutMessage.PayloadSizeBytes = RawMessage.Payload.Num(); + OutMessage.CompressionTag = RawMessage.Payload.Num() > 0 ? RawMessage.Payload[0] : 0; + OutMessage.QueueDepthAtEnqueue = RawMessage.QueueDepthAtEnqueue; + OutMessage.QueuedBytesAtEnqueue = RawMessage.QueuedBytesAtEnqueue; + + if (!PreProcessMessage(RawMessage.Payload, OutMessage)) { + OutMessage.bProtocolError = true; + OutMessage.ProtocolError = FString::Printf( + TEXT("Failed to parse/decompress incoming WebSocket message: sequence=%llu payload_bytes=%d compression_tag=%u queued_messages=%d queued_bytes=%lld"), + OutMessage.SequenceId, + OutMessage.PayloadSizeBytes, + static_cast(OutMessage.CompressionTag), + OutMessage.QueueDepthAtEnqueue, + OutMessage.QueuedBytesAtEnqueue); return false; } - FrameTick(); return true; } - - + +void UDbConnectionBase::Tick(float DeltaTime) +{ + if (bIsAutoTicking) + { + FrameTick(); + } +} + +TStatId UDbConnectionBase::GetStatId() const +{ + // This is used by the engine to track tickables, we return a unique stat ID for this class + RETURN_QUICK_DECLARE_CYCLE_STAT(UMyTickableObject, STATGROUP_Tickables); +} + +bool UDbConnectionBase::IsTickable() const +{ + return bIsAutoTicking; +} + +bool UDbConnectionBase::IsTickableInEditor() const +{ + return bIsAutoTicking; +} + + +void UDbConnectionBase::ProcessInboundServerMessage(FInboundParsedMessage& InboundMessage, FSpacetimeDBInboundMessageApplyStats& ApplyStats) +{ + struct FActivePreprocessedDataGuard + { + FPreprocessedTableDataMap*& Target; + FSpacetimeDBInboundMessageApplyStats*& StatsTarget; + + FActivePreprocessedDataGuard( + FPreprocessedTableDataMap*& InTarget, + FPreprocessedTableDataMap* InValue, + FSpacetimeDBInboundMessageApplyStats*& InStatsTarget, + FSpacetimeDBInboundMessageApplyStats* InStatsValue) + : Target(InTarget) + , StatsTarget(InStatsTarget) + { + checkf(Target == nullptr, TEXT("Nested SpacetimeDB inbound table preprocessing scope detected.")); + checkf(StatsTarget == nullptr, TEXT("Nested SpacetimeDB inbound apply stats scope detected.")); + Target = InValue; + StatsTarget = InStatsValue; + } + + ~FActivePreprocessedDataGuard() + { + Target = nullptr; + StatsTarget = nullptr; + } + }; + + FActivePreprocessedDataGuard Guard( + ActivePreprocessedTableData, + &InboundMessage.PreprocessedTableData, + ActiveInboundMessageApplyStats, + &ApplyStats); + ProcessServerMessage(InboundMessage.Message); +} + void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) { switch (Message.Tag) { case EServerMessageTag::InitialConnection: { - const FInitialConnectionType Payload = Message.GetAsInitialConnection(); + const FInitialConnectionType& Payload = Message.MessageData.Get(); Token = Payload.Token; UCredentials::SaveToken(Token); Identity = Payload.Identity; @@ -307,7 +926,7 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::TransactionUpdate: { - const FTransactionUpdateType Payload = Message.GetAsTransactionUpdate(); + const FTransactionUpdateType& Payload = Message.MessageData.Get(); const FDatabaseUpdateType Update = TransactionUpdateToDatabaseUpdate(Payload); DbUpdate(Update, FSpacetimeDBEvent::Transaction(FSpacetimeDBUnit())); break; @@ -319,8 +938,8 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::SubscribeApplied: { - const FSubscribeAppliedType Payload = Message.GetAsSubscribeApplied(); - const FDatabaseUpdateType Update = QueryRowsToDatabaseUpdate(Payload.Rows, false); + const FSubscribeAppliedType& Payload = Message.MessageData.Get(); + const FDatabaseUpdateType Update = QueryRowsToDatabaseUpdate(Payload.Rows, UE::SpacetimeDB::EQueryRowsApplyMode::Inserts); DbUpdate(Update, FSpacetimeDBEvent::SubscribeApplied(FSpacetimeDBUnit())); if (TObjectPtr* HandlePtr = ActiveSubscriptions.Find(Payload.QuerySetId.Id)) @@ -339,10 +958,10 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::UnsubscribeApplied: { - const FUnsubscribeAppliedType Payload = Message.GetAsUnsubscribeApplied(); + const FUnsubscribeAppliedType& Payload = Message.MessageData.Get(); if (Payload.Rows.IsSet()) { - const FDatabaseUpdateType Update = QueryRowsToDatabaseUpdate(Payload.Rows.Value, true); + const FDatabaseUpdateType Update = QueryRowsToDatabaseUpdate(Payload.Rows.Value, UE::SpacetimeDB::EQueryRowsApplyMode::Deletes); DbUpdate(Update, FSpacetimeDBEvent::UnsubscribeApplied(FSpacetimeDBUnit())); } @@ -369,7 +988,7 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::SubscriptionError: { - const FSubscriptionErrorType Payload = Message.GetAsSubscriptionError(); + const FSubscriptionErrorType& Payload = Message.MessageData.Get(); UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("SubscriptionError received for QuerySetId=%u Error=%s"), Payload.QuerySetId.Id, *Payload.Error); @@ -391,7 +1010,7 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::ReducerResult: { - const FReducerResultType Payload = Message.GetAsReducerResult(); + const FReducerResultType& Payload = Message.MessageData.Get(); const FReducerCallInfoType* FoundReducerCall = PendingReducerCalls.Find(Payload.RequestId); if (!FoundReducerCall) { @@ -415,7 +1034,7 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) if (Payload.Result.IsOk()) { RedEvent.Status = FSpacetimeDBStatus::Committed(FSpacetimeDBUnit()); - const FReducerOkType Ok = Payload.Result.GetAsOk(); + const FReducerOkType& Ok = Payload.Result.MessageData.Get(); const FDatabaseUpdateType Update = TransactionUpdateToDatabaseUpdate(Ok.TransactionUpdate); DbUpdate(Update, FSpacetimeDBEvent::Reducer(RedEvent)); ReducerEvent(RedEvent); @@ -430,11 +1049,11 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) FString ErrorMessage; if (Payload.Result.IsErr()) { - ErrorMessage = DecodeReducerErrorMessage(Payload.Result.GetAsErr()); + ErrorMessage = DecodeReducerErrorMessage(Payload.Result.MessageData.Get>()); } else { - ErrorMessage = Payload.Result.GetAsInternalError(); + ErrorMessage = Payload.Result.MessageData.Get(); } RedEvent.Status = FSpacetimeDBStatus::Failed(ErrorMessage); ReducerEvent(RedEvent); @@ -444,7 +1063,7 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) } case EServerMessageTag::ProcedureResult: { - const FProcedureResultType Payload = Message.GetAsProcedureResult(); + const FProcedureResultType& Payload = Message.MessageData.Get(); FProcedureEvent ProcEvent; ProcEvent.Status = Payload.Status; ProcEvent.Timestamp = Payload.Timestamp; @@ -488,54 +1107,39 @@ void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) break; } } - -bool UDbConnectionBase::DecompressBrotli(const TArray& InData, TArray& OutData) -{ - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Brotli decompression unavilable")); - return false; -} - -bool UDbConnectionBase::DecompressGzip(const TArray& InData, TArray& OutData) -{ - if (InData.Num() < 4) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Gzip data too small")); - return false; - } - - // Gzip data ends with 4 bytes indicating the uncompressed size - const uint8* SizePtr = InData.GetData() + InData.Num() - 4; - uint32 OutSize = SizePtr[0] | (SizePtr[1] << 8) | (SizePtr[2] << 16) | (SizePtr[3] << 24); - - // Validate the output size - OutData.SetNumUninitialized(OutSize); - // Attempt to decompress the Gzip data - if (!FCompression::UncompressMemory(NAME_Gzip, OutData.GetData(), OutSize, InData.GetData(), InData.Num())) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Gzip decompression failed")); - return false; - } - - OutData.SetNum(OutSize); - return true; -} - -bool UDbConnectionBase::DecompressPayload(uint8 Variant, const TArray& In, TArray& Out) -{ - switch (static_cast(Variant)) + +bool UDbConnectionBase::DecompressBrotli(const uint8* InData, int32 InSize, TArray& OutData) +{ + (void)InData; + (void)InSize; + (void)OutData; + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Brotli decompression unavilable")); + return false; +} + +bool UDbConnectionBase::DecompressGzip(const uint8* InData, int32 InSize, TArray& OutData) +{ + if (InData == nullptr || InSize < GzipFooterUncompressedSizeBytes) { - case EWsCompressionTag::Uncompressed: - // No compression, just copy the data - Out = In; - return true; - case EWsCompressionTag::Brotli: - return DecompressBrotli(In, Out); - case EWsCompressionTag::Gzip: - return DecompressGzip(In, Out); - default: - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Unknown compression variant")); + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Gzip data too small")); return false; } + + // Gzip data ends with 4 bytes indicating the uncompressed size + const uint8* SizePtr = InData + InSize - GzipFooterUncompressedSizeBytes; + uint32 OutSize = SizePtr[0] | (SizePtr[1] << 8) | (SizePtr[2] << 16) | (SizePtr[3] << 24); + + // Validate the output size + OutData.SetNumUninitialized(OutSize); + // Attempt to decompress the Gzip data + if (!FCompression::UncompressMemory(NAME_Gzip, OutData.GetData(), OutSize, InData, InSize)) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Gzip decompression failed")); + return false; + } + + OutData.SetNum(OutSize); + return true; } void UDbConnectionBase::ClearPendingOperations(const FString& Reason) @@ -550,97 +1154,159 @@ void UDbConnectionBase::ClearPendingOperations(const FString& Reason) UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("Cleared pending operations due to connection issue: %s"), *Reason); } } - -void UDbConnectionBase::PreProcessDatabaseUpdate(const FDatabaseUpdateType& Update) + +void UDbConnectionBase::PreProcessTableUpdateRows( + const FString& TableName, + const TArray& RowSets, + FPreprocessedTableDataMap& OutPreprocessedTableData) { - for (const FTableUpdateType& TableUpdate : Update.Tables) + TSharedPtr Deserializer = FindTableDeserializerForPreprocess(TableName); + if (!Deserializer.IsValid()) { - // Attempt to deserialize rows after payload decode. - TSharedPtr Deserializer; - { - // Find the deserializer for this table - FScopeLock Lock(&TableDeserializersMutex); - if (TSharedPtr* Found = TableDeserializers.Find(TableUpdate.TableName)) - { - // If found, use the deserializer - Deserializer = *Found; - } - else - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("No deserializer found for table %s"), *TableUpdate.TableName); - } - } - if (Deserializer) - { - TSharedPtr Data = Deserializer->PreProcess(TableUpdate.Rows, TableUpdate.TableName); - if (Data.IsValid()) - { - FScopeLock Lock(&PreprocessedDataMutex); - FPreprocessedTableKey Key(TableUpdate.TableName); - TArray>& Queue = PreprocessedTableData.FindOrAdd(Key); - Queue.Add(Data); - } - } - else - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Skipping table %s updates due to missing deserializer"), *TableUpdate.TableName); - } - } -} - -bool UDbConnectionBase::PreProcessMessage(const TArray& Message, FServerMessageType& OutMessage) -{ - if (Message.Num() == 0) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Empty message recived from server, ignored")); + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Skipping table %s updates due to missing deserializer"), *TableName); + return; + } + + StorePreprocessedTableData(TableName, Deserializer->PreProcess(RowSets, TableName), OutPreprocessedTableData); +} + +void UDbConnectionBase::PreProcessQueryRows( + const FQueryRowsType& Rows, + UE::SpacetimeDB::EQueryRowsApplyMode Mode, + FPreprocessedTableDataMap& OutPreprocessedTableData) +{ + for (const FSingleTableRowsType& TableRows : Rows.Tables) + { + TSharedPtr Deserializer = FindTableDeserializerForPreprocess(TableRows.Table); + if (!Deserializer.IsValid()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Skipping table %s query rows due to missing deserializer"), *TableRows.Table); + continue; + } + + StorePreprocessedTableData(TableRows.Table, Deserializer->PreProcessQueryRows(TableRows.Rows, Mode, TableRows.Table), OutPreprocessedTableData); + } +} + +void UDbConnectionBase::PreProcessTransactionUpdate( + const FTransactionUpdateType& Update, + FPreprocessedTableDataMap& OutPreprocessedTableData) +{ + for (const FQuerySetUpdateType& QuerySet : Update.QuerySets) + { + for (const FTableUpdateType& TableUpdate : QuerySet.Tables) + { + PreProcessTableUpdateRows(TableUpdate.TableName, TableUpdate.Rows, OutPreprocessedTableData); + } + } +} + +TSharedPtr UDbConnectionBase::FindTableDeserializerForPreprocess(const FString& TableName) +{ + FScopeLock Lock(&TableDeserializersMutex); + if (TSharedPtr* Found = TableDeserializers.Find(TableName)) + { + return *Found; + } + return nullptr; +} + +void UDbConnectionBase::StorePreprocessedTableData( + const FString& TableName, + TSharedPtr Data, + FPreprocessedTableDataMap& OutPreprocessedTableData) +{ + checkf(Data.IsValid(), TEXT("Invalid message-scoped preprocessed data generated for table '%s'."), *TableName); + + FPreprocessedTableKey Key(TableName); + TArray>& Queue = OutPreprocessedTableData.FindOrAdd(Key); + Queue.Add(MoveTemp(Data)); +} + +bool UDbConnectionBase::PreProcessMessage(const TArray& Message, FInboundParsedMessage& OutMessage) +{ + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_PreProcessMessage); + + if (Message.Num() <= SpacetimeDbCompressionTagBytes) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Empty message received from server, ignored")); return false; } // The first byte indicates compression format for the payload. const uint8 Compression = Message[0]; - TArray CompressedPayload; - CompressedPayload.Append(Message.GetData() + 1, Message.Num() - 1); - - // Decompress the payload based on the compression tag - TArray Decompressed; - if (!DecompressPayload(Compression, CompressedPayload, Decompressed)) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Failed to decompress incoming message")); + + const uint8* DecodedPayload = Message.GetData() + SpacetimeDbCompressionTagBytes; + int32 DecodedPayloadSize = Message.Num() - SpacetimeDbCompressionTagBytes; + TArray DecompressedStorage; + switch (static_cast(Compression)) + { + case EWsCompressionTag::Uncompressed: + break; + case EWsCompressionTag::Brotli: + if (!DecompressBrotli(DecodedPayload, DecodedPayloadSize, DecompressedStorage)) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Failed to decompress Brotli incoming message")); + return false; + } + DecodedPayload = DecompressedStorage.GetData(); + DecodedPayloadSize = DecompressedStorage.Num(); + break; + case EWsCompressionTag::Gzip: + if (!DecompressGzip(DecodedPayload, DecodedPayloadSize, DecompressedStorage)) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Failed to decompress Gzip incoming message")); + return false; + } + DecodedPayload = DecompressedStorage.GetData(); + DecodedPayloadSize = DecompressedStorage.Num(); + break; + default: + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Unknown compression variant")); + return false; + } + if (DecodedPayloadSize <= 0 || DecodedPayload == nullptr) + { + UE_LOG(LogSpacetimeDb_Connection, + Error, + TEXT("SpacetimeDB decoded server message payload is empty after compression tag %u."), + static_cast(Compression)); return false; } // Deserialize the decompressed data into a UServerMessageType object - OutMessage = UE::SpacetimeDB::Deserialize(Decompressed); + OutMessage.Message = UE::SpacetimeDB::DeserializeView(DecodedPayload, DecodedPayloadSize); // Preprocess row-bearing payloads for table deserializers. - switch (OutMessage.Tag) + switch (OutMessage.Message.Tag) { case EServerMessageTag::SubscribeApplied: { - const FSubscribeAppliedType Payload = OutMessage.GetAsSubscribeApplied(); - PreProcessDatabaseUpdate(QueryRowsToDatabaseUpdate(Payload.Rows, false)); + const FSubscribeAppliedType& Payload = OutMessage.Message.MessageData.Get(); + PreProcessQueryRows(Payload.Rows, UE::SpacetimeDB::EQueryRowsApplyMode::Inserts, OutMessage.PreprocessedTableData); break; } case EServerMessageTag::UnsubscribeApplied: { - const FUnsubscribeAppliedType Payload = OutMessage.GetAsUnsubscribeApplied(); + const FUnsubscribeAppliedType& Payload = OutMessage.Message.MessageData.Get(); if (Payload.Rows.IsSet()) { - PreProcessDatabaseUpdate(QueryRowsToDatabaseUpdate(Payload.Rows.Value, true)); + PreProcessQueryRows(Payload.Rows.Value, UE::SpacetimeDB::EQueryRowsApplyMode::Deletes, OutMessage.PreprocessedTableData); } break; } case EServerMessageTag::TransactionUpdate: { - const FTransactionUpdateType Payload = OutMessage.GetAsTransactionUpdate(); - PreProcessDatabaseUpdate(TransactionUpdateToDatabaseUpdate(Payload)); + const FTransactionUpdateType& Payload = OutMessage.Message.MessageData.Get(); + PreProcessTransactionUpdate(Payload, OutMessage.PreprocessedTableData); break; } case EServerMessageTag::ReducerResult: { - const FReducerResultType Payload = OutMessage.GetAsReducerResult(); + const FReducerResultType& Payload = OutMessage.Message.MessageData.Get(); if (Payload.Result.IsOk()) { - PreProcessDatabaseUpdate(TransactionUpdateToDatabaseUpdate(Payload.Result.GetAsOk().TransactionUpdate)); + const FReducerOkType& Ok = Payload.Result.MessageData.Get(); + PreProcessTransactionUpdate(Ok.TransactionUpdate, OutMessage.PreprocessedTableData); } break; } @@ -649,8 +1315,8 @@ bool UDbConnectionBase::PreProcessMessage(const TArray& Message, FServerM } return true; } - - + + uint32 UDbConnectionBase::GetNextRequestId() { return NextRequestId++; @@ -660,21 +1326,21 @@ uint32 UDbConnectionBase::GetNextSubscriptionId() { return NextSubscriptionId++; } - + void UDbConnectionBase::StartSubscription(USubscriptionHandleBase* Handle) -{ - if (!Handle) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("StartSubscription called with null handle")); - return; - } - - if (Handle->QuerySqls.Num() == 0) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("StartSubscription called with empty query list")); - return; - } - +{ + if (!Handle) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("StartSubscription called with null handle")); + return; + } + + if (Handle->QuerySqls.Num() == 0) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("StartSubscription called with empty query list")); + return; + } + const uint32 QuerySetId = GetNextSubscriptionId(); Handle->QuerySetId = QuerySetId; Handle->ConnInternal = this; @@ -689,14 +1355,14 @@ void UDbConnectionBase::StartSubscription(USubscriptionHandleBase* Handle) TArray Data = UE::SpacetimeDB::Serialize(Msg); SendRawMessage(Data); } - -void UDbConnectionBase::UnsubscribeInternal(USubscriptionHandleBase* Handle) -{ - if (!Handle || Handle->bEnded) - { - return; - } - + +void UDbConnectionBase::UnsubscribeInternal(USubscriptionHandleBase* Handle) +{ + if (!Handle || Handle->bEnded) + { + return; + } + const uint32 QuerySetId = Handle->QuerySetId; FUnsubscribeType MsgData; MsgData.RequestId = GetNextRequestId(); @@ -707,7 +1373,7 @@ void UDbConnectionBase::UnsubscribeInternal(USubscriptionHandleBase* Handle) TArray Data = UE::SpacetimeDB::Serialize(Msg); SendRawMessage(Data); } - + uint32 UDbConnectionBase::InternalCallReducer(const FString& Reducer, TArray Args) { if (!WebSocket || !WebSocket->IsConnected()) @@ -732,51 +1398,124 @@ uint32 UDbConnectionBase::InternalCallReducer(const FString& Reducer, TArray Args, const FOnProcedureCompleteDelegate& Callback) -{ - if (!WebSocket || !WebSocket->IsConnected()) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Cannot call proceduer, not connected to server!")); - return; - } - FCallProcedureType MsgData; - MsgData.Procedure = ProcedureName; - MsgData.Args = Args; - MsgData.RequestId = ProcedureCallbacks->RegisterCallback(Callback); - MsgData.Flags = static_cast(EProcedureFlags::Default); - - FClientMessageType Msg = FClientMessageType::CallProcedure(MsgData); - TArray Data = UE::SpacetimeDB::Serialize(Msg); - SendRawMessage(Data); -} - -void UDbConnectionBase::ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context) -{ - // Ensure we have a valid context for the update - TArray> Handlers; - for (const FTableUpdateType& TableUpdate : Update.Tables) - { - TSharedPtr Handler; - { - // Find the handler for this table update - FScopeLock Lock(&RegisteredTablesMutex); - if (TSharedPtr* Found = RegisteredTables.Find(TableUpdate.TableName)) - { - Handler = *Found; - } - } - if (Handler.IsValid()) - { - // Update the cache for the handler with the table update and context - Handler->UpdateCache(this, TableUpdate, Context); - Handlers.Add(Handler); - } - } - - for (TSharedPtr& Handler : Handlers) - { - // Broadcast the diff for each handler - Handler->BroadcastDiff(this, Context); - } + +void UDbConnectionBase::InternalCallProcedure(const FString& ProcedureName, TArray Args, const FOnProcedureCompleteDelegate& Callback) +{ + if (!WebSocket || !WebSocket->IsConnected()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Cannot call proceduer, not connected to server!")); + return; + } + FCallProcedureType MsgData; + MsgData.Procedure = ProcedureName; + MsgData.Args = Args; + MsgData.RequestId = ProcedureCallbacks->RegisterCallback(Callback); + MsgData.Flags = static_cast(EProcedureFlags::Default); + + FClientMessageType Msg = FClientMessageType::CallProcedure(MsgData); + TArray Data = UE::SpacetimeDB::Serialize(Msg); + SendRawMessage(Data); +} + +void UDbConnectionBase::ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context) +{ + checkf(ActivePreprocessedTableData != nullptr, TEXT("ApplyRegisteredTableUpdates requires active message-scoped preprocessed data.")); + + TSharedPtr>> RegisteredHandlers; + { + FScopeLock Lock(&RegisteredTablesMutex); + RegisteredHandlers = RegisteredTablesSnapshot; + } + if (!RegisteredHandlers.IsValid()) + { + return; + } + + TableUpdateHandlersScratch.Reset(); + TableUpdateHandlersScratch.Reserve(Update.Tables.Num()); + if (ActiveInboundMessageApplyStats != nullptr) + { + ActiveInboundMessageApplyStats->TableStats.Reserve( + ActiveInboundMessageApplyStats->TableStats.Num() + Update.Tables.Num()); + } + for (const FTableUpdateType& TableUpdate : Update.Tables) + { + if (TableUpdate.Rows.IsEmpty()) + { + continue; + } + + TSharedPtr Handler; + if (const TSharedPtr* Found = RegisteredHandlers->Find(TableUpdate.TableName)) + { + Handler = *Found; + } + if (Handler.IsValid()) + { + FSpacetimeDBTableApplyStats* TableStats = nullptr; + int32 TableStatsIndex = INDEX_NONE; + if (ActiveInboundMessageApplyStats != nullptr) + { + TableStatsIndex = ActiveInboundMessageApplyStats->TableStats.AddDefaulted(); + TableStats = &ActiveInboundMessageApplyStats->TableStats[TableStatsIndex]; + TableStats->TableName = Handler->GetTableName(); + } + + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_TableUpdateCache); + const uint64 TableCacheStartCycles = FPlatformTime::Cycles64(); + const bool bHasNonEmptyDiff = Handler->UpdateCache(this, TableUpdate, Context, TableStats); + const double TableCacheElapsedMicros = + FPlatformTime::ToMilliseconds64(FPlatformTime::Cycles64() - TableCacheStartCycles) * 1000.0; + if (TableStats != nullptr) + { + TableStats->CacheMicros = TableCacheElapsedMicros; + } + if (InboundApplyBudget.SoftTimeBudgetMicros > 0 && + TableCacheElapsedMicros >= static_cast(InboundApplyBudget.SoftTimeBudgetMicros)) + { + UE_LOG(LogSpacetimeDb_Connection, + Warning, + TEXT("SpacetimeDB table cache apply exceeded soft budget: table=%s elapsed=%.2fus budget=%lldus row_ops=%d"), + *Handler->GetTableName(), + TableCacheElapsedMicros, + InboundApplyBudget.SoftTimeBudgetMicros, + TableUpdate.Rows.Num()); + } + if (bHasNonEmptyDiff) + { + FPendingTableBroadcast& PendingBroadcast = TableUpdateHandlersScratch.AddDefaulted_GetRef(); + PendingBroadcast.Handler = Handler; + PendingBroadcast.StatsIndex = TableStatsIndex; + } + } + } + + for (FPendingTableBroadcast& PendingBroadcast : TableUpdateHandlersScratch) + { + TSharedPtr& Handler = PendingBroadcast.Handler; + checkf(Handler.IsValid(), TEXT("Invalid pending SpacetimeDB table broadcast handler.")); + TRACE_CPUPROFILER_EVENT_SCOPE(SpacetimeDB_TableBroadcastDiff); + const uint64 BroadcastStartCycles = FPlatformTime::Cycles64(); + Handler->BroadcastDiff(this, Context); + const double BroadcastElapsedMicros = + FPlatformTime::ToMilliseconds64(FPlatformTime::Cycles64() - BroadcastStartCycles) * 1000.0; + if (ActiveInboundMessageApplyStats != nullptr && PendingBroadcast.StatsIndex != INDEX_NONE) + { + checkf(ActiveInboundMessageApplyStats->TableStats.IsValidIndex(PendingBroadcast.StatsIndex), + TEXT("Invalid SpacetimeDB inbound apply stats index %d."), + PendingBroadcast.StatsIndex); + ActiveInboundMessageApplyStats->TableStats[PendingBroadcast.StatsIndex].BroadcastMicros = BroadcastElapsedMicros; + } + if (InboundApplyBudget.SoftTimeBudgetMicros > 0 && + BroadcastElapsedMicros >= static_cast(InboundApplyBudget.SoftTimeBudgetMicros)) + { + UE_LOG(LogSpacetimeDb_Connection, + Warning, + TEXT("SpacetimeDB table broadcast exceeded soft budget: table=%s elapsed=%.2fus budget=%lldus"), + *Handler->GetTableName(), + BroadcastElapsedMicros, + InboundApplyBudget.SoftTimeBudgetMicros); + } + } + TableUpdateHandlersScratch.Reset(); } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp index 64f2bbc66b0..e17fbb04150 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp @@ -136,13 +136,14 @@ UDbConnectionBase* UDbConnectionBuilderBase::BuildConnection(UDbConnectionBase* *ModuleName, *CompressionName); - Connection->WebSocket->OnConnectionError.AddDynamic(Connection, &UDbConnectionBase::HandleWSError); - Connection->WebSocket->OnClosed.AddDynamic(Connection, &UDbConnectionBase::HandleWSClosed); - Connection->WebSocket->OnBinaryMessageReceived.AddDynamic(Connection, &UDbConnectionBase::HandleWSBinaryMessage); - // Set the initialization token for the WebSocket connection - Connection->WebSocket->SetInitToken(Token); - // Connect the WebSocket to the constructed URL - Connection->WebSocket->Connect(WebSocketUrl); + Connection->WebSocket->OnConnectionError.AddDynamic(Connection, &UDbConnectionBase::HandleWSError); + Connection->WebSocket->OnClosed.AddDynamic(Connection, &UDbConnectionBase::HandleWSClosed); + Connection->WebSocket->SetNativeBinaryMessageTarget(Connection); + // Set the initialization token for the WebSocket connection + Connection->WebSocket->SetInitToken(Token); + Connection->StartInboundMessageWorker(); + // Connect the WebSocket to the constructed URL + Connection->WebSocket->Connect(WebSocketUrl); return Connection; } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp index 7b4bbe53f40..5aacb0b2b9b 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp @@ -1,5 +1,6 @@ #include "Connection/Websocket.h" +#include "Connection/DbConnectionBase.h" #include "WebSocketsModule.h" // Required for FWebSocketsModule #include "SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h" #include "ModuleBindings/Types/ServerMessageType.g.h" @@ -16,13 +17,14 @@ UWebsocketManager::UWebsocketManager() void UWebsocketManager::BeginDestroy() { - UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager::BeginDestroy: Cleaning up WebSocket.")); - if (!HasAnyFlags(RF_ClassDefaultObject)) - { - Disconnect(); - } - Super::BeginDestroy(); -} + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager::BeginDestroy: Cleaning up WebSocket.")); + if (!HasAnyFlags(RF_ClassDefaultObject)) + { + Disconnect(); + } + NativeBinaryMessageTarget.Reset(); + Super::BeginDestroy(); +} void UWebsocketManager::Connect(const FString& ServerUrl) { @@ -128,15 +130,20 @@ bool UWebsocketManager::SendMessage(const TArray& Data) return true; } -bool UWebsocketManager::IsConnected() const -{ - return WebSocket.IsValid() && WebSocket->IsConnected(); -} - -void UWebsocketManager::SetInitToken(FString Token) -{ - InitToken = Token; -} +bool UWebsocketManager::IsConnected() const +{ + return WebSocket.IsValid() && WebSocket->IsConnected(); +} + +void UWebsocketManager::SetNativeBinaryMessageTarget(UDbConnectionBase* Target) +{ + NativeBinaryMessageTarget = Target; +} + +void UWebsocketManager::SetInitToken(FString Token) +{ + InitToken = Token; +} void UWebsocketManager::HandleConnected() { @@ -157,39 +164,56 @@ void UWebsocketManager::HandleMessageReceived(const FString& Message) OnMessageReceived.Broadcast(Message); } -void UWebsocketManager::HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment) -{ - if (Size == 0) - { +void UWebsocketManager::HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment) +{ + if (Size == 0) + { return; } - - // Handle binary messages, which may be fragmented - const uint8* Bytes = static_cast(Data); - - // Append this fragment to our buffer - IncompleteMessage.Append(Bytes, Size); - - // If this is the last fragment, we have the complete message - if (bIsLastFragment) - { - // We have the complete message - TArray MessageBytes = IncompleteMessage; - IncompleteMessage.Reset(); - bAwaitingBinaryFragments = false; - - // Forward the complete binary payload to listeners. - OnBinaryMessageReceived.Broadcast(MessageBytes); - } - else - { - // More fragments are coming - bAwaitingBinaryFragments = true; - } -} - -void UWebsocketManager::HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean) -{ + + // Handle binary messages, which may be fragmented + const uint8* Bytes = static_cast(Data); + + if (!bAwaitingBinaryFragments && bIsLastFragment) + { + TArray MessageBytes; + MessageBytes.Append(Bytes, Size); + DispatchCompleteBinaryMessage(MoveTemp(MessageBytes)); + return; + } + + // Append this fragment to our buffer + IncompleteMessage.Append(Bytes, Size); + + // If this is the last fragment, we have the complete message + if (bIsLastFragment) + { + // We have the complete message + TArray MessageBytes = MoveTemp(IncompleteMessage); + bAwaitingBinaryFragments = false; + + DispatchCompleteBinaryMessage(MoveTemp(MessageBytes)); + } + else + { + // More fragments are coming + bAwaitingBinaryFragments = true; + } +} + +void UWebsocketManager::DispatchCompleteBinaryMessage(TArray&& Message) +{ + if (UDbConnectionBase* Target = NativeBinaryMessageTarget.Get()) + { + Target->HandleWSBinaryMessageOwned(MoveTemp(Message)); + return; + } + + OnBinaryMessageReceived.Broadcast(Message); +} + +void UWebsocketManager::HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean) +{ UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager: WebSocket Closed. Status: %d, Reason: %s, Clean: %s"), StatusCode, *Reason, bWasClean ? TEXT("true") : TEXT("false")); // Notify listeners about the closure diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h index 7d44d69c24f..7e93e81d248 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h @@ -9,6 +9,29 @@ namespace UE::SpacetimeDB { + enum class EQueryRowsApplyMode : uint8 + { + Inserts, + Deletes + }; + + namespace Private + { + template + static void AddParsedRowWithBsatn(const uint8* RowData, int32 RowLength, TArray>& OutRows) + { + checkf(RowLength >= 0, TEXT("Cannot parse a negative BSATN row length: %d"), RowLength); + checkf(RowData != nullptr || RowLength == 0, TEXT("Cannot parse null BSATN row data with length %d"), RowLength); + + RowType Row = DeserializeView(RowData, RowLength); + TArray Bsatn; + if (RowLength > 0) + { + Bsatn.Append(RowData, RowLength); + } + OutRows.Add(FWithBsatn(MoveTemp(Bsatn), MoveTemp(Row))); + } + } /** Parse a single row list based on its size hint and retain BSATN bytes */ template @@ -21,20 +44,19 @@ namespace UE::SpacetimeDB if (List.SizeHint.IsFixedSize()) { // Get the fixed size from the size hint - uint16 Size = List.SizeHint.GetAsFixedSize(); + const uint16 Size = List.SizeHint.GetAsFixedSize(); if (Size > 0) { + checkf(List.RowsData.Num() % Size == 0, + TEXT("Fixed-size BSATN row list has %d bytes, which is not divisible by row size %u"), + List.RowsData.Num(), + static_cast(Size)); // If the size is valid, parse the rows based on the fixed size - int32 Count = List.RowsData.Num() / Size; + const int32 Count = List.RowsData.Num() / Size; + OutRows.Reserve(OutRows.Num() + Count); for (int32 i = 0; i < Count; ++i) { - // Create a slice of the row data based on the fixed size - TArray Slice; - Slice.Append(List.RowsData.GetData() + i * Size, Size); - // Deserialize the row from the slice - RowType Row = UE::SpacetimeDB::Deserialize(Slice); - // Add the row with its BSATN bytes to the output array - OutRows.Add(FWithBsatn(Slice, Row)); + Private::AddParsedRowWithBsatn(List.RowsData.GetData() + i * Size, Size, OutRows); } return; } @@ -43,26 +65,29 @@ namespace UE::SpacetimeDB else if (List.SizeHint.IsRowOffsets()) { // Get the offsets from the size hint - TArray Offsets = List.SizeHint.GetAsRowOffsets(); + const TArray& Offsets = List.SizeHint.MessageData.Get>(); if (Offsets.Num() > 0) { // If the offsets are valid, parse the rows based on the offsets - UEReader Reader(List.RowsData); + OutRows.Reserve(OutRows.Num() + Offsets.Num()); for (int32 i = 0; i < Offsets.Num(); ++i) { // If this is the last offset, read until the end of the data - int64 Start = Offsets[i]; - int64 End = (i + 1 < Offsets.Num()) ? Offsets[i + 1] : List.RowsData.Num(); - int64 Length = End - Start; - TArray Slice; - Slice.Append(List.RowsData.GetData() + Start, Length); - - // Deserialize the row from the slice - UEReader SliceReader(Slice); - RowType Row = deserialize(SliceReader); - - // Add the row with its BSATN bytes to the output array - OutRows.Add(FWithBsatn(Slice, Row)); + const uint64 Start = Offsets[i]; + const uint64 End = (i + 1 < Offsets.Num()) ? Offsets[i + 1] : static_cast(List.RowsData.Num()); + checkf(Start <= End, + TEXT("BSATN row offsets are not sorted: start=%llu end=%llu row_index=%d"), + Start, + End, + i); + checkf(End <= static_cast(List.RowsData.Num()), + TEXT("BSATN row offset %llu exceeds row data size %d at row_index=%d"), + End, + List.RowsData.Num(), + i); + const int32 RowStart = static_cast(Start); + const int32 RowLength = static_cast(End - Start); + Private::AddParsedRowWithBsatn(List.RowsData.GetData() + RowStart, RowLength, OutRows); } } } @@ -79,14 +104,14 @@ namespace UE::SpacetimeDB { if (RowSet.IsPersistentTable()) { - const FPersistentTableRowsType Persistent = RowSet.GetAsPersistentTable(); + const FPersistentTableRowsType& Persistent = RowSet.MessageData.Get(); ParseRowListWithBsatn(Persistent.Inserts, Inserts); ParseRowListWithBsatn(Persistent.Deletes, Deletes); } // Event-table rows are callback-only inserts and should not create delete paths. else if (RowSet.IsEventTable()) { - const FEventTableRowsType EventRows = RowSet.GetAsEventTable(); + const FEventTableRowsType& EventRows = RowSet.MessageData.Get(); ParseRowListWithBsatn(EventRows.Events, Inserts); } else @@ -100,6 +125,11 @@ namespace UE::SpacetimeDB struct FPreprocessedTableDataBase { virtual ~FPreprocessedTableDataBase() {} + int32 InsertRowCount = 0; + int32 DeleteRowCount = 0; + int32 RowSetCount = 0; + int64 InsertRowBytes = 0; + int64 DeleteRowBytes = 0; }; /** A wrapper for a row type that includes its bsatn value. Used to store rows with their bsatn values. */ @@ -117,7 +147,8 @@ namespace UE::SpacetimeDB public: virtual ~ITableRowDeserializer() {} /** Preprocess the table update and return a shared pointer to preprocessed data. */ - virtual TSharedPtr PreProcess(const TArray& RowSets, const FString TableName) const = 0; + virtual TSharedPtr PreProcess(const TArray& RowSets, const FString& TableName) const = 0; + virtual TSharedPtr PreProcessQueryRows(const FBsatnRowListType& Rows, EQueryRowsApplyMode Mode, const FString& TableName) const = 0; }; /** Specialization of ITableRowDeserializer for a specific row type not defined in SDK. Used to deserialize rows of a specific type from a database update. */ @@ -125,24 +156,34 @@ namespace UE::SpacetimeDB class TTableRowDeserializer : public ITableRowDeserializer { public: - virtual TSharedPtr PreProcess(const TArray& RowSets, const FString TableName) const override + virtual TSharedPtr PreProcess(const TArray& RowSets, const FString& TableName) const override { // Create a new preprocessed table data object for the specific row type TSharedPtr> Result = MakeShared>(); + Result->RowSetCount = RowSets.Num(); // Process each row-set update in the table update for (const FTableUpdateRowsType& RowSet : RowSets) { if (RowSet.IsPersistentTable()) { - const FPersistentTableRowsType Persistent = RowSet.GetAsPersistentTable(); + const FPersistentTableRowsType& Persistent = RowSet.MessageData.Get(); + const int32 InsertCountBefore = Result->Inserts.Num(); + const int32 DeleteCountBefore = Result->Deletes.Num(); ParseRowListWithBsatn(Persistent.Inserts, Result->Inserts); ParseRowListWithBsatn(Persistent.Deletes, Result->Deletes); + Result->InsertRowCount += Result->Inserts.Num() - InsertCountBefore; + Result->DeleteRowCount += Result->Deletes.Num() - DeleteCountBefore; + Result->InsertRowBytes += Persistent.Inserts.RowsData.Num(); + Result->DeleteRowBytes += Persistent.Deletes.RowsData.Num(); } else if (RowSet.IsEventTable()) { // Event rows are insert-style callback payloads only. - const FEventTableRowsType Events = RowSet.GetAsEventTable(); + const FEventTableRowsType& Events = RowSet.MessageData.Get(); + const int32 InsertCountBefore = Result->Inserts.Num(); ParseRowListWithBsatn(Events.Events, Result->Inserts); + Result->InsertRowCount += Result->Inserts.Num() - InsertCountBefore; + Result->InsertRowBytes += Events.Events.RowsData.Num(); } else { @@ -151,5 +192,34 @@ namespace UE::SpacetimeDB } return Result; } + + virtual TSharedPtr PreProcessQueryRows(const FBsatnRowListType& Rows, EQueryRowsApplyMode Mode, const FString& TableName) const override + { + TSharedPtr> Result = MakeShared>(); + Result->RowSetCount = 1; + switch (Mode) + { + case EQueryRowsApplyMode::Inserts: + { + const int32 InsertCountBefore = Result->Inserts.Num(); + ParseRowListWithBsatn(Rows, Result->Inserts); + Result->InsertRowCount += Result->Inserts.Num() - InsertCountBefore; + Result->InsertRowBytes += Rows.RowsData.Num(); + break; + } + case EQueryRowsApplyMode::Deletes: + { + const int32 DeleteCountBefore = Result->Deletes.Num(); + ParseRowListWithBsatn(Rows, Result->Deletes); + Result->DeleteRowCount += Result->Deletes.Num() - DeleteCountBefore; + Result->DeleteRowBytes += Rows.RowsData.Num(); + break; + } + default: + checkf(false, TEXT("Unsupported query-row apply mode for table %s"), *TableName); + break; + } + return Result; + } }; } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h index 8b25806e282..283d2261383 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h @@ -167,14 +167,33 @@ namespace UE::SpacetimeDB { * This class provides a UE-friendly interface for deserializing BSATN data, * handling conversions between standard C++ types and UE types. * - * The reader stores a copy of the input data to prevent lifetime issues - * that could occur if the original TArray is destroyed. + * The TArray and std::vector constructors retain copy semantics for callers + * that need owned reader storage. The pointer and array-view constructors are + * caller-owned views intended for immediate, scoped parsing. */ class UEReader { private: - std::vector stored_data; ///< Local copy of data to ensure lifetime + std::vector stored_data; ///< Local copy of data when ownership is requested ::SpacetimeDb::bsatn::Reader core_reader; ///< Underlying BSATN reader + static const uint8_t* ValidateReaderData(const uint8_t* data, int32 size) + { + checkf(size >= 0, TEXT("UEReader cannot read a negative byte count: %d"), size); + checkf(data != nullptr || size == 0, TEXT("UEReader received null data for %d bytes"), size); + static constexpr uint8_t EmptyReaderByte = 0; + if (data == nullptr) + { + return &EmptyReaderByte; + } + return data; + } + + static size_t ValidateReaderSize(int32 size) + { + checkf(size >= 0, TEXT("UEReader cannot read a negative byte count: %d"), size); + return static_cast(size); + } + public: /** * Construct a reader from a UE byte array @@ -198,6 +217,12 @@ namespace UE::SpacetimeDB { : stored_data(data), core_reader(stored_data) {} + explicit UEReader(const uint8_t* data, int32 size) + : core_reader(ValidateReaderData(data, size), ValidateReaderSize(size)) {} + + explicit UEReader(TConstArrayView data) + : UEReader(data.GetData(), data.Num()) {} + // ------------------------------------------------------------------------- // Primitive Type Readers // ------------------------------------------------------------------------- @@ -894,6 +919,27 @@ namespace UE::SpacetimeDB { } } + template + T DeserializeView(const uint8* data, int32 size) { + + UEReader reader(data, size); + + if constexpr (is_tarray_v) { + return DeserializeHelper::deserialize(reader); + } + else if constexpr (is_toptional_v) { + return DeserializeHelper::deserialize(reader); + } + else { + return deserialize(reader); + } + } + + template + T DeserializeView(TConstArrayView data) { + return DeserializeView(data.GetData(), data.Num()); + } + /** @} */ // end of HighLevelAPI group // ============================================================================= @@ -932,7 +978,7 @@ namespace UE::SpacetimeDB { * @brief Helper macro to generate deserialize specialization for TOptional * * Use this macro when you have structs with TOptional fields of custom types. - * + * * @Note: We are not using TOptional directly becouse it is not compatable wiht blueprints. * This macro is kept for future compatibility if things change on the Engine side. * @@ -1322,7 +1368,7 @@ namespace UE::SpacetimeDB { UE_SPACETIMEDB_ENABLE_TARRAY(float) UE_SPACETIMEDB_ENABLE_TARRAY(double) - + UE_SPACETIMEDB_ENABLE_TARRAY(bool) // Large integer type containers @@ -1330,8 +1376,8 @@ namespace UE::SpacetimeDB { UE_SPACETIMEDB_ENABLE_TARRAY(FSpacetimeDBUInt256) UE_SPACETIMEDB_ENABLE_TARRAY(FSpacetimeDBInt128) UE_SPACETIMEDB_ENABLE_TARRAY(FSpacetimeDBInt256) - + /** @} */ // end of CommonSpecializations group -} // namespace UE::SpacetimeDB \ No newline at end of file +} // namespace UE::SpacetimeDB diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h index b32221ce641..7a8da5d1a1a 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h @@ -1,48 +1,47 @@ -#pragma once - +#pragma once + #include "CoreMinimal.h" -#include "Containers/Ticker.h" #include "UObject/NoExportTypes.h" -#include "Types/Builtins.h" -#include "Websocket.h" -#include "Subscription.h" -#include "ModuleBindings/Types/ServerMessageType.g.h" -#include "DBCache/TableAppliedDiff.h" -#include "HAL/CriticalSection.h" -#include "Containers/Queue.h" +#include "Types/Builtins.h" +#include "Websocket.h" +#include "Subscription.h" +#include "ModuleBindings/Types/ServerMessageType.g.h" +#include "DBCache/TableAppliedDiff.h" +#include "HAL/CriticalSection.h" #include "HAL/ThreadSafeBool.h" #include "BSATN/UEBSATNHelpers.h" #include "Connection/Callback.h" #include "LogCategory.h" #include - -#include "DbConnectionBase.generated.h" - -// Forward declarations -class UDbConnectionBuilder; -class UProcedureCallbacks; - -/** Macro for safae way to bind delegate without needing to write Function name as an FName. */ -#define BIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ - DelegateVar.BindUFunction(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) - -/** Macro for safe way to unbind delegate without needing to write Function name as an FName. */ -#define UNBIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ - DelegateVar.Remove(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) - + +#include "DbConnectionBase.generated.h" + +// Forward declarations +class UDbConnectionBuilder; +class UProcedureCallbacks; +class FSpacetimeDbInboundWorker; + +/** Macro for safae way to bind delegate without needing to write Function name as an FName. */ +#define BIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ + DelegateVar.BindUFunction(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) + +/** Macro for safe way to unbind delegate without needing to write Function name as an FName. */ +#define UNBIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ + DelegateVar.Remove(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) + /** Delegate called when the connection attempt fails. */ -DECLARE_DYNAMIC_DELEGATE_OneParam( - FOnConnectErrorDelegate, - const FString&, ErrorMessage); - -/** Called when a connection is established. */ -DECLARE_DYNAMIC_DELEGATE_ThreeParams( - FOnConnectBaseDelegate, - UDbConnectionBase*, Connection, - FSpacetimeDBIdentity, Identity, - const FString&, Token); - -/** Called when a connection closes. */ +DECLARE_DYNAMIC_DELEGATE_OneParam( + FOnConnectErrorDelegate, + const FString&, ErrorMessage); + +/** Called when a connection is established. */ +DECLARE_DYNAMIC_DELEGATE_ThreeParams( + FOnConnectBaseDelegate, + UDbConnectionBase*, Connection, + FSpacetimeDBIdentity, Identity, + const FString&, Token); + +/** Called when a connection closes. */ DECLARE_DYNAMIC_DELEGATE_TwoParams( FOnDisconnectBaseDelegate, UDbConnectionBase*, Connection, @@ -95,6 +94,92 @@ FORCEINLINE uint32 GetTypeHash(const FPreprocessedTableKey& Key) return GetTypeHash(Key.TableName); } +using FPreprocessedTableDataMap = TMap>>; + +struct FInboundRawMessage +{ + uint64 ConnectionEpoch = 0; + uint64 SequenceId = 0; + int32 QueueDepthAtEnqueue = 0; + int64 QueuedBytesAtEnqueue = 0; + TArray Payload; +}; + +struct FInboundParsedMessage +{ + uint64 ConnectionEpoch = 0; + uint64 SequenceId = 0; + int32 PayloadSizeBytes = 0; + uint8 CompressionTag = 0; + int32 QueueDepthAtEnqueue = 0; + int64 QueuedBytesAtEnqueue = 0; + bool bProtocolError = false; + FString ProtocolError; + FServerMessageType Message; + FPreprocessedTableDataMap PreprocessedTableData; +}; + +struct FSpacetimeDBTableApplyStats +{ + FString TableName; + int32 RowSetCount = 0; + int32 InsertRowCount = 0; + int32 DeleteRowCount = 0; + int64 InsertRowBytes = 0; + int64 DeleteRowBytes = 0; + double CacheMicros = 0.0; + double BroadcastMicros = 0.0; + bool bProducedDiff = false; +}; + +struct FSpacetimeDBInboundMessageApplyStats +{ + FString MessageKind; + FString ReducerName; + uint32 RequestId = 0; + uint64 SequenceId = 0; + int32 PayloadSizeBytes = 0; + int32 QueueDepthAtEnqueue = 0; + int64 QueuedBytesAtEnqueue = 0; + TArray TableStats; +}; + +USTRUCT(BlueprintType) +struct SPACETIMEDBSDK_API FSpacetimeDBInboundApplyBudget +{ + GENERATED_BODY() + + UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "SpacetimeDB") + int32 MaxMessagesPerFrame = 256; + + UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "SpacetimeDB") + int64 MaxPayloadBytesPerFrame = 4 * 1024 * 1024; + + UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "SpacetimeDB") + int32 MinMessagesPerFrame = 1; + + UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "SpacetimeDB") + int64 SoftTimeBudgetMicros = 0; + + UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "SpacetimeDB") + bool bDrainAllPendingMessages = false; + + static FSpacetimeDBInboundApplyBudget MakeDrainAllPendingMessages() + { + FSpacetimeDBInboundApplyBudget Budget; + Budget.bDrainAllPendingMessages = true; + return Budget; + } + + void Sanitize() + { + MaxMessagesPerFrame = FMath::Max(1, MaxMessagesPerFrame); + MaxPayloadBytesPerFrame = FMath::Max(1, MaxPayloadBytesPerFrame); + MinMessagesPerFrame = FMath::Clamp(MinMessagesPerFrame, 1, MaxMessagesPerFrame); + SoftTimeBudgetMicros = FMath::Max(0, SoftTimeBudgetMicros); + } +}; + template struct THasOnDeleteDelegate : std::false_type { @@ -114,251 +199,537 @@ template struct THasOnUpdateDelegate> : std::true_type { }; - + +template +const void* GetNativeTableListenerTypeId() +{ + static const uint8 TypeId = 0; + return &TypeId; +} + +struct FNativeTableListenerBinding +{ + using FInsertThunk = void(*)(void* Owner, const void* Context, const void* Row); + using FUpdateThunk = void(*)(void* Owner, const void* Context, const void* OldRow, const void* NewRow); + using FDeleteThunk = void(*)(void* Owner, const void* Context, const void* Row); + using FDiffThunk = void(*)(void* Owner, const void* Context, const void* Diff); + + void* Owner = nullptr; + const void* RowTypeId = nullptr; + const void* EventContextTypeId = nullptr; + FInsertThunk InsertThunk = nullptr; + FUpdateThunk UpdateThunk = nullptr; + FDeleteThunk DeleteThunk = nullptr; + FDiffThunk DiffThunk = nullptr; + + bool IsComplete() const + { + return Owner != nullptr && + RowTypeId != nullptr && + EventContextTypeId != nullptr && + (DiffThunk != nullptr || + (InsertThunk != nullptr && + UpdateThunk != nullptr && + DeleteThunk != nullptr)); + } +}; + UCLASS() -class SPACETIMEDBSDK_API UDbConnectionBase : public UObject -{ - GENERATED_BODY() - -public: - - /** The default constructor is private to prevent instantiation without using the builder. */ - explicit UDbConnectionBase(const FObjectInitializer& ObjectInitializer = FObjectInitializer::Get()); - - /** Disconnect from the server. */ - UFUNCTION(BlueprintCallable, Category="SpacetimeDB") - void Disconnect(); - - /** Check if the underlying WebSocket is connected. */ - UFUNCTION(BlueprintPure, Category="SpacetimeDB") - bool IsActive() const; - - UFUNCTION(BlueprintCallable, Category="SpacetimeDB") - void FrameTick(); - +class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGameObject +{ + GENERATED_BODY() + +public: + + /** The default constructor is private to prevent instantiation without using the builder. */ + explicit UDbConnectionBase(const FObjectInitializer& ObjectInitializer = FObjectInitializer::Get()); + + virtual void BeginDestroy() override; + + /** Disconnect from the server. */ + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") + void Disconnect(); + + /** Check if the underlying WebSocket is connected. */ + UFUNCTION(BlueprintPure, Category="SpacetimeDB") + bool IsActive() const; + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") - void SetAutoTicking(bool bAutoTick); - - /** Send a raw JSON message to the server. */ - bool SendRawMessage(const FString& Message); - /** Send a raw binary message to the server. */ - bool SendRawMessage(const TArray& Message); - - /** Get the current subscription builder. This is used to create subscriptions. */ - UFUNCTION() - USubscriptionBuilderBase* SubscriptionBuilderBase(); - - /** Get the current identity of the SpacetimeDB instance. This is used to identify the connection. */ - UFUNCTION(BlueprintPure, Category = "SpacetimeDB") - bool TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const; - - /** Get the current connection id. This is used to identify the connection. */ - UFUNCTION(BlueprintPure, Category = "SpacetimeDB") - FSpacetimeDBConnectionId GetConnectionId() const; - - // Typed reducer call helper: hides BSATN bytes from callers. + void FrameTick(); + + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") + void SetAutoTicking(bool bAutoTick) { bIsAutoTicking = bAutoTick; } + + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") + void SetInboundApplyBudget(FSpacetimeDBInboundApplyBudget InBudget) + { + InBudget.Sanitize(); + InboundApplyBudget = InBudget; + } + + /** Send a raw JSON message to the server. */ + bool SendRawMessage(const FString& Message); + /** Send a raw binary message to the server. */ + bool SendRawMessage(const TArray& Message); + + /** Get the current subscription builder. This is used to create subscriptions. */ + UFUNCTION() + USubscriptionBuilderBase* SubscriptionBuilderBase(); + + /** Get the current identity of the SpacetimeDB instance. This is used to identify the connection. */ + UFUNCTION(BlueprintPure, Category = "SpacetimeDB") + bool TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const; + + /** Get the current connection id. This is used to identify the connection. */ + UFUNCTION(BlueprintPure, Category = "SpacetimeDB") + FSpacetimeDBConnectionId GetConnectionId() const; + + // Typed reducer call helper: hides BSATN bytes from callers. template uint32 CallReducerTyped(const FString& Reducer, const ArgsStruct& Args) { TArray Bytes = UE::SpacetimeDB::Serialize(Args); return InternalCallReducer(Reducer, MoveTemp(Bytes)); } - - template - void CallProcedureTyped(const FString& ProcedureName, const ArgsStruct& Args, const FOnProcedureCompleteDelegate& Callback) - { - TArray Bytes = UE::SpacetimeDB::Serialize(Args); - InternalCallProcedure(ProcedureName, MoveTemp(Bytes), Callback); - } - - template - void RegisterTable(const FString& TableName) - { - FScopeLock Lock(&TableDeserializersMutex); - TableDeserializers.Add(TableName, MakeShared>()); - } - - /** Internal interface for applying table updates generically */ - class ITableUpdateHandler - { - public: - virtual ~ITableUpdateHandler() {} - - /** Update the in-memory cache for the table and store the diff */ - virtual void UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context) = 0; - - /** Broadcast the previously stored diff */ - virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) = 0; - }; - - template - class TTableUpdateHandler : public ITableUpdateHandler - { - public: - explicit TTableUpdateHandler(TableClass* InTable) : Table(InTable) {} - - //** Update the in-memory cache for the table and store the diff */ - virtual void UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context) override - { - // Attempt to take preprocessed data if available - TSharedPtr> Pre; - if (Conn->TakePreprocessedTableData(Update, Pre)) - { - // If preprocessed data is available, use it to update the table - LastDiff = Table->Update(Pre->Inserts, Pre->Deletes); - } - else - { - // If no preprocessed data, process the update directly. Backup - UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("No preprocessed data for table update. Processing directly.")); - TArray> Inserts, Deletes; - UE::SpacetimeDB::ProcessTableUpdateWithBsatn(Update, Inserts, Deletes); - LastDiff = Table->Update(Inserts, Deletes); - } - } - //** Broadcast the last stored diff to the table's delegates */ - virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) override - { - EventContext& Ctx = *reinterpret_cast(Context); - Conn->BroadcastDiff(Table, LastDiff, Ctx); - } - - private: - TableClass* Table; - FTableAppliedDiff LastDiff; - }; - //** Register a table with the connection. This will allow the connection to handle updates for the table. - template - void RegisterTable(const FString& TableName, TableClass* Table) - { - RegisterTable(TableName); - FScopeLock Lock(&RegisteredTablesMutex); - RegisteredTables.Add(TableName, MakeShared>(Table)); - } - //** Take preprocessed table row data. */ - template + + template + void CallProcedureTyped(const FString& ProcedureName, const ArgsStruct& Args, const FOnProcedureCompleteDelegate& Callback) + { + TArray Bytes = UE::SpacetimeDB::Serialize(Args); + InternalCallProcedure(ProcedureName, MoveTemp(Bytes), Callback); + } + + template + void RegisterTable(const FString& TableName) + { + FScopeLock Lock(&TableDeserializersMutex); + TableDeserializers.Add(TableName, MakeShared>()); + } + + /** Internal interface for applying table updates generically */ + class ITableUpdateHandler + { + public: + virtual ~ITableUpdateHandler() {} + + /** Update the in-memory cache for the table and store the diff */ + virtual bool UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context, FSpacetimeDBTableApplyStats* OutStats) = 0; + + /** Broadcast the previously stored diff */ + virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) = 0; + virtual const FString& GetTableName() const = 0; + virtual void RegisterNativeListener(const FNativeTableListenerBinding& Binding) = 0; + virtual void UnregisterNativeListener(void* Owner) = 0; + }; + + template + class TTableUpdateHandler : public ITableUpdateHandler + { + public: + explicit TTableUpdateHandler(const FString& InTableName, TableClass* InTable) + : TableName(InTableName) + , Table(InTable) + { + } + + //** Update the in-memory cache for the table and store the diff */ + virtual bool UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context, FSpacetimeDBTableApplyStats* OutStats) override + { + if (PendingDiffReadIndex == PendingDiffs.Num()) + { + PendingDiffs.Reset(); + PendingDiffReadIndex = 0; + } + + TSharedPtr> Pre; + const bool bTookPreprocessedData = Conn->TakePreprocessedTableData(Update, Pre); + checkf(bTookPreprocessedData && Pre.IsValid(), TEXT("Missing message-scoped preprocessed data for table '%s'."), *Update.TableName); + if (OutStats != nullptr) + { + OutStats->TableName = TableName; + OutStats->RowSetCount = Pre->RowSetCount; + OutStats->InsertRowCount = Pre->InsertRowCount; + OutStats->DeleteRowCount = Pre->DeleteRowCount; + OutStats->InsertRowBytes = Pre->InsertRowBytes; + OutStats->DeleteRowBytes = Pre->DeleteRowBytes; + } + FTableAppliedDiff AppliedDiff = Table->Update(MoveTemp(Pre->Inserts), MoveTemp(Pre->Deletes)); + if (AppliedDiff.IsEmpty()) + { + if (OutStats != nullptr) + { + OutStats->bProducedDiff = false; + } + return false; + } + if (OutStats != nullptr) + { + OutStats->bProducedDiff = true; + } + PendingDiffs.Add(MoveTemp(AppliedDiff)); + return true; + } + //** Broadcast the last stored diff to the table's delegates */ + virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) override + { + checkf(PendingDiffReadIndex < PendingDiffs.Num(), TEXT("Missing pending SpacetimeDB table diff for broadcast.")); + EventContext& Ctx = *reinterpret_cast(Context); + const FTableAppliedDiff& Diff = PendingDiffs[PendingDiffReadIndex]; + if (!NativeListeners.IsEmpty()) + { + BroadcastNativeDiff(Diff, Ctx); + } + else + { + Conn->BroadcastDiff(Table, Diff, Ctx); + } + ++PendingDiffReadIndex; + if (PendingDiffReadIndex == PendingDiffs.Num()) + { + PendingDiffs.Reset(); + PendingDiffReadIndex = 0; + } + } + + virtual const FString& GetTableName() const override + { + return TableName; + } + + virtual void RegisterNativeListener(const FNativeTableListenerBinding& Binding) override + { + checkf(!bBroadcastingNativeListeners, + TEXT("Cannot register native SpacetimeDB table listener during broadcast for table '%s'."), + *TableName); + checkf(Binding.IsComplete(), TEXT("Incomplete native SpacetimeDB table listener for table '%s'."), *TableName); + checkf(Binding.RowTypeId == GetNativeTableListenerTypeId(), + TEXT("Native SpacetimeDB table listener row type mismatch for table '%s'."), *TableName); + checkf(Binding.EventContextTypeId == GetNativeTableListenerTypeId(), + TEXT("Native SpacetimeDB table listener context type mismatch for table '%s'."), *TableName); + for (const FNativeTableListenerBinding& ExistingBinding : NativeListeners) + { + checkf(ExistingBinding.Owner != Binding.Owner, + TEXT("Duplicate native SpacetimeDB table listener owner for table '%s'."), + *TableName); + } + NativeListeners.Add(Binding); + } + + virtual void UnregisterNativeListener(void* Owner) override + { + checkf(!bBroadcastingNativeListeners, + TEXT("Cannot unregister native SpacetimeDB table listener during broadcast for table '%s'."), + *TableName); + checkf(Owner != nullptr, TEXT("Cannot unregister null native SpacetimeDB table listener owner for table '%s'."), *TableName); + const int32 ListenerIndex = NativeListeners.IndexOfByPredicate( + [Owner](const FNativeTableListenerBinding& Binding) + { + return Binding.Owner == Owner; + }); + checkf(ListenerIndex != INDEX_NONE, + TEXT("Missing native SpacetimeDB table listener for table '%s'."), + *TableName); + NativeListeners.RemoveAtSwap(ListenerIndex, 1, EAllowShrinking::No); + } + + private: + void BroadcastNativeDiff(const FTableAppliedDiff& Diff, const EventContext& Context) + { + TGuardValue BroadcastingScope(bBroadcastingNativeListeners, true); + for (const FNativeTableListenerBinding& Listener : NativeListeners) + { + BroadcastNativeDiffToListener(Diff, Context, Listener); + } + } + + void BroadcastNativeDiffToListener( + const FTableAppliedDiff& Diff, + const EventContext& Context, + const FNativeTableListenerBinding& Listener) + { + checkf(Listener.IsComplete(), TEXT("Incomplete native SpacetimeDB table listener for table '%s'."), *TableName); + if (Listener.DiffThunk != nullptr) + { + Listener.DiffThunk(Listener.Owner, &Context, &Diff); + return; + } + + for (const TSharedPtr& Row : Diff.Inserts) + { + checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB native insert diff row for table '%s'."), *TableName); + Listener.InsertThunk(Listener.Owner, &Context, Row.Get()); + } + + for (const TSharedPtr& Row : Diff.Deletes) + { + checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB native delete diff row for table '%s'."), *TableName); + Listener.DeleteThunk(Listener.Owner, &Context, Row.Get()); + } + + checkf(Diff.UpdateDeletes.Num() == Diff.UpdateInserts.Num(), + TEXT("Mismatched SpacetimeDB native update diff counts for table '%s'."), *TableName); + for (int32 Index = 0; Index < Diff.UpdateInserts.Num(); ++Index) + { + const TSharedPtr& OldRow = Diff.UpdateDeletes[Index]; + const TSharedPtr& NewRow = Diff.UpdateInserts[Index]; + checkf(OldRow.IsValid() && NewRow.IsValid(), TEXT("Invalid SpacetimeDB native update diff row for table '%s'."), *TableName); + Listener.UpdateThunk(Listener.Owner, &Context, OldRow.Get(), NewRow.Get()); + } + } + + FString TableName; + TableClass* Table; + TArray> PendingDiffs; + int32 PendingDiffReadIndex = 0; + TArray NativeListeners; + bool bBroadcastingNativeListeners = false; + }; + //** Register a table with the connection. This will allow the connection to handle updates for the table. + template + void RegisterTable(const FString& TableName, TableClass* Table) + { + RegisterTable(TableName); + FScopeLock Lock(&RegisteredTablesMutex); + RegisteredTables.Add(TableName, MakeShared>(TableName, Table)); + RegisteredTablesSnapshot = MakeShared>>(RegisteredTables); + } + + template + void RegisterNativeTableListener(const FString& TableName, OwnerType* Owner) + { + static_assert(std::is_base_of_v, "Native SpacetimeDB table listener owner must derive from UObject."); + checkf(Owner != nullptr, TEXT("Cannot register null native SpacetimeDB table listener owner for table '%s'."), *TableName); + + FNativeTableListenerBinding Binding; + Binding.Owner = Owner; + Binding.RowTypeId = GetNativeTableListenerTypeId(); + Binding.EventContextTypeId = GetNativeTableListenerTypeId(); + Binding.InsertThunk = [](void* RawOwner, const void* RawContext, const void* RawRow) + { + (static_cast(RawOwner)->*InsertFn)( + *static_cast(RawContext), + *static_cast(RawRow)); + }; + Binding.UpdateThunk = [](void* RawOwner, const void* RawContext, const void* RawOldRow, const void* RawNewRow) + { + (static_cast(RawOwner)->*UpdateFn)( + *static_cast(RawContext), + *static_cast(RawOldRow), + *static_cast(RawNewRow)); + }; + Binding.DeleteThunk = [](void* RawOwner, const void* RawContext, const void* RawRow) + { + (static_cast(RawOwner)->*DeleteFn)( + *static_cast(RawContext), + *static_cast(RawRow)); + }; + + FScopeLock Lock(&RegisteredTablesMutex); + TSharedPtr* Handler = RegisteredTables.Find(TableName); + checkf(Handler != nullptr && Handler->IsValid(), + TEXT("Missing SpacetimeDB table handler while registering native listener for table '%s'."), *TableName); + (*Handler)->RegisterNativeListener(Binding); + } + + template&)> + void RegisterNativeTableDiffListener(const FString& TableName, OwnerType* Owner) + { + static_assert(std::is_base_of_v, "Native SpacetimeDB table diff listener owner must derive from UObject."); + checkf(Owner != nullptr, TEXT("Cannot register null native SpacetimeDB table diff listener owner for table '%s'."), *TableName); + + FNativeTableListenerBinding Binding; + Binding.Owner = Owner; + Binding.RowTypeId = GetNativeTableListenerTypeId(); + Binding.EventContextTypeId = GetNativeTableListenerTypeId(); + Binding.DiffThunk = [](void* RawOwner, const void* RawContext, const void* RawDiff) + { + (static_cast(RawOwner)->*DiffFn)( + *static_cast(RawContext), + *static_cast*>(RawDiff)); + }; + + FScopeLock Lock(&RegisteredTablesMutex); + TSharedPtr* Handler = RegisteredTables.Find(TableName); + checkf(Handler != nullptr && Handler->IsValid(), + TEXT("Missing SpacetimeDB table handler while registering native diff listener for table '%s'."), *TableName); + (*Handler)->RegisterNativeListener(Binding); + } + + template + void UnregisterNativeTableListener(const FString& TableName, OwnerType* Owner) + { + static_assert(std::is_base_of_v, "Native SpacetimeDB table listener owner must derive from UObject."); + checkf(Owner != nullptr, TEXT("Cannot unregister null native SpacetimeDB table listener owner for table '%s'."), *TableName); + + FScopeLock Lock(&RegisteredTablesMutex); + TSharedPtr* Handler = RegisteredTables.Find(TableName); + checkf(Handler != nullptr && Handler->IsValid(), + TEXT("Missing SpacetimeDB table handler while unregistering native listener for table '%s'."), *TableName); + (*Handler)->UnregisterNativeListener(Owner); + } + + template&)> + void UnregisterNativeTableDiffListener(const FString& TableName, OwnerType* Owner) + { + static_assert(std::is_base_of_v, "Native SpacetimeDB table diff listener owner must derive from UObject."); + checkf(Owner != nullptr, TEXT("Cannot unregister null native SpacetimeDB table diff listener owner for table '%s'."), *TableName); + + FScopeLock Lock(&RegisteredTablesMutex); + TSharedPtr* Handler = RegisteredTables.Find(TableName); + checkf(Handler != nullptr && Handler->IsValid(), + TEXT("Missing SpacetimeDB table handler while unregistering native diff listener for table '%s'."), *TableName); + (*Handler)->UnregisterNativeListener(Owner); + } + //** Take preprocessed table row data. */ + template bool TakePreprocessedTableData(const FTableUpdateType& Update, TSharedPtr>& OutData) { - FScopeLock Lock(&PreprocessedDataMutex); + checkf(ActivePreprocessedTableData != nullptr, TEXT("No active inbound message while applying table update '%s'."), *Update.TableName); FPreprocessedTableKey Key(Update.TableName); - if (TArray>* Found = PreprocessedTableData.Find(Key)) + TArray>* Found = ActivePreprocessedTableData->Find(Key); + checkf(Found != nullptr && Found->Num() > 0, TEXT("Missing message-scoped preprocessed data for table '%s'."), *Update.TableName); + OutData = StaticCastSharedPtr>((*Found)[0]); + Found->RemoveAt(0, 1, EAllowShrinking::No); + if (Found->Num() == 0) { - if (Found->Num() > 0) - { - OutData = StaticCastSharedPtr>((*Found)[0]); - Found->RemoveAt(0); - if (Found->Num() == 0) - { - PreprocessedTableData.Remove(Key); - } - return OutData.IsValid(); - } - } - return false; - } - - + ActivePreprocessedTableData->Remove(Key); + } + checkf(OutData.IsValid(), TEXT("Invalid message-scoped preprocessed data for table '%s'."), *Update.TableName); + return true; + } + + protected: - - virtual void BeginDestroy() override; friend class UDbConnectionBuilderBase; - friend class UDbConnectionBuilder; - friend class USubscriptionHandleBase; - friend class USubscriptionBuilder; - friend class URemoteReducers; - - /** Allow derived classes to override the delegates used when connecting */ - void SetOnConnectDelegate(const FOnConnectBaseDelegate& Delegate) { OnConnectBaseDelegate = Delegate; } - void SetOnDisconnectDelegate(const FOnDisconnectBaseDelegate& Delegate) { OnDisconnectBaseDelegate = Delegate; } - - UFUNCTION() - void HandleWSError(const FString& Error); - UFUNCTION() - void HandleWSClosed(int32 StatusCode, const FString& Reason, bool bWasClean); - UFUNCTION() - void HandleWSBinaryMessage(const TArray& Message); - - bool OnTickerTick(float DeltaTime); - + friend class UDbConnectionBuilder; + friend class FSpacetimeDbInboundWorker; + friend class UWebsocketManager; + friend class USubscriptionHandleBase; + friend class USubscriptionBuilder; + friend class URemoteReducers; + + /** Allow derived classes to override the delegates used when connecting */ + void SetOnConnectDelegate(const FOnConnectBaseDelegate& Delegate) { OnConnectBaseDelegate = Delegate; } + void SetOnDisconnectDelegate(const FOnDisconnectBaseDelegate& Delegate) { OnDisconnectBaseDelegate = Delegate; } + + UFUNCTION() + void HandleWSError(const FString& Error); + UFUNCTION() + void HandleWSClosed(int32 StatusCode, const FString& Reason, bool bWasClean); + UFUNCTION() + void HandleWSBinaryMessage(const TArray& Message); + void HandleWSBinaryMessageOwned(TArray&& Message); + void StartInboundMessageWorker(); + void StopInboundMessageWorker(); + void ClearInboundMessageQueues(); + void NotifyInboundWorkerIfNeeded(); + void DrainInboundRawMessagesOnWorker(); + bool BuildInboundParsedMessage(const FInboundRawMessage& RawMessage, FInboundParsedMessage& OutMessage); + void EnqueueInboundProtocolError(uint64 SequenceId, int32 PayloadSizeBytes, uint8 CompressionTag, int32 QueueDepthAtEnqueue, int64 QueuedBytesAtEnqueue, const FString& ErrorMessage); + bool IsInboundProtocolErrorQueued() const; + bool IsInboundEpochCurrentAndAccepting(uint64 ConnectionEpoch) const; + void MarkInboundProtocolErrorQueued(); + + virtual void Tick(float DeltaTime) override; + + virtual TStatId GetStatId() const override; + + virtual bool IsTickable() const override; + + virtual bool IsTickableInEditor() const override; + /** Internal handler that processes a single server message. */ void ProcessServerMessage(const FServerMessageType& Message); - void PreProcessDatabaseUpdate(const FDatabaseUpdateType& Update); + void ProcessInboundServerMessage(FInboundParsedMessage& InboundMessage, FSpacetimeDBInboundMessageApplyStats& ApplyStats); + void PreProcessTableUpdateRows(const FString& TableName, const TArray& RowSets, FPreprocessedTableDataMap& OutPreprocessedTableData); + void PreProcessQueryRows(const FQueryRowsType& Rows, UE::SpacetimeDB::EQueryRowsApplyMode Mode, FPreprocessedTableDataMap& OutPreprocessedTableData); + void PreProcessTransactionUpdate(const FTransactionUpdateType& Update, FPreprocessedTableDataMap& OutPreprocessedTableData); + TSharedPtr FindTableDeserializerForPreprocess(const FString& TableName); + void StorePreprocessedTableData(const FString& TableName, TSharedPtr Data, FPreprocessedTableDataMap& OutPreprocessedTableData); /** Decompress and parse a raw message. */ - bool PreProcessMessage(const TArray& Message, FServerMessageType& OutMessage); - bool DecompressPayload(uint8 Variant, const TArray& In, TArray& Out); - bool DecompressGzip(const TArray& InData, TArray& OutData); - bool DecompressBrotli(const TArray& InData, TArray& OutData); + bool PreProcessMessage(const TArray& Message, FInboundParsedMessage& OutMessage); + bool DecompressGzip(const uint8* InData, int32 InSize, TArray& OutData); + bool DecompressBrotli(const uint8* InData, int32 InSize, TArray& OutData); void ClearPendingOperations(const FString& Reason); void HandleProtocolViolation(const FString& ErrorMessage); - - /** Pending messages awaiting processing on the game thread. */ - TArray PendingMessages; - - /** Mutex protecting access to PendingMessages. */ - FCriticalSection PendingMessagesMutex; - - /** Map of preprocessed messages keyed by their sequential id. */ - TMap PreprocessedMessages; - - /** Protects PreprocessedMessages and PendingMessages ordering state. */ - FCriticalSection PreprocessMutex; - - /** Counter for assigning ids to incoming messages. */ - FThreadSafeCounter NextPreprocessId; - - /** Id of the next message expected to be released. */ - int32 NextReleaseId = 0; - - // Map of table name to row deserializer - TMap> TableDeserializers; - FCriticalSection TableDeserializersMutex; - - // Map from table update pointer to preprocessed data - TMap>> PreprocessedTableData; - FCriticalSection PreprocessedDataMutex; - - // Map of table name to generic table update handler - TMap> RegisteredTables; - FCriticalSection RegisteredTablesMutex; - - - /** Start a subscription. This will add the subscription to the active list and send a subscribe message to the server. */ - void StartSubscription(USubscriptionHandleBase* Handle); - /** Unsubscribe from a subscription. This will remove the subscription from the active list and send an unsubscribe message to the server. */ - void UnsubscribeInternal(USubscriptionHandleBase* Handle); - - /** Call a reducer on the connected SpacetimeDB instance. */ + + /** Parsed inbound messages awaiting processing on the game thread. */ + TArray PendingMessages; + + /** Mutex protecting access to PendingMessages. */ + FCriticalSection PendingMessagesMutex; + int32 PendingMessageReadIndex = 0; + int64 PendingParsedPayloadBytes = 0; + + /** Raw inbound messages awaiting FIFO processing by the connection-owned worker. */ + TArray InboundRawMessages; + mutable FCriticalSection InboundRawMessagesMutex; + int64 InboundQueuedRawBytes = 0; + uint64 InboundConnectionEpoch = 0; + uint64 NextInboundSequenceId = 0; + bool bInboundAcceptingMessages = false; + bool bInboundProtocolErrorQueued = false; + FSpacetimeDbInboundWorker* InboundWorker = nullptr; + FCriticalSection InboundWorkerMutex; + + // Map of table name to row deserializer + TMap> TableDeserializers; + FCriticalSection TableDeserializersMutex; + + // Message-scoped preprocessed table rows active only while applying one inbound server message. + FPreprocessedTableDataMap* ActivePreprocessedTableData = nullptr; + + // Map of table name to generic table update handler + TMap> RegisteredTables; + TSharedPtr>> RegisteredTablesSnapshot; + FCriticalSection RegisteredTablesMutex; + + + /** Start a subscription. This will add the subscription to the active list and send a subscribe message to the server. */ + void StartSubscription(USubscriptionHandleBase* Handle); + /** Unsubscribe from a subscription. This will remove the subscription from the active list and send an unsubscribe message to the server. */ + void UnsubscribeInternal(USubscriptionHandleBase* Handle); + + /** Call a reducer on the connected SpacetimeDB instance. */ uint32 InternalCallReducer(const FString& Reducer, TArray Args); - - /** Call a reducer on the connected SpacetimeDB instance. */ - void InternalCallProcedure(const FString& ProcedureName, TArray Args, const FOnProcedureCompleteDelegate& Callback); - - /** - * Update function to apply database changes. - * Must be implemented by child classes. - * @param Update - Struct containing update data. - */ - virtual void DbUpdate(const FDatabaseUpdateType& Update, const FSpacetimeDBEvent& Event) {}; - - /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ - virtual void ReducerEvent(const FReducerEvent& Event) {}; - - /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ - virtual void ReducerEventFailed(const FReducerEvent& Event, const FString ErrorMessage) {}; - - /** Event handler for procedure events. This can should overridden by child classes to handle specific procedure events. */ - virtual void ProcedureEventFailed(const FProcedureEvent& Event, const FString ErrorMessage) {}; - - /** Event handler for error events. This can should overridden by child classes to handle specific error events. */ - virtual void TriggerError(const FString& ErrorMessage) {}; - - /** Event handler for subscription events. This can should overridden by child classes to handle specific subscription events. */ - virtual void TriggerSubscription() {}; - - /** Apply updates for all registered tables using the provided context pointer */ - void ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context); - + + /** Call a reducer on the connected SpacetimeDB instance. */ + void InternalCallProcedure(const FString& ProcedureName, TArray Args, const FOnProcedureCompleteDelegate& Callback); + + /** + * Update function to apply database changes. + * Must be implemented by child classes. + * @param Update - Struct containing update data. + */ + virtual void DbUpdate(const FDatabaseUpdateType& Update, const FSpacetimeDBEvent& Event) {}; + + /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ + virtual void ReducerEvent(const FReducerEvent& Event) {}; + + /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ + virtual void ReducerEventFailed(const FReducerEvent& Event, const FString ErrorMessage) {}; + + /** Event handler for procedure events. This can should overridden by child classes to handle specific procedure events. */ + virtual void ProcedureEventFailed(const FProcedureEvent& Event, const FString ErrorMessage) {}; + + /** Event handler for error events. This can should overridden by child classes to handle specific error events. */ + virtual void TriggerError(const FString& ErrorMessage) {}; + + /** Event handler for subscription events. This can should overridden by child classes to handle specific subscription events. */ + virtual void TriggerSubscription() {}; + + /** Apply updates for all registered tables using the provided context pointer */ + void ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context); + /** Called when a subscription is updated. */ UPROPERTY() TMap> ActiveSubscriptions; @@ -366,9 +737,9 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject /** Pending reducer call metadata keyed by request id for ReducerResult correlation. */ UPROPERTY() TMap PendingReducerCalls; - - UPROPERTY() - TObjectPtr ProcedureCallbacks; + + UPROPERTY() + TObjectPtr ProcedureCallbacks; /** Get the next request id for a message. This is used to track requests and responses. */ uint32 NextRequestId; /** Get the next subscription id for a subscription. This is used to track subscriptions and their responses. */ @@ -377,67 +748,78 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject uint32 GetNextRequestId(); /** Get the next subscription id for a subscription. This is used to track subscriptions and their responses. */ uint32 GetNextSubscriptionId(); - - /** The WebSocket manager used to connect to the server. */ - UPROPERTY() - UWebsocketManager* WebSocket = nullptr; - - /** The URI of the SpacetimeDB server to connect to. */ - UPROPERTY() - FString Uri; - /** The module name to connect to. This is used to identify the SpacetimeDB instance. */ - UPROPERTY() - FString ModuleName; - /** The token used to authenticate the connection. */ - UPROPERTY() - FString Token; - - /** The identity of the SpacetimeDB instance. This is used to identify the connection. */ - UPROPERTY() - FSpacetimeDBIdentity Identity; - UPROPERTY() - /** Whether the identity has been set. This is used to prevent multiple identity sets. */ - bool bIsIdentitySet = false; - /** The connection id of the SpacetimeDB instance. This is used to identify the connection. */ - UPROPERTY() - FSpacetimeDBConnectionId ConnectionId; - + + /** The WebSocket manager used to connect to the server. */ + UPROPERTY() + UWebsocketManager* WebSocket = nullptr; + + /** The URI of the SpacetimeDB server to connect to. */ + UPROPERTY() + FString Uri; + /** The module name to connect to. This is used to identify the SpacetimeDB instance. */ + UPROPERTY() + FString ModuleName; + /** The token used to authenticate the connection. */ + UPROPERTY() + FString Token; + + /** The identity of the SpacetimeDB instance. This is used to identify the connection. */ + UPROPERTY() + FSpacetimeDBIdentity Identity; + UPROPERTY() + /** Whether the identity has been set. This is used to prevent multiple identity sets. */ + bool bIsIdentitySet = false; + /** The connection id of the SpacetimeDB instance. This is used to identify the connection. */ + UPROPERTY() + FSpacetimeDBConnectionId ConnectionId; + UPROPERTY() bool bIsAutoTicking = false; - FTSTicker::FDelegateHandle TickerHandle; + + FSpacetimeDBInboundApplyBudget InboundApplyBudget; + struct FPendingTableBroadcast + { + TSharedPtr Handler; + int32 StatsIndex = INDEX_NONE; + }; + TArray TableUpdateHandlersScratch; + FSpacetimeDBInboundMessageApplyStats* ActiveInboundMessageApplyStats = nullptr; + /** Guard to avoid repeatedly handling the same fatal protocol error. */ FThreadSafeBool bProtocolViolationHandled = false; - - UPROPERTY() - FOnConnectErrorDelegate OnConnectErrorDelegate; - UPROPERTY() - FOnDisconnectBaseDelegate OnDisconnectBaseDelegate; - UPROPERTY() - FOnConnectBaseDelegate OnConnectBaseDelegate; - - /** Called when the connection is established. */ - template - void BroadcastDiff(TableClass* Table, const FTableAppliedDiff& Diff, const EventContext& Context) - { - if (!Table) return; - - // Broadcast the diff to the table's delegates - if (Table->OnInsert.IsBound()) - { - for (const TPair, RowType>& Pair : Diff.Inserts) - { - Table->OnInsert.Broadcast(Context, Pair.Value); - } - } - + + UPROPERTY() + FOnConnectErrorDelegate OnConnectErrorDelegate; + UPROPERTY() + FOnDisconnectBaseDelegate OnDisconnectBaseDelegate; + UPROPERTY() + FOnConnectBaseDelegate OnConnectBaseDelegate; + + /** Called when the connection is established. */ + template + void BroadcastDiff(TableClass* Table, const FTableAppliedDiff& Diff, const EventContext& Context) + { + if (!Table) return; + + // Broadcast the diff to the table's delegates + if (Table->OnInsert.IsBound()) + { + for (const TSharedPtr& Row : Diff.Inserts) + { + checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB insert diff row.")); + Table->OnInsert.Broadcast(Context, *Row); + } + } + // Event tables intentionally omit delete/update delegates. if constexpr (THasOnDeleteDelegate::value) { if (Table->OnDelete.IsBound()) { - for (const TPair, RowType>& Pair : Diff.Deletes) + for (const TSharedPtr& Row : Diff.Deletes) { - Table->OnDelete.Broadcast(Context, Pair.Value); + checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB delete diff row.")); + Table->OnDelete.Broadcast(Context, *Row); } } } @@ -446,12 +828,13 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject { if (Table->OnUpdate.IsBound()) { - int32 Count = FMath::Min(Diff.UpdateDeletes.Num(), Diff.UpdateInserts.Num()); - for (int32 Index = 0; Index < Count; ++Index) + checkf(Diff.UpdateDeletes.Num() == Diff.UpdateInserts.Num(), TEXT("Mismatched SpacetimeDB update diff counts.")); + for (int32 Index = 0; Index < Diff.UpdateInserts.Num(); ++Index) { - const RowType& OldRow = Diff.UpdateDeletes[Index]; - const RowType& NewRow = Diff.UpdateInserts[Index]; - Table->OnUpdate.Broadcast(Context, OldRow, NewRow); + const TSharedPtr& OldRow = Diff.UpdateDeletes[Index]; + const TSharedPtr& NewRow = Diff.UpdateInserts[Index]; + checkf(OldRow.IsValid() && NewRow.IsValid(), TEXT("Invalid SpacetimeDB update diff row.")); + Table->OnUpdate.Broadcast(Context, *OldRow, *NewRow); } } } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h index b9e6f91378d..15d375f7f2b 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h @@ -29,11 +29,11 @@ DECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParam(FOnWebSocketBinaryMessageReceived, c * Handles connecting, disconnecting, sending messages, and receiving messages. */ UCLASS(BlueprintType) -class SPACETIMEDBSDK_API UWebsocketManager : public UObject -{ - GENERATED_BODY() - -public: +class SPACETIMEDBSDK_API UWebsocketManager : public UObject +{ + GENERATED_BODY() + +public: UWebsocketManager(); virtual void BeginDestroy() override; @@ -66,14 +66,16 @@ class SPACETIMEDBSDK_API UWebsocketManager : public UObject /** * Checks if the WebSocket connection is currently active. * @return True if connected, false otherwise. - */ - bool IsConnected() const; - - /** - * Sets the initial auth token used when connecting. - * @param Token JWT or session token expected by the server. - */ - void SetInitToken(FString Token); + */ + bool IsConnected() const; + + void SetNativeBinaryMessageTarget(class UDbConnectionBase* Target); + + /** + * Sets the initial auth token used when connecting. + * @param Token JWT or session token expected by the server. + */ + void SetInitToken(FString Token); /** Delegates for WebSocket events */ UPROPERTY() @@ -105,16 +107,18 @@ class SPACETIMEDBSDK_API UWebsocketManager : public UObject void HandleConnectionError(const FString& Error); /** Handler for incoming text messages */ void HandleMessageReceived(const FString& Message); - /** Handler for incoming binary messages */ - void HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment); - /** Handler for socket close */ - void HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean); - - FString InitToken; - - /** Buffer used to accumulate binary fragments until a complete message - * is received. */ - TArray IncompleteMessage; + /** Handler for incoming binary messages */ + void HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment); + void DispatchCompleteBinaryMessage(TArray&& Message); + /** Handler for socket close */ + void HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean); + + FString InitToken; + TWeakObjectPtr NativeBinaryMessageTarget; + + /** Buffer used to accumulate binary fragments until a complete message + * is received. */ + TArray IncompleteMessage; /** Tracks if we are waiting for additional binary fragments. */ bool bAwaitingBinaryFragments = false; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h index 70693f4338c..6682607598c 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h @@ -1,7 +1,96 @@ #pragma once #include "CoreMinimal.h" +#include "BSATN/UESpacetimeDB.h" #include "TableCache.h" #include "TableAppliedDiff.h" +#include "WithBsatn.h" + +#include +#include + +namespace UE::SpacetimeDB +{ +enum class ETableCacheApplyMode : uint8 +{ + PersistentIndexed, + DirectNativeDiff +}; + +namespace Private +{ + static constexpr const TCHAR* MatchIdIndexName = TEXT("match_id"); + static constexpr const TCHAR* MobRuntimeSnapshotBatchTableName = TEXT("mob_runtime_snapshot_batch"); + static constexpr const TCHAR* MobCombatStateFrameTableName = TEXT("mob_combat_state_frame"); + static constexpr const TCHAR* MobAttackVisualBatchTableName = TEXT("mob_attack_visual_batch"); + static constexpr const TCHAR* MobProjectileVisualBatchTableName = TEXT("mob_projectile_visual_batch"); + static constexpr const TCHAR* AbilityCastVisualBatchTableName = TEXT("ability_cast_visual_batch"); + static constexpr const TCHAR* PlayerMotionFrameTableName = TEXT("player_motion_frame"); + static constexpr const TCHAR* PlayerCombatStateFrameTableName = TEXT("player_combat_state_frame"); + + template + struct THasUint64FrameKey + { + static constexpr bool Value = false; + }; + + template + struct THasUint64FrameKey().FrameKey)>> + { + using FieldType = std::remove_cv_t().FrameKey)>>; + static constexpr bool Value = std::is_same_v; + }; + + template + struct THasUint64BatchKey + { + static constexpr bool Value = false; + }; + + template + struct THasUint64BatchKey().BatchKey)>> + { + using FieldType = std::remove_cv_t().BatchKey)>>; + static constexpr bool Value = std::is_same_v; + }; +} + +template +struct TCompactPrimaryKeyTraits +{ + static constexpr bool bHasFrameKey = Private::THasUint64FrameKey::Value; + static constexpr bool bHasBatchKey = Private::THasUint64BatchKey::Value; + static_assert(!(bHasFrameKey && bHasBatchKey), "SpacetimeDB compact cache key trait requires exactly one generated key field."); + static constexpr bool bEnabled = bHasFrameKey || bHasBatchKey; + + using KeyType = uint64; + + static KeyType GetKey(const RowType& Row) + { + if constexpr (bHasFrameKey) + { + return Row.FrameKey; + } + else + { + static_assert(bHasBatchKey, "SpacetimeDB compact cache key trait is not active for this row type."); + return Row.BatchKey; + } + } + + static const TCHAR* GetUniqueIndexName() + { + if constexpr (bHasFrameKey) + { + return TEXT("frame_key"); + } + else + { + static_assert(bHasBatchKey, "SpacetimeDB compact cache key trait is not active for this row type."); + return TEXT("batch_key"); + } + } +}; +} /* ============================================================================ * * ClientCache.h (2025-05-28) @@ -20,6 +109,16 @@ class UClientCache */ TSharedPtr> Table; + void SetApplyMode(UE::SpacetimeDB::ETableCacheApplyMode InApplyMode) + { + ApplyMode = InApplyMode; + } + + UE::SpacetimeDB::ETableCacheApplyMode GetApplyMode() const + { + return ApplyMode; + } + /** * Retrieves the existing table cache or creates a new one if none exists. @@ -70,6 +169,261 @@ class UClientCache } + FTableAppliedDiff ApplyDiffByPrimaryKey( + const FString& Name, + TArray>&& Inserts, + TArray>&& Deletes, + const TCHAR* ExpectedUniqueIndexName) + { + using FCompactPrimaryKeyTraits = UE::SpacetimeDB::TCompactPrimaryKeyTraits; + static_assert(FCompactPrimaryKeyTraits::bEnabled, "ApplyDiffByPrimaryKey requires a generated compact primary-key trait."); + using KeyType = typename FCompactPrimaryKeyTraits::KeyType; + + checkf(!Name.IsEmpty(), TEXT("ApplyDiffByPrimaryKey called with empty table name.")); + checkf(Table.IsValid(), TEXT("ApplyDiffByPrimaryKey could not find table cache for %s."), *Name); + checkf(ExpectedUniqueIndexName != nullptr && ExpectedUniqueIndexName[0] != TEXT('\0'), + TEXT("ApplyDiffByPrimaryKey for %s requires a generated unique index name."), *Name); + const FString ExpectedIndexName(ExpectedUniqueIndexName); + checkf(Table->UniqueIndices.Contains(ExpectedIndexName), + TEXT("ApplyDiffByPrimaryKey for %s requires generated unique index %s."), + *Name, + ExpectedUniqueIndexName); + + if (ShouldApplyDirectNativeDiff(Name)) + { + return BuildDirectDiffByPrimaryKey(Name, MoveTemp(Inserts), MoveTemp(Deletes)); + } + + struct FDeletedRow + { + TArray CacheKey; + TSharedPtr Row; + int32 PendingCount = 0; + bool bUpdateApplied = false; + }; + + auto BuildCacheKey = [](const KeyType& Key) + { + static constexpr int32 CompactPrimaryKeyBytes = sizeof(KeyType); + TArray CacheKey; + CacheKey.SetNumUninitialized(CompactPrimaryKeyBytes); + uint64 Remaining = Key; + for (int32 ByteIndex = 0; ByteIndex < CompactPrimaryKeyBytes; ++ByteIndex) + { + CacheKey[ByteIndex] = static_cast(Remaining & 0xffu); + Remaining >>= 8; + } + return CacheKey; + }; + + auto RemoveFromIndices = [this, &Name](const TArray& Key, const TSharedPtr& Row) + { + checkf(Row.IsValid(), TEXT("Cannot remove invalid row from table indices.")); + for (auto& IndexPair : Table->UniqueIndices) + { + IndexPair.Value->RemoveRow(Row); + } + for (auto& IndexPair : Table->BTreeIndices) + { + if (ShouldSkipRuntimeApplyBTreeIndex(Name, IndexPair.Key)) + { + continue; + } + IndexPair.Value->RemoveRow(Key, Row); + } + }; + + auto AddToIndices = [this, &Name](const TArray& Key, const TSharedPtr& Row) + { + checkf(Row.IsValid(), TEXT("Cannot add invalid row to table indices.")); + for (auto& IndexPair : Table->UniqueIndices) + { + IndexPair.Value->AddRow(Row); + } + for (auto& IndexPair : Table->BTreeIndices) + { + if (ShouldSkipRuntimeApplyBTreeIndex(Name, IndexPair.Key)) + { + continue; + } + IndexPair.Value->AddRow(Key, Row); + } + }; + + FTableAppliedDiff Diff; + Diff.bPrimaryKeyUpdatesClassified = true; + Diff.Inserts.Reserve(Inserts.Num()); + Diff.Deletes.Reserve(Deletes.Num()); + Diff.UpdateDeletes.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + Diff.UpdateInserts.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + + TMap DeletedRows; + DeletedRows.Reserve(Deletes.Num()); + + for (const FWithBsatn& Delete : Deletes) + { + const KeyType PrimaryKey = FCompactPrimaryKeyTraits::GetKey(Delete.Row); + const TArray CacheKey = BuildCacheKey(PrimaryKey); + FRowEntry* Entry = Table->Entries.Find(CacheKey); + if (!Entry) + { + continue; + } + + checkf(Entry->RefCount > 0, TEXT("Table cache row for %s has invalid refcount before primary-key delete."), *Name); + checkf(Entry->Row.IsValid(), TEXT("Table cache row for %s is invalid before primary-key delete."), *Name); + FDeletedRow& Deleted = DeletedRows.FindOrAdd(PrimaryKey); + if (!Deleted.Row.IsValid()) + { + Deleted.CacheKey = CacheKey; + Deleted.Row = Entry->Row; + } + checkf(Deleted.CacheKey == CacheKey, TEXT("Mismatched compact cache key for primary-key delete on %s."), *Name); + ++Deleted.PendingCount; + --Entry->RefCount; + checkf(Entry->RefCount >= 0, TEXT("Table cache row for %s has negative refcount after primary-key delete."), *Name); + } + + for (FWithBsatn& Insert : Inserts) + { + const KeyType PrimaryKey = FCompactPrimaryKeyTraits::GetKey(Insert.Row); + const TArray CacheKey = BuildCacheKey(PrimaryKey); + FDeletedRow* MatchingDelete = DeletedRows.Find(PrimaryKey); + if (MatchingDelete && MatchingDelete->PendingCount > 0) + { + checkf(MatchingDelete->CacheKey == CacheKey, + TEXT("Mismatched compact cache key for primary-key update on %s."), *Name); + FRowEntry* Entry = Table->Entries.Find(CacheKey); + checkf(Entry != nullptr, TEXT("Missing table cache row for primary-key update on %s."), *Name); + checkf(Entry->RefCount >= 0, TEXT("Table cache row for %s has invalid refcount before primary-key update insert."), *Name); + checkf(MatchingDelete->Row.IsValid(), TEXT("Invalid deleted row for primary-key update on %s."), *Name); + + ++Entry->RefCount; + if (!MatchingDelete->bUpdateApplied) + { + TSharedPtr OldRow = MatchingDelete->Row; + TSharedPtr NewRow = MakeShared(MoveTemp(Insert.Row)); + RemoveFromIndices(CacheKey, OldRow); + Entry->Row = NewRow; + AddToIndices(CacheKey, NewRow); + + Diff.UpdateDeletes.Add(OldRow); + Diff.UpdateInserts.Add(NewRow); + MatchingDelete->bUpdateApplied = true; + MatchingDelete->Row = NewRow; + } + --MatchingDelete->PendingCount; + continue; + } + + FRowEntry* ExistingEntry = Table->Entries.Find(CacheKey); + if (ExistingEntry) + { + checkf(ExistingEntry->RefCount > 0, + TEXT("Primary-key insert for %s found an existing row with invalid refcount."), *Name); + ++ExistingEntry->RefCount; + continue; + } + + TSharedPtr NewRow = MakeShared(MoveTemp(Insert.Row)); + Table->Entries.Add(CacheKey, FRowEntry{NewRow, 1}); + AddToIndices(CacheKey, NewRow); + Diff.Inserts.Add(NewRow); + } + + for (const TPair& DeletedPair : DeletedRows) + { + const FDeletedRow& Deleted = DeletedPair.Value; + if (Deleted.PendingCount <= 0) + { + continue; + } + + FRowEntry* Entry = Table->Entries.Find(Deleted.CacheKey); + if (!Entry || Entry->RefCount > 0) + { + continue; + } + + checkf(Deleted.Row.IsValid(), TEXT("Invalid deleted row for primary-key delete on %s."), *Name); + checkf(Entry->RefCount == 0, TEXT("Primary-key delete for %s reached impossible refcount state."), *Name); + RemoveFromIndices(Deleted.CacheKey, Deleted.Row); + Diff.Deletes.Add(Deleted.Row); + Table->Entries.Remove(Deleted.CacheKey); + } + + return Diff; + } + +private: + FTableAppliedDiff BuildDirectDiffByPrimaryKey( + const FString& Name, + TArray>&& Inserts, + TArray>&& Deletes) + { + using FCompactPrimaryKeyTraits = UE::SpacetimeDB::TCompactPrimaryKeyTraits; + static_assert(FCompactPrimaryKeyTraits::bEnabled, "BuildDirectDiffByPrimaryKey requires a generated compact primary-key trait."); + using KeyType = typename FCompactPrimaryKeyTraits::KeyType; + + checkf(!Name.IsEmpty(), TEXT("BuildDirectDiffByPrimaryKey called with empty table name.")); + + FTableAppliedDiff Diff; + Diff.bPrimaryKeyUpdatesClassified = true; + Diff.Inserts.Reserve(Inserts.Num()); + Diff.Deletes.Reserve(Deletes.Num()); + Diff.UpdateDeletes.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + Diff.UpdateInserts.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + + TMap>> DeletesByKey; + DeletesByKey.Reserve(Deletes.Num()); + for (FWithBsatn& Delete : Deletes) + { + const KeyType PrimaryKey = FCompactPrimaryKeyTraits::GetKey(Delete.Row); + TArray>& Rows = DeletesByKey.FindOrAdd(PrimaryKey); + Rows.Add(MakeShared(MoveTemp(Delete.Row))); + } + + for (FWithBsatn& Insert : Inserts) + { + const KeyType PrimaryKey = FCompactPrimaryKeyTraits::GetKey(Insert.Row); + if (TArray>* MatchingDeletes = DeletesByKey.Find(PrimaryKey)) + { + checkf(!MatchingDeletes->IsEmpty(), + TEXT("Direct compact diff for %s found an empty delete bucket."), + *Name); + TSharedPtr OldRow = MatchingDeletes->Pop(EAllowShrinking::No); + checkf(OldRow.IsValid(), + TEXT("Direct compact diff for %s found an invalid deleted row."), + *Name); + if (MatchingDeletes->IsEmpty()) + { + DeletesByKey.Remove(PrimaryKey); + } + + Diff.UpdateDeletes.Add(OldRow); + Diff.UpdateInserts.Add(MakeShared(MoveTemp(Insert.Row))); + continue; + } + + Diff.Inserts.Add(MakeShared(MoveTemp(Insert.Row))); + } + + for (TPair>>& DeletePair : DeletesByKey) + { + for (TSharedPtr& DeletedRow : DeletePair.Value) + { + checkf(DeletedRow.IsValid(), + TEXT("Direct compact diff for %s retained an invalid deleted row."), + *Name); + Diff.Deletes.Add(MoveTemp(DeletedRow)); + } + } + + return Diff; + } + +public: + /** * Apply Inserts + Deletes to the specified table. * Inserts: increment refCount, add new entry when needed. @@ -77,8 +431,8 @@ class UClientCache */ FTableAppliedDiff ApplyDiff( const FString& Name, - const TArray, RowType>>& Inserts, - const TArray>& Deletes) + TArray>&& Inserts, + TArray>&& Deletes) { if (Name.IsEmpty()) { @@ -92,95 +446,141 @@ class UClientCache return FTableAppliedDiff(); } - FTableAppliedDiff Diff; - - // Map of deleted SerializedBytes -> (Key, Row) - // The key type is now generic TArray - TMap, TPair, TSharedPtr>> DeletedEntries; - - // Phase 1: Pre-process Deletes - for (const TArray& Key : Deletes) + struct FDeletedRow { + TSharedPtr Row; + bool bMatchedInsert = false; + }; + struct FInsertedRow + { + TArray Key; + TSharedPtr Row; + }; - FRowEntry* Entry = Table->Entries.Find(Key); - if (!Entry) continue; - - // Decrement refcount and store the entry if it's about to be deleted - if (--Entry->RefCount == 0) + auto RemoveFromIndices = [this](const TArray& Key, const TSharedPtr& Row) + { + checkf(Row.IsValid(), TEXT("Cannot remove invalid row from table indices.")); + for (auto& IndexPair : Table->UniqueIndices) { - DeletedEntries.Add(Key, TPair, TSharedPtr>(Key, Entry->Row)); + IndexPair.Value->RemoveRow(Row); } - } + for (auto& IndexPair : Table->BTreeIndices) + { + IndexPair.Value->RemoveRow(Key, Row); + } + }; - // Phase 2: Process Inserts and Updates - for (const auto& Ins : Inserts) + auto AddToIndices = [this](const TArray& Key, const TSharedPtr& Row) { - const TArray& Key = Ins.Key; - const RowType& Row = Ins.Value; - - TSharedPtr NewRow = MakeShared(Row); - - FRowEntry* Entry = Table->Entries.Find(Key); - if (!Entry) + checkf(Row.IsValid(), TEXT("Cannot add invalid row to table indices.")); + for (auto& IndexPair : Table->UniqueIndices) { - // True insert — these row-bytes are not cached yet. Either a - // genuinely new row, or the insert half of an update (the - // paired delete of different bytes is tracked in phase 1/3; - // DeriveUpdatesByPrimaryKey reconciles them by PK afterward). - Table->Entries.Add(Key, FRowEntry{NewRow, 1}); - Diff.Inserts.Add(Key, *NewRow); + IndexPair.Value->AddRow(Row); } - else + for (auto& IndexPair : Table->BTreeIndices) { - // Refcount bump — an overlapping subscription brought an - // identical row already in cache. Mirror the delete path - // (which only emits Diff.Deletes on refcount == 0) by not - // emitting a spurious Diff.Inserts entry here. - Table->Entries.Add(Key, FRowEntry{NewRow, Entry->RefCount + 1}); + IndexPair.Value->AddRow(Key, Row); } - } + }; - // Phase 3: Finalize Deletes and Update Indices - for (const auto& KeyValue : DeletedEntries) - { - // Add to diff before removal - Diff.Deletes.Add(KeyValue.Key, *KeyValue.Value.Value); - Table->Entries.Remove(KeyValue.Key); - } + FTableAppliedDiff Diff; + Diff.Inserts.Reserve(Inserts.Num()); + Diff.Deletes.Reserve(Deletes.Num()); + Diff.UpdateDeletes.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + Diff.UpdateInserts.Reserve(FMath::Min(Inserts.Num(), Deletes.Num())); + + TMap, FDeletedRow> DeletedRows; + DeletedRows.Reserve(Deletes.Num()); + TArray InsertedRows; + InsertedRows.Reserve(Inserts.Num()); - // Now, update all indices with the completed diff - for (const auto& DeletePair : Diff.Deletes) + for (const FWithBsatn& Delete : Deletes) { - for (auto& IndexPair : Table->UniqueIndices) + const TArray& Key = Delete.Bsatn; + FRowEntry* Entry = Table->Entries.Find(Key); + if (!Entry) { - // Assuming RemoveRow takes the TSharedPtr directly - IndexPair.Value->RemoveRow(MakeShared(DeletePair.Value)); + continue; } + + checkf(Entry->RefCount > 0, TEXT("Table cache row for %s has invalid refcount before delete."), *Name); + checkf(Entry->Row.IsValid(), TEXT("Table cache row for %s is invalid before delete."), *Name); + FDeletedRow& Deleted = DeletedRows.FindOrAdd(Key); + if (!Deleted.Row.IsValid()) + { + Deleted.Row = Entry->Row; + } + --Entry->RefCount; } - for (const auto& InsertPair : Diff.Inserts) + for (FWithBsatn& Insert : Inserts) { - for (auto& IndexPair : Table->UniqueIndices) + const TArray& Key = Insert.Bsatn; + FRowEntry* ExistingEntry = Table->Entries.Find(Key); + FDeletedRow* MatchingDelete = DeletedRows.Find(Key); + if (ExistingEntry) { - // Assuming AddRow takes TSharedPtr directly - IndexPair.Value->AddRow(MakeShared(InsertPair.Value)); + ++ExistingEntry->RefCount; + if (MatchingDelete) + { + MatchingDelete->bMatchedInsert = true; + } + continue; } + + TSharedPtr NewRow = MakeShared(MoveTemp(Insert.Row)); + Table->Entries.Add(Key, FRowEntry{NewRow, 1}); + InsertedRows.Add(FInsertedRow{Key, NewRow}); + Diff.Inserts.Add(NewRow); } - // And for BtreeIndices... - for (auto& Pair : Table->BTreeIndices) + for (const TPair, FDeletedRow>& DeletedPair : DeletedRows) { - for (const auto& DeletePair : Diff.Deletes) + const TArray& Key = DeletedPair.Key; + const FDeletedRow& Deleted = DeletedPair.Value; + if (Deleted.bMatchedInsert) { - Pair.Value->RemoveRow(DeletePair.Key, MakeShared(DeletePair.Value)); + continue; } - for (const auto& InsertPair : Diff.Inserts) + FRowEntry* Entry = Table->Entries.Find(Key); + if (!Entry || Entry->RefCount > 0) { - Pair.Value->AddRow(InsertPair.Key, MakeShared(InsertPair.Value)); + continue; } + + RemoveFromIndices(Key, Deleted.Row); + Diff.Deletes.Add(Deleted.Row); + Table->Entries.Remove(Key); + } + + for (const FInsertedRow& Inserted : InsertedRows) + { + AddToIndices(Inserted.Key, Inserted.Row); } return Diff; } +private: + bool ShouldApplyDirectNativeDiff(const FString& Name) const + { + return ApplyMode == UE::SpacetimeDB::ETableCacheApplyMode::DirectNativeDiff + || Name == UE::SpacetimeDB::Private::MobAttackVisualBatchTableName + || Name == UE::SpacetimeDB::Private::MobProjectileVisualBatchTableName + || Name == UE::SpacetimeDB::Private::AbilityCastVisualBatchTableName + || Name == UE::SpacetimeDB::Private::MobCombatStateFrameTableName + || Name == UE::SpacetimeDB::Private::PlayerMotionFrameTableName + || Name == UE::SpacetimeDB::Private::PlayerCombatStateFrameTableName; + } + + bool ShouldSkipRuntimeApplyBTreeIndex(const FString& Name, const FString& IndexName) const + { + return IndexName == UE::SpacetimeDB::Private::MatchIdIndexName + && (Name == UE::SpacetimeDB::Private::MobRuntimeSnapshotBatchTableName + || Name == UE::SpacetimeDB::Private::MobAttackVisualBatchTableName + || Name == UE::SpacetimeDB::Private::MobProjectileVisualBatchTableName + || Name == UE::SpacetimeDB::Private::AbilityCastVisualBatchTableName); + } + + UE::SpacetimeDB::ETableCacheApplyMode ApplyMode = UE::SpacetimeDB::ETableCacheApplyMode::PersistentIndexed; }; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h index cc76786999e..3e38e90d564 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h @@ -12,15 +12,13 @@ template struct FTableAppliedDiff { - // SerializedKey -> Row copy. Keeping the rows by value ensures - // the memory stays valid even if the underlying table reallocates - // or removes entries while this diff is alive. - TMap, RowType> Deletes; - TMap, RowType> Inserts; + TArray> Deletes; + TArray> Inserts; - // Parallel arrays for (old, new) row update pairs. - TArray UpdateDeletes; - TArray UpdateInserts; + TArray> UpdateDeletes; + TArray> UpdateInserts; + + bool bPrimaryKeyUpdatesClassified = false; bool IsEmpty() const { @@ -36,38 +34,79 @@ struct FTableAppliedDiff template void DeriveUpdatesByPrimaryKey(TFunctionRef DerivePK) { - if (Deletes.IsEmpty()) return; + if (bPrimaryKeyUpdatesClassified) + { + checkf(UpdateDeletes.Num() == UpdateInserts.Num(), TEXT("Pre-classified primary-key update diff arrays are mismatched.")); + return; + } + if (Deletes.IsEmpty() || Inserts.IsEmpty()) return; - // Build PK->(key,row) map for deletes. - TMap, RowType>> DeletePK; - for (const auto& Pair : Deletes) + const int32 DeleteCount = Deletes.Num(); + const int32 InsertCount = Inserts.Num(); + TMap>> DeletePK; + DeletePK.Reserve(Deletes.Num()); + for (int32 DeleteIndex = DeleteCount - 1; DeleteIndex >= 0; --DeleteIndex) { - DeletePK.Add(DerivePK(Pair.Value), { Pair.Key, Pair.Value }); + const TSharedPtr& DeletedRow = Deletes[DeleteIndex]; + checkf(DeletedRow.IsValid(), TEXT("Invalid deleted row while deriving SpacetimeDB table updates.")); + const KeyType PK = DerivePK(*DeletedRow); + DeletePK.FindOrAdd(PK).Add(DeleteIndex); } - // Scan inserts for matching PKs. - TArray> DeleteKeys; - TArray> InsertKeys; - for (const auto& Pair : Inserts) + const int32 MaxUpdatePairs = FMath::Min(DeleteCount, InsertCount); + TArray MatchedDeletes; + TArray MatchedInserts; + MatchedDeletes.Init(0, DeleteCount); + MatchedInserts.Init(0, InsertCount); + UpdateDeletes.Reserve(UpdateDeletes.Num() + MaxUpdatePairs); + UpdateInserts.Reserve(UpdateInserts.Num() + MaxUpdatePairs); + int32 MatchedPairCount = 0; + for (int32 InsertIndex = 0; InsertIndex < InsertCount; ++InsertIndex) { - KeyType PK = DerivePK(Pair.Value); - if (const auto* Found = DeletePK.Find(PK)) + const TSharedPtr& InsertedRow = Inserts[InsertIndex]; + checkf(InsertedRow.IsValid(), TEXT("Invalid inserted row while deriving SpacetimeDB table updates.")); + KeyType PK = DerivePK(*InsertedRow); + if (TArray>* DeleteIndices = DeletePK.Find(PK)) { - UpdateDeletes.Add(Found->Value); - UpdateInserts.Add(Pair.Value); - DeleteKeys.Add(Found->Key); - InsertKeys.Add(Pair.Key); + checkf(!DeleteIndices->IsEmpty(), TEXT("Empty deleted row index list while deriving SpacetimeDB table updates.")); + const int32 DeleteIndex = DeleteIndices->Pop(EAllowShrinking::No); + checkf(Deletes.IsValidIndex(DeleteIndex), TEXT("Invalid deleted row index while deriving SpacetimeDB table updates.")); + UpdateDeletes.Add(Deletes[DeleteIndex]); + UpdateInserts.Add(InsertedRow); + MatchedDeletes[DeleteIndex] = 1; + MatchedInserts[InsertIndex] = 1; + if (DeleteIndices->IsEmpty()) + { + DeletePK.Remove(PK); + } + ++MatchedPairCount; } } - // Remove update pairs from base maps. - for (const auto& K : DeleteKeys) + if (MatchedPairCount == 0) + { + return; + } + + TArray> RemainingDeletes; + TArray> RemainingInserts; + RemainingDeletes.Reserve(DeleteCount - MatchedPairCount); + RemainingInserts.Reserve(InsertCount - MatchedPairCount); + for (int32 DeleteIndex = 0; DeleteIndex < DeleteCount; ++DeleteIndex) { - Deletes.Remove(K); + if (MatchedDeletes[DeleteIndex] == 0) + { + RemainingDeletes.Add(MoveTemp(Deletes[DeleteIndex])); + } } - for (const auto& K : InsertKeys) + for (int32 InsertIndex = 0; InsertIndex < InsertCount; ++InsertIndex) { - Inserts.Remove(K); + if (MatchedInserts[InsertIndex] == 0) + { + RemainingInserts.Add(MoveTemp(Inserts[InsertIndex])); + } } + Deletes = MoveTemp(RemainingDeletes); + Inserts = MoveTemp(RemainingInserts); } }; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/WithBsatn.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/WithBsatn.h index e4c2a405459..353d88e59ec 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/WithBsatn.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/WithBsatn.h @@ -10,8 +10,11 @@ struct FWithBsatn /** Deserialized row value */ RowType Row; - FWithBsatn() = default; - FWithBsatn(const TArray& InBsatn, const RowType& InRow) - : Bsatn(InBsatn), Row(InRow) { - } -}; \ No newline at end of file + FWithBsatn() = default; + FWithBsatn(const TArray& InBsatn, const RowType& InRow) + : Bsatn(InBsatn), Row(InRow) { + } + FWithBsatn(TArray&& InBsatn, RowType&& InRow) + : Bsatn(MoveTemp(InBsatn)), Row(MoveTemp(InRow)) { + } +}; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Tables/RemoteTable.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Tables/RemoteTable.h index cb331255ec7..daa21cde936 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Tables/RemoteTable.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Tables/RemoteTable.h @@ -29,8 +29,8 @@ class SPACETIMEDBSDK_API URemoteTable : public UObject */ template FTableAppliedDiff BaseUpdate( - const TArray>& InsertsRef, - const TArray>& DeletesRef, + TArray>& InsertsRef, + TArray>& DeletesRef, const TSharedPtr>& ClientCache, const FString& InTableName ) @@ -42,19 +42,19 @@ class SPACETIMEDBSDK_API URemoteTable : public UObject return {}; } - TArray, T>> Inserts; - for (const FWithBsatn& Insert : InsertsRef) + // Forward ownership of the worker-preprocessed row arrays to avoid rebuilding them on the game thread. + using FCompactPrimaryKeyTraits = UE::SpacetimeDB::TCompactPrimaryKeyTraits; + if constexpr (FCompactPrimaryKeyTraits::bEnabled) { - Inserts.Add({ Insert.Bsatn, Insert.Row }); + return ClientCache->ApplyDiffByPrimaryKey( + InTableName, + MoveTemp(InsertsRef), + MoveTemp(DeletesRef), + FCompactPrimaryKeyTraits::GetUniqueIndexName()); } - - TArray> Deletes; - for (const FWithBsatn& Delete : DeletesRef) + else { - Deletes.Add(Delete.Bsatn); + return ClientCache->ApplyDiff(InTableName, MoveTemp(InsertsRef), MoveTemp(DeletesRef)); } - - // Forward to the shared client cache implementation - return ClientCache->ApplyDiff(InTableName, Inserts, Deletes); } -}; \ No newline at end of file +}; From c0ac97f856b1e0eab863aa94787bc996b728fe5f Mon Sep 17 00:00:00 2001 From: Brougkr Date: Tue, 5 May 2026 04:22:38 -0400 Subject: [PATCH 2/2] Harden Unreal SDK inbound cache and listener policy Remove FACTIONS-specific table policy from the generic Unreal SDK cache path. Direct native diffs now require explicit ETableCacheApplyMode::DirectNativeDiff, runtime B-tree index application is controlled through explicit cache policy hooks, and compact cache keys are driven by generated TCompactPrimaryKeyTraits specializations for schema-declared uint64 primary keys instead of FrameKey/BatchKey field-name inference. Make native listener dispatch lifetime-safe and explicit. Native listeners now store weak UObject owners plus an unregister identity key, expired owners are logged and removed before dispatch continues, and the default dispatch path invokes native listeners before dynamic delegates unless a registration explicitly opts into NativeOnly suppression for hot tables. Add decoded and parsed-memory backpressure. Gzip payloads whose declared decoded size exceeds the inbound cap are rejected before allocation, parsed queue accounting now uses estimated decoded/preprocessed memory instead of only raw payload bytes, parsed queue overload reports a protocol error, and raw inbound queue draining uses a read index with periodic compaction instead of front-removing every drain. Update public diff documentation to describe the lower-copy TSharedPtr-backed C++ diff representation and clarify that generated dynamic delegates remain value-reference based. This documents the direct FTableAppliedDiff source-shape change rather than claiming a purely additive or zero-copy API. --- crates/codegen/src/unrealcpp.rs | 96 +++++++++- .../Private/Connection/DbConnectionBase.cpp | 175 ++++++++++++++---- .../Public/BSATN/UEBSATNHelpers.h | 21 +++ .../Public/Connection/DbConnectionBase.h | 143 ++++++++++---- .../Public/DBCache/ClientCache.h | 153 ++++++++------- .../SpacetimeDbSdk/Public/DBCache/README.md | 4 +- .../Public/DBCache/TableAppliedDiff.h | 12 +- 7 files changed, 438 insertions(+), 166 deletions(-) diff --git a/crates/codegen/src/unrealcpp.rs b/crates/codegen/src/unrealcpp.rs index bab5d0bfed5..2e980843cb1 100644 --- a/crates/codegen/src/unrealcpp.rs +++ b/crates/codegen/src/unrealcpp.rs @@ -47,6 +47,7 @@ impl Lang for UnrealCpp<'_> { &[ "Types/Builtins.h", &format!("ModuleBindings/Types/{struct_name}Type.g.h"), + "DBCache/ClientCache.h", "Tables/RemoteTable.h", "DBCache/WithBsatn.h", "DBCache/TableHandle.h", @@ -63,6 +64,35 @@ impl Lang for UnrealCpp<'_> { // Generate unique index classes first let product_type = module.typespace_for_generate()[table.product_type_ref].as_product(); + if let (false, Some(pk)) = (table.is_event, schema.pk()) { + let (pk_name, pk_ty) = &product_type.unwrap().elements[pk.col_pos.idx()]; + let pk_type_str = cpp_ty_fmt_with_module(self.module_prefix, module, pk_ty, self.module_name).to_string(); + if pk_type_str == "uint64" { + let pk_field_name = pk_name.deref().to_case(Case::Pascal); + let pk_index_name = pk.col_name.deref(); + writeln!(output, "namespace UE::SpacetimeDB"); + writeln!(output, "{{"); + writeln!(output, "template<>"); + writeln!(output, "struct TCompactPrimaryKeyTraits<::{row_struct}>"); + writeln!(output, "{{"); + writeln!(output, " static constexpr bool bEnabled = true;"); + writeln!(output, " using KeyType = uint64;"); + writeln!(output); + writeln!(output, " static KeyType GetKey(const ::{row_struct}& Row)"); + writeln!(output, " {{"); + writeln!(output, " return Row.{pk_field_name};"); + writeln!(output, " }}"); + writeln!(output); + writeln!(output, " static const TCHAR* GetUniqueIndexName()"); + writeln!(output, " {{"); + writeln!(output, " return TEXT(\"{pk_index_name}\");"); + writeln!(output, " }}"); + writeln!(output, "}};"); + writeln!(output, "}}"); + writeln!(output); + } + } + let mut unique_indexes = Vec::new(); let mut multi_key_indexes = Vec::new(); @@ -308,6 +338,17 @@ impl Lang for UnrealCpp<'_> { writeln!(output, " void PostInitialize();"); writeln!(output); + writeln!( + output, + " void SetCacheApplyMode(UE::SpacetimeDB::ETableCacheApplyMode InApplyMode);" + ); + writeln!( + output, + " UE::SpacetimeDB::ETableCacheApplyMode GetCacheApplyMode() const;" + ); + writeln!(output, " void SetRuntimeBTreeIndexApplyMode(const FString& IndexName, UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode InApplyMode);"); + writeln!(output, " void ClearRuntimeBTreeIndexApplyModes();"); + writeln!(output); writeln!(output, " /** Update function for {table_name} table*/"); writeln!( @@ -1205,7 +1246,7 @@ fn generate_table_cpp( writeln!( output, " {table_pascal}Table->AddUniqueConstraint<{field_type}>(\"{}\", [](const {row_struct}& Row) -> const {field_type}& {{", - field_name.to_lowercase() + field_name ); writeln!(output, " return Row.{}; }});", field_name.to_case(Case::Pascal)); } @@ -1255,6 +1296,52 @@ fn generate_table_cpp( writeln!(output, "}}"); writeln!(output); + writeln!( + output, + "void U{table_pascal}Table::SetCacheApplyMode(UE::SpacetimeDB::ETableCacheApplyMode InApplyMode)" + ); + writeln!(output, "{{"); + writeln!( + output, + " checkf(Data.IsValid(), TEXT(\"{table_pascal} table cache policy set before PostInitialize.\"));" + ); + writeln!(output, " Data->SetApplyMode(InApplyMode);"); + writeln!(output, "}}"); + writeln!(output); + + writeln!( + output, + "UE::SpacetimeDB::ETableCacheApplyMode U{table_pascal}Table::GetCacheApplyMode() const" + ); + writeln!(output, "{{"); + writeln!( + output, + " checkf(Data.IsValid(), TEXT(\"{table_pascal} table cache policy read before PostInitialize.\"));" + ); + writeln!(output, " return Data->GetApplyMode();"); + writeln!(output, "}}"); + writeln!(output); + + writeln!(output, "void U{table_pascal}Table::SetRuntimeBTreeIndexApplyMode(const FString& IndexName, UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode InApplyMode)"); + writeln!(output, "{{"); + writeln!( + output, + " checkf(Data.IsValid(), TEXT(\"{table_pascal} runtime B-Tree index policy set before PostInitialize.\"));" + ); + writeln!( + output, + " Data->SetRuntimeBTreeIndexApplyMode(IndexName, InApplyMode);" + ); + writeln!(output, "}}"); + writeln!(output); + + writeln!(output, "void U{table_pascal}Table::ClearRuntimeBTreeIndexApplyModes()"); + writeln!(output, "{{"); + writeln!(output, " checkf(Data.IsValid(), TEXT(\"{table_pascal} runtime B-Tree index policy reset before PostInitialize.\"));"); + writeln!(output, " Data->ClearRuntimeBTreeIndexApplyModes();"); + writeln!(output, "}}"); + writeln!(output); + // Generate Update implementation writeln!( output, @@ -1267,9 +1354,12 @@ fn generate_table_cpp( " // Event tables are callback-only: do not persist rows in the local cache." ); writeln!(output, " FTableAppliedDiff<{row_struct}> Diff;"); - writeln!(output, " for (const FWithBsatn<{row_struct}>& Insert : InsertsRef)"); + writeln!(output, " for (FWithBsatn<{row_struct}>& Insert : InsertsRef)"); writeln!(output, " {{"); - writeln!(output, " Diff.Inserts.Add(Insert.Bsatn, Insert.Row);"); + writeln!( + output, + " Diff.Inserts.Add(MakeShared<{row_struct}>(MoveTemp(Insert.Row)));" + ); writeln!(output, " }}"); } else { writeln!( diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp index 1607ca986c2..04a8dc9fd7f 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp @@ -26,8 +26,10 @@ enum class EWsCompressionTag : uint8 constexpr int32 MaxQueuedInboundRawMessages = 8192; constexpr int64 MaxQueuedInboundRawBytes = 128ll * 1024ll * 1024ll; constexpr int32 MaxPendingInboundParsedMessages = 8192; -constexpr int64 MaxPendingInboundParsedPayloadBytes = 128ll * 1024ll * 1024ll; +constexpr int64 MaxInboundDecodedPayloadBytes = 128ll * 1024ll * 1024ll; +constexpr int64 MaxPendingInboundParsedEstimatedBytes = 256ll * 1024ll * 1024ll; constexpr int32 PendingInboundCompactionMinConsumedMessages = 512; +constexpr int32 InboundRawCompactionMinConsumedMessages = 512; constexpr uint32 InboundWorkerStackSizeBytes = 0; constexpr EThreadPriority InboundWorkerThreadPriority = TPri_Normal; constexpr const TCHAR* InboundWorkerThreadName = TEXT("SpacetimeDBInboundWorker"); @@ -122,6 +124,32 @@ static FString FormatInboundTableApplyStats(const FSpacetimeDBTableApplyStats& S Stats.BroadcastMicros, Stats.bProducedDiff ? 1 : 0); } + +static int64 EstimatePreprocessedTableDataBytes(const FPreprocessedTableDataMap& PreprocessedTableData) +{ + int64 EstimatedBytes = PreprocessedTableData.GetAllocatedSize(); + for (const TPair>>& TablePair : PreprocessedTableData) + { + EstimatedBytes += TablePair.Key.TableName.GetAllocatedSize(); + EstimatedBytes += TablePair.Value.GetAllocatedSize(); + for (const TSharedPtr& Data : TablePair.Value) + { + if (Data.IsValid()) + { + EstimatedBytes += Data->EstimateMemoryBytes(); + } + } + } + return EstimatedBytes; +} + +static int64 EstimateInboundParsedMessageBytes(const FInboundParsedMessage& Message) +{ + return sizeof(FInboundParsedMessage) + + static_cast(Message.DecodedPayloadSizeBytes) + + Message.ProtocolError.GetAllocatedSize() + + EstimatePreprocessedTableDataBytes(Message.PreprocessedTableData); +} } class FSpacetimeDbInboundWorker final : public FRunnable @@ -326,6 +354,7 @@ void UDbConnectionBase::StartInboundMessageWorker() { FScopeLock RawLock(&InboundRawMessagesMutex); InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; InboundQueuedRawBytes = 0; ++InboundConnectionEpoch; NextInboundSequenceId = 0; @@ -337,7 +366,7 @@ void UDbConnectionBase::StartInboundMessageWorker() FScopeLock PendingLock(&PendingMessagesMutex); PendingMessages.Reset(); PendingMessageReadIndex = 0; - PendingParsedPayloadBytes = 0; + PendingParsedEstimatedBytes = 0; } ActivePreprocessedTableData = nullptr; @@ -352,6 +381,7 @@ void UDbConnectionBase::StopInboundMessageWorker() { FScopeLock RawLock(&InboundRawMessagesMutex); InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; InboundQueuedRawBytes = 0; ++InboundConnectionEpoch; bInboundAcceptingMessages = false; @@ -362,7 +392,7 @@ void UDbConnectionBase::StopInboundMessageWorker() FScopeLock PendingLock(&PendingMessagesMutex); PendingMessages.Reset(); PendingMessageReadIndex = 0; - PendingParsedPayloadBytes = 0; + PendingParsedEstimatedBytes = 0; } ActivePreprocessedTableData = nullptr; @@ -383,6 +413,7 @@ void UDbConnectionBase::ClearInboundMessageQueues() { FScopeLock Lock(&InboundRawMessagesMutex); InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; InboundQueuedRawBytes = 0; } @@ -390,7 +421,7 @@ void UDbConnectionBase::ClearInboundMessageQueues() FScopeLock Lock(&PendingMessagesMutex); PendingMessages.Reset(); PendingMessageReadIndex = 0; - PendingParsedPayloadBytes = 0; + PendingParsedEstimatedBytes = 0; } ActivePreprocessedTableData = nullptr; @@ -407,7 +438,9 @@ void UDbConnectionBase::NotifyInboundWorkerIfNeeded() bool bShouldNotify = false; { FScopeLock RawLock(&InboundRawMessagesMutex); - bShouldNotify = InboundRawMessages.Num() > 0 && bInboundAcceptingMessages && !bInboundProtocolErrorQueued; + checkf(InboundRawMessageReadIndex <= InboundRawMessages.Num(), + TEXT("SpacetimeDB inbound raw queue read index exceeded queued messages while notifying worker.")); + bShouldNotify = InboundRawMessages.Num() > InboundRawMessageReadIndex && bInboundAcceptingMessages && !bInboundProtocolErrorQueued; } if (bShouldNotify) @@ -434,6 +467,7 @@ void UDbConnectionBase::MarkInboundProtocolErrorQueued() bInboundProtocolErrorQueued = true; bInboundAcceptingMessages = false; InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; InboundQueuedRawBytes = 0; } @@ -458,13 +492,14 @@ void UDbConnectionBase::EnqueueInboundProtocolError(uint64 SequenceId, int32 Pay Parsed.QueuedBytesAtEnqueue = QueuedBytesAtEnqueue; Parsed.bProtocolError = true; Parsed.ProtocolError = ErrorMessage; + Parsed.EstimatedMemoryBytes = EstimateInboundParsedMessageBytes(Parsed); FScopeLock Lock(&PendingMessagesMutex); PendingMessages.Reset(); PendingMessageReadIndex = 0; - PendingParsedPayloadBytes = 0; + PendingParsedEstimatedBytes = 0; PendingMessages.Add(MoveTemp(Parsed)); - PendingParsedPayloadBytes += static_cast(PayloadSizeBytes); + PendingParsedEstimatedBytes += PendingMessages[0].EstimatedMemoryBytes; } void UDbConnectionBase::HandleWSBinaryMessage(const TArray& Message) @@ -497,8 +532,11 @@ void UDbConnectionBase::HandleWSBinaryMessageOwned(TArray&& Message) ConnectionEpoch = InboundConnectionEpoch; SequenceId = NextInboundSequenceId++; + checkf(InboundRawMessageReadIndex <= InboundRawMessages.Num(), + TEXT("SpacetimeDB inbound raw queue read index exceeded queued messages while enqueuing payload.")); const int64 NewQueuedRawBytes = InboundQueuedRawBytes + static_cast(PayloadSizeBytes); - const int32 NewQueuedRawMessageCount = InboundRawMessages.Num() + 1; + const int32 LiveQueuedRawMessageCount = InboundRawMessages.Num() - InboundRawMessageReadIndex; + const int32 NewQueuedRawMessageCount = LiveQueuedRawMessageCount + 1; QueueDepthAtEnqueue = NewQueuedRawMessageCount; QueuedBytesAtEnqueue = NewQueuedRawBytes; if (NewQueuedRawMessageCount > MaxQueuedInboundRawMessages || NewQueuedRawBytes > MaxQueuedInboundRawBytes) @@ -506,6 +544,7 @@ void UDbConnectionBase::HandleWSBinaryMessageOwned(TArray&& Message) bInboundProtocolErrorQueued = true; bInboundAcceptingMessages = false; InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; InboundQueuedRawBytes = 0; bQueueOverloaded = true; QueueOverloadError = FString::Printf( @@ -567,7 +606,7 @@ void UDbConnectionBase::FrameTick() { PendingMessages.Reset(); PendingMessageReadIndex = 0; - PendingParsedPayloadBytes = 0; + PendingParsedEstimatedBytes = 0; break; } @@ -578,13 +617,14 @@ void UDbConnectionBase::FrameTick() FInboundParsedMessage& PendingMessage = PendingMessages[PendingMessageReadIndex]; const int64 PendingPayloadBytes = static_cast(PendingMessage.PayloadSizeBytes); + const int64 PendingEstimatedBytes = PendingMessage.EstimatedMemoryBytes; if (!bDrainAllPendingMessages && MessagesProcessed > 0 && PayloadBytesProcessed + PendingPayloadBytes > InboundApplyBudget.MaxPayloadBytesPerFrame) { break; } PayloadBytesProcessed += PendingPayloadBytes; - PendingParsedPayloadBytes = FMath::Max(0, PendingParsedPayloadBytes - PendingPayloadBytes); + PendingParsedEstimatedBytes = FMath::Max(0, PendingParsedEstimatedBytes - PendingEstimatedBytes); Msg = MoveTemp(PendingMessage); ++PendingMessageReadIndex; @@ -592,7 +632,7 @@ void UDbConnectionBase::FrameTick() { PendingMessages.Reset(); PendingMessageReadIndex = 0; - PendingParsedPayloadBytes = 0; + PendingParsedEstimatedBytes = 0; } else if (PendingMessageReadIndex >= PendingInboundCompactionMinConsumedMessages) { @@ -689,15 +729,15 @@ void UDbConnectionBase::DrainInboundRawMessagesOnWorker() while (!IsInboundProtocolErrorQueued()) { int32 ParsedMessageCapacity = 0; - int64 ParsedPayloadByteCapacity = 0; + int64 ParsedEstimatedByteCapacity = 0; { FScopeLock Lock(&PendingMessagesMutex); const int32 LivePendingMessages = PendingMessages.Num() - PendingMessageReadIndex; ParsedMessageCapacity = MaxPendingInboundParsedMessages - LivePendingMessages; - ParsedPayloadByteCapacity = MaxPendingInboundParsedPayloadBytes - PendingParsedPayloadBytes; + ParsedEstimatedByteCapacity = MaxPendingInboundParsedEstimatedBytes - PendingParsedEstimatedBytes; } - if (ParsedMessageCapacity <= 0 || ParsedPayloadByteCapacity <= 0) + if (ParsedMessageCapacity <= 0 || ParsedEstimatedByteCapacity <= 0) { return; } @@ -706,20 +746,24 @@ void UDbConnectionBase::DrainInboundRawMessagesOnWorker() int64 DrainedRawBytes = 0; { FScopeLock Lock(&InboundRawMessagesMutex); - if (InboundRawMessages.Num() == 0 || !bInboundAcceptingMessages || bInboundProtocolErrorQueued) + checkf(InboundRawMessageReadIndex <= InboundRawMessages.Num(), + TEXT("SpacetimeDB inbound raw queue read index exceeded queued messages while draining worker queue.")); + const int32 LiveRawMessageCount = InboundRawMessages.Num() - InboundRawMessageReadIndex; + if (LiveRawMessageCount <= 0 || !bInboundAcceptingMessages || bInboundProtocolErrorQueued) { return; } int32 DrainCount = 0; - for (; DrainCount < InboundRawMessages.Num() && DrainCount < ParsedMessageCapacity; ++DrainCount) + for (; DrainCount < LiveRawMessageCount && DrainCount < ParsedMessageCapacity; ++DrainCount) { - const int64 NextPayloadBytes = static_cast(InboundRawMessages[DrainCount].Payload.Num()); - if (DrainCount > 0 && DrainedRawBytes + NextPayloadBytes > ParsedPayloadByteCapacity) + const FInboundRawMessage& Candidate = InboundRawMessages[InboundRawMessageReadIndex + DrainCount]; + const int64 NextPayloadBytes = static_cast(Candidate.Payload.Num()); + if (DrainCount > 0 && DrainedRawBytes + NextPayloadBytes > ParsedEstimatedByteCapacity) { break; } - if (DrainCount == 0 && NextPayloadBytes > ParsedPayloadByteCapacity) + if (DrainCount == 0 && NextPayloadBytes > ParsedEstimatedByteCapacity) { return; } @@ -735,10 +779,20 @@ void UDbConnectionBase::DrainInboundRawMessagesOnWorker() LocalRawMessages.Reserve(DrainCount); for (int32 Index = 0; Index < DrainCount; ++Index) { - LocalRawMessages.Add(MoveTemp(InboundRawMessages[Index])); + LocalRawMessages.Add(MoveTemp(InboundRawMessages[InboundRawMessageReadIndex + Index])); } - InboundRawMessages.RemoveAt(0, DrainCount, EAllowShrinking::No); + InboundRawMessageReadIndex += DrainCount; + if (InboundRawMessageReadIndex == InboundRawMessages.Num()) + { + InboundRawMessages.Reset(); + InboundRawMessageReadIndex = 0; + } + else if (InboundRawMessageReadIndex >= InboundRawCompactionMinConsumedMessages) + { + InboundRawMessages.RemoveAt(0, InboundRawMessageReadIndex, EAllowShrinking::No); + InboundRawMessageReadIndex = 0; + } InboundQueuedRawBytes = FMath::Max(0, InboundQueuedRawBytes - DrainedRawBytes); } @@ -789,31 +843,52 @@ void UDbConnectionBase::DrainInboundRawMessagesOnWorker() uint64 OverloadSequenceId = LocalParsedMessages[0].SequenceId; int32 OverloadPayloadSizeBytes = LocalParsedMessages[0].PayloadSizeBytes; uint8 OverloadCompressionTag = LocalParsedMessages[0].CompressionTag; + bool bParsedQueueOverloaded = false; + FString ParsedQueueOverloadError; { FScopeLock Lock(&PendingMessagesMutex); - int64 AddedPayloadBytes = 0; + int64 AddedEstimatedBytes = 0; for (const FInboundParsedMessage& ParsedMessage : LocalParsedMessages) { - AddedPayloadBytes += static_cast(ParsedMessage.PayloadSizeBytes); + AddedEstimatedBytes += ParsedMessage.EstimatedMemoryBytes; } const int32 LivePendingMessages = PendingMessages.Num() - PendingMessageReadIndex; const int32 NewPendingMessageCount = LivePendingMessages + LocalParsedMessages.Num(); - const int64 NewPendingPayloadBytes = PendingParsedPayloadBytes + AddedPayloadBytes; - checkf(bBatchEndsWithProtocolError || - (NewPendingMessageCount <= MaxPendingInboundParsedMessages && - NewPendingPayloadBytes <= MaxPendingInboundParsedPayloadBytes), - TEXT("SpacetimeDB parsed inbound queue overflow despite worker backpressure: sequence=%llu payload_bytes=%d compression_tag=%u queued_messages=%d queued_bytes=%lld max_messages=%d max_bytes=%lld"), + const int64 NewPendingEstimatedBytes = PendingParsedEstimatedBytes + AddedEstimatedBytes; + bParsedQueueOverloaded = !bBatchEndsWithProtocolError && + (NewPendingMessageCount > MaxPendingInboundParsedMessages || + NewPendingEstimatedBytes > MaxPendingInboundParsedEstimatedBytes); + if (bParsedQueueOverloaded) + { + ParsedQueueOverloadError = FString::Printf( + TEXT("SpacetimeDB parsed inbound queue overload: sequence=%llu payload_bytes=%d compression_tag=%u queued_messages=%d estimated_bytes=%lld max_messages=%d max_estimated_bytes=%lld"), + OverloadSequenceId, + OverloadPayloadSizeBytes, + static_cast(OverloadCompressionTag), + NewPendingMessageCount, + NewPendingEstimatedBytes, + MaxPendingInboundParsedMessages, + MaxPendingInboundParsedEstimatedBytes); + } + else + { + PendingMessages.Append(MoveTemp(LocalParsedMessages)); + PendingParsedEstimatedBytes = NewPendingEstimatedBytes; + } + } + + if (bParsedQueueOverloaded) + { + MarkInboundProtocolErrorQueued(); + EnqueueInboundProtocolError( OverloadSequenceId, OverloadPayloadSizeBytes, - static_cast(OverloadCompressionTag), - NewPendingMessageCount, - NewPendingPayloadBytes, - MaxPendingInboundParsedMessages, - MaxPendingInboundParsedPayloadBytes); - - PendingMessages.Append(MoveTemp(LocalParsedMessages)); - PendingParsedPayloadBytes = NewPendingPayloadBytes; + OverloadCompressionTag, + LocalParsedMessages[0].QueueDepthAtEnqueue, + LocalParsedMessages[0].QueuedBytesAtEnqueue, + ParsedQueueOverloadError); + return; } } } @@ -839,6 +914,7 @@ bool UDbConnectionBase::BuildInboundParsedMessage(const FInboundRawMessage& RawM static_cast(OutMessage.CompressionTag), OutMessage.QueueDepthAtEnqueue, OutMessage.QueuedBytesAtEnqueue); + OutMessage.EstimatedMemoryBytes = EstimateInboundParsedMessageBytes(OutMessage); return false; } @@ -1127,7 +1203,20 @@ bool UDbConnectionBase::DecompressGzip(const uint8* InData, int32 InSize, TArray // Gzip data ends with 4 bytes indicating the uncompressed size const uint8* SizePtr = InData + InSize - GzipFooterUncompressedSizeBytes; - uint32 OutSize = SizePtr[0] | (SizePtr[1] << 8) | (SizePtr[2] << 16) | (SizePtr[3] << 24); + const uint32 OutSize = + static_cast(SizePtr[0]) | + (static_cast(SizePtr[1]) << 8) | + (static_cast(SizePtr[2]) << 16) | + (static_cast(SizePtr[3]) << 24); + if (static_cast(OutSize) > MaxInboundDecodedPayloadBytes) + { + UE_LOG(LogSpacetimeDb_Connection, + Error, + TEXT("Gzip payload declares %u decoded bytes, exceeding max decoded bytes %lld"), + OutSize, + MaxInboundDecodedPayloadBytes); + return false; + } // Validate the output size OutData.SetNumUninitialized(OutSize); @@ -1264,6 +1353,16 @@ bool UDbConnectionBase::PreProcessMessage(const TArray& Message, FInbound UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Unknown compression variant")); return false; } + if (static_cast(DecodedPayloadSize) > MaxInboundDecodedPayloadBytes) + { + UE_LOG(LogSpacetimeDb_Connection, + Error, + TEXT("Decoded server message payload has %d bytes, exceeding max decoded bytes %lld after compression tag %u."), + DecodedPayloadSize, + MaxInboundDecodedPayloadBytes, + static_cast(Compression)); + return false; + } if (DecodedPayloadSize <= 0 || DecodedPayload == nullptr) { UE_LOG(LogSpacetimeDb_Connection, @@ -1273,6 +1372,7 @@ bool UDbConnectionBase::PreProcessMessage(const TArray& Message, FInbound return false; } + OutMessage.DecodedPayloadSizeBytes = DecodedPayloadSize; // Deserialize the decompressed data into a UServerMessageType object OutMessage.Message = UE::SpacetimeDB::DeserializeView(DecodedPayload, DecodedPayloadSize); @@ -1313,6 +1413,7 @@ bool UDbConnectionBase::PreProcessMessage(const TArray& Message, FInbound default: break; } + OutMessage.EstimatedMemoryBytes = EstimateInboundParsedMessageBytes(OutMessage); return true; } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h index 7e93e81d248..8999bf9fd9c 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/BSATN/UEBSATNHelpers.h @@ -125,6 +125,10 @@ namespace UE::SpacetimeDB struct FPreprocessedTableDataBase { virtual ~FPreprocessedTableDataBase() {} + virtual int64 EstimateMemoryBytes() const + { + return sizeof(FPreprocessedTableDataBase); + } int32 InsertRowCount = 0; int32 DeleteRowCount = 0; int32 RowSetCount = 0; @@ -139,6 +143,23 @@ namespace UE::SpacetimeDB // The type of the row being processed TArray> Inserts; TArray> Deletes; + + virtual int64 EstimateMemoryBytes() const override + { + auto EstimateRowsBytes = [](const TArray>& Rows) + { + int64 Bytes = Rows.GetAllocatedSize(); + for (const FWithBsatn& Row : Rows) + { + Bytes += Row.Bsatn.GetAllocatedSize(); + } + return Bytes; + }; + + return sizeof(TPreprocessedTableData) + + EstimateRowsBytes(Inserts) + + EstimateRowsBytes(Deletes); + } }; /** Interface for deserializing table rows from a database update. Allows for different row types to be processed in SDK. */ diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h index 7a8da5d1a1a..b2b130743bb 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h @@ -110,6 +110,8 @@ struct FInboundParsedMessage uint64 ConnectionEpoch = 0; uint64 SequenceId = 0; int32 PayloadSizeBytes = 0; + int32 DecodedPayloadSizeBytes = 0; + int64 EstimatedMemoryBytes = 0; uint8 CompressionTag = 0; int32 QueueDepthAtEnqueue = 0; int64 QueuedBytesAtEnqueue = 0; @@ -207,6 +209,12 @@ const void* GetNativeTableListenerTypeId() return &TypeId; } +enum class ESpacetimeDBNativeListenerDispatchMode : uint8 +{ + NativeAndDynamic, + NativeOnly +}; + struct FNativeTableListenerBinding { using FInsertThunk = void(*)(void* Owner, const void* Context, const void* Row); @@ -214,17 +222,19 @@ struct FNativeTableListenerBinding using FDeleteThunk = void(*)(void* Owner, const void* Context, const void* Row); using FDiffThunk = void(*)(void* Owner, const void* Context, const void* Diff); - void* Owner = nullptr; + TWeakObjectPtr Owner; + void* OwnerKey = nullptr; const void* RowTypeId = nullptr; const void* EventContextTypeId = nullptr; FInsertThunk InsertThunk = nullptr; FUpdateThunk UpdateThunk = nullptr; FDeleteThunk DeleteThunk = nullptr; FDiffThunk DiffThunk = nullptr; + ESpacetimeDBNativeListenerDispatchMode DispatchMode = ESpacetimeDBNativeListenerDispatchMode::NativeAndDynamic; bool IsComplete() const { - return Owner != nullptr && + return OwnerKey != nullptr && RowTypeId != nullptr && EventContextTypeId != nullptr && (DiffThunk != nullptr || @@ -319,7 +329,7 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGam virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) = 0; virtual const FString& GetTableName() const = 0; virtual void RegisterNativeListener(const FNativeTableListenerBinding& Binding) = 0; - virtual void UnregisterNativeListener(void* Owner) = 0; + virtual void UnregisterNativeListener(void* OwnerKey) = 0; }; template @@ -375,11 +385,8 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGam checkf(PendingDiffReadIndex < PendingDiffs.Num(), TEXT("Missing pending SpacetimeDB table diff for broadcast.")); EventContext& Ctx = *reinterpret_cast(Context); const FTableAppliedDiff& Diff = PendingDiffs[PendingDiffReadIndex]; - if (!NativeListeners.IsEmpty()) - { - BroadcastNativeDiff(Diff, Ctx); - } - else + const bool bSuppressDynamicDispatch = BroadcastNativeDiff(Diff, Ctx); + if (!bSuppressDynamicDispatch) { Conn->BroadcastDiff(Table, Diff, Ctx); } @@ -402,29 +409,32 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGam TEXT("Cannot register native SpacetimeDB table listener during broadcast for table '%s'."), *TableName); checkf(Binding.IsComplete(), TEXT("Incomplete native SpacetimeDB table listener for table '%s'."), *TableName); + checkf(Binding.Owner.IsValid(), + TEXT("Cannot register invalid native SpacetimeDB table listener owner for table '%s'."), + *TableName); checkf(Binding.RowTypeId == GetNativeTableListenerTypeId(), TEXT("Native SpacetimeDB table listener row type mismatch for table '%s'."), *TableName); checkf(Binding.EventContextTypeId == GetNativeTableListenerTypeId(), TEXT("Native SpacetimeDB table listener context type mismatch for table '%s'."), *TableName); for (const FNativeTableListenerBinding& ExistingBinding : NativeListeners) { - checkf(ExistingBinding.Owner != Binding.Owner, + checkf(ExistingBinding.OwnerKey != Binding.OwnerKey, TEXT("Duplicate native SpacetimeDB table listener owner for table '%s'."), *TableName); } NativeListeners.Add(Binding); } - virtual void UnregisterNativeListener(void* Owner) override + virtual void UnregisterNativeListener(void* OwnerKey) override { checkf(!bBroadcastingNativeListeners, TEXT("Cannot unregister native SpacetimeDB table listener during broadcast for table '%s'."), *TableName); - checkf(Owner != nullptr, TEXT("Cannot unregister null native SpacetimeDB table listener owner for table '%s'."), *TableName); + checkf(OwnerKey != nullptr, TEXT("Cannot unregister null native SpacetimeDB table listener owner for table '%s'."), *TableName); const int32 ListenerIndex = NativeListeners.IndexOfByPredicate( - [Owner](const FNativeTableListenerBinding& Binding) + [OwnerKey](const FNativeTableListenerBinding& Binding) { - return Binding.Owner == Owner; + return Binding.OwnerKey == OwnerKey; }); checkf(ListenerIndex != INDEX_NONE, TEXT("Missing native SpacetimeDB table listener for table '%s'."), @@ -433,49 +443,89 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGam } private: - void BroadcastNativeDiff(const FTableAppliedDiff& Diff, const EventContext& Context) + bool BroadcastNativeDiff(const FTableAppliedDiff& Diff, const EventContext& Context) { - TGuardValue BroadcastingScope(bBroadcastingNativeListeners, true); - for (const FNativeTableListenerBinding& Listener : NativeListeners) + if (NativeListeners.IsEmpty()) { - BroadcastNativeDiffToListener(Diff, Context, Listener); + return false; } - } - void BroadcastNativeDiffToListener( - const FTableAppliedDiff& Diff, - const EventContext& Context, - const FNativeTableListenerBinding& Listener) - { - checkf(Listener.IsComplete(), TEXT("Incomplete native SpacetimeDB table listener for table '%s'."), *TableName); - if (Listener.DiffThunk != nullptr) + bool bSuppressDynamicDispatch = false; + TArray ExpiredOwnerKeys; { - Listener.DiffThunk(Listener.Owner, &Context, &Diff); - return; + TGuardValue BroadcastingScope(bBroadcastingNativeListeners, true); + for (const FNativeTableListenerBinding& Listener : NativeListeners) + { + UObject* OwnerObject = Listener.Owner.Get(); + if (OwnerObject == nullptr) + { + ExpiredOwnerKeys.Add(Listener.OwnerKey); + UE_LOG(LogSpacetimeDb_Connection, + Error, + TEXT("Removing expired native SpacetimeDB table listener owner for table '%s'. Native listeners must be unregistered before owner destruction."), + *TableName); + continue; + } + + BroadcastNativeDiffToListener(Diff, Context, Listener, OwnerObject); + if (Listener.DispatchMode == ESpacetimeDBNativeListenerDispatchMode::NativeOnly) + { + bSuppressDynamicDispatch = true; + } + } } - for (const TSharedPtr& Row : Diff.Inserts) + for (void* ExpiredOwnerKey : ExpiredOwnerKeys) { - checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB native insert diff row for table '%s'."), *TableName); - Listener.InsertThunk(Listener.Owner, &Context, Row.Get()); + NativeListeners.RemoveAllSwap( + [ExpiredOwnerKey](const FNativeTableListenerBinding& Listener) + { + return Listener.OwnerKey == ExpiredOwnerKey; + }, + EAllowShrinking::No); } + return bSuppressDynamicDispatch; + } - for (const TSharedPtr& Row : Diff.Deletes) + void BroadcastNativeDiffToListener( + const FTableAppliedDiff& Diff, + const EventContext& Context, + const FNativeTableListenerBinding& Listener, + UObject* OwnerObject) { - checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB native delete diff row for table '%s'."), *TableName); - Listener.DeleteThunk(Listener.Owner, &Context, Row.Get()); - } + checkf(Listener.IsComplete(), TEXT("Incomplete native SpacetimeDB table listener for table '%s'."), *TableName); + checkf(OwnerObject != nullptr, TEXT("Cannot dispatch native SpacetimeDB table listener to a null owner for table '%s'."), *TableName); + checkf(Listener.Owner.Get() == OwnerObject, + TEXT("Native SpacetimeDB table listener owner identity mismatch for table '%s'."), + *TableName); + if (Listener.DiffThunk != nullptr) + { + Listener.DiffThunk(Listener.OwnerKey, &Context, &Diff); + return; + } + + for (const TSharedPtr& Row : Diff.Inserts) + { + checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB native insert diff row for table '%s'."), *TableName); + Listener.InsertThunk(Listener.OwnerKey, &Context, Row.Get()); + } + + for (const TSharedPtr& Row : Diff.Deletes) + { + checkf(Row.IsValid(), TEXT("Invalid SpacetimeDB native delete diff row for table '%s'."), *TableName); + Listener.DeleteThunk(Listener.OwnerKey, &Context, Row.Get()); + } checkf(Diff.UpdateDeletes.Num() == Diff.UpdateInserts.Num(), TEXT("Mismatched SpacetimeDB native update diff counts for table '%s'."), *TableName); for (int32 Index = 0; Index < Diff.UpdateInserts.Num(); ++Index) { const TSharedPtr& OldRow = Diff.UpdateDeletes[Index]; - const TSharedPtr& NewRow = Diff.UpdateInserts[Index]; - checkf(OldRow.IsValid() && NewRow.IsValid(), TEXT("Invalid SpacetimeDB native update diff row for table '%s'."), *TableName); - Listener.UpdateThunk(Listener.Owner, &Context, OldRow.Get(), NewRow.Get()); + const TSharedPtr& NewRow = Diff.UpdateInserts[Index]; + checkf(OldRow.IsValid() && NewRow.IsValid(), TEXT("Invalid SpacetimeDB native update diff row for table '%s'."), *TableName); + Listener.UpdateThunk(Listener.OwnerKey, &Context, OldRow.Get(), NewRow.Get()); + } } - } FString TableName; TableClass* Table; @@ -498,15 +548,20 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGam void (OwnerType::*InsertFn)(const EventContext&, const RowType&), void (OwnerType::*UpdateFn)(const EventContext&, const RowType&, const RowType&), void (OwnerType::*DeleteFn)(const EventContext&, const RowType&)> - void RegisterNativeTableListener(const FString& TableName, OwnerType* Owner) + void RegisterNativeTableListener( + const FString& TableName, + OwnerType* Owner, + ESpacetimeDBNativeListenerDispatchMode DispatchMode = ESpacetimeDBNativeListenerDispatchMode::NativeAndDynamic) { static_assert(std::is_base_of_v, "Native SpacetimeDB table listener owner must derive from UObject."); checkf(Owner != nullptr, TEXT("Cannot register null native SpacetimeDB table listener owner for table '%s'."), *TableName); FNativeTableListenerBinding Binding; Binding.Owner = Owner; + Binding.OwnerKey = Owner; Binding.RowTypeId = GetNativeTableListenerTypeId(); Binding.EventContextTypeId = GetNativeTableListenerTypeId(); + Binding.DispatchMode = DispatchMode; Binding.InsertThunk = [](void* RawOwner, const void* RawContext, const void* RawRow) { (static_cast(RawOwner)->*InsertFn)( @@ -536,15 +591,20 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGam template&)> - void RegisterNativeTableDiffListener(const FString& TableName, OwnerType* Owner) + void RegisterNativeTableDiffListener( + const FString& TableName, + OwnerType* Owner, + ESpacetimeDBNativeListenerDispatchMode DispatchMode = ESpacetimeDBNativeListenerDispatchMode::NativeAndDynamic) { static_assert(std::is_base_of_v, "Native SpacetimeDB table diff listener owner must derive from UObject."); checkf(Owner != nullptr, TEXT("Cannot register null native SpacetimeDB table diff listener owner for table '%s'."), *TableName); FNativeTableListenerBinding Binding; Binding.Owner = Owner; + Binding.OwnerKey = Owner; Binding.RowTypeId = GetNativeTableListenerTypeId(); Binding.EventContextTypeId = GetNativeTableListenerTypeId(); + Binding.DispatchMode = DispatchMode; Binding.DiffThunk = [](void* RawOwner, const void* RawContext, const void* RawDiff) { (static_cast(RawOwner)->*DiffFn)( @@ -668,11 +728,12 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGam /** Mutex protecting access to PendingMessages. */ FCriticalSection PendingMessagesMutex; int32 PendingMessageReadIndex = 0; - int64 PendingParsedPayloadBytes = 0; + int64 PendingParsedEstimatedBytes = 0; /** Raw inbound messages awaiting FIFO processing by the connection-owned worker. */ TArray InboundRawMessages; mutable FCriticalSection InboundRawMessagesMutex; + int32 InboundRawMessageReadIndex = 0; int64 InboundQueuedRawBytes = 0; uint64 InboundConnectionEpoch = 0; uint64 NextInboundSequenceId = 0; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h index 6682607598c..1d90345eecb 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/ClientCache.h @@ -16,78 +16,41 @@ enum class ETableCacheApplyMode : uint8 DirectNativeDiff }; -namespace Private +enum class ERuntimeBTreeIndexApplyMode : uint8 { - static constexpr const TCHAR* MatchIdIndexName = TEXT("match_id"); - static constexpr const TCHAR* MobRuntimeSnapshotBatchTableName = TEXT("mob_runtime_snapshot_batch"); - static constexpr const TCHAR* MobCombatStateFrameTableName = TEXT("mob_combat_state_frame"); - static constexpr const TCHAR* MobAttackVisualBatchTableName = TEXT("mob_attack_visual_batch"); - static constexpr const TCHAR* MobProjectileVisualBatchTableName = TEXT("mob_projectile_visual_batch"); - static constexpr const TCHAR* AbilityCastVisualBatchTableName = TEXT("ability_cast_visual_batch"); - static constexpr const TCHAR* PlayerMotionFrameTableName = TEXT("player_motion_frame"); - static constexpr const TCHAR* PlayerCombatStateFrameTableName = TEXT("player_combat_state_frame"); - - template - struct THasUint64FrameKey - { - static constexpr bool Value = false; - }; - - template - struct THasUint64FrameKey().FrameKey)>> - { - using FieldType = std::remove_cv_t().FrameKey)>>; - static constexpr bool Value = std::is_same_v; - }; - - template - struct THasUint64BatchKey - { - static constexpr bool Value = false; - }; - - template - struct THasUint64BatchKey().BatchKey)>> - { - using FieldType = std::remove_cv_t().BatchKey)>>; - static constexpr bool Value = std::is_same_v; - }; -} + Apply, + Skip +}; template struct TCompactPrimaryKeyTraits { - static constexpr bool bHasFrameKey = Private::THasUint64FrameKey::Value; - static constexpr bool bHasBatchKey = Private::THasUint64BatchKey::Value; - static_assert(!(bHasFrameKey && bHasBatchKey), "SpacetimeDB compact cache key trait requires exactly one generated key field."); - static constexpr bool bEnabled = bHasFrameKey || bHasBatchKey; - + static constexpr bool bEnabled = false; using KeyType = uint64; static KeyType GetKey(const RowType& Row) { - if constexpr (bHasFrameKey) - { - return Row.FrameKey; - } - else - { - static_assert(bHasBatchKey, "SpacetimeDB compact cache key trait is not active for this row type."); - return Row.BatchKey; - } + (void)Row; + static_assert(bEnabled, "SpacetimeDB compact cache key trait is not generated for this row type."); + return 0; } static const TCHAR* GetUniqueIndexName() { - if constexpr (bHasFrameKey) - { - return TEXT("frame_key"); - } - else - { - static_assert(bHasBatchKey, "SpacetimeDB compact cache key trait is not active for this row type."); - return TEXT("batch_key"); - } + static_assert(bEnabled, "SpacetimeDB compact cache key trait is not generated for this row type."); + return TEXT(""); + } +}; + +template +struct TTableCachePolicy +{ + static constexpr ETableCacheApplyMode ApplyMode = ETableCacheApplyMode::PersistentIndexed; + + static ERuntimeBTreeIndexApplyMode GetRuntimeBTreeIndexApplyMode(const FString& IndexName) + { + (void)IndexName; + return ERuntimeBTreeIndexApplyMode::Apply; } }; } @@ -119,6 +82,30 @@ class UClientCache return ApplyMode; } + void SetRuntimeBTreeIndexApplyMode( + const FString& IndexName, + UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode InApplyMode) + { + checkf(!IndexName.IsEmpty(), TEXT("Cannot configure runtime B-Tree index policy for an empty index name.")); + switch (InApplyMode) + { + case UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode::Apply: + RuntimeBTreeIndexApplyModes.Add(IndexName, InApplyMode); + break; + case UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode::Skip: + RuntimeBTreeIndexApplyModes.Add(IndexName, InApplyMode); + break; + default: + checkf(false, TEXT("Unknown runtime B-Tree index apply mode for index '%s'."), *IndexName); + break; + } + } + + void ClearRuntimeBTreeIndexApplyModes() + { + RuntimeBTreeIndexApplyModes.Reset(); + } + /** * Retrieves the existing table cache or creates a new one if none exists. @@ -189,7 +176,7 @@ class UClientCache *Name, ExpectedUniqueIndexName); - if (ShouldApplyDirectNativeDiff(Name)) + if (ShouldApplyDirectNativeDiff()) { return BuildDirectDiffByPrimaryKey(Name, MoveTemp(Inserts), MoveTemp(Deletes)); } @@ -216,7 +203,7 @@ class UClientCache return CacheKey; }; - auto RemoveFromIndices = [this, &Name](const TArray& Key, const TSharedPtr& Row) + auto RemoveFromIndices = [this](const TArray& Key, const TSharedPtr& Row) { checkf(Row.IsValid(), TEXT("Cannot remove invalid row from table indices.")); for (auto& IndexPair : Table->UniqueIndices) @@ -225,7 +212,7 @@ class UClientCache } for (auto& IndexPair : Table->BTreeIndices) { - if (ShouldSkipRuntimeApplyBTreeIndex(Name, IndexPair.Key)) + if (!ShouldApplyRuntimeBTreeIndex(IndexPair.Key)) { continue; } @@ -233,7 +220,7 @@ class UClientCache } }; - auto AddToIndices = [this, &Name](const TArray& Key, const TSharedPtr& Row) + auto AddToIndices = [this](const TArray& Key, const TSharedPtr& Row) { checkf(Row.IsValid(), TEXT("Cannot add invalid row to table indices.")); for (auto& IndexPair : Table->UniqueIndices) @@ -242,7 +229,7 @@ class UClientCache } for (auto& IndexPair : Table->BTreeIndices) { - if (ShouldSkipRuntimeApplyBTreeIndex(Name, IndexPair.Key)) + if (!ShouldApplyRuntimeBTreeIndex(IndexPair.Key)) { continue; } @@ -445,6 +432,9 @@ class UClientCache UE_LOG(LogTemp, Error, TEXT("Failed to create or retrieve table: %s"), *Name); return FTableAppliedDiff(); } + checkf(ApplyMode != UE::SpacetimeDB::ETableCacheApplyMode::DirectNativeDiff, + TEXT("DirectNativeDiff for table %s requires a generated compact uint64 primary-key trait."), + *Name); struct FDeletedRow { @@ -466,6 +456,10 @@ class UClientCache } for (auto& IndexPair : Table->BTreeIndices) { + if (!ShouldApplyRuntimeBTreeIndex(IndexPair.Key)) + { + continue; + } IndexPair.Value->RemoveRow(Key, Row); } }; @@ -479,6 +473,10 @@ class UClientCache } for (auto& IndexPair : Table->BTreeIndices) { + if (!ShouldApplyRuntimeBTreeIndex(IndexPair.Key)) + { + continue; + } IndexPair.Value->AddRow(Key, Row); } }; @@ -562,25 +560,22 @@ class UClientCache return Diff; } private: - bool ShouldApplyDirectNativeDiff(const FString& Name) const + bool ShouldApplyDirectNativeDiff() const { - return ApplyMode == UE::SpacetimeDB::ETableCacheApplyMode::DirectNativeDiff - || Name == UE::SpacetimeDB::Private::MobAttackVisualBatchTableName - || Name == UE::SpacetimeDB::Private::MobProjectileVisualBatchTableName - || Name == UE::SpacetimeDB::Private::AbilityCastVisualBatchTableName - || Name == UE::SpacetimeDB::Private::MobCombatStateFrameTableName - || Name == UE::SpacetimeDB::Private::PlayerMotionFrameTableName - || Name == UE::SpacetimeDB::Private::PlayerCombatStateFrameTableName; + return ApplyMode == UE::SpacetimeDB::ETableCacheApplyMode::DirectNativeDiff; } - bool ShouldSkipRuntimeApplyBTreeIndex(const FString& Name, const FString& IndexName) const + bool ShouldApplyRuntimeBTreeIndex(const FString& IndexName) const { - return IndexName == UE::SpacetimeDB::Private::MatchIdIndexName - && (Name == UE::SpacetimeDB::Private::MobRuntimeSnapshotBatchTableName - || Name == UE::SpacetimeDB::Private::MobAttackVisualBatchTableName - || Name == UE::SpacetimeDB::Private::MobProjectileVisualBatchTableName - || Name == UE::SpacetimeDB::Private::AbilityCastVisualBatchTableName); + checkf(!IndexName.IsEmpty(), TEXT("Cannot apply runtime B-Tree index policy for an empty index name.")); + if (const UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode* RuntimeMode = RuntimeBTreeIndexApplyModes.Find(IndexName)) + { + return *RuntimeMode == UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode::Apply; + } + return UE::SpacetimeDB::TTableCachePolicy::GetRuntimeBTreeIndexApplyMode(IndexName) + == UE::SpacetimeDB::ERuntimeBTreeIndexApplyMode::Apply; } - UE::SpacetimeDB::ETableCacheApplyMode ApplyMode = UE::SpacetimeDB::ETableCacheApplyMode::PersistentIndexed; + UE::SpacetimeDB::ETableCacheApplyMode ApplyMode = UE::SpacetimeDB::TTableCachePolicy::ApplyMode; + TMap RuntimeBTreeIndexApplyModes; }; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/README.md b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/README.md index 7ff0ed9ee3d..d1a23afdaaa 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/README.md +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/README.md @@ -8,7 +8,7 @@ Utilities used to maintain a client side cache of database tables. These classe - `ClientCache.h` – Owns `FTableCache` objects and applies insert/delete diffs sent over the network. - `IUniqueIndex.h` – Interface that unique index implementations conform to. - `RowEntry.h` – Wrapper storing a row value with a reference count used by overlapping subscriptions. -- `TableAppliedDiff.h` – Describes the inserts, deletes and updates detected when applying a diff. +- `TableAppliedDiff.h` – Describes the inserts, deletes and updates detected when applying a diff. Direct C++ consumers receive `TSharedPtr` row arrays; generated dynamic delegates continue broadcasting value references. - `TableCache.h` – In-memory representation of a table and its unique indices. - `TableHandle.h` – Lightweight helper exposing read only access to a cached table. - `UniqueConstraintHandle.h` – Helper that allows typed lookups against a unique constraint. @@ -126,4 +126,4 @@ Table.GetValues(AllRows); - `TArray` keys allow serialized identifiers (network-friendly). - Adding indices after inserting rows is **not supported** without manual rebuild. -- B-Tree indices can later be extended for **range queries**.ed to a **single column** per call. \ No newline at end of file +- B-Tree indices can later be extended for **range queries** and are currently scoped to a **single column** per call. diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h index 3e38e90d564..640b50343a3 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/DBCache/TableAppliedDiff.h @@ -4,10 +4,14 @@ /* ============================================================================ * * TableAppliedDiff.h (2025-05-28) * ---------------------------------------------------------------------------- - * Captures the semantic result of applying a low‑level diff (inserts/deletes) - * to a table cache. Rows that transition from dead→live are inserts, live→dead - * are deletes, and a delete+insert with the same primary‑key is surfaced as an - * update pair. + * Captures the semantic result of applying a low-level diff (inserts/deletes) + * to a table cache. Rows that transition from dead to live are inserts, live + * to dead are deletes, and a delete+insert with the same primary key is + * surfaced as an update pair. + * + * This is the SDK's lower-copy diff representation. Direct C++ consumers read + * row payloads from TSharedPtr-backed arrays; generated dynamic delegates still + * broadcast value references for Blueprint and existing dynamic delegate code. * ============================================================================ */ template struct FTableAppliedDiff