diff --git a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/api/state/ChatClientStateExtensions.kt b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/api/state/ChatClientStateExtensions.kt index f0be86cb937..44ffac23a82 100644 --- a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/api/state/ChatClientStateExtensions.kt +++ b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/api/state/ChatClientStateExtensions.kt @@ -164,6 +164,30 @@ public fun ChatClient.watchChannelAsState( } } +/** + * Same as [watchChannelAsState] but loads messages around the given [aroundMessageId] in the initial request. + * Use this when opening a channel with focus on a specific message (e.g. deep-link) to avoid a race between + * the initial watch and a separate loadAround call. + * + * @param aroundMessageId When non-null, the initial watch uses AROUND_ID pagination to load messages + * around this ID. Only use for channel-level focus (not when opening a thread). + */ +@InternalStreamChatApi +@JvmOverloads +public fun ChatClient.watchChannelAsState( + cid: String, + messageLimit: Int, + aroundMessageId: String?, + coroutineScope: CoroutineScope = CoroutineScope(DispatcherProvider.IO), +): StateFlow { + StreamLog.i(TAG) { + "[watchChannelAsState] cid: $cid, messageLimit: $messageLimit, aroundMessageId: $aroundMessageId" + } + return getStateOrNull(coroutineScope) { + requestsAsState(coroutineScope).watchChannel(cid, messageLimit, stateConfig.userPresence, aroundMessageId) + } +} + /** * Performs [ChatClient.queryThreadsResult] under the hood and returns [QueryThreadsState]. * The [QueryThreadsState] cannot be created before connecting the user therefore, the method returns a StateFlow diff --git a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/state/internal/ChatClientStateCalls.kt b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/state/internal/ChatClientStateCalls.kt index 5a204416e5d..4bc2b5b1378 100644 --- a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/state/internal/ChatClientStateCalls.kt +++ b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/state/internal/ChatClientStateCalls.kt @@ -18,6 +18,7 @@ package io.getstream.chat.android.client.internal.state.plugin.state.internal import io.getstream.chat.android.client.ChatClient import io.getstream.chat.android.client.api.event.ChatEventHandlerFactory +import io.getstream.chat.android.client.api.models.Pagination import io.getstream.chat.android.client.api.models.QueryChannelRequest import io.getstream.chat.android.client.api.models.QueryChannelsRequest import io.getstream.chat.android.client.api.models.QueryThreadsRequest @@ -84,15 +85,27 @@ internal class ChatClientStateCalls( } /** Reference request of the watch channel query. */ - internal suspend fun watchChannel(cid: String, messageLimit: Int, userPresence: Boolean): ChannelState { - logger.d { "[watchChannel] cid: $cid, messageLimit: $messageLimit, userPresence: $userPresence" } + internal suspend fun watchChannel( + cid: String, + messageLimit: Int, + userPresence: Boolean, + aroundMessageId: String? = null, + ): ChannelState { + logger.d { + "[watchChannel] cid: $cid, messageLimit: $messageLimit, userPresence: $userPresence, " + + "aroundMessageId: $aroundMessageId" + } val (channelType, channelId) = cid.cidToTypeAndId() - val request = QueryChannelPaginationRequest(messageLimit) - .toWatchChannelRequest(userPresence) - .apply { - this.shouldRefresh = false - this.isWatchChannel = true + val request = when (aroundMessageId) { + null -> QueryChannelPaginationRequest(messageLimit) + else -> QueryChannelPaginationRequest(messageLimit).apply { + messageFilterDirection = Pagination.AROUND_ID + messageFilterValue = aroundMessageId } + }.toWatchChannelRequest(userPresence).apply { + shouldRefresh = false + isWatchChannel = true + } return queryChannel(channelType, channelId, request) } diff --git a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/extensions/ChatClientExtensionTests.kt b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/extensions/ChatClientExtensionTests.kt index 77a233ffbf8..ede3c9666fe 100644 --- a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/extensions/ChatClientExtensionTests.kt +++ b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/extensions/ChatClientExtensionTests.kt @@ -18,6 +18,7 @@ package io.getstream.chat.android.client.internal.state.extensions import io.getstream.chat.android.client.ChatClient import io.getstream.chat.android.client.api.StateConfig +import io.getstream.chat.android.client.api.models.Pagination import io.getstream.chat.android.client.api.state.GlobalState import io.getstream.chat.android.client.api.state.StateRegistry import io.getstream.chat.android.client.api.state.watchChannelAsState @@ -28,6 +29,7 @@ import io.getstream.chat.android.client.internal.state.plugin.internal.StatePlug import io.getstream.chat.android.client.plugin.Plugin import io.getstream.chat.android.client.plugin.factory.PluginFactory import io.getstream.chat.android.client.setup.state.ClientState +import io.getstream.chat.android.core.internal.InternalStreamChatApi import io.getstream.chat.android.models.Channel import io.getstream.chat.android.models.ConnectionData import io.getstream.chat.android.models.InitializationState @@ -55,6 +57,7 @@ import org.mockito.kotlin.times import org.mockito.kotlin.verify import java.util.concurrent.atomic.AtomicInteger +@OptIn(InternalStreamChatApi::class) internal class ChatClientExtensionTests { companion object { @@ -64,6 +67,7 @@ internal class ChatClientExtensionTests { private const val MESSAGE_LIMIT = 30 private const val ONE_SEC = 1000L + private const val AROUND_MESSAGE_ID = "msg-id-42" } private lateinit var channel: Channel @@ -73,6 +77,7 @@ internal class ChatClientExtensionTests { private lateinit var pluginFactories: List private lateinit var plugins: List + @Suppress("LongMethod") @BeforeEach fun setUp() { channel = randomChannel() @@ -95,6 +100,14 @@ internal class ChatClientExtensionTests { this.isWatchChannel = true } + val aroundMessageRequest = QueryChannelPaginationRequest(MESSAGE_LIMIT).apply { + messageFilterDirection = Pagination.AROUND_ID + messageFilterValue = AROUND_MESSAGE_ID + }.toWatchChannelRequest(stateConfig.userPresence).apply { + this.shouldRefresh = false + this.isWatchChannel = true + } + pluginFactories = listOf( mock { on(it.resolveDependency(StateConfig::class)) doReturn stateConfig @@ -127,6 +140,7 @@ internal class ChatClientExtensionTests { on(it.pluginFactories) doReturn pluginFactories on(it.plugins) doReturn plugins on(it.queryChannel(channel.type, channel.id, request)) doReturn channel.asCall() + on(it.queryChannel(channel.type, channel.id, aroundMessageRequest)) doReturn channel.asCall() on(it.connectAnonymousUser()) doAnswer { userFlow.value = connectionData.user initializationStateFlow.value = InitializationState.COMPLETE @@ -212,4 +226,76 @@ internal class ChatClientExtensionTests { userFlow.value = randomUser(id = "123") verify(chatClient, times(2)).queryChannel(any(), any(), any(), any()) } + + @Test + fun `watchChannelAsState with aroundMessageId should emit when InitializationState becomes COMPLETE`() = runTest { + val channelStateEmissions = AtomicInteger(0) + val channelStateFlow = MutableStateFlow(null) + val localScope = testCoroutines.scope + Job() + localScope.launch { + chatClient.watchChannelAsState(channel.cid, MESSAGE_LIMIT, AROUND_MESSAGE_ID).collect { + channelStateFlow.value = it + channelStateEmissions.incrementAndGet() + } + } + delay(ONE_SEC) + + chatClient.connectAnonymousUser().await() + + delay(ONE_SEC) + + val channelState = channelStateFlow.first { it != null } + channelState?.cid shouldBeEqualTo channel.cid + channelStateEmissions.get() shouldBeEqualTo 2 + } + + @Test + fun `watchChannelAsState with aroundMessageId should emit when InitializationState becomes NOT_INITIALIZED`() = runTest { + userFlow.value = randomUser() + initializationStateFlow.value = InitializationState.COMPLETE + val channelStateEmissions = AtomicInteger(0) + val channelStateFlow = MutableStateFlow(null) + val localScope = testCoroutines.scope + Job() + localScope.launch { + chatClient.watchChannelAsState(channel.cid, MESSAGE_LIMIT, AROUND_MESSAGE_ID).collect { + channelStateFlow.value = it + channelStateEmissions.incrementAndGet() + } + } + delay(ONE_SEC) + val channelState = channelStateFlow.first { it != null } + channelState?.cid shouldBeEqualTo channel.cid + channelStateEmissions.get() shouldBeEqualTo 2 + + chatClient.disconnect(flushPersistence = true).await() + + delay(ONE_SEC) + + channelStateFlow.value shouldBeEqualTo null + channelStateEmissions.get() shouldBeEqualTo 3 + } + + @Test + fun `watchChannelAsState with aroundMessageId shouldn't queryChannel twice when the user is the same but some properties has changed`() = runTest { + userFlow.value = randomUser(id = "jc") + initializationStateFlow.value = InitializationState.COMPLETE + val localScope = testCoroutines.scope + Job() + localScope.launch { + chatClient.watchChannelAsState(channel.cid, MESSAGE_LIMIT, AROUND_MESSAGE_ID).collect { } + } + userFlow.value = randomUser(id = "jc") + verify(chatClient, times(1)).queryChannel(any(), any(), any(), any()) + } + + @Test + fun `watchChannelAsState with aroundMessageId should queryChannel again when the user is a different one`() = runTest { + userFlow.value = randomUser(id = "jc") + initializationStateFlow.value = InitializationState.COMPLETE + val localScope = testCoroutines.scope + Job() + localScope.launch { + chatClient.watchChannelAsState(channel.cid, MESSAGE_LIMIT, AROUND_MESSAGE_ID).collect { } + } + userFlow.value = randomUser(id = "123") + verify(chatClient, times(2)).queryChannel(any(), any(), any(), any()) + } } diff --git a/stream-chat-android-compose/src/main/java/io/getstream/chat/android/compose/viewmodel/messages/MessagesViewModelFactory.kt b/stream-chat-android-compose/src/main/java/io/getstream/chat/android/compose/viewmodel/messages/MessagesViewModelFactory.kt index 19087aaf531..f35e11c7763 100644 --- a/stream-chat-android-compose/src/main/java/io/getstream/chat/android/compose/viewmodel/messages/MessagesViewModelFactory.kt +++ b/stream-chat-android-compose/src/main/java/io/getstream/chat/android/compose/viewmodel/messages/MessagesViewModelFactory.kt @@ -115,11 +115,13 @@ public class MessagesViewModelFactory( ) : ViewModelProvider.Factory { private val channelStateFlow: StateFlow by lazy { - chatClient.watchChannelAsState( - cid = channelId, - messageLimit = messageLimit, - coroutineScope = chatClient.inheritScope { SupervisorJob(it) + DispatcherProvider.Immediate }, - ) + val scope = chatClient.inheritScope { SupervisorJob(it) + DispatcherProvider.Immediate } + when { + messageId != null && parentMessageId == null -> + chatClient.watchChannelAsState(channelId, messageLimit, messageId, scope) + else -> + chatClient.watchChannelAsState(channelId, messageLimit, scope) + } } private val storageHelper by lazy { AttachmentStorageHelper(context) } diff --git a/stream-chat-android-ui-components/src/main/kotlin/io/getstream/chat/android/ui/viewmodel/messages/MessageListViewModelFactory.kt b/stream-chat-android-ui-components/src/main/kotlin/io/getstream/chat/android/ui/viewmodel/messages/MessageListViewModelFactory.kt index a2b4224567d..64b680230f9 100644 --- a/stream-chat-android-ui-components/src/main/kotlin/io/getstream/chat/android/ui/viewmodel/messages/MessageListViewModelFactory.kt +++ b/stream-chat-android-ui-components/src/main/kotlin/io/getstream/chat/android/ui/viewmodel/messages/MessageListViewModelFactory.kt @@ -103,11 +103,13 @@ public class MessageListViewModelFactory @JvmOverloads constructor( ) : ViewModelProvider.Factory { private val channelStateFlow: StateFlow by lazy { - chatClient.watchChannelAsState( - cid = cid, - messageLimit = messageLimit, - coroutineScope = chatClient.inheritScope { SupervisorJob(it) }, - ) + val scope = chatClient.inheritScope { SupervisorJob(it) } + when { + messageId != null && parentMessageId == null -> + chatClient.watchChannelAsState(cid, messageLimit, messageId, scope) + else -> + chatClient.watchChannelAsState(cid, messageLimit, scope) + } } override fun create(modelClass: Class, extras: CreationExtras): T =