diff --git a/change/react-native-windows-c3827e14-777b-475a-bf00-dc169bf89f3d.json b/change/react-native-windows-c3827e14-777b-475a-bf00-dc169bf89f3d.json new file mode 100644 index 00000000000..e48ac161fcd --- /dev/null +++ b/change/react-native-windows-c3827e14-777b-475a-bf00-dc169bf89f3d.json @@ -0,0 +1,7 @@ +{ + "type": "prerelease", + "comment": "Fix WebSocket binaryType handling — stop unconditional Blob interception of binary messages", + "packageName": "react-native-windows", + "email": "gordomacmaster@gmail.com", + "dependentChangeType": "patch" +} \ No newline at end of file diff --git a/vnext/Shared/Modules/IWebSocketModuleContentHandler.h b/vnext/Shared/Modules/IWebSocketModuleContentHandler.h index 4d508603865..f64b0fc65ab 100644 --- a/vnext/Shared/Modules/IWebSocketModuleContentHandler.h +++ b/vnext/Shared/Modules/IWebSocketModuleContentHandler.h @@ -18,11 +18,49 @@ namespace Microsoft::React { struct IWebSocketModuleContentHandler { virtual ~IWebSocketModuleContentHandler() noexcept {} + /// Returns true if this handler should process messages for the given socket. + /// Default returns true for backward compatibility; BlobModule overrides to + /// check whether binaryType='blob' was set for this socket via addWebSocketHandler. + /// + /// WARNING: Subclasses that override Supports() with a stateful or lock-protected + /// check MUST also override both TryProcessMessage() overloads to perform the + /// check-and-process atomically. The default TryProcessMessage() calls Supports() + /// and ProcessMessage() as two separate operations with no lock held between them. + virtual bool Supports(int64_t /*socketId*/) noexcept { + return true; + } + virtual void ProcessMessage(std::string &&message, winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept = 0; virtual void ProcessMessage( std::vector &&message, winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept = 0; + + /// Check Supports() then ProcessMessage() in one call. + /// Returns true if the message was handled. + /// + /// The default implementation does NOT hold any lock across both operations. + /// Subclasses with a stateful Supports() MUST override these to make the + /// check-and-process atomic (see BlobWebSocketModuleContentHandler for an example). + virtual bool TryProcessMessage( + int64_t socketId, + std::string &&message, + winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept { + if (!Supports(socketId)) + return false; + ProcessMessage(std::move(message), params); + return true; + } + + virtual bool TryProcessMessage( + int64_t socketId, + std::vector &&message, + winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept { + if (!Supports(socketId)) + return false; + ProcessMessage(std::move(message), params); + return true; + } }; } // namespace Microsoft::React diff --git a/vnext/Shared/Modules/WebSocketModule.cpp b/vnext/Shared/Modules/WebSocketModule.cpp index 7b0f5bb12a9..6bbc800207f 100644 --- a/vnext/Shared/Modules/WebSocketModule.cpp +++ b/vnext/Shared/Modules/WebSocketModule.cpp @@ -83,6 +83,7 @@ shared_ptr WebSocketTurboModule::CreateResource(int64_t id, if (auto prop = propBag.Get(BlobModuleContentHandlerPropertyId())) contentHandler = prop.Value().lock(); + bool handled = false; if (contentHandler) { if (isBinary) { auto buffer = CryptographicBuffer::DecodeFromBase64String(winrt::to_hstring(message)); @@ -90,11 +91,12 @@ shared_ptr WebSocketTurboModule::CreateResource(int64_t id, CryptographicBuffer::CopyToByteArray(buffer, arr); auto data = vector(arr.begin(), arr.end()); - contentHandler->ProcessMessage(std::move(data), args); + handled = contentHandler->TryProcessMessage(id, std::move(data), args); } else { - contentHandler->ProcessMessage(string{message}, args); + handled = contentHandler->TryProcessMessage(id, string{message}, args); } - } else { + } + if (!handled) { args["data"] = message; } diff --git a/vnext/Shared/Networking/DefaultBlobResource.cpp b/vnext/Shared/Networking/DefaultBlobResource.cpp index 31fdfd6b061..a774d3b700f 100644 --- a/vnext/Shared/Networking/DefaultBlobResource.cpp +++ b/vnext/Shared/Networking/DefaultBlobResource.cpp @@ -221,6 +221,11 @@ BlobWebSocketModuleContentHandler::BlobWebSocketModuleContentHandler(shared_ptr< #pragma region IWebSocketModuleContentHandler +bool BlobWebSocketModuleContentHandler::Supports(int64_t socketId) noexcept /*override*/ { + scoped_lock lock{m_mutex}; + return m_socketIds.find(socketId) != m_socketIds.end(); +} + void BlobWebSocketModuleContentHandler::ProcessMessage( string &&message, msrn::JSValueObject ¶ms) noexcept /*override*/ @@ -241,6 +246,38 @@ void BlobWebSocketModuleContentHandler::ProcessMessage( params[blobKeys.Type] = blobKeys.Blob; } +bool BlobWebSocketModuleContentHandler::TryProcessMessage( + int64_t socketId, + string &&message, + msrn::JSValueObject ¶ms) noexcept /*override*/ +{ + scoped_lock lock{m_mutex}; + if (m_socketIds.find(socketId) == m_socketIds.end()) + return false; + + params[blobKeys.Data] = std::move(message); + return true; +} + +bool BlobWebSocketModuleContentHandler::TryProcessMessage( + int64_t socketId, + vector &&message, + msrn::JSValueObject ¶ms) noexcept /*override*/ +{ + scoped_lock lock{m_mutex}; + if (m_socketIds.find(socketId) == m_socketIds.end()) + return false; + + auto blob = msrn::JSValueObject{ + {blobKeys.Offset, 0}, + {blobKeys.Size, message.size()}, + {blobKeys.BlobId, m_blobPersistor->StoreMessage(std::move(message))}}; + + params[blobKeys.Data] = std::move(blob); + params[blobKeys.Type] = blobKeys.Blob; + return true; +} + #pragma endregion IWebSocketModuleContentHandler void BlobWebSocketModuleContentHandler::Register(int64_t socketID) noexcept { diff --git a/vnext/Shared/Networking/DefaultBlobResource.h b/vnext/Shared/Networking/DefaultBlobResource.h index 4dfdf5f18aa..68f5903ce47 100644 --- a/vnext/Shared/Networking/DefaultBlobResource.h +++ b/vnext/Shared/Networking/DefaultBlobResource.h @@ -51,11 +51,23 @@ class BlobWebSocketModuleContentHandler final : public IWebSocketModuleContentHa #pragma region IWebSocketModuleContentHandler + bool Supports(int64_t socketId) noexcept override; + void ProcessMessage(std::string &&message, winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept override; void ProcessMessage(std::vector &&message, winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept override; + bool TryProcessMessage( + int64_t socketId, + std::string &&message, + winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept override; + + bool TryProcessMessage( + int64_t socketId, + std::vector &&message, + winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept override; + #pragma endregion IWebSocketModuleContentHandler void Register(int64_t socketID) noexcept;