From 1ca4677b2876f0b68bde843b961b14e69ffce342 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Fri, 7 Nov 2025 21:36:28 +0000 Subject: [PATCH 01/11] smp server: messaging services (#1565) * smp server: refactor message delivery to always respond SOK to subscriptions * refactor ntf subscribe * cancel subscription thread and reduce service subscription count when queue is deleted * subscribe rcv service, deliver sent messages to subscribed service * subscribe rcv service to messages (TODO delivery on subscription) * WIP * efficient initial delivery of messages to subscribed service * test: delivery to client with service certificate * test: upgrade/downgrade to/from service subscriptions * remove service association from agent API, add per-user flag to use the service * agent client (WIP) * service certificates in the client * rfc about drift detection, and SALL to mark end of message delivery * fix test * fix test * add function for postgresql message storage * update migration --- rfcs/2025-08-20-service-subs-drift.md | 101 ++++++++ simplexmq.cabal | 2 + src/Simplex/Messaging/Agent.hs | 150 +++++++----- src/Simplex/Messaging/Agent/Client.hs | 66 +++++- src/Simplex/Messaging/Agent/Env/SQLite.hs | 1 + src/Simplex/Messaging/Agent/Protocol.hs | 18 +- src/Simplex/Messaging/Agent/Store.hs | 6 +- .../Messaging/Agent/Store/AgentStore.hs | 105 ++++++++- .../Agent/Store/SQLite/Migrations/App.hs | 4 +- .../Migrations/M20250517_service_certs.hs | 40 ---- .../Migrations/M20251020_service_certs.hs | 40 ++++ .../Store/SQLite/Migrations/agent_schema.sql | 17 ++ src/Simplex/Messaging/Client.hs | 4 +- src/Simplex/Messaging/Client/Agent.hs | 6 +- src/Simplex/Messaging/Crypto.hs | 18 +- src/Simplex/Messaging/Protocol.hs | 21 +- src/Simplex/Messaging/Server.hs | 218 ++++++++++++------ .../Messaging/Server/MsgStore/Journal.hs | 20 ++ .../Messaging/Server/MsgStore/Postgres.hs | 33 +++ src/Simplex/Messaging/Server/MsgStore/STM.hs | 5 + .../Messaging/Server/MsgStore/Types.hs | 1 + .../Messaging/Server/QueueStore/Postgres.hs | 17 +- .../Messaging/Server/QueueStore/STM.hs | 13 +- src/Simplex/Messaging/Transport.hs | 23 +- tests/AgentTests/FunctionalAPITests.hs | 70 +++--- tests/AgentTests/SQLiteTests.hs | 4 +- tests/AgentTests/ServerChoice.hs | 1 + tests/SMPAgentClient.hs | 1 + tests/SMPClient.hs | 34 ++- tests/SMPProxyTests.hs | 18 +- tests/ServerTests.hs | 217 ++++++++++++++++- 31 files changed, 969 insertions(+), 305 deletions(-) create mode 100644 rfcs/2025-08-20-service-subs-drift.md delete mode 100644 src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20250517_service_certs.hs create mode 100644 src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251020_service_certs.hs diff --git a/rfcs/2025-08-20-service-subs-drift.md b/rfcs/2025-08-20-service-subs-drift.md new file mode 100644 index 000000000..1ca9e6018 --- /dev/null +++ b/rfcs/2025-08-20-service-subs-drift.md @@ -0,0 +1,101 @@ +# Detecting and fixing state with service subscriptions + +## Problem + +While service certificates and subscriptions hugely decrease startup time and delivery delays on server restarts, they introduce the risk of losing subscriptions in case of state drifts. They also do not provide efficient mechanism for validating that the list of subscribed queues is in sync. + +How can the state drift happen? + +There are several possibilities: +- lost broker response would make the broker consider that the queue is associated, but the client won't know it, and will have to re-associate. While in itself it is not a problem, as it'll be resolved, it would make drift detected more frequently (regardless of the detection logic used). That service certificates are used on clients with good connection would make it less likely though. +- server state restored from the backup, in case of some failure. Nothing can be done to recover lost queues, but we may restore lost service associations. +- queue blocking or removal by server operator because of policy violation. +- server downgrade (when it loses all service associations) with subsequent upgrade - the client would think queues are associated, while they are not, and won't receive any messages at all in this scenario. +- any other server-side error or logic error. + +In addition to the possibility of the drift, we simply need to have confidence that service subscriptions work as intended, without skipping queues. We ignored this consideration for notifications, as the tolerance to lost notifications is higher, but we can't ignore it for messages. + +## Solution + +Previously considered approach of sending NIL to all queues without messages is very expensive for traffic (most queues don't have messages), and it is also very expensive to detect and validate drift in the client because of asynchronous / concurrent events. + +We cannot read all queues into memory, and we cannot aggregate all responses in memory, and we cannot create database writes on every single service subscription to say 1m queues (a realistic number), as it simply won't work well even at the current scale. + +An approach of having an efficient way to detect drift, but load the full list of IDs when drift is detected, also won't work well, as drifts may be common, so we need both efficient way to detect there is diff and also to reconcile it. + +### Drift detection + +Both client and server would maintain the number of associated queues and the "symmetric" hash over the set of queue IDs. The requirements for this hash algorithm are: +- not cryptographically strong, to be fast. +- 128 bits to minimize collisions over the large set of millions of queues. +- symmetric - the result should not depend on ID order. +- allows fast additions and removals. + +In this way, every time association is added or removed (including queue marked as deleted), both peers would recompute this hash in the same transaction. + +The client would suspend sending and processing any other commands on the server and the queues of this server until SOKS response is received from this server, to prevent drift. It can be achieved with per-server semaphores/locks in memory. UI clients need to become responsive sooner than these responses are received, but we do not service certificates on UI clients, and chat relays may prevent operations on server queues until SOKS response is received. + +SOKS response would include both the count of associated queues (as now) and the hash over all associated queue IDs (to be added). If both count and hash match, the client will not do anything. If either does not match the client would perform full sync (see below). + +There is a value from doing the same in notification server as well to detect and "fix" drifts. + +The algorithm to compute hashes can be the following. + +1. Compute hash of each queue ID using xxHash3_128 ([xxhash-ffi](https://hackage.haskell.org/package/xxhash-ffi) library). They don't need to be stored or loaded at once, initially, it can be done with streaming if it is detected on start that there is no pre-computed hash. +2. Combine hashes using XOR. XOR is both commutative and associative, so it would produce the same aggregate hash irrespective of the ID order. +3. Adding queue ID to pre-computed hash requires a single XOR with ID hash: `new_aggregate = aggregate XOR hash(queue_id)`. +4. Removing queue ID from pre-computed hash also requires the same XOR (XOR is involutory, it undoes itself): `new_aggregate = aggregate XOR hash(queue_id)`. + +These hashes need to be computed per user/server in the client and per service certificate in the server - on startup both have to validate and compute them once if necessary. + +There can be also a start-up option to recompute hashe(s) to detect and fix any errors. + +This is all rather simple and would help detecting drifts. + +### Synchronization when drift is detected + +The assumption here is that in most cases drifts are rare, and isolated to few IDs (e.g., this is the case with notification server). + +But the algorithm should be resilient to losing all associations, and it should not be substantially worse than simply restoring all associations or loading all IDs. + +We have `c_n` and `c_hash` for client-side count and hash of queue IDs and `s_n` and `s_hash` for server-side, which are returned in SOKS response to SUBS command. + +1. If `c_n /= s_n || c_hash /= s_hash`, the client must perform sync. + +2. If `abs(c_n - s_n) / max(c_n, s_n) > 0.5`, the client will request the full list of queues (more than half of the queues are different), and will perform diff with the queues it has. While performing the diff the client will continue block operations with this user/server. + +3. Otherwise would perform some algorithm for determining the difference between queue IDs between client and server. This algorithm can be made efficient (`O(log N)`) by relying on efficient sorting of IDs and database loading of ranges, via computing and communicating hashes of ranges, and performing a binary search on ranges, with batching to optimize network traffic. + +This algorithm is similar to Merkle tree reconcilliation, but it is optimized for database reading of ordered ranges, and for our 16kb block size to minimize network requests. + +The algorithm: +1. The client would request all ranges from the server. +2. The server would compute hashes for N ranges of IDs and send them to the client. Each range would include start_id, optional end_id (for single ID ranges) and XOR-hash of the range. N is determined based on the block size and the range size. +3. The client would perform the same computation for the same ranges, and compare them with the returned ranges from the server, while detecting any gaps between ranges and missing range boundaries. +4. If more than half of the ranges don't match, the client would request the full list. Otherwise it would repeat the same algorithm for each mismatched range and for gaps. + +It can be further optimized by merging adjacent ranges and by batching all range requests, it is quite simple. + +Once the client determines the list of missing and extra queues it can: +- create associations (via SUB) for missing queues, +- request removal of association (a new command, e.g. BUS) for extra queues on the server. + +The pseudocode for the algorightm: + +For the server to return all ranges or subranges of requested range: + +```haskell +getSubRanges :: Maybe (RecipientId, RecipientId) -> [(RecipientId, Maybe RecipientId, Hash)] +getSubRanges range_ = do + ((min_id, max_id), s_n) <- case range_ of + Nothing -> getAssociatedQueueRange -- with the certificate in the client session. + Just range -> (range,) <$> getAssociatedQueueCount range + if + | s_n <= max_N -> reply_with_single_queue_ranges + | otherwise -> do + let range_size = s_n `div` max_N + read_all_ranges -- in a recursive loop, with max_id, range_hash and next_min_id in each step + reply_ranges +``` + +We don't need to implement this synchronization logic right now, so not including client logic here, it's sufficient to implement drift detection, and the action to fix the drift would be to disable and to re-enable certificates via some command-line parameter of CLI. diff --git a/simplexmq.cabal b/simplexmq.cabal index 7fd1396e1..081c05bca 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -216,6 +216,7 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250702_conn_invitations_remove_cascade_delete Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251009_queue_to_subscribe Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251010_client_notices + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251020_service_certs if flag(client_postgres) || flag(server_postgres) exposed-modules: Simplex.Messaging.Agent.Store.Postgres @@ -553,6 +554,7 @@ test-suite simplexmq-test , text , time , timeit ==2.0.* + , tls >=1.9.0 && <1.10 , transformers , unliftio , unliftio-core diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index c19d4aeea..f9f1dc089 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -47,6 +47,7 @@ module Simplex.Messaging.Agent withInvLock, createUser, deleteUser, + setUserService, connRequestPQSupport, createConnectionAsync, joinConnectionAsync, @@ -78,7 +79,7 @@ module Simplex.Messaging.Agent getNotificationConns, resubscribeConnection, resubscribeConnections, - subscribeClientService, + subscribeClientServices, sendMessage, sendMessages, sendMessagesB, @@ -210,6 +211,7 @@ import Simplex.Messaging.Protocol ErrorType (AUTH), MsgBody, MsgFlags (..), + IdsHash, NtfServer, ProtoServerWithAuth (..), ProtocolServer (..), @@ -340,6 +342,11 @@ deleteUser :: AgentClient -> UserId -> Bool -> AE () deleteUser c = withAgentEnv c .: deleteUser' c {-# INLINE deleteUser #-} +-- | Enable using service certificate for this user +setUserService :: AgentClient -> UserId -> Bool -> AE () +setUserService c = withAgentEnv c .: setUserService' c +{-# INLINE setUserService #-} + -- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id createConnectionAsync :: ConnectionModeI c => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> AE ConnId createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. newConnAsync c userId aCorrId enableNtfs @@ -381,7 +388,7 @@ deleteConnectionsAsync c waitDelivery = withAgentEnv c . deleteConnectionsAsync' {-# INLINE deleteConnectionsAsync #-} -- | Create SMP agent connection (NEW command) -createConnection :: ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> Bool -> Bool -> SConnectionMode c -> Maybe (UserConnLinkData c) -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AE (ConnId, (CreatedConnLink c, Maybe ClientServiceId)) +createConnection :: ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> Bool -> Bool -> SConnectionMode c -> Maybe (UserConnLinkData c) -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AE (ConnId, CreatedConnLink c) createConnection c nm userId enableNtfs checkNotices = withAgentEnv c .::. newConn c nm userId enableNtfs checkNotices {-# INLINE createConnection #-} @@ -424,7 +431,7 @@ prepareConnectionToAccept c userId enableNtfs = withAgentEnv c .: newConnToAccep {-# INLINE prepareConnectionToAccept #-} -- | Join SMP agent connection (JOIN command). -joinConnection :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE (SndQueueSecured, Maybe ClientServiceId) +joinConnection :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE SndQueueSecured joinConnection c nm userId connId enableNtfs = withAgentEnv c .:: joinConn c nm userId connId enableNtfs {-# INLINE joinConnection #-} @@ -434,7 +441,7 @@ allowConnection c = withAgentEnv c .:. allowConnection' c {-# INLINE allowConnection #-} -- | Accept contact after REQ notification (ACPT command) -acceptContact :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE (SndQueueSecured, Maybe ClientServiceId) +acceptContact :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE SndQueueSecured acceptContact c userId connId enableNtfs = withAgentEnv c .::. acceptContact' c userId connId enableNtfs {-# INLINE acceptContact #-} @@ -462,12 +469,12 @@ syncConnections c = withAgentEnv c .: syncConnections' c {-# INLINE syncConnections #-} -- | Subscribe to receive connection messages (SUB command) -subscribeConnection :: AgentClient -> ConnId -> AE (Maybe ClientServiceId) +subscribeConnection :: AgentClient -> ConnId -> AE () subscribeConnection c = withAgentEnv c . subscribeConnection' c {-# INLINE subscribeConnection #-} -- | Subscribe to receive connection messages from multiple connections, batching commands when possible -subscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) +subscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType ())) subscribeConnections c = withAgentEnv c . subscribeConnections' c {-# INLINE subscribeConnections #-} @@ -485,18 +492,17 @@ getNotificationConns :: AgentClient -> C.CbNonce -> ByteString -> AE (NonEmpty N getNotificationConns c = withAgentEnv c .: getNotificationConns' c {-# INLINE getNotificationConns #-} -resubscribeConnection :: AgentClient -> ConnId -> AE (Maybe ClientServiceId) +resubscribeConnection :: AgentClient -> ConnId -> AE () resubscribeConnection c = withAgentEnv c . resubscribeConnection' c {-# INLINE resubscribeConnection #-} -resubscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) +resubscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType ())) resubscribeConnections c = withAgentEnv c . resubscribeConnections' c {-# INLINE resubscribeConnections #-} --- TODO [certs rcv] how to communicate that service ID changed - as error or as result? -subscribeClientService :: AgentClient -> ClientServiceId -> AE Int -subscribeClientService c = withAgentEnv c . subscribeClientService' c -{-# INLINE subscribeClientService #-} +subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType (Int64, IdsHash))) +subscribeClientServices c = withAgentEnv c . subscribeClientServices' c +{-# INLINE subscribeClientServices #-} -- | Send message to the connection (SEND command) sendMessage :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AE (AgentMsgId, PQEncryption) @@ -746,6 +752,7 @@ createUser' c smp xftp = do userId <- withStore' c createUserRecord atomically $ TM.insert userId (mkUserServers smp) $ smpServers c atomically $ TM.insert userId (mkUserServers xftp) $ xftpServers c + atomically $ TM.insert userId False $ useClientServices c pure userId deleteUser' :: AgentClient -> UserId -> Bool -> AM () @@ -755,6 +762,7 @@ deleteUser' c@AgentClient {smpServersStats, xftpServersStats} userId delSMPQueue else withStore c (`deleteUserRecord` userId) atomically $ TM.delete userId $ smpServers c atomically $ TM.delete userId $ xftpServers c + atomically $ TM.delete userId $ useClientServices c atomically $ modifyTVar' smpServersStats $ M.filterWithKey (\(userId', _) _ -> userId' /= userId) atomically $ modifyTVar' xftpServersStats $ M.filterWithKey (\(userId', _) _ -> userId' /= userId) lift $ saveServersStats c @@ -763,6 +771,13 @@ deleteUser' c@AgentClient {smpServersStats, xftpServersStats} userId delSMPQueue whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ DEL_USER userId) +setUserService' :: AgentClient -> UserId -> Bool -> AM () +setUserService' c userId enable = do + wasEnabled <- liftIO $ fromMaybe False <$> TM.lookupIO userId (useClientServices c) + when (enable /= wasEnabled) $ do + atomically $ TM.insert userId enable $ useClientServices c + unless enable $ withStore' c (`deleteClientServices` userId) + newConnAsync :: ConnectionModeI c => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> AM ConnId newConnAsync c userId corrId enableNtfs cMode pqInitKeys subMode = do connId <- newConnNoQueues c userId enableNtfs cMode (CR.connPQEncryption pqInitKeys) @@ -865,7 +880,7 @@ switchConnectionAsync' c corrId connId = connectionStats c $ DuplexConnection cData rqs' sqs _ -> throwE $ CMD PROHIBITED "switchConnectionAsync: not duplex" -newConn :: ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> Bool -> Bool -> SConnectionMode c -> Maybe (UserConnLinkData c) -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AM (ConnId, (CreatedConnLink c, Maybe ClientServiceId)) +newConn :: ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> Bool -> Bool -> SConnectionMode c -> Maybe (UserConnLinkData c) -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AM (ConnId, CreatedConnLink c) newConn c nm userId enableNtfs checkNotices cMode linkData_ clientData pqInitKeys subMode = do srv <- getSMPServer c userId when (checkNotices && connMode cMode == CMContact) $ checkClientNotices c srv @@ -989,7 +1004,7 @@ changeConnectionUser' c oldUserId connId newUserId = do where updateConn = withStore' c $ \db -> setConnUserId db oldUserId connId newUserId -newRcvConnSrv :: forall c. ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe (UserConnLinkData c) -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (CreatedConnLink c, Maybe ClientServiceId) +newRcvConnSrv :: forall c. ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe (UserConnLinkData c) -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (CreatedConnLink c) newRcvConnSrv c nm userId connId enableNtfs cMode userLinkData_ clientData pqInitKeys subMode srvWithAuth@(ProtoServerWithAuth srv _) = do case (cMode, pqInitKeys) of (SCMContact, CR.IKUsePQ) -> throwE $ CMD PROHIBITED "newRcvConnSrv" @@ -1000,12 +1015,12 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userLinkData_ clientData pqIni (nonce, qUri, cReq, qd) <- prepareLinkData d $ fst e2eKeys (rq, qUri') <- createRcvQueue (Just nonce) qd e2eKeys ccLink <- connReqWithShortLink qUri cReq qUri' (shortLink rq) - pure (ccLink, clientServiceId rq) + pure ccLink Nothing -> do let qd = case cMode of SCMContact -> CQRContact Nothing; SCMInvitation -> CQRMessaging Nothing - (rq, qUri) <- createRcvQueue Nothing qd e2eKeys + (_rq, qUri) <- createRcvQueue Nothing qd e2eKeys cReq <- createConnReq qUri - pure (CCLink cReq Nothing, clientServiceId rq) + pure $ CCLink cReq Nothing where createRcvQueue :: Maybe C.CbNonce -> ClntQueueReqData -> C.KeyPairX25519 -> AM (RcvQueue, SMPQueueUri) createRcvQueue nonce_ qd e2eKeys = do @@ -1107,7 +1122,7 @@ newConnToAccept c userId connId enableNtfs invId pqSup = do Invitation {connReq} <- withStore c $ \db -> getInvitation db "newConnToAccept" invId newConnToJoin c userId connId enableNtfs connReq pqSup -joinConn :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM (SndQueueSecured, Maybe ClientServiceId) +joinConn :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM SndQueueSecured joinConn c nm userId connId enableNtfs cReq cInfo pqSupport subMode = do srv <- getNextSMPServer c userId [qServer $ connReqQueue cReq] joinConnSrv c nm userId connId enableNtfs cReq cInfo pqSupport subMode srv @@ -1187,7 +1202,7 @@ versionPQSupport_ :: VersionSMPA -> Maybe CR.VersionE2E -> PQSupport versionPQSupport_ agentV e2eV_ = PQSupport $ agentV >= pqdrSMPAgentVersion && maybe True (>= CR.pqRatchetE2EEncryptVersion) e2eV_ {-# INLINE versionPQSupport_ #-} -joinConnSrv :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM (SndQueueSecured, Maybe ClientServiceId) +joinConnSrv :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM SndQueueSecured joinConnSrv c nm userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do SomeConn cType conn <- withStore c (`getConn` connId) @@ -1198,7 +1213,7 @@ joinConnSrv c nm userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup sub | sqStatus == New || sqStatus == Secured -> doJoin (Just rq) (Just sq) _ -> throwE $ CMD PROHIBITED $ "joinConnSrv: bad connection " <> show cType where - doJoin :: Maybe RcvQueue -> Maybe SndQueue -> AM (SndQueueSecured, Maybe ClientServiceId) + doJoin :: Maybe RcvQueue -> Maybe SndQueue -> AM SndQueueSecured doJoin rq_ sq_ = do (cData, sq, e2eSndParams, lnkId_) <- startJoinInvitation c userId connId sq_ enableNtfs inv pqSup secureConfirmQueue c nm cData rq_ sq srv cInfo (Just e2eSndParams) subMode @@ -1209,14 +1224,14 @@ joinConnSrv c nm userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup su withInvLock c (strEncode cReqUri) "joinConnSrv" $ do SomeConn cType conn <- withStore c (`getConn` connId) let pqInitKeys = CR.joinContactInitialKeys (v >= pqdrSMPAgentVersion) pqSup - (CCLink cReq _, service) <- case conn of + CCLink cReq _ <- case conn of NewConnection _ -> newRcvConnSrv c NRMBackground userId connId enableNtfs SCMInvitation Nothing Nothing pqInitKeys subMode srv RcvConnection _ rq -> mkJoinInvitation rq pqInitKeys _ -> throwE $ CMD PROHIBITED $ "joinConnSrv: bad connection " <> show cType void $ sendInvitation c nm userId connId qInfo vrsn cReq cInfo - pure (False, service) + pure False where - mkJoinInvitation rq@RcvQueue {clientService} pqInitKeys = do + mkJoinInvitation rq pqInitKeys = do g <- asks random AgentConfig {smpClientVRange = vr, smpAgentVRange, e2eEncryptVRange = e2eVR} <- asks config let qUri = SMPQueueUri vr $ (rcvSMPQueueAddress rq) {queueMode = Just QMMessaging} @@ -1231,7 +1246,7 @@ joinConnSrv c nm userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup su createRatchetX3dhKeys db connId pk1 pk2 pKem pure e2eRcvParams let cReq = CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eVR - pure (CCLink cReq Nothing, dbServiceId <$> clientService) + pure $ CCLink cReq Nothing Nothing -> throwE $ AGENT A_VERSION delInvSL :: AgentClient -> ConnId -> SMPServerWithAuth -> SMP.LinkId -> AM () @@ -1239,7 +1254,7 @@ delInvSL c connId srv lnkId = withStore' c (\db -> deleteInvShortLink db (protoServer srv) lnkId) `catchE` \e -> liftIO $ nonBlockingWriteTBQueue (subQ c) ("", connId, AEvt SAEConn (ERR $ INTERNAL $ "error deleting short link " <> show e)) -joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM (SndQueueSecured, Maybe ClientServiceId) +joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM SndQueueSecured joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = do SomeConn cType conn <- withStore c (`getConn` connId) case conn of @@ -1251,7 +1266,7 @@ joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSuppo | sqStatus == New || sqStatus == Secured -> doJoin (Just rq) (Just sq) _ -> throwE $ CMD PROHIBITED $ "joinConnSrvAsync: bad connection " <> show cType where - doJoin :: Maybe RcvQueue -> Maybe SndQueue -> AM (SndQueueSecured, Maybe ClientServiceId) + doJoin :: Maybe RcvQueue -> Maybe SndQueue -> AM SndQueueSecured doJoin rq_ sq_ = do (cData, sq, e2eSndParams, lnkId_) <- startJoinInvitation c userId connId sq_ enableNtfs inv pqSupport secureConfirmQueueAsync c cData rq_ sq srv cInfo (Just e2eSndParams) subMode @@ -1259,7 +1274,7 @@ joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSuppo joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode _pqSupport _srv = do throwE $ CMD PROHIBITED "joinConnSrvAsync" -createReplyQueue :: AgentClient -> NetworkRequestMode -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> AM (SMPQueueInfo, Maybe ClientServiceId) +createReplyQueue :: AgentClient -> NetworkRequestMode -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> AM SMPQueueInfo createReplyQueue c nm ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} subMode srv = do ntfServer_ <- if enableNtfs then newQueueNtfServer else pure Nothing (rq, qUri, tSess, sessId) <- newRcvQueue c nm userId connId srv (versionToRange smpClientVersion) SCMInvitation (isJust ntfServer_) subMode @@ -1268,7 +1283,7 @@ createReplyQueue c nm ConnData {userId, connId, enableNtfs} SndQueue {smpClientV rq' <- withStore c $ \db -> upgradeSndConnToDuplex db connId rq subMode lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId mapM_ (newQueueNtfSubscription c rq') ntfServer_ - pure (qInfo, clientServiceId rq') + pure qInfo -- | Approve confirmation (LET command) in Reader monad allowConnection' :: AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> AM () @@ -1281,7 +1296,7 @@ allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConne _ -> throwE $ CMD PROHIBITED "allowConnection" -- | Accept contact (ACPT command) in Reader monad -acceptContact' :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM (SndQueueSecured, Maybe ClientServiceId) +acceptContact' :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM SndQueueSecured acceptContact' c nm userId connId enableNtfs invId ownConnInfo pqSupport subMode = withConnLock c connId "acceptContact" $ do Invitation {connReq} <- withStore c $ \db -> getInvitation db "acceptContact'" invId r <- joinConn c nm userId connId enableNtfs connReq ownConnInfo pqSupport subMode @@ -1316,7 +1331,7 @@ databaseDiff passed known = in DatabaseDiff {missingIds, extraIds} -- | Subscribe to receive connection messages (SUB command) in Reader monad -subscribeConnection' :: AgentClient -> ConnId -> AM (Maybe ClientServiceId) +subscribeConnection' :: AgentClient -> ConnId -> AM () subscribeConnection' c connId = toConnResult connId =<< subscribeConnections' c [connId] {-# INLINE subscribeConnection' #-} @@ -1332,12 +1347,13 @@ type QDelResult = QCmdResult () type QSubResult = QCmdResult (Maybe SMP.ServiceId) -subscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) +subscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) subscribeConnections' _ [] = pure M.empty subscribeConnections' c connIds = subscribeConnections_ c . zip connIds =<< withStore' c (`getConnSubs` connIds) -subscribeConnections_ :: AgentClient -> [(ConnId, Either StoreError SomeConnSub)] -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) +subscribeConnections_ :: AgentClient -> [(ConnId, Either StoreError SomeConnSub)] -> AM (Map ConnId (Either AgentErrorType ())) subscribeConnections_ c conns = do + -- TODO [certs rcv] - it should exclude connections already associated, and then if some don't deliver any response they may be unassociated let (subRs, cs) = foldr partitionResultsConns ([], []) conns resumeDelivery cs resumeConnCmds c $ map fst cs @@ -1351,8 +1367,8 @@ subscribeConnections_ c conns = do pure rs where partitionResultsConns :: (ConnId, Either StoreError SomeConnSub) -> - (Map ConnId (Either AgentErrorType (Maybe ClientServiceId)), [(ConnId, SomeConnSub)]) -> - (Map ConnId (Either AgentErrorType (Maybe ClientServiceId)), [(ConnId, SomeConnSub)]) + (Map ConnId (Either AgentErrorType ()), [(ConnId, SomeConnSub)]) -> + (Map ConnId (Either AgentErrorType ()), [(ConnId, SomeConnSub)]) partitionResultsConns (connId, conn_) (rs, cs) = case conn_ of Left e -> (M.insert connId (Left $ storeError e) rs, cs) Right c'@(SomeConn _ conn) -> case conn of @@ -1360,12 +1376,12 @@ subscribeConnections_ c conns = do SndConnection _ sq -> (M.insert connId (sndSubResult sq) rs, cs') RcvConnection _ _ -> (rs, cs') ContactConnection _ _ -> (rs, cs') - NewConnection _ -> (M.insert connId (Right Nothing) rs, cs') + NewConnection _ -> (M.insert connId (Right ()) rs, cs') where cs' = (connId, c') : cs - sndSubResult :: SndQueue -> Either AgentErrorType (Maybe ClientServiceId) + sndSubResult :: SndQueue -> Either AgentErrorType () sndSubResult SndQueue {status} = case status of - Confirmed -> Right Nothing + Confirmed -> Right () Active -> Left $ CONN SIMPLEX "subscribeConnections" _ -> Left $ INTERNAL "unexpected queue status" rcvQueues :: (ConnId, SomeConnSub) -> [RcvQueueSub] @@ -1386,9 +1402,9 @@ subscribeConnections_ c conns = do order (_, Right _) = 3 order _ = 4 -- TODO [certs rcv] store associations of queues with client service ID - storeClientServiceAssocs :: Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) - storeClientServiceAssocs = pure . M.map (Nothing <$) - sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType (Maybe ClientServiceId)) -> [(ConnId, SomeConnSub)] -> AM' () + storeClientServiceAssocs :: Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) -> AM (Map ConnId (Either AgentErrorType ())) + storeClientServiceAssocs = pure . M.map (() <$) + sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> [(ConnId, SomeConnSub)] -> AM' () sendNtfCreate ns rcvRs cs = do let oks = M.keysSet $ M.filter (either temporaryAgentError $ const True) rcvRs (csCreate, csDelete) = foldr (groupConnIds oks) ([], []) cs @@ -1412,7 +1428,7 @@ subscribeConnections_ c conns = do DuplexConnection _ _ sqs -> L.toList sqs SndConnection _ sq -> [sq] _ -> [] - notifyResultError :: Map ConnId (Either AgentErrorType (Maybe ClientServiceId)) -> AM () + notifyResultError :: Map ConnId (Either AgentErrorType ()) -> AM () notifyResultError rs = do let actual = M.size rs expected = length conns @@ -1472,15 +1488,15 @@ subscribeAllConnections' c onlyNeeded activeUserId_ = handleErr $ do sqs <- withStore' c getAllSndQueuesForDelivery lift $ mapM_ (resumeMsgDelivery c) sqs -resubscribeConnection' :: AgentClient -> ConnId -> AM (Maybe ClientServiceId) +resubscribeConnection' :: AgentClient -> ConnId -> AM () resubscribeConnection' c connId = toConnResult connId =<< resubscribeConnections' c [connId] {-# INLINE resubscribeConnection' #-} -resubscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) +resubscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) resubscribeConnections' _ [] = pure M.empty resubscribeConnections' c connIds = do conns <- zip connIds <$> withStore' c (`getConnSubs` connIds) - let r = M.fromList $ map (,Right Nothing) connIds -- TODO [certs rcv] + let r = M.fromList $ map (,Right ()) connIds conns' <- filterM (fmap not . isActiveConn . snd) conns -- union is left-biased, so results returned by subscribeConnections' take precedence (`M.union` r) <$> subscribeConnections_ c conns' @@ -1491,9 +1507,15 @@ resubscribeConnections' c connIds = do [] -> pure True rqs' -> anyM $ map (atomically . hasActiveSubscription c) rqs' --- TODO [certs rcv] -subscribeClientService' :: AgentClient -> ClientServiceId -> AM Int -subscribeClientService' = undefined +-- TODO [certs rcv] compare hash with lock +subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType (Int64, IdsHash))) +subscribeClientServices' c userId = + ifM useService subscribe $ throwError $ CMD PROHIBITED "no user service allowed" + where + useService = liftIO $ (Just True ==) <$> TM.lookupIO userId (useClientServices c) + subscribe = do + srvs <- withStore' c (`getClientServiceServers` userId) + lift $ M.fromList . zip srvs <$> mapConcurrently (tryAllErrors' . subscribeClientService c userId) srvs -- requesting messages sequentially, to reduce memory usage getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta))) @@ -1655,13 +1677,13 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do NEW enableNtfs (ACM cMode) pqEnc subMode -> noServer $ do triedHosts <- newTVarIO S.empty tryCommand . withNextSrv c userId storageSrvs triedHosts [] $ \srv -> do - (CCLink cReq _, service) <- newRcvConnSrv c NRMBackground userId connId enableNtfs cMode Nothing Nothing pqEnc subMode srv - notify $ INV (ACR cMode cReq) service + CCLink cReq _ <- newRcvConnSrv c NRMBackground userId connId enableNtfs cMode Nothing Nothing pqEnc subMode srv + notify $ INV (ACR cMode cReq) JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) pqEnc subMode connInfo -> noServer $ do triedHosts <- newTVarIO S.empty tryCommand . withNextSrv c userId storageSrvs triedHosts [qServer q] $ \srv -> do - (sqSecured, service) <- joinConnSrvAsync c userId connId enableNtfs cReq connInfo pqEnc subMode srv - notify $ JOINED sqSecured service + sqSecured <- joinConnSrvAsync c userId connId enableNtfs cReq connInfo pqEnc subMode srv + notify $ JOINED sqSecured LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK ACK msgId rcptInfo_ -> withServer' . tryCommand $ ackMessage' c connId msgId rcptInfo_ >> notify OK SWCH -> @@ -2818,7 +2840,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId SMP.SUB -> case respOrErr of Right SMP.OK -> liftIO $ processSubOk rq upConnIds -- TODO [certs rcv] associate queue with the service - Right (SMP.SOK serviceId_) -> liftIO $ processSubOk rq upConnIds + Right (SMP.SOK _serviceId_) -> liftIO $ processSubOk rq upConnIds Right msg@SMP.MSG {} -> do liftIO $ processSubOk rq upConnIds -- the connection is UP even when processing this particular message fails runProcessSMP rq conn (toConnData conn) msg @@ -3053,7 +3075,9 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId notifyEnd removed | removed = notify END >> logServer "<--" c srv rId "END" | otherwise = logServer "<--" c srv rId "END from disconnected client - ignored" - -- Possibly, we need to add some flag to connection that it was deleted + -- TODO [certs rcv] + r@(SMP.ENDS _) -> unexpected r + -- TODO [certs rcv] Possibly, we need to add some flag to connection that it was deleted SMP.DELD -> atomically (removeSubscription c tSess connId rq) >> notify DELD SMP.ERR e -> notify $ ERR $ SMP (B.unpack $ strEncode srv) e r -> unexpected r @@ -3439,22 +3463,22 @@ connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo sq_ (qInfo :| _ (sq, _) <- lift $ newSndQueue userId connId qInfo' Nothing withStore c $ \db -> upgradeRcvConnToDuplex db connId sq -secureConfirmQueueAsync :: AgentClient -> ConnData -> Maybe RcvQueue -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM (SndQueueSecured, Maybe ClientServiceId) +secureConfirmQueueAsync :: AgentClient -> ConnData -> Maybe RcvQueue -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM SndQueueSecured secureConfirmQueueAsync c cData rq_ sq srv connInfo e2eEncryption_ subMode = do sqSecured <- agentSecureSndQueue c NRMBackground cData sq - (qInfo, service) <- mkAgentConfirmation c NRMBackground cData rq_ sq srv connInfo subMode + qInfo <- mkAgentConfirmation c NRMBackground cData rq_ sq srv connInfo subMode storeConfirmation c cData sq e2eEncryption_ qInfo lift $ submitPendingMsg c sq - pure (sqSecured, service) + pure sqSecured -secureConfirmQueue :: AgentClient -> NetworkRequestMode -> ConnData -> Maybe RcvQueue -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM (SndQueueSecured, Maybe ClientServiceId) +secureConfirmQueue :: AgentClient -> NetworkRequestMode -> ConnData -> Maybe RcvQueue -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM SndQueueSecured secureConfirmQueue c nm cData@ConnData {connId, connAgentVersion, pqSupport} rq_ sq srv connInfo e2eEncryption_ subMode = do sqSecured <- agentSecureSndQueue c nm cData sq - (qInfo, service) <- mkAgentConfirmation c nm cData rq_ sq srv connInfo subMode + qInfo <- mkAgentConfirmation c nm cData rq_ sq srv connInfo subMode msg <- mkConfirmation qInfo void $ sendConfirmation c nm sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed - pure (sqSecured, service) + pure sqSecured where mkConfirmation :: AgentMessage -> AM MsgBody mkConfirmation aMessage = do @@ -3480,12 +3504,12 @@ agentSecureSndQueue c nm ConnData {connAgentVersion} sq@SndQueue {queueMode, sta sndSecure = senderCanSecure queueMode initiatorRatchetOnConf = connAgentVersion >= ratchetOnConfSMPAgentVersion -mkAgentConfirmation :: AgentClient -> NetworkRequestMode -> ConnData -> Maybe RcvQueue -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> AM (AgentMessage, Maybe ClientServiceId) +mkAgentConfirmation :: AgentClient -> NetworkRequestMode -> ConnData -> Maybe RcvQueue -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> AM AgentMessage mkAgentConfirmation c nm cData rq_ sq srv connInfo subMode = do - (qInfo, service) <- case rq_ of + qInfo <- case rq_ of Nothing -> createReplyQueue c nm cData sq subMode srv - Just rq@RcvQueue {smpClientVersion = v, clientService} -> pure (SMPQueueInfo v $ rcvSMPQueueAddress rq, dbServiceId <$> clientService) - pure (AgentConnInfoReply (qInfo :| []) connInfo, service) + Just rq@RcvQueue {smpClientVersion = v} -> pure $ SMPQueueInfo v $ rcvSMPQueueAddress rq + pure $ AgentConnInfoReply (qInfo :| []) connInfo enqueueConfirmation :: AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> AM () enqueueConfirmation c cData sq connInfo e2eEncryption_ = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 217a1682a..4a10d07ef 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -49,6 +49,7 @@ module Simplex.Messaging.Agent.Client newRcvQueue_, subscribeQueues, subscribeUserServerQueues, + subscribeClientService, processClientNotices, getQueueMessage, decryptSMPMessage, @@ -223,6 +224,7 @@ import Data.Text.Encoding import Data.Time (UTCTime, addUTCTime, defaultTimeLocale, formatTime, getCurrentTime) import Data.Time.Clock.System (getSystemTime) import Data.Word (Word16) +import qualified Data.X509.Validation as XV import Network.Socket (HostName) import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError) import qualified Simplex.FileTransfer.Client as X @@ -238,7 +240,7 @@ import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Stats import Simplex.Messaging.Agent.Store -import Simplex.Messaging.Agent.Store.AgentStore (getClientNotices, updateClientNotices) +import Simplex.Messaging.Agent.Store.AgentStore import Simplex.Messaging.Agent.Store.Common (DBStore, withTransaction) import qualified Simplex.Messaging.Agent.Store.DB as DB import Simplex.Messaging.Agent.Store.Entity @@ -262,6 +264,7 @@ import Simplex.Messaging.Protocol NetworkError (..), MsgFlags (..), MsgId, + IdsHash, NtfServer, NtfServerWithAuth, ProtoServer, @@ -296,8 +299,9 @@ import Simplex.Messaging.Session import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (SMPVersion, SessionId, THandleParams (sessionId, thVersion), TransportError (..), TransportPeer (..), sndAuthKeySMPVersion, shortLinksSMPVersion, newNtfCredsSMPVersion) +import Simplex.Messaging.Transport (SMPServiceRole (..), SMPVersion, ServiceCredentials (..), SessionId, THClientService' (..), THandleParams (sessionId, thVersion), TransportError (..), TransportPeer (..), sndAuthKeySMPVersion, shortLinksSMPVersion, newNtfCredsSMPVersion) import Simplex.Messaging.Transport.Client (TransportHost (..)) +import Simplex.Messaging.Transport.Credentials import Simplex.Messaging.Util import Simplex.Messaging.Version import System.Mem.Weak (Weak, deRefWeak) @@ -331,6 +335,7 @@ data AgentClient = AgentClient msgQ :: TBQueue (ServerTransmissionBatch SMPVersion ErrorType BrokerMsg), smpServers :: TMap UserId (UserServers 'PSMP), smpClients :: TMap SMPTransportSession SMPClientVar, + useClientServices :: TMap UserId Bool, -- smpProxiedRelays: -- SMPTransportSession defines connection from proxy to relay, -- SMPServerWithAuth defines client connected to SMP proxy (with the same userId and entityId in TransportSession) @@ -495,7 +500,7 @@ data UserNetworkType = UNNone | UNCellular | UNWifi | UNEthernet | UNOther -- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Map (Maybe SMPServer) (Maybe SystemSeconds) -> Env -> IO AgentClient -newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, presetDomains, presetServers} currentTs notices agentEnv = do +newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, useServices, presetDomains, presetServers} currentTs notices agentEnv = do let cfg = config agentEnv qSize = tbqSize cfg proxySessTs <- newTVarIO =<< getCurrentTime @@ -505,6 +510,7 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, presetDomai msgQ <- newTBQueueIO qSize smpServers <- newTVarIO $ M.map mkUserServers smp smpClients <- TM.emptyIO + useClientServices <- newTVarIO useServices smpProxiedRelays <- TM.emptyIO ntfServers <- newTVarIO ntf ntfClients <- TM.emptyIO @@ -544,6 +550,7 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, presetDomai msgQ, smpServers, smpClients, + useClientServices, smpProxiedRelays, ntfServers, ntfClients, @@ -598,6 +605,28 @@ agentDRG :: AgentClient -> TVar ChaChaDRG agentDRG AgentClient {agentEnv = Env {random}} = random {-# INLINE agentDRG #-} +getServiceCredentials :: AgentClient -> UserId -> SMPServer -> AM (Maybe (ServiceCredentials, Maybe ServiceId)) +getServiceCredentials c userId srv = + liftIO (TM.lookupIO userId $ useClientServices c) + $>>= \useService -> if useService then Just <$> getService else pure Nothing + where + getService :: AM (ServiceCredentials, Maybe ServiceId) + getService = do + let g = agentDRG c + ((C.KeyHash kh, serviceCreds), serviceId_) <- + withStore' c $ \db -> + getClientService db userId srv >>= \case + Just service -> pure service + Nothing -> do + cred <- genCredentials g Nothing (25, 24 * 999999) "simplex" + let tlsCreds = tlsCredentials [cred] + createClientService db userId srv tlsCreds + pure (tlsCreds, Nothing) + (_, pk) <- atomically $ C.generateKeyPair g + let serviceSignKey = C.APrivateSignKey C.SEd25519 pk + creds = ServiceCredentials {serviceRole = SRMessaging, serviceCreds, serviceCertHash = XV.Fingerprint kh, serviceSignKey} + pure (creds, serviceId_) + class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where type Client msg = c | c -> msg getProtocolServerClient :: AgentClient -> NetworkRequestMode -> TransportSession msg -> AM (Client msg) @@ -701,7 +730,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq Nothing -> Left $ BROKER (B.unpack $ strEncode srv) TIMEOUT smpConnectClient :: AgentClient -> NetworkRequestMode -> SMPTransportSession -> TMap SMPServer ProxiedRelayVar -> SMPClientVar -> AM SMPConnectedClient -smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm tSess@(_, srv, _) prs v = +smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm tSess@(userId, srv, _) prs v = newProtocolClient c tSess smpClients connectClient v `catchAllErrors` \e -> lift (resubscribeSMPSession c tSess) >> throwE e where @@ -709,12 +738,22 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm connectClient v' = do cfg <- lift $ getClientConfig c smpCfg g <- asks random + service <- getServiceCredentials c userId srv + let cfg' = cfg {serviceCredentials = fst <$> service} env <- ask - liftError (protocolClientError SMP $ B.unpack $ strEncode srv) $ do + smp <- liftError (protocolClientError SMP $ B.unpack $ strEncode srv) $ do ts <- readTVarIO proxySessTs - smp <- ExceptT $ getProtocolClient g nm tSess cfg presetDomains (Just msgQ) ts $ smpClientDisconnected c tSess env v' prs - atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c - pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs} + ExceptT $ getProtocolClient g nm tSess cfg' presetDomains (Just msgQ) ts $ smpClientDisconnected c tSess env v' prs + atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c + updateClientService service smp + pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs} + updateClientService service smp = case (service, smpClientService smp) of + (Just (_, serviceId_), Just THClientService {serviceId}) + | serviceId_ /= Just serviceId -> withStore' c $ \db -> setClientServiceId db userId srv serviceId + | otherwise -> pure () + (Just _, Nothing) -> withStore' c $ \db -> deleteClientService db userId srv -- e.g., server version downgrade + (Nothing, Just _) -> logError "server returned serviceId without service credentials in request" + (Nothing, Nothing) -> pure () smpClientDisconnected :: AgentClient -> SMPTransportSession -> Env -> SMPClientVar -> TMap SMPServer ProxiedRelayVar -> SMPClient -> IO () smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess@(userId, srv, cId) env v prs client = do @@ -862,7 +901,6 @@ waitForProtocolClient c nm tSess@(_, srv, _) clients v = do (throwE e) Nothing -> throwE $ BROKER (B.unpack $ strEncode srv) TIMEOUT --- clientConnected arg is only passed for SMP server newProtocolClient :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => @@ -1399,7 +1437,8 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl withClient c nm tSess $ \(SMPConnectedClient smp _) -> do (ntfKeys, ntfCreds) <- liftIO $ mkNtfCreds a g smp (thParams smp,ntfKeys,) <$> createSMPQueue smp nm nonce_ rKeys dhKey auth subMode (queueReqData cqrd) ntfCreds - -- TODO [certs rcv] validate that serviceId is the same as in the client session + -- TODO [certs rcv] validate that serviceId is the same as in the client session, fail otherwise + -- possibly, it should allow returning Nothing - it would indicate incorrect old version liftIO . logServer "<--" c srv NoEntity $ B.unwords ["IDS", logSecret rcvId, logSecret sndId] shortLink <- mkShortLinkCreds thParams' qik let rq = @@ -1415,7 +1454,7 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl sndId, queueMode, shortLink, - clientService = ClientService DBNewEntity <$> serviceId, + rcvServiceAssoc = isJust serviceId, status = New, enableNtfs, clientNoticeId = Nothing, @@ -1650,6 +1689,11 @@ processClientNotices c@AgentClient {presetServers} tSess notices = do logError $ "processClientNotices error: " <> tshow e notifySub' c "" $ ERR e +subscribeClientService :: AgentClient -> UserId -> SMPServer -> AM (Int64, IdsHash) +subscribeClientService c userId srv = + withLogClient c NRMBackground (userId, srv, Nothing) B.empty "SUBS" $ + (`subscribeService` SMP.SRecipientService) . connectedClient + activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClients c) where diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 57bc11e3c..129a58239 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -90,6 +90,7 @@ data InitialAgentServers = InitialAgentServers ntf :: [NtfServer], xftp :: Map UserId (NonEmpty (ServerCfg 'PXFTP)), netCfg :: NetworkConfig, + useServices :: Map UserId Bool, presetDomains :: [HostName], presetServers :: [SMPServer] } diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 05ebc1b27..15d51aed9 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -126,9 +126,6 @@ module Simplex.Messaging.Agent.Protocol ContactConnType (..), ShortLinkScheme (..), LinkKey (..), - StoredClientService (..), - ClientService, - ClientServiceId, sameConnReqContact, sameShortLinkContact, simplexChat, @@ -212,7 +209,6 @@ import Simplex.FileTransfer.Transport (XFTPErrorType) import Simplex.FileTransfer.Types (FileErrorType) import Simplex.Messaging.Agent.QueryString import Simplex.Messaging.Agent.Store.DB (Binary (..), FromField (..), ToField (..), blobFieldDecoder, fromTextField_) -import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Client (ProxyClientError) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet @@ -381,7 +377,7 @@ type SndQueueSecured = Bool -- | Parameterized type for SMP agent events data AEvent (e :: AEntity) where - INV :: AConnectionRequestUri -> Maybe ClientServiceId -> AEvent AEConn + INV :: AConnectionRequestUri -> AEvent AEConn CONF :: ConfirmationId -> PQSupport -> [SMPServer] -> ConnInfo -> AEvent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake REQ :: InvitationId -> PQSupport -> NonEmpty SMPServer -> ConnInfo -> AEvent AEConn -- ConnInfo is from sender INFO :: PQSupport -> ConnInfo -> AEvent AEConn @@ -407,7 +403,7 @@ data AEvent (e :: AEntity) where DEL_USER :: Int64 -> AEvent AENone STAT :: ConnectionStats -> AEvent AEConn OK :: AEvent AEConn - JOINED :: SndQueueSecured -> Maybe ClientServiceId -> AEvent AEConn + JOINED :: SndQueueSecured -> AEvent AEConn ERR :: AgentErrorType -> AEvent AEConn ERRS :: NonEmpty (ConnId, AgentErrorType) -> AEvent AENone SUSPENDED :: AEvent AENone @@ -1783,16 +1779,6 @@ instance Encoding UserLinkData where smpP = UserLinkData <$> ((A.char '\255' *> (unLarge <$> smpP)) <|> smpP) {-# INLINE smpP #-} -data StoredClientService (s :: DBStored) = ClientService - { dbServiceId :: DBEntityId' s, - serviceId :: SMP.ServiceId - } - deriving (Eq, Show) - -type ClientService = StoredClientService 'DBStored - -type ClientServiceId = DBEntityId - -- | SMP queue status. data QueueStatus = -- | queue is created diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index c054cb267..ab831ad38 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -85,7 +85,7 @@ data StoredRcvQueue (q :: DBStored) = RcvQueue -- | short link ID and credentials shortLink :: Maybe ShortLinkCreds, -- | associated client service - clientService :: Maybe (StoredClientService q), + rcvServiceAssoc :: ServiceAssoc, -- | queue status status :: QueueStatus, -- | to enable notifications for this queue - this field is duplicated from ConnData @@ -134,9 +134,7 @@ data ShortLinkCreds = ShortLinkCreds } deriving (Show) -clientServiceId :: RcvQueue -> Maybe ClientServiceId -clientServiceId = fmap dbServiceId . clientService -{-# INLINE clientServiceId #-} +type ServiceAssoc = Bool rcvSMPQueueAddress :: RcvQueue -> SMPQueueAddress rcvSMPQueueAddress RcvQueue {server, sndId, e2ePrivKey, queueMode} = diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index ef66eca38..0b2c632fa 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -35,6 +35,14 @@ module Simplex.Messaging.Agent.Store.AgentStore deleteUsersWithoutConns, checkUser, + -- * Client services + createClientService, + getClientService, + getClientServiceServers, + setClientServiceId, + deleteClientService, + deleteClientServices, + -- * Queues and connections createNewConn, updateNewConnRcv, @@ -274,7 +282,9 @@ import qualified Data.Set as S import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) import Data.Word (Word32) +import qualified Data.X509 as X import Network.Socket (ServiceName) +import qualified Network.TLS as TLS import Simplex.FileTransfer.Client (XFTPChunkSpec (..)) import Simplex.FileTransfer.Description import Simplex.FileTransfer.Protocol (FileParty (..), SFileParty (..)) @@ -390,6 +400,75 @@ deleteUsersWithoutConns db = do forM_ userIds $ DB.execute db "DELETE FROM users WHERE user_id = ?" . Only pure userIds +createClientService :: DB.Connection -> UserId -> SMPServer -> (C.KeyHash, TLS.Credential) -> IO () +createClientService db userId srv (kh, (cert, pk)) = + DB.execute + db + [sql| + INSERT INTO client_services + (user_id, host, port, service_cert_hash, service_cert, service_priv_key) + VALUES (?,?,?,?,?,?) + ON CONFLICT (user_id, host, port) + DO UPDATE SET + service_cert_hash = EXCLUDED.service_cert_hash, + service_cert = EXCLUDED.service_cert, + service_priv_key = EXCLUDED.service_priv_key, + rcv_service_id = NULL + |] + (userId, host srv, port srv, kh, cert, pk) + +getClientService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId)) +getClientService db userId srv = + maybeFirstRow toService $ + DB.query + db + [sql| + SELECT service_cert_hash, service_cert, service_priv_key, rcv_service_id + FROM client_services + WHERE user_id = ? AND host = ? AND port = ? + |] + (userId, host srv, port srv) + where + toService (kh, cert, pk, serviceId_) = ((kh, (cert, pk)), serviceId_) + +getClientServiceServers :: DB.Connection -> UserId -> IO [SMPServer] +getClientServiceServers db userId = + map toServer + <$> DB.query + db + [sql| + SELECT c.host, c.port, s.key_hash + FROM client_services c + JOIN servers s ON s.host = c.host AND s.port = c.port + |] + (Only userId) + where + toServer (host, port, kh) = SMPServer host port kh + +setClientServiceId :: DB.Connection -> UserId -> SMPServer -> ServiceId -> IO () +setClientServiceId db userId srv serviceId = + DB.execute + db + [sql| + UPDATE client_services + SET rcv_service_id = ? + WHERE user_id = ? AND host = ? AND port = ? + |] + (serviceId, userId, host srv, port srv) + +deleteClientService :: DB.Connection -> UserId -> SMPServer -> IO () +deleteClientService db userId srv = + DB.execute + db + [sql| + DELETE FROM client_services + WHERE user_id = ? AND host = ? AND port = ? + |] + (userId, host srv, port srv) + +deleteClientServices :: DB.Connection -> UserId -> IO () +deleteClientServices db userId = DB.execute db "DELETE FROM client_services WHERE user_id = ?" (Only userId) + createConn_ :: TVar ChaChaDRG -> ConnData -> @@ -1926,6 +2005,15 @@ deriving newtype instance ToField ChunkReplicaId deriving newtype instance FromField ChunkReplicaId +instance ToField X.CertificateChain where toField = toField . Binary . smpEncode . C.encodeCertChain + +instance FromField X.CertificateChain where fromField = blobFieldDecoder (parseAll C.certChainP) + +instance ToField X.PrivKey where toField = toField . Binary . C.encodeASNObj + +instance FromField X.PrivKey where + fromField = blobFieldDecoder $ C.decodeASNKey >=> \case (pk, []) -> Right pk; r -> C.asnKeyError r + fromOnlyBI :: Only BoolInt -> Bool fromOnlyBI (Only (BI b)) = b {-# INLINE fromOnlyBI #-} @@ -2005,19 +2093,18 @@ insertRcvQueue_ db connId' rq@RcvQueue {..} subMode serverKeyHash_ = do db [sql| INSERT INTO rcv_queues - ( host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, + ( host, port, rcv_id, rcv_service_assoc, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, queue_mode, status, to_subscribe, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version, server_key_hash, link_id, link_key, link_priv_sig_key, link_enc_fixed_data, ntf_public_key, ntf_private_key, ntf_id, rcv_ntf_dh_secret - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); |] - ( (host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) + ( (host server, port server, rcvId, rcvServiceAssoc, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, queueMode, status, BI toSubscribe, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_) :. (shortLinkId <$> shortLink, shortLinkKey <$> shortLink, linkPrivSigKey <$> shortLink, linkEncFixedData <$> shortLink) :. ntfCredsFields ) - -- TODO [certs rcv] save client service - pure (rq :: NewRcvQueue) {connId = connId', dbQueueId = qId, clientService = Nothing} + pure (rq :: NewRcvQueue) {connId = connId', dbQueueId = qId} where toSubscribe = subMode == SMOnlyCreate ntfCredsFields = case clientNtfCreds of @@ -2371,7 +2458,7 @@ rcvQueueQuery = [sql| SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.queue_mode, q.status, c.enable_ntfs, q.client_notice_id, - q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.switch_status, q.smp_client_version, q.delete_errors, + q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.switch_status, q.smp_client_version, q.delete_errors, q.rcv_service_assoc, q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret, q.link_id, q.link_key, q.link_priv_sig_key, q.link_enc_fixed_data FROM rcv_queues q @@ -2381,13 +2468,13 @@ rcvQueueQuery = toRcvQueue :: (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, Maybe QueueMode) - :. (QueueStatus, Maybe BoolInt, Maybe NoticeId, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int) + :. (QueueStatus, Maybe BoolInt, Maybe NoticeId, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int, ServiceAssoc) :. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) :. (Maybe SMP.LinkId, Maybe LinkKey, Maybe C.PrivateKeyEd25519, Maybe EncDataBytes) -> RcvQueue toRcvQueue ( (userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode) - :. (status, enableNtfs_, clientNoticeId, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) + :. (status, enableNtfs_, clientNoticeId, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors, rcvServiceAssoc) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) :. (shortLinkId_, shortLinkKey_, linkPrivSigKey_, linkEncFixedData_) ) = @@ -2401,7 +2488,7 @@ toRcvQueue _ -> Nothing enableNtfs = maybe True unBI enableNtfs_ -- TODO [certs rcv] read client service - in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode, shortLink, clientService = Nothing, status, enableNtfs, clientNoticeId, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion, clientNtfCreds, deleteErrors} + in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode, shortLink, rcvServiceAssoc, status, enableNtfs, clientNoticeId, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion, clientNtfCreds, deleteErrors} -- | returns all connection queue credentials, the first queue is the primary one getRcvQueueSubsByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueueSub)) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/App.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/App.hs index 7371d9584..ae9b3d80e 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/App.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/App.hs @@ -46,6 +46,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250322_short_links import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250702_conn_invitations_remove_cascade_delete import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251009_queue_to_subscribe import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251010_client_notices +import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251020_service_certs import Simplex.Messaging.Agent.Store.Shared (Migration (..)) schemaMigrations :: [(String, Query, Maybe Query)] @@ -91,7 +92,8 @@ schemaMigrations = ("m20250322_short_links", m20250322_short_links, Just down_m20250322_short_links), ("m20250702_conn_invitations_remove_cascade_delete", m20250702_conn_invitations_remove_cascade_delete, Just down_m20250702_conn_invitations_remove_cascade_delete), ("m20251009_queue_to_subscribe", m20251009_queue_to_subscribe, Just down_m20251009_queue_to_subscribe), - ("m20251010_client_notices", m20251010_client_notices, Just down_m20251010_client_notices) + ("m20251010_client_notices", m20251010_client_notices, Just down_m20251010_client_notices), + ("m20251020_service_certs", m20251020_service_certs, Just down_m20251020_service_certs) ] -- | The list of migrations in ascending order by date diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20250517_service_certs.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20250517_service_certs.hs deleted file mode 100644 index 7708fd6d2..000000000 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20250517_service_certs.hs +++ /dev/null @@ -1,40 +0,0 @@ -{-# LANGUAGE QuasiQuotes #-} - -module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250517_service_certs where - -import Database.SQLite.Simple (Query) -import Database.SQLite.Simple.QQ (sql) - --- TODO move date forward, create migration for postgres -m20250517_service_certs :: Query -m20250517_service_certs = - [sql| -CREATE TABLE server_certs( - server_cert_id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL REFERENCES users ON UPDATE RESTRICT ON DELETE CASCADE, - host TEXT NOT NULL, - port TEXT NOT NULL, - certificate BLOB NOT NULL, - priv_key BLOB NOT NULL, - service_id BLOB, - FOREIGN KEY(host, port) REFERENCES servers ON UPDATE CASCADE ON DELETE RESTRICT, -); - -CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON server_certs(user_id, host, port); - -CREATE INDEX idx_server_certs_host_port ON server_certs(host, port); - -ALTER TABLE rcv_queues ADD COLUMN rcv_service_id BLOB; - |] - -down_m20250517_service_certs :: Query -down_m20250517_service_certs = - [sql| -ALTER TABLE rcv_queues DROP COLUMN rcv_service_id; - -DROP INDEX idx_server_certs_host_port; - -DROP INDEX idx_server_certs_user_id_host_port; - -DROP TABLE server_certs; - |] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251020_service_certs.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251020_service_certs.hs new file mode 100644 index 000000000..780ced1d4 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251020_service_certs.hs @@ -0,0 +1,40 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251020_service_certs where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +-- TODO move date forward, create migration for postgres +m20251020_service_certs :: Query +m20251020_service_certs = + [sql| +CREATE TABLE client_services( + user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE, + host TEXT NOT NULL, + port TEXT NOT NULL, + service_cert BLOB NOT NULL, + service_cert_hash BLOB NOT NULL, + service_priv_key BLOB NOT NULL, + rcv_service_id BLOB, + FOREIGN KEY(host, port) REFERENCES servers ON UPDATE CASCADE ON DELETE RESTRICT +); + +CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(user_id, host, port); + +CREATE INDEX idx_server_certs_host_port ON client_services(host, port); + +ALTER TABLE rcv_queues ADD COLUMN rcv_service_assoc INTEGER NOT NULL DEFAULT 0; + |] + +down_m20251020_service_certs :: Query +down_m20251020_service_certs = + [sql| +ALTER TABLE rcv_queues DROP COLUMN rcv_service_assoc; + +DROP INDEX idx_server_certs_host_port; + +DROP INDEX idx_server_certs_user_id_host_port; + +DROP TABLE client_services; + |] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql index d2838a7b0..8013313ac 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql @@ -63,6 +63,7 @@ CREATE TABLE rcv_queues( to_subscribe INTEGER NOT NULL DEFAULT 0, client_notice_id INTEGER REFERENCES client_notices ON UPDATE RESTRICT ON DELETE SET NULL, + rcv_service_assoc INTEGER NOT NULL DEFAULT 0, PRIMARY KEY(host, port, rcv_id), FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT ON UPDATE CASCADE, @@ -450,6 +451,16 @@ CREATE TABLE client_notices( created_at INTEGER NOT NULL, updated_at INTEGER NOT NULL ); +CREATE TABLE client_services( + user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE, + host TEXT NOT NULL, + port TEXT NOT NULL, + service_cert BLOB NOT NULL, + service_cert_hash BLOB NOT NULL, + service_priv_key BLOB NOT NULL, + rcv_service_id BLOB, + FOREIGN KEY(host, port) REFERENCES servers ON UPDATE CASCADE ON DELETE RESTRICT +); CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues(host, port, ntf_id); CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues(conn_id, rcv_queue_id); CREATE UNIQUE INDEX idx_snd_queue_id ON snd_queues(conn_id, snd_queue_id); @@ -593,3 +604,9 @@ CREATE UNIQUE INDEX idx_client_notices_entity ON client_notices( entity_id ); CREATE INDEX idx_rcv_queues_client_notice_id ON rcv_queues(client_notice_id); +CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services( + user_id, + host, + port +); +CREATE INDEX idx_server_certs_host_port ON client_services(host, port); diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 27840b092..4f70efcf2 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -909,12 +909,12 @@ nsubResponse_ = \case {-# INLINE nsubResponse_ #-} -- This command is always sent in background request mode -subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO Int64 +subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO (Int64, IdsHash) subscribeService c party = case smpClientService c of Just THClientService {serviceId, serviceKey} -> do liftIO $ enablePings c sendSMPCommand c NRMBackground (Just (C.APrivateAuthKey C.SEd25519 serviceKey)) serviceId subCmd >>= \case - SOKS n -> pure n + SOKS n idsHash -> pure (n, idsHash) r -> throwE $ unexpectedResponse r where subCmd :: Command p diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 604960360..722a86c7e 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -479,14 +479,14 @@ smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService (True <$ processSubscription r) (pure False) if ok - then case r of - Right n -> notify ca $ CAServiceSubscribed srv serviceSub n + then case r of -- TODO [certs rcv] compare hash + Right (n, _idsHash) -> notify ca $ CAServiceSubscribed srv serviceSub n Left e | smpClientServiceError e -> notifyUnavailable | temporaryClientError e -> reconnectClient ca srv | otherwise -> notify ca $ CAServiceSubError srv serviceSub e else reconnectClient ca srv - processSubscription = mapM_ $ \n -> do + processSubscription = mapM_ $ \(n, _idsHash) -> do -- TODO [certs rcv] validate hash here? setActiveServiceSub ca srv $ Just ((serviceId, n), sessId) setPendingServiceSub ca srv Nothing serviceAvailable THClientService {serviceRole, serviceId = serviceId'} = diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 9cc78acb3..3d24f0bcb 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -87,6 +87,8 @@ module Simplex.Messaging.Crypto signatureKeyPair, publicToX509, encodeASNObj, + decodeASNKey, + asnKeyError, -- * key encoding/decoding encodePubKey, @@ -1493,11 +1495,11 @@ encodeASNObj k = toStrict . encodeASN1 DER $ toASN1 k [] -- Decoding of binary X509 'CryptoPublicKey'. decodePubKey :: CryptoPublicKey k => ByteString -> Either String k -decodePubKey = decodeKey >=> x509ToPublic >=> pubKey +decodePubKey = decodeASNKey >=> x509ToPublic >=> pubKey -- Decoding of binary PKCS8 'PrivateKey'. decodePrivKey :: CryptoPrivateKey k => ByteString -> Either String k -decodePrivKey = decodeKey >=> x509ToPrivate >=> privKey +decodePrivKey = decodeASNKey >=> x509ToPrivate >=> privKey x509ToPublic :: (X.PubKey, [ASN1]) -> Either String APublicKey x509ToPublic = \case @@ -1505,7 +1507,7 @@ x509ToPublic = \case (X.PubKeyEd448 k, []) -> Right . APublicKey SEd448 $ PublicKeyEd448 k (X.PubKeyX25519 k, []) -> Right . APublicKey SX25519 $ PublicKeyX25519 k (X.PubKeyX448 k, []) -> Right . APublicKey SX448 $ PublicKeyX448 k - r -> keyError r + r -> asnKeyError r x509ToPublic' :: CryptoPublicKey k => X.PubKey -> Either String k x509ToPublic' k = x509ToPublic (k, []) >>= pubKey @@ -1517,16 +1519,16 @@ x509ToPrivate = \case (X.PrivKeyEd448 k, []) -> Right $ APrivateKey SEd448 $ PrivateKeyEd448 k (X.PrivKeyX25519 k, []) -> Right $ APrivateKey SX25519 $ PrivateKeyX25519 k (X.PrivKeyX448 k, []) -> Right $ APrivateKey SX448 $ PrivateKeyX448 k - r -> keyError r + r -> asnKeyError r x509ToPrivate' :: CryptoPrivateKey k => X.PrivKey -> Either String k x509ToPrivate' pk = x509ToPrivate (pk, []) >>= privKey {-# INLINE x509ToPrivate' #-} -decodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1]) -decodeKey = fromASN1 <=< first show . decodeASN1 DER . fromStrict +decodeASNKey :: ASN1Object a => ByteString -> Either String (a, [ASN1]) +decodeASNKey = fromASN1 <=< first show . decodeASN1 DER . fromStrict -keyError :: (a, [ASN1]) -> Either String b -keyError = \case +asnKeyError :: (a, [ASN1]) -> Either String b +asnKeyError = \case (_, []) -> Left "unknown key algorithm" _ -> Left "more than one key" diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 13ac3f182..3be4515cc 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -140,6 +140,7 @@ module Simplex.Messaging.Protocol RcvMessage (..), MsgId, MsgBody, + IdsHash, MaxMessageLen, MaxRcvMessageLen, EncRcvMsgBody (..), @@ -698,11 +699,13 @@ data BrokerMsg where -- | Service subscription success - confirms when queue was associated with the service SOK :: Maybe ServiceId -> BrokerMsg -- | The number of queues subscribed with SUBS command - SOKS :: Int64 -> BrokerMsg + SOKS :: Int64 -> IdsHash -> BrokerMsg -- MSG v1/2 has to be supported for encoding/decoding -- v1: MSG :: MsgId -> SystemTime -> MsgBody -> BrokerMsg -- v2: MsgId -> SystemTime -> MsgFlags -> MsgBody -> BrokerMsg MSG :: RcvMessage -> BrokerMsg + -- sent once delivering messages to SUBS command is complete + SALL :: BrokerMsg NID :: NotifierId -> RcvNtfPublicDhKey -> BrokerMsg NMSG :: C.CbNonce -> EncNMsgMeta -> BrokerMsg -- Should include certificate chain @@ -939,6 +942,7 @@ data BrokerMsgTag | SOK_ | SOKS_ | MSG_ + | SALL_ | NID_ | NMSG_ | PKEY_ @@ -1031,6 +1035,7 @@ instance Encoding BrokerMsgTag where SOK_ -> "SOK" SOKS_ -> "SOKS" MSG_ -> "MSG" + SALL_ -> "SALL" NID_ -> "NID" NMSG_ -> "NMSG" PKEY_ -> "PKEY" @@ -1052,6 +1057,7 @@ instance ProtocolMsgTag BrokerMsgTag where "SOK" -> Just SOK_ "SOKS" -> Just SOKS_ "MSG" -> Just MSG_ + "SALL" -> Just SALL_ "NID" -> Just NID_ "NMSG" -> Just NMSG_ "PKEY" -> Just PKEY_ @@ -1454,6 +1460,8 @@ type MsgId = ByteString -- | SMP message body. type MsgBody = ByteString +type IdsHash = ByteString + data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock -- | Type for protocol errors. @@ -1834,9 +1842,12 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where SOK serviceId_ | v >= serviceCertsSMPVersion -> e (SOK_, ' ', serviceId_) | otherwise -> e OK_ -- won't happen, the association with the service requires v >= serviceCertsSMPVersion - SOKS n -> e (SOKS_, ' ', n) + SOKS n idsHash + | v >= rcvServiceSMPVersion -> e (SOKS_, ' ', n, idsHash) + | otherwise -> e (SOKS_, ' ', n) MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -> e (MSG_, ' ', msgId, Tail body) + SALL -> e SALL_ NID nId srvNtfDh -> e (NID_, ' ', nId, srvNtfDh) NMSG nmsgNonce encNMsgMeta -> e (NMSG_, ' ', nmsgNonce, encNMsgMeta) PKEY sid vr certKey -> e (PKEY_, ' ', sid, vr, certKey) @@ -1867,6 +1878,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where MSG . RcvMessage msgId <$> bodyP where bodyP = EncRcvMsgBody . unTail <$> smpP + SALL_ -> pure SALL IDS_ | v >= newNtfCredsSMPVersion -> ids smpP smpP smpP smpP | v >= serviceCertsSMPVersion -> ids smpP smpP smpP nothing @@ -1887,7 +1899,9 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId, serviceId, serverNtfCreds} LNK_ -> LNK <$> _smpP <*> smpP SOK_ -> SOK <$> _smpP - SOKS_ -> SOKS <$> _smpP + SOKS_ + | v >= rcvServiceSMPVersion -> SOKS <$> _smpP <*> smpP + | otherwise -> SOKS <$> _smpP <*> pure B.empty NID_ -> NID <$> _smpP <*> smpP NMSG_ -> NMSG <$> _smpP <*> smpP PKEY_ -> PKEY <$> _smpP <*> smpP <*> smpP @@ -1917,6 +1931,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where PONG -> noEntityMsg PKEY {} -> noEntityMsg RRES _ -> noEntityMsg + SALL -> noEntityMsg -- other broker responses must have queue ID _ | B.null entId -> Left $ CMD NO_ENTITY diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index ec75a07d4..1e5e94fd6 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -1359,7 +1359,7 @@ client -- TODO [certs rcv] rcv subscriptions Server {subscribers, ntfSubscribers} ms - clnt@Client {clientId, ntfSubscriptions, ntfServiceSubscribed, serviceSubsCount = _todo', ntfServiceSubsCount, rcvQ, sndQ, clientTHParams = thParams'@THandleParams {sessionId}, procThreads} = do + clnt@Client {clientId, rcvQ, sndQ, msgQ, clientTHParams = thParams'@THandleParams {sessionId}, procThreads} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands" let THandleParams {thVersion} = thParams' clntServiceId = (\THClientService {serviceId} -> serviceId) <$> (peerClientService =<< thAuth thParams') @@ -1495,7 +1495,9 @@ client OFF -> response <$> maybe (pure $ err INTERNAL) suspendQueue_ q_ DEL -> response <$> maybe (pure $ err INTERNAL) delQueueAndMsgs q_ QUE -> withQueue $ \q qr -> (corrId,entId,) <$> getQueueInfo q qr - Cmd SRecipientService SUBS -> pure $ response $ err (CMD PROHIBITED) -- "TODO [certs rcv]" + Cmd SRecipientService SUBS -> response . (corrId,entId,) <$> case clntServiceId of + Just serviceId -> subscribeServiceMessages serviceId + Nothing -> pure $ ERR INTERNAL -- it's "internal" because it should never get to this branch where createQueue :: NewQueueReq -> M s (Transmission BrokerMsg) createQueue NewQueueReq {rcvAuthKey, rcvDhKey, subMode, queueReqData, ntfCreds} @@ -1615,11 +1617,13 @@ client suspendQueue_ :: (StoreQueue s, QueueRec) -> M s (Transmission BrokerMsg) suspendQueue_ (q, _) = liftIO $ either err (const ok) <$> suspendQueue (queueStore ms) q - -- TODO [certs rcv] if serviceId is passed, associate with the service and respond with SOK subscribeQueueAndDeliver :: StoreQueue s -> QueueRec -> M s ResponseAndMessage - subscribeQueueAndDeliver q qr = + subscribeQueueAndDeliver q qr@QueueRec {rcvServiceId} = liftIO (TM.lookupIO entId $ subscriptions clnt) >>= \case - Nothing -> subscribeRcvQueue qr >>= deliver False + Nothing -> + sharedSubscribeQueue q SRecipientService rcvServiceId subscribers subscriptions serviceSubsCount (newSubscription NoSub) rcvServices >>= \case + Left e -> pure (err e, Nothing) + Right s -> deliver s Just s@Sub {subThread} -> do stats <- asks serverStats case subThread of @@ -1629,27 +1633,29 @@ client pure (err (CMD PROHIBITED), Nothing) _ -> do incStat $ qSubDuplicate stats - atomically (writeTVar (delivered s) Nothing) >> deliver True s + atomically (writeTVar (delivered s) Nothing) >> deliver (True, Just s) where - deliver :: Bool -> Sub -> M s ResponseAndMessage - deliver hasSub sub = do + deliver :: (Bool, Maybe Sub) -> M s ResponseAndMessage + deliver (hasSub, sub_) = do stats <- asks serverStats fmap (either ((,Nothing) . err) id) $ liftIO $ runExceptT $ do msg_ <- tryPeekMsg ms q msg' <- forM msg_ $ \msg -> liftIO $ do ts <- getSystemSeconds + sub <- maybe (atomically getSub) pure sub_ atomically $ setDelivered sub msg ts unless hasSub $ incStat $ qSub stats pure (NoCorrId, entId, MSG (encryptMsg qr msg)) pure ((corrId, entId, SOK clntServiceId), msg') - -- TODO [certs rcv] combine with subscribing ntf queues - subscribeRcvQueue :: QueueRec -> M s Sub - subscribeRcvQueue QueueRec {rcvServiceId} = atomically $ do - writeTQueue (subQ subscribers) (CSClient entId rcvServiceId Nothing, clientId) - sub <- newSubscription NoSub - TM.insert entId sub $ subscriptions clnt - pure sub + getSub :: STM Sub + getSub = + TM.lookup entId (subscriptions clnt) >>= \case + Just sub -> pure sub + Nothing -> do + sub <- newSubscription NoSub + TM.insert entId sub $ subscriptions clnt + pure sub subscribeNewQueue :: RecipientId -> QueueRec -> M s () subscribeNewQueue rId QueueRec {rcvServiceId} = do @@ -1719,74 +1725,131 @@ client else liftIO (updateQueueTime (queueStore ms) q t) >>= either (pure . err') (action q) subscribeNotifications :: StoreQueue s -> NtfCreds -> M s BrokerMsg - subscribeNotifications q NtfCreds {ntfServiceId} = do + subscribeNotifications q NtfCreds {ntfServiceId} = + sharedSubscribeQueue q SNotifierService ntfServiceId ntfSubscribers ntfSubscriptions ntfServiceSubsCount (pure ()) ntfServices >>= \case + Left e -> pure $ ERR e + Right (hasSub, _) -> do + when (isNothing clntServiceId) $ + asks serverStats >>= incStat . (if hasSub then ntfSubDuplicate else ntfSub) + pure $ SOK clntServiceId + + sharedSubscribeQueue :: + (PartyI p, ServiceParty p) => + StoreQueue s -> + SParty p -> + Maybe ServiceId -> + ServerSubscribers s -> + (Client s -> TMap QueueId sub) -> + (Client s -> TVar Int64) -> + STM sub -> + (ServerStats -> ServiceStats) -> + M s (Either ErrorType (Bool, Maybe sub)) + sharedSubscribeQueue q party queueServiceId srvSubscribers clientSubs clientServiceSubs mkSub servicesSel = do stats <- asks serverStats - let incNtfSrvStat sel = incStat $ sel $ ntfServices stats - case clntServiceId of + let incSrvStat sel = incStat $ sel $ servicesSel stats + writeSub = writeTQueue (subQ srvSubscribers) (CSClient entId queueServiceId clntServiceId, clientId) + liftIO $ case clntServiceId of Just serviceId - | ntfServiceId == Just serviceId -> do + | queueServiceId == Just serviceId -> do -- duplicate queue-service association - can only happen in case of response error/timeout - hasSub <- atomically $ ifM hasServiceSub (pure True) (False <$ newServiceQueueSub) + hasSub <- atomically $ ifM hasServiceSub (pure True) (False <$ incServiceQueueSubs) unless hasSub $ do - incNtfSrvStat srvSubCount - incNtfSrvStat srvSubQueues - incNtfSrvStat srvAssocDuplicate - pure $ SOK $ Just serviceId - | otherwise -> + atomically writeSub + incSrvStat srvSubCount + incSrvStat srvSubQueues + incSrvStat srvAssocDuplicate + pure $ Right (hasSub, Nothing) + | otherwise -> runExceptT $ do -- new or updated queue-service association - liftIO (setQueueService (queueStore ms) q SNotifierService (Just serviceId)) >>= \case - Left e -> pure $ ERR e - Right () -> do - hasSub <- atomically $ (<$ newServiceQueueSub) =<< hasServiceSub - unless hasSub $ incNtfSrvStat srvSubCount - incNtfSrvStat srvSubQueues - incNtfSrvStat $ maybe srvAssocNew (const srvAssocUpdated) ntfServiceId - pure $ SOK $ Just serviceId + ExceptT $ setQueueService (queueStore ms) q party (Just serviceId) + hasSub <- atomically $ (<$ incServiceQueueSubs) =<< hasServiceSub + atomically writeSub + liftIO $ do + unless hasSub $ incSrvStat srvSubCount + incSrvStat srvSubQueues + incSrvStat $ maybe srvAssocNew (const srvAssocUpdated) queueServiceId + pure (hasSub, Nothing) where - hasServiceSub = (0 /=) <$> readTVar ntfServiceSubsCount - -- This function is used when queue is associated with the service. - newServiceQueueSub = do - writeTQueue (subQ ntfSubscribers) (CSClient entId ntfServiceId (Just serviceId), clientId) - modifyTVar' ntfServiceSubsCount (+ 1) -- service count - modifyTVar' (totalServiceSubs ntfSubscribers) (+ 1) -- server count for all services - Nothing -> case ntfServiceId of - Just _ -> - liftIO (setQueueService (queueStore ms) q SNotifierService Nothing) >>= \case - Left e -> pure $ ERR e - Right () -> do - -- hasSubscription should never be True in this branch, because queue was associated with service. - -- So unless storage and session states diverge, this check is redundant. - hasSub <- atomically $ hasSubscription >>= newSub - incNtfSrvStat srvAssocRemoved - sok hasSub + hasServiceSub = (0 /=) <$> readTVar (clientServiceSubs clnt) + -- This function is used when queue association with the service is created. + incServiceQueueSubs = modifyTVar' (clientServiceSubs clnt) (+ 1) -- service count + Nothing -> case queueServiceId of + Just _ -> runExceptT $ do + ExceptT $ setQueueService (queueStore ms) q party Nothing + liftIO $ incSrvStat srvAssocRemoved + -- getSubscription may be Just for receiving service, where clientSubs also hold active deliveries for service subscriptions. + -- For notification service it can only be Just if storage and session states diverge. + r <- atomically $ getSubscription >>= newSub + atomically writeSub + pure r Nothing -> do - hasSub <- atomically $ ifM hasSubscription (pure True) (newSub False) - sok hasSub + r@(hasSub, _) <- atomically $ getSubscription >>= newSub + unless hasSub $ atomically writeSub + pure $ Right r where - hasSubscription = TM.member entId ntfSubscriptions - newSub hasSub = do - writeTQueue (subQ ntfSubscribers) (CSClient entId ntfServiceId Nothing, clientId) - unless (hasSub) $ TM.insert entId () ntfSubscriptions - pure hasSub - sok hasSub = do - incStat $ if hasSub then ntfSubDuplicate stats else ntfSub stats - pure $ SOK Nothing + getSubscription = TM.lookup entId $ clientSubs clnt + newSub = \case + Just sub -> pure (True, Just sub) + Nothing -> do + sub <- mkSub + TM.insert entId sub $ clientSubs clnt + pure (False, Just sub) + + subscribeServiceMessages :: ServiceId -> M s BrokerMsg + subscribeServiceMessages serviceId = + sharedSubscribeService SRecipientService serviceId subscribers serviceSubscribed serviceSubsCount >>= \case + Left e -> pure $ ERR e + Right (hasSub, (count, idsHash)) -> do + unless hasSub $ forkClient clnt "deliverServiceMessages" $ liftIO $ deliverServiceMessages count + pure $ SOKS count idsHash + where + deliverServiceMessages expectedCnt = do + (qCnt, _msgCnt, _dupCnt, _errCnt) <- foldRcvServiceMessages ms serviceId deliverQueueMsg (0, 0, 0, 0) + atomically $ writeTBQueue msgQ [(NoCorrId, NoEntity, SALL)] + -- TODO [cert rcv] compare with expected + logNote $ "Service subscriptions for " <> tshow serviceId <> " (" <> tshow qCnt <> " queues)" + deliverQueueMsg :: (Int, Int, Int, Int) -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO (Int, Int, Int, Int) + deliverQueueMsg (!qCnt, !msgCnt, !dupCnt, !errCnt) rId = \case + Left e -> pure (qCnt + 1, msgCnt, dupCnt, errCnt + 1) -- TODO [certs rcv] deliver subscription error + Right qMsg_ -> case qMsg_ of + Nothing -> pure (qCnt + 1, msgCnt, dupCnt, errCnt) + Just (qr, msg) -> + atomically (getSubscription rId) >>= \case + Nothing -> pure (qCnt + 1, msgCnt, dupCnt + 1, errCnt) + Just sub -> do + ts <- getSystemSeconds + atomically $ setDelivered sub msg ts + atomically $ writeTBQueue msgQ [(NoCorrId, rId, MSG (encryptMsg qr msg))] + pure (qCnt + 1, msgCnt + 1, dupCnt, errCnt) + getSubscription rId = + TM.lookup rId (subscriptions clnt) >>= \case + -- If delivery subscription already exists, then there is no need to deliver message. + -- It may have been created when the message is sent after service subscription is created. + Just _sub -> pure Nothing + Nothing -> do + sub <- newSubscription NoSub + TM.insert rId sub $ subscriptions clnt + pure $ Just sub subscribeServiceNotifications :: ServiceId -> M s BrokerMsg - subscribeServiceNotifications serviceId = do - subscribed <- readTVarIO ntfServiceSubscribed - if subscribed - then SOKS <$> readTVarIO ntfServiceSubsCount - else - liftIO (getServiceQueueCount @(StoreQueue s) (queueStore ms) SNotifierService serviceId) >>= \case - Left e -> pure $ ERR e - Right !count' -> do + subscribeServiceNotifications serviceId = + either ERR (uncurry SOKS . snd) <$> sharedSubscribeService SNotifierService serviceId ntfSubscribers ntfServiceSubscribed ntfServiceSubsCount + + sharedSubscribeService :: (PartyI p, ServiceParty p) => SParty p -> ServiceId -> ServerSubscribers s -> (Client s -> TVar Bool) -> (Client s -> TVar Int64) -> M s (Either ErrorType (Bool, (Int64, IdsHash))) + sharedSubscribeService party serviceId srvSubscribers clientServiceSubscribed clientServiceSubs = do + subscribed <- readTVarIO $ clientServiceSubscribed clnt + liftIO $ runExceptT $ + (subscribed,) + <$> if subscribed + then (,B.empty) <$> readTVarIO (clientServiceSubs clnt) -- TODO [certs rcv] get IDs hash + else do + count' <- ExceptT $ getServiceQueueCount @(StoreQueue s) (queueStore ms) party serviceId incCount <- atomically $ do - writeTVar ntfServiceSubscribed True - count <- swapTVar ntfServiceSubsCount count' + writeTVar (clientServiceSubscribed clnt) True + count <- swapTVar (clientServiceSubs clnt) count' pure $ count' - count - atomically $ writeTQueue (subQ ntfSubscribers) (CSService serviceId incCount, clientId) - pure $ SOKS count' + atomically $ writeTQueue (subQ srvSubscribers) (CSService serviceId incCount, clientId) + pure (count', B.empty) -- TODO [certs rcv] get IDs hash acknowledgeMsg :: MsgId -> StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg) acknowledgeMsg msgId q qr = @@ -1904,10 +1967,13 @@ client tryDeliverMessage msg = -- the subscribed client var is read outside of STM to avoid transaction cost -- in case no client is subscribed. - getSubscribedClient rId (queueSubscribers subscribers) + getSubscribed $>>= deliverToSub >>= mapM_ forkDeliver where + getSubscribed = case rcvServiceId qr of + Just serviceId -> getSubscribedClient serviceId $ serviceSubscribers subscribers + Nothing -> getSubscribedClient rId $ queueSubscribers subscribers rId = recipientId q deliverToSub rcv = do ts <- getSystemSeconds @@ -1918,6 +1984,7 @@ client -- the new client will receive message in response to SUB. readTVar rcv $>>= \rc@Client {subscriptions = subs, sndQ = sndQ'} -> TM.lookup rId subs + >>= maybe (newServiceDeliverySub subs) (pure . Just) $>>= \s@Sub {subThread, delivered} -> case subThread of ProhibitSub -> pure Nothing ServerSub st -> readTVar st >>= \case @@ -1930,6 +1997,12 @@ client (writeTVar st SubPending $> Just (rc, s, st)) (deliver sndQ' s ts $> Nothing) _ -> pure Nothing + newServiceDeliverySub subs + | isJust (rcvServiceId qr) = do + sub <- newSubscription NoSub + TM.insert rId sub subs + pure $ Just sub + | otherwise = pure Nothing deliver sndQ' s ts = do let encMsg = encryptMsg qr msg writeTBQueue sndQ' ([(NoCorrId, rId, MSG encMsg)], []) @@ -2051,6 +2124,7 @@ client -- we delete subscription here, so the client with no subscriptions can be disconnected. sub <- atomically $ TM.lookupDelete entId $ subscriptions clnt liftIO $ mapM_ cancelSub sub + when (isJust rcvServiceId) $ atomically $ modifyTVar' (serviceSubsCount clnt) $ \n -> max 0 (n - 1) atomically $ writeTQueue (subQ subscribers) (CSDeleted entId rcvServiceId, clientId) forM_ (notifier qr) $ \NtfCreds {notifierId = nId, ntfServiceId} -> do -- queue is deleted by a different client from the one subscribed to notifications, diff --git a/src/Simplex/Messaging/Server/MsgStore/Journal.hs b/src/Simplex/Messaging/Server/MsgStore/Journal.hs index 5038c8826..d9a1ff6ec 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Journal.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Journal.hs @@ -444,6 +444,26 @@ instance MsgStoreClass (JournalMsgStore s) where getLoadedQueue :: JournalQueue s -> IO (JournalQueue s) getLoadedQueue q = fromMaybe q <$> TM.lookupIO (recipientId q) (loadedQueues $ queueStore_ ms) + foldRcvServiceMessages :: JournalMsgStore s -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO a + foldRcvServiceMessages ms serviceId f acc = case queueStore_ ms of + MQStore st -> foldRcvServiceQueues st serviceId f' acc + where + f' a (q, qr) = runExceptT (tryPeekMsg ms q) >>= f a (recipientId q) . ((qr,) <$$>) +#if defined(dbServerPostgres) + PQStore st -> foldRcvServiceQueueRecs st serviceId f' acc + where + JournalMsgStore {queueLocks, sharedLock} = ms + f' a (rId, qr) = do + q <- mkQueue ms False rId qr + qMsg_ <- + withSharedWaitLock rId queueLocks sharedLock $ runExceptT $ tryStore' "foldRcvServiceMessages" rId $ + (qr,) . snd <$$> (getLoadedQueue q >>= unStoreIO . getPeekMsgQueue ms) + f a rId qMsg_ + -- Use cached queue if available. + -- Also see the comment in loadQueue in PostgresQueueStore + getLoadedQueue q = fromMaybe q <$> TM.lookupIO (recipientId q) (loadedQueues $ queueStore_ ms) +#endif + logQueueStates :: JournalMsgStore s -> IO () logQueueStates ms = withActiveMsgQueues ms $ unStoreIO . logQueueState diff --git a/src/Simplex/Messaging/Server/MsgStore/Postgres.hs b/src/Simplex/Messaging/Server/MsgStore/Postgres.hs index a0eb1d1ca..f3000811b 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Postgres.hs @@ -119,6 +119,34 @@ instance MsgStoreClass PostgresMsgStore where toMessageStats (expiredMsgsCount, storedMsgsCount, storedQueues) = MessageStats {expiredMsgsCount, storedMsgsCount, storedQueues} + foldRcvServiceMessages :: PostgresMsgStore -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO a + foldRcvServiceMessages ms serviceId f acc = + withTransaction (dbStore $ queueStore_ ms) $ \db -> + DB.fold + db + [sql| + SELECT q.recipient_id, q.recipient_keys, q.rcv_dh_secret, + q.sender_id, q.sender_key, q.queue_mode, + q.notifier_id, q.notifier_key, q.rcv_ntf_dh_secret, q.ntf_service_id, + q.status, q.updated_at, q.link_id, q.rcv_service_id, + m.msg_id, m.msg_ts, m.msg_quota, m.msg_ntf_flag, m.msg_body + FROM msg_queues q + LEFT JOIN ( + SELECT recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body, + ROW_NUMBER() OVER (PARTITION BY recipient_id ORDER BY message_id ASC) AS row_num + FROM messages + ) m ON q.recipient_id = m.recipient_id AND m.row_num = 1 + WHERE q.rcv_service_id = ? AND q.deleted_at IS NULL; + |] + (Only serviceId) + acc + f' + where + f' a (qRow :. mRow) = + let (rId, qr) = rowToQueueRec qRow + msg_ = toMaybeMessage mRow + in f a rId $ Right ((qr,) <$> msg_) + logQueueStates _ = error "logQueueStates not used" logQueueState _ = error "logQueueState not used" @@ -247,6 +275,11 @@ uninterruptibleMask_ :: ExceptT ErrorType IO a -> ExceptT ErrorType IO a uninterruptibleMask_ = ExceptT . E.uninterruptibleMask_ . runExceptT {-# INLINE uninterruptibleMask_ #-} +toMaybeMessage :: (Maybe (Binary MsgId), Maybe Int64, Maybe Bool, Maybe Bool, Maybe (Binary MsgBody)) -> Maybe Message +toMaybeMessage = \case + (Just msgId, Just ts, Just msgQuota, Just ntf, Just body) -> Just $ toMessage (msgId, ts, msgQuota, ntf, body) + _ -> Nothing + toMessage :: (Binary MsgId, Int64, Bool, Bool, Binary MsgBody) -> Message toMessage (Binary msgId, ts, msgQuota, ntf, Binary body) | msgQuota = MessageQuota {msgId, msgTs} diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 73e1bf398..24d489acc 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -87,6 +87,11 @@ instance MsgStoreClass STMMsgStore where expireOldMessages _tty ms now ttl = withLoadedQueues (queueStore_ ms) $ atomically . expireQueueMsgs ms now (now - ttl) + foldRcvServiceMessages :: STMMsgStore -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO a + foldRcvServiceMessages ms serviceId f= + foldRcvServiceQueues (queueStore_ ms) serviceId $ \a (q, qr) -> + runExceptT (tryPeekMsg ms q) >>= f a (recipientId q) . ((qr,) <$$>) + logQueueStates _ = pure () {-# INLINE logQueueStates #-} logQueueState _ = pure () diff --git a/src/Simplex/Messaging/Server/MsgStore/Types.hs b/src/Simplex/Messaging/Server/MsgStore/Types.hs index 98c12d4be..e186da05a 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Types.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Types.hs @@ -45,6 +45,7 @@ class (Monad (StoreMonad s), QueueStoreClass (StoreQueue s) (QueueStore s)) => M unsafeWithAllMsgQueues :: Monoid a => Bool -> s -> (StoreQueue s -> IO a) -> IO a -- tty, store, now, ttl expireOldMessages :: Bool -> s -> Int64 -> Int64 -> IO MessageStats + foldRcvServiceMessages :: s -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO a logQueueStates :: s -> IO () logQueueState :: StoreQueue s -> StoreMonad s () queueStore :: s -> QueueStore s diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs index e86bec07b..2fabbfa33 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs @@ -24,9 +24,11 @@ module Simplex.Messaging.Server.QueueStore.Postgres batchInsertServices, batchInsertQueues, foldServiceRecs, + foldRcvServiceQueueRecs, foldQueueRecs, foldRecentQueueRecs, handleDuplicate, + rowToQueueRec, withLog_, withDB, withDB', @@ -577,12 +579,17 @@ insertServiceQuery = VALUES (?,?,?,?,?) |] -foldServiceRecs :: forall a q. Monoid a => PostgresQueueStore q -> (ServiceRec -> IO a) -> IO a +foldServiceRecs :: Monoid a => PostgresQueueStore q -> (ServiceRec -> IO a) -> IO a foldServiceRecs st f = withTransaction (dbStore st) $ \db -> DB.fold_ db "SELECT service_id, service_role, service_cert, service_cert_hash, created_at FROM services" mempty $ \ !acc -> fmap (acc <>) . f . rowToServiceRec +foldRcvServiceQueueRecs :: PostgresQueueStore q -> ServiceId -> (a -> (RecipientId, QueueRec) -> IO a) -> a -> IO a +foldRcvServiceQueueRecs st serviceId f acc = + withTransaction (dbStore st) $ \db -> + DB.fold db (queueRecQuery <> " WHERE rcv_service_id = ? AND deleted_at IS NULL") (Only serviceId) acc $ \a -> f a . rowToQueueRec + foldQueueRecs :: Monoid a => Bool -> Bool -> PostgresQueueStore q -> ((RecipientId, QueueRec) -> IO a) -> IO a foldQueueRecs withData = foldQueueRecs_ foldRecs where @@ -769,10 +776,6 @@ instance ToField SMPServiceRole where toField = toField . decodeLatin1 . smpEnco instance FromField SMPServiceRole where fromField = fromTextField_ $ eitherToMaybe . smpDecode . encodeUtf8 -instance ToField X.CertificateChain where toField = toField . Binary . smpEncode . C.encodeCertChain - -instance FromField X.CertificateChain where fromField = blobFieldDecoder (parseAll C.certChainP) - #if !defined(dbPostgres) instance ToField EntityId where toField (EntityId s) = toField $ Binary s @@ -797,4 +800,8 @@ deriving newtype instance FromField EncDataBytes deriving newtype instance ToField (RoundedSystemTime t) deriving newtype instance FromField (RoundedSystemTime t) + +instance ToField X.CertificateChain where toField = toField . Binary . smpEncode . C.encodeCertChain + +instance FromField X.CertificateChain where fromField = blobFieldDecoder (parseAll C.certChainP) #endif diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index ad98698db..ad3e00a03 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -17,6 +17,7 @@ module Simplex.Messaging.Server.QueueStore.STM ( STMQueueStore (..), STMService (..), + foldRcvServiceQueues, setStoreLog, withLog', readQueueRecIO, @@ -45,7 +46,7 @@ import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPServiceRole (..)) -import Simplex.Messaging.Util (anyM, ifM, tshow, ($>>), ($>>=), (<$$)) +import Simplex.Messaging.Util (anyM, ifM, tshow, ($>>), ($>>=), (<$$), (<$$>)) import System.IO import UnliftIO.STM @@ -359,6 +360,16 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where SRecipientService -> serviceRcvQueues SNotifierService -> serviceNtfQueues +foldRcvServiceQueues :: StoreQueueClass q => STMQueueStore q -> ServiceId -> (a -> (q, QueueRec) -> IO a) -> a -> IO a +foldRcvServiceQueues st serviceId f acc = + TM.lookupIO serviceId (services st) >>= \case + Nothing -> pure acc + Just s -> + readTVarIO (serviceRcvQueues s) + >>= foldM (\a -> get >=> maybe (pure a) (f a)) acc + where + get rId = TM.lookupIO rId (queues st) $>>= \q -> (q,) <$$> readTVarIO (queueRec q) + withQueueRec :: TVar (Maybe QueueRec) -> (QueueRec -> STM a) -> IO (Either ErrorType a) withQueueRec qr a = atomically $ readQueueRec qr >>= mapM a diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index e2e912875..2d959410d 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -56,6 +56,7 @@ module Simplex.Messaging.Transport serviceCertsSMPVersion, newNtfCredsSMPVersion, clientNoticesSMPVersion, + rcvServiceSMPVersion, simplexMQVersion, smpBlockSize, TransportConfig (..), @@ -170,6 +171,7 @@ smpBlockSize = 16384 -- 16 - service certificates (5/31/2025) -- 17 - create notification credentials with NEW (7/12/2025) -- 18 - support client notices (10/10/2025) +-- 19 - service subscriptions to messages (10/20/2025) data SMPVersion @@ -218,6 +220,9 @@ newNtfCredsSMPVersion = VersionSMP 17 clientNoticesSMPVersion :: VersionSMP clientNoticesSMPVersion = VersionSMP 18 +rcvServiceSMPVersion :: VersionSMP +rcvServiceSMPVersion = VersionSMP 19 + minClientSMPRelayVersion :: VersionSMP minClientSMPRelayVersion = VersionSMP 6 @@ -225,13 +230,13 @@ minServerSMPRelayVersion :: VersionSMP minServerSMPRelayVersion = VersionSMP 6 currentClientSMPRelayVersion :: VersionSMP -currentClientSMPRelayVersion = VersionSMP 18 +currentClientSMPRelayVersion = VersionSMP 19 legacyServerSMPRelayVersion :: VersionSMP legacyServerSMPRelayVersion = VersionSMP 6 currentServerSMPRelayVersion :: VersionSMP -currentServerSMPRelayVersion = VersionSMP 18 +currentServerSMPRelayVersion = VersionSMP 19 -- Max SMP protocol version to be used in e2e encrypted -- connection between client and server, as defined by SMP proxy. @@ -239,7 +244,7 @@ currentServerSMPRelayVersion = VersionSMP 18 -- to prevent client version fingerprinting by the -- destination relays when clients upgrade at different times. proxiedSMPRelayVersion :: VersionSMP -proxiedSMPRelayVersion = VersionSMP 17 +proxiedSMPRelayVersion = VersionSMP 18 -- minimal supported protocol version is 6 -- TODO remove code that supports sending commands without batching @@ -823,7 +828,7 @@ smpClientHandshake c ks_ keyHash@(C.KeyHash kh) vRange proxyServer serviceKeys_ serviceKeys = case serviceKeys_ of Just sks | v >= serviceCertsSMPVersion && certificateSent c -> Just sks _ -> Nothing - clientService = mkClientService <$> serviceKeys + clientService = mkClientService v =<< serviceKeys hs = SMPClientHandshake {smpVersion = v, keyHash, authPubKey = fst <$> ks_, proxyServer, clientService} sendHandshake th hs service <- mapM getClientService serviceKeys @@ -831,10 +836,12 @@ smpClientHandshake c ks_ keyHash@(C.KeyHash kh) vRange proxyServer serviceKeys_ Nothing -> throwE TEVersion where th@THandle {params = THandleParams {sessionId}} = smpTHandle c - mkClientService :: (ServiceCredentials, C.KeyPairEd25519) -> SMPClientHandshakeService - mkClientService (ServiceCredentials {serviceRole, serviceCreds, serviceSignKey}, (k, _)) = - let sk = C.signX509 serviceSignKey $ C.publicToX509 k - in SMPClientHandshakeService {serviceRole, serviceCertKey = CertChainPubKey (fst serviceCreds) sk} + mkClientService :: VersionSMP -> (ServiceCredentials, C.KeyPairEd25519) -> Maybe SMPClientHandshakeService + mkClientService v (ServiceCredentials {serviceRole, serviceCreds, serviceSignKey}, (k, _)) + | serviceRole == SRMessaging && v < rcvServiceSMPVersion = Nothing + | otherwise = + let sk = C.signX509 serviceSignKey $ C.publicToX509 k + in Just SMPClientHandshakeService {serviceRole, serviceCertKey = CertChainPubKey (fst serviceCreds) sk} getClientService :: (ServiceCredentials, C.KeyPairEd25519) -> ExceptT TransportError IO THClientService getClientService (ServiceCredentials {serviceRole, serviceCertHash}, (_, pk)) = getHandshake th >>= \case diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index fcdd5be29..017958890 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -85,7 +85,7 @@ import Simplex.Messaging.Agent hiding (acceptContact, createConnection, deleteCo import qualified Simplex.Messaging.Agent as A import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), ServerQueueInfo (..), UserNetworkInfo (..), UserNetworkType (..), waitForUserNetwork) import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), Env (..), InitialAgentServers (..), createAgentStore) -import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, REQ, SENT, INV, JOINED) +import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, REQ, SENT) import qualified Simplex.Messaging.Agent.Protocol as A import Simplex.Messaging.Agent.Store (Connection' (..), SomeConn' (..), StoredRcvQueue (..)) import Simplex.Messaging.Agent.Store.AgentStore (getConn) @@ -219,12 +219,6 @@ pattern SENT msgId = A.SENT msgId Nothing pattern Rcvd :: AgentMsgId -> AEvent 'AEConn pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}] -pattern INV :: AConnectionRequestUri -> AEvent 'AEConn -pattern INV cReq = A.INV cReq Nothing - -pattern JOINED :: SndQueueSecured -> AEvent 'AEConn -pattern JOINED sndSecure = A.JOINED sndSecure Nothing - smpCfgVPrev :: ProtocolClientConfig SMPVersion smpCfgVPrev = (smpCfg agentCfg) {serverVRange = prevRange $ serverVRange $ smpCfg agentCfg} @@ -282,16 +276,16 @@ inAnyOrder g rs = withFrozenCallStack $ do createConnection :: ConnectionModeI c => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> AE (ConnId, ConnectionRequestUri c) createConnection c userId enableNtfs cMode clientData subMode = do - (connId, (CCLink cReq _, Nothing)) <- A.createConnection c NRMInteractive userId enableNtfs True cMode Nothing clientData IKPQOn subMode + (connId, CCLink cReq _) <- A.createConnection c NRMInteractive userId enableNtfs True cMode Nothing clientData IKPQOn subMode pure (connId, cReq) joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE (ConnId, SndQueueSecured) joinConnection c userId enableNtfs cReq connInfo subMode = do connId <- A.prepareConnectionToJoin c userId enableNtfs cReq PQSupportOn - (sndSecure, Nothing) <- A.joinConnection c NRMInteractive userId connId enableNtfs cReq connInfo PQSupportOn subMode + sndSecure <- A.joinConnection c NRMInteractive userId connId enableNtfs cReq connInfo PQSupportOn subMode pure (connId, sndSecure) -acceptContact :: AgentClient -> UserId -> ConnId -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE (SndQueueSecured, Maybe ClientServiceId) +acceptContact :: AgentClient -> UserId -> ConnId -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE SndQueueSecured acceptContact c = A.acceptContact c NRMInteractive subscribeConnection :: AgentClient -> ConnId -> AE () @@ -708,9 +702,9 @@ runAgentClientTest pqSupport sqSecured viaProxy alice bob baseId = runAgentClientTestPQ :: HasCallStack => SndQueueSecured -> Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () runAgentClientTestPQ sqSecured viaProxy (alice, aPQ) (bob, bPQ) baseId = runRight_ $ do - (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing aPQ SMSubscribe + (bobId, CCLink qInfo Nothing) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing aPQ SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo bPQ - (sqSecured', Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" bPQ SMSubscribe + sqSecured' <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" bPQ SMSubscribe liftIO $ sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` CR.connPQEncryption aPQ @@ -910,14 +904,14 @@ runAgentClientContactTest pqSupport sqSecured viaProxy alice bob baseId = runAgentClientContactTestPQ :: HasCallStack => SndQueueSecured -> Bool -> PQSupport -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () runAgentClientContactTestPQ sqSecured viaProxy reqPQSupport (alice, aPQ) (bob, bPQ) baseId = runRight_ $ do - (_, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMContact Nothing Nothing aPQ SMSubscribe + (_, CCLink qInfo Nothing) <- A.createConnection alice NRMInteractive 1 True True SCMContact Nothing Nothing aPQ SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo bPQ - (sqSecuredJoin, Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" bPQ SMSubscribe + sqSecuredJoin <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" bPQ SMSubscribe liftIO $ sqSecuredJoin `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` reqPQSupport bobId <- A.prepareConnectionToAccept alice 1 True invId (CR.connPQEncryption aPQ) - (sqSecured', Nothing) <- acceptContact alice 1 bobId True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + sqSecured' <- acceptContact alice 1 bobId True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe liftIO $ sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get bob liftIO $ pqSup'' `shouldBe` bPQ @@ -954,7 +948,7 @@ runAgentClientContactTestPQ sqSecured viaProxy reqPQSupport (alice, aPQ) (bob, b runAgentClientContactTestPQ3 :: HasCallStack => Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () runAgentClientContactTestPQ3 viaProxy (alice, aPQ) (bob, bPQ) (tom, tPQ) baseId = runRight_ $ do - (_, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMContact Nothing Nothing aPQ SMSubscribe + (_, CCLink qInfo Nothing) <- A.createConnection alice NRMInteractive 1 True True SCMContact Nothing Nothing aPQ SMSubscribe (bAliceId, bobId, abPQEnc) <- connectViaContact bob bPQ qInfo sentMessages abPQEnc alice bobId bob bAliceId (tAliceId, tomId, atPQEnc) <- connectViaContact tom tPQ qInfo @@ -963,12 +957,12 @@ runAgentClientContactTestPQ3 viaProxy (alice, aPQ) (bob, bPQ) (tom, tPQ) baseId msgId = subtract baseId . fst connectViaContact b pq qInfo = do aId <- A.prepareConnectionToJoin b 1 True qInfo pq - (sqSecuredJoin, Nothing) <- A.joinConnection b NRMInteractive 1 aId True qInfo "bob's connInfo" pq SMSubscribe + sqSecuredJoin <- A.joinConnection b NRMInteractive 1 aId True qInfo "bob's connInfo" pq SMSubscribe liftIO $ sqSecuredJoin `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn bId <- A.prepareConnectionToAccept alice 1 True invId (CR.connPQEncryption aPQ) - (sqSecuredAccept, Nothing) <- acceptContact alice 1 bId True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + sqSecuredAccept <- acceptContact alice 1 bId True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe liftIO $ sqSecuredAccept `shouldBe` False -- agent cfg is v8 ("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get b liftIO $ pqSup'' `shouldBe` pq @@ -1007,9 +1001,9 @@ noMessages_ ingoreQCONT c err = tryGet `shouldReturn` () testRejectContactRequest :: HasCallStack => IO () testRejectContactRequest = withAgentClients2 $ \alice bob -> runRight_ $ do - (_addrConnId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMContact Nothing Nothing IKPQOn SMSubscribe + (_addrConnId, CCLink qInfo Nothing) <- A.createConnection alice NRMInteractive 1 True True SCMContact Nothing Nothing IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - (sqSecured, Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe + sqSecured <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId PQSupportOn _ "bob's connInfo") <- get alice rejectContact alice invId @@ -1022,7 +1016,7 @@ testUpdateConnectionUserId = newUserId <- createUser alice [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] _ <- changeConnectionUser alice 1 connId newUserId aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - (sqSecured', Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe + sqSecured' <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured' `shouldBe` True ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn @@ -1206,7 +1200,7 @@ testInvitationErrors ps restart = do threadDelay 200000 let loopConfirm n = runExceptT (A.joinConnection b' NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe) >>= \case - Right (True, Nothing) -> pure n + Right True -> pure n Right r -> error $ "unexpected result " <> show r Left _ -> putStrLn "retrying confirm" >> threadDelay 200000 >> loopConfirm (n + 1) n <- loopConfirm 1 @@ -1268,7 +1262,7 @@ testContactErrors ps restart = do let loopSend = do -- sends the invitation to testPort runExceptT (A.joinConnection b'' NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe) >>= \case - Right (False, Nothing) -> pure () + Right False -> pure () Right r -> error $ "unexpected result " <> show r Left _ -> putStrLn "retrying send" >> threadDelay 200000 >> loopSend loopSend @@ -1297,7 +1291,7 @@ testContactErrors ps restart = do ("", "", UP _ [_]) <- nGet b'' let loopConfirm n = runExceptT (acceptContact a' 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe) >>= \case - Right (True, Nothing) -> pure n + Right True -> pure n Right r -> error $ "unexpected result " <> show r Left _ -> putStrLn "retrying accept confirm" >> threadDelay 200000 >> loopConfirm (n + 1) n <- loopConfirm 1 @@ -1334,7 +1328,7 @@ testInvitationShortLink viaProxy a b = withAgent 3 agentCfg initAgentServers testDB3 $ \c -> do let userData = UserLinkData "some user data" newLinkData = UserInvLinkData userData - (bId, (CCLink connReq (Just shortLink), Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKUsePQ SMSubscribe + (bId, CCLink connReq (Just shortLink)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKUsePQ SMSubscribe (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq @@ -1356,7 +1350,7 @@ testInvitationShortLink viaProxy a b = testJoinConn_ :: Bool -> Bool -> AgentClient -> ConnId -> AgentClient -> ConnectionRequestUri c -> ExceptT AgentErrorType IO () testJoinConn_ viaProxy sndSecure a bId b connReq = do aId <- A.prepareConnectionToJoin b 1 True connReq PQSupportOn - (sndSecure', Nothing) <- A.joinConnection b NRMInteractive 1 aId True connReq "bob's connInfo" PQSupportOn SMSubscribe + sndSecure' <- A.joinConnection b NRMInteractive 1 aId True connReq "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sndSecure' `shouldBe` sndSecure ("", _, CONF confId _ "bob's connInfo") <- get a allowConnection a bId confId "alice's connInfo" @@ -1370,14 +1364,14 @@ testInvitationShortLinkPrev viaProxy sndSecure a b = runRight_ $ do let userData = UserLinkData "some user data" newLinkData = UserInvLinkData userData -- can't create short link with previous version - (bId, (CCLink connReq Nothing, Nothing)) <- A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKPQOn SMSubscribe + (bId, CCLink connReq Nothing) <- A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKPQOn SMSubscribe testJoinConn_ viaProxy sndSecure a bId b connReq testInvitationShortLinkAsync :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO () testInvitationShortLinkAsync viaProxy a b = do let userData = UserLinkData "some user data" newLinkData = UserInvLinkData userData - (bId, (CCLink connReq (Just shortLink), Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKUsePQ SMSubscribe + (bId, CCLink connReq (Just shortLink)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKUsePQ SMSubscribe (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq @@ -1404,7 +1398,7 @@ testContactShortLink viaProxy a b = let userData = UserLinkData "some user data" userCtData = UserContactData {direct = True, owners = [], relays = [], userData} newLinkData = UserContactLinkData userCtData - (contactId, (CCLink connReq0 (Just shortLink), Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMContact (Just newLinkData) Nothing CR.IKPQOn SMSubscribe + (contactId, CCLink connReq0 (Just shortLink)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMContact (Just newLinkData) Nothing CR.IKPQOn SMSubscribe Right connReq <- pure $ smpDecode (smpEncode connReq0) (connReq', ContactLinkData _ userCtData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink @@ -1423,7 +1417,7 @@ testContactShortLink viaProxy a b = liftIO $ sndSecure `shouldBe` False ("", _, REQ invId _ "bob's connInfo") <- get a bId <- A.prepareConnectionToAccept a 1 True invId PQSupportOn - (sndSecure', Nothing) <- acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe + sndSecure' <- acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe liftIO $ sndSecure' `shouldBe` True ("", _, CONF confId _ "alice's connInfo") <- get b allowConnection b aId confId "bob's connInfo" @@ -1451,7 +1445,7 @@ testContactShortLink viaProxy a b = testAddContactShortLink :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO () testAddContactShortLink viaProxy a b = withAgent 3 agentCfg initAgentServers testDB3 $ \c -> do - (contactId, (CCLink connReq0 Nothing, Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMContact Nothing Nothing CR.IKPQOn SMSubscribe + (contactId, CCLink connReq0 Nothing) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMContact Nothing Nothing CR.IKPQOn SMSubscribe Right connReq <- pure $ smpDecode (smpEncode connReq0) -- let userData = UserLinkData "some user data" userCtData = UserContactData {direct = True, owners = [], relays = [], userData} @@ -1474,7 +1468,7 @@ testAddContactShortLink viaProxy a b = liftIO $ sndSecure `shouldBe` False ("", _, REQ invId _ "bob's connInfo") <- get a bId <- A.prepareConnectionToAccept a 1 True invId PQSupportOn - (sndSecure', Nothing) <- acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe + sndSecure' <- acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe liftIO $ sndSecure' `shouldBe` True ("", _, CONF confId _ "alice's connInfo") <- get b allowConnection b aId confId "bob's connInfo" @@ -1496,7 +1490,7 @@ testInvitationShortLinkRestart :: HasCallStack => (ASrvTransport, AStoreType) -> testInvitationShortLinkRestart ps = withAgentClients2 $ \a b -> do let userData = UserLinkData "some user data" newLinkData = UserInvLinkData userData - (bId, (CCLink connReq (Just shortLink), Nothing)) <- withSmpServer ps $ + (bId, CCLink connReq (Just shortLink)) <- withSmpServer ps $ runRight $ A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKUsePQ SMOnlyCreate withSmpServer ps $ do runRight_ $ subscribeConnection a bId @@ -1510,7 +1504,7 @@ testContactShortLinkRestart ps = withAgentClients2 $ \a b -> do let userData = UserLinkData "some user data" userCtData = UserContactData {direct = True, owners = [], relays = [], userData} newLinkData = UserContactLinkData userCtData - (contactId, (CCLink connReq0 (Just shortLink), Nothing)) <- withSmpServer ps $ + (contactId, CCLink connReq0 (Just shortLink)) <- withSmpServer ps $ runRight $ A.createConnection a NRMInteractive 1 True True SCMContact (Just newLinkData) Nothing CR.IKPQOn SMOnlyCreate Right connReq <- pure $ smpDecode (smpEncode connReq0) let updatedData = UserLinkData "updated user data" @@ -1534,7 +1528,7 @@ testAddContactShortLinkRestart ps = withAgentClients2 $ \a b -> do let userData = UserLinkData "some user data" userCtData = UserContactData {direct = True, owners = [], relays = [], userData} newLinkData = UserContactLinkData userCtData - ((contactId, (CCLink connReq0 Nothing, Nothing)), shortLink) <- withSmpServer ps $ runRight $ do + ((contactId, CCLink connReq0 Nothing), shortLink) <- withSmpServer ps $ runRight $ do r@(contactId, _) <- A.createConnection a NRMInteractive 1 True True SCMContact Nothing Nothing CR.IKPQOn SMOnlyCreate (r,) <$> setConnShortLink a contactId SCMContact newLinkData Nothing Right connReq <- pure $ smpDecode (smpEncode connReq0) @@ -1556,7 +1550,7 @@ testAddContactShortLinkRestart ps = withAgentClients2 $ \a b -> do testOldContactQueueShortLink :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testOldContactQueueShortLink ps@(_, msType) = withAgentClients2 $ \a b -> do - (contactId, (CCLink connReq Nothing, Nothing)) <- withSmpServer ps $ runRight $ + (contactId, CCLink connReq Nothing) <- withSmpServer ps $ runRight $ A.createConnection a NRMInteractive 1 True True SCMContact Nothing Nothing CR.IKPQOn SMOnlyCreate -- make it an "old" queue let updateStoreLog f = replaceSubstringInFile f " queue_mode=C" "" @@ -2301,9 +2295,9 @@ makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn True makeConnectionForUsers_ :: HasCallStack => PQSupport -> SndQueueSecured -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnectionForUsers_ pqSupport sqSecured alice aliceUserId bob bobUserId = do - (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive aliceUserId True True SCMInvitation Nothing Nothing (IKLinkPQ pqSupport) SMSubscribe + (bobId, CCLink qInfo Nothing) <- A.createConnection alice NRMInteractive aliceUserId True True SCMInvitation Nothing Nothing (IKLinkPQ pqSupport) SMSubscribe aliceId <- A.prepareConnectionToJoin bob bobUserId True qInfo pqSupport - (sqSecured', Nothing) <- A.joinConnection bob NRMInteractive bobUserId aliceId True qInfo "bob's connInfo" pqSupport SMSubscribe + sqSecured' <- A.joinConnection bob NRMInteractive bobUserId aliceId True qInfo "bob's connInfo" pqSupport SMSubscribe liftIO $ sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` pqSupport diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index dff79c861..f66dfe5df 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -227,7 +227,7 @@ rcvQueue1 = sndId = EntityId "2345", queueMode = Just QMMessaging, shortLink = Nothing, - clientService = Nothing, + rcvServiceAssoc = False, status = New, enableNtfs = True, clientNoticeId = Nothing, @@ -441,7 +441,7 @@ testUpgradeSndConnToDuplex = sndId = EntityId "4567", queueMode = Just QMMessaging, shortLink = Nothing, - clientService = Nothing, + rcvServiceAssoc = False, status = New, enableNtfs = True, clientNoticeId = Nothing, diff --git a/tests/AgentTests/ServerChoice.hs b/tests/AgentTests/ServerChoice.hs index a27678cb6..8412c6761 100644 --- a/tests/AgentTests/ServerChoice.hs +++ b/tests/AgentTests/ServerChoice.hs @@ -64,6 +64,7 @@ initServers = ntf = [testNtfServer], xftp = userServers [testXFTPServer], netCfg = defaultNetworkConfig, + useServices = M.empty, presetDomains = [], presetServers = [] } diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 02bee9ae7..935775050 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -65,6 +65,7 @@ initAgentServers = ntf = [testNtfServer], xftp = userServers [testXFTPServer], netCfg = defaultNetworkConfig {tcpTimeout = NetworkTimeout 500000 500000, tcpConnectTimeout = NetworkTimeout 500000 500000}, + useServices = M.empty, presetDomains = [], presetServers = [] } diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 3c1ac0150..361bc4f1d 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -15,10 +15,14 @@ module SMPClient where +import Control.Monad import Control.Monad.Except (runExceptT) import Data.ByteString.Char8 (ByteString) import Data.List.NonEmpty (NonEmpty) +import qualified Data.X509 as X +import qualified Data.X509.Validation as XV import Network.Socket +import qualified Network.TLS as TLS import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultNetworkConfig) @@ -33,6 +37,7 @@ import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (.. import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client import Simplex.Messaging.Transport.Server +import Simplex.Messaging.Transport.Shared (ChainCertificates (..), chainIdCaCerts) import Simplex.Messaging.Util (ifM) import Simplex.Messaging.Version import Simplex.Messaging.Version.Internal @@ -151,13 +156,26 @@ testSMPClient = testSMPClientVR supportedClientSMPRelayVRange testSMPClientVR :: Transport c => VersionRangeSMP -> (THandleSMP c 'TClient -> IO a) -> IO a testSMPClientVR vr client = do Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost - testSMPClient_ useHost testPort vr client + testSMPClient_ useHost testPort vr Nothing client -testSMPClient_ :: Transport c => TransportHost -> ServiceName -> VersionRangeSMP -> (THandleSMP c 'TClient -> IO a) -> IO a -testSMPClient_ host port vr client = do - let tcConfig = defaultTransportClientConfig {clientALPN} :: TransportClientConfig +testSMPServiceClient :: Transport c => (TLS.Credential, C.KeyPairEd25519) -> (THandleSMP c 'TClient -> IO a) -> IO a +testSMPServiceClient serviceCreds client = do + Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost + testSMPClient_ useHost testPort supportedClientSMPRelayVRange (Just serviceCreds) client + +testSMPClient_ :: Transport c => TransportHost -> ServiceName -> VersionRangeSMP -> Maybe (TLS.Credential, C.KeyPairEd25519) -> (THandleSMP c 'TClient -> IO a) -> IO a +testSMPClient_ host port vr serviceCreds_ client = do + serviceAndKeys_ <- forM serviceCreds_ $ \(serviceCreds@(cc, pk), keys) -> do + Right serviceSignKey <- pure $ C.x509ToPrivate' pk + let idCert' = case chainIdCaCerts cc of + CCSelf cert -> cert + CCValid {idCert} -> idCert + _ -> error "bad certificate" + serviceCertHash = XV.getFingerprint idCert' X.HashSHA256 + pure (ServiceCredentials {serviceRole = SRMessaging, serviceCreds, serviceCertHash, serviceSignKey}, keys) + let tcConfig = defaultTransportClientConfig {clientALPN, clientCredentials = fst <$> serviceCreds_} :: TransportClientConfig runTransportClient tcConfig Nothing host port (Just testKeyHash) $ \h -> - runExceptT (smpClientHandshake h Nothing testKeyHash vr False Nothing) >>= \case + runExceptT (smpClientHandshake h Nothing testKeyHash vr False serviceAndKeys_) >>= \case Right th -> client th Left e -> error $ show e where @@ -165,6 +183,12 @@ testSMPClient_ host port vr client = do | authCmdsSMPVersion `isCompatible` vr = Just alpnSupportedSMPHandshakes | otherwise = Nothing +runSMPClient :: Transport c => TProxy c 'TServer -> (THandleSMP c 'TClient -> IO a) -> IO a +runSMPClient _ test' = testSMPClient test' + +runSMPServiceClient :: Transport c => TProxy c 'TServer -> (TLS.Credential, C.KeyPairEd25519) -> (THandleSMP c 'TClient -> IO a) -> IO a +runSMPServiceClient _ serviceCreds test' = testSMPServiceClient serviceCreds test' + testNtfServiceClient :: Transport c => TProxy c 'TServer -> C.KeyPairEd25519 -> (THandleSMP c 'TClient -> IO a) -> IO a testNtfServiceClient _ keys client = do tlsNtfServerCreds <- loadServerCredential ntfTestServerCredentials diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index b756ce7c9..09f20c1dd 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -224,9 +224,9 @@ agentDeliverMessageViaProxy :: (C.AlgorithmI a, C.AuthAlgorithm a) => (NonEmpty agentDeliverMessageViaProxy aTestCfg@(aSrvs, _, aViaProxy) bTestCfg@(bSrvs, _, bViaProxy) alg msg1 msg2 baseId = withAgent 1 aCfg (servers aTestCfg) testDB $ \alice -> withAgent 2 aCfg (servers bTestCfg) testDB2 $ \bob -> runRight_ $ do - (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + (bobId, CCLink qInfo Nothing) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - (sqSecured, Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe + sqSecured <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` True ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn @@ -280,9 +280,9 @@ agentDeliverMessagesViaProxyConc agentServers msgs = -- agent connections have to be set up in advance -- otherwise the CONF messages would get mixed with MSG prePair alice bob = do - (bobId, (CCLink qInfo Nothing, Nothing)) <- runExceptT' $ A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + (bobId, CCLink qInfo Nothing) <- runExceptT' $ A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe aliceId <- runExceptT' $ A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - (sqSecured, Nothing) <- runExceptT' $ A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe + sqSecured <- runExceptT' $ A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` True confId <- get alice >>= \case @@ -331,7 +331,7 @@ agentViaProxyVersionError = withAgent 1 agentCfg (servers [SMPServer testHost testPort testKeyHash]) testDB $ \alice -> do Left (A.BROKER _ (TRANSPORT TEVersion)) <- withAgent 2 agentCfg (servers [SMPServer testHost2 testPort2 testKeyHash]) testDB2 $ \bob -> runExceptT $ do - (_bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + (_bobId, CCLink qInfo Nothing) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe pure () @@ -351,9 +351,9 @@ agentViaProxyRetryOffline = do let pqEnc = CR.PQEncOn withServer $ \_ -> do (aliceId, bobId) <- withServer2 $ \_ -> runRight $ do - (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + (bobId, CCLink qInfo Nothing) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - (sqSecured, Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe + sqSecured <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` True ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn @@ -434,14 +434,14 @@ agentViaProxyRetryNoSession = do testNoProxy :: AStoreType -> IO () testNoProxy msType = do withSmpServerConfigOn (transport @TLS) (cfgMS msType) testPort2 $ \_ -> do - testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do + testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 Nothing $ \(th :: THandleSMP TLS 'TClient) -> do (_, _, reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer Nothing) reply `shouldBe` Right (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH) testProxyAuth :: AStoreType -> IO () testProxyAuth msType = do withSmpServerConfigOn (transport @TLS) proxyCfgAuth testPort $ \_ -> do - testSMPClient_ "127.0.0.1" testPort proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do + testSMPClient_ "127.0.0.1" testPort proxyVRangeV8 Nothing $ \(th :: THandleSMP TLS 'TClient) -> do (_, _, reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer2 $ Just "wrong") reply `shouldBe` Right (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH) where diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index b2c2d997c..39009794c 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -29,9 +29,11 @@ import Data.Bifunctor (first) import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Foldable (foldrM) import Data.Hashable (hash) import qualified Data.IntSet as IS import Data.List.NonEmpty (NonEmpty) +import Data.Maybe (catMaybes) import Data.String (IsString (..)) import Data.Type.Equality import qualified Data.X509.Validation as XV @@ -50,6 +52,7 @@ import Simplex.Messaging.Server.MsgStore.Types (MsgStoreClass (..), QSType (..), import Simplex.Messaging.Server.Stats (PeriodStatsData (..), ServerStatsData (..)) import Simplex.Messaging.Server.StoreLog (StoreLogRecord (..), closeStoreLog) import Simplex.Messaging.Transport +import Simplex.Messaging.Transport.Credentials import Simplex.Messaging.Util (whenM) import Simplex.Messaging.Version (mkVersionRange) import System.Directory (doesDirectoryExist, doesFileExist, removeDirectoryRecursive, removeFile) @@ -84,6 +87,9 @@ serverTests = do describe "GET & SUB commands" testGetSubCommands describe "Exceeding queue quota" testExceedQueueQuota describe "Concurrent sending and delivery" testConcurrentSendDelivery + describe "Service message subscriptions" $ do + testServiceDeliverSubscribe + testServiceUpgradeAndDowngrade describe "Store log" testWithStoreLog describe "Restore messages" testRestoreMessages describe "Restore messages (old / v2)" testRestoreExpireMessages @@ -111,6 +117,9 @@ pattern New rPub dhPub = NEW (NewQueueReq rPub dhPub Nothing SMSubscribe (Just ( pattern Ids :: RecipientId -> SenderId -> RcvPublicDhKey -> BrokerMsg pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh _sndSecure _linkId Nothing Nothing) +pattern Ids_ :: RecipientId -> SenderId -> RcvPublicDhKey -> ServiceId -> BrokerMsg +pattern Ids_ rId sId srvDh serviceId <- IDS (QIK rId sId srvDh _sndSecure _linkId (Just serviceId) Nothing) + pattern Msg :: MsgId -> MsgBody -> BrokerMsg pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} @@ -135,11 +144,21 @@ serviceSignSendRecv h pk serviceKey t = do [r] <- signSendRecv_ h pk (Just serviceKey) t pure r +serviceSignSendRecv2 :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg), Transmission (Either ErrorType BrokerMsg)) +serviceSignSendRecv2 h pk serviceKey t = do + [r1, r2] <- signSendRecv_ h pk (Just serviceKey) t + pure (r1, r2) + signSendRecv_ :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> Maybe C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (NonEmpty (Transmission (Either ErrorType BrokerMsg))) -signSendRecv_ h@THandle {params} (C.APrivateAuthKey a pk) serviceKey_ (corrId, qId, cmd) = do +signSendRecv_ h pk serviceKey_ t = do + signSend_ h pk serviceKey_ t + tGetClient h + +signSend_ :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> Maybe C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO () +signSend_ h@THandle {params} (C.APrivateAuthKey a pk) serviceKey_ (corrId, qId, cmd) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (authorize tForAuth, tToSend) - liftIO $ tGetClient h + pure () where authorize t = (,(`C.sign'` t) <$> serviceKey_) <$> case a of C.SEd25519 -> Just . TASignature . C.ASignature C.SEd25519 $ C.sign' pk t' @@ -660,6 +679,194 @@ testConcurrentSendDelivery = Resp "4" _ OK <- signSendRecv rh rKey ("4", rId, ACK mId2) pure () +testServiceDeliverSubscribe :: SpecWith (ASrvTransport, AStoreType) +testServiceDeliverSubscribe = + it "should create queue as service and subscribe with SUBS after reconnect" $ \(at@(ATransport t), msType) -> do + g <- C.newRandom + creds <- genCredentials g Nothing (0, 2400) "localhost" + let (_fp, tlsCred) = tlsCredentials [creds] + serviceKeys@(_, servicePK) <- atomically $ C.generateKeyPair g + let aServicePK = C.APrivateAuthKey C.SEd25519 servicePK + withSmpServerConfigOn at (cfgMS msType) testPort $ \_ -> runSMPClient t $ \h -> do + (rPub, rKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (dhPub, dhPriv :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g + (sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + + (rId, sId, dec, serviceId) <- runSMPServiceClient t (tlsCred, serviceKeys) $ \sh -> do + Resp "1" NoEntity (ERR SERVICE) <- signSendRecv sh rKey ("1", NoEntity, New rPub dhPub) + Resp "2" NoEntity (Ids_ rId sId srvDh serviceId) <- serviceSignSendRecv sh rKey servicePK ("2", NoEntity, New rPub dhPub) + let dec = decryptMsgV3 $ C.dh' srvDh dhPriv + Resp "3" sId' OK <- signSendRecv h sKey ("3", sId, SKEY sPub) + sId' `shouldBe` sId + Resp "4" _ OK <- signSendRecv h sKey ("4", sId, _SEND "hello") + Resp "5" _ OK <- signSendRecv h sKey ("5", sId, _SEND "hello 2") + Resp "" rId' (Msg mId1 msg1) <- tGet1 sh + rId' `shouldBe` rId + dec mId1 msg1 `shouldBe` Right "hello" + -- ACK doesn't need service signature + Resp "6" _ (Msg mId2 msg2) <- signSendRecv sh rKey ("6", rId, ACK mId1) + dec mId2 msg2 `shouldBe` Right "hello 2" + Resp "7" _ (ERR NO_MSG) <- signSendRecv sh rKey ("7", rId, ACK mId1) + Resp "8" _ OK <- signSendRecv sh rKey ("8", rId, ACK mId2) + Resp "9" _ OK <- signSendRecv h sKey ("9", sId, _SEND "hello 3") + pure (rId, sId, dec, serviceId) + + runSMPServiceClient t (tlsCred, serviceKeys) $ \sh -> do + Resp "10" NoEntity (ERR (CMD NO_AUTH)) <- signSendRecv sh aServicePK ("10", NoEntity, SUBS) + signSend_ sh aServicePK Nothing ("11", serviceId, SUBS) + [mId3] <- + fmap catMaybes $ + receiveInAnyOrder -- race between SOKS and MSG, clients can handle it + sh + [ \case + Resp "11" serviceId' (SOKS n _) -> do + n `shouldBe` 1 + serviceId' `shouldBe` serviceId + pure $ Just Nothing + _ -> pure Nothing, + \case + Resp "" rId'' (Msg mId3 msg3) -> do + rId'' `shouldBe` rId + dec mId3 msg3 `shouldBe` Right "hello 3" + pure $ Just $ Just mId3 + _ -> pure Nothing + ] + Resp "" NoEntity SALL <- tGet1 sh + Resp "12" _ OK <- signSendRecv sh rKey ("12", rId, ACK mId3) + Resp "14" _ OK <- signSendRecv h sKey ("14", sId, _SEND "hello 4") + Resp "" _ (Msg mId4 msg4) <- tGet1 sh + dec mId4 msg4 `shouldBe` Right "hello 4" + Resp "15" _ OK <- signSendRecv sh rKey ("15", rId, ACK mId4) + pure () + +testServiceUpgradeAndDowngrade :: SpecWith (ASrvTransport, AStoreType) +testServiceUpgradeAndDowngrade = + it "should create queue as client and switch to service and back" $ \(at@(ATransport t), msType) -> do + g <- C.newRandom + creds <- genCredentials g Nothing (0, 2400) "localhost" + let (_fp, tlsCred) = tlsCredentials [creds] + serviceKeys@(_, servicePK) <- atomically $ C.generateKeyPair g + let aServicePK = C.APrivateAuthKey C.SEd25519 servicePK + withSmpServerConfigOn at (cfgMS msType) testPort $ \_ -> runSMPClient t $ \h -> do + (rPub, rKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (dhPub, dhPriv :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g + (sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (rPub2, rKey2) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (dhPub2, dhPriv2 :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g + (sPub2, sKey2) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (rPub3, rKey3) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (dhPub3, dhPriv3 :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g + (sPub3, sKey3) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + + (rId, sId, dec) <- runSMPClient t $ \sh -> do + Resp "1" NoEntity (Ids rId sId srvDh) <- signSendRecv sh rKey ("1", NoEntity, New rPub dhPub) + let dec = decryptMsgV3 $ C.dh' srvDh dhPriv + Resp "2" sId' OK <- signSendRecv h sKey ("2", sId, SKEY sPub) + sId' `shouldBe` sId + Resp "3" _ OK <- signSendRecv h sKey ("3", sId, _SEND "hello") + Resp "" rId' (Msg mId1 msg1) <- tGet1 sh + rId' `shouldBe` rId + dec mId1 msg1 `shouldBe` Right "hello" + Resp "4" _ OK <- signSendRecv sh rKey ("4", rId, ACK mId1) + Resp "5" _ OK <- signSendRecv h sKey ("5", sId, _SEND "hello 2") + pure (rId, sId, dec) + + -- split to prevent message delivery + (rId2, sId2, dec2) <- runSMPClient t $ \sh -> do + Resp "6" NoEntity (Ids rId2 sId2 srvDh2) <- signSendRecv sh rKey2 ("6", NoEntity, New rPub2 dhPub2) + let dec2 = decryptMsgV3 $ C.dh' srvDh2 dhPriv2 + Resp "7" sId2' OK <- signSendRecv h sKey2 ("7", sId2, SKEY sPub2) + sId2' `shouldBe` sId2 + pure (rId2, sId2, dec2) + + (rId3, _sId3, _dec3) <- runSMPClient t $ \sh -> do + Resp "6" NoEntity (Ids rId3 sId3 srvDh3) <- signSendRecv sh rKey3 ("6", NoEntity, New rPub3 dhPub3) + let dec3 = decryptMsgV3 $ C.dh' srvDh3 dhPriv3 + Resp "7" sId3' OK <- signSendRecv h sKey3 ("7", sId3, SKEY sPub3) + sId3' `shouldBe` sId3 + pure (rId3, sId3, dec3) + + serviceId <- runSMPServiceClient t (tlsCred, serviceKeys) $ \sh -> do + Resp "8" _ (ERR SERVICE) <- signSendRecv sh rKey ("8", rId, SUB) + (Resp "9" rId' (SOK (Just serviceId)), Resp "" rId'' (Msg mId2 msg2)) <- serviceSignSendRecv2 sh rKey servicePK ("9", rId, SUB) + rId' `shouldBe` rId + rId'' `shouldBe` rId + dec mId2 msg2 `shouldBe` Right "hello 2" + (Resp "10" rId2' (SOK (Just serviceId'))) <- serviceSignSendRecv sh rKey2 servicePK ("10", rId2, SUB) + rId2' `shouldBe` rId2 + serviceId' `shouldBe` serviceId + Resp "10.1" _ OK <- signSendRecv sh rKey ("10.1", rId, ACK mId2) + (Resp "10.2" rId3' (SOK (Just serviceId''))) <- serviceSignSendRecv sh rKey3 servicePK ("10.2", rId3, SUB) + rId3' `shouldBe` rId3 + serviceId'' `shouldBe` serviceId + pure serviceId + + Resp "11" _ OK <- signSendRecv h sKey ("11", sId, _SEND "hello 3.1") + Resp "12" _ OK <- signSendRecv h sKey2 ("12", sId2, _SEND "hello 3.2") + + runSMPServiceClient t (tlsCred, serviceKeys) $ \sh -> do + signSend_ sh aServicePK Nothing ("14", serviceId, SUBS) + [(rKey3_1, rId3_1, mId3_1), (rKey3_2, rId3_2, mId3_2)] <- + fmap catMaybes $ + receiveInAnyOrder -- race between SOKS and MSG, clients can handle it + sh + [ \case + Resp "14" serviceId' (SOKS n _) -> do + n `shouldBe` 3 + serviceId' `shouldBe` serviceId + pure $ Just Nothing + _ -> pure Nothing, + \case + Resp "" rId'' (Msg mId3 msg3) | rId'' == rId -> do + dec mId3 msg3 `shouldBe` Right "hello 3.1" + pure $ Just $ Just (rKey, rId, mId3) + _ -> pure Nothing, + \case + Resp "" rId'' (Msg mId3 msg3) | rId'' == rId2 -> do + dec2 mId3 msg3 `shouldBe` Right "hello 3.2" + pure $ Just $ Just (rKey2, rId2, mId3) + _ -> pure Nothing + ] + Resp "" NoEntity SALL <- tGet1 sh + Resp "15" _ OK <- signSendRecv sh rKey3_1 ("15", rId3_1, ACK mId3_1) + Resp "16" _ OK <- signSendRecv sh rKey3_2 ("16", rId3_2, ACK mId3_2) + pure () + + Resp "17" _ OK <- signSendRecv h sKey ("17", sId, _SEND "hello 4") + + runSMPClient t $ \sh -> do + Resp "18" _ (ERR SERVICE) <- signSendRecv sh aServicePK ("18", serviceId, SUBS) + (Resp "19" rId' (SOK Nothing), Resp "" rId'' (Msg mId4 msg4)) <- signSendRecv2 sh rKey ("19", rId, SUB) + rId' `shouldBe` rId + rId'' `shouldBe` rId + dec mId4 msg4 `shouldBe` Right "hello 4" + Resp "20" _ OK <- signSendRecv sh rKey ("20", rId, ACK mId4) + Resp "21" _ OK <- signSendRecv h sKey ("21", sId, _SEND "hello 5") + Resp "" _ (Msg mId5 msg5) <- tGet1 sh + dec mId5 msg5 `shouldBe` Right "hello 5" + Resp "22" _ OK <- signSendRecv sh rKey ("22", rId, ACK mId5) + + Resp "23" rId2' (SOK Nothing) <- signSendRecv sh rKey2 ("23", rId2, SUB) + rId2' `shouldBe` rId2 + Resp "24" _ OK <- signSendRecv h sKey ("24", sId, _SEND "hello 6") + Resp "" _ (Msg mId6 msg6) <- tGet1 sh + dec mId6 msg6 `shouldBe` Right "hello 6" + Resp "25" _ OK <- signSendRecv sh rKey ("25", rId, ACK mId6) + pure () + +receiveInAnyOrder :: (HasCallStack, Transport c) => THandleSMP c 'TClient -> [(CorrId, EntityId, Either ErrorType BrokerMsg) -> IO (Maybe b)] -> IO [b] +receiveInAnyOrder h = fmap reverse . go [] + where + go rs [] = pure rs + go rs ps = withFrozenCallStack $ do + r <- 5000000 `timeout` tGet1 h >>= maybe (error "inAnyOrder timeout") pure + (r_, ps') <- foldrM (choose r) (Nothing, []) ps + case r_ of + Just r' -> go (r' : rs) ps' + Nothing -> error $ "unexpected event: " <> show r + choose r p (Nothing, ps') = (maybe (Nothing, p : ps') ((,ps') . Just)) <$> p r + choose _ p (Just r, ps') = pure (Just r, p : ps') + testWithStoreLog :: SpecWith (ASrvTransport, AStoreType) testWithStoreLog = it "should store simplex queues to log and restore them after server restart" $ \(at@(ATransport t), msType) -> do @@ -1159,7 +1366,7 @@ testMessageServiceNotifications = deliverMessage rh rId rKey sh sId sKey nh2 "connection 1" dec deliverMessage rh rId'' rKey'' sh sId'' sKey'' nh2 "connection 2" dec'' -- -- another client makes service subscription - Resp "12" serviceId5 (SOKS 2) <- signSendRecv nh1 (C.APrivateAuthKey C.SEd25519 servicePK) ("12", serviceId, NSUBS) + Resp "12" serviceId5 (SOKS 2 _) <- signSendRecv nh1 (C.APrivateAuthKey C.SEd25519 servicePK) ("12", serviceId, NSUBS) serviceId5 `shouldBe` serviceId Resp "" serviceId6 (ENDS 2) <- tGet1 nh2 serviceId6 `shouldBe` serviceId @@ -1193,7 +1400,7 @@ testServiceNotificationsTwoRestarts = threadDelay 250000 withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh -> testNtfServiceClient t serviceKeys $ \nh -> do - Resp "2.1" serviceId' (SOKS n) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("2.1", serviceId, NSUBS) + Resp "2.1" serviceId' (SOKS n _) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("2.1", serviceId, NSUBS) n `shouldBe` 1 Resp "2.2" _ (SOK Nothing) <- signSendRecv rh rKey ("2.2", rId, SUB) serviceId' `shouldBe` serviceId @@ -1201,7 +1408,7 @@ testServiceNotificationsTwoRestarts = threadDelay 250000 withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh -> testNtfServiceClient t serviceKeys $ \nh -> do - Resp "3.1" _ (SOKS n) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("3.1", serviceId, NSUBS) + Resp "3.1" _ (SOKS n _) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("3.1", serviceId, NSUBS) n `shouldBe` 1 Resp "3.2" _ (SOK Nothing) <- signSendRecv rh rKey ("3.2", rId, SUB) deliverMessage rh rId rKey sh sId sKey nh "hello 3" dec From 3ccf8548658d809b0eaaf64c95208cc1b0f7a5ea Mon Sep 17 00:00:00 2001 From: Evgeny Date: Tue, 25 Nov 2025 16:55:59 +0000 Subject: [PATCH 02/11] servers: maintain xor-hash of all associated queue IDs in PostgreSQL (#1668) * servers: maintain xor-hash of all associated queue IDs in PostgreSQL (#1615) * ntf server: maintain xor-hash of all associated queue IDs via PostgreSQL triggers * smp server: xor hash with triggers * fix sql and using pgcrypto extension in tests * track counts and hashes in smp/ntf servers via triggers, smp server stats for service subscription, update SMP protocol to pass expected count and hash in SSUB/NSSUB commands * agent migrations with functions/triggers * remove agent triggers * try tracking service subs in the agent (WIP, does not compile) * Revert "try tracking service subs in the agent (WIP, does not compile)" This reverts commit 59e908100d21ddb6eb95c75d49821d2349fc4d6c. * comment * agent database triggers * service subscriptions in the client * test / fix client services * update schema * fix postgres migration * update schema * move schema test to the end * use static function with SQLite to avoid dynamic wrapper --- simplexmq.cabal | 3 + src/Simplex/Messaging/Agent.hs | 27 +- src/Simplex/Messaging/Agent/Client.hs | 124 +- .../Messaging/Agent/NtfSubSupervisor.hs | 2 +- .../Messaging/Agent/Store/AgentStore.hs | 41 +- .../Agent/Store/Postgres/Migrations/App.hs | 4 +- .../Migrations/M20251020_service_certs.hs | 114 ++ .../Agent/Store/Postgres/Migrations/Util.hs | 46 + .../Migrations/agent_postgres_schema.sql | 1469 +++++++++++++++++ .../Messaging/Agent/Store/Postgres/Util.hs | 112 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 35 +- .../Messaging/Agent/Store/SQLite/Common.hs | 6 + .../Migrations/M20251020_service_certs.hs | 63 +- .../Store/SQLite/Migrations/agent_schema.sql | 52 +- .../Messaging/Agent/Store/SQLite/Util.hs | 41 + src/Simplex/Messaging/Agent/TSessionSubs.hs | 84 +- src/Simplex/Messaging/Client.hs | 10 +- src/Simplex/Messaging/Client/Agent.hs | 57 +- src/Simplex/Messaging/Crypto.hs | 6 +- .../Messaging/Notifications/Protocol.hs | 14 +- src/Simplex/Messaging/Notifications/Server.hs | 19 +- .../Messaging/Notifications/Server/Stats.hs | 1 + .../Notifications/Server/Store/Migrations.hs | 126 +- .../Notifications/Server/Store/Postgres.hs | 35 +- .../Server/Store/ntf_server_schema.sql | 133 +- src/Simplex/Messaging/Protocol.hs | 72 +- src/Simplex/Messaging/Server.hs | 49 +- .../Messaging/Server/MsgStore/Journal.hs | 4 +- src/Simplex/Messaging/Server/Prometheus.hs | 1 + src/Simplex/Messaging/Server/QueueStore.hs | 1 + .../Messaging/Server/QueueStore/Postgres.hs | 18 +- .../Server/QueueStore/Postgres/Migrations.hs | 140 +- .../QueueStore/Postgres/server_schema.sql | 146 +- .../Messaging/Server/QueueStore/STM.hs | 44 +- .../Messaging/Server/QueueStore/Types.hs | 2 +- src/Simplex/Messaging/Server/Stats.hs | 82 +- .../Messaging/Server/StoreLog/ReadWrite.hs | 2 +- tests/AgentTests/EqInstances.hs | 5 + tests/AgentTests/FunctionalAPITests.hs | 28 + tests/CoreTests/TSessionSubs.hs | 24 +- tests/Fixtures.hs | 5 + tests/SMPAgentClient.hs | 3 + tests/ServerTests.hs | 29 +- tests/Test.hs | 19 +- 44 files changed, 2968 insertions(+), 330 deletions(-) create mode 100644 src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20251020_service_certs.hs create mode 100644 src/Simplex/Messaging/Agent/Store/Postgres/Migrations/Util.hs create mode 100644 src/Simplex/Messaging/Agent/Store/Postgres/Migrations/agent_postgres_schema.sql create mode 100644 src/Simplex/Messaging/Agent/Store/SQLite/Util.hs diff --git a/simplexmq.cabal b/simplexmq.cabal index 081c05bca..0eeec3cfd 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -167,6 +167,7 @@ library Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250702_conn_invitations_remove_cascade_delete Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251009_queue_to_subscribe Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251010_client_notices + Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251020_service_certs else exposed-modules: Simplex.Messaging.Agent.Store.SQLite @@ -217,12 +218,14 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251009_queue_to_subscribe Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251010_client_notices Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251020_service_certs + Simplex.Messaging.Agent.Store.SQLite.Util if flag(client_postgres) || flag(server_postgres) exposed-modules: Simplex.Messaging.Agent.Store.Postgres Simplex.Messaging.Agent.Store.Postgres.Common Simplex.Messaging.Agent.Store.Postgres.DB Simplex.Messaging.Agent.Store.Postgres.Migrations + Simplex.Messaging.Agent.Store.Postgres.Migrations.Util Simplex.Messaging.Agent.Store.Postgres.Util if !flag(client_library) exposed-modules: diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index f9f1dc089..63516ada4 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -211,7 +211,6 @@ import Simplex.Messaging.Protocol ErrorType (AUTH), MsgBody, MsgFlags (..), - IdsHash, NtfServer, ProtoServerWithAuth (..), ProtocolServer (..), @@ -222,6 +221,7 @@ import Simplex.Messaging.Protocol SMPMsgMeta, SParty (..), SProtocolType (..), + ServiceSub (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, @@ -500,7 +500,7 @@ resubscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either Agen resubscribeConnections c = withAgentEnv c . resubscribeConnections' c {-# INLINE resubscribeConnections #-} -subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType (Int64, IdsHash))) +subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType ServiceSub)) subscribeClientServices c = withAgentEnv c . subscribeClientServices' c {-# INLINE subscribeClientServices #-} @@ -594,6 +594,7 @@ testProtocolServer c nm userId srv = withAgentEnv' c $ case protocolTypeI @p of SPNTF -> runNTFServerTest c nm userId srv -- | set SOCKS5 proxy on/off and optionally set TCP timeouts for fast network +-- TODO [certs rcv] should fail if any user is enabled to use services and per-connection isolation is chosen setNetworkConfig :: AgentClient -> NetworkConfig -> IO () setNetworkConfig c@AgentClient {useNetworkConfig, proxySessTs} cfg' = do ts <- getCurrentTime @@ -771,6 +772,7 @@ deleteUser' c@AgentClient {smpServersStats, xftpServersStats} userId delSMPQueue whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ DEL_USER userId) +-- TODO [certs rcv] should fail enabling if per-connection isolation is set setUserService' :: AgentClient -> UserId -> Bool -> AM () setUserService' c userId enable = do wasEnabled <- liftIO $ fromMaybe False <$> TM.lookupIO userId (useClientServices c) @@ -1507,15 +1509,15 @@ resubscribeConnections' c connIds = do [] -> pure True rqs' -> anyM $ map (atomically . hasActiveSubscription c) rqs' --- TODO [certs rcv] compare hash with lock -subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType (Int64, IdsHash))) +-- TODO [certs rcv] compare hash. possibly, it should return both expected and returned counts +subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType ServiceSub)) subscribeClientServices' c userId = ifM useService subscribe $ throwError $ CMD PROHIBITED "no user service allowed" where useService = liftIO $ (Just True ==) <$> TM.lookupIO userId (useClientServices c) subscribe = do srvs <- withStore' c (`getClientServiceServers` userId) - lift $ M.fromList . zip srvs <$> mapConcurrently (tryAllErrors' . subscribeClientService c userId) srvs + lift $ M.fromList <$> mapConcurrently (\(srv, ServiceSub _ n idsHash) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c userId srv n idsHash) srvs -- requesting messages sequentially, to reduce memory usage getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta))) @@ -2829,12 +2831,13 @@ processSMPTransmissions :: AgentClient -> ServerTransmissionBatch SMPVersion Err processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId, ts) = do upConnIds <- newTVarIO [] forM_ ts $ \(entId, t) -> case t of - STEvent msgOrErr -> - withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of - Right msg -> runProcessSMP rq conn (toConnData conn) msg - Left e -> lift $ do - processClientNotice rq e - notifyErr connId e + STEvent msgOrErr + | entId == SMP.NoEntity -> pure () -- TODO [certs rcv] process SALL + | otherwise -> withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of + Right msg -> runProcessSMP rq conn (toConnData conn) msg + Left e -> lift $ do + processClientNotice rq e + notifyErr connId e STResponse (Cmd SRecipient cmd) respOrErr -> withRcvConn entId $ \rq conn -> case cmd of SMP.SUB -> case respOrErr of @@ -2870,7 +2873,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId processSubOk :: RcvQueue -> TVar [ConnId] -> IO () processSubOk rq@RcvQueue {connId} upConnIds = atomically . whenM (isPendingSub rq) $ do - SS.addActiveSub tSess sessId (rcvQueueSub rq) $ currentSubs c + SS.addActiveSub tSess sessId rq $ currentSubs c modifyTVar' upConnIds (connId :) processSubErr :: RcvQueue -> SMPClientError -> AM' () processSubErr rq@RcvQueue {connId} e = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 4a10d07ef..68d7ef62b 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -241,7 +241,7 @@ import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Stats import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.AgentStore -import Simplex.Messaging.Agent.Store.Common (DBStore, withTransaction) +import Simplex.Messaging.Agent.Store.Common (DBStore) import qualified Simplex.Messaging.Agent.Store.DB as DB import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Agent.TSessionSubs (TSessionSubs) @@ -279,6 +279,7 @@ import Simplex.Messaging.Protocol RcvNtfPublicDhKey, SMPMsgMeta (..), SProtocolType (..), + ServiceSub (..), SndPublicAuthKey, SubscriptionMode (..), NewNtfCreds (..), @@ -499,6 +500,7 @@ data UserNetworkType = UNNone | UNCellular | UNWifi | UNEthernet | UNOther deriving (Eq, Show) -- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. +-- TODO [certs rcv] should fail if both per-connection isolation is set and any users use services newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Map (Maybe SMPServer) (Maybe SystemSeconds) -> Env -> IO AgentClient newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, useServices, presetDomains, presetServers} currentTs notices agentEnv = do let cfg = config agentEnv @@ -622,9 +624,8 @@ getServiceCredentials c userId srv = let tlsCreds = tlsCredentials [cred] createClientService db userId srv tlsCreds pure (tlsCreds, Nothing) - (_, pk) <- atomically $ C.generateKeyPair g - let serviceSignKey = C.APrivateSignKey C.SEd25519 pk - creds = ServiceCredentials {serviceRole = SRMessaging, serviceCreds, serviceCertHash = XV.Fingerprint kh, serviceSignKey} + serviceSignKey <- liftEitherWith INTERNAL $ C.x509ToPrivate' $ snd serviceCreds + let creds = ServiceCredentials {serviceRole = SRMessaging, serviceCreds, serviceCertHash = XV.Fingerprint kh, serviceSignKey} pure (creds, serviceId_) class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where @@ -744,9 +745,11 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm smp <- liftError (protocolClientError SMP $ B.unpack $ strEncode srv) $ do ts <- readTVarIO proxySessTs ExceptT $ getProtocolClient g nm tSess cfg' presetDomains (Just msgQ) ts $ smpClientDisconnected c tSess env v' prs + -- TODO [certs rcv] add service to SS, possibly combine with SS.setSessionId atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c updateClientService service smp pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs} + -- TODO [certs rcv] this should differentiate between service ID just set and service ID changed, and in the latter case disassociate the queue updateClientService service smp = case (service, smpClientService smp) of (Just (_, serviceId_), Just THClientService {serviceId}) | serviceId_ /= Just serviceId -> withStore' c $ \db -> setClientServiceId db userId srv serviceId @@ -763,32 +766,34 @@ smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess -- we make active subscriptions pending only if the client for tSess was current (in the map) and active, -- because we can have a race condition when a new current client could have already -- made subscriptions active, and the old client would be processing diconnection later. - removeClientAndSubs :: IO ([RcvQueueSub], [ConnId]) + removeClientAndSubs :: IO ([RcvQueueSub], [ConnId], Maybe ServiceSub) removeClientAndSubs = atomically $ do removeSessVar v tSess smpClients - ifM (readTVar active) removeSubs (pure ([], [])) + ifM (readTVar active) removeSubs (pure ([], [], Nothing)) where sessId = sessionId $ thParams client removeSubs = do mode <- getSessionMode c - subs <- SS.setSubsPending mode tSess sessId $ currentSubs c + (subs, serviceSub_) <- SS.setSubsPending mode tSess sessId $ currentSubs c let qs = M.elems subs cs = nubOrd $ map qConnId qs -- this removes proxied relays that this client created sessions to destSrvs <- M.keys <$> readTVar prs forM_ destSrvs $ \destSrv -> TM.delete (userId, destSrv, cId) smpProxiedRelays - pure (qs, cs) + pure (qs, cs, serviceSub_) - serverDown :: ([RcvQueueSub], [ConnId]) -> IO () - serverDown (qs, conns) = whenM (readTVarIO active) $ do + serverDown :: ([RcvQueueSub], [ConnId], Maybe ServiceSub) -> IO () + serverDown (qs, conns, serviceSub_) = whenM (readTVarIO active) $ do notifySub c $ hostEvent' DISCONNECT client unless (null conns) $ notifySub c $ DOWN srv conns - unless (null qs) $ do + unless (null qs && isNothing serviceSub_) $ do releaseGetLocksIO c qs mode <- getSessionModeIO c let resubscribe | (mode == TSMEntity) == isJust cId = resubscribeSMPSession c tSess - | otherwise = void $ subscribeQueues c True qs + | otherwise = do + mapM_ (runExceptT . resubscribeClientService c tSess) serviceSub_ + unless (null qs) $ void $ subscribeQueues c True qs runReaderT resubscribe env resubscribeSMPSession :: AgentClient -> SMPTransportSession -> AM' () @@ -807,11 +812,12 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do runSubWorker = do ri <- asks $ reconnectInterval . config withRetryForeground ri isForeground (isNetworkOnline c) $ \_ loop -> do - pending <- atomically $ SS.getPendingSubs tSess $ currentSubs c - unless (M.null pending) $ do + (pendingSubs, pendingSS) <- atomically $ SS.getPendingSubs tSess $ currentSubs c + unless (M.null pendingSubs && isNothing pendingSS) $ do liftIO $ waitUntilForeground c liftIO $ waitForUserNetwork c - handleNotify $ resubscribeSessQueues c tSess $ M.elems pending + mapM_ (handleNotify . void . runExceptT . resubscribeClientService c tSess) pendingSS + unless (M.null pendingSubs) $ handleNotify $ resubscribeSessQueues c tSess $ M.elems pendingSubs loop isForeground = (ASForeground ==) <$> readTVar (agentState c) cleanup :: SessionVar (Async ()) -> STM () @@ -1508,25 +1514,25 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl newErr :: String -> AM (Maybe ShortLinkCreds) newErr = throwE . BROKER (B.unpack $ strEncode srv) . UNEXPECTED . ("Create queue: " <>) -processSubResults :: AgentClient -> SMPTransportSession -> SessionId -> NonEmpty (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> STM [(RcvQueueSub, Maybe ClientNotice)] -processSubResults c tSess@(userId, srv, _) sessId rs = do - pendingSubs <- SS.getPendingSubs tSess $ currentSubs c - let (failed, subscribed, notices, ignored) = foldr (partitionResults pendingSubs) (M.empty, [], [], 0) rs +processSubResults :: AgentClient -> SMPTransportSession -> SessionId -> Maybe ServiceId -> NonEmpty (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> STM ([RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)]) +processSubResults c tSess@(userId, srv, _) sessId smpServiceId rs = do + pending <- SS.getPendingSubs tSess $ currentSubs c + let (failed, subscribed@(qs, sQs), notices, ignored) = foldr (partitionResults pending) (M.empty, ([], []), [], 0) rs unless (M.null failed) $ do incSMPServerStat' c userId srv connSubErrs $ M.size failed failSubscriptions c tSess failed - unless (null subscribed) $ do - incSMPServerStat' c userId srv connSubscribed $ length subscribed + unless (null qs && null sQs) $ do + incSMPServerStat' c userId srv connSubscribed $ length qs + length sQs SS.batchAddActiveSubs tSess sessId subscribed $ currentSubs c unless (ignored == 0) $ incSMPServerStat' c userId srv connSubIgnored ignored - pure notices + pure (sQs, notices) where partitionResults :: - Map SMP.RecipientId RcvQueueSub -> + (Map SMP.RecipientId RcvQueueSub, Maybe ServiceSub) -> (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> - (Map SMP.RecipientId SMPClientError, [RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)], Int) -> - (Map SMP.RecipientId SMPClientError, [RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)], Int) - partitionResults pendingSubs (rq@RcvQueueSub {rcvId, clientNoticeId}, r) acc@(failed, subscribed, notices, ignored) = case r of + (Map SMP.RecipientId SMPClientError, ([RcvQueueSub], [RcvQueueSub]), [(RcvQueueSub, Maybe ClientNotice)], Int) -> + (Map SMP.RecipientId SMPClientError, ([RcvQueueSub], [RcvQueueSub]), [(RcvQueueSub, Maybe ClientNotice)], Int) + partitionResults (pendingSubs, pendingSS) (rq@RcvQueueSub {rcvId, clientNoticeId}, r) acc@(failed, subscribed@(qs, sQs), notices, ignored) = case r of Left e -> case smpErrorClientNotice e of Just notice_ -> (failed', subscribed, (rq, notice_) : notices, ignored) where @@ -1536,8 +1542,12 @@ processSubResults c tSess@(userId, srv, _) sessId rs = do | otherwise -> (failed', subscribed, notices, ignored) where failed' = M.insert rcvId e failed - Right _serviceId -- TODO [certs rcv] store association with the service - | rcvId `M.member` pendingSubs -> (failed, rq : subscribed, notices', ignored) + Right serviceId_ + | rcvId `M.member` pendingSubs -> + let subscribed' = case (smpServiceId, serviceId_, pendingSS) of + (Just sId, Just sId', Just ServiceSub {serviceId}) | sId == sId' && sId == serviceId -> (qs, rq : sQs) + _ -> (rq : qs, sQs) + in (failed, subscribed', notices', ignored) | otherwise -> (failed, subscribed, notices', ignored + 1) where notices' = if isJust clientNoticeId then (rq, Nothing) : notices else notices @@ -1576,6 +1586,7 @@ serverHostError = \case -- | Batch by transport session and subscribe queues. The list of results can have a different order. subscribeQueues :: AgentClient -> Bool -> [RcvQueueSub] -> AM' [(RcvQueueSub, Either AgentErrorType (Maybe ServiceId))] +subscribeQueues _ _ [] = pure [] subscribeQueues c withEvents qs = do (errs, qs') <- checkQueues c qs atomically $ modifyTVar' (subscrConns c) (`S.union` S.fromList (map qConnId qs')) @@ -1632,6 +1643,7 @@ checkQueues c = fmap partitionEithers . mapM checkQueue -- This function expects that all queues belong to one transport session, -- and that they are already added to pending subscriptions. resubscribeSessQueues :: AgentClient -> SMPTransportSession -> [RcvQueueSub] -> AM' () +resubscribeSessQueues _ _ [] = pure () resubscribeSessQueues c tSess qs = do (errs, qs_) <- checkQueues c qs forM_ (L.nonEmpty qs_) $ \qs' -> void $ subscribeSessQueues_ c True (tSess, qs') @@ -1650,13 +1662,15 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c then Just . S.fromList . map qConnId . M.elems <$> atomically (SS.getActiveSubs tSess $ currentSubs c) else pure Nothing active <- E.uninterruptibleMask_ $ do - (active, notices) <- atomically $ do - r@(_, notices) <- ifM + (active, (serviceQs, notices)) <- atomically $ do + r@(_, (_, notices)) <- ifM (activeClientSession c tSess sessId) - ((True,) <$> processSubResults c tSess sessId rs) - ((False, []) <$ incSMPServerStat' c userId srv connSubIgnored (length rs)) + ((True,) <$> processSubResults c tSess sessId smpServiceId rs) + ((False, ([], [])) <$ incSMPServerStat' c userId srv connSubIgnored (length rs)) unless (null notices) $ takeTMVar $ clientNoticesLock c pure r + unless (null serviceQs) $ void $ + processRcvServiceAssocs c serviceQs `runReaderT` agentEnv c unless (null notices) $ void $ (processClientNotices c tSess notices `runReaderT` agentEnv c) `E.finally` atomically (putTMVar (clientNoticesLock c) ()) @@ -1677,6 +1691,13 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c where tSess = transportSession' smp sessId = sessionId $ thParams smp + smpServiceId = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp + +processRcvServiceAssocs :: AgentClient -> [RcvQueueSub] -> AM' () +processRcvServiceAssocs c serviceQs = + withStore' c (`setRcvServiceAssocs` serviceQs) `catchAllErrors'` \e -> do + logError $ "processClientNotices error: " <> tshow e + notifySub' c "" $ ERR e processClientNotices :: AgentClient -> SMPTransportSession -> [(RcvQueueSub, Maybe ClientNotice)] -> AM' () processClientNotices c@AgentClient {presetServers} tSess notices = do @@ -1689,10 +1710,35 @@ processClientNotices c@AgentClient {presetServers} tSess notices = do logError $ "processClientNotices error: " <> tshow e notifySub' c "" $ ERR e -subscribeClientService :: AgentClient -> UserId -> SMPServer -> AM (Int64, IdsHash) -subscribeClientService c userId srv = - withLogClient c NRMBackground (userId, srv, Nothing) B.empty "SUBS" $ - (`subscribeService` SMP.SRecipientService) . connectedClient +resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub -> AM ServiceSub +resubscribeClientService c tSess (ServiceSub _ n idsHash) = + withServiceClient c tSess $ \smp _ -> do + subscribeClientService_ c tSess smp n idsHash + +subscribeClientService :: AgentClient -> UserId -> SMPServer -> Int64 -> IdsHash -> AM ServiceSub +subscribeClientService c userId srv n idsHash = + withServiceClient c tSess $ \smp smpServiceId -> do + let serviceSub = ServiceSub smpServiceId n idsHash + atomically $ SS.setPendingServiceSub tSess serviceSub $ currentSubs c + subscribeClientService_ c tSess smp n idsHash + where + tSess = (userId, srv, Nothing) + +withServiceClient :: AgentClient -> SMPTransportSession -> (SMPClient -> ServiceId -> ExceptT SMPClientError IO a) -> AM a +withServiceClient c tSess action = + withLogClient c NRMBackground tSess B.empty "SUBS" $ \(SMPConnectedClient smp _) -> + case (\THClientService {serviceId} -> serviceId) <$> smpClientService smp of + Just smpServiceId -> action smp smpServiceId + Nothing -> throwE PCEServiceUnavailable + +subscribeClientService_ :: AgentClient -> SMPTransportSession -> SMPClient -> Int64 -> IdsHash -> ExceptT SMPClientError IO ServiceSub +subscribeClientService_ c tSess smp n idsHash = do + -- TODO [certs rcv] handle error + serviceSub' <- subscribeService smp SMP.SRecipientService n idsHash + let sessId = sessionId $ thParams smp + atomically $ whenM (activeClientSession c tSess sessId) $ + SS.setActiveServiceSub tSess sessId serviceSub' $ currentSubs c + pure serviceSub' activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClients c) @@ -1762,7 +1808,7 @@ addNewQueueSubscription c rq' tSess sessId = do modifyTVar' (subscrConns c) $ S.insert $ qConnId rq active <- activeClientSession c tSess sessId if active - then SS.addActiveSub tSess sessId rq $ currentSubs c + then SS.addActiveSub tSess sessId rq' $ currentSubs c else SS.addPendingSub tSess rq $ currentSubs c pure active unless same $ resubscribeSMPSession c tSess @@ -1951,6 +1997,7 @@ releaseGetLock c rq = {-# INLINE releaseGetLock #-} releaseGetLocksIO :: SomeRcvQueue q => AgentClient -> [q] -> IO () +releaseGetLocksIO _ [] = pure () releaseGetLocksIO c rqs = do locks <- readTVarIO $ getMsgLocks c forM_ rqs $ \rq -> @@ -2301,7 +2348,8 @@ withStore c action = do [ E.Handler $ \(e :: SQL.SQLError) -> let se = SQL.sqlError e busy = se == SQL.ErrorBusy || se == SQL.ErrorLocked - in pure . Left . (if busy then SEDatabaseBusy else SEInternal) $ bshow se, + err = tshow se <> ": " <> SQL.sqlErrorDetails e <> ", " <> SQL.sqlErrorContext e + in pure . Left . (if busy then SEDatabaseBusy else SEInternal) $ encodeUtf8 err, E.Handler $ \(E.SomeException e) -> pure . Left $ SEInternal $ bshow e ] #endif diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index fe852ac64..f5a2b281d 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -314,7 +314,7 @@ runNtfWorker c srv Worker {doWork} = _ -> ((ntfSubConnId sub, INTERNAL "NSACheck - no subscription ID") : errs, subs, subIds) updateSub :: DB.Connection -> NtfServer -> UTCTime -> UTCTime -> (NtfSubscription, NtfSubStatus) -> IO (Maybe SMPServer) updateSub db ntfServer ts nextCheckTs (sub, status) - | ntfShouldSubscribe status = + | status `elem` subscribeNtfStatuses = let sub' = sub {ntfSubStatus = NASCreated status} in Nothing <$ updateNtfSubscription db sub' (NSANtf NSACheck) nextCheckTs -- ntf server stopped subscribing to this queue diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index 0b2c632fa..b519f381e 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -53,6 +53,7 @@ module Simplex.Messaging.Agent.Store.AgentStore getSubscriptionServers, getUserServerRcvQueueSubs, unsetQueuesToSubscribe, + setRcvServiceAssocs, getConnIds, getConn, getDeletedConn, @@ -401,29 +402,31 @@ deleteUsersWithoutConns db = do pure userIds createClientService :: DB.Connection -> UserId -> SMPServer -> (C.KeyHash, TLS.Credential) -> IO () -createClientService db userId srv (kh, (cert, pk)) = +createClientService db userId srv (kh, (cert, pk)) = do + serverKeyHash_ <- createServer_ db srv DB.execute db [sql| INSERT INTO client_services - (user_id, host, port, service_cert_hash, service_cert, service_priv_key) - VALUES (?,?,?,?,?,?) - ON CONFLICT (user_id, host, port) + (user_id, host, port, server_key_hash, service_cert_hash, service_cert, service_priv_key) + VALUES (?,?,?,?,?,?,?) + ON CONFLICT (user_id, host, port, server_key_hash) DO UPDATE SET service_cert_hash = EXCLUDED.service_cert_hash, service_cert = EXCLUDED.service_cert, service_priv_key = EXCLUDED.service_priv_key, - rcv_service_id = NULL + service_id = NULL |] - (userId, host srv, port srv, kh, cert, pk) + (userId, host srv, port srv, serverKeyHash_, kh, cert, pk) +-- TODO [certs rcv] get correct service based on key hash of the server getClientService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId)) getClientService db userId srv = maybeFirstRow toService $ DB.query db [sql| - SELECT service_cert_hash, service_cert, service_priv_key, rcv_service_id + SELECT service_cert_hash, service_cert, service_priv_key, service_id FROM client_services WHERE user_id = ? AND host = ? AND port = ? |] @@ -431,19 +434,21 @@ getClientService db userId srv = where toService (kh, cert, pk, serviceId_) = ((kh, (cert, pk)), serviceId_) -getClientServiceServers :: DB.Connection -> UserId -> IO [SMPServer] +getClientServiceServers :: DB.Connection -> UserId -> IO [(SMPServer, ServiceSub)] getClientServiceServers db userId = map toServer <$> DB.query db [sql| - SELECT c.host, c.port, s.key_hash + SELECT c.host, c.port, s.key_hash, c.service_id, c.service_queue_count, c.service_queue_ids_hash FROM client_services c JOIN servers s ON s.host = c.host AND s.port = c.port + WHERE c.user_id = ? |] (Only userId) where - toServer (host, port, kh) = SMPServer host port kh + toServer (host, port, kh, serviceId, n, Binary idsHash) = + (SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash)) setClientServiceId :: DB.Connection -> UserId -> SMPServer -> ServiceId -> IO () setClientServiceId db userId srv serviceId = @@ -451,7 +456,7 @@ setClientServiceId db userId srv serviceId = db [sql| UPDATE client_services - SET rcv_service_id = ? + SET service_id = ? WHERE user_id = ? AND host = ? AND port = ? |] (serviceId, userId, host srv, port srv) @@ -2099,7 +2104,7 @@ insertRcvQueue_ db connId' rq@RcvQueue {..} subMode serverKeyHash_ = do ntf_public_key, ntf_private_key, ntf_id, rcv_ntf_dh_secret ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); |] - ( (host server, port server, rcvId, rcvServiceAssoc, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) + ( (host server, port server, rcvId, BI rcvServiceAssoc, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, queueMode, status, BI toSubscribe, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_) :. (shortLinkId <$> shortLink, shortLinkKey <$> shortLink, linkPrivSigKey <$> shortLink, linkEncFixedData <$> shortLink) :. ntfCredsFields @@ -2248,6 +2253,14 @@ getUserServerRcvQueueSubs db userId srv onlyNeeded = unsetQueuesToSubscribe :: DB.Connection -> IO () unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = 0 WHERE to_subscribe = 1" +setRcvServiceAssocs :: DB.Connection -> [RcvQueueSub] -> IO () +setRcvServiceAssocs db rqs = +#if defined(dbPostgres) + DB.execute db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id IN " $ Only $ In (map queueId rqs) +#else + DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = " $ map (Only . queueId) rqs +#endif + -- * getConn helpers getConnIds :: DB.Connection -> IO [ConnId] @@ -2468,13 +2481,13 @@ rcvQueueQuery = toRcvQueue :: (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, Maybe QueueMode) - :. (QueueStatus, Maybe BoolInt, Maybe NoticeId, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int, ServiceAssoc) + :. (QueueStatus, Maybe BoolInt, Maybe NoticeId, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int, BoolInt) :. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) :. (Maybe SMP.LinkId, Maybe LinkKey, Maybe C.PrivateKeyEd25519, Maybe EncDataBytes) -> RcvQueue toRcvQueue ( (userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode) - :. (status, enableNtfs_, clientNoticeId, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors, rcvServiceAssoc) + :. (status, enableNtfs_, clientNoticeId, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors, BI rcvServiceAssoc) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) :. (shortLinkId_, shortLinkKey_, linkPrivSigKey_, linkEncFixedData_) ) = diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/App.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/App.hs index 011d89031..41090aa20 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/App.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/App.hs @@ -10,6 +10,7 @@ import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250322_short_links import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250702_conn_invitations_remove_cascade_delete import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251009_queue_to_subscribe import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251010_client_notices +import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251020_service_certs import Simplex.Messaging.Agent.Store.Shared (Migration (..)) schemaMigrations :: [(String, Text, Maybe Text)] @@ -19,7 +20,8 @@ schemaMigrations = ("20250322_short_links", m20250322_short_links, Just down_m20250322_short_links), ("20250702_conn_invitations_remove_cascade_delete", m20250702_conn_invitations_remove_cascade_delete, Just down_m20250702_conn_invitations_remove_cascade_delete), ("20251009_queue_to_subscribe", m20251009_queue_to_subscribe, Just down_m20251009_queue_to_subscribe), - ("20251010_client_notices", m20251010_client_notices, Just down_m20251010_client_notices) + ("20251010_client_notices", m20251010_client_notices, Just down_m20251010_client_notices), + ("20251020_service_certs", m20251020_service_certs, Just down_m20251020_service_certs) ] -- | The list of migrations in ascending order by date diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20251020_service_certs.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20251020_service_certs.hs new file mode 100644 index 000000000..aee45de82 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20251020_service_certs.hs @@ -0,0 +1,114 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251020_service_certs where + +import Data.Text (Text) +import Simplex.Messaging.Agent.Store.Postgres.Migrations.Util +import Text.RawString.QQ (r) + +m20251020_service_certs :: Text +m20251020_service_certs = + createXorHashFuncs <> [r| +CREATE TABLE client_services( + user_id BIGINT NOT NULL REFERENCES users ON UPDATE RESTRICT ON DELETE CASCADE, + host TEXT NOT NULL, + port TEXT NOT NULL, + server_key_hash BYTEA, + service_cert BYTEA NOT NULL, + service_cert_hash BYTEA NOT NULL, + service_priv_key BYTEA NOT NULL, + service_id BYTEA, + service_queue_count BIGINT NOT NULL DEFAULT 0, + service_queue_ids_hash BYTEA NOT NULL DEFAULT '\x00000000000000000000000000000000', + FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT +); + +CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(user_id, host, port, server_key_hash); +CREATE INDEX idx_server_certs_host_port ON client_services(host, port); + +ALTER TABLE rcv_queues ADD COLUMN rcv_service_assoc SMALLINT NOT NULL DEFAULT 0; + +CREATE FUNCTION update_aggregates(p_conn_id BYTEA, p_host TEXT, p_port TEXT, p_change BIGINT, p_rcv_id BYTEA) RETURNS VOID +LANGUAGE plpgsql +AS $$ +DECLARE q_user_id BIGINT; +BEGIN + SELECT user_id INTO q_user_id FROM connections WHERE conn_id = p_conn_id; + UPDATE client_services + SET service_queue_count = service_queue_count + p_change, + service_queue_ids_hash = xor_combine(service_queue_ids_hash, public.digest(p_rcv_id, 'md5')) + WHERE user_id = q_user_id AND host = p_host AND port = p_port; +END; +$$; + +CREATE FUNCTION on_rcv_queue_insert() RETURNS TRIGGER +LANGUAGE plpgsql +AS $$ +BEGIN + IF NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 THEN + PERFORM update_aggregates(NEW.conn_id, NEW.host, NEW.port, 1, NEW.rcv_id); + END IF; + RETURN NEW; +END; +$$; + +CREATE FUNCTION on_rcv_queue_delete() RETURNS TRIGGER +LANGUAGE plpgsql +AS $$ +BEGIN + IF OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 THEN + PERFORM update_aggregates(OLD.conn_id, OLD.host, OLD.port, -1, OLD.rcv_id); + END IF; + RETURN OLD; +END; +$$; + +CREATE FUNCTION on_rcv_queue_update() RETURNS TRIGGER +LANGUAGE plpgsql +AS $$ +BEGIN + IF OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 THEN + IF NOT (NEW.rcv_service_assoc != 0 AND NEW.deleted = 0) THEN + PERFORM update_aggregates(OLD.conn_id, OLD.host, OLD.port, -1, OLD.rcv_id); + END IF; + ELSIF NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 THEN + PERFORM update_aggregates(NEW.conn_id, NEW.host, NEW.port, 1, NEW.rcv_id); + END IF; + RETURN NEW; +END; +$$; + +CREATE TRIGGER tr_rcv_queue_insert +AFTER INSERT ON rcv_queues +FOR EACH ROW EXECUTE PROCEDURE on_rcv_queue_insert(); + +CREATE TRIGGER tr_rcv_queue_delete +AFTER DELETE ON rcv_queues +FOR EACH ROW EXECUTE PROCEDURE on_rcv_queue_delete(); + +CREATE TRIGGER tr_rcv_queue_update +AFTER UPDATE ON rcv_queues +FOR EACH ROW EXECUTE PROCEDURE on_rcv_queue_update(); + |] + +down_m20251020_service_certs :: Text +down_m20251020_service_certs = + [r| +DROP TRIGGER tr_rcv_queue_insert ON rcv_queues; +DROP TRIGGER tr_rcv_queue_delete ON rcv_queues; +DROP TRIGGER tr_rcv_queue_update ON rcv_queues; + +DROP FUNCTION on_rcv_queue_insert; +DROP FUNCTION on_rcv_queue_delete; +DROP FUNCTION on_rcv_queue_update; + +DROP FUNCTION update_aggregates; + +ALTER TABLE rcv_queues DROP COLUMN rcv_service_assoc; + +DROP INDEX idx_server_certs_host_port; +DROP INDEX idx_server_certs_user_id_host_port; +DROP TABLE client_services; + |] + <> dropXorHashFuncs diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/Util.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/Util.hs new file mode 100644 index 000000000..b51d487e4 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/Util.hs @@ -0,0 +1,46 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.Postgres.Migrations.Util where + +import Data.Text (Text) +import qualified Data.Text as T +import Text.RawString.QQ (r) + +-- xor_combine is only applied to locally computed md5 hashes (128 bits/16 bytes), +-- so it is safe to require that all values are of the same length. +createXorHashFuncs :: Text +createXorHashFuncs = + T.pack + [r| +CREATE OR REPLACE FUNCTION xor_combine(state BYTEA, value BYTEA) RETURNS BYTEA +LANGUAGE plpgsql IMMUTABLE STRICT +AS $$ +DECLARE + result BYTEA := state; + i INTEGER; + len INTEGER := octet_length(value); +BEGIN + IF octet_length(state) != len THEN + RAISE EXCEPTION 'Inputs must be equal length (% != %)', octet_length(state), len; + END IF; + FOR i IN 0..len-1 LOOP + result := set_byte(result, i, get_byte(state, i) # get_byte(value, i)); + END LOOP; + RETURN result; +END; +$$; + +CREATE OR REPLACE AGGREGATE xor_aggregate(BYTEA) ( + SFUNC = xor_combine, + STYPE = BYTEA, + INITCOND = '\x00000000000000000000000000000000' -- 16 bytes +); + |] + +dropXorHashFuncs :: Text +dropXorHashFuncs = + T.pack + [r| +DROP AGGREGATE xor_aggregate(BYTEA); +DROP FUNCTION xor_combine; + |] diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/agent_postgres_schema.sql b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/agent_postgres_schema.sql new file mode 100644 index 000000000..c56efb226 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/agent_postgres_schema.sql @@ -0,0 +1,1469 @@ + + +SET statement_timeout = 0; +SET lock_timeout = 0; +SET idle_in_transaction_session_timeout = 0; +SET client_encoding = 'UTF8'; +SET standard_conforming_strings = on; +SELECT pg_catalog.set_config('search_path', '', false); +SET check_function_bodies = false; +SET xmloption = content; +SET client_min_messages = warning; +SET row_security = off; + + +CREATE SCHEMA smp_agent_test_protocol_schema; + + + +CREATE FUNCTION smp_agent_test_protocol_schema.on_rcv_queue_delete() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 THEN + PERFORM update_aggregates(OLD.conn_id, OLD.host, OLD.port, -1, OLD.rcv_id); + END IF; + RETURN OLD; +END; +$$; + + + +CREATE FUNCTION smp_agent_test_protocol_schema.on_rcv_queue_insert() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 THEN + PERFORM update_aggregates(NEW.conn_id, NEW.host, NEW.port, 1, NEW.rcv_id); + END IF; + RETURN NEW; +END; +$$; + + + +CREATE FUNCTION smp_agent_test_protocol_schema.on_rcv_queue_update() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 THEN + IF NOT (NEW.rcv_service_assoc != 0 AND NEW.deleted = 0) THEN + PERFORM update_aggregates(OLD.conn_id, OLD.host, OLD.port, -1, OLD.rcv_id); + END IF; + ELSIF NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 THEN + PERFORM update_aggregates(NEW.conn_id, NEW.host, NEW.port, 1, NEW.rcv_id); + END IF; + RETURN NEW; +END; +$$; + + + +CREATE FUNCTION smp_agent_test_protocol_schema.update_aggregates(p_conn_id bytea, p_host text, p_port text, p_change bigint, p_rcv_id bytea) RETURNS void + LANGUAGE plpgsql + AS $$ +DECLARE q_user_id BIGINT; +BEGIN + SELECT user_id INTO q_user_id FROM connections WHERE conn_id = p_conn_id; + UPDATE client_services + SET service_queue_count = service_queue_count + p_change, + service_queue_ids_hash = xor_combine(service_queue_ids_hash, public.digest(p_rcv_id, 'md5')) + WHERE user_id = q_user_id AND host = p_host AND port = p_port; +END; +$$; + + + +CREATE FUNCTION smp_agent_test_protocol_schema.xor_combine(state bytea, value bytea) RETURNS bytea + LANGUAGE plpgsql IMMUTABLE STRICT + AS $$ +DECLARE + result BYTEA := state; + i INTEGER; + len INTEGER := octet_length(value); +BEGIN + IF octet_length(state) != len THEN + RAISE EXCEPTION 'Inputs must be equal length (% != %)', octet_length(state), len; + END IF; + FOR i IN 0..len-1 LOOP + result := set_byte(result, i, get_byte(state, i) # get_byte(value, i)); + END LOOP; + RETURN result; +END; +$$; + + + +CREATE AGGREGATE smp_agent_test_protocol_schema.xor_aggregate(bytea) ( + SFUNC = smp_agent_test_protocol_schema.xor_combine, + STYPE = bytea, + INITCOND = '\x00000000000000000000000000000000' +); + + +SET default_table_access_method = heap; + + +CREATE TABLE smp_agent_test_protocol_schema.client_notices ( + client_notice_id bigint NOT NULL, + protocol text NOT NULL, + host text NOT NULL, + port text NOT NULL, + entity_id bytea NOT NULL, + server_key_hash bytea, + notice_ttl bigint, + created_at bigint NOT NULL, + updated_at bigint NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.client_notices ALTER COLUMN client_notice_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.client_notices_client_notice_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.client_services ( + user_id bigint NOT NULL, + host text NOT NULL, + port text NOT NULL, + server_key_hash bytea, + service_cert bytea NOT NULL, + service_cert_hash bytea NOT NULL, + service_priv_key bytea NOT NULL, + service_id bytea, + service_queue_count bigint DEFAULT 0 NOT NULL, + service_queue_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL +); + + + +CREATE TABLE smp_agent_test_protocol_schema.commands ( + command_id bigint NOT NULL, + conn_id bytea NOT NULL, + host text, + port text, + corr_id bytea NOT NULL, + command_tag bytea NOT NULL, + command bytea NOT NULL, + agent_version integer DEFAULT 1 NOT NULL, + server_key_hash bytea, + created_at timestamp with time zone DEFAULT '1970-01-01 00:00:00+01'::timestamp with time zone NOT NULL, + failed smallint DEFAULT 0 +); + + + +ALTER TABLE smp_agent_test_protocol_schema.commands ALTER COLUMN command_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.commands_command_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.conn_confirmations ( + confirmation_id bytea NOT NULL, + conn_id bytea NOT NULL, + e2e_snd_pub_key bytea NOT NULL, + sender_key bytea, + ratchet_state bytea NOT NULL, + sender_conn_info bytea NOT NULL, + accepted smallint NOT NULL, + own_conn_info bytea, + created_at timestamp with time zone DEFAULT now() NOT NULL, + smp_reply_queues bytea, + smp_client_version integer +); + + + +CREATE TABLE smp_agent_test_protocol_schema.conn_invitations ( + invitation_id bytea NOT NULL, + contact_conn_id bytea, + cr_invitation bytea NOT NULL, + recipient_conn_info bytea NOT NULL, + accepted smallint DEFAULT 0 NOT NULL, + own_conn_info bytea, + created_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +CREATE TABLE smp_agent_test_protocol_schema.connections ( + conn_id bytea NOT NULL, + conn_mode text NOT NULL, + last_internal_msg_id bigint DEFAULT 0 NOT NULL, + last_internal_rcv_msg_id bigint DEFAULT 0 NOT NULL, + last_internal_snd_msg_id bigint DEFAULT 0 NOT NULL, + last_external_snd_msg_id bigint DEFAULT 0 NOT NULL, + last_rcv_msg_hash bytea DEFAULT '\x'::bytea NOT NULL, + last_snd_msg_hash bytea DEFAULT '\x'::bytea NOT NULL, + smp_agent_version integer DEFAULT 1 NOT NULL, + duplex_handshake smallint DEFAULT 0, + enable_ntfs smallint, + deleted smallint DEFAULT 0 NOT NULL, + user_id bigint NOT NULL, + ratchet_sync_state text DEFAULT 'ok'::text NOT NULL, + deleted_at_wait_delivery timestamp with time zone, + pq_support smallint DEFAULT 0 NOT NULL +); + + + +CREATE TABLE smp_agent_test_protocol_schema.deleted_snd_chunk_replicas ( + deleted_snd_chunk_replica_id bigint NOT NULL, + user_id bigint NOT NULL, + xftp_server_id bigint NOT NULL, + replica_id bytea NOT NULL, + replica_key bytea NOT NULL, + chunk_digest bytea NOT NULL, + delay bigint, + retries bigint DEFAULT 0 NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + failed smallint DEFAULT 0 +); + + + +ALTER TABLE smp_agent_test_protocol_schema.deleted_snd_chunk_replicas ALTER COLUMN deleted_snd_chunk_replica_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.deleted_snd_chunk_replicas_deleted_snd_chunk_replica_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.encrypted_rcv_message_hashes ( + encrypted_rcv_message_hash_id bigint NOT NULL, + conn_id bytea NOT NULL, + hash bytea NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.encrypted_rcv_message_hashes ALTER COLUMN encrypted_rcv_message_hash_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.encrypted_rcv_message_hashes_encrypted_rcv_message_hash_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.inv_short_links ( + inv_short_link_id bigint NOT NULL, + host text NOT NULL, + port text NOT NULL, + server_key_hash bytea, + link_id bytea NOT NULL, + link_key bytea NOT NULL, + snd_private_key bytea NOT NULL, + snd_id bytea +); + + + +ALTER TABLE smp_agent_test_protocol_schema.inv_short_links ALTER COLUMN inv_short_link_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.inv_short_links_inv_short_link_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.messages ( + conn_id bytea NOT NULL, + internal_id bigint NOT NULL, + internal_ts timestamp with time zone NOT NULL, + internal_rcv_id bigint, + internal_snd_id bigint, + msg_type bytea NOT NULL, + msg_body bytea DEFAULT '\x'::bytea NOT NULL, + msg_flags text, + pq_encryption smallint DEFAULT 0 NOT NULL +); + + + +CREATE TABLE smp_agent_test_protocol_schema.migrations ( + name text NOT NULL, + ts timestamp without time zone NOT NULL, + down text +); + + + +CREATE TABLE smp_agent_test_protocol_schema.ntf_servers ( + ntf_host text NOT NULL, + ntf_port text NOT NULL, + ntf_key_hash bytea NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +CREATE TABLE smp_agent_test_protocol_schema.ntf_subscriptions ( + conn_id bytea NOT NULL, + smp_host text, + smp_port text, + smp_ntf_id bytea, + ntf_host text NOT NULL, + ntf_port text NOT NULL, + ntf_sub_id bytea, + ntf_sub_status text NOT NULL, + ntf_sub_action bytea, + ntf_sub_smp_action bytea, + ntf_sub_action_ts timestamp with time zone, + updated_by_supervisor smallint DEFAULT 0 NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + smp_server_key_hash bytea, + ntf_failed smallint DEFAULT 0, + smp_failed smallint DEFAULT 0 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.ntf_tokens ( + provider text NOT NULL, + device_token text NOT NULL, + ntf_host text NOT NULL, + ntf_port text NOT NULL, + tkn_id bytea, + tkn_pub_key bytea NOT NULL, + tkn_priv_key bytea NOT NULL, + tkn_pub_dh_key bytea NOT NULL, + tkn_priv_dh_key bytea NOT NULL, + tkn_dh_secret bytea, + tkn_status text NOT NULL, + tkn_action bytea, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + ntf_mode bytea +); + + + +CREATE TABLE smp_agent_test_protocol_schema.ntf_tokens_to_delete ( + ntf_token_to_delete_id bigint NOT NULL, + ntf_host text NOT NULL, + ntf_port text NOT NULL, + ntf_key_hash bytea NOT NULL, + tkn_id bytea NOT NULL, + tkn_priv_key bytea NOT NULL, + del_failed smallint DEFAULT 0, + created_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.ntf_tokens_to_delete ALTER COLUMN ntf_token_to_delete_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.ntf_tokens_to_delete_ntf_token_to_delete_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.processed_ratchet_key_hashes ( + processed_ratchet_key_hash_id bigint NOT NULL, + conn_id bytea NOT NULL, + hash bytea NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.processed_ratchet_key_hashes ALTER COLUMN processed_ratchet_key_hash_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.processed_ratchet_key_hashes_processed_ratchet_key_hash_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.ratchets ( + conn_id bytea NOT NULL, + x3dh_priv_key_1 bytea, + x3dh_priv_key_2 bytea, + ratchet_state bytea, + e2e_version integer DEFAULT 1 NOT NULL, + x3dh_pub_key_1 bytea, + x3dh_pub_key_2 bytea, + pq_priv_kem bytea, + pq_pub_kem bytea +); + + + +CREATE TABLE smp_agent_test_protocol_schema.rcv_file_chunk_replicas ( + rcv_file_chunk_replica_id bigint NOT NULL, + rcv_file_chunk_id bigint NOT NULL, + replica_number bigint NOT NULL, + xftp_server_id bigint NOT NULL, + replica_id bytea NOT NULL, + replica_key bytea NOT NULL, + received smallint DEFAULT 0 NOT NULL, + delay bigint, + retries bigint DEFAULT 0 NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.rcv_file_chunk_replicas ALTER COLUMN rcv_file_chunk_replica_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.rcv_file_chunk_replicas_rcv_file_chunk_replica_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.rcv_file_chunks ( + rcv_file_chunk_id bigint NOT NULL, + rcv_file_id bigint NOT NULL, + chunk_no bigint NOT NULL, + chunk_size bigint NOT NULL, + digest bytea NOT NULL, + tmp_path text, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.rcv_file_chunks ALTER COLUMN rcv_file_chunk_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.rcv_file_chunks_rcv_file_chunk_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.rcv_files ( + rcv_file_id bigint NOT NULL, + rcv_file_entity_id bytea NOT NULL, + user_id bigint NOT NULL, + size bigint NOT NULL, + digest bytea NOT NULL, + key bytea NOT NULL, + nonce bytea NOT NULL, + chunk_size bigint NOT NULL, + prefix_path text NOT NULL, + tmp_path text, + save_path text NOT NULL, + status text NOT NULL, + deleted smallint DEFAULT 0 NOT NULL, + error text, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + save_file_key bytea, + save_file_nonce bytea, + failed smallint DEFAULT 0, + redirect_id bigint, + redirect_entity_id bytea, + redirect_size bigint, + redirect_digest bytea, + approved_relays smallint DEFAULT 0 NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.rcv_files ALTER COLUMN rcv_file_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.rcv_files_rcv_file_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.rcv_messages ( + conn_id bytea NOT NULL, + internal_rcv_id bigint NOT NULL, + internal_id bigint NOT NULL, + external_snd_id bigint NOT NULL, + broker_id bytea NOT NULL, + broker_ts timestamp with time zone NOT NULL, + internal_hash bytea NOT NULL, + external_prev_snd_hash bytea NOT NULL, + integrity bytea NOT NULL, + user_ack smallint DEFAULT 0, + rcv_queue_id bigint NOT NULL +); + + + +CREATE TABLE smp_agent_test_protocol_schema.rcv_queues ( + host text NOT NULL, + port text NOT NULL, + rcv_id bytea NOT NULL, + conn_id bytea NOT NULL, + rcv_private_key bytea NOT NULL, + rcv_dh_secret bytea NOT NULL, + e2e_priv_key bytea NOT NULL, + e2e_dh_secret bytea, + snd_id bytea NOT NULL, + snd_key bytea, + status text NOT NULL, + smp_server_version integer DEFAULT 1 NOT NULL, + smp_client_version integer, + ntf_public_key bytea, + ntf_private_key bytea, + ntf_id bytea, + rcv_ntf_dh_secret bytea, + rcv_queue_id bigint NOT NULL, + rcv_primary smallint NOT NULL, + replace_rcv_queue_id bigint, + delete_errors bigint DEFAULT 0 NOT NULL, + server_key_hash bytea, + switch_status text, + deleted smallint DEFAULT 0 NOT NULL, + last_broker_ts timestamp with time zone, + link_id bytea, + link_key bytea, + link_priv_sig_key bytea, + link_enc_fixed_data bytea, + queue_mode text, + to_subscribe smallint DEFAULT 0 NOT NULL, + client_notice_id bigint, + rcv_service_assoc smallint DEFAULT 0 NOT NULL +); + + + +CREATE TABLE smp_agent_test_protocol_schema.servers ( + host text NOT NULL, + port text NOT NULL, + key_hash bytea NOT NULL +); + + + +CREATE TABLE smp_agent_test_protocol_schema.servers_stats ( + servers_stats_id bigint NOT NULL, + servers_stats text, + started_at timestamp with time zone DEFAULT now() NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.servers_stats ALTER COLUMN servers_stats_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.servers_stats_servers_stats_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.skipped_messages ( + skipped_message_id bigint NOT NULL, + conn_id bytea NOT NULL, + header_key bytea NOT NULL, + msg_n bigint NOT NULL, + msg_key bytea NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.skipped_messages ALTER COLUMN skipped_message_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.skipped_messages_skipped_message_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.snd_file_chunk_replica_recipients ( + snd_file_chunk_replica_recipient_id bigint NOT NULL, + snd_file_chunk_replica_id bigint NOT NULL, + rcv_replica_id bytea NOT NULL, + rcv_replica_key bytea NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.snd_file_chunk_replica_recipients ALTER COLUMN snd_file_chunk_replica_recipient_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.snd_file_chunk_replica_recipi_snd_file_chunk_replica_recipi_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.snd_file_chunk_replicas ( + snd_file_chunk_replica_id bigint NOT NULL, + snd_file_chunk_id bigint NOT NULL, + replica_number bigint NOT NULL, + xftp_server_id bigint NOT NULL, + replica_id bytea NOT NULL, + replica_key bytea NOT NULL, + replica_status text NOT NULL, + delay bigint, + retries bigint DEFAULT 0 NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.snd_file_chunk_replicas ALTER COLUMN snd_file_chunk_replica_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.snd_file_chunk_replicas_snd_file_chunk_replica_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.snd_file_chunks ( + snd_file_chunk_id bigint NOT NULL, + snd_file_id bigint NOT NULL, + chunk_no bigint NOT NULL, + chunk_offset bigint NOT NULL, + chunk_size bigint NOT NULL, + digest bytea NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.snd_file_chunks ALTER COLUMN snd_file_chunk_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.snd_file_chunks_snd_file_chunk_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.snd_files ( + snd_file_id bigint NOT NULL, + snd_file_entity_id bytea NOT NULL, + user_id bigint NOT NULL, + num_recipients bigint NOT NULL, + digest bytea, + key bytea NOT NULL, + nonce bytea NOT NULL, + path text NOT NULL, + prefix_path text, + status text NOT NULL, + deleted smallint DEFAULT 0 NOT NULL, + error text, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + src_file_key bytea, + src_file_nonce bytea, + failed smallint DEFAULT 0, + redirect_size bigint, + redirect_digest bytea +); + + + +ALTER TABLE smp_agent_test_protocol_schema.snd_files ALTER COLUMN snd_file_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.snd_files_snd_file_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.snd_message_bodies ( + snd_message_body_id bigint NOT NULL, + agent_msg bytea DEFAULT '\x'::bytea NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.snd_message_bodies ALTER COLUMN snd_message_body_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.snd_message_bodies_snd_message_body_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.snd_message_deliveries ( + snd_message_delivery_id bigint NOT NULL, + conn_id bytea NOT NULL, + snd_queue_id bigint NOT NULL, + internal_id bigint NOT NULL, + failed smallint DEFAULT 0 +); + + + +ALTER TABLE smp_agent_test_protocol_schema.snd_message_deliveries ALTER COLUMN snd_message_delivery_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.snd_message_deliveries_snd_message_delivery_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.snd_messages ( + conn_id bytea NOT NULL, + internal_snd_id bigint NOT NULL, + internal_id bigint NOT NULL, + internal_hash bytea NOT NULL, + previous_msg_hash bytea DEFAULT '\x'::bytea NOT NULL, + retry_int_slow bigint, + retry_int_fast bigint, + rcpt_internal_id bigint, + rcpt_status text, + msg_encrypt_key bytea, + padded_msg_len bigint, + snd_message_body_id bigint +); + + + +CREATE TABLE smp_agent_test_protocol_schema.snd_queues ( + host text NOT NULL, + port text NOT NULL, + snd_id bytea NOT NULL, + conn_id bytea NOT NULL, + snd_private_key bytea NOT NULL, + e2e_dh_secret bytea NOT NULL, + status text NOT NULL, + smp_server_version integer DEFAULT 1 NOT NULL, + smp_client_version integer DEFAULT 1 NOT NULL, + snd_public_key bytea, + e2e_pub_key bytea, + snd_queue_id bigint NOT NULL, + snd_primary smallint NOT NULL, + replace_snd_queue_id bigint, + server_key_hash bytea, + switch_status text, + queue_mode text +); + + + +CREATE TABLE smp_agent_test_protocol_schema.users ( + user_id bigint NOT NULL, + deleted smallint DEFAULT 0 NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.users ALTER COLUMN user_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.users_user_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +CREATE TABLE smp_agent_test_protocol_schema.xftp_servers ( + xftp_server_id bigint NOT NULL, + xftp_host text NOT NULL, + xftp_port text NOT NULL, + xftp_key_hash bytea NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + + + +ALTER TABLE smp_agent_test_protocol_schema.xftp_servers ALTER COLUMN xftp_server_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_agent_test_protocol_schema.xftp_servers_xftp_server_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.client_notices + ADD CONSTRAINT client_notices_pkey PRIMARY KEY (client_notice_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.commands + ADD CONSTRAINT commands_pkey PRIMARY KEY (command_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.conn_confirmations + ADD CONSTRAINT conn_confirmations_pkey PRIMARY KEY (confirmation_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.conn_invitations + ADD CONSTRAINT conn_invitations_pkey PRIMARY KEY (invitation_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.connections + ADD CONSTRAINT connections_pkey PRIMARY KEY (conn_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.deleted_snd_chunk_replicas + ADD CONSTRAINT deleted_snd_chunk_replicas_pkey PRIMARY KEY (deleted_snd_chunk_replica_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.encrypted_rcv_message_hashes + ADD CONSTRAINT encrypted_rcv_message_hashes_pkey PRIMARY KEY (encrypted_rcv_message_hash_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.inv_short_links + ADD CONSTRAINT inv_short_links_pkey PRIMARY KEY (inv_short_link_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.messages + ADD CONSTRAINT messages_pkey PRIMARY KEY (conn_id, internal_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.migrations + ADD CONSTRAINT migrations_pkey PRIMARY KEY (name); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.ntf_servers + ADD CONSTRAINT ntf_servers_pkey PRIMARY KEY (ntf_host, ntf_port); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.ntf_subscriptions + ADD CONSTRAINT ntf_subscriptions_pkey PRIMARY KEY (conn_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.ntf_tokens + ADD CONSTRAINT ntf_tokens_pkey PRIMARY KEY (provider, device_token, ntf_host, ntf_port); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.ntf_tokens_to_delete + ADD CONSTRAINT ntf_tokens_to_delete_pkey PRIMARY KEY (ntf_token_to_delete_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.processed_ratchet_key_hashes + ADD CONSTRAINT processed_ratchet_key_hashes_pkey PRIMARY KEY (processed_ratchet_key_hash_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.ratchets + ADD CONSTRAINT ratchets_pkey PRIMARY KEY (conn_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_file_chunk_replicas + ADD CONSTRAINT rcv_file_chunk_replicas_pkey PRIMARY KEY (rcv_file_chunk_replica_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_file_chunks + ADD CONSTRAINT rcv_file_chunks_pkey PRIMARY KEY (rcv_file_chunk_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_files + ADD CONSTRAINT rcv_files_pkey PRIMARY KEY (rcv_file_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_files + ADD CONSTRAINT rcv_files_rcv_file_entity_id_key UNIQUE (rcv_file_entity_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_messages + ADD CONSTRAINT rcv_messages_pkey PRIMARY KEY (conn_id, internal_rcv_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_queues + ADD CONSTRAINT rcv_queues_host_port_snd_id_key UNIQUE (host, port, snd_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_queues + ADD CONSTRAINT rcv_queues_pkey PRIMARY KEY (host, port, rcv_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.servers + ADD CONSTRAINT servers_pkey PRIMARY KEY (host, port); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.servers_stats + ADD CONSTRAINT servers_stats_pkey PRIMARY KEY (servers_stats_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.skipped_messages + ADD CONSTRAINT skipped_messages_pkey PRIMARY KEY (skipped_message_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_file_chunk_replica_recipients + ADD CONSTRAINT snd_file_chunk_replica_recipients_pkey PRIMARY KEY (snd_file_chunk_replica_recipient_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_file_chunk_replicas + ADD CONSTRAINT snd_file_chunk_replicas_pkey PRIMARY KEY (snd_file_chunk_replica_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_file_chunks + ADD CONSTRAINT snd_file_chunks_pkey PRIMARY KEY (snd_file_chunk_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_files + ADD CONSTRAINT snd_files_pkey PRIMARY KEY (snd_file_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_message_bodies + ADD CONSTRAINT snd_message_bodies_pkey PRIMARY KEY (snd_message_body_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_message_deliveries + ADD CONSTRAINT snd_message_deliveries_pkey PRIMARY KEY (snd_message_delivery_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_messages + ADD CONSTRAINT snd_messages_pkey PRIMARY KEY (conn_id, internal_snd_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_queues + ADD CONSTRAINT snd_queues_pkey PRIMARY KEY (host, port, snd_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.users + ADD CONSTRAINT users_pkey PRIMARY KEY (user_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.xftp_servers + ADD CONSTRAINT xftp_servers_pkey PRIMARY KEY (xftp_server_id); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.xftp_servers + ADD CONSTRAINT xftp_servers_xftp_host_xftp_port_xftp_key_hash_key UNIQUE (xftp_host, xftp_port, xftp_key_hash); + + + +CREATE UNIQUE INDEX idx_client_notices_entity ON smp_agent_test_protocol_schema.client_notices USING btree (protocol, host, port, entity_id); + + + +CREATE INDEX idx_commands_conn_id ON smp_agent_test_protocol_schema.commands USING btree (conn_id); + + + +CREATE INDEX idx_commands_host_port ON smp_agent_test_protocol_schema.commands USING btree (host, port); + + + +CREATE INDEX idx_commands_server_commands ON smp_agent_test_protocol_schema.commands USING btree (host, port, created_at, command_id); + + + +CREATE INDEX idx_conn_confirmations_conn_id ON smp_agent_test_protocol_schema.conn_confirmations USING btree (conn_id); + + + +CREATE INDEX idx_conn_invitations_contact_conn_id ON smp_agent_test_protocol_schema.conn_invitations USING btree (contact_conn_id); + + + +CREATE INDEX idx_connections_user ON smp_agent_test_protocol_schema.connections USING btree (user_id); + + + +CREATE INDEX idx_deleted_snd_chunk_replicas_pending ON smp_agent_test_protocol_schema.deleted_snd_chunk_replicas USING btree (created_at); + + + +CREATE INDEX idx_deleted_snd_chunk_replicas_user_id ON smp_agent_test_protocol_schema.deleted_snd_chunk_replicas USING btree (user_id); + + + +CREATE INDEX idx_deleted_snd_chunk_replicas_xftp_server_id ON smp_agent_test_protocol_schema.deleted_snd_chunk_replicas USING btree (xftp_server_id); + + + +CREATE INDEX idx_encrypted_rcv_message_hashes_created_at ON smp_agent_test_protocol_schema.encrypted_rcv_message_hashes USING btree (created_at); + + + +CREATE INDEX idx_encrypted_rcv_message_hashes_hash ON smp_agent_test_protocol_schema.encrypted_rcv_message_hashes USING btree (conn_id, hash); + + + +CREATE UNIQUE INDEX idx_inv_short_links_link_id ON smp_agent_test_protocol_schema.inv_short_links USING btree (host, port, link_id); + + + +CREATE INDEX idx_messages_conn_id ON smp_agent_test_protocol_schema.messages USING btree (conn_id); + + + +CREATE INDEX idx_messages_conn_id_internal_rcv_id ON smp_agent_test_protocol_schema.messages USING btree (conn_id, internal_rcv_id); + + + +CREATE INDEX idx_messages_conn_id_internal_snd_id ON smp_agent_test_protocol_schema.messages USING btree (conn_id, internal_snd_id); + + + +CREATE INDEX idx_messages_internal_ts ON smp_agent_test_protocol_schema.messages USING btree (internal_ts); + + + +CREATE INDEX idx_messages_snd_expired ON smp_agent_test_protocol_schema.messages USING btree (conn_id, internal_snd_id, internal_ts); + + + +CREATE INDEX idx_ntf_subscriptions_ntf_host_ntf_port ON smp_agent_test_protocol_schema.ntf_subscriptions USING btree (ntf_host, ntf_port); + + + +CREATE INDEX idx_ntf_subscriptions_smp_host_smp_port ON smp_agent_test_protocol_schema.ntf_subscriptions USING btree (smp_host, smp_port); + + + +CREATE INDEX idx_ntf_tokens_ntf_host_ntf_port ON smp_agent_test_protocol_schema.ntf_tokens USING btree (ntf_host, ntf_port); + + + +CREATE INDEX idx_processed_ratchet_key_hashes_created_at ON smp_agent_test_protocol_schema.processed_ratchet_key_hashes USING btree (created_at); + + + +CREATE INDEX idx_processed_ratchet_key_hashes_hash ON smp_agent_test_protocol_schema.processed_ratchet_key_hashes USING btree (conn_id, hash); + + + +CREATE INDEX idx_ratchets_conn_id ON smp_agent_test_protocol_schema.ratchets USING btree (conn_id); + + + +CREATE INDEX idx_rcv_file_chunk_replicas_pending ON smp_agent_test_protocol_schema.rcv_file_chunk_replicas USING btree (received, replica_number); + + + +CREATE INDEX idx_rcv_file_chunk_replicas_rcv_file_chunk_id ON smp_agent_test_protocol_schema.rcv_file_chunk_replicas USING btree (rcv_file_chunk_id); + + + +CREATE INDEX idx_rcv_file_chunk_replicas_xftp_server_id ON smp_agent_test_protocol_schema.rcv_file_chunk_replicas USING btree (xftp_server_id); + + + +CREATE INDEX idx_rcv_file_chunks_rcv_file_id ON smp_agent_test_protocol_schema.rcv_file_chunks USING btree (rcv_file_id); + + + +CREATE INDEX idx_rcv_files_redirect_id ON smp_agent_test_protocol_schema.rcv_files USING btree (redirect_id); + + + +CREATE INDEX idx_rcv_files_status_created_at ON smp_agent_test_protocol_schema.rcv_files USING btree (status, created_at); + + + +CREATE INDEX idx_rcv_files_user_id ON smp_agent_test_protocol_schema.rcv_files USING btree (user_id); + + + +CREATE INDEX idx_rcv_messages_conn_id_internal_id ON smp_agent_test_protocol_schema.rcv_messages USING btree (conn_id, internal_id); + + + +CREATE UNIQUE INDEX idx_rcv_queue_id ON smp_agent_test_protocol_schema.rcv_queues USING btree (conn_id, rcv_queue_id); + + + +CREATE INDEX idx_rcv_queues_client_notice_id ON smp_agent_test_protocol_schema.rcv_queues USING btree (client_notice_id); + + + +CREATE UNIQUE INDEX idx_rcv_queues_link_id ON smp_agent_test_protocol_schema.rcv_queues USING btree (host, port, link_id); + + + +CREATE UNIQUE INDEX idx_rcv_queues_ntf ON smp_agent_test_protocol_schema.rcv_queues USING btree (host, port, ntf_id); + + + +CREATE INDEX idx_rcv_queues_to_subscribe ON smp_agent_test_protocol_schema.rcv_queues USING btree (to_subscribe); + + + +CREATE INDEX idx_server_certs_host_port ON smp_agent_test_protocol_schema.client_services USING btree (host, port); + + + +CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON smp_agent_test_protocol_schema.client_services USING btree (user_id, host, port, server_key_hash); + + + +CREATE INDEX idx_skipped_messages_conn_id ON smp_agent_test_protocol_schema.skipped_messages USING btree (conn_id); + + + +CREATE INDEX idx_snd_file_chunk_replica_recipients_snd_file_chunk_replica_id ON smp_agent_test_protocol_schema.snd_file_chunk_replica_recipients USING btree (snd_file_chunk_replica_id); + + + +CREATE INDEX idx_snd_file_chunk_replicas_pending ON smp_agent_test_protocol_schema.snd_file_chunk_replicas USING btree (replica_status, replica_number); + + + +CREATE INDEX idx_snd_file_chunk_replicas_snd_file_chunk_id ON smp_agent_test_protocol_schema.snd_file_chunk_replicas USING btree (snd_file_chunk_id); + + + +CREATE INDEX idx_snd_file_chunk_replicas_xftp_server_id ON smp_agent_test_protocol_schema.snd_file_chunk_replicas USING btree (xftp_server_id); + + + +CREATE INDEX idx_snd_file_chunks_snd_file_id ON smp_agent_test_protocol_schema.snd_file_chunks USING btree (snd_file_id); + + + +CREATE INDEX idx_snd_files_snd_file_entity_id ON smp_agent_test_protocol_schema.snd_files USING btree (snd_file_entity_id); + + + +CREATE INDEX idx_snd_files_status_created_at ON smp_agent_test_protocol_schema.snd_files USING btree (status, created_at); + + + +CREATE INDEX idx_snd_files_user_id ON smp_agent_test_protocol_schema.snd_files USING btree (user_id); + + + +CREATE INDEX idx_snd_message_deliveries ON smp_agent_test_protocol_schema.snd_message_deliveries USING btree (conn_id, snd_queue_id); + + + +CREATE INDEX idx_snd_message_deliveries_conn_id_internal_id ON smp_agent_test_protocol_schema.snd_message_deliveries USING btree (conn_id, internal_id); + + + +CREATE INDEX idx_snd_message_deliveries_expired ON smp_agent_test_protocol_schema.snd_message_deliveries USING btree (conn_id, snd_queue_id, failed, internal_id); + + + +CREATE INDEX idx_snd_messages_conn_id_internal_id ON smp_agent_test_protocol_schema.snd_messages USING btree (conn_id, internal_id); + + + +CREATE INDEX idx_snd_messages_rcpt_internal_id ON smp_agent_test_protocol_schema.snd_messages USING btree (conn_id, rcpt_internal_id); + + + +CREATE INDEX idx_snd_messages_snd_message_body_id ON smp_agent_test_protocol_schema.snd_messages USING btree (snd_message_body_id); + + + +CREATE UNIQUE INDEX idx_snd_queue_id ON smp_agent_test_protocol_schema.snd_queues USING btree (conn_id, snd_queue_id); + + + +CREATE INDEX idx_snd_queues_host_port ON smp_agent_test_protocol_schema.snd_queues USING btree (host, port); + + + +CREATE TRIGGER tr_rcv_queue_delete AFTER DELETE ON smp_agent_test_protocol_schema.rcv_queues FOR EACH ROW EXECUTE FUNCTION smp_agent_test_protocol_schema.on_rcv_queue_delete(); + + + +CREATE TRIGGER tr_rcv_queue_insert AFTER INSERT ON smp_agent_test_protocol_schema.rcv_queues FOR EACH ROW EXECUTE FUNCTION smp_agent_test_protocol_schema.on_rcv_queue_insert(); + + + +CREATE TRIGGER tr_rcv_queue_update AFTER UPDATE ON smp_agent_test_protocol_schema.rcv_queues FOR EACH ROW EXECUTE FUNCTION smp_agent_test_protocol_schema.on_rcv_queue_update(); + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.client_services + ADD CONSTRAINT client_services_host_port_fkey FOREIGN KEY (host, port) REFERENCES smp_agent_test_protocol_schema.servers(host, port) ON DELETE RESTRICT; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.client_services + ADD CONSTRAINT client_services_user_id_fkey FOREIGN KEY (user_id) REFERENCES smp_agent_test_protocol_schema.users(user_id) ON UPDATE RESTRICT ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.commands + ADD CONSTRAINT commands_conn_id_fkey FOREIGN KEY (conn_id) REFERENCES smp_agent_test_protocol_schema.connections(conn_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.commands + ADD CONSTRAINT commands_host_port_fkey FOREIGN KEY (host, port) REFERENCES smp_agent_test_protocol_schema.servers(host, port) ON UPDATE CASCADE ON DELETE RESTRICT; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.conn_confirmations + ADD CONSTRAINT conn_confirmations_conn_id_fkey FOREIGN KEY (conn_id) REFERENCES smp_agent_test_protocol_schema.connections(conn_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.conn_invitations + ADD CONSTRAINT conn_invitations_contact_conn_id_fkey FOREIGN KEY (contact_conn_id) REFERENCES smp_agent_test_protocol_schema.connections(conn_id) ON DELETE SET NULL; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.connections + ADD CONSTRAINT connections_user_id_fkey FOREIGN KEY (user_id) REFERENCES smp_agent_test_protocol_schema.users(user_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.deleted_snd_chunk_replicas + ADD CONSTRAINT deleted_snd_chunk_replicas_user_id_fkey FOREIGN KEY (user_id) REFERENCES smp_agent_test_protocol_schema.users(user_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.deleted_snd_chunk_replicas + ADD CONSTRAINT deleted_snd_chunk_replicas_xftp_server_id_fkey FOREIGN KEY (xftp_server_id) REFERENCES smp_agent_test_protocol_schema.xftp_servers(xftp_server_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.encrypted_rcv_message_hashes + ADD CONSTRAINT encrypted_rcv_message_hashes_conn_id_fkey FOREIGN KEY (conn_id) REFERENCES smp_agent_test_protocol_schema.connections(conn_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.messages + ADD CONSTRAINT fk_messages_rcv_messages FOREIGN KEY (conn_id, internal_rcv_id) REFERENCES smp_agent_test_protocol_schema.rcv_messages(conn_id, internal_rcv_id) ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.messages + ADD CONSTRAINT fk_messages_snd_messages FOREIGN KEY (conn_id, internal_snd_id) REFERENCES smp_agent_test_protocol_schema.snd_messages(conn_id, internal_snd_id) ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.inv_short_links + ADD CONSTRAINT inv_short_links_host_port_fkey FOREIGN KEY (host, port) REFERENCES smp_agent_test_protocol_schema.servers(host, port) ON UPDATE CASCADE ON DELETE RESTRICT; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.messages + ADD CONSTRAINT messages_conn_id_fkey FOREIGN KEY (conn_id) REFERENCES smp_agent_test_protocol_schema.connections(conn_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.ntf_subscriptions + ADD CONSTRAINT ntf_subscriptions_ntf_host_ntf_port_fkey FOREIGN KEY (ntf_host, ntf_port) REFERENCES smp_agent_test_protocol_schema.ntf_servers(ntf_host, ntf_port) ON UPDATE CASCADE ON DELETE RESTRICT; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.ntf_subscriptions + ADD CONSTRAINT ntf_subscriptions_smp_host_smp_port_fkey FOREIGN KEY (smp_host, smp_port) REFERENCES smp_agent_test_protocol_schema.servers(host, port) ON UPDATE CASCADE ON DELETE SET NULL; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.ntf_tokens + ADD CONSTRAINT ntf_tokens_ntf_host_ntf_port_fkey FOREIGN KEY (ntf_host, ntf_port) REFERENCES smp_agent_test_protocol_schema.ntf_servers(ntf_host, ntf_port) ON UPDATE CASCADE ON DELETE RESTRICT; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.processed_ratchet_key_hashes + ADD CONSTRAINT processed_ratchet_key_hashes_conn_id_fkey FOREIGN KEY (conn_id) REFERENCES smp_agent_test_protocol_schema.connections(conn_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.ratchets + ADD CONSTRAINT ratchets_conn_id_fkey FOREIGN KEY (conn_id) REFERENCES smp_agent_test_protocol_schema.connections(conn_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_file_chunk_replicas + ADD CONSTRAINT rcv_file_chunk_replicas_rcv_file_chunk_id_fkey FOREIGN KEY (rcv_file_chunk_id) REFERENCES smp_agent_test_protocol_schema.rcv_file_chunks(rcv_file_chunk_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_file_chunk_replicas + ADD CONSTRAINT rcv_file_chunk_replicas_xftp_server_id_fkey FOREIGN KEY (xftp_server_id) REFERENCES smp_agent_test_protocol_schema.xftp_servers(xftp_server_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_file_chunks + ADD CONSTRAINT rcv_file_chunks_rcv_file_id_fkey FOREIGN KEY (rcv_file_id) REFERENCES smp_agent_test_protocol_schema.rcv_files(rcv_file_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_files + ADD CONSTRAINT rcv_files_redirect_id_fkey FOREIGN KEY (redirect_id) REFERENCES smp_agent_test_protocol_schema.rcv_files(rcv_file_id) ON DELETE SET NULL; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_files + ADD CONSTRAINT rcv_files_user_id_fkey FOREIGN KEY (user_id) REFERENCES smp_agent_test_protocol_schema.users(user_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_messages + ADD CONSTRAINT rcv_messages_conn_id_internal_id_fkey FOREIGN KEY (conn_id, internal_id) REFERENCES smp_agent_test_protocol_schema.messages(conn_id, internal_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_queues + ADD CONSTRAINT rcv_queues_client_notice_id_fkey FOREIGN KEY (client_notice_id) REFERENCES smp_agent_test_protocol_schema.client_notices(client_notice_id) ON UPDATE RESTRICT ON DELETE SET NULL; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_queues + ADD CONSTRAINT rcv_queues_conn_id_fkey FOREIGN KEY (conn_id) REFERENCES smp_agent_test_protocol_schema.connections(conn_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.rcv_queues + ADD CONSTRAINT rcv_queues_host_port_fkey FOREIGN KEY (host, port) REFERENCES smp_agent_test_protocol_schema.servers(host, port) ON UPDATE CASCADE ON DELETE RESTRICT; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.skipped_messages + ADD CONSTRAINT skipped_messages_conn_id_fkey FOREIGN KEY (conn_id) REFERENCES smp_agent_test_protocol_schema.ratchets(conn_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_file_chunk_replica_recipients + ADD CONSTRAINT snd_file_chunk_replica_recipient_snd_file_chunk_replica_id_fkey FOREIGN KEY (snd_file_chunk_replica_id) REFERENCES smp_agent_test_protocol_schema.snd_file_chunk_replicas(snd_file_chunk_replica_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_file_chunk_replicas + ADD CONSTRAINT snd_file_chunk_replicas_snd_file_chunk_id_fkey FOREIGN KEY (snd_file_chunk_id) REFERENCES smp_agent_test_protocol_schema.snd_file_chunks(snd_file_chunk_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_file_chunk_replicas + ADD CONSTRAINT snd_file_chunk_replicas_xftp_server_id_fkey FOREIGN KEY (xftp_server_id) REFERENCES smp_agent_test_protocol_schema.xftp_servers(xftp_server_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_file_chunks + ADD CONSTRAINT snd_file_chunks_snd_file_id_fkey FOREIGN KEY (snd_file_id) REFERENCES smp_agent_test_protocol_schema.snd_files(snd_file_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_files + ADD CONSTRAINT snd_files_user_id_fkey FOREIGN KEY (user_id) REFERENCES smp_agent_test_protocol_schema.users(user_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_message_deliveries + ADD CONSTRAINT snd_message_deliveries_conn_id_fkey FOREIGN KEY (conn_id) REFERENCES smp_agent_test_protocol_schema.connections(conn_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_message_deliveries + ADD CONSTRAINT snd_message_deliveries_conn_id_internal_id_fkey FOREIGN KEY (conn_id, internal_id) REFERENCES smp_agent_test_protocol_schema.messages(conn_id, internal_id) ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_messages + ADD CONSTRAINT snd_messages_conn_id_internal_id_fkey FOREIGN KEY (conn_id, internal_id) REFERENCES smp_agent_test_protocol_schema.messages(conn_id, internal_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_messages + ADD CONSTRAINT snd_messages_snd_message_body_id_fkey FOREIGN KEY (snd_message_body_id) REFERENCES smp_agent_test_protocol_schema.snd_message_bodies(snd_message_body_id) ON DELETE SET NULL; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_queues + ADD CONSTRAINT snd_queues_conn_id_fkey FOREIGN KEY (conn_id) REFERENCES smp_agent_test_protocol_schema.connections(conn_id) ON DELETE CASCADE; + + + +ALTER TABLE ONLY smp_agent_test_protocol_schema.snd_queues + ADD CONSTRAINT snd_queues_host_port_fkey FOREIGN KEY (host, port) REFERENCES smp_agent_test_protocol_schema.servers(host, port) ON UPDATE CASCADE ON DELETE RESTRICT; + + + diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Util.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Util.hs index 0913c76e3..bcbb0e281 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Util.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Util.hs @@ -21,30 +21,32 @@ import Database.PostgreSQL.Simple.SqlQQ (sql) createDBAndUserIfNotExists :: ConnectInfo -> IO () createDBAndUserIfNotExists ConnectInfo {connectUser = user, connectDatabase = dbName} = do -- connect to the default "postgres" maintenance database - bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ - \postgresDB -> do - void $ PSQL.execute_ postgresDB "SET client_min_messages TO WARNING" - -- check if the user exists, create if not - [Only userExists] <- - PSQL.query - postgresDB - [sql| - SELECT EXISTS ( - SELECT 1 FROM pg_catalog.pg_roles - WHERE rolname = ? - ) - |] - (Only user) - unless userExists $ void $ PSQL.execute_ postgresDB (fromString $ "CREATE USER " <> user) - -- check if the database exists, create if not - dbExists <- checkDBExists postgresDB dbName - unless dbExists $ void $ PSQL.execute_ postgresDB (fromString $ "CREATE DATABASE " <> dbName <> " OWNER " <> user) + bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ \db -> do + execSQL db "SET client_min_messages TO WARNING" + -- check if the user exists, create if not + [Only userExists] <- + PSQL.query + db + [sql| + SELECT EXISTS ( + SELECT 1 FROM pg_catalog.pg_roles + WHERE rolname = ? + ) + |] + (Only user) + unless userExists $ execSQL db $ "CREATE USER " <> user + -- check if the database exists, create if not + dbExists <- checkDBExists db dbName + unless dbExists $ do + execSQL db $ "CREATE DATABASE " <> dbName <> " OWNER " <> user + bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = dbName}) PSQL.close $ + (`execSQL` "CREATE EXTENSION IF NOT EXISTS pgcrypto") checkDBExists :: PSQL.Connection -> String -> IO Bool -checkDBExists postgresDB dbName = do +checkDBExists db dbName = do [Only dbExists] <- PSQL.query - postgresDB + db [sql| SELECT EXISTS ( SELECT 1 FROM pg_catalog.pg_database @@ -56,45 +58,45 @@ checkDBExists postgresDB dbName = do dropSchema :: ConnectInfo -> String -> IO () dropSchema connectInfo schema = - bracket (PSQL.connect connectInfo) PSQL.close $ - \db -> do - void $ PSQL.execute_ db "SET client_min_messages TO WARNING" - void $ PSQL.execute_ db (fromString $ "DROP SCHEMA IF EXISTS " <> schema <> " CASCADE") + bracket (PSQL.connect connectInfo) PSQL.close $ \db -> do + execSQL db "SET client_min_messages TO WARNING" + execSQL db $ "DROP SCHEMA IF EXISTS " <> schema <> " CASCADE" dropAllSchemasExceptSystem :: ConnectInfo -> IO () dropAllSchemasExceptSystem connectInfo = - bracket (PSQL.connect connectInfo) PSQL.close $ - \db -> do - void $ PSQL.execute_ db "SET client_min_messages TO WARNING" - schemaNames :: [Only String] <- - PSQL.query_ + bracket (PSQL.connect connectInfo) PSQL.close $ \db -> do + execSQL db "SET client_min_messages TO WARNING" + schemaNames :: [Only String] <- + PSQL.query_ + db + [sql| + SELECT schema_name + FROM information_schema.schemata + WHERE schema_name NOT IN ('public', 'pg_catalog', 'information_schema') + |] + forM_ schemaNames $ \(Only schema) -> + execSQL db $ "DROP SCHEMA " <> schema <> " CASCADE" + +dropDatabaseAndUser :: ConnectInfo -> IO () +dropDatabaseAndUser ConnectInfo {connectUser = user, connectDatabase = dbName} = + bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ \db -> do + execSQL db "SET client_min_messages TO WARNING" + dbExists <- checkDBExists db dbName + when dbExists $ do + execSQL db $ "ALTER DATABASE " <> dbName <> " WITH ALLOW_CONNECTIONS false" + -- terminate all connections to the database + _r :: [Only Bool] <- + PSQL.query db [sql| - SELECT schema_name - FROM information_schema.schemata - WHERE schema_name NOT IN ('public', 'pg_catalog', 'information_schema') + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE datname = ? + AND pid <> pg_backend_pid() |] - forM_ schemaNames $ \(Only schema) -> - PSQL.execute_ db (fromString $ "DROP SCHEMA " <> schema <> " CASCADE") + (Only dbName) + execSQL db $ "DROP DATABASE " <> dbName + execSQL db $ "DROP USER IF EXISTS " <> user -dropDatabaseAndUser :: ConnectInfo -> IO () -dropDatabaseAndUser ConnectInfo {connectUser = user, connectDatabase = dbName} = - bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ - \postgresDB -> do - void $ PSQL.execute_ postgresDB "SET client_min_messages TO WARNING" - dbExists <- checkDBExists postgresDB dbName - when dbExists $ do - void $ PSQL.execute_ postgresDB (fromString $ "ALTER DATABASE " <> dbName <> " WITH ALLOW_CONNECTIONS false") - -- terminate all connections to the database - _r :: [Only Bool] <- - PSQL.query - postgresDB - [sql| - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE datname = ? - AND pid <> pg_backend_pid() - |] - (Only dbName) - void $ PSQL.execute_ postgresDB (fromString $ "DROP DATABASE " <> dbName) - void $ PSQL.execute_ postgresDB (fromString $ "DROP USER IF EXISTS " <> user) +execSQL :: PSQL.Connection -> String -> IO () +execSQL db = void . PSQL.execute_ db . fromString diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 688eae0d2..45c1f26ad 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -42,9 +42,15 @@ module Simplex.Messaging.Agent.Store.SQLite ) where +import Control.Concurrent.MVar +import Control.Concurrent.STM +import Control.Exception (bracketOnError, onException, throwIO) import Control.Monad +import Data.Bits (xor) import Data.ByteArray (ScrubbedBytes) import qualified Data.ByteArray as BA +import Data.ByteString (ByteString) +import qualified Data.ByteString as B import Data.Functor (($>)) import Data.IORef import Data.Maybe (fromMaybe) @@ -54,17 +60,19 @@ import Database.SQLite.Simple (Query (..)) import qualified Database.SQLite.Simple as SQL import Database.SQLite.Simple.QQ (sql) import qualified Database.SQLite3 as SQLite3 +import Database.SQLite3.Bindings +import Foreign.C.Types +import Foreign.Ptr import Simplex.Messaging.Agent.Store.Migrations (DBMigrate (..), sharedMigrateSchema) import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Agent.Store.SQLite.Common import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationError (..)) +import Simplex.Messaging.Agent.Store.SQLite.Util (SQLiteFunc, createStaticFunction, mkSQLiteFunc) +import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Util (ifM, safeDecodeUtf8) import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) import System.FilePath (takeDirectory, takeFileName, ()) -import UnliftIO.Exception (bracketOnError, onException) -import UnliftIO.MVar -import UnliftIO.STM -- * SQLite Store implementation @@ -109,9 +117,9 @@ connectDB path key track = do pure db where prepare db = do - let exec = SQLite3.exec $ SQL.connectionHandle $ DB.conn db - unless (BA.null key) . exec $ "PRAGMA key = " <> keyString key <> ";" - exec . fromQuery $ + let db' = SQL.connectionHandle $ DB.conn db + unless (BA.null key) . SQLite3.exec db' $ "PRAGMA key = " <> keyString key <> ";" + SQLite3.exec db' . fromQuery $ [sql| PRAGMA busy_timeout = 100; PRAGMA foreign_keys = ON; @@ -119,6 +127,21 @@ connectDB path key track = do PRAGMA secure_delete = ON; PRAGMA auto_vacuum = FULL; |] + createStaticFunction db' "simplex_xor_md5_combine" 2 True sqliteXorMd5CombinePtr + >>= either (throwIO . userError . show) pure + +foreign export ccall "simplex_xor_md5_combine" sqliteXorMd5Combine :: SQLiteFunc + +foreign import ccall "&simplex_xor_md5_combine" sqliteXorMd5CombinePtr :: FunPtr SQLiteFunc + +sqliteXorMd5Combine :: SQLiteFunc +sqliteXorMd5Combine = mkSQLiteFunc $ \cxt args -> do + idsHash <- SQLite3.funcArgBlob args 0 + rId <- SQLite3.funcArgBlob args 1 + SQLite3.funcResultBlob cxt $ xorMd5Combine idsHash rId + +xorMd5Combine :: ByteString -> ByteString -> ByteString +xorMd5Combine idsHash rId = B.packZipWith xor idsHash $ C.md5Hash rId closeDBStore :: DBStore -> IO () closeDBStore st@DBStore {dbClosed} = diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index 3800dc362..af70c41f5 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -53,6 +53,12 @@ withConnectionPriority DBStore {dbSem, dbConnection} priority action | priority = E.bracket_ signal release $ withMVar dbConnection action | otherwise = lowPriority where + -- To debug FK errors, set foreign_keys = OFF in Simplex.Messaging.Agent.Store.SQLite and use action' instead of action + -- action' conn = do + -- r <- action conn + -- violations <- DB.query_ conn "PRAGMA foreign_key_check" :: IO [ (String, Int, String, Int)] + -- unless (null violations) $ print violations + -- pure r lowPriority = wait >> withMVar dbConnection (\db -> ifM free (Just <$> action db) (pure Nothing)) >>= maybe lowPriority pure signal = atomically $ modifyTVar' dbSem (+ 1) release = atomically $ modifyTVar' dbSem $ \sem -> if sem > 0 then sem - 1 else 0 diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251020_service_certs.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251020_service_certs.hs index 780ced1d4..ee6a0095a 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251020_service_certs.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251020_service_certs.hs @@ -5,7 +5,6 @@ module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251020_service_certs w import Database.SQLite.Simple (Query) import Database.SQLite.Simple.QQ (sql) --- TODO move date forward, create migration for postgres m20251020_service_certs :: Query m20251020_service_certs = [sql| @@ -13,27 +12,81 @@ CREATE TABLE client_services( user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE, host TEXT NOT NULL, port TEXT NOT NULL, + server_key_hash BLOB, service_cert BLOB NOT NULL, service_cert_hash BLOB NOT NULL, service_priv_key BLOB NOT NULL, - rcv_service_id BLOB, + service_id BLOB, + service_queue_count INTEGER NOT NULL DEFAULT 0, + service_queue_ids_hash BLOB NOT NULL DEFAULT x'00000000000000000000000000000000', FOREIGN KEY(host, port) REFERENCES servers ON UPDATE CASCADE ON DELETE RESTRICT ); -CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(user_id, host, port); - +CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services(user_id, host, port, server_key_hash); CREATE INDEX idx_server_certs_host_port ON client_services(host, port); ALTER TABLE rcv_queues ADD COLUMN rcv_service_assoc INTEGER NOT NULL DEFAULT 0; + +CREATE TRIGGER tr_rcv_queue_insert +AFTER INSERT ON rcv_queues +FOR EACH ROW +WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 +BEGIN + UPDATE client_services + SET service_queue_count = service_queue_count + 1, + service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id) + WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id) + AND host = NEW.host AND port = NEW.port; +END; + +CREATE TRIGGER tr_rcv_queue_delete +AFTER DELETE ON rcv_queues +FOR EACH ROW +WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 +BEGIN + UPDATE client_services + SET service_queue_count = service_queue_count - 1, + service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id) + WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id) + AND host = OLD.host AND port = OLD.port; +END; + +CREATE TRIGGER tr_rcv_queue_update_remove +AFTER UPDATE ON rcv_queues +FOR EACH ROW +WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 AND NOT (NEW.rcv_service_assoc != 0 AND NEW.deleted = 0) +BEGIN + UPDATE client_services + SET service_queue_count = service_queue_count - 1, + service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id) + WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id) + AND host = OLD.host AND port = OLD.port; +END; + +CREATE TRIGGER tr_rcv_queue_update_add +AFTER UPDATE ON rcv_queues +FOR EACH ROW +WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 AND NOT (OLD.rcv_service_assoc != 0 AND OLD.deleted = 0) +BEGIN + UPDATE client_services + SET service_queue_count = service_queue_count + 1, + service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id) + WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id) + AND host = NEW.host AND port = NEW.port; +END; |] down_m20251020_service_certs :: Query down_m20251020_service_certs = [sql| +DROP TRIGGER tr_rcv_queue_insert; +DROP TRIGGER tr_rcv_queue_delete; +DROP TRIGGER tr_rcv_queue_update_remove; +DROP TRIGGER tr_rcv_queue_update_add; + ALTER TABLE rcv_queues DROP COLUMN rcv_service_assoc; DROP INDEX idx_server_certs_host_port; - DROP INDEX idx_server_certs_user_id_host_port; DROP TABLE client_services; diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql index 8013313ac..339e3a8ee 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql @@ -455,10 +455,13 @@ CREATE TABLE client_services( user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE, host TEXT NOT NULL, port TEXT NOT NULL, + server_key_hash BLOB, service_cert BLOB NOT NULL, service_cert_hash BLOB NOT NULL, service_priv_key BLOB NOT NULL, - rcv_service_id BLOB, + service_id BLOB, + service_queue_count INTEGER NOT NULL DEFAULT 0, + service_queue_ids_hash BLOB NOT NULL DEFAULT x'00000000000000000000000000000000', FOREIGN KEY(host, port) REFERENCES servers ON UPDATE CASCADE ON DELETE RESTRICT ); CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues(host, port, ntf_id); @@ -607,6 +610,51 @@ CREATE INDEX idx_rcv_queues_client_notice_id ON rcv_queues(client_notice_id); CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON client_services( user_id, host, - port + port, + server_key_hash ); CREATE INDEX idx_server_certs_host_port ON client_services(host, port); +CREATE TRIGGER tr_rcv_queue_insert +AFTER INSERT ON rcv_queues +FOR EACH ROW +WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 +BEGIN + UPDATE client_services + SET service_queue_count = service_queue_count + 1, + service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id) + WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id) + AND host = NEW.host AND port = NEW.port; +END; +CREATE TRIGGER tr_rcv_queue_delete +AFTER DELETE ON rcv_queues +FOR EACH ROW +WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 +BEGIN + UPDATE client_services + SET service_queue_count = service_queue_count - 1, + service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id) + WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id) + AND host = OLD.host AND port = OLD.port; +END; +CREATE TRIGGER tr_rcv_queue_update_remove +AFTER UPDATE ON rcv_queues +FOR EACH ROW +WHEN OLD.rcv_service_assoc != 0 AND OLD.deleted = 0 AND NOT (NEW.rcv_service_assoc != 0 AND NEW.deleted = 0) +BEGIN + UPDATE client_services + SET service_queue_count = service_queue_count - 1, + service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, OLD.rcv_id) + WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = OLD.conn_id) + AND host = OLD.host AND port = OLD.port; +END; +CREATE TRIGGER tr_rcv_queue_update_add +AFTER UPDATE ON rcv_queues +FOR EACH ROW +WHEN NEW.rcv_service_assoc != 0 AND NEW.deleted = 0 AND NOT (OLD.rcv_service_assoc != 0 AND OLD.deleted = 0) +BEGIN + UPDATE client_services + SET service_queue_count = service_queue_count + 1, + service_queue_ids_hash = simplex_xor_md5_combine(service_queue_ids_hash, NEW.rcv_id) + WHERE user_id = (SELECT user_id FROM connections WHERE conn_id = NEW.conn_id) + AND host = NEW.host AND port = NEW.port; +END; diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs new file mode 100644 index 000000000..a3c3b94ac --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs @@ -0,0 +1,41 @@ +module Simplex.Messaging.Agent.Store.SQLite.Util where + +import Control.Exception (SomeException, catch, mask_) +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..)) +import Database.SQLite3.Bindings +import Foreign.C.String +import Foreign.Ptr +import Foreign.StablePtr + +data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal) + +type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO () + +mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc +mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals) +{-# INLINE mkSQLiteFunc #-} + +-- Based on createFunction from Database.SQLite3.Direct, but uses static function pointer to avoid dynamic wrapper that triggers DCL. +createStaticFunction :: Database -> ByteString -> CArgCount -> Bool -> FunPtr SQLiteFunc -> IO (Either Error ()) +createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do + u <- newStablePtr $ CFuncPtrs funPtr nullFunPtr nullFunPtr + let flags = if isDet then c_SQLITE_DETERMINISTIC else 0 + B.useAsCString name $ \namePtr -> + toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr + +-- Convert a 'CError' to a 'Either Error', in the common case where +-- SQLITE_OK signals success and anything else signals an error. +-- +-- Note that SQLITE_OK == 0. +toResult :: a -> CError -> Either Error a +toResult a (CError 0) = Right a +toResult _ code = Left $ decodeError code + +-- call c_sqlite3_result_error in the event of an error +catchAsResultError :: Ptr CContext -> IO () -> IO () +catchAsResultError ctx action = catch action $ \exn -> do + let msg = show (exn :: SomeException) + withCAStringLen msg $ \(ptr, len) -> + c_sqlite3_result_error ctx ptr (fromIntegral len) diff --git a/src/Simplex/Messaging/Agent/TSessionSubs.hs b/src/Simplex/Messaging/Agent/TSessionSubs.hs index cce103fe6..ab15b9793 100644 --- a/src/Simplex/Messaging/Agent/TSessionSubs.hs +++ b/src/Simplex/Messaging/Agent/TSessionSubs.hs @@ -2,6 +2,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TupleSections #-} module Simplex.Messaging.Agent.TSessionSubs ( TSessionSubs (sessionSubs), @@ -12,7 +13,10 @@ module Simplex.Messaging.Agent.TSessionSubs hasPendingSub, addPendingSub, setSessionId, + setPendingServiceSub, + setActiveServiceSub, addActiveSub, + addActiveSub', batchAddActiveSubs, batchAddPendingSubs, deletePendingSub, @@ -38,13 +42,13 @@ import qualified Data.Map.Strict as M import Data.Maybe (isJust) import qualified Data.Set as S import Simplex.Messaging.Agent.Protocol (SMPQueue (..)) -import Simplex.Messaging.Agent.Store (RcvQueueSub (..), SomeRcvQueue) +import Simplex.Messaging.Agent.Store (RcvQueue, RcvQueueSub (..), SomeRcvQueue, StoredRcvQueue (rcvServiceAssoc), rcvQueueSub) import Simplex.Messaging.Client (SMPTransportSession, TransportSessionMode (..)) -import Simplex.Messaging.Protocol (RecipientId) +import Simplex.Messaging.Protocol (RecipientId, ServiceSub (..), queueIdHash) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport -import Simplex.Messaging.Util (($>>=)) +import Simplex.Messaging.Util (anyM, ($>>=)) data TSessionSubs = TSessionSubs { sessionSubs :: TMap SMPTransportSession SessSubs @@ -53,7 +57,9 @@ data TSessionSubs = TSessionSubs data SessSubs = SessSubs { subsSessId :: TVar (Maybe SessionId), activeSubs :: TMap RecipientId RcvQueueSub, - pendingSubs :: TMap RecipientId RcvQueueSub + pendingSubs :: TMap RecipientId RcvQueueSub, + activeServiceSub :: TVar (Maybe ServiceSub), + pendingServiceSub :: TVar (Maybe ServiceSub) } emptyIO :: IO TSessionSubs @@ -72,7 +78,7 @@ getSessSubs :: SMPTransportSession -> TSessionSubs -> STM SessSubs getSessSubs tSess ss = lookupSubs tSess ss >>= maybe new pure where new = do - s <- SessSubs <$> newTVar Nothing <*> newTVar M.empty <*> newTVar M.empty + s <- SessSubs <$> newTVar Nothing <*> newTVar M.empty <*> newTVar M.empty <*> newTVar Nothing <*> newTVar Nothing TM.insert tSess s $ sessionSubs ss pure s @@ -98,8 +104,27 @@ setSessionId tSess sessId ss = do Nothing -> writeTVar (subsSessId s) (Just sessId) Just sessId' -> unless (sessId == sessId') $ void $ setSubsPending_ s $ Just sessId -addActiveSub :: SMPTransportSession -> SessionId -> RcvQueueSub -> TSessionSubs -> STM () -addActiveSub tSess sessId rq ss = do +setPendingServiceSub :: SMPTransportSession -> ServiceSub -> TSessionSubs -> STM () +setPendingServiceSub tSess serviceSub ss = do + s <- getSessSubs tSess ss + writeTVar (pendingServiceSub s) $ Just serviceSub + +setActiveServiceSub :: SMPTransportSession -> SessionId -> ServiceSub -> TSessionSubs -> STM () +setActiveServiceSub tSess sessId serviceSub ss = do + s <- getSessSubs tSess ss + sessId' <- readTVar $ subsSessId s + if Just sessId == sessId' + then do + writeTVar (activeServiceSub s) $ Just serviceSub + writeTVar (pendingServiceSub s) Nothing + else writeTVar (pendingServiceSub s) $ Just serviceSub + +addActiveSub :: SMPTransportSession -> SessionId -> RcvQueue -> TSessionSubs -> STM () +addActiveSub tSess sessId rq = addActiveSub' tSess sessId (rcvQueueSub rq) (rcvServiceAssoc rq) +{-# INLINE addActiveSub #-} + +addActiveSub' :: SMPTransportSession -> SessionId -> RcvQueueSub -> Bool -> TSessionSubs -> STM () +addActiveSub' tSess sessId rq serviceAssoc ss = do s <- getSessSubs tSess ss sessId' <- readTVar $ subsSessId s let rId = rcvId rq @@ -107,10 +132,13 @@ addActiveSub tSess sessId rq ss = do then do TM.insert rId rq $ activeSubs s TM.delete rId $ pendingSubs s + when serviceAssoc $ + let updateServiceSub (ServiceSub serviceId n idsHash) = ServiceSub serviceId (n + 1) (idsHash <> queueIdHash rId) + in modifyTVar' (activeServiceSub s) (updateServiceSub <$>) else TM.insert rId rq $ pendingSubs s -batchAddActiveSubs :: SMPTransportSession -> SessionId -> [RcvQueueSub] -> TSessionSubs -> STM () -batchAddActiveSubs tSess sessId rqs ss = do +batchAddActiveSubs :: SMPTransportSession -> SessionId -> ([RcvQueueSub], [RcvQueueSub]) -> TSessionSubs -> STM () +batchAddActiveSubs tSess sessId (rqs, serviceRQs) ss = do s <- getSessSubs tSess ss sessId' <- readTVar $ subsSessId s let qs = M.fromList $ map (\rq -> (rcvId rq, rq)) rqs @@ -118,6 +146,12 @@ batchAddActiveSubs tSess sessId rqs ss = do then do TM.union qs $ activeSubs s modifyTVar' (pendingSubs s) (`M.difference` qs) + serviceSub_ <- readTVar $ activeServiceSub s + forM_ serviceSub_ $ \(ServiceSub serviceId n idsHash) -> do + unless (null serviceRQs) $ do + let idsHash' = idsHash <> mconcat (map (queueIdHash . rcvId) serviceRQs) + n' = n + fromIntegral (length serviceRQs) + writeTVar (activeServiceSub s) $ Just $ ServiceSub serviceId n' idsHash' else TM.union qs $ pendingSubs s batchAddPendingSubs :: SMPTransportSession -> [RcvQueueSub] -> TSessionSubs -> STM () @@ -143,11 +177,15 @@ batchDeleteSubs tSess rqs = lookupSubs tSess >=> mapM_ (\s -> delete (activeSubs delete = (`modifyTVar'` (`M.withoutKeys` rIds)) hasPendingSubs :: SMPTransportSession -> TSessionSubs -> STM Bool -hasPendingSubs tSess = lookupSubs tSess >=> maybe (pure False) (fmap (not . null) . readTVar . pendingSubs) +hasPendingSubs tSess = lookupSubs tSess >=> maybe (pure False) (\s -> anyM [hasSubs s, hasServiceSub s]) + where + hasSubs = fmap (not . null) . readTVar . pendingSubs + hasServiceSub = fmap isJust . readTVar . pendingServiceSub -getPendingSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub) -getPendingSubs = getSubs_ pendingSubs -{-# INLINE getPendingSubs #-} +getPendingSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub, Maybe ServiceSub) +getPendingSubs tSess = lookupSubs tSess >=> maybe (pure (M.empty, Nothing)) get + where + get s = liftM2 (,) (readTVar $ pendingSubs s) (readTVar $ pendingServiceSub s) getActiveSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub) getActiveSubs = getSubs_ activeSubs @@ -156,7 +194,7 @@ getActiveSubs = getSubs_ activeSubs getSubs_ :: (SessSubs -> TMap RecipientId RcvQueueSub) -> SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub) getSubs_ subs tSess = lookupSubs tSess >=> maybe (pure M.empty) (readTVar . subs) -setSubsPending :: TransportSessionMode -> SMPTransportSession -> SessionId -> TSessionSubs -> STM (Map RecipientId RcvQueueSub) +setSubsPending :: TransportSessionMode -> SMPTransportSession -> SessionId -> TSessionSubs -> STM (Map RecipientId RcvQueueSub, Maybe ServiceSub) setSubsPending mode tSess@(uId, srv, connId_) sessId tss@(TSessionSubs ss) | entitySession == isJust connId_ = TM.lookup tSess ss >>= withSessSubs (`setSubsPending_` Nothing) @@ -166,17 +204,17 @@ setSubsPending mode tSess@(uId, srv, connId_) sessId tss@(TSessionSubs ss) entitySession = mode == TSMEntity sessEntId = if entitySession then Just else const Nothing withSessSubs run = \case - Nothing -> pure M.empty + Nothing -> pure (M.empty, Nothing) Just s -> do sessId' <- readTVar $ subsSessId s - if Just sessId == sessId' then run s else pure M.empty + if Just sessId == sessId' then run s else pure (M.empty, Nothing) setPendingChangeMode s = do subs <- M.union <$> readTVar (activeSubs s) <*> readTVar (pendingSubs s) unless (null subs) $ forM_ subs $ \rq -> addPendingSub (uId, srv, sessEntId (connId rq)) rq tss - pure subs + (subs,) <$> setServiceSubPending_ s -setSubsPending_ :: SessSubs -> Maybe SessionId -> STM (Map RecipientId RcvQueueSub) +setSubsPending_ :: SessSubs -> Maybe SessionId -> STM (Map RecipientId RcvQueueSub, Maybe ServiceSub) setSubsPending_ s sessId_ = do writeTVar (subsSessId s) sessId_ let as = activeSubs s @@ -184,7 +222,15 @@ setSubsPending_ s sessId_ = do unless (null subs) $ do writeTVar as M.empty modifyTVar' (pendingSubs s) $ M.union subs - pure subs + (subs,) <$> setServiceSubPending_ s + +setServiceSubPending_ :: SessSubs -> STM (Maybe ServiceSub) +setServiceSubPending_ s = do + serviceSub_ <- readTVar $ activeServiceSub s + forM_ serviceSub_ $ \serviceSub -> do + writeTVar (activeServiceSub s) Nothing + writeTVar (pendingServiceSub s) $ Just serviceSub + pure serviceSub_ updateClientNotices :: SMPTransportSession -> [(RecipientId, Maybe Int64)] -> TSessionSubs -> STM () updateClientNotices tSess noticeIds ss = do diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 4f70efcf2..58ffd1418 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -909,18 +909,18 @@ nsubResponse_ = \case {-# INLINE nsubResponse_ #-} -- This command is always sent in background request mode -subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO (Int64, IdsHash) -subscribeService c party = case smpClientService c of +subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> Int64 -> IdsHash -> ExceptT SMPClientError IO ServiceSub +subscribeService c party n idsHash = case smpClientService c of Just THClientService {serviceId, serviceKey} -> do liftIO $ enablePings c sendSMPCommand c NRMBackground (Just (C.APrivateAuthKey C.SEd25519 serviceKey)) serviceId subCmd >>= \case - SOKS n idsHash -> pure (n, idsHash) + SOKS n' idsHash' -> pure $ ServiceSub serviceId n' idsHash' r -> throwE $ unexpectedResponse r where subCmd :: Command p subCmd = case party of - SRecipientService -> SUBS - SNotifierService -> NSUBS + SRecipientService -> SUBS n idsHash + SNotifierService -> NSUBS n idsHash Nothing -> throwE PCEServiceUnavailable smpClientService :: SMPClient -> Maybe THClientService diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 722a86c7e..45d747d21 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -45,7 +45,6 @@ import Crypto.Random (ChaChaDRG) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Constraint (Dict (..)) -import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) @@ -69,10 +68,12 @@ import Simplex.Messaging.Protocol ProtocolServer (..), QueueId, SMPServer, + ServiceSub (..), SParty (..), ServiceParty, serviceParty, - partyServiceRole + partyServiceRole, + queueIdsHash, ) import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) @@ -91,14 +92,14 @@ data SMPClientAgentEvent | CADisconnected SMPServer (NonEmpty QueueId) | CASubscribed SMPServer (Maybe ServiceId) (NonEmpty QueueId) | CASubError SMPServer (NonEmpty (QueueId, SMPClientError)) - | CAServiceDisconnected SMPServer (ServiceId, Int64) - | CAServiceSubscribed SMPServer (ServiceId, Int64) Int64 - | CAServiceSubError SMPServer (ServiceId, Int64) SMPClientError + | CAServiceDisconnected SMPServer ServiceSub + | CAServiceSubscribed {subServer :: SMPServer, expected :: ServiceSub, subscribed :: ServiceSub} + | CAServiceSubError SMPServer ServiceSub SMPClientError -- CAServiceUnavailable is used when service ID in pending subscription is different from the current service in connection. -- This will require resubscribing to all queues associated with this service ID individually, creating new associations. -- It may happen if, for example, SMP server deletes service information (e.g. via downgrade and upgrade) -- and assigns different service ID to the service certificate. - | CAServiceUnavailable SMPServer (ServiceId, Int64) + | CAServiceUnavailable SMPServer ServiceSub data SMPClientAgentConfig = SMPClientAgentConfig { smpCfg :: ProtocolClientConfig SMPVersion, @@ -142,11 +143,11 @@ data SMPClientAgent p = SMPClientAgent -- Only one service subscription can exist per server with this agent. -- With correctly functioning SMP server, queue and service subscriptions can't be -- active at the same time. - activeServiceSubs :: TMap SMPServer (TVar (Maybe ((ServiceId, Int64), SessionId))), + activeServiceSubs :: TMap SMPServer (TVar (Maybe (ServiceSub, SessionId))), activeQueueSubs :: TMap SMPServer (TMap QueueId (SessionId, C.APrivateAuthKey)), -- Pending service subscriptions can co-exist with pending queue subscriptions -- on the same SMP server during subscriptions being transitioned from per-queue to service. - pendingServiceSubs :: TMap SMPServer (TVar (Maybe (ServiceId, Int64))), + pendingServiceSubs :: TMap SMPServer (TVar (Maybe ServiceSub)), pendingQueueSubs :: TMap SMPServer (TMap QueueId C.APrivateAuthKey), smpSubWorkers :: TMap SMPServer (SessionVar (Async ())), workerSeq :: TVar Int @@ -256,7 +257,7 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random removeClientAndSubs smp >>= serverDown logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv - removeClientAndSubs :: SMPClient -> IO (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) + removeClientAndSubs :: SMPClient -> IO (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey)) removeClientAndSubs smp = do -- Looking up subscription vars outside of STM transaction to reduce re-evaluation. -- It is possible because these vars are never removed, they are only added. @@ -287,7 +288,7 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random then pure Nothing else Just subs <$ addSubs_ (pendingQueueSubs ca) srv subs - serverDown :: (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> IO () + serverDown :: (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey)) -> IO () serverDown (sSub, qSubs) = do mapM_ (notify ca . CAServiceDisconnected srv) sSub let qIds = L.nonEmpty . M.keys =<< qSubs @@ -317,7 +318,7 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s loop ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg noPending (sSub, qSubs) = isNothing sSub && maybe True M.null qSubs - getPending :: Monad m => (forall a. SMPServer -> TMap SMPServer a -> m (Maybe a)) -> (forall a. TVar a -> m a) -> m (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) + getPending :: Monad m => (forall a. SMPServer -> TMap SMPServer a -> m (Maybe a)) -> (forall a. TVar a -> m a) -> m (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey)) getPending lkup rd = do sSub <- lkup srv (pendingServiceSubs ca) $>>= rd qSubs <- lkup srv (pendingQueueSubs ca) >>= mapM rd @@ -329,7 +330,7 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s whenM (isEmptyTMVar $ sessionVar v) retry removeSessVar v srv smpSubWorkers -reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> ExceptT SMPClientError IO () +reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> (Maybe ServiceSub, Maybe (Map QueueId C.APrivateAuthKey)) -> ExceptT SMPClientError IO () reconnectSMPClient ca@SMPClientAgent {agentCfg, agentParty} srv (sSub_, qSubs_) = withSMP ca srv $ \smp -> liftIO $ case serviceParty agentParty of Just Dict -> resubscribe smp @@ -430,7 +431,7 @@ smpSubscribeQueues ca smp srv subs = do let acc@(_, _, (qOks, sQs), notPending) = foldr (groupSub pending) (False, [], ([], []), []) (L.zip subs rs) unless (null qOks) $ addActiveSubs ca srv qOks unless (null sQs) $ forM_ smpServiceId $ \serviceId -> - updateActiveServiceSub ca srv ((serviceId, fromIntegral $ length sQs), sessId) + updateActiveServiceSub ca srv (ServiceSub serviceId (fromIntegral $ length sQs) (queueIdsHash sQs), sessId) unless (null notPending) $ removePendingSubs ca srv notPending pure acc sessId = sessionId $ thParams smp @@ -454,24 +455,24 @@ smpSubscribeQueues ca smp srv subs = do notify_ :: (SMPServer -> NonEmpty a -> SMPClientAgentEvent) -> [a] -> IO () notify_ evt qs = mapM_ (notify ca . evt srv) $ L.nonEmpty qs -subscribeServiceNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> (ServiceId, Int64) -> IO () +subscribeServiceNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> ServiceSub -> IO () subscribeServiceNtfs = subscribeService_ {-# INLINE subscribeServiceNtfs #-} -subscribeService_ :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPServer -> (ServiceId, Int64) -> IO () +subscribeService_ :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPServer -> ServiceSub -> IO () subscribeService_ ca srv serviceSub = do atomically $ setPendingServiceSub ca srv $ Just serviceSub runExceptT (getSMPServerClient' ca srv) >>= \case Right smp -> smpSubscribeService ca smp srv serviceSub Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that -smpSubscribeService :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> (ServiceId, Int64) -> IO () -smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService smp of +smpSubscribeService :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> ServiceSub -> IO () +smpSubscribeService ca smp srv serviceSub@(ServiceSub serviceId n idsHash) = case smpClientService smp of Just service | serviceAvailable service -> subscribe _ -> notifyUnavailable where subscribe = do - r <- runExceptT $ subscribeService smp $ agentParty ca + r <- runExceptT $ subscribeService smp (agentParty ca) n idsHash ok <- atomically $ ifM @@ -479,15 +480,15 @@ smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService (True <$ processSubscription r) (pure False) if ok - then case r of -- TODO [certs rcv] compare hash - Right (n, _idsHash) -> notify ca $ CAServiceSubscribed srv serviceSub n + then case r of + Right serviceSub' -> notify ca $ CAServiceSubscribed srv serviceSub serviceSub' Left e | smpClientServiceError e -> notifyUnavailable | temporaryClientError e -> reconnectClient ca srv | otherwise -> notify ca $ CAServiceSubError srv serviceSub e else reconnectClient ca srv - processSubscription = mapM_ $ \(n, _idsHash) -> do -- TODO [certs rcv] validate hash here? - setActiveServiceSub ca srv $ Just ((serviceId, n), sessId) + processSubscription = mapM_ $ \serviceSub' -> do -- TODO [certs rcv] validate hash here? + setActiveServiceSub ca srv $ Just (serviceSub', sessId) setPendingServiceSub ca srv Nothing serviceAvailable THClientService {serviceRole, serviceId = serviceId'} = serviceId == serviceId' && partyServiceRole (agentParty ca) == serviceRole @@ -529,11 +530,11 @@ addSubs_ subs srv ss = Just m -> TM.union ss m _ -> TM.insertM srv (newTVar ss) subs -setActiveServiceSub :: SMPClientAgent p -> SMPServer -> Maybe ((ServiceId, Int64), SessionId) -> STM () +setActiveServiceSub :: SMPClientAgent p -> SMPServer -> Maybe (ServiceSub, SessionId) -> STM () setActiveServiceSub = setServiceSub_ activeServiceSubs {-# INLINE setActiveServiceSub #-} -setPendingServiceSub :: SMPClientAgent p -> SMPServer -> Maybe (ServiceId, Int64) -> STM () +setPendingServiceSub :: SMPClientAgent p -> SMPServer -> Maybe ServiceSub -> STM () setPendingServiceSub = setServiceSub_ pendingServiceSubs {-# INLINE setPendingServiceSub #-} @@ -548,12 +549,12 @@ setServiceSub_ subsSel ca srv sub = Just v -> writeTVar v sub Nothing -> TM.insertM srv (newTVar sub) (subsSel ca) -updateActiveServiceSub :: SMPClientAgent p -> SMPServer -> ((ServiceId, Int64), SessionId) -> STM () -updateActiveServiceSub ca srv sub@((serviceId', n'), sessId') = +updateActiveServiceSub :: SMPClientAgent p -> SMPServer -> (ServiceSub, SessionId) -> STM () +updateActiveServiceSub ca srv sub@(ServiceSub serviceId' n' idsHash', sessId') = TM.lookup srv (activeServiceSubs ca) >>= \case Just v -> modifyTVar' v $ \case - Just ((serviceId, n), sessId) | serviceId == serviceId' && sessId == sessId' -> - Just ((serviceId, n + n'), sessId) + Just (ServiceSub serviceId n idsHash, sessId) | serviceId == serviceId' && sessId == sessId' -> + Just (ServiceSub serviceId (n + n') (idsHash <> idsHash'), sessId) _ -> Just sub Nothing -> TM.insertM srv (newTVar $ Just sub) (activeServiceSubs ca) diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 3d24f0bcb..c7b539641 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -178,6 +178,7 @@ module Simplex.Messaging.Crypto sha512Hash, sha3_256, sha3_384, + md5Hash, -- * Message padding / un-padding canPad, @@ -216,7 +217,7 @@ import Crypto.Cipher.AES (AES256) import qualified Crypto.Cipher.Types as AES import qualified Crypto.Cipher.XSalsa as XSalsa import qualified Crypto.Error as CE -import Crypto.Hash (Digest, SHA3_256, SHA3_384, SHA256 (..), SHA512 (..), hash, hashDigestSize) +import Crypto.Hash (Digest, MD5, SHA3_256, SHA3_384, SHA256 (..), SHA512 (..), hash, hashDigestSize) import qualified Crypto.KDF.HKDF as H import qualified Crypto.MAC.Poly1305 as Poly1305 import qualified Crypto.PubKey.Curve25519 as X25519 @@ -1024,6 +1025,9 @@ sha3_384 :: ByteString -> ByteString sha3_384 = BA.convert . (hash :: ByteString -> Digest SHA3_384) {-# INLINE sha3_384 #-} +md5Hash :: ByteString -> ByteString +md5Hash = BA.convert . (hash :: ByteString -> Digest MD5) + -- | AEAD-GCM encryption with associated data. -- -- Used as part of double ratchet encryption. diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 0b5889bb7..7acb714c0 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -489,17 +489,9 @@ data NtfSubStatus NSErr ByteString deriving (Eq, Ord, Show) -ntfShouldSubscribe :: NtfSubStatus -> Bool -ntfShouldSubscribe = \case - NSNew -> True - NSPending -> True - NSActive -> True - NSInactive -> True - NSEnd -> False - NSDeleted -> False - NSAuth -> False - NSService -> True - NSErr _ -> False +-- if these statuses change, the queue ID hashes for services need to be updated in a new migration (see m20250830_queue_ids_hash) +subscribeNtfStatuses :: [NtfSubStatus] +subscribeNtfStatuses = [NSNew, NSPending, NSActive, NSInactive] instance Encoding NtfSubStatus where smpEncode = \case diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 43d97988e..f06e9c7b1 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -62,7 +62,7 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore, TokenNtfMessag import Simplex.Messaging.Notifications.Server.Store.Postgres import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Notifications.Transport -import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceId, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGetServer, tPut) +import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceSub (..), SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGetServer, tPut) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server import Simplex.Messaging.Server.Control (CPClientRole (..)) @@ -257,9 +257,9 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions} srvSubscribers <- getSMPWorkerMetrics a smpSubscribers srvClients <- getSMPWorkerMetrics a smpClients srvSubWorkers <- getSMPWorkerMetrics a smpSubWorkers - ntfActiveServiceSubs <- getSMPServiceSubMetrics a activeServiceSubs $ snd . fst + ntfActiveServiceSubs <- getSMPServiceSubMetrics a activeServiceSubs $ smpQueueCount . fst ntfActiveQueueSubs <- getSMPSubMetrics a activeQueueSubs - ntfPendingServiceSubs <- getSMPServiceSubMetrics a pendingServiceSubs snd + ntfPendingServiceSubs <- getSMPServiceSubMetrics a pendingServiceSubs smpQueueCount ntfPendingQueueSubs <- getSMPSubMetrics a pendingQueueSubs smpSessionCount <- M.size <$> readTVarIO smpSessions apnsPushQLength <- atomically $ lengthTBQueue pushQ @@ -452,13 +452,13 @@ resubscribe NtfSubscriber {smpAgent = ca} = do counts <- mapConcurrently (subscribeSrvSubs ca st batchSize) srvs logNote $ "Completed all SMP resubscriptions for " <> tshow (length srvs) <> " servers (" <> tshow (sum counts) <> " subscriptions)" -subscribeSrvSubs :: SMPClientAgent 'NotifierService -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe (ServiceId, Int64)) -> IO Int +subscribeSrvSubs :: SMPClientAgent 'NotifierService -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe ServiceSub) -> IO Int subscribeSrvSubs ca st batchSize (srv, srvId, service_) = do let srvStr = safeDecodeUtf8 (strEncode $ L.head $ host srv) logNote $ "Starting SMP resubscriptions for " <> srvStr - forM_ service_ $ \(serviceId, n) -> do - logNote $ "Subscribing service to " <> srvStr <> " with " <> tshow n <> " associated queues" - subscribeServiceNtfs ca srv (serviceId, n) + forM_ service_ $ \serviceSub -> do + logNote $ "Subscribing service to " <> srvStr <> " with " <> tshow (smpQueueCount serviceSub) <> " associated queues" + subscribeServiceNtfs ca srv serviceSub n <- subscribeLoop 0 Nothing logNote $ "Completed SMP resubscriptions for " <> srvStr <> " (" <> tshow n <> " subscriptions)" pure n @@ -576,7 +576,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = -- TODO [certs] resubscribe queues with statuses NSErr and NSService CAServiceDisconnected srv serviceSub -> logNote $ "SMP server service disconnected " <> showService srv serviceSub - CAServiceSubscribed srv serviceSub@(_, expected) n + CAServiceSubscribed srv serviceSub@(ServiceSub _ expected _) (ServiceSub _ n _) -- TODO [certs rcv] compare hash | expected == n -> logNote msg | otherwise -> logWarn $ msg <> ", confirmed subs: " <> tshow n where @@ -593,7 +593,8 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = void $ subscribeSrvSubs ca st batchSize (srv, srvId, Nothing) Left e -> logError $ "SMP server update and resubscription error " <> tshow e where - showService srv (serviceId, n) = showServer' srv <> ", service ID " <> decodeLatin1 (strEncode serviceId) <> ", " <> tshow n <> " subs" + -- TODO [certs rcv] compare hash + showService srv (ServiceSub serviceId n _idsHash) = showServer' srv <> ", service ID " <> decodeLatin1 (strEncode serviceId) <> ", " <> tshow n <> " subs" logSubErrors :: SMPServer -> NonEmpty (SMP.NotifierId, NtfSubStatus) -> Int -> IO () logSubErrors srv subs updated = forM_ (L.group $ L.sort $ L.map snd subs) $ \ss -> do diff --git a/src/Simplex/Messaging/Notifications/Server/Stats.hs b/src/Simplex/Messaging/Notifications/Server/Stats.hs index a20e41c34..7125ce290 100644 --- a/src/Simplex/Messaging/Notifications/Server/Stats.hs +++ b/src/Simplex/Messaging/Notifications/Server/Stats.hs @@ -17,6 +17,7 @@ import Simplex.Messaging.Server.Stats import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM +-- TODO [certs rcv] track service subscriptions and count/hash diffs for own and other servers + prometheus data NtfServerStats = NtfServerStats { fromTime :: IORef UTCTime, tknCreated :: IORef Int, diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs index 6a53ff4a2..8c0da7c07 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs @@ -6,13 +6,15 @@ module Simplex.Messaging.Notifications.Server.Store.Migrations where import Data.List (sortOn) import Data.Text (Text) +import Simplex.Messaging.Agent.Store.Postgres.Migrations.Util import Simplex.Messaging.Agent.Store.Shared import Text.RawString.QQ (r) ntfServerSchemaMigrations :: [(String, Text, Maybe Text)] ntfServerSchemaMigrations = [ ("20250417_initial", m20250417_initial, Nothing), - ("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert) + ("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert), + ("20250830_queue_ids_hash", m20250830_queue_ids_hash, Just down_m20250830_queue_ids_hash) ] -- | The list of migrations in ascending order by date @@ -101,3 +103,125 @@ ALTER TABLE smp_servers DROP COLUMN ntf_service_id; ALTER TABLE subscriptions DROP COLUMN ntf_service_assoc; |] + +m20250830_queue_ids_hash :: Text +m20250830_queue_ids_hash = + createXorHashFuncs + <> [r| +ALTER TABLE smp_servers + ADD COLUMN smp_notifier_count BIGINT NOT NULL DEFAULT 0, + ADD COLUMN smp_notifier_ids_hash BYTEA NOT NULL DEFAULT '\x00000000000000000000000000000000'; + +CREATE FUNCTION should_subscribe_status(p_status TEXT) RETURNS BOOLEAN +LANGUAGE plpgsql IMMUTABLE STRICT +AS $$ +BEGIN + RETURN p_status IN ('NEW', 'PENDING', 'ACTIVE', 'INACTIVE'); +END; +$$; + +CREATE FUNCTION update_all_aggregates() RETURNS VOID +LANGUAGE plpgsql +AS $$ +BEGIN + WITH acc AS ( + SELECT + s.smp_server_id, + count(smp_notifier_id) as notifier_count, + xor_aggregate(public.digest(s.smp_notifier_id, 'md5')) AS notifier_hash + FROM subscriptions s + WHERE s.ntf_service_assoc = true AND should_subscribe_status(s.status) + GROUP BY s.smp_server_id + ) + UPDATE smp_servers srv + SET smp_notifier_count = COALESCE(acc.notifier_count, 0), + smp_notifier_ids_hash = COALESCE(acc.notifier_hash, '\x00000000000000000000000000000000') + FROM acc + WHERE srv.smp_server_id = acc.smp_server_id; +END; +$$; + +SELECT update_all_aggregates(); + +CREATE FUNCTION update_aggregates(p_server_id BIGINT, p_change BIGINT, p_notifier_id BYTEA) RETURNS VOID +LANGUAGE plpgsql +AS $$ +BEGIN + UPDATE smp_servers + SET smp_notifier_count = smp_notifier_count + p_change, + smp_notifier_ids_hash = xor_combine(smp_notifier_ids_hash, public.digest(p_notifier_id, 'md5')) + WHERE smp_server_id = p_server_id; +END; +$$; + +CREATE FUNCTION on_subscription_insert() RETURNS TRIGGER +LANGUAGE plpgsql +AS $$ +BEGIN + IF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN + PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id); + END IF; + RETURN NEW; +END; +$$; + +CREATE FUNCTION on_subscription_delete() RETURNS TRIGGER +LANGUAGE plpgsql +AS $$ +BEGIN + IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN + PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id); + END IF; + RETURN OLD; +END; +$$; + +CREATE FUNCTION on_subscription_update() RETURNS TRIGGER +LANGUAGE plpgsql +AS $$ +BEGIN + IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN + IF NOT (NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status)) THEN + PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id); + END IF; + ELSIF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN + PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id); + END IF; + RETURN NEW; +END; +$$; + +CREATE TRIGGER tr_subscriptions_insert +AFTER INSERT ON subscriptions +FOR EACH ROW EXECUTE PROCEDURE on_subscription_insert(); + +CREATE TRIGGER tr_subscriptions_delete +AFTER DELETE ON subscriptions +FOR EACH ROW EXECUTE PROCEDURE on_subscription_delete(); + +CREATE TRIGGER tr_subscriptions_update +AFTER UPDATE ON subscriptions +FOR EACH ROW EXECUTE PROCEDURE on_subscription_update(); + |] + +down_m20250830_queue_ids_hash :: Text +down_m20250830_queue_ids_hash = + [r| +DROP TRIGGER tr_subscriptions_insert ON subscriptions; +DROP TRIGGER tr_subscriptions_delete ON subscriptions; +DROP TRIGGER tr_subscriptions_update ON subscriptions; + +DROP FUNCTION on_subscription_insert; +DROP FUNCTION on_subscription_delete; +DROP FUNCTION on_subscription_update; + +DROP FUNCTION update_aggregates; +DROP FUNCTION update_all_aggregates; + +DROP FUNCTION should_subscribe_status; + +ALTER TABLE smp_servers + DROP COLUMN smp_notifier_count, + DROP COLUMN smp_notifier_ids_hash; + |] + <> dropXorHashFuncs diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs index 80d946c8b..60e81a68b 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -64,7 +64,7 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore (..), NtfSubDat import Simplex.Messaging.Notifications.Server.Store.Migrations import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Notifications.Server.StoreLog -import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, pattern SMPServer) +import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), IdsHash (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, ServiceSub (..), pattern SMPServer) import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_) import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) import Simplex.Messaging.Server.StoreLog (openWriteStoreLog) @@ -239,7 +239,7 @@ updateTknCronInterval st tknId cronInt = -- Reads servers that have subscriptions that need subscribing. -- It is executed on server start, and it is supposed to crash on database error -getUsedSMPServers :: NtfPostgresStore -> IO [(SMPServer, Int64, Maybe (ServiceId, Int64))] +getUsedSMPServers :: NtfPostgresStore -> IO [(SMPServer, Int64, Maybe ServiceSub)] getUsedSMPServers st = withTransaction (dbStore st) $ \db -> map rowToSrvSubs <$> @@ -247,25 +247,17 @@ getUsedSMPServers st = db [sql| SELECT - p.smp_host, p.smp_port, p.smp_keyhash, p.smp_server_id, p.ntf_service_id, - SUM(CASE WHEN s.ntf_service_assoc THEN s.subs_count ELSE 0 END) :: BIGINT as service_subs_count - FROM smp_servers p - JOIN ( - SELECT - smp_server_id, - ntf_service_assoc, - COUNT(1) as subs_count - FROM subscriptions - WHERE status IN ? - GROUP BY smp_server_id, ntf_service_assoc - ) s ON s.smp_server_id = p.smp_server_id - GROUP BY p.smp_host, p.smp_port, p.smp_keyhash, p.smp_server_id, p.ntf_service_id + smp_host, smp_port, smp_keyhash, smp_server_id, + ntf_service_id, smp_notifier_count, smp_notifier_ids_hash + FROM smp_servers + WHERE EXISTS (SELECT 1 FROM subscriptions WHERE status IN ?) |] - (Only (In [NSNew, NSPending, NSActive, NSInactive])) + (Only (In subscribeNtfStatuses)) where - rowToSrvSubs :: SMPServerRow :. (Int64, Maybe ServiceId, Int64) -> (SMPServer, Int64, Maybe (ServiceId, Int64)) - rowToSrvSubs ((host, port, kh) :. (srvId, serviceId_, subsCount)) = - (SMPServer host port kh, srvId, (,subsCount) <$> serviceId_) + rowToSrvSubs :: SMPServerRow :. (Int64, Maybe ServiceId, Int64, IdsHash) -> (SMPServer, Int64, Maybe ServiceSub) + rowToSrvSubs ((host, port, kh) :. (srvId, serviceId_, n, idsHash)) = + let service_ = (\serviceId -> ServiceSub serviceId n idsHash) <$> serviceId_ + in (SMPServer host port kh, srvId, service_) getServerNtfSubscriptions :: NtfPostgresStore -> Int64 -> Maybe NtfSubscriptionId -> Int -> IO (Either ErrorType [ServerNtfSub]) getServerNtfSubscriptions st srvId afterSubId_ count = @@ -273,9 +265,9 @@ getServerNtfSubscriptions st srvId afterSubId_ count = subs <- map toServerNtfSub <$> case afterSubId_ of Nothing -> - DB.query db (query <> orderLimit) (srvId, statusIn, count) + DB.query db (query <> orderLimit) (srvId, In subscribeNtfStatuses, count) Just afterSubId -> - DB.query db (query <> " AND subscription_id > ?" <> orderLimit) (srvId, statusIn, afterSubId, count) + DB.query db (query <> " AND subscription_id > ?" <> orderLimit) (srvId, In subscribeNtfStatuses, afterSubId, count) void $ DB.executeMany db @@ -296,7 +288,6 @@ getServerNtfSubscriptions st srvId afterSubId_ count = WHERE smp_server_id = ? AND NOT ntf_service_assoc AND status IN ? |] orderLimit = " ORDER BY subscription_id LIMIT ?" - statusIn = In [NSNew, NSPending, NSActive, NSInactive] toServerNtfSub (ntfSubId, notifierId, notifierKey) = (ntfSubId, (notifierId, notifierKey)) -- Returns token and subscription. diff --git a/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql b/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql index 3b155fa1a..b73995684 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql +++ b/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql @@ -15,6 +15,123 @@ SET row_security = off; CREATE SCHEMA ntf_server; + +CREATE FUNCTION ntf_server.on_subscription_delete() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN + PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id); + END IF; + RETURN OLD; +END; +$$; + + + +CREATE FUNCTION ntf_server.on_subscription_insert() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN + PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id); + END IF; + RETURN NEW; +END; +$$; + + + +CREATE FUNCTION ntf_server.on_subscription_update() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF OLD.ntf_service_assoc = true AND should_subscribe_status(OLD.status) THEN + IF NOT (NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status)) THEN + PERFORM update_aggregates(OLD.smp_server_id, -1, OLD.smp_notifier_id); + END IF; + ELSIF NEW.ntf_service_assoc = true AND should_subscribe_status(NEW.status) THEN + PERFORM update_aggregates(NEW.smp_server_id, 1, NEW.smp_notifier_id); + END IF; + RETURN NEW; +END; +$$; + + + +CREATE FUNCTION ntf_server.should_subscribe_status(p_status text) RETURNS boolean + LANGUAGE plpgsql IMMUTABLE STRICT + AS $$ +BEGIN + RETURN p_status IN ('NEW', 'PENDING', 'ACTIVE', 'INACTIVE'); +END; +$$; + + + +CREATE FUNCTION ntf_server.update_aggregates(p_server_id bigint, p_change bigint, p_notifier_id bytea) RETURNS void + LANGUAGE plpgsql + AS $$ +BEGIN + UPDATE smp_servers + SET smp_notifier_count = smp_notifier_count + p_change, + smp_notifier_ids_hash = xor_combine(smp_notifier_ids_hash, public.digest(p_notifier_id, 'md5')) + WHERE smp_server_id = p_server_id; +END; +$$; + + + +CREATE FUNCTION ntf_server.update_all_aggregates() RETURNS void + LANGUAGE plpgsql + AS $$ +BEGIN + WITH acc AS ( + SELECT + s.smp_server_id, + count(smp_notifier_id) as notifier_count, + xor_aggregate(public.digest(s.smp_notifier_id, 'md5')) AS notifier_hash + FROM subscriptions s + WHERE s.ntf_service_assoc = true AND should_subscribe_status(s.status) + GROUP BY s.smp_server_id + ) + UPDATE smp_servers srv + SET smp_notifier_count = COALESCE(acc.notifier_count, 0), + smp_notifier_ids_hash = COALESCE(acc.notifier_hash, '\x00000000000000000000000000000000') + FROM acc + WHERE srv.smp_server_id = acc.smp_server_id; +END; +$$; + + + +CREATE FUNCTION ntf_server.xor_combine(state bytea, value bytea) RETURNS bytea + LANGUAGE plpgsql IMMUTABLE STRICT + AS $$ +DECLARE + result BYTEA := state; + i INTEGER; + len INTEGER := octet_length(value); +BEGIN + IF octet_length(state) != len THEN + RAISE EXCEPTION 'Inputs must be equal length (% != %)', octet_length(state), len; + END IF; + FOR i IN 0..len-1 LOOP + result := set_byte(result, i, get_byte(state, i) # get_byte(value, i)); + END LOOP; + RETURN result; +END; +$$; + + + +CREATE AGGREGATE ntf_server.xor_aggregate(bytea) ( + SFUNC = ntf_server.xor_combine, + STYPE = bytea, + INITCOND = '\x00000000000000000000000000000000' +); + + SET default_table_access_method = heap; @@ -53,7 +170,9 @@ CREATE TABLE ntf_server.smp_servers ( smp_host text NOT NULL, smp_port text NOT NULL, smp_keyhash bytea NOT NULL, - ntf_service_id bytea + ntf_service_id bytea, + smp_notifier_count bigint DEFAULT 0 NOT NULL, + smp_notifier_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL ); @@ -158,6 +277,18 @@ CREATE INDEX idx_tokens_status_cron_interval_sent_at ON ntf_server.tokens USING +CREATE TRIGGER tr_subscriptions_delete AFTER DELETE ON ntf_server.subscriptions FOR EACH ROW EXECUTE FUNCTION ntf_server.on_subscription_delete(); + + + +CREATE TRIGGER tr_subscriptions_insert AFTER INSERT ON ntf_server.subscriptions FOR EACH ROW EXECUTE FUNCTION ntf_server.on_subscription_insert(); + + + +CREATE TRIGGER tr_subscriptions_update AFTER UPDATE ON ntf_server.subscriptions FOR EACH ROW EXECUTE FUNCTION ntf_server.on_subscription_update(); + + + ALTER TABLE ONLY ntf_server.last_notifications ADD CONSTRAINT last_notifications_subscription_id_fkey FOREIGN KEY (subscription_id) REFERENCES ntf_server.subscriptions(subscription_id) ON UPDATE RESTRICT ON DELETE CASCADE; diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 3be4515cc..c00899e1c 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -140,7 +140,10 @@ module Simplex.Messaging.Protocol RcvMessage (..), MsgId, MsgBody, - IdsHash, + IdsHash (..), + ServiceSub (..), + queueIdsHash, + queueIdHash, MaxMessageLen, MaxRcvMessageLen, EncRcvMsgBody (..), @@ -223,6 +226,8 @@ import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser, ()) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (bimap, first) +import Data.Bits (xor) +import qualified Data.ByteString as BS import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -232,6 +237,7 @@ import Data.Constraint (Dict (..)) import Data.Functor (($>)) import Data.Int (Int64) import Data.Kind +import Data.List (foldl') import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Maybe (isJust, isNothing) @@ -241,7 +247,7 @@ import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time.Clock.System (SystemTime (..), systemToUTCTime) import Data.Type.Equality -import Data.Word (Word16) +import Data.Word (Word8, Word16) import GHC.TypeLits (ErrorMessage (..), TypeError, type (+)) import qualified GHC.TypeLits as TE import qualified GHC.TypeLits as Type @@ -548,7 +554,8 @@ data Command (p :: Party) where NEW :: NewQueueReq -> Command Creator SUB :: Command Recipient -- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command. - SUBS :: Command RecipientService + -- Parameters are expected queue count and hash of all subscribed queues, it allows to monitor "state drift" on the server + SUBS :: Int64 -> IdsHash -> Command RecipientService KEY :: SndPublicAuthKey -> Command Recipient RKEY :: NonEmpty RcvPublicAuthKey -> Command Recipient LSET :: LinkId -> QueueLinkData -> Command Recipient @@ -572,7 +579,7 @@ data Command (p :: Party) where -- SMP notification subscriber commands NSUB :: Command Notifier -- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command. - NSUBS :: Command NotifierService + NSUBS :: Int64 -> IdsHash -> Command NotifierService PRXY :: SMPServer -> Maybe BasicAuth -> Command ProxiedClient -- request a relay server connection by URI -- Transmission to proxy: -- - entity ID: ID of the session with relay returned in PKEY (response to PRXY) @@ -698,7 +705,7 @@ data BrokerMsg where LNK :: SenderId -> QueueLinkData -> BrokerMsg -- | Service subscription success - confirms when queue was associated with the service SOK :: Maybe ServiceId -> BrokerMsg - -- | The number of queues subscribed with SUBS command + -- | The number of queues and XOR-hash of their IDs subscribed with SUBS command SOKS :: Int64 -> IdsHash -> BrokerMsg -- MSG v1/2 has to be supported for encoding/decoding -- v1: MSG :: MsgId -> SystemTime -> MsgBody -> BrokerMsg @@ -1460,7 +1467,42 @@ type MsgId = ByteString -- | SMP message body. type MsgBody = ByteString -type IdsHash = ByteString +data ServiceSub = ServiceSub + { serviceId :: ServiceId, + smpQueueCount :: Int64, + smpQueueIdsHash :: IdsHash + } + +newtype IdsHash = IdsHash {unIdsHash :: BS.ByteString} + deriving (Eq, Show) + deriving newtype (Encoding, FromField) + +instance ToField IdsHash where + toField (IdsHash s) = toField (Binary s) + {-# INLINE toField #-} + +instance Semigroup IdsHash where + (IdsHash s1) <> (IdsHash s2) = IdsHash $! BS.pack $ BS.zipWith xor s1 s2 + +instance Monoid IdsHash where + mempty = IdsHash $ BS.replicate 16 0 + mconcat ss = + let !s' = BS.pack $ foldl' (\ !r (IdsHash s) -> zipWith xor' r (BS.unpack s)) (replicate 16 0) ss -- to prevent packing/unpacking in <> on each step with default mappend + in IdsHash s' + +xor' :: Word8 -> Word8 -> Word8 +xor' x y = let !r = xor x y in r + +noIdsHash ::IdsHash +noIdsHash = IdsHash B.empty +{-# INLINE noIdsHash #-} + +queueIdsHash :: [QueueId] -> IdsHash +queueIdsHash = mconcat . map queueIdHash + +queueIdHash :: QueueId -> IdsHash +queueIdHash = IdsHash . C.md5Hash . unEntityId +{-# INLINE queueIdHash #-} data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock @@ -1695,7 +1737,9 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where new = e (NEW_, ' ', rKey, dhKey) auth = maybe "" (e . ('A',)) auth_ SUB -> e SUB_ - SUBS -> e SUBS_ + SUBS n idsHash + | v >= rcvServiceSMPVersion -> e (SUBS_, ' ', n, idsHash) + | otherwise -> e SUBS_ KEY k -> e (KEY_, ' ', k) RKEY ks -> e (RKEY_, ' ', ks) LSET lnkId d -> e (LSET_, ' ', lnkId, d) @@ -1711,7 +1755,9 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where SEND flags msg -> e (SEND_, ' ', flags, ' ', Tail msg) PING -> e PING_ NSUB -> e NSUB_ - NSUBS -> e NSUBS_ + NSUBS n idsHash + | v >= rcvServiceSMPVersion -> e (NSUBS_, ' ', n, idsHash) + | otherwise -> e NSUBS_ LKEY k -> e (LKEY_, ' ', k) LGET -> e LGET_ PRXY host auth_ -> e (PRXY_, ' ', host, auth_) @@ -1802,7 +1848,9 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where OFF_ -> pure OFF DEL_ -> pure DEL QUE_ -> pure QUE - CT SRecipientService SUBS_ -> pure $ Cmd SRecipientService SUBS + CT SRecipientService SUBS_ + | v >= rcvServiceSMPVersion -> Cmd SRecipientService <$> (SUBS <$> _smpP <*> smpP) + | otherwise -> pure $ Cmd SRecipientService $ SUBS (-1) noIdsHash CT SSender tag -> Cmd SSender <$> case tag of SKEY_ -> SKEY <$> _smpP @@ -1819,7 +1867,9 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where PFWD_ -> PFWD <$> _smpP <*> smpP <*> (EncTransmission . unTail <$> smpP) PRXY_ -> PRXY <$> _smpP <*> smpP CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB - CT SNotifierService NSUBS_ -> pure $ Cmd SNotifierService NSUBS + CT SNotifierService NSUBS_ + | v >= rcvServiceSMPVersion -> Cmd SNotifierService <$> (NSUBS <$> _smpP <*> smpP) + | otherwise -> pure $ Cmd SNotifierService $ NSUBS (-1) noIdsHash fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg {-# INLINE fromProtocolError #-} @@ -1901,7 +1951,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where SOK_ -> SOK <$> _smpP SOKS_ | v >= rcvServiceSMPVersion -> SOKS <$> _smpP <*> smpP - | otherwise -> SOKS <$> _smpP <*> pure B.empty + | otherwise -> SOKS <$> _smpP <*> pure noIdsHash NID_ -> NID <$> _smpP <*> smpP NMSG_ -> NMSG <$> _smpP <*> smpP PKEY_ -> PKEY <$> _smpP <*> smpP <*> smpP diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 1e5e94fd6..a05743a06 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -6,6 +6,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedLists #-} @@ -1247,7 +1248,7 @@ verifyQueueTransmission service thAuth (tAuth, authorized, (corrId, entId, comma vc SCreator (NEW NewQueueReq {rcvAuthKey = k}) = verifiedWith k vc SRecipient SUB = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q) vc SRecipient _ = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q) - vc SRecipientService SUBS = verifyServiceCmd + vc SRecipientService SUBS {} = verifyServiceCmd vc SSender (SKEY k) = verifySecure k -- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command vc SSender SEND {} = verifyQueue $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified q_ else VRFailed AUTH @@ -1255,7 +1256,7 @@ verifyQueueTransmission service thAuth (tAuth, authorized, (corrId, entId, comma vc SSenderLink (LKEY k) = verifySecure k vc SSenderLink LGET = verifyQueue $ \q -> if isContactQueue (snd q) then VRVerified q_ else VRFailed AUTH vc SNotifier NSUB = verifyQueue $ \q -> maybe dummyVerify (\n -> verifiedWith $ notifierKey n) (notifier $ snd q) - vc SNotifierService NSUBS = verifyServiceCmd + vc SNotifierService NSUBS {} = verifyServiceCmd vc SProxiedClient _ = VRVerified Nothing vc SProxyService (RFWD _) = VRVerified Nothing checkRole = case (service, partyClientRole p) of @@ -1465,8 +1466,8 @@ client Cmd SNotifier NSUB -> response . (corrId,entId,) <$> case q_ of Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds _ -> pure $ ERR INTERNAL - Cmd SNotifierService NSUBS -> response . (corrId,entId,) <$> case clntServiceId of - Just serviceId -> subscribeServiceNotifications serviceId + Cmd SNotifierService (NSUBS n idsHash) -> response . (corrId,entId,) <$> case clntServiceId of + Just serviceId -> subscribeServiceNotifications serviceId (n, idsHash) Nothing -> pure $ ERR INTERNAL Cmd SCreator (NEW nqr@NewQueueReq {auth_}) -> response <$> ifM allowNew (createQueue nqr) (pure (corrId, entId, ERR AUTH)) @@ -1495,8 +1496,8 @@ client OFF -> response <$> maybe (pure $ err INTERNAL) suspendQueue_ q_ DEL -> response <$> maybe (pure $ err INTERNAL) delQueueAndMsgs q_ QUE -> withQueue $ \q qr -> (corrId,entId,) <$> getQueueInfo q qr - Cmd SRecipientService SUBS -> response . (corrId,entId,) <$> case clntServiceId of - Just serviceId -> subscribeServiceMessages serviceId + Cmd SRecipientService (SUBS n idsHash)-> response . (corrId,entId,) <$> case clntServiceId of + Just serviceId -> subscribeServiceMessages serviceId (n, idsHash) Nothing -> pure $ ERR INTERNAL -- it's "internal" because it should never get to this branch where createQueue :: NewQueueReq -> M s (Transmission BrokerMsg) @@ -1795,9 +1796,9 @@ client TM.insert entId sub $ clientSubs clnt pure (False, Just sub) - subscribeServiceMessages :: ServiceId -> M s BrokerMsg - subscribeServiceMessages serviceId = - sharedSubscribeService SRecipientService serviceId subscribers serviceSubscribed serviceSubsCount >>= \case + subscribeServiceMessages :: ServiceId -> (Int64, IdsHash) -> M s BrokerMsg + subscribeServiceMessages serviceId expected = + sharedSubscribeService SRecipientService serviceId expected subscribers serviceSubscribed serviceSubsCount rcvServices >>= \case Left e -> pure $ ERR e Right (hasSub, (count, idsHash)) -> do unless hasSub $ forkClient clnt "deliverServiceMessages" $ liftIO $ deliverServiceMessages count @@ -1806,7 +1807,7 @@ client deliverServiceMessages expectedCnt = do (qCnt, _msgCnt, _dupCnt, _errCnt) <- foldRcvServiceMessages ms serviceId deliverQueueMsg (0, 0, 0, 0) atomically $ writeTBQueue msgQ [(NoCorrId, NoEntity, SALL)] - -- TODO [cert rcv] compare with expected + -- TODO [certs rcv] compare with expected logNote $ "Service subscriptions for " <> tshow serviceId <> " (" <> tshow qCnt <> " queues)" deliverQueueMsg :: (Int, Int, Int, Int) -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO (Int, Int, Int, Int) deliverQueueMsg (!qCnt, !msgCnt, !dupCnt, !errCnt) rId = \case @@ -1831,25 +1832,33 @@ client TM.insert rId sub $ subscriptions clnt pure $ Just sub - subscribeServiceNotifications :: ServiceId -> M s BrokerMsg - subscribeServiceNotifications serviceId = - either ERR (uncurry SOKS . snd) <$> sharedSubscribeService SNotifierService serviceId ntfSubscribers ntfServiceSubscribed ntfServiceSubsCount + subscribeServiceNotifications :: ServiceId -> (Int64, IdsHash) -> M s BrokerMsg + subscribeServiceNotifications serviceId expected = + either ERR (uncurry SOKS . snd) <$> sharedSubscribeService SNotifierService serviceId expected ntfSubscribers ntfServiceSubscribed ntfServiceSubsCount ntfServices - sharedSubscribeService :: (PartyI p, ServiceParty p) => SParty p -> ServiceId -> ServerSubscribers s -> (Client s -> TVar Bool) -> (Client s -> TVar Int64) -> M s (Either ErrorType (Bool, (Int64, IdsHash))) - sharedSubscribeService party serviceId srvSubscribers clientServiceSubscribed clientServiceSubs = do + sharedSubscribeService :: (PartyI p, ServiceParty p) => SParty p -> ServiceId -> (Int64, IdsHash) -> ServerSubscribers s -> (Client s -> TVar Bool) -> (Client s -> TVar Int64) -> (ServerStats -> ServiceStats) -> M s (Either ErrorType (Bool, (Int64, IdsHash))) + sharedSubscribeService party serviceId (count, idsHash) srvSubscribers clientServiceSubscribed clientServiceSubs servicesSel = do subscribed <- readTVarIO $ clientServiceSubscribed clnt + stats <- asks serverStats liftIO $ runExceptT $ (subscribed,) <$> if subscribed - then (,B.empty) <$> readTVarIO (clientServiceSubs clnt) -- TODO [certs rcv] get IDs hash + then (,mempty) <$> readTVarIO (clientServiceSubs clnt) -- TODO [certs rcv] get IDs hash else do - count' <- ExceptT $ getServiceQueueCount @(StoreQueue s) (queueStore ms) party serviceId + (count', idsHash') <- ExceptT $ getServiceQueueCountHash @(StoreQueue s) (queueStore ms) party serviceId incCount <- atomically $ do writeTVar (clientServiceSubscribed clnt) True - count <- swapTVar (clientServiceSubs clnt) count' - pure $ count' - count + currCount <- swapTVar (clientServiceSubs clnt) count' -- TODO [certs rcv] maintain IDs hash here? + pure $ count' - currCount + let incSrvStat sel n = liftIO $ atomicModifyIORef'_ (sel $ servicesSel stats) (+ n) + diff = fromIntegral $ count' - count + if -- TODO [certs rcv] account for not provided counts/hashes (expected n = -1) + | diff == 0 && idsHash == idsHash' -> incSrvStat srvSubOk 1 + | diff > 0 -> incSrvStat srvSubMore 1 >> incSrvStat srvSubMoreTotal diff + | diff < 0 -> incSrvStat srvSubFewer 1 >> incSrvStat srvSubFewerTotal (- diff) + | otherwise -> incSrvStat srvSubDiff 1 atomically $ writeTQueue (subQ srvSubscribers) (CSService serviceId incCount, clientId) - pure (count', B.empty) -- TODO [certs rcv] get IDs hash + pure (count', idsHash') acknowledgeMsg :: MsgId -> StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg) acknowledgeMsg msgId q qr = diff --git a/src/Simplex/Messaging/Server/MsgStore/Journal.hs b/src/Simplex/Messaging/Server/MsgStore/Journal.hs index d9a1ff6ec..89e9f0383 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Journal.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Journal.hs @@ -355,8 +355,8 @@ instance QueueStoreClass (JournalQueue s) (QStore s) where {-# INLINE setQueueService #-} getQueueNtfServices = withQS (getQueueNtfServices @(JournalQueue s)) {-# INLINE getQueueNtfServices #-} - getServiceQueueCount = withQS (getServiceQueueCount @(JournalQueue s)) - {-# INLINE getServiceQueueCount #-} + getServiceQueueCountHash = withQS (getServiceQueueCountHash @(JournalQueue s)) + {-# INLINE getServiceQueueCountHash #-} makeQueue_ :: JournalMsgStore s -> RecipientId -> QueueRec -> Lock -> IO (JournalQueue s) makeQueue_ JournalMsgStore {sharedLock} rId qr queueLock = do diff --git a/src/Simplex/Messaging/Server/Prometheus.hs b/src/Simplex/Messaging/Server/Prometheus.hs index 859587b60..e4d6a2774 100644 --- a/src/Simplex/Messaging/Server/Prometheus.hs +++ b/src/Simplex/Messaging/Server/Prometheus.hs @@ -21,6 +21,7 @@ import Simplex.Messaging.Transport (simplexMQVersion) import Simplex.Messaging.Transport.Server (SocketStats (..)) import Simplex.Messaging.Util (tshow) +-- TODO [certs rcv] add service subscriptions and count/hash diffs data ServerMetrics = ServerMetrics { statsData :: ServerStatsData, activeQueueCounts :: PeriodStatCounts, diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index e05719cf6..7caca7669 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -65,6 +65,7 @@ data ServiceRec = ServiceRec serviceCert :: X.CertificateChain, serviceCertHash :: XV.Fingerprint, -- SHA512 hash of long-term service client certificate. See comment for ClientHandshake. serviceCreatedAt :: SystemDate + -- entitiesHash :: IdsHash -- a xor-hash of all associated entities } deriving (Show) diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs index 2fabbfa33..eb1ba3b2c 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs @@ -524,15 +524,11 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where let (sNtfs, restNtfs) = partition (\(nId, _) -> S.member nId snIds) ntfs' in ((serviceId, sNtfs) : ssNtfs, restNtfs) - getServiceQueueCount :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType Int64) - getServiceQueueCount st party serviceId = - E.uninterruptibleMask_ $ runExceptT $ withDB' "getServiceQueueCount" st $ \db -> - maybeFirstRow' 0 fromOnly $ - DB.query db query (Only serviceId) - where - query = case party of - SRecipientService -> "SELECT count(1) FROM msg_queues WHERE rcv_service_id = ? AND deleted_at IS NULL" - SNotifierService -> "SELECT count(1) FROM msg_queues WHERE ntf_service_id = ? AND deleted_at IS NULL" + getServiceQueueCountHash :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType (Int64, IdsHash)) + getServiceQueueCountHash st party serviceId = + E.uninterruptibleMask_ $ runExceptT $ withDB' "getServiceQueueCountHash" st $ \db -> + maybeFirstRow' (0, mempty) id $ + DB.query db ("SELECT queue_count, queue_ids_hash FROM services WHERE service_id = ? AND service_role = ?") (serviceId, partyServiceRole party) batchInsertServices :: [STMService] -> PostgresQueueStore q -> IO Int64 batchInsertServices services' toStore = @@ -793,6 +789,10 @@ instance ToField C.APublicAuthKey where toField = toField . Binary . C.encodePub instance FromField C.APublicAuthKey where fromField = blobFieldDecoder C.decodePubKey +instance ToField IdsHash where toField (IdsHash s) = toField (Binary s) + +deriving newtype instance FromField IdsHash + instance ToField EncDataBytes where toField (EncDataBytes s) = toField (Binary s) deriving newtype instance FromField EncDataBytes diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs index 7ff8b9862..5a4d470eb 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs @@ -7,6 +7,7 @@ module Simplex.Messaging.Server.QueueStore.Postgres.Migrations where import Data.List (sortOn) import Data.Text (Text) import Simplex.Messaging.Agent.Store.Shared +import Simplex.Messaging.Agent.Store.Postgres.Migrations.Util import Text.RawString.QQ (r) serverSchemaMigrations :: [(String, Text, Maybe Text)] @@ -15,7 +16,8 @@ serverSchemaMigrations = ("20250319_updated_index", m20250319_updated_index, Just down_m20250319_updated_index), ("20250320_short_links", m20250320_short_links, Just down_m20250320_short_links), ("20250514_service_certs", m20250514_service_certs, Just down_m20250514_service_certs), - ("20250903_store_messages", m20250903_store_messages, Just down_m20250903_store_messages) + ("20250903_store_messages", m20250903_store_messages, Just down_m20250903_store_messages), + ("20250915_queue_ids_hash", m20250915_queue_ids_hash, Just down_m20250915_queue_ids_hash) ] -- | The list of migrations in ascending order by date @@ -447,3 +449,139 @@ ALTER TABLE msg_queues DROP TABLE messages; |] + +m20250915_queue_ids_hash :: Text +m20250915_queue_ids_hash = + createXorHashFuncs + <> [r| +ALTER TABLE services + ADD COLUMN queue_count BIGINT NOT NULL DEFAULT 0, + ADD COLUMN queue_ids_hash BYTEA NOT NULL DEFAULT '\x00000000000000000000000000000000'; + +CREATE FUNCTION update_all_aggregates() RETURNS VOID +LANGUAGE plpgsql +AS $$ +BEGIN + WITH acc AS ( + SELECT + s.service_id, + count(1) as q_count, + xor_aggregate(public.digest(CASE WHEN s.service_role = 'M' THEN q.recipient_id ELSE COALESCE(q.notifier_id, '\x00000000000000000000000000000000') END, 'md5')) AS q_ids_hash + FROM services s + JOIN msg_queues q ON (s.service_id = q.rcv_service_id AND s.service_role = 'M') OR (s.service_id = q.ntf_service_id AND s.service_role = 'N') + WHERE q.deleted_at IS NULL + GROUP BY s.service_id + ) + UPDATE services s + SET queue_count = COALESCE(acc.q_count, 0), + queue_ids_hash = COALESCE(acc.q_ids_hash, '\x00000000000000000000000000000000') + FROM acc + WHERE s.service_id = acc.service_id; +END; +$$; + +SELECT update_all_aggregates(); + +CREATE FUNCTION update_aggregates(p_service_id BYTEA, p_role TEXT, p_queue_id BYTEA, p_change BIGINT) RETURNS VOID +LANGUAGE plpgsql +AS $$ +BEGIN + UPDATE services + SET queue_count = queue_count + p_change, + queue_ids_hash = xor_combine(queue_ids_hash, public.digest(p_queue_id, 'md5')) + WHERE service_id = p_service_id AND service_role = p_role; +END; +$$; + +CREATE FUNCTION on_queue_insert() RETURNS TRIGGER +LANGUAGE plpgsql +AS $$ +BEGIN + IF NEW.rcv_service_id IS NOT NULL THEN + PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1); + END IF; + IF NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN + PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1); + END IF; + RETURN NEW; +END; +$$; + +CREATE FUNCTION on_queue_delete() RETURNS TRIGGER +LANGUAGE plpgsql +AS $$ +BEGIN + IF OLD.deleted_at IS NULL THEN + IF OLD.rcv_service_id IS NOT NULL THEN + PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1); + END IF; + IF OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN + PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1); + END IF; + END IF; + RETURN OLD; +END; +$$; + +CREATE FUNCTION on_queue_update() RETURNS TRIGGER +LANGUAGE plpgsql +AS $$ +BEGIN + IF OLD.deleted_at IS NULL AND OLD.rcv_service_id IS NOT NULL THEN + IF NOT (NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL) THEN + PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1); + ELSIF OLD.rcv_service_id IS DISTINCT FROM NEW.rcv_service_id THEN + PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1); + PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1); + END IF; + ELSIF NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL THEN + PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1); + END IF; + + IF OLD.deleted_at IS NULL AND OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN + IF NOT (NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL) THEN + PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1); + ELSIF OLD.ntf_service_id IS DISTINCT FROM NEW.ntf_service_id OR OLD.notifier_id IS DISTINCT FROM NEW.notifier_id THEN + PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1); + PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1); + END IF; + ELSIF NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN + PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1); + END IF; + RETURN NEW; +END; +$$; + +CREATE TRIGGER tr_queue_insert +AFTER INSERT ON msg_queues +FOR EACH ROW EXECUTE PROCEDURE on_queue_insert(); + +CREATE TRIGGER tr_queue_delete +AFTER DELETE ON msg_queues +FOR EACH ROW EXECUTE PROCEDURE on_queue_delete(); + +CREATE TRIGGER tr_queue_update +AFTER UPDATE ON msg_queues +FOR EACH ROW EXECUTE PROCEDURE on_queue_update(); + |] + +down_m20250915_queue_ids_hash :: Text +down_m20250915_queue_ids_hash = + [r| +DROP TRIGGER tr_queue_insert ON msg_queues; +DROP TRIGGER tr_queue_delete ON msg_queues; +DROP TRIGGER tr_queue_update ON msg_queues; + +DROP FUNCTION on_queue_insert; +DROP FUNCTION on_queue_delete; +DROP FUNCTION on_queue_update; + +DROP FUNCTION update_aggregates; + +DROP FUNCTION update_all_aggregates; + +ALTER TABLE services + DROP COLUMN queue_count, + DROP COLUMN queue_ids_hash; + |] + <> dropXorHashFuncs diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql b/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql index 433d45473..f0da5272d 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql @@ -104,6 +104,71 @@ $$; +CREATE FUNCTION smp_server.on_queue_delete() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF OLD.deleted_at IS NULL THEN + IF OLD.rcv_service_id IS NOT NULL THEN + PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1); + END IF; + IF OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN + PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1); + END IF; + END IF; + RETURN OLD; +END; +$$; + + + +CREATE FUNCTION smp_server.on_queue_insert() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF NEW.rcv_service_id IS NOT NULL THEN + PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1); + END IF; + IF NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN + PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1); + END IF; + RETURN NEW; +END; +$$; + + + +CREATE FUNCTION smp_server.on_queue_update() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF OLD.deleted_at IS NULL AND OLD.rcv_service_id IS NOT NULL THEN + IF NOT (NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL) THEN + PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1); + ELSIF OLD.rcv_service_id IS DISTINCT FROM NEW.rcv_service_id THEN + PERFORM update_aggregates(OLD.rcv_service_id, 'M', OLD.recipient_id, -1); + PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1); + END IF; + ELSIF NEW.deleted_at IS NULL AND NEW.rcv_service_id IS NOT NULL THEN + PERFORM update_aggregates(NEW.rcv_service_id, 'M', NEW.recipient_id, 1); + END IF; + + IF OLD.deleted_at IS NULL AND OLD.ntf_service_id IS NOT NULL AND OLD.notifier_id IS NOT NULL THEN + IF NOT (NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL) THEN + PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1); + ELSIF OLD.ntf_service_id IS DISTINCT FROM NEW.ntf_service_id OR OLD.notifier_id IS DISTINCT FROM NEW.notifier_id THEN + PERFORM update_aggregates(OLD.ntf_service_id, 'N', OLD.notifier_id, -1); + PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1); + END IF; + ELSIF NEW.deleted_at IS NULL AND NEW.ntf_service_id IS NOT NULL AND NEW.notifier_id IS NOT NULL THEN + PERFORM update_aggregates(NEW.ntf_service_id, 'N', NEW.notifier_id, 1); + END IF; + RETURN NEW; +END; +$$; + + + CREATE FUNCTION smp_server.try_del_msg(p_recipient_id bytea, p_msg_id bytea) RETURNS TABLE(r_msg_id bytea, r_msg_ts bigint, r_msg_quota boolean, r_msg_ntf_flag boolean, r_msg_body bytea) LANGUAGE plpgsql AS $$ @@ -225,6 +290,43 @@ $$; +CREATE FUNCTION smp_server.update_aggregates(p_service_id bytea, p_role text, p_queue_id bytea, p_change bigint) RETURNS void + LANGUAGE plpgsql + AS $$ +BEGIN + UPDATE services + SET queue_count = queue_count + p_change, + queue_ids_hash = xor_combine(queue_ids_hash, public.digest(p_queue_id, 'md5')) + WHERE service_id = p_service_id AND service_role = p_role; +END; +$$; + + + +CREATE FUNCTION smp_server.update_all_aggregates() RETURNS void + LANGUAGE plpgsql + AS $$ +BEGIN + WITH acc AS ( + SELECT + s.service_id, + count(1) as q_count, + xor_aggregate(public.digest(CASE WHEN s.service_role = 'M' THEN q.recipient_id ELSE COALESCE(q.notifier_id, '\x00000000000000000000000000000000') END, 'md5')) AS q_ids_hash + FROM services s + JOIN msg_queues q ON (s.service_id = q.rcv_service_id AND s.service_role = 'M') OR (s.service_id = q.ntf_service_id AND s.service_role = 'N') + WHERE q.deleted_at IS NULL + GROUP BY s.service_id + ) + UPDATE services s + SET queue_count = COALESCE(acc.q_count, 0), + queue_ids_hash = COALESCE(acc.q_ids_hash, '\x00000000000000000000000000000000') + FROM acc + WHERE s.service_id = acc.service_id; +END; +$$; + + + CREATE FUNCTION smp_server.write_message(p_recipient_id bytea, p_msg_id bytea, p_msg_ts bigint, p_msg_quota boolean, p_msg_ntf_flag boolean, p_msg_body bytea, p_quota integer) RETURNS TABLE(quota_written boolean, was_empty boolean) LANGUAGE plpgsql AS $$ @@ -256,6 +358,34 @@ END; $$; + +CREATE FUNCTION smp_server.xor_combine(state bytea, value bytea) RETURNS bytea + LANGUAGE plpgsql IMMUTABLE STRICT + AS $$ +DECLARE + result BYTEA := state; + i INTEGER; + len INTEGER := octet_length(value); +BEGIN + IF octet_length(state) != len THEN + RAISE EXCEPTION 'Inputs must be equal length (% != %)', octet_length(state), len; + END IF; + FOR i IN 0..len-1 LOOP + result := set_byte(result, i, get_byte(state, i) # get_byte(value, i)); + END LOOP; + RETURN result; +END; +$$; + + + +CREATE AGGREGATE smp_server.xor_aggregate(bytea) ( + SFUNC = smp_server.xor_combine, + STYPE = bytea, + INITCOND = '\x00000000000000000000000000000000' +); + + SET default_table_access_method = heap; @@ -320,7 +450,9 @@ CREATE TABLE smp_server.services ( service_role text NOT NULL, service_cert bytea NOT NULL, service_cert_hash bytea NOT NULL, - created_at bigint NOT NULL + created_at bigint NOT NULL, + queue_count bigint DEFAULT 0 NOT NULL, + queue_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL ); @@ -390,6 +522,18 @@ CREATE INDEX idx_services_service_role ON smp_server.services USING btree (servi +CREATE TRIGGER tr_queue_delete AFTER DELETE ON smp_server.msg_queues FOR EACH ROW EXECUTE FUNCTION smp_server.on_queue_delete(); + + + +CREATE TRIGGER tr_queue_insert AFTER INSERT ON smp_server.msg_queues FOR EACH ROW EXECUTE FUNCTION smp_server.on_queue_insert(); + + + +CREATE TRIGGER tr_queue_update AFTER UPDATE ON smp_server.msg_queues FOR EACH ROW EXECUTE FUNCTION smp_server.on_queue_update(); + + + ALTER TABLE ONLY smp_server.messages ADD CONSTRAINT messages_recipient_id_fkey FOREIGN KEY (recipient_id) REFERENCES smp_server.msg_queues(recipient_id) ON UPDATE RESTRICT ON DELETE CASCADE; diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index ad3e00a03..8b64db55a 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -28,6 +28,7 @@ where import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad +import Data.Bifunctor (first) import Data.Bitraversable (bimapM) import Data.Functor (($>)) import Data.Int (Int64) @@ -62,8 +63,8 @@ data STMQueueStore q = STMQueueStore data STMService = STMService { serviceRec :: ServiceRec, - serviceRcvQueues :: TVar (Set RecipientId), - serviceNtfQueues :: TVar (Set NotifierId) + serviceRcvQueues :: TVar (Set RecipientId, IdsHash), -- TODO [certs rcv] get/maintain hash + serviceNtfQueues :: TVar (Set NotifierId, IdsHash) -- TODO [certs rcv] get/maintain hash } setStoreLog :: STMQueueStore q -> StoreLog 'WriteMode -> IO () @@ -113,7 +114,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where } where serviceCount role = M.foldl' (\ !n s -> if serviceRole (serviceRec s) == role then n + 1 else n) 0 - serviceQueuesCount serviceSel = foldM (\n s -> (n +) . S.size <$> readTVarIO (serviceSel s)) 0 + serviceQueuesCount serviceSel = foldM (\n s -> (n +) . S.size . fst <$> readTVarIO (serviceSel s)) 0 addQueue_ :: STMQueueStore q -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q) addQueue_ st mkQ rId qr@QueueRec {senderId = sId, notifier, queueData, rcvServiceId} = do @@ -304,8 +305,8 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where TM.insert fp newSrvId serviceCerts pure $ Right (newSrvId, True) newSTMService = do - serviceRcvQueues <- newTVar S.empty - serviceNtfQueues <- newTVar S.empty + serviceRcvQueues <- newTVar (S.empty, mempty) + serviceNtfQueues <- newTVar (S.empty, mempty) pure STMService {serviceRec = sr, serviceRcvQueues, serviceNtfQueues} setQueueService :: (PartyI p, ServiceParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) @@ -331,7 +332,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where let !q' = Just q {notifier = Just nc {ntfServiceId = serviceId}} updateServiceQueues serviceNtfQueues nId prevNtfSrvId writeTVar qr q' $> Right () - updateServiceQueues :: (STMService -> TVar (Set QueueId)) -> QueueId -> Maybe ServiceId -> STM () + updateServiceQueues :: (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> Maybe ServiceId -> STM () updateServiceQueues serviceSel qId prevSrvId = do mapM_ (removeServiceQueue st serviceSel qId) prevSrvId mapM_ (addServiceQueue st serviceSel qId) serviceId @@ -346,16 +347,16 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where pure $ Right (ssNtfs', deleteNtfs) where addService (ssNtfs, ntfs') (serviceId, s) = do - snIds <- readTVarIO $ serviceNtfQueues s + (snIds, _) <- readTVarIO $ serviceNtfQueues s let (sNtfs, restNtfs) = partition (\(nId, _) -> S.member nId snIds) ntfs' pure ((Just serviceId, sNtfs) : ssNtfs, restNtfs) - getServiceQueueCount :: (PartyI p, ServiceParty p) => STMQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType Int64) - getServiceQueueCount st party serviceId = + getServiceQueueCountHash :: (PartyI p, ServiceParty p) => STMQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType (Int64, IdsHash)) + getServiceQueueCountHash st party serviceId = TM.lookupIO serviceId (services st) >>= - maybe (pure $ Left AUTH) (fmap (Right . fromIntegral . S.size) . readTVarIO . serviceSel) + maybe (pure $ Left AUTH) (fmap (Right . first (fromIntegral . S.size)) . readTVarIO . serviceSel) where - serviceSel :: STMService -> TVar (Set QueueId) + serviceSel :: STMService -> TVar (Set QueueId, IdsHash) serviceSel = case party of SRecipientService -> serviceRcvQueues SNotifierService -> serviceNtfQueues @@ -366,7 +367,7 @@ foldRcvServiceQueues st serviceId f acc = Nothing -> pure acc Just s -> readTVarIO (serviceRcvQueues s) - >>= foldM (\a -> get >=> maybe (pure a) (f a)) acc + >>= foldM (\a -> get >=> maybe (pure a) (f a)) acc . fst where get rId = TM.lookupIO rId (queues st) $>>= \q -> (q,) <$$> readTVarIO (queueRec q) @@ -379,16 +380,23 @@ setStatus qr status = Just q -> (Right (), Just q {status}) Nothing -> (Left AUTH, Nothing) -addServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId)) -> QueueId -> ServiceId -> STM () -addServiceQueue st serviceSel qId serviceId = - TM.lookup serviceId (services st) >>= mapM_ (\s -> modifyTVar' (serviceSel s) (S.insert qId)) +addServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> ServiceId -> STM () +addServiceQueue = setServiceQueues_ S.insert {-# INLINE addServiceQueue #-} -removeServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId)) -> QueueId -> ServiceId -> STM () -removeServiceQueue st serviceSel qId serviceId = - TM.lookup serviceId (services st) >>= mapM_ (\s -> modifyTVar' (serviceSel s) (S.delete qId)) +removeServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> ServiceId -> STM () +removeServiceQueue = setServiceQueues_ S.delete {-# INLINE removeServiceQueue #-} +setServiceQueues_ :: (QueueId -> Set QueueId -> Set QueueId) -> STMQueueStore q -> (STMService -> TVar (Set QueueId, IdsHash)) -> QueueId -> ServiceId -> STM () +setServiceQueues_ updateSet st serviceSel qId serviceId = + TM.lookup serviceId (services st) >>= mapM_ (\v -> modifyTVar' (serviceSel v) update) + where + update (s, idsHash) = + let !s' = updateSet qId s + !idsHash' = queueIdHash qId <> idsHash + in (s', idsHash') + removeNotifier :: STMQueueStore q -> NtfCreds -> STM () removeNotifier st NtfCreds {notifierId = nId, ntfServiceId} = do TM.delete nId $ notifiers st diff --git a/src/Simplex/Messaging/Server/QueueStore/Types.hs b/src/Simplex/Messaging/Server/QueueStore/Types.hs index 8de015421..723930e9f 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Types.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Types.hs @@ -47,7 +47,7 @@ class StoreQueueClass q => QueueStoreClass q s where getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId) setQueueService :: (PartyI p, ServiceParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) getQueueNtfServices :: s -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)])) - getServiceQueueCount :: (PartyI p, ServiceParty p) => s -> SParty p -> ServiceId -> IO (Either ErrorType Int64) + getServiceQueueCountHash :: (PartyI p, ServiceParty p) => s -> SParty p -> ServiceId -> IO (Either ErrorType (Int64, IdsHash)) data EntityCounts = EntityCounts { queueCount :: Int, diff --git a/src/Simplex/Messaging/Server/Stats.hs b/src/Simplex/Messaging/Server/Stats.hs index e60f87815..120fad7b6 100644 --- a/src/Simplex/Messaging/Server/Stats.hs +++ b/src/Simplex/Messaging/Server/Stats.hs @@ -821,7 +821,15 @@ data ServiceStats = ServiceStats srvSubCount :: IORef Int, srvSubDuplicate :: IORef Int, srvSubQueues :: IORef Int, - srvSubEnd :: IORef Int + srvSubEnd :: IORef Int, + -- counts of subscriptions + srvSubOk :: IORef Int, -- server has the same queues as expected + srvSubMore :: IORef Int, -- server has more queues than expected + srvSubFewer :: IORef Int, -- server has fewer queues than expected + srvSubDiff :: IORef Int, -- server has the same count, but different queues than expected (based on xor-hash) + -- adds actual deviations + srvSubMoreTotal :: IORef Int, -- server has more queues than expected, adds diff + srvSubFewerTotal :: IORef Int } data ServiceStatsData = ServiceStatsData @@ -832,7 +840,13 @@ data ServiceStatsData = ServiceStatsData _srvSubCount :: Int, _srvSubDuplicate :: Int, _srvSubQueues :: Int, - _srvSubEnd :: Int + _srvSubEnd :: Int, + _srvSubOk :: Int, + _srvSubMore :: Int, + _srvSubFewer :: Int, + _srvSubDiff :: Int, + _srvSubMoreTotal :: Int, + _srvSubFewerTotal :: Int } deriving (Show) @@ -846,7 +860,13 @@ newServiceStatsData = _srvSubCount = 0, _srvSubDuplicate = 0, _srvSubQueues = 0, - _srvSubEnd = 0 + _srvSubEnd = 0, + _srvSubOk = 0, + _srvSubMore = 0, + _srvSubFewer = 0, + _srvSubDiff = 0, + _srvSubMoreTotal = 0, + _srvSubFewerTotal = 0 } newServiceStats :: IO ServiceStats @@ -859,6 +879,12 @@ newServiceStats = do srvSubDuplicate <- newIORef 0 srvSubQueues <- newIORef 0 srvSubEnd <- newIORef 0 + srvSubOk <- newIORef 0 + srvSubMore <- newIORef 0 + srvSubFewer <- newIORef 0 + srvSubDiff <- newIORef 0 + srvSubMoreTotal <- newIORef 0 + srvSubFewerTotal <- newIORef 0 pure ServiceStats { srvAssocNew, @@ -868,7 +894,13 @@ newServiceStats = do srvSubCount, srvSubDuplicate, srvSubQueues, - srvSubEnd + srvSubEnd, + srvSubOk, + srvSubMore, + srvSubFewer, + srvSubDiff, + srvSubMoreTotal, + srvSubFewerTotal } getServiceStatsData :: ServiceStats -> IO ServiceStatsData @@ -881,6 +913,12 @@ getServiceStatsData s = do _srvSubDuplicate <- readIORef $ srvSubDuplicate s _srvSubQueues <- readIORef $ srvSubQueues s _srvSubEnd <- readIORef $ srvSubEnd s + _srvSubOk <- readIORef $ srvSubOk s + _srvSubMore <- readIORef $ srvSubMore s + _srvSubFewer <- readIORef $ srvSubFewer s + _srvSubDiff <- readIORef $ srvSubDiff s + _srvSubMoreTotal <- readIORef $ srvSubMoreTotal s + _srvSubFewerTotal <- readIORef $ srvSubFewerTotal s pure ServiceStatsData { _srvAssocNew, @@ -890,7 +928,13 @@ getServiceStatsData s = do _srvSubCount, _srvSubDuplicate, _srvSubQueues, - _srvSubEnd + _srvSubEnd, + _srvSubOk, + _srvSubMore, + _srvSubFewer, + _srvSubDiff, + _srvSubMoreTotal, + _srvSubFewerTotal } getResetServiceStatsData :: ServiceStats -> IO ServiceStatsData @@ -903,6 +947,12 @@ getResetServiceStatsData s = do _srvSubDuplicate <- atomicSwapIORef (srvSubDuplicate s) 0 _srvSubQueues <- atomicSwapIORef (srvSubQueues s) 0 _srvSubEnd <- atomicSwapIORef (srvSubEnd s) 0 + _srvSubOk <- atomicSwapIORef (srvSubOk s) 0 + _srvSubMore <- atomicSwapIORef (srvSubMore s) 0 + _srvSubFewer <- atomicSwapIORef (srvSubFewer s) 0 + _srvSubDiff <- atomicSwapIORef (srvSubDiff s) 0 + _srvSubMoreTotal <- atomicSwapIORef (srvSubMoreTotal s) 0 + _srvSubFewerTotal <- atomicSwapIORef (srvSubFewerTotal s) 0 pure ServiceStatsData { _srvAssocNew, @@ -912,7 +962,13 @@ getResetServiceStatsData s = do _srvSubCount, _srvSubDuplicate, _srvSubQueues, - _srvSubEnd + _srvSubEnd, + _srvSubOk, + _srvSubMore, + _srvSubFewer, + _srvSubDiff, + _srvSubMoreTotal, + _srvSubFewerTotal } -- this function is not thread safe, it is used on server start only @@ -926,6 +982,12 @@ setServiceStats s d = do writeIORef (srvSubDuplicate s) $! _srvSubDuplicate d writeIORef (srvSubQueues s) $! _srvSubQueues d writeIORef (srvSubEnd s) $! _srvSubEnd d + writeIORef (srvSubOk s) $! _srvSubOk d + writeIORef (srvSubMore s) $! _srvSubMore d + writeIORef (srvSubFewer s) $! _srvSubFewer d + writeIORef (srvSubDiff s) $! _srvSubDiff d + writeIORef (srvSubMoreTotal s) $! _srvSubMoreTotal d + writeIORef (srvSubFewerTotal s) $! _srvSubFewerTotal d instance StrEncoding ServiceStatsData where strEncode ServiceStatsData {_srvAssocNew, _srvAssocDuplicate, _srvAssocUpdated, _srvAssocRemoved, _srvSubCount, _srvSubDuplicate, _srvSubQueues, _srvSubEnd} = @@ -963,7 +1025,13 @@ instance StrEncoding ServiceStatsData where _srvSubCount, _srvSubDuplicate, _srvSubQueues, - _srvSubEnd + _srvSubEnd, + _srvSubOk = 0, + _srvSubMore = 0, + _srvSubFewer = 0, + _srvSubDiff = 0, + _srvSubMoreTotal = 0, + _srvSubFewerTotal = 0 } data TimeBuckets = TimeBuckets diff --git a/src/Simplex/Messaging/Server/StoreLog/ReadWrite.hs b/src/Simplex/Messaging/Server/StoreLog/ReadWrite.hs index ea6c9ed4a..2fd4ca6d8 100644 --- a/src/Simplex/Messaging/Server/StoreLog/ReadWrite.hs +++ b/src/Simplex/Messaging/Server/StoreLog/ReadWrite.hs @@ -61,7 +61,7 @@ readQueueStore tty mkQ f st = readLogLines tty f $ \_ -> processLine Left e -> logError $ errPfx <> tshow e where errPfx = "STORE: getCreateService, stored service " <> decodeLatin1 (strEncode serviceId) <> ", " - QueueService rId (ASP party) serviceId -> withQueue rId "QueueService" $ \q -> setQueueService st q party serviceId + QueueService qId (ASP party) serviceId -> withQueue qId "QueueService" $ \q -> setQueueService st q party serviceId printError :: String -> IO () printError e = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> s withQueue :: forall a. RecipientId -> T.Text -> (q -> IO (Either ErrorType a)) -> IO () diff --git a/tests/AgentTests/EqInstances.hs b/tests/AgentTests/EqInstances.hs index 63c493861..e142c6177 100644 --- a/tests/AgentTests/EqInstances.hs +++ b/tests/AgentTests/EqInstances.hs @@ -8,6 +8,7 @@ import Data.Type.Equality import Simplex.Messaging.Agent.Protocol (ConnLinkData (..), OwnerAuth (..), UserContactData (..), UserLinkData (..)) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Client (ProxiedRelay (..)) +import Simplex.Messaging.Protocol (ServiceSub (..)) instance (Eq rq, Eq sq) => Eq (SomeConn' rq sq) where SomeConn d c == SomeConn d' c' = case testEquality d d' of @@ -47,3 +48,7 @@ deriving instance Eq OwnerAuth deriving instance Show ProxiedRelay deriving instance Eq ProxiedRelay + +deriving instance Show ServiceSub + +deriving instance Eq ServiceSub diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 017958890..7f9641a5b 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -476,6 +476,8 @@ functionalAPITests ps = do testUsersNoServer ps it "should connect two users and switch session mode" $ withSmpServer ps testTwoUsers + describe "Client service certificates" $ do + it "should connect, subscribe and reconnect as a service" $ testClientServiceConnection ps describe "Connection switch" $ do describe "should switch delivery to the new queue" $ testServerMatrix2 ps testSwitchConnection @@ -3664,6 +3666,32 @@ testTwoUsers = withAgentClients2 $ \a b -> do hasClients :: HasCallStack => AgentClient -> Int -> ExceptT AgentErrorType IO () hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n +testClientServiceConnection :: HasCallStack => (ASrvTransport, AStoreType) -> IO () +testClientServiceConnection ps = do + (sId, uId) <- withSmpServerStoreLogOn ps testPort $ \_ -> do + conns@(sId, uId) <- withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> runRight $ do + conns@(sId, uId) <- makeConnection service user + exchangeGreetings service uId user sId + pure conns + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> runRight $ do + subscribeClientServices service 1 + subscribeConnection user sId + exchangeGreetingsMsgId 4 service uId user sId + pure conns + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + subscribeClientServices service 1 + subscribeConnection user sId + exchangeGreetingsMsgId 6 service uId user sId + ("", "", DOWN _ [_]) <- nGet user + -- TODO [certs rcv] how to integrate service counts into stats + -- r <- nGet service -- TODO [certs rcv] some event when service disconnects with count + -- print r + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + ("", "", UP _ [_]) <- nGet user + -- r <- nGet service -- TODO [certs rcv] some event when service reconnects with count + exchangeGreetingsMsgId 8 service uId user sId + getSMPAgentClient' :: Int -> AgentConfig -> InitialAgentServers -> String -> IO AgentClient getSMPAgentClient' clientId cfg' initServers dbPath = do Right st <- liftIO $ createStore dbPath diff --git a/tests/CoreTests/TSessionSubs.hs b/tests/CoreTests/TSessionSubs.hs index e3f819332..e9038b9d9 100644 --- a/tests/CoreTests/TSessionSubs.hs +++ b/tests/CoreTests/TSessionSubs.hs @@ -58,9 +58,9 @@ testSessionSubs = do atomically (SS.hasPendingSubs tSess2 ss) `shouldReturn` True atomically (SS.batchAddPendingSubs tSess1 [q1, q2] ss') atomically (SS.batchAddPendingSubs tSess2 [q3] ss') - atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` M.fromList [("r1", q1), ("r2", q2)] + atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` (M.fromList [("r1", q1), ("r2", q2)], Nothing) atomically (SS.getActiveSubs tSess1 ss) `shouldReturn` M.fromList [] - atomically (SS.getPendingSubs tSess2 ss) `shouldReturn` M.fromList [("r3", q3)] + atomically (SS.getPendingSubs tSess2 ss) `shouldReturn` (M.fromList [("r3", q3)], Nothing) st <- dumpSessionSubs ss dumpSessionSubs ss' `shouldReturn` st countSubs ss `shouldReturn` (0, 3) @@ -69,41 +69,41 @@ testSessionSubs = do atomically (SS.hasPendingSub tSess1 (rcvId q4) ss) `shouldReturn` False atomically (SS.hasActiveSub tSess1 (rcvId q4) ss) `shouldReturn` False -- setting active queue without setting session ID would keep it as pending - atomically $ SS.addActiveSub tSess1 "123" q1 ss + atomically $ SS.addActiveSub' tSess1 "123" q1 False ss atomically (SS.hasPendingSub tSess1 (rcvId q1) ss) `shouldReturn` True atomically (SS.hasActiveSub tSess1 (rcvId q1) ss) `shouldReturn` False dumpSessionSubs ss `shouldReturn` st countSubs ss `shouldReturn` (0, 3) -- setting active queues atomically $ SS.setSessionId tSess1 "123" ss - atomically $ SS.addActiveSub tSess1 "123" q1 ss + atomically $ SS.addActiveSub' tSess1 "123" q1 False ss atomically (SS.hasPendingSub tSess1 (rcvId q1) ss) `shouldReturn` False atomically (SS.hasActiveSub tSess1 (rcvId q1) ss) `shouldReturn` True atomically (SS.getActiveSubs tSess1 ss) `shouldReturn` M.fromList [("r1", q1)] - atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` M.fromList [("r2", q2)] + atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` (M.fromList [("r2", q2)], Nothing) countSubs ss `shouldReturn` (1, 2) atomically $ SS.setSessionId tSess2 "456" ss - atomically $ SS.addActiveSub tSess2 "456" q4 ss + atomically $ SS.addActiveSub' tSess2 "456" q4 False ss atomically (SS.hasPendingSub tSess2 (rcvId q4) ss) `shouldReturn` False atomically (SS.hasActiveSub tSess2 (rcvId q4) ss) `shouldReturn` True atomically (SS.hasActiveSub tSess1 (rcvId q4) ss) `shouldReturn` False -- wrong transport session atomically (SS.getActiveSubs tSess2 ss) `shouldReturn` M.fromList [("r4", q4)] - atomically (SS.getPendingSubs tSess2 ss) `shouldReturn` M.fromList [("r3", q3)] + atomically (SS.getPendingSubs tSess2 ss) `shouldReturn` (M.fromList [("r3", q3)], Nothing) countSubs ss `shouldReturn` (2, 2) -- setting pending queues st' <- dumpSessionSubs ss - atomically (SS.setSubsPending TSMUser tSess1 "abc" ss) `shouldReturn` M.empty -- wrong session + atomically (SS.setSubsPending TSMUser tSess1 "abc" ss) `shouldReturn` (M.empty, Nothing) -- wrong session dumpSessionSubs ss `shouldReturn` st' - atomically (SS.setSubsPending TSMUser tSess1 "123" ss) `shouldReturn` M.fromList [("r1", q1)] + atomically (SS.setSubsPending TSMUser tSess1 "123" ss) `shouldReturn` (M.fromList [("r1", q1)], Nothing) atomically (SS.getActiveSubs tSess1 ss) `shouldReturn` M.fromList [] - atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` M.fromList [("r1", q1), ("r2", q2)] + atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` (M.fromList [("r1", q1), ("r2", q2)], Nothing) countSubs ss `shouldReturn` (1, 3) -- delete subs atomically $ SS.deletePendingSub tSess1 (rcvId q1) ss - atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` M.fromList [("r2", q2)] + atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` (M.fromList [("r2", q2)], Nothing) countSubs ss `shouldReturn` (1, 2) atomically $ SS.deleteSub tSess1 (rcvId q2) ss - atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` M.fromList [] + atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` (M.fromList [], Nothing) countSubs ss `shouldReturn` (1, 1) atomically (SS.getActiveSubs tSess2 ss) `shouldReturn` M.fromList [("r4", q4)] atomically $ SS.deleteSub tSess2 (rcvId q4) ss diff --git a/tests/Fixtures.hs b/tests/Fixtures.hs index 2360a7ba6..f2f314fed 100644 --- a/tests/Fixtures.hs +++ b/tests/Fixtures.hs @@ -3,7 +3,9 @@ module Fixtures where import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as B import Database.PostgreSQL.Simple (ConnectInfo (..), defaultConnectInfo) +import Simplex.Messaging.Agent.Store.Postgres.Options testDBConnstr :: ByteString testDBConnstr = "postgresql://test_agent_user@/test_agent_db" @@ -14,3 +16,6 @@ testDBConnectInfo = connectUser = "test_agent_user", connectDatabase = "test_agent_db" } + +testDBOpts :: String -> DBOpts +testDBOpts schema' = DBOpts testDBConnstr (B.pack schema') 1 True diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 935775050..41aab2039 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -83,6 +83,9 @@ initAgentServersProxy_ smpProxyMode smpProxyFallback = initAgentServersProxy2 :: InitialAgentServers initAgentServersProxy2 = initAgentServersProxy {smp = userServers [testSMPServer2]} +initAgentServersClientService :: InitialAgentServers +initAgentServersClientService = initAgentServers {useServices = M.fromList [(1, True)]} + agentCfg :: AgentConfig agentCfg = defaultAgentConfig diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 39009794c..d3e1b21d0 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -712,15 +712,17 @@ testServiceDeliverSubscribe = pure (rId, sId, dec, serviceId) runSMPServiceClient t (tlsCred, serviceKeys) $ \sh -> do - Resp "10" NoEntity (ERR (CMD NO_AUTH)) <- signSendRecv sh aServicePK ("10", NoEntity, SUBS) - signSend_ sh aServicePK Nothing ("11", serviceId, SUBS) + let idsHash = queueIdsHash [rId] + Resp "10" NoEntity (ERR (CMD NO_AUTH)) <- signSendRecv sh aServicePK ("10", NoEntity, SUBS 1 idsHash) + signSend_ sh aServicePK Nothing ("11", serviceId, SUBS 1 idsHash) -- TODO [certs rcv] compute and compare hashes [mId3] <- fmap catMaybes $ receiveInAnyOrder -- race between SOKS and MSG, clients can handle it sh [ \case - Resp "11" serviceId' (SOKS n _) -> do + Resp "11" serviceId' (SOKS n idsHash') -> do n `shouldBe` 1 + idsHash' `shouldBe` idsHash serviceId' `shouldBe` serviceId pure $ Just Nothing _ -> pure Nothing, @@ -805,14 +807,16 @@ testServiceUpgradeAndDowngrade = Resp "12" _ OK <- signSendRecv h sKey2 ("12", sId2, _SEND "hello 3.2") runSMPServiceClient t (tlsCred, serviceKeys) $ \sh -> do - signSend_ sh aServicePK Nothing ("14", serviceId, SUBS) + let idsHash = queueIdsHash [rId, rId2, rId3] + signSend_ sh aServicePK Nothing ("14", serviceId, SUBS 3 idsHash) -- TODO [certs rcv] compute hash [(rKey3_1, rId3_1, mId3_1), (rKey3_2, rId3_2, mId3_2)] <- fmap catMaybes $ receiveInAnyOrder -- race between SOKS and MSG, clients can handle it sh [ \case - Resp "14" serviceId' (SOKS n _) -> do + Resp "14" serviceId' (SOKS n idsHash') -> do n `shouldBe` 3 + idsHash' `shouldBe` idsHash serviceId' `shouldBe` serviceId pure $ Just Nothing _ -> pure Nothing, @@ -835,7 +839,7 @@ testServiceUpgradeAndDowngrade = Resp "17" _ OK <- signSendRecv h sKey ("17", sId, _SEND "hello 4") runSMPClient t $ \sh -> do - Resp "18" _ (ERR SERVICE) <- signSendRecv sh aServicePK ("18", serviceId, SUBS) + Resp "18" _ (ERR SERVICE) <- signSendRecv sh aServicePK ("18", serviceId, SUBS 3 mempty) (Resp "19" rId' (SOK Nothing), Resp "" rId'' (Msg mId4 msg4)) <- signSendRecv2 sh rKey ("19", rId, SUB) rId' `shouldBe` rId rId'' `shouldBe` rId @@ -1366,7 +1370,9 @@ testMessageServiceNotifications = deliverMessage rh rId rKey sh sId sKey nh2 "connection 1" dec deliverMessage rh rId'' rKey'' sh sId'' sKey'' nh2 "connection 2" dec'' -- -- another client makes service subscription - Resp "12" serviceId5 (SOKS 2 _) <- signSendRecv nh1 (C.APrivateAuthKey C.SEd25519 servicePK) ("12", serviceId, NSUBS) + let idsHash = queueIdsHash [nId', nId''] + Resp "12" serviceId5 (SOKS 2 idsHash') <- signSendRecv nh1 (C.APrivateAuthKey C.SEd25519 servicePK) ("12", serviceId, NSUBS 2 idsHash) -- TODO [certs rcv] compute and compare hashes + idsHash' `shouldBe` idsHash serviceId5 `shouldBe` serviceId Resp "" serviceId6 (ENDS 2) <- tGet1 nh2 serviceId6 `shouldBe` serviceId @@ -1389,18 +1395,19 @@ testServiceNotificationsTwoRestarts = (nPub, nKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g serviceKeys@(_, servicePK) <- atomically $ C.generateKeyPair g (rcvNtfPubDhKey, _) <- atomically $ C.generateKeyPair g - (rId, rKey, sId, dec, serviceId) <- withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh -> do + (rId, rKey, sId, dec, nId, serviceId) <- withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh -> do (sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub let dec = decryptMsgV3 dhShared Resp "0" _ (NID nId _) <- signSendRecv rh rKey ("0", rId, NKEY nPub rcvNtfPubDhKey) testNtfServiceClient t serviceKeys $ \nh -> do Resp "1" _ (SOK (Just serviceId)) <- serviceSignSendRecv nh nKey servicePK ("1", nId, NSUB) deliverMessage rh rId rKey sh sId sKey nh "hello" dec - pure (rId, rKey, sId, dec, serviceId) + pure (rId, rKey, sId, dec, nId, serviceId) + let idsHash = queueIdsHash [nId] threadDelay 250000 withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh -> testNtfServiceClient t serviceKeys $ \nh -> do - Resp "2.1" serviceId' (SOKS n _) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("2.1", serviceId, NSUBS) + Resp "2.1" serviceId' (SOKS n _) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("2.1", serviceId, NSUBS 1 idsHash) n `shouldBe` 1 Resp "2.2" _ (SOK Nothing) <- signSendRecv rh rKey ("2.2", rId, SUB) serviceId' `shouldBe` serviceId @@ -1408,7 +1415,7 @@ testServiceNotificationsTwoRestarts = threadDelay 250000 withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh -> testNtfServiceClient t serviceKeys $ \nh -> do - Resp "3.1" _ (SOKS n _) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("3.1", serviceId, NSUBS) + Resp "3.1" _ (SOKS n _) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("3.1", serviceId, NSUBS 1 idsHash) n `shouldBe` 1 Resp "3.2" _ (SOK Nothing) <- signSendRecv rh rKey ("3.2", rId, SUB) deliverMessage rh rId rKey sh sId sKey nh "hello 3" dec diff --git a/tests/Test.hs b/tests/Test.hs index 3e36e192d..260366fc8 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -38,6 +38,8 @@ import XFTPServerTests (xftpServerTests) #if defined(dbPostgres) import Fixtures +import SMPAgentClient (testDB) +import Simplex.Messaging.Agent.Store.Postgres.Migrations.App #else import AgentTests.SchemaDump (schemaDumpTest) #endif @@ -45,13 +47,13 @@ import AgentTests.SchemaDump (schemaDumpTest) #if defined(dbServerPostgres) import NtfServerTests (ntfServerTests) import NtfClient (ntfTestServerDBConnectInfo, ntfTestStoreDBOpts) -import PostgresSchemaDump (postgresSchemaDumpTest) import SMPClient (testServerDBConnectInfo, testStoreDBOpts) import Simplex.Messaging.Notifications.Server.Store.Migrations (ntfServerMigrations) import Simplex.Messaging.Server.QueueStore.Postgres.Migrations (serverMigrations) #endif #if defined(dbPostgres) || defined(dbServerPostgres) +import PostgresSchemaDump (postgresSchemaDumpTest) import SMPClient (postgressBracket) #endif @@ -71,10 +73,6 @@ main = do . before_ (createDirectoryIfMissing False "tests/tmp") . after_ (eventuallyRemove "tests/tmp" 3) $ do --- TODO [postgres] schema dump for postgres -#if !defined(dbPostgres) - describe "Agent SQLite schema dump" schemaDumpTest -#endif describe "Core tests" $ do describe "Batching tests" batchingTests describe "Encoding tests" encodingTests @@ -151,6 +149,17 @@ main = do describe "XFTP agent" xftpAgentTests describe "XRCP" remoteControlTests describe "Server CLIs" cliTests +#if defined(dbPostgres) + around_ (postgressBracket testDBConnectInfo) $ + describe "Agent PostgreSQL schema dump" $ + postgresSchemaDumpTest + appMigrations + ["20250322_short_links"] -- snd_secure and last_broker_ts columns swap order on down migration + (testDBOpts testDB) + "src/Simplex/Messaging/Agent/Store/Postgres/Migrations/agent_postgres_schema.sql" +#else + describe "Agent SQLite schema dump" schemaDumpTest +#endif eventuallyRemove :: FilePath -> Int -> IO () eventuallyRemove path retries = case retries of From 5e9b164f4e81e8a28dd845d392a0a1563ae7dbb8 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Tue, 25 Nov 2025 23:17:47 +0000 Subject: [PATCH 03/11] agent: fail when per-connection transport isolation is used with services (#1670) --- src/Simplex/Messaging/Agent.hs | 54 +++++++++++-------- src/Simplex/Messaging/Agent/Client.hs | 3 +- .../Messaging/Agent/Store/AgentStore.hs | 4 -- src/Simplex/Messaging/Client.hs | 2 +- src/Simplex/Messaging/Notifications/Server.hs | 4 +- src/Simplex/Messaging/Server.hs | 2 +- src/Simplex/Messaging/Server/Main.hs | 2 +- src/Simplex/Messaging/Transport.hs | 2 +- tests/AgentTests/FunctionalAPITests.hs | 10 ++-- tests/CoreTests/BatchingTests.hs | 2 +- tests/ServerTests.hs | 6 +-- 11 files changed, 49 insertions(+), 42 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 63516ada4..18bc0afbb 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -194,7 +194,7 @@ import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Agent.Store.Interface (closeDBStore, execSQL, getCurrentMigrations) import Simplex.Messaging.Agent.Store.Shared (UpMigration (..), upMigration) import qualified Simplex.Messaging.Agent.TSessionSubs as SS -import Simplex.Messaging.Client (NetworkRequestMode (..), SMPClientError, ServerTransmission (..), ServerTransmissionBatch, nonBlockingWriteTBQueue, smpErrorClientNotice, temporaryClientError, unexpectedResponse) +import Simplex.Messaging.Client (NetworkRequestMode (..), SMPClientError, ServerTransmission (..), ServerTransmissionBatch, TransportSessionMode (..), nonBlockingWriteTBQueue, smpErrorClientNotice, temporaryClientError, unexpectedResponse) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport (..), pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn) @@ -249,13 +249,15 @@ import UnliftIO.STM type AE a = ExceptT AgentErrorType IO a -- | Creates an SMP agent client instance -getSMPAgentClient :: AgentConfig -> InitialAgentServers -> DBStore -> Bool -> IO AgentClient +getSMPAgentClient :: AgentConfig -> InitialAgentServers -> DBStore -> Bool -> AE AgentClient getSMPAgentClient = getSMPAgentClient_ 1 {-# INLINE getSMPAgentClient #-} -getSMPAgentClient_ :: Int -> AgentConfig -> InitialAgentServers -> DBStore -> Bool -> IO AgentClient -getSMPAgentClient_ clientId cfg initServers@InitialAgentServers {smp, xftp, presetServers} store backgroundMode = - newSMPAgentEnv cfg store >>= runReaderT runAgent +getSMPAgentClient_ :: Int -> AgentConfig -> InitialAgentServers -> DBStore -> Bool -> AE AgentClient +getSMPAgentClient_ clientId cfg initServers@InitialAgentServers {smp, xftp, netCfg, useServices, presetServers} store backgroundMode = do + -- This error should be prevented in the app + when (any id useServices && sessionMode netCfg == TSMEntity) $ throwE $ CMD PROHIBITED "newAgentClient" + liftIO $ newSMPAgentEnv cfg store >>= runReaderT runAgent where runAgent = do liftIO $ checkServers "SMP" smp >> checkServers "XFTP" xftp @@ -594,18 +596,22 @@ testProtocolServer c nm userId srv = withAgentEnv' c $ case protocolTypeI @p of SPNTF -> runNTFServerTest c nm userId srv -- | set SOCKS5 proxy on/off and optionally set TCP timeouts for fast network --- TODO [certs rcv] should fail if any user is enabled to use services and per-connection isolation is chosen -setNetworkConfig :: AgentClient -> NetworkConfig -> IO () +setNetworkConfig :: AgentClient -> NetworkConfig -> AE () setNetworkConfig c@AgentClient {useNetworkConfig, proxySessTs} cfg' = do - ts <- getCurrentTime - changed <- atomically $ do - (_, cfg) <- readTVar useNetworkConfig - let changed = cfg /= cfg' - !cfgSlow = slowNetworkConfig cfg' - when changed $ writeTVar useNetworkConfig (cfgSlow, cfg') - when (socksProxy cfg /= socksProxy cfg') $ writeTVar proxySessTs ts - pure changed - when changed $ reconnectAllServers c + ts <- liftIO getCurrentTime + (ok, changed) <- atomically $ do + useServices <- readTVar $ useClientServices c + if any id useServices && sessionMode cfg' == TSMEntity + then pure (False, False) + else do + (_, cfg) <- readTVar useNetworkConfig + let changed = cfg /= cfg' + !cfgSlow = slowNetworkConfig cfg' + when changed $ writeTVar useNetworkConfig (cfgSlow, cfg') + when (socksProxy cfg /= socksProxy cfg') $ writeTVar proxySessTs ts + pure (True, changed) + unless ok $ throwE $ CMD PROHIBITED "setNetworkConfig" + when changed $ liftIO $ reconnectAllServers c setUserNetworkInfo :: AgentClient -> UserNetworkInfo -> IO () setUserNetworkInfo c@AgentClient {userNetworkInfo, userNetworkUpdated} ni = withAgentEnv' c $ do @@ -772,13 +778,19 @@ deleteUser' c@AgentClient {smpServersStats, xftpServersStats} userId delSMPQueue whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ DEL_USER userId) --- TODO [certs rcv] should fail enabling if per-connection isolation is set setUserService' :: AgentClient -> UserId -> Bool -> AM () setUserService' c userId enable = do - wasEnabled <- liftIO $ fromMaybe False <$> TM.lookupIO userId (useClientServices c) - when (enable /= wasEnabled) $ do - atomically $ TM.insert userId enable $ useClientServices c - unless enable $ withStore' c (`deleteClientServices` userId) + (ok, changed) <- atomically $ do + (cfg, _) <- readTVar $ useNetworkConfig c + if enable && sessionMode cfg == TSMEntity + then pure (False, False) + else do + wasEnabled <- fromMaybe False <$> TM.lookup userId (useClientServices c) + let changed = enable /= wasEnabled + when changed $ TM.insert userId enable $ useClientServices c + pure (True, changed) + unless ok $ throwE $ CMD PROHIBITED "setNetworkConfig" + when (changed && not enable) $ withStore' c (`deleteClientServices` userId) newConnAsync :: ConnectionModeI c => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> AM ConnId newConnAsync c userId corrId enableNtfs cMode pqInitKeys subMode = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 68d7ef62b..e4324e088 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -500,7 +500,6 @@ data UserNetworkType = UNNone | UNCellular | UNWifi | UNEthernet | UNOther deriving (Eq, Show) -- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. --- TODO [certs rcv] should fail if both per-connection isolation is set and any users use services newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Map (Maybe SMPServer) (Maybe SystemSeconds) -> Env -> IO AgentClient newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, useServices, presetDomains, presetServers} currentTs notices agentEnv = do let cfg = config agentEnv @@ -749,7 +748,7 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c updateClientService service smp pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs} - -- TODO [certs rcv] this should differentiate between service ID just set and service ID changed, and in the latter case disassociate the queue + -- TODO [certs rcv] this should differentiate between service ID just set and service ID changed, and in the latter case disassociate the queues updateClientService service smp = case (service, smpClientService smp) of (Just (_, serviceId_), Just THClientService {serviceId}) | serviceId_ /= Just serviceId -> withStore' c $ \db -> setClientServiceId db userId srv serviceId diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index b519f381e..6e42aac9d 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -487,7 +487,6 @@ createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode createNewConn db gVar cData cMode = do fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode) --- TODO [certs rcv] store clientServiceId from NewRcvQueue updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue) updateNewConnRcv db connId rq subMode = getConn db connId $>>= \case @@ -577,7 +576,6 @@ upgradeRcvConnToDuplex db connId sq = (SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq (SomeConn c _) -> pure . Left . SEBadConnType "upgradeRcvConnToDuplex" $ connType c --- TODO [certs rcv] store clientServiceId from NewRcvQueue upgradeSndConnToDuplex :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue) upgradeSndConnToDuplex db connId rq subMode = getConn db connId >>= \case @@ -585,7 +583,6 @@ upgradeSndConnToDuplex db connId rq subMode = Right (SomeConn c _) -> pure . Left . SEBadConnType "upgradeSndConnToDuplex" $ connType c _ -> pure $ Left SEConnNotFound --- TODO [certs rcv] store clientServiceId from NewRcvQueue addConnRcvQueue :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue) addConnRcvQueue db connId rq subMode = getConn db connId >>= \case @@ -2500,7 +2497,6 @@ toRcvQueue (Just shortLinkId, Just shortLinkKey, Just linkPrivSigKey, Just linkEncFixedData) -> Just ShortLinkCreds {shortLinkId, shortLinkKey, linkPrivSigKey, linkEncFixedData} _ -> Nothing enableNtfs = maybe True unBI enableNtfs_ - -- TODO [certs rcv] read client service in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode, shortLink, rcvServiceAssoc, status, enableNtfs, clientNoticeId, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion, clientNtfCreds, deleteErrors} -- | returns all connection queue credentials, the first queue is the primary one diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 58ffd1418..4d4086cfd 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -781,7 +781,7 @@ temporaryClientError = \case smpClientServiceError :: SMPClientError -> Bool smpClientServiceError = \case PCEServiceUnavailable -> True - PCETransportError (TEHandshake BAD_SERVICE) -> True -- TODO [certs] this error may be temporary, so we should possibly resubscribe. + PCETransportError (TEHandshake BAD_SERVICE) -> True -- TODO [certs rcv] this error may be temporary, so we should possibly resubscribe. PCEProtocolError SERVICE -> True PCEProtocolError (PROXY (BROKER NO_SERVICE)) -> True -- for completeness, it cannot happen. _ -> False diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index f06e9c7b1..143d417c6 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -573,7 +573,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = forM_ (L.nonEmpty $ mapMaybe (\(nId, err) -> (nId,) <$> queueSubErrorStatus err) $ L.toList errs) $ \subStatuses -> do updated <- batchUpdateSrvSubErrors st srv subStatuses logSubErrors srv subStatuses updated - -- TODO [certs] resubscribe queues with statuses NSErr and NSService + -- TODO [certs rcv] resubscribe queues with statuses NSErr and NSService CAServiceDisconnected srv serviceSub -> logNote $ "SMP server service disconnected " <> showService srv serviceSub CAServiceSubscribed srv serviceSub@(ServiceSub _ expected _) (ServiceSub _ n _) -- TODO [certs rcv] compare hash @@ -603,7 +603,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = queueSubErrorStatus :: SMPClientError -> Maybe NtfSubStatus queueSubErrorStatus = \case PCEProtocolError AUTH -> Just NSAuth - -- TODO [certs] we could allow making individual subscriptions within service session to handle SERVICE error. + -- TODO [certs rcv] we could allow making individual subscriptions within service session to handle SERVICE error. -- This would require full stack changes in SMP server, SMP client and SMP service agent. PCEProtocolError SERVICE -> Just NSService PCEProtocolError e -> updateErr "SMP error " e diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index a05743a06..0598f3c53 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -923,7 +923,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt putSubscribersInfo protoName ServerSubscribers {queueSubscribers, subClients} showIds = do activeSubs <- getSubscribedClients queueSubscribers hPutStrLn h $ protoName <> " subscriptions: " <> show (M.size activeSubs) - -- TODO [certs] service subscriptions + -- TODO [certs rcv] service subscriptions clnts <- countSubClients activeSubs hPutStrLn h $ protoName <> " subscribed clients: " <> show (IS.size clnts) <> (if showIds then " " <> show (IS.toList clnts) else "") clnts' <- readTVarIO subClients diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 64d18088d..7de966c36 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -556,7 +556,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = mkTransportServerConfig (fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini) (Just $ alpnSupportedSMPHandshakes <> httpALPN) - (fromMaybe True $ iniOnOff "TRANSPORT" "accept_service_credentials" ini), -- TODO [certs] remove this option + (fromMaybe True $ iniOnOff "TRANSPORT" "accept_service_credentials" ini), -- TODO [certs rcv] remove this option controlPort = eitherToMaybe $ T.unpack <$> lookupValue "TRANSPORT" "control_port" ini, smpAgentCfg = defaultSMPClientAgentConfig diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 2d959410d..a14118ce4 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -560,7 +560,7 @@ data SMPClientHandshake = SMPClientHandshake keyHash :: C.KeyHash, -- | pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys. authPubKey :: Maybe C.PublicKeyX25519, - -- TODO [certs] remove proxyServer, as serviceInfo includes it as clientRole + -- TODO [certs rcv] remove proxyServer, as serviceInfo includes it as clientRole -- | Whether connecting client is a proxy server (send from SMP v12). -- This property, if True, disables additional transport encrytion inside TLS. -- (Proxy server connection already has additional encryption, so this layer is not needed there). diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 7f9641a5b..f3f7e817c 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -3607,7 +3607,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do exchangeGreetings a bId1' b aId1' a `hasClients` 1 b `hasClients` 1 - liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity} + setNetworkConfig a nc {sessionMode = TSMEntity} liftIO $ threadDelay 250000 ("", "", DOWN _ _) <- nGet a ("", "", UP _ _) <- nGet a @@ -3617,7 +3617,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do exchangeGreetingsMsgId 4 a bId1 b aId1 exchangeGreetingsMsgId 4 a bId1' b aId1' liftIO $ threadDelay 250000 - liftIO $ setNetworkConfig a nc {sessionMode = TSMUser} + setNetworkConfig a nc {sessionMode = TSMUser} liftIO $ threadDelay 250000 ("", "", DOWN _ _) <- nGet a ("", "", DOWN _ _) <- nGet a @@ -3632,7 +3632,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do exchangeGreetings a bId2' b aId2' a `hasClients` 2 b `hasClients` 1 - liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity} + setNetworkConfig a nc {sessionMode = TSMEntity} liftIO $ threadDelay 250000 ("", "", DOWN _ _) <- nGet a ("", "", DOWN _ _) <- nGet a @@ -3646,7 +3646,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do exchangeGreetingsMsgId 4 a bId2 b aId2 exchangeGreetingsMsgId 4 a bId2' b aId2' liftIO $ threadDelay 250000 - liftIO $ setNetworkConfig a nc {sessionMode = TSMUser} + setNetworkConfig a nc {sessionMode = TSMUser} liftIO $ threadDelay 250000 ("", "", DOWN _ _) <- nGet a ("", "", DOWN _ _) <- nGet a @@ -3695,7 +3695,7 @@ testClientServiceConnection ps = do getSMPAgentClient' :: Int -> AgentConfig -> InitialAgentServers -> String -> IO AgentClient getSMPAgentClient' clientId cfg' initServers dbPath = do Right st <- liftIO $ createStore dbPath - c <- getSMPAgentClient_ clientId cfg' initServers st False + Right c <- runExceptT $ getSMPAgentClient_ clientId cfg' initServers st False when (dbNew st) $ insertUser st pure c diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index d013c0db4..8a285721b 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -334,7 +334,7 @@ randomSUBv6 = randomSUB_ C.SEd25519 minServerSMPRelayVersion randomSUB :: ByteString -> IO (Either TransportError (Maybe TAuthorizations, ByteString)) randomSUB = randomSUB_ C.SEd25519 currentClientSMPRelayVersion --- TODO [certs] test with the additional certificate signature +-- TODO [certs rcv] test with the additional certificate signature randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> IO (Either TransportError (Maybe TAuthorizations, ByteString)) randomSUB_ a v sessId = do g <- C.newRandom diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index d3e1b21d0..dd97781c2 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -714,7 +714,7 @@ testServiceDeliverSubscribe = runSMPServiceClient t (tlsCred, serviceKeys) $ \sh -> do let idsHash = queueIdsHash [rId] Resp "10" NoEntity (ERR (CMD NO_AUTH)) <- signSendRecv sh aServicePK ("10", NoEntity, SUBS 1 idsHash) - signSend_ sh aServicePK Nothing ("11", serviceId, SUBS 1 idsHash) -- TODO [certs rcv] compute and compare hashes + signSend_ sh aServicePK Nothing ("11", serviceId, SUBS 1 idsHash) [mId3] <- fmap catMaybes $ receiveInAnyOrder -- race between SOKS and MSG, clients can handle it @@ -808,7 +808,7 @@ testServiceUpgradeAndDowngrade = runSMPServiceClient t (tlsCred, serviceKeys) $ \sh -> do let idsHash = queueIdsHash [rId, rId2, rId3] - signSend_ sh aServicePK Nothing ("14", serviceId, SUBS 3 idsHash) -- TODO [certs rcv] compute hash + signSend_ sh aServicePK Nothing ("14", serviceId, SUBS 3 idsHash) [(rKey3_1, rId3_1, mId3_1), (rKey3_2, rId3_2, mId3_2)] <- fmap catMaybes $ receiveInAnyOrder -- race between SOKS and MSG, clients can handle it @@ -1371,7 +1371,7 @@ testMessageServiceNotifications = deliverMessage rh rId'' rKey'' sh sId'' sKey'' nh2 "connection 2" dec'' -- -- another client makes service subscription let idsHash = queueIdsHash [nId', nId''] - Resp "12" serviceId5 (SOKS 2 idsHash') <- signSendRecv nh1 (C.APrivateAuthKey C.SEd25519 servicePK) ("12", serviceId, NSUBS 2 idsHash) -- TODO [certs rcv] compute and compare hashes + Resp "12" serviceId5 (SOKS 2 idsHash') <- signSendRecv nh1 (C.APrivateAuthKey C.SEd25519 servicePK) ("12", serviceId, NSUBS 2 idsHash) idsHash' `shouldBe` idsHash serviceId5 `shouldBe` serviceId Resp "" serviceId6 (ENDS 2) <- tGet1 nh2 From 38e899957f5c5618f46c4c70eb37d4a683d18917 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Thu, 27 Nov 2025 21:37:19 +0000 Subject: [PATCH 04/11] agent: service subscription events (#1671) * agent: use server keyhash when loading service record * agent: process queue/service associations with delayed subscription results * agent: service subscription events --- src/Simplex/Messaging/Agent.hs | 67 ++++++++++--------- src/Simplex/Messaging/Agent/Client.hs | 41 +++++++----- src/Simplex/Messaging/Agent/Protocol.hs | 11 +++ .../Messaging/Agent/Store/AgentStore.hs | 15 +++-- src/Simplex/Messaging/Client.hs | 5 +- src/Simplex/Messaging/Notifications/Server.hs | 2 +- src/Simplex/Messaging/Protocol.hs | 37 +++++++--- src/Simplex/Messaging/Server.hs | 2 +- tests/AgentTests/EqInstances.hs | 5 -- tests/AgentTests/FunctionalAPITests.hs | 16 +++-- tests/SMPProxyTests.hs | 4 +- tests/ServerTests.hs | 4 +- 12 files changed, 125 insertions(+), 84 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 18bc0afbb..18e9d0465 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -194,7 +194,7 @@ import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Agent.Store.Interface (closeDBStore, execSQL, getCurrentMigrations) import Simplex.Messaging.Agent.Store.Shared (UpMigration (..), upMigration) import qualified Simplex.Messaging.Agent.TSessionSubs as SS -import Simplex.Messaging.Client (NetworkRequestMode (..), SMPClientError, ServerTransmission (..), ServerTransmissionBatch, TransportSessionMode (..), nonBlockingWriteTBQueue, smpErrorClientNotice, temporaryClientError, unexpectedResponse) +import Simplex.Messaging.Client (NetworkRequestMode (..), ProtocolClientError (..), SMPClientError, ServerTransmission (..), ServerTransmissionBatch, TransportSessionMode (..), nonBlockingWriteTBQueue, smpErrorClientNotice, temporaryClientError, unexpectedResponse) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport (..), pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn) @@ -222,6 +222,7 @@ import Simplex.Messaging.Protocol SParty (..), SProtocolType (..), ServiceSub (..), + ServiceSubResult, SndPublicAuthKey, SubscriptionMode (..), UserProtocol, @@ -232,7 +233,7 @@ import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import Simplex.Messaging.SystemTime import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (SMPVersion) +import Simplex.Messaging.Transport (SMPVersion, THClientService' (..), THandleAuth (..), THandleParams (..)) import Simplex.Messaging.Util import Simplex.Messaging.Version import Simplex.RemoteControl.Client @@ -502,7 +503,7 @@ resubscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either Agen resubscribeConnections c = withAgentEnv c . resubscribeConnections' c {-# INLINE resubscribeConnections #-} -subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType ServiceSub)) +subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType ServiceSubResult)) subscribeClientServices c = withAgentEnv c . subscribeClientServices' c {-# INLINE subscribeClientServices #-} @@ -1355,11 +1356,7 @@ toConnResult connId rs = case M.lookup connId rs of Just (Left e) -> throwE e _ -> throwE $ INTERNAL $ "no result for connection " <> B.unpack connId -type QCmdResult a = (QueueStatus, Either AgentErrorType a) - -type QDelResult = QCmdResult () - -type QSubResult = QCmdResult (Maybe SMP.ServiceId) +type QCmdResult = (QueueStatus, Either AgentErrorType ()) subscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) subscribeConnections' _ [] = pure M.empty @@ -1367,16 +1364,15 @@ subscribeConnections' c connIds = subscribeConnections_ c . zip connIds =<< with subscribeConnections_ :: AgentClient -> [(ConnId, Either StoreError SomeConnSub)] -> AM (Map ConnId (Either AgentErrorType ())) subscribeConnections_ c conns = do - -- TODO [certs rcv] - it should exclude connections already associated, and then if some don't deliver any response they may be unassociated let (subRs, cs) = foldr partitionResultsConns ([], []) conns resumeDelivery cs resumeConnCmds c $ map fst cs + -- queue/service association is handled in the client rcvRs <- lift $ connResults <$> subscribeQueues c False (concatMap rcvQueues cs) - rcvRs' <- storeClientServiceAssocs rcvRs ns <- asks ntfSupervisor - lift $ whenM (liftIO $ hasInstantNotifications ns) . void . forkIO . void $ sendNtfCreate ns rcvRs' cs + lift $ whenM (liftIO $ hasInstantNotifications ns) . void . forkIO . void $ sendNtfCreate ns rcvRs cs -- union is left-biased - let rs = rcvRs' `M.union` subRs + let rs = rcvRs `M.union` subRs notifyResultError rs pure rs where @@ -1400,24 +1396,21 @@ subscribeConnections_ c conns = do _ -> Left $ INTERNAL "unexpected queue status" rcvQueues :: (ConnId, SomeConnSub) -> [RcvQueueSub] rcvQueues (_, SomeConn _ conn) = connRcvQueues conn - connResults :: [(RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId))] -> Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) + connResults :: [(RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId))] -> Map ConnId (Either AgentErrorType ()) connResults = M.map snd . foldl' addResult M.empty where -- collects results by connection ID - addResult :: Map ConnId QSubResult -> (RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId)) -> Map ConnId QSubResult - addResult rs (RcvQueueSub {connId, status}, r) = M.alter (combineRes (status, r)) connId rs + addResult :: Map ConnId QCmdResult -> (RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId)) -> Map ConnId QCmdResult + addResult rs (RcvQueueSub {connId, status}, r) = M.alter (combineRes (status, () <$ r)) connId rs -- combines two results for one connection, by using only Active queues (if there is at least one Active queue) - combineRes :: QSubResult -> Maybe QSubResult -> Maybe QSubResult + combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult combineRes r' (Just r) = Just $ if order r <= order r' then r else r' combineRes r' _ = Just r' - order :: QSubResult -> Int + order :: QCmdResult -> Int order (Active, Right _) = 1 order (Active, _) = 2 order (_, Right _) = 3 order _ = 4 - -- TODO [certs rcv] store associations of queues with client service ID - storeClientServiceAssocs :: Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) -> AM (Map ConnId (Either AgentErrorType ())) - storeClientServiceAssocs = pure . M.map (() <$) sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> [(ConnId, SomeConnSub)] -> AM' () sendNtfCreate ns rcvRs cs = do let oks = M.keysSet $ M.filter (either temporaryAgentError $ const True) rcvRs @@ -1522,14 +1515,14 @@ resubscribeConnections' c connIds = do rqs' -> anyM $ map (atomically . hasActiveSubscription c) rqs' -- TODO [certs rcv] compare hash. possibly, it should return both expected and returned counts -subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType ServiceSub)) +subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType ServiceSubResult)) subscribeClientServices' c userId = ifM useService subscribe $ throwError $ CMD PROHIBITED "no user service allowed" where useService = liftIO $ (Just True ==) <$> TM.lookupIO userId (useClientServices c) subscribe = do srvs <- withStore' c (`getClientServiceServers` userId) - lift $ M.fromList <$> mapConcurrently (\(srv, ServiceSub _ n idsHash) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c userId srv n idsHash) srvs + lift $ M.fromList <$> mapConcurrently (\(srv, ServiceSub _ n idsHash) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c False userId srv n idsHash) srvs -- requesting messages sequentially, to reduce memory usage getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta))) @@ -2383,13 +2376,13 @@ deleteConnQueues c nm waitDelivery ntf rqs = do connResults = M.map snd . foldl' addResult M.empty where -- collects results by connection ID - addResult :: Map ConnId QDelResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QDelResult + addResult :: Map ConnId QCmdResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QCmdResult addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs -- combines two results for one connection, by prioritizing errors in Active queues - combineRes :: QDelResult -> Maybe QDelResult -> Maybe QDelResult + combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult combineRes r' (Just r) = Just $ if order r <= order r' then r else r' combineRes r' _ = Just r' - order :: QDelResult -> Int + order :: QCmdResult -> Int order (Active, Left _) = 1 order (_, Left _) = 2 order _ = 3 @@ -2840,11 +2833,17 @@ data ACKd = ACKd | ACKPending -- It cannot be finally, as sometimes it needs to be ACK+DEL, -- and sometimes ACK has to be sent from the consumer. processSMPTransmissions :: AgentClient -> ServerTransmissionBatch SMPVersion ErrorType BrokerMsg -> AM' () -processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId, ts) = do +processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), THandleParams {thAuth, sessionId = sessId}, ts) = do upConnIds <- newTVarIO [] + serviceRQs <- newTVarIO ([] :: [RcvQueue]) forM_ ts $ \(entId, t) -> case t of STEvent msgOrErr - | entId == SMP.NoEntity -> pure () -- TODO [certs rcv] process SALL + | entId == SMP.NoEntity -> case msgOrErr of + Right msg -> case msg of + SMP.ALLS -> notifySub c $ SERVICE_ALL srv + SMP.ERR e -> notifyErr "" $ PCEProtocolError e + _ -> logError $ "unexpected event: " <> tshow msg + Left e -> notifyErr "" e | otherwise -> withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of Right msg -> runProcessSMP rq conn (toConnData conn) msg Left e -> lift $ do @@ -2853,11 +2852,10 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId STResponse (Cmd SRecipient cmd) respOrErr -> withRcvConn entId $ \rq conn -> case cmd of SMP.SUB -> case respOrErr of - Right SMP.OK -> liftIO $ processSubOk rq upConnIds - -- TODO [certs rcv] associate queue with the service - Right (SMP.SOK _serviceId_) -> liftIO $ processSubOk rq upConnIds + Right SMP.OK -> liftIO $ processSubOk rq upConnIds serviceRQs Nothing + Right (SMP.SOK serviceId_) -> liftIO $ processSubOk rq upConnIds serviceRQs serviceId_ Right msg@SMP.MSG {} -> do - liftIO $ processSubOk rq upConnIds -- the connection is UP even when processing this particular message fails + liftIO $ processSubOk rq upConnIds serviceRQs Nothing -- the connection is UP even when processing this particular message fails runProcessSMP rq conn (toConnData conn) msg Right r -> lift $ processSubErr rq $ unexpectedResponse r Left e -> lift $ unless (temporaryClientError e) $ processSubErr rq e -- timeout/network was already reported @@ -2873,6 +2871,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId unless (null connIds) $ do notify' "" $ UP srv connIds atomically $ incSMPServerStat' c userId srv connSubscribed $ length connIds + readTVarIO serviceRQs >>= processRcvServiceAssocs c where withRcvConn :: SMP.RecipientId -> (forall c. RcvQueue -> Connection c -> AM ()) -> AM' () withRcvConn rId a = do @@ -2882,11 +2881,13 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId tryAllErrors' (a rq conn) >>= \case Left e -> notify' connId (ERR e) Right () -> pure () - processSubOk :: RcvQueue -> TVar [ConnId] -> IO () - processSubOk rq@RcvQueue {connId} upConnIds = + processSubOk :: RcvQueue -> TVar [ConnId] -> TVar [RcvQueue] -> Maybe SMP.ServiceId -> IO () + processSubOk rq@RcvQueue {connId} upConnIds serviceRQs serviceId_ = atomically . whenM (isPendingSub rq) $ do SS.addActiveSub tSess sessId rq $ currentSubs c modifyTVar' upConnIds (connId :) + when (isJust serviceId_ && serviceId_ == clientServiceId_) $ modifyTVar' serviceRQs (rq :) + clientServiceId_ = (\THClientService {serviceId} -> serviceId) <$> (clientService =<< thAuth) processSubErr :: RcvQueue -> SMPClientError -> AM' () processSubErr rq@RcvQueue {connId} e = do atomically . whenM (isPendingSub rq) $ diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index e4324e088..77d73027d 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -50,6 +50,7 @@ module Simplex.Messaging.Agent.Client subscribeQueues, subscribeUserServerQueues, subscribeClientService, + processRcvServiceAssocs, processClientNotices, getQueueMessage, decryptSMPMessage, @@ -280,6 +281,7 @@ import Simplex.Messaging.Protocol SMPMsgMeta (..), SProtocolType (..), ServiceSub (..), + ServiceSubResult (..), SndPublicAuthKey, SubscriptionMode (..), NewNtfCreds (..), @@ -292,6 +294,7 @@ import Simplex.Messaging.Protocol XFTPServerWithAuth, pattern NoEntity, senderCanSecure, + serviceSubResult, ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Protocol.Types @@ -785,6 +788,7 @@ smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess serverDown (qs, conns, serviceSub_) = whenM (readTVarIO active) $ do notifySub c $ hostEvent' DISCONNECT client unless (null conns) $ notifySub c $ DOWN srv conns + mapM_ (notifySub c . SERVICE_DOWN srv) serviceSub_ unless (null qs && isNothing serviceSub_) $ do releaseGetLocksIO c qs mode <- getSessionModeIO c @@ -1514,7 +1518,7 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl newErr = throwE . BROKER (B.unpack $ strEncode srv) . UNEXPECTED . ("Create queue: " <>) processSubResults :: AgentClient -> SMPTransportSession -> SessionId -> Maybe ServiceId -> NonEmpty (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> STM ([RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)]) -processSubResults c tSess@(userId, srv, _) sessId smpServiceId rs = do +processSubResults c tSess@(userId, srv, _) sessId serviceId_ rs = do pending <- SS.getPendingSubs tSess $ currentSubs c let (failed, subscribed@(qs, sQs), notices, ignored) = foldr (partitionResults pending) (M.empty, ([], []), [], 0) rs unless (M.null failed) $ do @@ -1541,10 +1545,10 @@ processSubResults c tSess@(userId, srv, _) sessId smpServiceId rs = do | otherwise -> (failed', subscribed, notices, ignored) where failed' = M.insert rcvId e failed - Right serviceId_ + Right serviceId_' | rcvId `M.member` pendingSubs -> - let subscribed' = case (smpServiceId, serviceId_, pendingSS) of - (Just sId, Just sId', Just ServiceSub {serviceId}) | sId == sId' && sId == serviceId -> (qs, rq : sQs) + let subscribed' = case (serviceId_, serviceId_', pendingSS) of + (Just sId, Just sId', Just ServiceSub {smpServiceId}) | sId == sId' && sId == smpServiceId -> (qs, rq : sQs) _ -> (rq : qs, sQs) in (failed, subscribed', notices', ignored) | otherwise -> (failed, subscribed, notices', ignored + 1) @@ -1692,7 +1696,8 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c sessId = sessionId $ thParams smp smpServiceId = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp -processRcvServiceAssocs :: AgentClient -> [RcvQueueSub] -> AM' () +processRcvServiceAssocs :: SMPQueue q => AgentClient -> [q] -> AM' () +processRcvServiceAssocs _ [] = pure () processRcvServiceAssocs c serviceQs = withStore' c (`setRcvServiceAssocs` serviceQs) `catchAllErrors'` \e -> do logError $ "processClientNotices error: " <> tshow e @@ -1709,17 +1714,16 @@ processClientNotices c@AgentClient {presetServers} tSess notices = do logError $ "processClientNotices error: " <> tshow e notifySub' c "" $ ERR e -resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub -> AM ServiceSub -resubscribeClientService c tSess (ServiceSub _ n idsHash) = - withServiceClient c tSess $ \smp _ -> do - subscribeClientService_ c tSess smp n idsHash +resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub -> AM ServiceSubResult +resubscribeClientService c tSess serviceSub = + withServiceClient c tSess $ \smp _ -> subscribeClientService_ c True tSess smp serviceSub -subscribeClientService :: AgentClient -> UserId -> SMPServer -> Int64 -> IdsHash -> AM ServiceSub -subscribeClientService c userId srv n idsHash = +subscribeClientService :: AgentClient -> Bool -> UserId -> SMPServer -> Int64 -> IdsHash -> AM ServiceSubResult +subscribeClientService c withEvent userId srv n idsHash = withServiceClient c tSess $ \smp smpServiceId -> do let serviceSub = ServiceSub smpServiceId n idsHash atomically $ SS.setPendingServiceSub tSess serviceSub $ currentSubs c - subscribeClientService_ c tSess smp n idsHash + subscribeClientService_ c withEvent tSess smp serviceSub where tSess = (userId, srv, Nothing) @@ -1730,14 +1734,15 @@ withServiceClient c tSess action = Just smpServiceId -> action smp smpServiceId Nothing -> throwE PCEServiceUnavailable -subscribeClientService_ :: AgentClient -> SMPTransportSession -> SMPClient -> Int64 -> IdsHash -> ExceptT SMPClientError IO ServiceSub -subscribeClientService_ c tSess smp n idsHash = do - -- TODO [certs rcv] handle error - serviceSub' <- subscribeService smp SMP.SRecipientService n idsHash +subscribeClientService_ :: AgentClient -> Bool -> SMPTransportSession -> SMPClient -> ServiceSub -> ExceptT SMPClientError IO ServiceSubResult +subscribeClientService_ c withEvent tSess@(_, srv, _) smp expected@(ServiceSub _ n idsHash) = do + subscribed <- subscribeService smp SMP.SRecipientService n idsHash let sessId = sessionId $ thParams smp + r = serviceSubResult expected subscribed atomically $ whenM (activeClientSession c tSess sessId) $ - SS.setActiveServiceSub tSess sessId serviceSub' $ currentSubs c - pure serviceSub' + SS.setActiveServiceSub tSess sessId subscribed $ currentSubs c + when withEvent $ notifySub c $ SERVICE_UP srv r + pure r activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClients c) diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 15d51aed9..d5b35611b 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -234,6 +234,8 @@ import Simplex.Messaging.Protocol NMsgMeta, ProtocolServer (..), QueueMode (..), + ServiceSub, + ServiceSubResult, SMPClientVersion, SMPServer, SMPServerWithAuth, @@ -388,6 +390,9 @@ data AEvent (e :: AEntity) where DISCONNECT :: AProtocolType -> TransportHost -> AEvent AENone DOWN :: SMPServer -> [ConnId] -> AEvent AENone UP :: SMPServer -> [ConnId] -> AEvent AENone + SERVICE_ALL :: SMPServer -> AEvent AENone -- all service messages are delivered + SERVICE_DOWN :: SMPServer -> ServiceSub -> AEvent AENone + SERVICE_UP :: SMPServer -> ServiceSubResult -> AEvent AENone SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> AEvent AEConn RSYNC :: RatchetSyncState -> Maybe AgentCryptoError -> ConnectionStats -> AEvent AEConn SENT :: AgentMsgId -> Maybe SMPServer -> AEvent AEConn @@ -459,6 +464,9 @@ data AEventTag (e :: AEntity) where DISCONNECT_ :: AEventTag AENone DOWN_ :: AEventTag AENone UP_ :: AEventTag AENone + SERVICE_ALL_ :: AEventTag AENone + SERVICE_DOWN_ :: AEventTag AENone + SERVICE_UP_ :: AEventTag AENone SWITCH_ :: AEventTag AEConn RSYNC_ :: AEventTag AEConn SENT_ :: AEventTag AEConn @@ -514,6 +522,9 @@ aEventTag = \case DISCONNECT {} -> DISCONNECT_ DOWN {} -> DOWN_ UP {} -> UP_ + SERVICE_ALL _ -> SERVICE_ALL_ + SERVICE_DOWN {} -> SERVICE_DOWN_ + SERVICE_UP {} -> SERVICE_UP_ SWITCH {} -> SWITCH_ RSYNC {} -> RSYNC_ SENT {} -> SENT_ diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index 6e42aac9d..a732d28d4 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -419,18 +419,19 @@ createClientService db userId srv (kh, (cert, pk)) = do |] (userId, host srv, port srv, serverKeyHash_, kh, cert, pk) --- TODO [certs rcv] get correct service based on key hash of the server getClientService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId)) getClientService db userId srv = maybeFirstRow toService $ DB.query db [sql| - SELECT service_cert_hash, service_cert, service_priv_key, service_id - FROM client_services - WHERE user_id = ? AND host = ? AND port = ? + SELECT c.service_cert_hash, c.service_cert, c.service_priv_key, c.service_id + FROM client_services c + JOIN servers s ON c.host = s.host AND c.port = s.port + WHERE c.user_id = ? AND c.host = ? AND c.port = ? + AND COALESCE(c.server_key_hash, s.key_hash) = ? |] - (userId, host srv, port srv) + (userId, host srv, port srv, keyHash srv) where toService (kh, cert, pk, serviceId_) = ((kh, (cert, pk)), serviceId_) @@ -2250,12 +2251,12 @@ getUserServerRcvQueueSubs db userId srv onlyNeeded = unsetQueuesToSubscribe :: DB.Connection -> IO () unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = 0 WHERE to_subscribe = 1" -setRcvServiceAssocs :: DB.Connection -> [RcvQueueSub] -> IO () +setRcvServiceAssocs :: SMPQueue q => DB.Connection -> [q] -> IO () setRcvServiceAssocs db rqs = #if defined(dbPostgres) DB.execute db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id IN " $ Only $ In (map queueId rqs) #else - DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = " $ map (Only . queueId) rqs + DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = ?" $ map (Only . queueId) rqs #endif -- * getConn helpers diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 4d4086cfd..81e9820a2 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -251,7 +251,7 @@ type ClientCommand msg = (EntityId, Maybe C.APrivateAuthKey, ProtoCommand msg) -- | Type synonym for transmission from SPM servers. -- Batch response is presented as a single `ServerTransmissionBatch` tuple. -type ServerTransmissionBatch v err msg = (TransportSession msg, Version v, SessionId, NonEmpty (EntityId, ServerTransmission err msg)) +type ServerTransmissionBatch v err msg = (TransportSession msg, THandleParams v 'TClient, NonEmpty (EntityId, ServerTransmission err msg)) data ServerTransmission err msg = STEvent (Either (ProtocolClientError err) msg) @@ -864,8 +864,7 @@ writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO () writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c [(rId, STEvent (Right msg))]) (msgQ $ client_ c) serverTransmission :: ProtocolClient v err msg -> NonEmpty (RecipientId, ServerTransmission err msg) -> ServerTransmissionBatch v err msg -serverTransmission ProtocolClient {thParams = THandleParams {thVersion, sessionId}, client_ = PClient {transportSession}} ts = - (transportSession, thVersion, sessionId, ts) +serverTransmission ProtocolClient {thParams, client_ = PClient {transportSession}} ts = (transportSession, thParams, ts) -- | Get message from SMP queue. The server returns ERR PROHIBITED if a client uses SUB and GET via the same transport connection for the same queue -- diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 143d417c6..67ed89d71 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -524,7 +524,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = NtfPushServer {pushQ} <- asks pushServer stats <- asks serverStats liftIO $ forever $ do - ((_, srv@(SMPServer (h :| _) _ _), _), _thVersion, sessionId, ts) <- atomically $ readTBQueue msgQ + ((_, srv@(SMPServer (h :| _) _ _), _), THandleParams {sessionId}, ts) <- atomically $ readTBQueue msgQ forM ts $ \(ntfId, t) -> case t of STUnexpectedError e -> logError $ "SMP client unexpected error: " <> tshow e -- uncorrelated response, should not happen STResponse {} -> pure () -- it was already reported as timeout error diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index c00899e1c..a5f94960e 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -142,6 +142,8 @@ module Simplex.Messaging.Protocol MsgBody, IdsHash (..), ServiceSub (..), + ServiceSubResult (..), + serviceSubResult, queueIdsHash, queueIdHash, MaxMessageLen, @@ -712,7 +714,7 @@ data BrokerMsg where -- v2: MsgId -> SystemTime -> MsgFlags -> MsgBody -> BrokerMsg MSG :: RcvMessage -> BrokerMsg -- sent once delivering messages to SUBS command is complete - SALL :: BrokerMsg + ALLS :: BrokerMsg NID :: NotifierId -> RcvNtfPublicDhKey -> BrokerMsg NMSG :: C.CbNonce -> EncNMsgMeta -> BrokerMsg -- Should include certificate chain @@ -949,7 +951,7 @@ data BrokerMsgTag | SOK_ | SOKS_ | MSG_ - | SALL_ + | ALLS_ | NID_ | NMSG_ | PKEY_ @@ -1042,7 +1044,7 @@ instance Encoding BrokerMsgTag where SOK_ -> "SOK" SOKS_ -> "SOKS" MSG_ -> "MSG" - SALL_ -> "SALL" + ALLS_ -> "ALLS" NID_ -> "NID" NMSG_ -> "NMSG" PKEY_ -> "PKEY" @@ -1064,7 +1066,7 @@ instance ProtocolMsgTag BrokerMsgTag where "SOK" -> Just SOK_ "SOKS" -> Just SOKS_ "MSG" -> Just MSG_ - "SALL" -> Just SALL_ + "ALLS" -> Just ALLS_ "NID" -> Just NID_ "NMSG" -> Just NMSG_ "PKEY" -> Just PKEY_ @@ -1468,10 +1470,29 @@ type MsgId = ByteString type MsgBody = ByteString data ServiceSub = ServiceSub - { serviceId :: ServiceId, + { smpServiceId :: ServiceId, smpQueueCount :: Int64, smpQueueIdsHash :: IdsHash } + deriving (Eq, Show) + +data ServiceSubResult = ServiceSubResult (Maybe ServiceSubError) ServiceSub + deriving (Eq, Show) + +data ServiceSubError + = SSErrorServiceId {expectedServiceId :: ServiceId, subscribedServiceId :: ServiceId} + | SSErrorQueueCount {expectedQueueCount :: Int64, subscribedQueueCount :: Int64} + | SSErrorQueueIdsHash {expectedQueueIdsHash :: IdsHash, subscribedQueueIdsHash :: IdsHash} + deriving (Eq, Show) + +serviceSubResult :: ServiceSub -> ServiceSub -> ServiceSubResult +serviceSubResult s s' = ServiceSubResult subError_ s' + where + subError_ + | smpServiceId s /= smpServiceId s' = Just $ SSErrorServiceId (smpServiceId s) (smpServiceId s') + | smpQueueCount s /= smpQueueCount s' = Just $ SSErrorQueueCount (smpQueueCount s) (smpQueueCount s') + | smpQueueIdsHash s /= smpQueueIdsHash s' = Just $ SSErrorQueueIdsHash (smpQueueIdsHash s) (smpQueueIdsHash s') + | otherwise = Nothing newtype IdsHash = IdsHash {unIdsHash :: BS.ByteString} deriving (Eq, Show) @@ -1897,7 +1918,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where | otherwise -> e (SOKS_, ' ', n) MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -> e (MSG_, ' ', msgId, Tail body) - SALL -> e SALL_ + ALLS -> e ALLS_ NID nId srvNtfDh -> e (NID_, ' ', nId, srvNtfDh) NMSG nmsgNonce encNMsgMeta -> e (NMSG_, ' ', nmsgNonce, encNMsgMeta) PKEY sid vr certKey -> e (PKEY_, ' ', sid, vr, certKey) @@ -1928,7 +1949,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where MSG . RcvMessage msgId <$> bodyP where bodyP = EncRcvMsgBody . unTail <$> smpP - SALL_ -> pure SALL + ALLS_ -> pure ALLS IDS_ | v >= newNtfCredsSMPVersion -> ids smpP smpP smpP smpP | v >= serviceCertsSMPVersion -> ids smpP smpP smpP nothing @@ -1981,7 +2002,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where PONG -> noEntityMsg PKEY {} -> noEntityMsg RRES _ -> noEntityMsg - SALL -> noEntityMsg + ALLS -> noEntityMsg -- other broker responses must have queue ID _ | B.null entId -> Left $ CMD NO_ENTITY diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 0598f3c53..0fc15b3e3 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -1806,7 +1806,7 @@ client where deliverServiceMessages expectedCnt = do (qCnt, _msgCnt, _dupCnt, _errCnt) <- foldRcvServiceMessages ms serviceId deliverQueueMsg (0, 0, 0, 0) - atomically $ writeTBQueue msgQ [(NoCorrId, NoEntity, SALL)] + atomically $ writeTBQueue msgQ [(NoCorrId, NoEntity, ALLS)] -- TODO [certs rcv] compare with expected logNote $ "Service subscriptions for " <> tshow serviceId <> " (" <> tshow qCnt <> " queues)" deliverQueueMsg :: (Int, Int, Int, Int) -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO (Int, Int, Int, Int) diff --git a/tests/AgentTests/EqInstances.hs b/tests/AgentTests/EqInstances.hs index e142c6177..63c493861 100644 --- a/tests/AgentTests/EqInstances.hs +++ b/tests/AgentTests/EqInstances.hs @@ -8,7 +8,6 @@ import Data.Type.Equality import Simplex.Messaging.Agent.Protocol (ConnLinkData (..), OwnerAuth (..), UserContactData (..), UserLinkData (..)) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Client (ProxiedRelay (..)) -import Simplex.Messaging.Protocol (ServiceSub (..)) instance (Eq rq, Eq sq) => Eq (SomeConn' rq sq) where SomeConn d c == SomeConn d' c' = case testEquality d d' of @@ -48,7 +47,3 @@ deriving instance Eq OwnerAuth deriving instance Show ProxiedRelay deriving instance Eq ProxiedRelay - -deriving instance Show ServiceSub - -deriving instance Eq ServiceSub diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index f3f7e817c..cb74bc0b6 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -3668,27 +3668,35 @@ testTwoUsers = withAgentClients2 $ \a b -> do testClientServiceConnection :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testClientServiceConnection ps = do - (sId, uId) <- withSmpServerStoreLogOn ps testPort $ \_ -> do + ((sId, uId), qIdHash) <- withSmpServerStoreLogOn ps testPort $ \_ -> do conns@(sId, uId) <- withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> runRight $ do conns@(sId, uId) <- makeConnection service user exchangeGreetings service uId user sId pure conns withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> runRight $ do - subscribeClientServices service 1 + [(_, Right (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash)))] <- M.toList <$> subscribeClientServices service 1 + ("", "", SERVICE_ALL _) <- nGet service subscribeConnection user sId exchangeGreetingsMsgId 4 service uId user sId - pure conns + pure (conns, qIdHash) withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do - subscribeClientServices service 1 + [(_, Right (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash')))] <- M.toList <$> subscribeClientServices service 1 + ("", "", SERVICE_ALL _) <- nGet service + liftIO $ qIdHash' `shouldBe` qIdHash subscribeConnection user sId exchangeGreetingsMsgId 6 service uId user sId ("", "", DOWN _ [_]) <- nGet user + ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 1 qIdHash')) <- nGet service + qIdHash' `shouldBe` qIdHash -- TODO [certs rcv] how to integrate service counts into stats -- r <- nGet service -- TODO [certs rcv] some event when service disconnects with count -- print r withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do ("", "", UP _ [_]) <- nGet user + ("", "", SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash''))) <- nGet service + ("", "", SERVICE_ALL _) <- nGet service + liftIO $ qIdHash'' `shouldBe` qIdHash -- r <- nGet service -- TODO [certs rcv] some event when service reconnects with count exchangeGreetingsMsgId 8 service uId user sId diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index 09f20c1dd..0d8ccdf89 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -188,7 +188,7 @@ deliverMessagesViaProxy proxyServ relayServ alg unsecuredMsgs securedMsgs = do runExceptT' (proxySMPMessage pc NRMInteractive sess Nothing sndId noMsgFlags msg) `shouldReturn` Right () runExceptT' (proxySMPMessage pc NRMInteractive sess {prSessionId = "bad session"} Nothing sndId noMsgFlags msg) `shouldReturn` Left (ProxyProtocolError $ SMP.PROXY SMP.NO_SESSION) -- receive 1 - (_tSess, _v, _sid, [(_entId, STEvent (Right (SMP.MSG RcvMessage {msgId, msgBody = EncRcvMsgBody encBody})))]) <- atomically $ readTBQueue msgQ + (_tSess, _, [(_entId, STEvent (Right (SMP.MSG RcvMessage {msgId, msgBody = EncRcvMsgBody encBody})))]) <- atomically $ readTBQueue msgQ dec msgId encBody `shouldBe` Right msg runExceptT' $ ackSMPMessage rc rPriv rcvId msgId -- secure queue @@ -200,7 +200,7 @@ deliverMessagesViaProxy proxyServ relayServ alg unsecuredMsgs securedMsgs = do runExceptT' (proxySMPMessage pc NRMInteractive sess (Just sPriv) sndId noMsgFlags msg') `shouldReturn` Right () ) ( forM_ securedMsgs $ \msg' -> do - (_tSess, _v, _sid, [(_entId, STEvent (Right (SMP.MSG RcvMessage {msgId = msgId', msgBody = EncRcvMsgBody encBody'})))]) <- atomically $ readTBQueue msgQ + (_tSess, _, [(_entId, STEvent (Right (SMP.MSG RcvMessage {msgId = msgId', msgBody = EncRcvMsgBody encBody'})))]) <- atomically $ readTBQueue msgQ dec msgId' encBody' `shouldBe` Right msg' runExceptT' $ ackSMPMessage rc rPriv rcvId msgId' ) diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index dd97781c2..82a39af39 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -733,7 +733,7 @@ testServiceDeliverSubscribe = pure $ Just $ Just mId3 _ -> pure Nothing ] - Resp "" NoEntity SALL <- tGet1 sh + Resp "" NoEntity ALLS <- tGet1 sh Resp "12" _ OK <- signSendRecv sh rKey ("12", rId, ACK mId3) Resp "14" _ OK <- signSendRecv h sKey ("14", sId, _SEND "hello 4") Resp "" _ (Msg mId4 msg4) <- tGet1 sh @@ -831,7 +831,7 @@ testServiceUpgradeAndDowngrade = pure $ Just $ Just (rKey2, rId2, mId3) _ -> pure Nothing ] - Resp "" NoEntity SALL <- tGet1 sh + Resp "" NoEntity ALLS <- tGet1 sh Resp "15" _ OK <- signSendRecv sh rKey3_1 ("15", rId3_1, ACK mId3_1) Resp "16" _ OK <- signSendRecv sh rKey3_2 ("16", rId3_2, ACK mId3_2) pure () From 2ea9a9a143168f2b04933bf5d6ca13a3f9a170b0 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Fri, 5 Dec 2025 20:46:48 +0000 Subject: [PATCH 05/11] agent: finalize initial service subscriptions, remove associations on service ID changes (#1672) * agent: remove service/queue associations when service ID changes * agent: check that service ID in NEW response matches session ID in transport session * agent subscription WIP * test * comment * enable tests * update queries * agent: option to add SQLite aggregates to DB connection (#1673) * agent: add build_relations_vector function to sqlite * update aggregate * use static aggregate * remove relations --------- Co-authored-by: Evgeny Poberezkin * add test, treat BAD_SERVICE as temp error, only remove queue associations on service errors * add packZipWith for backward compatibility with GHC 8.10.7 --------- Co-authored-by: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> --- src/Simplex/Messaging/Agent.hs | 45 ++++++-- src/Simplex/Messaging/Agent/Client.hs | 52 ++++++--- .../Messaging/Agent/Store/AgentStore.hs | 107 +++++++++++++++--- src/Simplex/Messaging/Agent/Store/SQLite.hs | 20 ++-- .../Messaging/Agent/Store/SQLite/Common.hs | 11 +- .../Messaging/Agent/Store/SQLite/Util.hs | 48 ++++++++ src/Simplex/Messaging/Client.hs | 2 +- src/Simplex/Messaging/Protocol.hs | 1 + src/Simplex/Messaging/Util.hs | 26 +++++ tests/AgentTests/FunctionalAPITests.hs | 88 +++++++++++--- 10 files changed, 330 insertions(+), 70 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 18e9d0465..f155ce77b 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -153,7 +153,7 @@ import Data.Bifunctor (bimap, first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition -import Data.Either (isRight, partitionEithers, rights) +import Data.Either (fromRight, isRight, partitionEithers, rights) import Data.Foldable (foldl', toList) import Data.Functor (($>)) import Data.Functor.Identity @@ -221,7 +221,6 @@ import Simplex.Messaging.Protocol SMPMsgMeta, SParty (..), SProtocolType (..), - ServiceSub (..), ServiceSubResult, SndPublicAuthKey, SubscriptionMode (..), @@ -1451,7 +1450,23 @@ subscribeAllConnections' c onlyNeeded activeUserId_ = handleErr $ do let userSrvs' = case activeUserId_ of Just activeUserId -> sortOn (\(uId, _) -> if uId == activeUserId then 0 else 1 :: Int) userSrvs Nothing -> userSrvs - rs <- lift $ mapConcurrently (subscribeUserServer maxPending currPending) userSrvs' + useServices <- readTVarIO $ useClientServices c + -- These options are possible below: + -- 1) services fully disabled: + -- No service subscriptions will be attempted, and existing services and association will remain in in the database, + -- but they will be ignored because of hasService parameter set to False. + -- This approach preserves performance for all clients that do not use services. + -- 2) at least one user ID has services enabled: + -- Service will be loaded for all user/server combinations: + -- a) service is enabled for user ID and service record exists: subscription will be attempted, + -- b) service is disabled and record exists: service record and all associations will be removed, + -- c) service is disabled or no record: no subscription attempt. + -- On successful service subscription, only unassociated queues will be subscribed. + userSrvs'' <- + if any id useServices + then lift $ mapConcurrently (subscribeService useServices) userSrvs' + else pure $ map (,False) userSrvs' + rs <- lift $ mapConcurrently (subscribeUserServer maxPending currPending) userSrvs'' let (errs, oks) = partitionEithers rs logInfo $ "subscribed " <> tshow (sum oks) <> " queues" forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map ("",) @@ -1460,21 +1475,31 @@ subscribeAllConnections' c onlyNeeded activeUserId_ = handleErr $ do resumeAllCommands c where handleErr = (`catchAllErrors` \e -> notifySub' c "" (ERR e) >> throwE e) - subscribeUserServer :: Int -> TVar Int -> (UserId, SMPServer) -> AM' (Either AgentErrorType Int) - subscribeUserServer maxPending currPending (userId, srv) = do + subscribeService :: Map UserId Bool -> (UserId, SMPServer) -> AM' ((UserId, SMPServer), ServiceAssoc) + subscribeService useServices us@(userId, srv) = fmap ((us,) . fromRight False) $ tryAllErrors' $ do + withStore' c (\db -> getSubscriptionService db userId srv) >>= \case + Just serviceSub -> case M.lookup userId useServices of + Just True -> tryAllErrors (subscribeClientService c True userId srv serviceSub) >>= \case + Left e | clientServiceError e -> unassocQueues $> False + _ -> pure True + _ -> unassocQueues $> False + where + unassocQueues = withStore' c $ \db -> unassocUserServerRcvQueueSubs db userId srv + _ -> pure False + subscribeUserServer :: Int -> TVar Int -> ((UserId, SMPServer), ServiceAssoc) -> AM' (Either AgentErrorType Int) + subscribeUserServer maxPending currPending ((userId, srv), hasService) = do atomically $ whenM ((maxPending <=) <$> readTVar currPending) retry tryAllErrors' $ do qs <- withStore' c $ \db -> do - qs <- getUserServerRcvQueueSubs db userId srv onlyNeeded - atomically $ modifyTVar' currPending (+ length qs) -- update before leaving transaction + qs <- getUserServerRcvQueueSubs db userId srv onlyNeeded hasService + unless (null qs) $ atomically $ modifyTVar' currPending (+ length qs) -- update before leaving transaction pure qs let n = length qs - lift $ subscribe qs `E.finally` atomically (modifyTVar' currPending $ subtract n) + unless (null qs) $ lift $ subscribe qs `E.finally` atomically (modifyTVar' currPending $ subtract n) pure n where subscribe qs = do rs <- subscribeUserServerQueues c userId srv qs - -- TODO [certs rcv] storeClientServiceAssocs store associations of queues with client service ID ns <- asks ntfSupervisor whenM (liftIO $ hasInstantNotifications ns) $ sendNtfCreate ns rs sendNtfCreate :: NtfSupervisor -> [(RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId))] -> AM' () @@ -1522,7 +1547,7 @@ subscribeClientServices' c userId = useService = liftIO $ (Just True ==) <$> TM.lookupIO userId (useClientServices c) subscribe = do srvs <- withStore' c (`getClientServiceServers` userId) - lift $ M.fromList <$> mapConcurrently (\(srv, ServiceSub _ n idsHash) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c False userId srv n idsHash) srvs + lift $ M.fromList <$> mapConcurrently (\(srv, serviceSub) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c False userId srv serviceSub) srvs -- requesting messages sequentially, to reduce memory usage getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta))) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 77d73027d..7acfb0b49 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -120,6 +120,7 @@ module Simplex.Messaging.Agent.Client getAgentSubscriptions, slowNetworkConfig, protocolClientError, + clientServiceError, Worker (..), SessionVar (..), SubscriptionsInfo (..), @@ -303,7 +304,7 @@ import Simplex.Messaging.Session import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (SMPServiceRole (..), SMPVersion, ServiceCredentials (..), SessionId, THClientService' (..), THandleParams (sessionId, thVersion), TransportError (..), TransportPeer (..), sndAuthKeySMPVersion, shortLinksSMPVersion, newNtfCredsSMPVersion) +import Simplex.Messaging.Transport (HandshakeError (..), SMPServiceRole (..), SMPVersion, ServiceCredentials (..), SessionId, THClientService' (..), THandleAuth (..), THandleParams (sessionId, thAuth, thVersion), TransportError (..), TransportPeer (..), sndAuthKeySMPVersion, shortLinksSMPVersion, newNtfCredsSMPVersion) import Simplex.Messaging.Transport.Client (TransportHost (..)) import Simplex.Messaging.Transport.Credentials import Simplex.Messaging.Util @@ -619,7 +620,7 @@ getServiceCredentials c userId srv = let g = agentDRG c ((C.KeyHash kh, serviceCreds), serviceId_) <- withStore' c $ \db -> - getClientService db userId srv >>= \case + getClientServiceCredentials db userId srv >>= \case Just service -> pure service Nothing -> do cred <- genCredentials g Nothing (25, 24 * 999999) "simplex" @@ -747,15 +748,13 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm smp <- liftError (protocolClientError SMP $ B.unpack $ strEncode srv) $ do ts <- readTVarIO proxySessTs ExceptT $ getProtocolClient g nm tSess cfg' presetDomains (Just msgQ) ts $ smpClientDisconnected c tSess env v' prs - -- TODO [certs rcv] add service to SS, possibly combine with SS.setSessionId atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c updateClientService service smp pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs} - -- TODO [certs rcv] this should differentiate between service ID just set and service ID changed, and in the latter case disassociate the queues updateClientService service smp = case (service, smpClientService smp) of - (Just (_, serviceId_), Just THClientService {serviceId}) - | serviceId_ /= Just serviceId -> withStore' c $ \db -> setClientServiceId db userId srv serviceId - | otherwise -> pure () + (Just (_, serviceId_), Just THClientService {serviceId}) -> withStore' c $ \db -> do + setClientServiceId db userId srv serviceId + forM_ serviceId_ $ \sId -> when (sId /= serviceId) $ removeRcvServiceAssocs db userId srv (Just _, Nothing) -> withStore' c $ \db -> deleteClientService db userId srv -- e.g., server version downgrade (Nothing, Just _) -> logError "server returned serviceId without service credentials in request" (Nothing, Nothing) -> pure () @@ -1258,6 +1257,14 @@ protocolClientError protocolError_ host = \case PCEServiceUnavailable {} -> BROKER host NO_SERVICE PCEIOError e -> BROKER host $ NETWORK $ NEConnectError $ E.displayException e +-- it is consistent with smpClientServiceError +clientServiceError :: AgentErrorType -> Bool +clientServiceError = \case + BROKER _ NO_SERVICE -> True + SMP _ SMP.SERVICE -> True + SMP _ (SMP.PROXY (SMP.BROKER NO_SERVICE)) -> True -- for completeness, it cannot happen. + _ -> False + data ProtocolTestStep = TSConnect | TSDisconnect @@ -1446,8 +1453,8 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl withClient c nm tSess $ \(SMPConnectedClient smp _) -> do (ntfKeys, ntfCreds) <- liftIO $ mkNtfCreds a g smp (thParams smp,ntfKeys,) <$> createSMPQueue smp nm nonce_ rKeys dhKey auth subMode (queueReqData cqrd) ntfCreds - -- TODO [certs rcv] validate that serviceId is the same as in the client session, fail otherwise - -- possibly, it should allow returning Nothing - it would indicate incorrect old version + let sessServiceId = (\THClientService {serviceId = sId} -> sId) <$> (clientService =<< thAuth thParams') + when (isJust serviceId && serviceId /= sessServiceId) $ logError "incorrect service ID in NEW response" liftIO . logServer "<--" c srv NoEntity $ B.unwords ["IDS", logSecret rcvId, logSecret sndId] shortLink <- mkShortLinkCreds thParams' qik let rq = @@ -1463,7 +1470,7 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl sndId, queueMode, shortLink, - rcvServiceAssoc = isJust serviceId, + rcvServiceAssoc = isJust serviceId && serviceId == sessServiceId, status = New, enableNtfs, clientNoticeId = Nothing, @@ -1559,6 +1566,8 @@ temporaryAgentError :: AgentErrorType -> Bool temporaryAgentError = \case BROKER _ e -> tempBrokerError e SMP _ (SMP.PROXY (SMP.BROKER e)) -> tempBrokerError e + SMP _ (SMP.STORE _) -> True + NTF _ (SMP.STORE _) -> True XFTP _ XFTP.TIMEOUT -> True PROXY _ _ (ProxyProtocolError (SMP.PROXY (SMP.BROKER e))) -> tempBrokerError e PROXY _ _ (ProxyProtocolError (SMP.PROXY SMP.NO_SESSION)) -> True @@ -1569,6 +1578,7 @@ temporaryAgentError = \case tempBrokerError = \case NETWORK _ -> True TIMEOUT -> True + TRANSPORT (TEHandshake BAD_SERVICE) -> True -- this error is considered temporary because it is DB error _ -> False temporaryOrHostError :: AgentErrorType -> Bool @@ -1715,11 +1725,16 @@ processClientNotices c@AgentClient {presetServers} tSess notices = do notifySub' c "" $ ERR e resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub -> AM ServiceSubResult -resubscribeClientService c tSess serviceSub = - withServiceClient c tSess $ \smp _ -> subscribeClientService_ c True tSess smp serviceSub - -subscribeClientService :: AgentClient -> Bool -> UserId -> SMPServer -> Int64 -> IdsHash -> AM ServiceSubResult -subscribeClientService c withEvent userId srv n idsHash = +resubscribeClientService c tSess@(userId, srv, _) serviceSub = + withServiceClient c tSess (\smp _ -> subscribeClientService_ c True tSess smp serviceSub) `catchE` \e -> do + when (clientServiceError e) $ do + qs <- withStore' c $ \db -> unassocUserServerRcvQueueSubs db userId srv + void $ lift $ subscribeUserServerQueues c userId srv qs + throwE e + +-- TODO [certs rcv] update service in the database if it has different ID and re-associate queues, and send event +subscribeClientService :: AgentClient -> Bool -> UserId -> SMPServer -> ServiceSub -> AM ServiceSubResult +subscribeClientService c withEvent userId srv (ServiceSub _ n idsHash) = withServiceClient c tSess $ \smp smpServiceId -> do let serviceSub = ServiceSub smpServiceId n idsHash atomically $ SS.setPendingServiceSub tSess serviceSub $ currentSubs c @@ -1728,14 +1743,15 @@ subscribeClientService c withEvent userId srv n idsHash = tSess = (userId, srv, Nothing) withServiceClient :: AgentClient -> SMPTransportSession -> (SMPClient -> ServiceId -> ExceptT SMPClientError IO a) -> AM a -withServiceClient c tSess action = +withServiceClient c tSess subscribe = withLogClient c NRMBackground tSess B.empty "SUBS" $ \(SMPConnectedClient smp _) -> case (\THClientService {serviceId} -> serviceId) <$> smpClientService smp of - Just smpServiceId -> action smp smpServiceId + Just smpServiceId -> subscribe smp smpServiceId Nothing -> throwE PCEServiceUnavailable +-- TODO [certs rcv] send subscription error event? subscribeClientService_ :: AgentClient -> Bool -> SMPTransportSession -> SMPClient -> ServiceSub -> ExceptT SMPClientError IO ServiceSubResult -subscribeClientService_ c withEvent tSess@(_, srv, _) smp expected@(ServiceSub _ n idsHash) = do +subscribeClientService_ c withEvent tSess@(userId, srv, _) smp expected@(ServiceSub _ n idsHash) = do subscribed <- subscribeService smp SMP.SRecipientService n idsHash let sessId = sessionId $ thParams smp r = serviceSubResult expected subscribed diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index a732d28d4..0d0b2af70 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -37,7 +37,9 @@ module Simplex.Messaging.Agent.Store.AgentStore -- * Client services createClientService, - getClientService, + getClientServiceCredentials, + getSubscriptionServices, + getSubscriptionService, getClientServiceServers, setClientServiceId, deleteClientService, @@ -52,8 +54,10 @@ module Simplex.Messaging.Agent.Store.AgentStore updateClientNotices, getSubscriptionServers, getUserServerRcvQueueSubs, + unassocUserServerRcvQueueSubs, unsetQueuesToSubscribe, setRcvServiceAssocs, + removeRcvServiceAssocs, getConnIds, getConn, getDeletedConn, @@ -419,8 +423,8 @@ createClientService db userId srv (kh, (cert, pk)) = do |] (userId, host srv, port srv, serverKeyHash_, kh, cert, pk) -getClientService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId)) -getClientService db userId srv = +getClientServiceCredentials :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId)) +getClientServiceCredentials db userId srv = maybeFirstRow toService $ DB.query db @@ -435,21 +439,41 @@ getClientService db userId srv = where toService (kh, cert, pk, serviceId_) = ((kh, (cert, pk)), serviceId_) -getClientServiceServers :: DB.Connection -> UserId -> IO [(SMPServer, ServiceSub)] -getClientServiceServers db userId = - map toServer - <$> DB.query +getSubscriptionServices :: DB.Connection -> IO [(UserId, (SMPServer, ServiceSub))] +getSubscriptionServices db = map toUserService <$> DB.query_ db clientServiceQuery + where + toUserService (Only userId :. serviceRow) = (userId, toServerService serviceRow) + +getSubscriptionService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ServiceSub) +getSubscriptionService db userId (SMPServer h p kh) = + maybeFirstRow toService $ + DB.query db [sql| - SELECT c.host, c.port, s.key_hash, c.service_id, c.service_queue_count, c.service_queue_ids_hash + SELECT c.service_id, c.service_queue_count, c.service_queue_ids_hash FROM client_services c JOIN servers s ON s.host = c.host AND s.port = c.port - WHERE c.user_id = ? + WHERE c.user_id = ? AND c.host = ? AND c.port = ? AND COALESCE(c.server_key_hash, s.key_hash) = ? |] - (Only userId) + (userId, h, p, kh) where - toServer (host, port, kh, serviceId, n, Binary idsHash) = - (SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash)) + toService (serviceId, qCnt, idsHash) = ServiceSub serviceId qCnt idsHash + +getClientServiceServers :: DB.Connection -> UserId -> IO [(SMPServer, ServiceSub)] +getClientServiceServers db userId = + map toServerService <$> DB.query db (clientServiceQuery <> " WHERE c.user_id = ?") (Only userId) + +clientServiceQuery :: Query +clientServiceQuery = + [sql| + SELECT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash), c.service_id, c.service_queue_count, c.service_queue_ids_hash + FROM client_services c + JOIN servers s ON s.host = c.host AND s.port = c.port + |] + +toServerService :: (NonEmpty TransportHost, ServiceName, C.KeyHash, ServiceId, Int64, Binary ByteString) -> (ProtocolServer 'PSMP, ServiceSub) +toServerService (host, port, kh, serviceId, n, Binary idsHash) = + (SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash)) setClientServiceId :: DB.Connection -> UserId -> SMPServer -> ServiceId -> IO () setClientServiceId db userId srv serviceId = @@ -473,7 +497,9 @@ deleteClientService db userId srv = (userId, host srv, port srv) deleteClientServices :: DB.Connection -> UserId -> IO () -deleteClientServices db userId = DB.execute db "DELETE FROM client_services WHERE user_id = ?" (Only userId) +deleteClientServices db userId = do + DB.execute db "DELETE FROM client_services WHERE user_id = ?" (Only userId) + removeUserRcvServiceAssocs db userId createConn_ :: TVar ChaChaDRG -> @@ -2236,17 +2262,36 @@ getSubscriptionServers db onlyNeeded = toUserServer :: (UserId, NonEmpty TransportHost, ServiceName, C.KeyHash) -> (UserId, SMPServer) toUserServer (userId, host, port, keyHash) = (userId, SMPServer host port keyHash) -getUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> Bool -> IO [RcvQueueSub] -getUserServerRcvQueueSubs db userId srv onlyNeeded = +-- TODO [certs rcv] check index for getting queues with service present +getUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> Bool -> ServiceAssoc -> IO [RcvQueueSub] +getUserServerRcvQueueSubs db userId srv onlyNeeded hasService = map toRcvQueueSub <$> DB.query db - (rcvQueueSubQuery <> toSubscribe <> " c.deleted = 0 AND q.deleted = 0 AND c.user_id = ? AND q.host = ? AND q.port = ?") + (rcvQueueSubQuery <> toSubscribe <> " c.deleted = 0 AND q.deleted = 0 AND c.user_id = ? AND q.host = ? AND q.port = ?" <> serviceCond) (userId, host srv, port srv) where toSubscribe | onlyNeeded = " WHERE q.to_subscribe = 1 AND " | otherwise = " WHERE " + serviceCond + | hasService = " AND q.rcv_service_assoc = 0" + | otherwise = "" + +unassocUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> IO [RcvQueueSub] +unassocUserServerRcvQueueSubs db userId (SMPServer h p kh) = + map toRcvQueueSub + <$> DB.query + db + (removeRcvAssocsQuery <> " " <> returningColums) + (h, p, userId, kh) + where + returningColums = + [sql| + RETURNING c.user_id, rcv_queues.conn_id, rcv_queues.host, rcv_queues.port, COALESCE(rcv_queues.server_key_hash, s.key_hash), + rcv_queues.rcv_id, rcv_queues.rcv_private_key, rcv_queues.status, c.enable_ntfs, rcv_queues.client_notice_id, + rcv_queues.rcv_queue_id, rcv_queues.rcv_primary, rcv_queues.replace_rcv_queue_id + |] unsetQueuesToSubscribe :: DB.Connection -> IO () unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = 0 WHERE to_subscribe = 1" @@ -2259,6 +2304,36 @@ setRcvServiceAssocs db rqs = DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = ?" $ map (Only . queueId) rqs #endif +removeRcvServiceAssocs :: DB.Connection -> UserId -> SMPServer -> IO () +removeRcvServiceAssocs db userId (SMPServer h p kh) = DB.execute db removeRcvAssocsQuery (h, p, userId, kh) + +removeRcvAssocsQuery :: Query +removeRcvAssocsQuery = + [sql| + UPDATE rcv_queues + SET rcv_service_assoc = 0 + FROM connections c, servers s + WHERE rcv_queues.host = ? + AND rcv_queues.port = ? + AND c.conn_id = rcv_queues.conn_id + AND c.user_id = ? + AND s.host = rcv_queues.host + AND s.port = rcv_queues.port + AND COALESCE(rcv_queues.server_key_hash, s.key_hash) = ? + |] + +removeUserRcvServiceAssocs :: DB.Connection -> UserId -> IO () +removeUserRcvServiceAssocs db userId = + DB.execute + db + [sql| + UPDATE rcv_queues + SET rcv_service_assoc = 0 + FROM connections c + WHERE c.conn_id = rcv_queues.conn_id AND c.user_id = ? + |] + (Only userId) + -- * getConn helpers getConnIds :: DB.Connection -> IO [ConnId] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index d5b8f8290..a670dd3e2 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -67,10 +67,10 @@ import Simplex.Messaging.Agent.Store.Migrations (DBMigrate (..), sharedMigrateSc import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Agent.Store.SQLite.Common import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB -import Simplex.Messaging.Agent.Store.SQLite.Util (SQLiteFunc, createStaticFunction, mkSQLiteFunc) +import Simplex.Messaging.Agent.Store.SQLite.Util import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationError (..)) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Util (ifM, safeDecodeUtf8) +import Simplex.Messaging.Util (ifM, packZipWith, safeDecodeUtf8) import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) import System.FilePath (takeDirectory, takeFileName, ()) @@ -116,9 +116,7 @@ connectDB path functions key track = do -- _printPragmas db path pure db where - functions' = SQLiteFuncDef "simplex_xor_md5_combine" 2 True sqliteXorMd5CombinePtr : functions prepare db = do - let db' = SQL.connectionHandle $ DB.conn db unless (BA.null key) . SQLite3.exec db' $ "PRAGMA key = " <> keyString key <> ";" SQLite3.exec db' . fromQuery $ [sql| @@ -128,9 +126,14 @@ connectDB path functions key track = do PRAGMA secure_delete = ON; PRAGMA auto_vacuum = FULL; |] - forM_ functions' $ \SQLiteFuncDef {funcName, argCount, deterministic, funcPtr} -> - createStaticFunction db' funcName argCount deterministic funcPtr - >>= either (throwIO . userError . show) pure + mapM_ addFunction functions' + where + db' = SQL.connectionHandle $ DB.conn db + functions' = SQLiteFuncDef "simplex_xor_md5_combine" 2 (SQLiteFuncPtr True sqliteXorMd5CombinePtr) : functions + addFunction SQLiteFuncDef {funcName, argCount, funcPtrs} = + either (throwIO . userError . show) pure =<< case funcPtrs of + SQLiteFuncPtr isDet funcPtr -> createStaticFunction db' funcName argCount isDet funcPtr + SQLiteAggrPtrs stepPtr finalPtr -> createStaticAggregate db' funcName argCount stepPtr finalPtr foreign export ccall "simplex_xor_md5_combine" sqliteXorMd5Combine :: SQLiteFunc @@ -143,7 +146,8 @@ sqliteXorMd5Combine = mkSQLiteFunc $ \cxt args -> do SQLite3.funcResultBlob cxt $ xorMd5Combine idsHash rId xorMd5Combine :: ByteString -> ByteString -> ByteString -xorMd5Combine idsHash rId = B.packZipWith xor idsHash $ C.md5Hash rId +xorMd5Combine idsHash rId = packZipWith xor idsHash $ C.md5Hash rId +{-# INLINE xorMd5Combine #-} closeDBStore :: DBStore -> IO () closeDBStore st@DBStore {dbClosed} = diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index 0634360a2..448c885f2 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -7,6 +7,7 @@ module Simplex.Messaging.Agent.Store.SQLite.Common ( DBStore (..), DBOpts (..), SQLiteFuncDef (..), + SQLiteFuncPtrs (..), withConnection, withConnection', withTransaction, @@ -55,14 +56,18 @@ data DBOpts = DBOpts track :: DB.TrackQueries } --- e.g. `SQLiteFuncDef "name" 2 True f` +-- e.g. `SQLiteFuncDef "func_name" 2 (SQLiteFuncPtr True func)` +-- or `SQLiteFuncDef "aggr_name" 3 (SQLiteAggrPtrs step final)` data SQLiteFuncDef = SQLiteFuncDef { funcName :: ByteString, argCount :: CArgCount, - deterministic :: Bool, - funcPtr :: FunPtr SQLiteFunc + funcPtrs :: SQLiteFuncPtrs } +data SQLiteFuncPtrs + = SQLiteFuncPtr {deterministic :: Bool, funcPtr :: FunPtr SQLiteFunc} + | SQLiteAggrPtrs {stepPtr :: FunPtr SQLiteFunc, finalPtr :: FunPtr SQLiteFuncFinal} + withConnectionPriority :: DBStore -> Bool -> (DB.Connection -> IO a) -> IO a withConnectionPriority DBStore {dbSem, dbConnection} priority action | priority = E.bracket_ signal release $ withMVar dbConnection action diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs index a3c3b94ac..2cbd7ecff 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs @@ -3,16 +3,20 @@ module Simplex.Messaging.Agent.Store.SQLite.Util where import Control.Exception (SomeException, catch, mask_) import Data.ByteString (ByteString) import qualified Data.ByteString as B +import Data.IORef import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..)) import Database.SQLite3.Bindings import Foreign.C.String import Foreign.Ptr import Foreign.StablePtr +import Foreign.Storable data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal) type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO () +type SQLiteFuncFinal = Ptr CContext -> IO () + mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals) {-# INLINE mkSQLiteFunc #-} @@ -25,6 +29,50 @@ createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do B.useAsCString name $ \namePtr -> toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr +mkSQLiteAggStep :: a -> (FuncContext -> FuncArgs -> a -> IO a) -> SQLiteFunc +mkSQLiteAggStep initSt xStep cxt nArgs cvals = catchAsResultError cxt $ do + -- we store the aggregate state in the buffer returned by + -- c_sqlite3_aggregate_context as a StablePtr pointing to an IORef that + -- contains the actual aggregate state + aggCtx <- getAggregateContext cxt + aggStPtr <- peek aggCtx + aggStRef <- + if castStablePtrToPtr aggStPtr /= nullPtr + then deRefStablePtr aggStPtr + else do + aggStRef <- newIORef initSt + aggStPtr' <- newStablePtr aggStRef + poke aggCtx aggStPtr' + return aggStRef + aggSt <- readIORef aggStRef + aggSt' <- xStep (FuncContext cxt) (FuncArgs nArgs cvals) aggSt + writeIORef aggStRef aggSt' + +mkSQLiteAggFinal :: a -> (FuncContext -> a -> IO ()) -> SQLiteFuncFinal +mkSQLiteAggFinal initSt xFinal cxt = do + aggCtx <- getAggregateContext cxt + aggStPtr <- peek aggCtx + if castStablePtrToPtr aggStPtr == nullPtr + then catchAsResultError cxt $ xFinal (FuncContext cxt) initSt + else do + catchAsResultError cxt $ do + aggStRef <- deRefStablePtr aggStPtr + aggSt <- readIORef aggStRef + xFinal (FuncContext cxt) aggSt + freeStablePtr aggStPtr + +getAggregateContext :: Ptr CContext -> IO (Ptr a) +getAggregateContext cxt = c_sqlite3_aggregate_context cxt stPtrSize + where + stPtrSize = fromIntegral $ sizeOf (undefined :: StablePtr ()) + +-- Based on createAggregate from Database.SQLite3.Direct, but uses static function pointers to avoid dynamic wrappers that trigger DCL. +createStaticAggregate :: Database -> ByteString -> CArgCount -> FunPtr SQLiteFunc -> FunPtr SQLiteFuncFinal -> IO (Either Error ()) +createStaticAggregate (Database db) name nArgs stepPtr finalPtr = mask_ $ do + u <- newStablePtr $ CFuncPtrs nullFunPtr stepPtr finalPtr + B.useAsCString name $ \namePtr -> + toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs 0 (castStablePtrToPtr u) nullFunPtr stepPtr finalPtr nullFunPtr + -- Convert a 'CError' to a 'Either Error', in the common case where -- SQLITE_OK signals success and anything else signals an error. -- diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 81e9820a2..ac2dc9a9d 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -778,10 +778,10 @@ temporaryClientError = \case _ -> False {-# INLINE temporaryClientError #-} +-- it is consistent with clientServiceError smpClientServiceError :: SMPClientError -> Bool smpClientServiceError = \case PCEServiceUnavailable -> True - PCETransportError (TEHandshake BAD_SERVICE) -> True -- TODO [certs rcv] this error may be temporary, so we should possibly resubscribe. PCEProtocolError SERVICE -> True PCEProtocolError (PROXY (BROKER NO_SERVICE)) -> True -- for completeness, it cannot happen. _ -> False diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index a5f94960e..6b232f12b 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -143,6 +143,7 @@ module Simplex.Messaging.Protocol IdsHash (..), ServiceSub (..), ServiceSubResult (..), + ServiceSubError (..), serviceSubResult, queueIdsHash, queueIdHash, diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index e9f37b1ae..83a911452 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE MonadComprehensions #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -15,6 +16,7 @@ import qualified Data.Aeson as J import Data.Bifunctor (first, second) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.ByteString.Internal (toForeignPtr, unsafeCreate) import qualified Data.ByteString.Lazy.Char8 as LB import Data.IORef import Data.Int (Int64) @@ -29,6 +31,9 @@ import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8With, encodeUtf8) import Data.Time (NominalDiffTime) import Data.Tuple (swap) +import Data.Word (Word8) +import Foreign.ForeignPtr (withForeignPtr) +import Foreign.Storable (peekByteOff, pokeByteOff) import GHC.Conc (labelThread, myThreadId, threadDelay) import UnliftIO hiding (atomicModifyIORef') import qualified UnliftIO.Exception as UE @@ -156,6 +161,27 @@ mapAccumLM_NonEmpty mapAccumLM_NonEmpty f s (x :| xs) = [(s2, x' :| xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs] +-- | Optimized from bytestring package for GHC 8.10.7 compatibility +packZipWith :: (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> ByteString +packZipWith f s1 s2 = + unsafeCreate len $ \r -> + withForeignPtr fp1 $ \p1 -> + withForeignPtr fp2 $ \p2 -> zipWith_ p1 p2 r + where + zipWith_ p1 p2 r = go 0 + where + go :: Int -> IO () + go !n + | n >= len = pure () + | otherwise = do + x <- peekByteOff p1 (off1 + n) + y <- peekByteOff p2 (off2 + n) + pokeByteOff r n (f x y) + go (n + 1) + (fp1, off1, l1) = toForeignPtr s1 + (fp2, off2, l2) = toForeignPtr s2 + len = min l1 l2 + tryWriteTBQueue :: TBQueue a -> a -> STM Bool tryWriteTBQueue q a = do full <- isFullTBQueue q diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 2a62deb45..31967917a 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -66,7 +66,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Either (isRight) import Data.Int (Int64) -import Data.List (find, isSuffixOf, nub) +import Data.List (find, isPrefixOf, isSuffixOf, nub) import Data.List.NonEmpty (NonEmpty) import qualified Data.Map as M import Data.Maybe (isJust, isNothing) @@ -113,7 +113,7 @@ import Simplex.Messaging.Util (bshow, diffToMicroseconds) import Simplex.Messaging.Version (VersionRange (..)) import qualified Simplex.Messaging.Version as V import Simplex.Messaging.Version.Internal (Version (..)) -import System.Directory (copyFile, renameFile) +import System.Directory (copyFile, removeFile, renameFile) import Test.Hspec hiding (fit, it) import UnliftIO import Util @@ -124,12 +124,13 @@ import Fixtures #endif #if defined(dbServerPostgres) import qualified Database.PostgreSQL.Simple as PSQL -import Simplex.Messaging.Agent.Store (Connection' (..), StoredRcvQueue (..), SomeConn' (..)) -import Simplex.Messaging.Agent.Store.AgentStore (getConn) +import qualified Simplex.Messaging.Agent.Store.Postgres as Postgres +import qualified Simplex.Messaging.Agent.Store.Postgres.Common as Postgres import Simplex.Messaging.Server.MsgStore.Journal (JournalQueue) import Simplex.Messaging.Server.MsgStore.Postgres (PostgresQueue) import Simplex.Messaging.Server.MsgStore.Types (QSType (..)) import Simplex.Messaging.Server.QueueStore.Postgres +import Simplex.Messaging.Server.QueueStore.Postgres.Migrations import Simplex.Messaging.Server.QueueStore.Types (QueueStoreClass (..)) #endif @@ -478,6 +479,7 @@ functionalAPITests ps = do withSmpServer ps testTwoUsers describe "Client service certificates" $ do it "should connect, subscribe and reconnect as a service" $ testClientServiceConnection ps + it "should re-subscribe when service ID changed" $ testClientServiceIDChange ps describe "Connection switch" $ do describe "should switch delivery to the new queue" $ testServerMatrix2 ps testSwitchConnection @@ -3679,26 +3681,84 @@ testClientServiceConnection ps = do subscribeConnection user sId exchangeGreetingsMsgId 4 service uId user sId pure (conns, qIdHash) - withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + (uId', sId') <- withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do - [(_, Right (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash')))] <- M.toList <$> subscribeClientServices service 1 - ("", "", SERVICE_ALL _) <- nGet service - liftIO $ qIdHash' `shouldBe` qIdHash + subscribeAllConnections service False Nothing + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash')))) -> qIdHash' == qIdHash; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] subscribeConnection user sId exchangeGreetingsMsgId 6 service uId user sId ("", "", DOWN _ [_]) <- nGet user ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 1 qIdHash')) <- nGet service qIdHash' `shouldBe` qIdHash -- TODO [certs rcv] how to integrate service counts into stats - -- r <- nGet service -- TODO [certs rcv] some event when service disconnects with count - -- print r withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do ("", "", UP _ [_]) <- nGet user - ("", "", SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash''))) <- nGet service - ("", "", SERVICE_ALL _) <- nGet service - liftIO $ qIdHash'' `shouldBe` qIdHash - -- r <- nGet service -- TODO [certs rcv] some event when service reconnects with count + -- Nothing in ServiceSubResult confirms that both counts and IDs hash match + -- SERVICE_ALL may be deliverd before SERVICE_UP event in case there are no messages to deliver + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash'')))) -> qIdHash'' == qIdHash; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] exchangeGreetingsMsgId 8 service uId user sId + conns'@(uId', sId') <- makeConnection user service -- opposite direction + exchangeGreetings user sId' service uId' + pure conns' + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + subscribeAllConnections service False Nothing + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 2 _)))) -> True; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] + -- TODO [certs rcv] test message delivery during subscription + subscribeAllConnections user False Nothing + ("", "", UP _ [_, _]) <- nGet user + exchangeGreetingsMsgId 4 user sId' service uId' + exchangeGreetingsMsgId 10 service uId user sId + +testClientServiceIDChange :: HasCallStack => (ASrvTransport, AStoreType) -> IO () +testClientServiceIDChange ps@(_, ASType qs _) = do + (sId, uId) <- withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + conns@(sId, uId) <- makeConnection service user + exchangeGreetings service uId user sId + pure conns + _ :: () <- case qs of + SQSPostgres -> do +#if defined(dbServerPostgres) + st <- either (error . show) pure =<< Postgres.createDBStore testStoreDBOpts serverMigrations (MigrationConfig MCError Nothing) + void $ Postgres.withTransaction st (`PSQL.execute_` "DELETE FROM services") +#else + pure () +#endif + SQSMemory -> do + s <- readFile testStoreLogFile + removeFile testStoreLogFile + writeFile testStoreLogFile $ unlines $ filter (not . ("NEW_SERVICE" `isPrefixOf`)) $ lines s + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + subscribeAllConnections service False Nothing + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult (Just (SMP.SSErrorQueueCount 1 0)) (SMP.ServiceSub _ 0 _)))) -> True; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False, + \case ("", "", AEvt SAENone (UP _ _)) -> True; _ -> False + ] + subscribeAllConnections user False Nothing + ("", "", UP _ [_]) <- nGet user + exchangeGreetingsMsgId 4 service uId user sId + -- disable service in the client + -- The test uses True for non-existing user to make sure it's removed for user 1, + -- because if no users use services, then it won't be checking them to optimize for most clients. + withAgentClientsServers2 (agentCfg, initAgentServers {useServices = M.fromList [(100, True)]}) (agentCfg, initAgentServers) $ \notService user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + subscribeAllConnections notService False Nothing + ("", "", UP _ [_]) <- nGet notService + subscribeAllConnections user False Nothing + ("", "", UP _ [_]) <- nGet user + exchangeGreetingsMsgId 6 notService uId user sId getSMPAgentClient' :: Int -> AgentConfig -> InitialAgentServers -> String -> IO AgentClient getSMPAgentClient' clientId cfg' initServers dbPath = do From f5eb735551cd36803845564fec3e75146370772f Mon Sep 17 00:00:00 2001 From: Evgeny Date: Sun, 14 Dec 2025 12:07:29 +0000 Subject: [PATCH 06/11] servers: service stats and logging, allow services without option (removed), report errors during service message delivery, remove threads when service subscription ended (#1676) * smp server: always allow services without option * smp server: maintain IDs hash in session subscription states * smp server: service message delivery error handling * ntf server: log subscription count and hash differences * smp server: remove delivery threads when service subscription ended/client disconnected --- src/Simplex/Messaging/Agent.hs | 1 - src/Simplex/Messaging/Notifications/Server.hs | 10 +- .../Messaging/Notifications/Server/Stats.hs | 1 - src/Simplex/Messaging/Protocol.hs | 11 +++ src/Simplex/Messaging/Server.hs | 95 ++++++++++--------- src/Simplex/Messaging/Server/Env/STM.hs | 14 +-- src/Simplex/Messaging/Server/Main.hs | 15 +-- .../Messaging/Server/MsgStore/Journal.hs | 4 +- .../Messaging/Server/MsgStore/Postgres.hs | 4 +- src/Simplex/Messaging/Server/MsgStore/STM.hs | 8 +- .../Messaging/Server/MsgStore/Types.hs | 2 +- src/Simplex/Messaging/Server/Prometheus.hs | 35 ++++++- .../Messaging/Server/QueueStore/Postgres.hs | 4 +- .../Messaging/Server/QueueStore/STM.hs | 4 +- src/Simplex/Messaging/Server/Stats.hs | 16 ++++ src/Simplex/Messaging/Transport.hs | 1 - tests/AgentTests/FunctionalAPITests.hs | 1 - 17 files changed, 147 insertions(+), 79 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index f155ce77b..f44708fe6 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -1539,7 +1539,6 @@ resubscribeConnections' c connIds = do [] -> pure True rqs' -> anyM $ map (atomically . hasActiveSubscription c) rqs' --- TODO [certs rcv] compare hash. possibly, it should return both expected and returned counts subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType ServiceSubResult)) subscribeClientServices' c userId = ifM useService subscribe $ throwError $ CMD PROHIBITED "no user service allowed" diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 67ed89d71..e7c1ca5f9 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -576,9 +576,10 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = -- TODO [certs rcv] resubscribe queues with statuses NSErr and NSService CAServiceDisconnected srv serviceSub -> logNote $ "SMP server service disconnected " <> showService srv serviceSub - CAServiceSubscribed srv serviceSub@(ServiceSub _ expected _) (ServiceSub _ n _) -- TODO [certs rcv] compare hash - | expected == n -> logNote msg - | otherwise -> logWarn $ msg <> ", confirmed subs: " <> tshow n + CAServiceSubscribed srv serviceSub@(ServiceSub _ n idsHash) (ServiceSub _ n' idsHash') + | n /= n' -> logWarn $ msg <> ", confirmed subs: " <> tshow n' + | idsHash /= idsHash' -> logWarn $ msg <> ", different IDs hash" + | otherwise -> logNote msg where msg = "SMP server service subscribed " <> showService srv serviceSub CAServiceSubError srv serviceSub e -> @@ -593,8 +594,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = void $ subscribeSrvSubs ca st batchSize (srv, srvId, Nothing) Left e -> logError $ "SMP server update and resubscription error " <> tshow e where - -- TODO [certs rcv] compare hash - showService srv (ServiceSub serviceId n _idsHash) = showServer' srv <> ", service ID " <> decodeLatin1 (strEncode serviceId) <> ", " <> tshow n <> " subs" + showService srv (ServiceSub serviceId n _) = showServer' srv <> ", service ID " <> decodeLatin1 (strEncode serviceId) <> ", " <> tshow n <> " subs" logSubErrors :: SMPServer -> NonEmpty (SMP.NotifierId, NtfSubStatus) -> Int -> IO () logSubErrors srv subs updated = forM_ (L.group $ L.sort $ L.map snd subs) $ \ss -> do diff --git a/src/Simplex/Messaging/Notifications/Server/Stats.hs b/src/Simplex/Messaging/Notifications/Server/Stats.hs index 7125ce290..a20e41c34 100644 --- a/src/Simplex/Messaging/Notifications/Server/Stats.hs +++ b/src/Simplex/Messaging/Notifications/Server/Stats.hs @@ -17,7 +17,6 @@ import Simplex.Messaging.Server.Stats import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM --- TODO [certs rcv] track service subscriptions and count/hash diffs for own and other servers + prometheus data NtfServerStats = NtfServerStats { fromTime :: IORef UTCTime, tknCreated :: IORef Int, diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 6b232f12b..51128597c 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -147,6 +147,9 @@ module Simplex.Messaging.Protocol serviceSubResult, queueIdsHash, queueIdHash, + noIdsHash, + addServiceSubs, + subtractServiceSubs, MaxMessageLen, MaxRcvMessageLen, EncRcvMsgBody (..), @@ -1526,6 +1529,14 @@ queueIdHash :: QueueId -> IdsHash queueIdHash = IdsHash . C.md5Hash . unEntityId {-# INLINE queueIdHash #-} +addServiceSubs :: (Int64, IdsHash) -> (Int64, IdsHash) -> (Int64, IdsHash) +addServiceSubs (n', idsHash') (n, idsHash) = (n + n', idsHash <> idsHash') + +subtractServiceSubs :: (Int64, IdsHash) -> (Int64, IdsHash) -> (Int64, IdsHash) +subtractServiceSubs (n', idsHash') (n, idsHash) + | n > n' = (n - n', idsHash <> idsHash') -- concat is a reversible xor: (x `xor` y) `xor` y == x + | otherwise = (0, noIdsHash) + data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock -- | Type for protocol errors. diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 0fc15b3e3..b7bb0efaa 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -166,8 +166,8 @@ type AttachHTTP = Socket -> TLS.Context -> IO () -- actions used in serverThread to reduce STM transaction scope data ClientSubAction = CSAEndSub QueueId -- end single direct queue subscription - | CSAEndServiceSub -- end service subscription to one queue - | CSADecreaseSubs Int64 -- reduce service subscriptions when cancelling. Fixed number is used to correctly handle race conditions when service resubscribes + | CSAEndServiceSub QueueId -- end service subscription to one queue + | CSADecreaseSubs (Int64, IdsHash) -- reduce service subscriptions when cancelling. Fixed number is used to correctly handle race conditions when service resubscribes type PrevClientSub s = (Client s, ClientSubAction, (EntityId, BrokerMsg)) @@ -251,7 +251,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt Server s -> (Server s -> ServerSubscribers s) -> (Client s -> TMap QueueId sub) -> - (Client s -> TVar Int64) -> + (Client s -> TVar (Int64, IdsHash)) -> Maybe (sub -> IO ()) -> M s () serverThread label srv srvSubscribers clientSubs clientServiceSubs unsub_ = do @@ -277,7 +277,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt as'' <- if prevServiceId == serviceId_ then pure [] else endServiceSub prevServiceId qId END case serviceId_ of Just serviceId -> do - modifyTVar' totalServiceSubs (+ 1) -- server count for all services + modifyTVar' totalServiceSubs $ addServiceSubs (1, queueIdHash qId) -- server count and IDs hash for all services as <- endQueueSub qId END as' <- cancelServiceSubs serviceId =<< upsertSubscribedClient serviceId c serviceSubscribers pure $ as ++ as' ++ as'' @@ -289,9 +289,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt as <- endQueueSub qId DELD as' <- endServiceSub serviceId qId DELD pure $ as ++ as' - CSService serviceId count -> do + CSService serviceId changedSubs -> do modifyTVar' subClients $ IS.insert clntId -- add ID to server's subscribed cients - modifyTVar' totalServiceSubs (+ count) -- server count for all services + modifyTVar' totalServiceSubs $ subtractServiceSubs changedSubs -- server count and IDs hash for all services cancelServiceSubs serviceId =<< upsertSubscribedClient serviceId c serviceSubscribers updateSubDisconnected = case clntSub of -- do not insert client if it is already disconnected, but send END/DELD to any other client subscribed to this queue or service @@ -309,15 +309,15 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt endQueueSub qId msg = prevSub qId msg (CSAEndSub qId) =<< lookupDeleteSubscribedClient qId queueSubscribers endServiceSub :: Maybe ServiceId -> QueueId -> BrokerMsg -> STM [PrevClientSub s] endServiceSub Nothing _ _ = pure [] - endServiceSub (Just serviceId) qId msg = prevSub qId msg CSAEndServiceSub =<< lookupSubscribedClient serviceId serviceSubscribers + endServiceSub (Just serviceId) qId msg = prevSub qId msg (CSAEndServiceSub qId) =<< lookupSubscribedClient serviceId serviceSubscribers prevSub :: QueueId -> BrokerMsg -> ClientSubAction -> Maybe (Client s) -> STM [PrevClientSub s] prevSub qId msg action = checkAnotherClient $ \c -> pure [(c, action, (qId, msg))] cancelServiceSubs :: ServiceId -> Maybe (Client s) -> STM [PrevClientSub s] cancelServiceSubs serviceId = checkAnotherClient $ \c -> do - n <- swapTVar (clientServiceSubs c) 0 - pure [(c, CSADecreaseSubs n, (serviceId, ENDS n))] + changedSubs@(n, _) <- swapTVar (clientServiceSubs c) (0, noIdsHash) + pure [(c, CSADecreaseSubs changedSubs, (serviceId, ENDS n))] checkAnotherClient :: (Client s -> STM [PrevClientSub s]) -> Maybe (Client s) -> STM [PrevClientSub s] checkAnotherClient mkSub = \case Just c@Client {clientId, connected} | clntId /= clientId -> @@ -332,20 +332,21 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt where a (Just unsub) (Just s) = unsub s a _ _ = pure () - CSAEndServiceSub -> atomically $ do + CSAEndServiceSub qId -> atomically $ do modifyTVar' (clientServiceSubs c) decrease modifyTVar' totalServiceSubs decrease where - decrease n = max 0 (n - 1) - -- TODO [certs rcv] for SMP subscriptions CSADecreaseSubs should also remove all delivery threads of the passed client - CSADecreaseSubs n' -> atomically $ modifyTVar' totalServiceSubs $ \n -> max 0 (n - n') + decrease = subtractServiceSubs (1, queueIdHash qId) + CSADecreaseSubs changedSubs -> do + atomically $ modifyTVar' totalServiceSubs $ subtractServiceSubs changedSubs + forM_ unsub_ $ \unsub -> atomically (swapTVar (clientSubs c) M.empty) >>= mapM_ unsub where endSub :: Client s -> QueueId -> STM (Maybe sub) endSub c qId = TM.lookupDelete qId (clientSubs c) >>= (removeWhenNoSubs c $>) -- remove client from server's subscribed cients removeWhenNoSubs c = do noClientSubs <- null <$> readTVar (clientSubs c) - noServiceSubs <- (0 ==) <$> readTVar (clientServiceSubs c) + noServiceSubs <- ((0 ==) . fst) <$> readTVar (clientServiceSubs c) when (noClientSubs && noServiceSubs) $ modifyTVar' subClients $ IS.delete (clientId c) deliverNtfsThread :: Server s -> M s () @@ -1112,10 +1113,10 @@ clientDisconnected c@Client {clientId, subscriptions, ntfSubscriptions, serviceS updateSubscribers subs ServerSubscribers {queueSubscribers, subClients} = do mapM_ (\qId -> deleteSubcribedClient qId c queueSubscribers) (M.keys subs) atomically $ modifyTVar' subClients $ IS.delete clientId - updateServiceSubs :: ServiceId -> TVar Int64 -> ServerSubscribers s -> IO () + updateServiceSubs :: ServiceId -> TVar (Int64, IdsHash) -> ServerSubscribers s -> IO () updateServiceSubs serviceId subsCount ServerSubscribers {totalServiceSubs, serviceSubscribers} = do deleteSubcribedClient serviceId c serviceSubscribers - atomically . modifyTVar' totalServiceSubs . subtract =<< readTVarIO subsCount + atomically . modifyTVar' totalServiceSubs . subtractServiceSubs =<< readTVarIO subsCount cancelSub :: Sub -> IO () cancelSub s = case subThread s of @@ -1357,7 +1358,6 @@ forkClient Client {endThreads, endThreadSeq} label action = do client :: forall s. MsgStoreClass s => Server s -> s -> Client s -> M s () client - -- TODO [certs rcv] rcv subscriptions Server {subscribers, ntfSubscribers} ms clnt@Client {clientId, rcvQ, sndQ, msgQ, clientTHParams = thParams'@THandleParams {sessionId}, procThreads} = do @@ -1661,7 +1661,7 @@ client subscribeNewQueue :: RecipientId -> QueueRec -> M s () subscribeNewQueue rId QueueRec {rcvServiceId} = do case rcvServiceId of - Just _ -> atomically $ modifyTVar' (serviceSubsCount clnt) (+ 1) + Just _ -> atomically $ modifyTVar' (serviceSubsCount clnt) $ addServiceSubs (1, queueIdHash rId) Nothing -> do sub <- atomically $ newSubscription NoSub atomically $ TM.insert rId sub $ subscriptions clnt @@ -1741,7 +1741,7 @@ client Maybe ServiceId -> ServerSubscribers s -> (Client s -> TMap QueueId sub) -> - (Client s -> TVar Int64) -> + (Client s -> TVar (Int64, IdsHash)) -> STM sub -> (ServerStats -> ServiceStats) -> M s (Either ErrorType (Bool, Maybe sub)) @@ -1771,9 +1771,9 @@ client incSrvStat $ maybe srvAssocNew (const srvAssocUpdated) queueServiceId pure (hasSub, Nothing) where - hasServiceSub = (0 /=) <$> readTVar (clientServiceSubs clnt) + hasServiceSub = ((0 /=) . fst) <$> readTVar (clientServiceSubs clnt) -- This function is used when queue association with the service is created. - incServiceQueueSubs = modifyTVar' (clientServiceSubs clnt) (+ 1) -- service count + incServiceQueueSubs = modifyTVar' (clientServiceSubs clnt) $ addServiceSubs (1, queueIdHash (recipientId q)) -- service count and IDs hash Nothing -> case queueServiceId of Just _ -> runExceptT $ do ExceptT $ setQueueService (queueStore ms) q party Nothing @@ -1801,27 +1801,36 @@ client sharedSubscribeService SRecipientService serviceId expected subscribers serviceSubscribed serviceSubsCount rcvServices >>= \case Left e -> pure $ ERR e Right (hasSub, (count, idsHash)) -> do - unless hasSub $ forkClient clnt "deliverServiceMessages" $ liftIO $ deliverServiceMessages count + stats <- asks serverStats + unless hasSub $ forkClient clnt "deliverServiceMessages" $ liftIO $ deliverServiceMessages stats count pure $ SOKS count idsHash where - deliverServiceMessages expectedCnt = do - (qCnt, _msgCnt, _dupCnt, _errCnt) <- foldRcvServiceMessages ms serviceId deliverQueueMsg (0, 0, 0, 0) - atomically $ writeTBQueue msgQ [(NoCorrId, NoEntity, ALLS)] - -- TODO [certs rcv] compare with expected - logNote $ "Service subscriptions for " <> tshow serviceId <> " (" <> tshow qCnt <> " queues)" - deliverQueueMsg :: (Int, Int, Int, Int) -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO (Int, Int, Int, Int) - deliverQueueMsg (!qCnt, !msgCnt, !dupCnt, !errCnt) rId = \case - Left e -> pure (qCnt + 1, msgCnt, dupCnt, errCnt + 1) -- TODO [certs rcv] deliver subscription error + deliverServiceMessages stats expectedCnt = do + foldRcvServiceMessages ms serviceId deliverQueueMsg (0, 0, 0, [(NoCorrId, NoEntity, ALLS)]) >>= \case + Right (qCnt, msgCnt, dupCnt, evts) -> do + atomically $ writeTBQueue msgQ evts + atomicModifyIORef'_ (rcvServicesSubMsg stats) (+ msgCnt) + atomicModifyIORef'_ (rcvServicesSubDuplicate stats) (+ dupCnt) + let logMsg = "Subscribed service " <> tshow serviceId <> " (" + if qCnt == expectedCnt + then logNote $ logMsg <> tshow qCnt <> " queues)" + else logError $ logMsg <> "expected " <> tshow expectedCnt <> "," <> tshow qCnt <> " queues)" + Left e -> do + logError $ "Service subscription error for " <> tshow serviceId <> ": " <> tshow e + atomically $ writeTBQueue msgQ [(NoCorrId, NoEntity, ERR e)] + deliverQueueMsg :: (Int64, Int, Int, NonEmpty (Transmission BrokerMsg)) -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO (Int64, Int, Int, NonEmpty (Transmission BrokerMsg)) + deliverQueueMsg (!qCnt, !msgCnt, !dupCnt, evts) rId = \case + Left e -> pure (qCnt + 1, msgCnt, dupCnt, (NoCorrId, rId, ERR e) <| evts) Right qMsg_ -> case qMsg_ of - Nothing -> pure (qCnt + 1, msgCnt, dupCnt, errCnt) + Nothing -> pure (qCnt + 1, msgCnt, dupCnt, evts) Just (qr, msg) -> atomically (getSubscription rId) >>= \case - Nothing -> pure (qCnt + 1, msgCnt, dupCnt + 1, errCnt) + Nothing -> pure (qCnt + 1, msgCnt, dupCnt + 1, evts) Just sub -> do ts <- getSystemSeconds atomically $ setDelivered sub msg ts atomically $ writeTBQueue msgQ [(NoCorrId, rId, MSG (encryptMsg qr msg))] - pure (qCnt + 1, msgCnt + 1, dupCnt, errCnt) + pure (qCnt + 1, msgCnt + 1, dupCnt, evts) getSubscription rId = TM.lookup rId (subscriptions clnt) >>= \case -- If delivery subscription already exists, then there is no need to deliver message. @@ -1836,28 +1845,28 @@ client subscribeServiceNotifications serviceId expected = either ERR (uncurry SOKS . snd) <$> sharedSubscribeService SNotifierService serviceId expected ntfSubscribers ntfServiceSubscribed ntfServiceSubsCount ntfServices - sharedSubscribeService :: (PartyI p, ServiceParty p) => SParty p -> ServiceId -> (Int64, IdsHash) -> ServerSubscribers s -> (Client s -> TVar Bool) -> (Client s -> TVar Int64) -> (ServerStats -> ServiceStats) -> M s (Either ErrorType (Bool, (Int64, IdsHash))) + sharedSubscribeService :: (PartyI p, ServiceParty p) => SParty p -> ServiceId -> (Int64, IdsHash) -> ServerSubscribers s -> (Client s -> TVar Bool) -> (Client s -> TVar (Int64, IdsHash)) -> (ServerStats -> ServiceStats) -> M s (Either ErrorType (Bool, (Int64, IdsHash))) sharedSubscribeService party serviceId (count, idsHash) srvSubscribers clientServiceSubscribed clientServiceSubs servicesSel = do subscribed <- readTVarIO $ clientServiceSubscribed clnt stats <- asks serverStats liftIO $ runExceptT $ (subscribed,) <$> if subscribed - then (,mempty) <$> readTVarIO (clientServiceSubs clnt) -- TODO [certs rcv] get IDs hash + then readTVarIO $ clientServiceSubs clnt else do - (count', idsHash') <- ExceptT $ getServiceQueueCountHash @(StoreQueue s) (queueStore ms) party serviceId - incCount <- atomically $ do + subs'@(count', idsHash') <- ExceptT $ getServiceQueueCountHash @(StoreQueue s) (queueStore ms) party serviceId + subsChange <- atomically $ do writeTVar (clientServiceSubscribed clnt) True - currCount <- swapTVar (clientServiceSubs clnt) count' -- TODO [certs rcv] maintain IDs hash here? - pure $ count' - currCount + currSubs <- swapTVar (clientServiceSubs clnt) subs' + pure $ subtractServiceSubs currSubs subs' let incSrvStat sel n = liftIO $ atomicModifyIORef'_ (sel $ servicesSel stats) (+ n) diff = fromIntegral $ count' - count - if -- TODO [certs rcv] account for not provided counts/hashes (expected n = -1) - | diff == 0 && idsHash == idsHash' -> incSrvStat srvSubOk 1 + if -- `count == -1` only for subscriptions by old NTF servers + | count == -1 && (diff == 0 && idsHash == idsHash') -> incSrvStat srvSubOk 1 | diff > 0 -> incSrvStat srvSubMore 1 >> incSrvStat srvSubMoreTotal diff | diff < 0 -> incSrvStat srvSubFewer 1 >> incSrvStat srvSubFewerTotal (- diff) | otherwise -> incSrvStat srvSubDiff 1 - atomically $ writeTQueue (subQ srvSubscribers) (CSService serviceId incCount, clientId) + atomically $ writeTQueue (subQ srvSubscribers) (CSService serviceId subsChange, clientId) pure (count', idsHash') acknowledgeMsg :: MsgId -> StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg) @@ -2133,7 +2142,7 @@ client -- we delete subscription here, so the client with no subscriptions can be disconnected. sub <- atomically $ TM.lookupDelete entId $ subscriptions clnt liftIO $ mapM_ cancelSub sub - when (isJust rcvServiceId) $ atomically $ modifyTVar' (serviceSubsCount clnt) $ \n -> max 0 (n - 1) + when (isJust rcvServiceId) $ atomically $ modifyTVar' (serviceSubsCount clnt) $ subtractServiceSubs (1, queueIdHash (recipientId q)) atomically $ writeTQueue (subQ subscribers) (CSDeleted entId rcvServiceId, clientId) forM_ (notifier qr) $ \NtfCreds {notifierId = nId, ntfServiceId} -> do -- queue is deleted by a different client from the one subscribed to notifications, diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 24cd6dfcc..02cf136c7 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -363,7 +363,7 @@ data ServerSubscribers s = ServerSubscribers { subQ :: TQueue (ClientSub, ClientId), queueSubscribers :: SubscribedClients s, serviceSubscribers :: SubscribedClients s, -- service clients with long-term certificates that have subscriptions - totalServiceSubs :: TVar Int64, + totalServiceSubs :: TVar (Int64, IdsHash), subClients :: TVar IntSet, -- clients with individual or service subscriptions pendingEvents :: TVar (IntMap (NonEmpty (EntityId, BrokerMsg))) } @@ -426,7 +426,7 @@ sameClient c cv = maybe False (sameClientId c) <$> readTVar cv data ClientSub = CSClient QueueId (Maybe ServiceId) (Maybe ServiceId) -- includes previous and new associated service IDs | CSDeleted QueueId (Maybe ServiceId) -- includes previously associated service IDs - | CSService ServiceId Int64 -- only send END to idividual client subs on message delivery, not of SSUB/NSSUB + | CSService ServiceId (Int64, IdsHash) -- only send END to idividual client subs on message delivery, not of SSUB/NSSUB newtype ProxyAgent = ProxyAgent { smpAgent :: SMPClientAgent 'Sender @@ -440,8 +440,8 @@ data Client s = Client ntfSubscriptions :: TMap NotifierId (), serviceSubscribed :: TVar Bool, -- set independently of serviceSubsCount, to track whether service subscription command was received ntfServiceSubscribed :: TVar Bool, - serviceSubsCount :: TVar Int64, -- only one service can be subscribed, based on its certificate, this is subscription count - ntfServiceSubsCount :: TVar Int64, -- only one service can be subscribed, based on its certificate, this is subscription count + serviceSubsCount :: TVar (Int64, IdsHash), -- only one service can be subscribed, based on its certificate, this is subscription count + ntfServiceSubsCount :: TVar (Int64, IdsHash), -- only one service can be subscribed, based on its certificate, this is subscription count rcvQ :: TBQueue (NonEmpty (VerifiedTransmission s)), sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg), [Transmission BrokerMsg]), msgQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), @@ -502,7 +502,7 @@ newServerSubscribers = do subQ <- newTQueueIO queueSubscribers <- SubscribedClients <$> TM.emptyIO serviceSubscribers <- SubscribedClients <$> TM.emptyIO - totalServiceSubs <- newTVarIO 0 + totalServiceSubs <- newTVarIO (0, noIdsHash) subClients <- newTVarIO IS.empty pendingEvents <- newTVarIO IM.empty pure ServerSubscribers {subQ, queueSubscribers, serviceSubscribers, totalServiceSubs, subClients, pendingEvents} @@ -513,8 +513,8 @@ newClient clientId qSize clientTHParams createdAt = do ntfSubscriptions <- TM.emptyIO serviceSubscribed <- newTVarIO False ntfServiceSubscribed <- newTVarIO False - serviceSubsCount <- newTVarIO 0 - ntfServiceSubsCount <- newTVarIO 0 + serviceSubsCount <- newTVarIO (0, noIdsHash) + ntfServiceSubsCount <- newTVarIO (0, noIdsHash) rcvQ <- newTBQueueIO qSize sndQ <- newTBQueueIO qSize msgQ <- newTBQueueIO qSize diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 7de966c36..86ff3d4a9 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -18,7 +18,7 @@ module Simplex.Messaging.Server.Main where import Control.Concurrent.STM -import Control.Exception (SomeException, finally, try) +import Control.Exception (finally) import Control.Logger.Simple import Control.Monad import qualified Data.Attoparsec.ByteString.Char8 as A @@ -28,10 +28,8 @@ import Data.Char (isAlpha, isAscii, toUpper) import Data.Either (fromRight) import Data.Functor (($>)) import Data.Ini (Ini, lookupValue, readIniFile) -import Data.Int (Int64) import Data.List (find, isPrefixOf) import qualified Data.List.NonEmpty as L -import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe, isJust, isNothing) import Data.Text (Text) import qualified Data.Text as T @@ -61,14 +59,17 @@ import Simplex.Messaging.Transport (supportedProxyClientSMPRelayVRange, alpnSupp import Simplex.Messaging.Transport.Client (TransportHost (..), defaultSocksProxy) import Simplex.Messaging.Transport.HTTP2 (httpALPN) import Simplex.Messaging.Transport.Server (ServerCredentials (..), mkTransportServerConfig) -import Simplex.Messaging.Util (eitherToMaybe, ifM, unlessM) +import Simplex.Messaging.Util (eitherToMaybe, ifM) import System.Directory (createDirectoryIfMissing, doesDirectoryExist, doesFileExist) import System.Exit (exitFailure) import System.FilePath (combine) -import System.IO (BufferMode (..), IOMode (..), hSetBuffering, stderr, stdout, withFile) +import System.IO (BufferMode (..), hSetBuffering, stderr, stdout) import Text.Read (readMaybe) #if defined(dbServerPostgres) +import Control.Exception (SomeException, try) +import Data.Int (Int64) +import qualified Data.Map.Strict as M import Data.Semigroup (Sum (..)) import Simplex.Messaging.Agent.Store.Postgres (checkSchemaExists) import Simplex.Messaging.Server.MsgStore.Journal (JournalQueue) @@ -79,7 +80,9 @@ import Simplex.Messaging.Server.QueueStore.Postgres (batchInsertQueues, batchIns import Simplex.Messaging.Server.QueueStore.STM (STMQueueStore (..)) import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.StoreLog (closeStoreLog, logNewService, logCreateQueue, openWriteStoreLog) +import Simplex.Messaging.Util (unlessM) import System.Directory (renameFile) +import System.IO (IOMode (..), withFile) #endif smpServerCLI :: FilePath -> FilePath -> IO () @@ -556,7 +559,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = mkTransportServerConfig (fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini) (Just $ alpnSupportedSMPHandshakes <> httpALPN) - (fromMaybe True $ iniOnOff "TRANSPORT" "accept_service_credentials" ini), -- TODO [certs rcv] remove this option + True, controlPort = eitherToMaybe $ T.unpack <$> lookupValue "TRANSPORT" "control_port" ini, smpAgentCfg = defaultSMPClientAgentConfig diff --git a/src/Simplex/Messaging/Server/MsgStore/Journal.hs b/src/Simplex/Messaging/Server/MsgStore/Journal.hs index 89e9f0383..c65660c93 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Journal.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Journal.hs @@ -444,9 +444,9 @@ instance MsgStoreClass (JournalMsgStore s) where getLoadedQueue :: JournalQueue s -> IO (JournalQueue s) getLoadedQueue q = fromMaybe q <$> TM.lookupIO (recipientId q) (loadedQueues $ queueStore_ ms) - foldRcvServiceMessages :: JournalMsgStore s -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO a + foldRcvServiceMessages :: JournalMsgStore s -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO (Either ErrorType a) foldRcvServiceMessages ms serviceId f acc = case queueStore_ ms of - MQStore st -> foldRcvServiceQueues st serviceId f' acc + MQStore st -> fmap Right $ foldRcvServiceQueues st serviceId f' acc where f' a (q, qr) = runExceptT (tryPeekMsg ms q) >>= f a (recipientId q) . ((qr,) <$$>) #if defined(dbServerPostgres) diff --git a/src/Simplex/Messaging/Server/MsgStore/Postgres.hs b/src/Simplex/Messaging/Server/MsgStore/Postgres.hs index f3000811b..edf7f481c 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Postgres.hs @@ -119,9 +119,9 @@ instance MsgStoreClass PostgresMsgStore where toMessageStats (expiredMsgsCount, storedMsgsCount, storedQueues) = MessageStats {expiredMsgsCount, storedMsgsCount, storedQueues} - foldRcvServiceMessages :: PostgresMsgStore -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO a + foldRcvServiceMessages :: PostgresMsgStore -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO (Either ErrorType a) foldRcvServiceMessages ms serviceId f acc = - withTransaction (dbStore $ queueStore_ ms) $ \db -> + runExceptT $ withDB' "foldRcvServiceMessages" (queueStore_ ms) $ \db -> DB.fold db [sql| diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 24d489acc..f118e007c 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -87,10 +87,10 @@ instance MsgStoreClass STMMsgStore where expireOldMessages _tty ms now ttl = withLoadedQueues (queueStore_ ms) $ atomically . expireQueueMsgs ms now (now - ttl) - foldRcvServiceMessages :: STMMsgStore -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO a - foldRcvServiceMessages ms serviceId f= - foldRcvServiceQueues (queueStore_ ms) serviceId $ \a (q, qr) -> - runExceptT (tryPeekMsg ms q) >>= f a (recipientId q) . ((qr,) <$$>) + foldRcvServiceMessages :: STMMsgStore -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO (Either ErrorType a) + foldRcvServiceMessages ms serviceId f = fmap Right . foldRcvServiceQueues (queueStore_ ms) serviceId f' + where + f' a (q, qr) = runExceptT (tryPeekMsg ms q) >>= f a (recipientId q) . ((qr,) <$$>) logQueueStates _ = pure () {-# INLINE logQueueStates #-} diff --git a/src/Simplex/Messaging/Server/MsgStore/Types.hs b/src/Simplex/Messaging/Server/MsgStore/Types.hs index e186da05a..fc97bbc20 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Types.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Types.hs @@ -45,7 +45,7 @@ class (Monad (StoreMonad s), QueueStoreClass (StoreQueue s) (QueueStore s)) => M unsafeWithAllMsgQueues :: Monoid a => Bool -> s -> (StoreQueue s -> IO a) -> IO a -- tty, store, now, ttl expireOldMessages :: Bool -> s -> Int64 -> Int64 -> IO MessageStats - foldRcvServiceMessages :: s -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO a + foldRcvServiceMessages :: s -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO (Either ErrorType a) logQueueStates :: s -> IO () logQueueState :: StoreQueue s -> StoreMonad s () queueStore :: s -> QueueStore s diff --git a/src/Simplex/Messaging/Server/Prometheus.hs b/src/Simplex/Messaging/Server/Prometheus.hs index e4d6a2774..1e3c5132d 100644 --- a/src/Simplex/Messaging/Server/Prometheus.hs +++ b/src/Simplex/Messaging/Server/Prometheus.hs @@ -21,7 +21,6 @@ import Simplex.Messaging.Transport (simplexMQVersion) import Simplex.Messaging.Transport.Server (SocketStats (..)) import Simplex.Messaging.Util (tshow) --- TODO [certs rcv] add service subscriptions and count/hash diffs data ServerMetrics = ServerMetrics { statsData :: ServerStatsData, activeQueueCounts :: PeriodStatCounts, @@ -118,6 +117,8 @@ prometheusMetrics sm rtm ts = _pMsgFwdsRecv, _rcvServices, _ntfServices, + _rcvServicesSubMsg, + _rcvServicesSubDuplicate, _qCount, _msgCount, _ntfCount @@ -383,6 +384,14 @@ prometheusMetrics sm rtm ts = \# HELP simplex_smp_ntf_services_queues_count The count of queues associated with notification services.\n\ \# TYPE simplex_smp_ntf_services_queues_count gauge\n\ \simplex_smp_ntf_services_queues_count " <> mshow (ntfServiceQueuesCount entityCounts) <> "\n# ntfServiceQueuesCount\n\ + \\n\ + \# HELP simplex_smp_rcv_services_sub_msg The count of subscribed service queues with messages.\n\ + \# TYPE simplex_smp_rcv_services_sub_msg counter\n\ + \simplex_smp_rcv_services_sub_msg " <> mshow _rcvServicesSubMsg <> "\n# rcvServicesSubMsg\n\ + \\n\ + \# HELP simplex_smp_rcv_services_sub_duplicate The count of duplicate subscribed service queues.\n\ + \# TYPE simplex_smp_rcv_services_sub_duplicate counter\n\ + \simplex_smp_rcv_services_sub_duplicate " <> mshow _rcvServicesSubDuplicate <> "\n# rcvServicesSubDuplicate\n\ \\n" <> showServices _rcvServices "rcv" "receiving" <> showServices _ntfServices "ntf" "notification" @@ -418,6 +427,30 @@ prometheusMetrics sm rtm ts = \# HELP simplex_smp_" <> pfx <> "_services_sub_end Ended subscriptions with " <> name <> " services.\n\ \# TYPE simplex_smp_" <> pfx <> "_services_sub_end gauge\n\ \simplex_smp_" <> pfx <> "_services_sub_end " <> mshow (_srvSubEnd ss) <> "\n# " <> pfx <> ".srvSubEnd\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_sub_ok Service subscriptions for " <> name <> " services.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_sub_ok gauge\n\ + \simplex_smp_" <> pfx <> "_services_sub_ok " <> mshow (_srvSubOk ss) <> "\n# " <> pfx <> ".srvSubOk\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_sub_more Service subscriptions for " <> name <> " services with more queues than in the client.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_sub_more gauge\n\ + \simplex_smp_" <> pfx <> "_services_sub_more " <> mshow (_srvSubMore ss) <> "\n# " <> pfx <> ".srvSubMore\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_sub_fewer Service subscriptions for " <> name <> " services with fewer queues than in the client.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_sub_fewer gauge\n\ + \simplex_smp_" <> pfx <> "_services_sub_fewer " <> mshow (_srvSubFewer ss) <> "\n# " <> pfx <> ".srvSubFewer\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_sub_diff Service subscriptions for " <> name <> " services with different hash than in the client.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_sub_diff gauge\n\ + \simplex_smp_" <> pfx <> "_services_sub_diff " <> mshow (_srvSubDiff ss) <> "\n# " <> pfx <> ".srvSubDiff\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_sub_more_total Service subscriptions for " <> name <> " services with more queues than in the client total.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_sub_more_total gauge\n\ + \simplex_smp_" <> pfx <> "_services_sub_more_total " <> mshow (_srvSubMoreTotal ss) <> "\n# " <> pfx <> ".srvSubMoreTotal\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_sub_fewer_total Service subscriptions for " <> name <> " services with fewer queues than in the client total.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_sub_fewer_total gauge\n\ + \simplex_smp_" <> pfx <> "_services_sub_fewer_total " <> mshow (_srvSubFewerTotal ss) <> "\n# " <> pfx <> ".srvSubFewerTotal\n\ \\n" info = "# Info\n\ diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs index eb1ba3b2c..a8c8c040a 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs @@ -581,9 +581,9 @@ foldServiceRecs st f = DB.fold_ db "SELECT service_id, service_role, service_cert, service_cert_hash, created_at FROM services" mempty $ \ !acc -> fmap (acc <>) . f . rowToServiceRec -foldRcvServiceQueueRecs :: PostgresQueueStore q -> ServiceId -> (a -> (RecipientId, QueueRec) -> IO a) -> a -> IO a +foldRcvServiceQueueRecs :: PostgresQueueStore q -> ServiceId -> (a -> (RecipientId, QueueRec) -> IO a) -> a -> IO (Either ErrorType a) foldRcvServiceQueueRecs st serviceId f acc = - withTransaction (dbStore st) $ \db -> + runExceptT $ withDB' "foldRcvServiceQueueRecs" st $ \db -> DB.fold db (queueRecQuery <> " WHERE rcv_service_id = ? AND deleted_at IS NULL") (Only serviceId) acc $ \a -> f a . rowToQueueRec foldQueueRecs :: Monoid a => Bool -> Bool -> PostgresQueueStore q -> ((RecipientId, QueueRec) -> IO a) -> IO a diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index 8b64db55a..3a236076c 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -63,8 +63,8 @@ data STMQueueStore q = STMQueueStore data STMService = STMService { serviceRec :: ServiceRec, - serviceRcvQueues :: TVar (Set RecipientId, IdsHash), -- TODO [certs rcv] get/maintain hash - serviceNtfQueues :: TVar (Set NotifierId, IdsHash) -- TODO [certs rcv] get/maintain hash + serviceRcvQueues :: TVar (Set RecipientId, IdsHash), + serviceNtfQueues :: TVar (Set NotifierId, IdsHash) } setStoreLog :: STMQueueStore q -> StoreLog 'WriteMode -> IO () diff --git a/src/Simplex/Messaging/Server/Stats.hs b/src/Simplex/Messaging/Server/Stats.hs index 120fad7b6..613c5e8be 100644 --- a/src/Simplex/Messaging/Server/Stats.hs +++ b/src/Simplex/Messaging/Server/Stats.hs @@ -86,6 +86,8 @@ data ServerStats = ServerStats pMsgFwdsRecv :: IORef Int, rcvServices :: ServiceStats, ntfServices :: ServiceStats, + rcvServicesSubMsg :: IORef Int, + rcvServicesSubDuplicate :: IORef Int, qCount :: IORef Int, msgCount :: IORef Int, ntfCount :: IORef Int @@ -145,6 +147,8 @@ data ServerStatsData = ServerStatsData _pMsgFwdsRecv :: Int, _ntfServices :: ServiceStatsData, _rcvServices :: ServiceStatsData, + _rcvServicesSubMsg :: Int, + _rcvServicesSubDuplicate :: Int, _qCount :: Int, _msgCount :: Int, _ntfCount :: Int @@ -206,6 +210,8 @@ newServerStats ts = do pMsgFwdsRecv <- newIORef 0 rcvServices <- newServiceStats ntfServices <- newServiceStats + rcvServicesSubMsg <- newIORef 0 + rcvServicesSubDuplicate <- newIORef 0 qCount <- newIORef 0 msgCount <- newIORef 0 ntfCount <- newIORef 0 @@ -264,6 +270,8 @@ newServerStats ts = do pMsgFwdsRecv, rcvServices, ntfServices, + rcvServicesSubMsg, + rcvServicesSubDuplicate, qCount, msgCount, ntfCount @@ -324,6 +332,8 @@ getServerStatsData s = do _pMsgFwdsRecv <- readIORef $ pMsgFwdsRecv s _rcvServices <- getServiceStatsData $ rcvServices s _ntfServices <- getServiceStatsData $ ntfServices s + _rcvServicesSubMsg <- readIORef $ rcvServicesSubMsg s + _rcvServicesSubDuplicate <- readIORef $ rcvServicesSubDuplicate s _qCount <- readIORef $ qCount s _msgCount <- readIORef $ msgCount s _ntfCount <- readIORef $ ntfCount s @@ -382,6 +392,8 @@ getServerStatsData s = do _pMsgFwdsRecv, _rcvServices, _ntfServices, + _rcvServicesSubMsg, + _rcvServicesSubDuplicate, _qCount, _msgCount, _ntfCount @@ -443,6 +455,8 @@ setServerStats s d = do writeIORef (pMsgFwdsRecv s) $! _pMsgFwdsRecv d setServiceStats (rcvServices s) $! _rcvServices d setServiceStats (ntfServices s) $! _ntfServices d + writeIORef (rcvServicesSubMsg s) $! _rcvServicesSubMsg d + writeIORef (rcvServicesSubDuplicate s) $! _rcvServicesSubDuplicate d writeIORef (qCount s) $! _qCount d writeIORef (msgCount s) $! _msgCount d writeIORef (ntfCount s) $! _ntfCount d @@ -636,6 +650,8 @@ instance StrEncoding ServerStatsData where _pMsgFwdsRecv, _rcvServices, _ntfServices, + _rcvServicesSubMsg = 0, + _rcvServicesSubDuplicate = 0, _qCount, _msgCount = 0, _ntfCount = 0 diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index a14118ce4..f1eb1a8bd 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -560,7 +560,6 @@ data SMPClientHandshake = SMPClientHandshake keyHash :: C.KeyHash, -- | pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys. authPubKey :: Maybe C.PublicKeyX25519, - -- TODO [certs rcv] remove proxyServer, as serviceInfo includes it as clientRole -- | Whether connecting client is a proxy server (send from SMP v12). -- This property, if True, disables additional transport encrytion inside TLS. -- (Proxy server connection already has additional encryption, so this layer is not needed there). diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 31967917a..b63e4cb48 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -3693,7 +3693,6 @@ testClientServiceConnection ps = do ("", "", DOWN _ [_]) <- nGet user ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 1 qIdHash')) <- nGet service qIdHash' `shouldBe` qIdHash - -- TODO [certs rcv] how to integrate service counts into stats withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do ("", "", UP _ [_]) <- nGet user -- Nothing in ServiceSubResult confirms that both counts and IDs hash match From a1277bf6bfb30015ef00bb0de58664ee00efe114 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Fri, 19 Dec 2025 21:10:12 +0000 Subject: [PATCH 07/11] agent: remove service queue association when service ID changed, process ENDS event, test migrating to/from service (#1677) * agent: remove service queue association when service ID changed * agent: process ENDS event * agent: send service subscription error event * agent: test migrating to/from service subscriptions, fixes * agent: always remove service when disabled, fix service subscriptions --- src/Simplex/Messaging/Agent.hs | 86 +++++---- src/Simplex/Messaging/Agent/Client.hs | 43 +++-- src/Simplex/Messaging/Agent/Protocol.hs | 3 + .../Messaging/Agent/Store/AgentStore.hs | 48 +++-- src/Simplex/Messaging/Agent/TSessionSubs.hs | 59 +++--- src/Simplex/Messaging/Protocol.hs | 32 ++-- src/Simplex/Messaging/Server.hs | 4 +- src/Simplex/Messaging/Server/Env/STM.hs | 6 +- tests/AgentTests/FunctionalAPITests.hs | 172 +++++++++++++++++- tests/CoreTests/TSessionSubs.hs | 6 +- tests/ServerTests.hs | 4 +- 11 files changed, 338 insertions(+), 125 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index f44708fe6..e17c39a16 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -221,7 +221,9 @@ import Simplex.Messaging.Protocol SMPMsgMeta, SParty (..), SProtocolType (..), - ServiceSubResult, + ServiceSub (..), + ServiceSubResult (..), + ServiceSubError (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, @@ -1040,10 +1042,10 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userLinkData_ clientData pqIni createRcvQueue nonce_ qd e2eKeys = do AgentConfig {smpClientVRange = vr} <- asks config ntfServer_ <- if enableNtfs then newQueueNtfServer else pure Nothing - (rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAllErrors` \e -> liftIO (print e) >> throwE e + (rq, qUri, tSess, sessId, serviceId_) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAllErrors` \e -> liftIO (print e) >> throwE e atomically $ incSMPServerStat c userId srv connCreated rq' <- withStore c $ \db -> updateNewConnRcv db connId rq subMode - lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId + lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId serviceId_ mapM_ (newQueueNtfSubscription c rq') ntfServer_ pure (rq', qUri) createConnReq :: SMPQueueUri -> AM (ConnectionRequestUri c) @@ -1291,11 +1293,11 @@ joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode createReplyQueue :: AgentClient -> NetworkRequestMode -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> AM SMPQueueInfo createReplyQueue c nm ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} subMode srv = do ntfServer_ <- if enableNtfs then newQueueNtfServer else pure Nothing - (rq, qUri, tSess, sessId) <- newRcvQueue c nm userId connId srv (versionToRange smpClientVersion) SCMInvitation (isJust ntfServer_) subMode + (rq, qUri, tSess, sessId, serviceId_) <- newRcvQueue c nm userId connId srv (versionToRange smpClientVersion) SCMInvitation (isJust ntfServer_) subMode atomically $ incSMPServerStat c userId (qServer rq) connCreated let qInfo = toVersionT qUri smpClientVersion rq' <- withStore c $ \db -> upgradeSndConnToDuplex db connId rq subMode - lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId + lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId serviceId_ mapM_ (newQueueNtfSubscription c rq') ntfServer_ pure qInfo @@ -1451,22 +1453,14 @@ subscribeAllConnections' c onlyNeeded activeUserId_ = handleErr $ do Just activeUserId -> sortOn (\(uId, _) -> if uId == activeUserId then 0 else 1 :: Int) userSrvs Nothing -> userSrvs useServices <- readTVarIO $ useClientServices c - -- These options are possible below: - -- 1) services fully disabled: - -- No service subscriptions will be attempted, and existing services and association will remain in in the database, - -- but they will be ignored because of hasService parameter set to False. - -- This approach preserves performance for all clients that do not use services. - -- 2) at least one user ID has services enabled: - -- Service will be loaded for all user/server combinations: - -- a) service is enabled for user ID and service record exists: subscription will be attempted, - -- b) service is disabled and record exists: service record and all associations will be removed, - -- c) service is disabled or no record: no subscription attempt. + -- Service will be loaded for all user/server combinations: + -- a) service is enabled for user ID and service record exists: subscription will be attempted, + -- b) service is disabled and record exists: service record and all associations will be removed, + -- c) service is disabled or no record: no subscription attempt. -- On successful service subscription, only unassociated queues will be subscribed. - userSrvs'' <- - if any id useServices - then lift $ mapConcurrently (subscribeService useServices) userSrvs' - else pure $ map (,False) userSrvs' - rs <- lift $ mapConcurrently (subscribeUserServer maxPending currPending) userSrvs'' + userSrvs2 <- withStore' c $ \db -> mapM (getService db useServices) userSrvs' + userSrvs3 <- lift $ mapConcurrently subscribeService userSrvs2 + rs <- lift $ mapConcurrently (subscribeUserServer maxPending currPending) userSrvs3 let (errs, oks) = partitionEithers rs logInfo $ "subscribed " <> tshow (sum oks) <> " queues" forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map ("",) @@ -1475,16 +1469,30 @@ subscribeAllConnections' c onlyNeeded activeUserId_ = handleErr $ do resumeAllCommands c where handleErr = (`catchAllErrors` \e -> notifySub' c "" (ERR e) >> throwE e) - subscribeService :: Map UserId Bool -> (UserId, SMPServer) -> AM' ((UserId, SMPServer), ServiceAssoc) - subscribeService useServices us@(userId, srv) = fmap ((us,) . fromRight False) $ tryAllErrors' $ do - withStore' c (\db -> getSubscriptionService db userId srv) >>= \case + getService :: DB.Connection -> Map UserId Bool -> (UserId, SMPServer) -> IO ((UserId, SMPServer), Maybe ServiceSub) + getService db useServices us@(userId, srv) = + fmap (us,) $ getSubscriptionService db userId srv >>= \case Just serviceSub -> case M.lookup userId useServices of - Just True -> tryAllErrors (subscribeClientService c True userId srv serviceSub) >>= \case - Left e | clientServiceError e -> unassocQueues $> False + Just True -> pure $ Just serviceSub + _ -> Nothing <$ unassocUserServerRcvQueueSubs' db userId srv + _ -> pure Nothing + subscribeService :: ((UserId, SMPServer), Maybe ServiceSub) -> AM' ((UserId, SMPServer), ServiceAssoc) + subscribeService (us@(userId, srv), serviceSub_) = fmap ((us,) . fromRight False) $ tryAllErrors' $ + case serviceSub_ of + Just serviceSub -> tryAllErrors (subscribeClientService c True userId srv serviceSub) >>= \case + Right (ServiceSubResult e _) -> case e of + Just SSErrorServiceId {} -> unassocQueues + -- Possibly, we should always resubscribe all when expected is greater than subscribed + Just SSErrorQueueCount {expectedQueueCount = n, subscribedQueueCount = n'} | n > 0 && n' == 0 -> unassocQueues _ -> pure True - _ -> unassocQueues $> False + Left e -> do + atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR e) + if clientServiceError e + then unassocQueues + else pure True where - unassocQueues = withStore' c $ \db -> unassocUserServerRcvQueueSubs db userId srv + unassocQueues :: AM Bool + unassocQueues = False <$ withStore' c (\db -> unassocUserServerRcvQueueSubs' db userId srv) _ -> pure False subscribeUserServer :: Int -> TVar Int -> ((UserId, SMPServer), ServiceAssoc) -> AM' (Either AgentErrorType Int) subscribeUserServer maxPending currPending ((userId, srv), hasService) = do @@ -2219,10 +2227,10 @@ switchDuplexConnection c nm (DuplexConnection cData@ConnData {connId, userId} rq srv' <- if srv == server then getNextSMPServer c userId [server] else pure srvAuth -- TODO [notications] possible improvement would be to create ntf credentials here, to avoid creating them after rotation completes. -- The problem is that currently subscription already exists, and we do not support queues with credentials but without subscriptions. - (q, qUri, tSess, sessId) <- newRcvQueue c nm userId connId srv' clientVRange SCMInvitation False SMSubscribe + (q, qUri, tSess, sessId, serviceId_) <- newRcvQueue c nm userId connId srv' clientVRange SCMInvitation False SMSubscribe let rq' = (q :: NewRcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} rq'' <- withStore c $ \db -> addConnRcvQueue db connId rq' SMSubscribe - lift $ addNewQueueSubscription c rq'' tSess sessId + lift $ addNewQueueSubscription c rq'' tSess sessId serviceId_ void . enqueueMessages c cData sqs SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))] rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSendingQADD let rqs' = updatedQs rq1 rqs <> [rq''] @@ -2908,7 +2916,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), THandlePar processSubOk :: RcvQueue -> TVar [ConnId] -> TVar [RcvQueue] -> Maybe SMP.ServiceId -> IO () processSubOk rq@RcvQueue {connId} upConnIds serviceRQs serviceId_ = atomically . whenM (isPendingSub rq) $ do - SS.addActiveSub tSess sessId rq $ currentSubs c + SS.addActiveSub tSess sessId serviceId_ rq $ currentSubs c modifyTVar' upConnIds (connId :) when (isJust serviceId_ && serviceId_ == clientServiceId_) $ modifyTVar' serviceRQs (rq :) clientServiceId_ = (\THClientService {serviceId} -> serviceId) <$> (clientService =<< thAuth) @@ -3115,16 +3123,26 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), THandlePar notifyEnd removed | removed = notify END >> logServer "<--" c srv rId "END" | otherwise = logServer "<--" c srv rId "END from disconnected client - ignored" - -- TODO [certs rcv] - r@(SMP.ENDS _) -> unexpected r + SMP.ENDS n idsHash -> + atomically (ifM (activeClientSession c tSess sessId) (SS.deleteServiceSub tSess (currentSubs c) $> True) (pure False)) + >>= notifyEnd + where + notifyEnd removed + | removed = do + forM_ clientServiceId_ $ \serviceId -> + notify_ B.empty $ SERVICE_END srv $ ServiceSub serviceId n idsHash + logServer "<--" c srv rId "ENDS" + | otherwise = logServer "<--" c srv rId "ENDS from disconnected client - ignored" -- TODO [certs rcv] Possibly, we need to add some flag to connection that it was deleted SMP.DELD -> atomically (removeSubscription c tSess connId rq) >> notify DELD SMP.ERR e -> notify $ ERR $ SMP (B.unpack $ strEncode srv) e r -> unexpected r where notify :: forall e m. (AEntityI e, MonadIO m) => AEvent e -> m () - notify msg = - let t = ("", connId, AEvt (sAEntity @e) msg) + notify = notify_ connId + notify_ :: forall e m. (AEntityI e, MonadIO m) => ConnId -> AEvent e -> m () + notify_ connId' msg = + let t = ("", connId', AEvt (sAEntity @e) msg) in atomically $ ifM (isFullTBQueue subQ) (modifyTVar' pendingMsgs (t :)) (writeTBQueue subQ t) prohibited :: Text -> AM () diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 7acfb0b49..9bf1afd8d 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -266,7 +266,6 @@ import Simplex.Messaging.Protocol NetworkError (..), MsgFlags (..), MsgId, - IdsHash, NtfServer, NtfServerWithAuth, ProtoServer, @@ -283,6 +282,7 @@ import Simplex.Messaging.Protocol SProtocolType (..), ServiceSub (..), ServiceSubResult (..), + ServiceSubError (..), SndPublicAuthKey, SubscriptionMode (..), NewNtfCreds (..), @@ -1420,7 +1420,7 @@ getSessionMode :: AgentClient -> STM TransportSessionMode getSessionMode = fmap (sessionMode . snd) . readTVar . useNetworkConfig {-# INLINE getSessionMode #-} -newRcvQueue :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SConnectionMode c -> Bool -> SubscriptionMode -> AM (NewRcvQueue, SMPQueueUri, SMPTransportSession, SessionId) +newRcvQueue :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SConnectionMode c -> Bool -> SubscriptionMode -> AM (NewRcvQueue, SMPQueueUri, SMPTransportSession, SessionId, Maybe ServiceId) newRcvQueue c nm userId connId srv vRange cMode enableNtfs subMode = do let qrd = case cMode of SCMInvitation -> CQRMessaging Nothing; SCMContact -> CQRContact Nothing e2eKeys <- atomically . C.generateKeyPair =<< asks random @@ -1441,7 +1441,7 @@ queueReqData = \case CQRMessaging d -> QRMessaging $ srvReq <$> d CQRContact d -> QRContact $ srvReq <$> d -newRcvQueue_ :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> ClntQueueReqData -> Bool -> SubscriptionMode -> Maybe C.CbNonce -> C.KeyPairX25519 -> AM (NewRcvQueue, SMPQueueUri, SMPTransportSession, SessionId) +newRcvQueue_ :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> ClntQueueReqData -> Bool -> SubscriptionMode -> Maybe C.CbNonce -> C.KeyPairX25519 -> AM (NewRcvQueue, SMPQueueUri, SMPTransportSession, SessionId, Maybe ServiceId) newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enableNtfs subMode nonce_ (e2eDhKey, e2ePrivKey) = do C.AuthAlg a <- asks (rcvAuthAlg . config) g <- asks random @@ -1483,7 +1483,7 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl deleteErrors = 0 } qUri = SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey queueMode - pure (rq, qUri, tSess, sessionId thParams') + pure (rq, qUri, tSess, sessionId thParams', sessServiceId) where mkNtfCreds :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> TVar ChaChaDRG -> SMPClient -> IO (Maybe (C.AAuthKeyPair, C.PrivateKeyX25519), Maybe NewNtfCreds) mkNtfCreds a g smp @@ -1526,23 +1526,23 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl processSubResults :: AgentClient -> SMPTransportSession -> SessionId -> Maybe ServiceId -> NonEmpty (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> STM ([RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)]) processSubResults c tSess@(userId, srv, _) sessId serviceId_ rs = do - pending <- SS.getPendingSubs tSess $ currentSubs c - let (failed, subscribed@(qs, sQs), notices, ignored) = foldr (partitionResults pending) (M.empty, ([], []), [], 0) rs + pendingSubs <- SS.getPendingQueueSubs tSess $ currentSubs c + let (failed, subscribed@(qs, sQs), notices, ignored) = foldr (partitionResults pendingSubs) (M.empty, ([], []), [], 0) rs unless (M.null failed) $ do incSMPServerStat' c userId srv connSubErrs $ M.size failed failSubscriptions c tSess failed unless (null qs && null sQs) $ do incSMPServerStat' c userId srv connSubscribed $ length qs + length sQs - SS.batchAddActiveSubs tSess sessId subscribed $ currentSubs c + SS.batchAddActiveSubs tSess sessId serviceId_ subscribed $ currentSubs c unless (ignored == 0) $ incSMPServerStat' c userId srv connSubIgnored ignored pure (sQs, notices) where partitionResults :: - (Map SMP.RecipientId RcvQueueSub, Maybe ServiceSub) -> + Map SMP.RecipientId RcvQueueSub -> (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> (Map SMP.RecipientId SMPClientError, ([RcvQueueSub], [RcvQueueSub]), [(RcvQueueSub, Maybe ClientNotice)], Int) -> (Map SMP.RecipientId SMPClientError, ([RcvQueueSub], [RcvQueueSub]), [(RcvQueueSub, Maybe ClientNotice)], Int) - partitionResults (pendingSubs, pendingSS) (rq@RcvQueueSub {rcvId, clientNoticeId}, r) acc@(failed, subscribed@(qs, sQs), notices, ignored) = case r of + partitionResults pendingSubs (rq@RcvQueueSub {rcvId, clientNoticeId}, r) acc@(failed, subscribed@(qs, sQs), notices, ignored) = case r of Left e -> case smpErrorClientNotice e of Just notice_ -> (failed', subscribed, (rq, notice_) : notices, ignored) where @@ -1554,8 +1554,8 @@ processSubResults c tSess@(userId, srv, _) sessId serviceId_ rs = do failed' = M.insert rcvId e failed Right serviceId_' | rcvId `M.member` pendingSubs -> - let subscribed' = case (serviceId_, serviceId_', pendingSS) of - (Just sId, Just sId', Just ServiceSub {smpServiceId}) | sId == sId' && sId == smpServiceId -> (qs, rq : sQs) + let subscribed' = case (serviceId_, serviceId_') of + (Just sId, Just sId') | sId == sId' -> (qs, rq : sQs) _ -> (rq : qs, sQs) in (failed, subscribed', notices', ignored) | otherwise -> (failed, subscribed, notices', ignored + 1) @@ -1726,11 +1726,18 @@ processClientNotices c@AgentClient {presetServers} tSess notices = do resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub -> AM ServiceSubResult resubscribeClientService c tSess@(userId, srv, _) serviceSub = - withServiceClient c tSess (\smp _ -> subscribeClientService_ c True tSess smp serviceSub) `catchE` \e -> do - when (clientServiceError e) $ do + tryAllErrors (withServiceClient c tSess $ \smp _ -> subscribeClientService_ c True tSess smp serviceSub) >>= \case + Right r@(ServiceSubResult e _) -> case e of + Just SSErrorServiceId {} -> unassocSubscribeQueues $> r + _ -> pure r + Left e -> do + when (clientServiceError e) $ unassocSubscribeQueues + atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR e) + throwE e + where + unassocSubscribeQueues = do qs <- withStore' c $ \db -> unassocUserServerRcvQueueSubs db userId srv void $ lift $ subscribeUserServerQueues c userId srv qs - throwE e -- TODO [certs rcv] update service in the database if it has different ID and re-associate queues, and send event subscribeClientService :: AgentClient -> Bool -> UserId -> SMPServer -> ServiceSub -> AM ServiceSubResult @@ -1751,7 +1758,7 @@ withServiceClient c tSess subscribe = -- TODO [certs rcv] send subscription error event? subscribeClientService_ :: AgentClient -> Bool -> SMPTransportSession -> SMPClient -> ServiceSub -> ExceptT SMPClientError IO ServiceSubResult -subscribeClientService_ c withEvent tSess@(userId, srv, _) smp expected@(ServiceSub _ n idsHash) = do +subscribeClientService_ c withEvent tSess@(_, srv, _) smp expected@(ServiceSub _ n idsHash) = do subscribed <- subscribeService smp SMP.SRecipientService n idsHash let sessId = sessionId $ thParams smp r = serviceSubResult expected subscribed @@ -1821,14 +1828,14 @@ getRemovedSubs AgentClient {removedSubs} k = TM.lookup k removedSubs >>= maybe n TM.insert k s removedSubs pure s -addNewQueueSubscription :: AgentClient -> RcvQueue -> SMPTransportSession -> SessionId -> AM' () -addNewQueueSubscription c rq' tSess sessId = do +addNewQueueSubscription :: AgentClient -> RcvQueue -> SMPTransportSession -> SessionId -> Maybe ServiceId -> AM' () +addNewQueueSubscription c rq' tSess sessId serviceId_ = do let rq = rcvQueueSub rq' same <- atomically $ do modifyTVar' (subscrConns c) $ S.insert $ qConnId rq active <- activeClientSession c tSess sessId if active - then SS.addActiveSub tSess sessId rq' $ currentSubs c + then SS.addActiveSub tSess sessId serviceId_ rq' $ currentSubs c else SS.addPendingSub tSess rq $ currentSubs c pure active unless same $ resubscribeSMPSession c tSess diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index d5b35611b..ef9bc592f 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -393,6 +393,7 @@ data AEvent (e :: AEntity) where SERVICE_ALL :: SMPServer -> AEvent AENone -- all service messages are delivered SERVICE_DOWN :: SMPServer -> ServiceSub -> AEvent AENone SERVICE_UP :: SMPServer -> ServiceSubResult -> AEvent AENone + SERVICE_END :: SMPServer -> ServiceSub -> AEvent AENone SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> AEvent AEConn RSYNC :: RatchetSyncState -> Maybe AgentCryptoError -> ConnectionStats -> AEvent AEConn SENT :: AgentMsgId -> Maybe SMPServer -> AEvent AEConn @@ -467,6 +468,7 @@ data AEventTag (e :: AEntity) where SERVICE_ALL_ :: AEventTag AENone SERVICE_DOWN_ :: AEventTag AENone SERVICE_UP_ :: AEventTag AENone + SERVICE_END_ :: AEventTag AENone SWITCH_ :: AEventTag AEConn RSYNC_ :: AEventTag AEConn SENT_ :: AEventTag AEConn @@ -525,6 +527,7 @@ aEventTag = \case SERVICE_ALL _ -> SERVICE_ALL_ SERVICE_DOWN {} -> SERVICE_DOWN_ SERVICE_UP {} -> SERVICE_UP_ + SERVICE_END {} -> SERVICE_END_ SWITCH {} -> SWITCH_ RSYNC {} -> RSYNC_ SENT {} -> SENT_ diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index 9508e4499..853a76908 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -38,7 +38,6 @@ module Simplex.Messaging.Agent.Store.AgentStore -- * Client services createClientService, getClientServiceCredentials, - getSubscriptionServices, getSubscriptionService, getClientServiceServers, setClientServiceId, @@ -55,6 +54,7 @@ module Simplex.Messaging.Agent.Store.AgentStore getSubscriptionServers, getUserServerRcvQueueSubs, unassocUserServerRcvQueueSubs, + unassocUserServerRcvQueueSubs', unsetQueuesToSubscribe, setRcvServiceAssocs, removeRcvServiceAssocs, @@ -344,7 +344,7 @@ handleSQLError err e = case constraintViolation e of handleSQLError :: StoreError -> SQLError -> StoreError handleSQLError err e | SQL.sqlError e == SQL.ErrorConstraint = err - | otherwise = SEInternal $ bshow e + | otherwise = SEInternal $ encodeUtf8 $ tshow e <> ": " <> SQL.sqlErrorDetails e <> ", " <> SQL.sqlErrorContext e #endif createUserRecord :: DB.Connection -> IO UserId @@ -439,11 +439,6 @@ getClientServiceCredentials db userId srv = where toService (kh, cert, pk, serviceId_) = ((kh, (cert, pk)), serviceId_) -getSubscriptionServices :: DB.Connection -> IO [(UserId, (SMPServer, ServiceSub))] -getSubscriptionServices db = map toUserService <$> DB.query_ db clientServiceQuery - where - toUserService (Only userId :. serviceRow) = (userId, toServerService serviceRow) - getSubscriptionService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ServiceSub) getSubscriptionService db userId (SMPServer h p kh) = maybeFirstRow toService $ @@ -453,7 +448,7 @@ getSubscriptionService db userId (SMPServer h p kh) = SELECT c.service_id, c.service_queue_count, c.service_queue_ids_hash FROM client_services c JOIN servers s ON s.host = c.host AND s.port = c.port - WHERE c.user_id = ? AND c.host = ? AND c.port = ? AND COALESCE(c.server_key_hash, s.key_hash) = ? + WHERE c.user_id = ? AND c.host = ? AND c.port = ? AND COALESCE(c.server_key_hash, s.key_hash) = ? AND service_id IS NOT NULL |] (userId, h, p, kh) where @@ -461,15 +456,16 @@ getSubscriptionService db userId (SMPServer h p kh) = getClientServiceServers :: DB.Connection -> UserId -> IO [(SMPServer, ServiceSub)] getClientServiceServers db userId = - map toServerService <$> DB.query db (clientServiceQuery <> " WHERE c.user_id = ?") (Only userId) - -clientServiceQuery :: Query -clientServiceQuery = - [sql| - SELECT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash), c.service_id, c.service_queue_count, c.service_queue_ids_hash - FROM client_services c - JOIN servers s ON s.host = c.host AND s.port = c.port - |] + map toServerService <$> + DB.query + db + [sql| + SELECT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash), c.service_id, c.service_queue_count, c.service_queue_ids_hash + FROM client_services c + JOIN servers s ON s.host = c.host AND s.port = c.port + WHERE c.user_id = ? AND service_id IS NOT NULL + |] + (Only userId) toServerService :: (NonEmpty TransportHost, ServiceName, C.KeyHash, ServiceId, Int64, Binary ByteString) -> (ProtocolServer 'PSMP, ServiceSub) toServerService (host, port, kh, serviceId, n, Binary idsHash) = @@ -487,14 +483,20 @@ setClientServiceId db userId srv serviceId = (serviceId, userId, host srv, port srv) deleteClientService :: DB.Connection -> UserId -> SMPServer -> IO () -deleteClientService db userId srv = +deleteClientService db userId (SMPServer h p kh) = DB.execute db [sql| DELETE FROM client_services WHERE user_id = ? AND host = ? AND port = ? + AND EXISTS ( + SELECT 1 FROM servers s + WHERE s.host = client_services.host + AND s.port = client_services.port + AND COALESCE(client_services.server_key_hash, s.key_hash) = ? + ); |] - (userId, host srv, port srv) + (userId, h, p, Just kh) deleteClientServices :: DB.Connection -> UserId -> IO () deleteClientServices db userId = do @@ -2279,7 +2281,8 @@ getUserServerRcvQueueSubs db userId (SMPServer h p kh) onlyNeeded hasService = | otherwise = "" unassocUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> IO [RcvQueueSub] -unassocUserServerRcvQueueSubs db userId (SMPServer h p kh) = +unassocUserServerRcvQueueSubs db userId srv@(SMPServer h p kh) = do + deleteClientService db userId srv map toRcvQueueSub <$> DB.query db @@ -2293,6 +2296,11 @@ unassocUserServerRcvQueueSubs db userId (SMPServer h p kh) = rcv_queues.rcv_queue_id, rcv_queues.rcv_primary, rcv_queues.replace_rcv_queue_id |] +unassocUserServerRcvQueueSubs' :: DB.Connection -> UserId -> SMPServer -> IO () +unassocUserServerRcvQueueSubs' db userId srv@(SMPServer h p kh) = do + deleteClientService db userId srv + DB.execute db removeRcvAssocsQuery (h, p, userId, kh) + unsetQueuesToSubscribe :: DB.Connection -> IO () unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = 0 WHERE to_subscribe = 1" diff --git a/src/Simplex/Messaging/Agent/TSessionSubs.hs b/src/Simplex/Messaging/Agent/TSessionSubs.hs index ab15b9793..a1db48c9e 100644 --- a/src/Simplex/Messaging/Agent/TSessionSubs.hs +++ b/src/Simplex/Messaging/Agent/TSessionSubs.hs @@ -23,8 +23,10 @@ module Simplex.Messaging.Agent.TSessionSubs batchDeletePendingSubs, deleteSub, batchDeleteSubs, + deleteServiceSub, hasPendingSubs, getPendingSubs, + getPendingQueueSubs, getActiveSubs, setSubsPending, updateClientNotices, @@ -39,12 +41,12 @@ import Data.Int (Int64) import Data.List (foldl') import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Data.Maybe (isJust) +import Data.Maybe (fromMaybe, isJust) import qualified Data.Set as S import Simplex.Messaging.Agent.Protocol (SMPQueue (..)) -import Simplex.Messaging.Agent.Store (RcvQueue, RcvQueueSub (..), SomeRcvQueue, StoredRcvQueue (rcvServiceAssoc), rcvQueueSub) +import Simplex.Messaging.Agent.Store (RcvQueue, RcvQueueSub (..), ServiceAssoc, SomeRcvQueue, StoredRcvQueue (rcvServiceAssoc), rcvQueueSub) import Simplex.Messaging.Client (SMPTransportSession, TransportSessionMode (..)) -import Simplex.Messaging.Protocol (RecipientId, ServiceSub (..), queueIdHash) +import Simplex.Messaging.Protocol (IdsHash, RecipientId, ServiceSub (..), queueIdHash) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport @@ -119,40 +121,48 @@ setActiveServiceSub tSess sessId serviceSub ss = do writeTVar (pendingServiceSub s) Nothing else writeTVar (pendingServiceSub s) $ Just serviceSub -addActiveSub :: SMPTransportSession -> SessionId -> RcvQueue -> TSessionSubs -> STM () -addActiveSub tSess sessId rq = addActiveSub' tSess sessId (rcvQueueSub rq) (rcvServiceAssoc rq) +addActiveSub :: SMPTransportSession -> SessionId -> Maybe ServiceId -> RcvQueue -> TSessionSubs -> STM () +addActiveSub tSess sessId serviceId_ rq = addActiveSub' tSess sessId serviceId_ (rcvQueueSub rq) (rcvServiceAssoc rq) {-# INLINE addActiveSub #-} -addActiveSub' :: SMPTransportSession -> SessionId -> RcvQueueSub -> Bool -> TSessionSubs -> STM () -addActiveSub' tSess sessId rq serviceAssoc ss = do +addActiveSub' :: SMPTransportSession -> SessionId -> Maybe ServiceId -> RcvQueueSub -> ServiceAssoc -> TSessionSubs -> STM () +addActiveSub' tSess sessId serviceId_ rq serviceAssoc ss = do s <- getSessSubs tSess ss sessId' <- readTVar $ subsSessId s let rId = rcvId rq if Just sessId == sessId' then do - TM.insert rId rq $ activeSubs s TM.delete rId $ pendingSubs s - when serviceAssoc $ - let updateServiceSub (ServiceSub serviceId n idsHash) = ServiceSub serviceId (n + 1) (idsHash <> queueIdHash rId) - in modifyTVar' (activeServiceSub s) (updateServiceSub <$>) + case serviceId_ of + Just serviceId | serviceAssoc -> updateActiveService s serviceId 1 (queueIdHash rId) + _ -> TM.insert rId rq $ activeSubs s else TM.insert rId rq $ pendingSubs s -batchAddActiveSubs :: SMPTransportSession -> SessionId -> ([RcvQueueSub], [RcvQueueSub]) -> TSessionSubs -> STM () -batchAddActiveSubs tSess sessId (rqs, serviceRQs) ss = do +batchAddActiveSubs :: SMPTransportSession -> SessionId -> Maybe ServiceId -> ([RcvQueueSub], [RcvQueueSub]) -> TSessionSubs -> STM () +batchAddActiveSubs tSess sessId serviceId_ (rqs, serviceRQs) ss = do s <- getSessSubs tSess ss sessId' <- readTVar $ subsSessId s - let qs = M.fromList $ map (\rq -> (rcvId rq, rq)) rqs + let qs = queuesMap rqs + serviceQs = queuesMap serviceRQs if Just sessId == sessId' then do TM.union qs $ activeSubs s modifyTVar' (pendingSubs s) (`M.difference` qs) - serviceSub_ <- readTVar $ activeServiceSub s - forM_ serviceSub_ $ \(ServiceSub serviceId n idsHash) -> do - unless (null serviceRQs) $ do - let idsHash' = idsHash <> mconcat (map (queueIdHash . rcvId) serviceRQs) - n' = n + fromIntegral (length serviceRQs) - writeTVar (activeServiceSub s) $ Just $ ServiceSub serviceId n' idsHash' - else TM.union qs $ pendingSubs s + unless (null serviceRQs) $ forM_ serviceId_ $ \serviceId -> do + modifyTVar' (pendingSubs s) (`M.difference` serviceQs) + updateActiveService s serviceId (fromIntegral $ length serviceRQs) (mconcat $ map (queueIdHash . rcvId) serviceRQs) + else do + TM.union qs $ pendingSubs s + when (isJust serviceId_ && not (null serviceRQs)) $ TM.union serviceQs $ pendingSubs s + where + queuesMap = M.fromList . map (\rq -> (rcvId rq, rq)) + +updateActiveService :: SessSubs -> ServiceId -> Int64 -> IdsHash -> STM () +updateActiveService s serviceId addN addIdsHash = do + ServiceSub serviceId' n idsHash <- + fromMaybe (ServiceSub serviceId 0 mempty) <$> readTVar (activeServiceSub s) + when (serviceId == serviceId') $ + writeTVar (activeServiceSub s) $ Just $ ServiceSub serviceId (n + addN) (idsHash <> addIdsHash) batchAddPendingSubs :: SMPTransportSession -> [RcvQueueSub] -> TSessionSubs -> STM () batchAddPendingSubs tSess rqs ss = do @@ -176,6 +186,9 @@ batchDeleteSubs tSess rqs = lookupSubs tSess >=> mapM_ (\s -> delete (activeSubs rIds = S.fromList $ map queueId rqs delete = (`modifyTVar'` (`M.withoutKeys` rIds)) +deleteServiceSub :: SMPTransportSession -> TSessionSubs -> STM () +deleteServiceSub tSess = lookupSubs tSess >=> mapM_ (\s -> writeTVar (activeServiceSub s) Nothing >> writeTVar (pendingServiceSub s) Nothing) + hasPendingSubs :: SMPTransportSession -> TSessionSubs -> STM Bool hasPendingSubs tSess = lookupSubs tSess >=> maybe (pure False) (\s -> anyM [hasSubs s, hasServiceSub s]) where @@ -187,6 +200,10 @@ getPendingSubs tSess = lookupSubs tSess >=> maybe (pure (M.empty, Nothing)) get where get s = liftM2 (,) (readTVar $ pendingSubs s) (readTVar $ pendingServiceSub s) +getPendingQueueSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub) +getPendingQueueSubs = getSubs_ pendingSubs +{-# INLINE getPendingQueueSubs #-} + getActiveSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub) getActiveSubs = getSubs_ activeSubs {-# INLINE getActiveSubs #-} diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 51128597c..4993aaac8 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -147,7 +147,6 @@ module Simplex.Messaging.Protocol serviceSubResult, queueIdsHash, queueIdHash, - noIdsHash, addServiceSubs, subtractServiceSubs, MaxMessageLen, @@ -726,7 +725,7 @@ data BrokerMsg where RRES :: EncFwdResponse -> BrokerMsg -- relay to proxy PRES :: EncResponse -> BrokerMsg -- proxy to client END :: BrokerMsg - ENDS :: Int64 -> BrokerMsg + ENDS :: Int64 -> IdsHash -> BrokerMsg DELD :: BrokerMsg INFO :: QueueInfo -> BrokerMsg OK :: BrokerMsg @@ -1518,10 +1517,6 @@ instance Monoid IdsHash where xor' :: Word8 -> Word8 -> Word8 xor' x y = let !r = xor x y in r -noIdsHash ::IdsHash -noIdsHash = IdsHash B.empty -{-# INLINE noIdsHash #-} - queueIdsHash :: [QueueId] -> IdsHash queueIdsHash = mconcat . map queueIdHash @@ -1535,7 +1530,7 @@ addServiceSubs (n', idsHash') (n, idsHash) = (n + n', idsHash <> idsHash') subtractServiceSubs :: (Int64, IdsHash) -> (Int64, IdsHash) -> (Int64, IdsHash) subtractServiceSubs (n', idsHash') (n, idsHash) | n > n' = (n - n', idsHash <> idsHash') -- concat is a reversible xor: (x `xor` y) `xor` y == x - | otherwise = (0, noIdsHash) + | otherwise = (0, mempty) data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock @@ -1883,7 +1878,7 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where QUE_ -> pure QUE CT SRecipientService SUBS_ | v >= rcvServiceSMPVersion -> Cmd SRecipientService <$> (SUBS <$> _smpP <*> smpP) - | otherwise -> pure $ Cmd SRecipientService $ SUBS (-1) noIdsHash + | otherwise -> pure $ Cmd SRecipientService $ SUBS (-1) mempty CT SSender tag -> Cmd SSender <$> case tag of SKEY_ -> SKEY <$> _smpP @@ -1902,7 +1897,7 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB CT SNotifierService NSUBS_ | v >= rcvServiceSMPVersion -> Cmd SNotifierService <$> (NSUBS <$> _smpP <*> smpP) - | otherwise -> pure $ Cmd SNotifierService $ NSUBS (-1) noIdsHash + | otherwise -> pure $ Cmd SNotifierService $ NSUBS (-1) mempty fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg {-# INLINE fromProtocolError #-} @@ -1925,9 +1920,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where SOK serviceId_ | v >= serviceCertsSMPVersion -> e (SOK_, ' ', serviceId_) | otherwise -> e OK_ -- won't happen, the association with the service requires v >= serviceCertsSMPVersion - SOKS n idsHash - | v >= rcvServiceSMPVersion -> e (SOKS_, ' ', n, idsHash) - | otherwise -> e (SOKS_, ' ', n) + SOKS n idsHash -> serviceResp SOKS_ n idsHash MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -> e (MSG_, ' ', msgId, Tail body) ALLS -> e ALLS_ @@ -1937,7 +1930,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where RRES (EncFwdResponse encBlock) -> e (RRES_, ' ', Tail encBlock) PRES (EncResponse encBlock) -> e (PRES_, ' ', Tail encBlock) END -> e END_ - ENDS n -> e (ENDS_, ' ', n) + ENDS n idsHash -> serviceResp ENDS_ n idsHash DELD | v >= deletedEventSMPVersion -> e DELD_ | otherwise -> e END_ @@ -1954,6 +1947,9 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where where e :: Encoding a => a -> ByteString e = smpEncode + serviceResp tag n idsHash + | v >= serviceCertsSMPVersion = e (tag, ' ', n, idsHash) + | otherwise = e (tag, ' ', n) protocolP v = \case MSG_ -> do @@ -1982,21 +1978,23 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId, serviceId, serverNtfCreds} LNK_ -> LNK <$> _smpP <*> smpP SOK_ -> SOK <$> _smpP - SOKS_ - | v >= rcvServiceSMPVersion -> SOKS <$> _smpP <*> smpP - | otherwise -> SOKS <$> _smpP <*> pure noIdsHash + SOKS_ -> serviceRespP SOKS NID_ -> NID <$> _smpP <*> smpP NMSG_ -> NMSG <$> _smpP <*> smpP PKEY_ -> PKEY <$> _smpP <*> smpP <*> smpP RRES_ -> RRES <$> (EncFwdResponse . unTail <$> _smpP) PRES_ -> PRES <$> (EncResponse . unTail <$> _smpP) END_ -> pure END - ENDS_ -> ENDS <$> _smpP + ENDS_ -> serviceRespP ENDS DELD_ -> pure DELD INFO_ -> INFO <$> _smpP OK_ -> pure OK ERR_ -> ERR <$> _smpP PONG_ -> pure PONG + where + serviceRespP resp + | v >= serviceCertsSMPVersion = resp <$> _smpP <*> smpP + | otherwise = resp <$> _smpP <*> pure mempty fromProtocolError = \case PECmdSyntax -> CMD SYNTAX diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index b7bb0efaa..24247e781 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -316,8 +316,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt cancelServiceSubs :: ServiceId -> Maybe (Client s) -> STM [PrevClientSub s] cancelServiceSubs serviceId = checkAnotherClient $ \c -> do - changedSubs@(n, _) <- swapTVar (clientServiceSubs c) (0, noIdsHash) - pure [(c, CSADecreaseSubs changedSubs, (serviceId, ENDS n))] + changedSubs@(n, idsHash) <- swapTVar (clientServiceSubs c) (0, mempty) + pure [(c, CSADecreaseSubs changedSubs, (serviceId, ENDS n idsHash))] checkAnotherClient :: (Client s -> STM [PrevClientSub s]) -> Maybe (Client s) -> STM [PrevClientSub s] checkAnotherClient mkSub = \case Just c@Client {clientId, connected} | clntId /= clientId -> diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 02cf136c7..e59cd5c0b 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -502,7 +502,7 @@ newServerSubscribers = do subQ <- newTQueueIO queueSubscribers <- SubscribedClients <$> TM.emptyIO serviceSubscribers <- SubscribedClients <$> TM.emptyIO - totalServiceSubs <- newTVarIO (0, noIdsHash) + totalServiceSubs <- newTVarIO (0, mempty) subClients <- newTVarIO IS.empty pendingEvents <- newTVarIO IM.empty pure ServerSubscribers {subQ, queueSubscribers, serviceSubscribers, totalServiceSubs, subClients, pendingEvents} @@ -513,8 +513,8 @@ newClient clientId qSize clientTHParams createdAt = do ntfSubscriptions <- TM.emptyIO serviceSubscribed <- newTVarIO False ntfServiceSubscribed <- newTVarIO False - serviceSubsCount <- newTVarIO (0, noIdsHash) - ntfServiceSubsCount <- newTVarIO (0, noIdsHash) + serviceSubsCount <- newTVarIO (0, mempty) + ntfServiceSubsCount <- newTVarIO (0, mempty) rcvQ <- newTBQueueIO qSize sndQ <- newTBQueueIO qSize msgQ <- newTBQueueIO qSize diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index b63e4cb48..34448fc10 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -480,6 +480,7 @@ functionalAPITests ps = do describe "Client service certificates" $ do it "should connect, subscribe and reconnect as a service" $ testClientServiceConnection ps it "should re-subscribe when service ID changed" $ testClientServiceIDChange ps + it "migrate connections to and from service" $ testMigrateConnectionsToService ps describe "Connection switch" $ do describe "should switch delivery to the new queue" $ testServerMatrix2 ps testSwitchConnection @@ -3721,10 +3722,22 @@ testClientServiceConnection ps = do testClientServiceIDChange :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testClientServiceIDChange ps@(_, ASType qs _) = do (sId, uId) <- withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do - withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + conns <- withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do conns@(sId, uId) <- makeConnection service user exchangeGreetings service uId user sId pure conns + ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 1 _)) <- nGet service + ("", "", DOWN _ [_]) <- nGet user + withSmpServerStoreLogOn ps testPort $ \_ -> do + getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 _)))) -> True; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] + ("", "", UP _ [_]) <- nGet user + pure () + ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 1 _)) <- nGet service + ("", "", DOWN _ [_]) <- nGet user + pure conns _ :: () <- case qs of SQSPostgres -> do #if defined(dbServerPostgres) @@ -3739,19 +3752,21 @@ testClientServiceIDChange ps@(_, ASType qs _) = do writeFile testStoreLogFile $ unlines $ filter (not . ("NEW_SERVICE" `isPrefixOf`)) $ lines s withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + liftIO $ threadDelay 250000 subscribeAllConnections service False Nothing liftIO $ getInAnyOrder service [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult (Just (SMP.SSErrorQueueCount 1 0)) (SMP.ServiceSub _ 0 _)))) -> True; _ -> False, \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False, - \case ("", "", AEvt SAENone (UP _ _)) -> True; _ -> False + \case ("", "", AEvt SAENone (UP _ [_])) -> True; _ -> False ] subscribeAllConnections user False Nothing ("", "", UP _ [_]) <- nGet user exchangeGreetingsMsgId 4 service uId user sId + ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 1 _)) <- nGet service + ("", "", DOWN _ [_]) <- nGet user + pure () -- disable service in the client - -- The test uses True for non-existing user to make sure it's removed for user 1, - -- because if no users use services, then it won't be checking them to optimize for most clients. - withAgentClientsServers2 (agentCfg, initAgentServers {useServices = M.fromList [(100, True)]}) (agentCfg, initAgentServers) $ \notService user -> do + withAgentClientsServers2 (agentCfg, initAgentServers) (agentCfg, initAgentServers) $ \notService user -> do withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do subscribeAllConnections notService False Nothing ("", "", UP _ [_]) <- nGet notService @@ -3759,6 +3774,153 @@ testClientServiceIDChange ps@(_, ASType qs _) = do ("", "", UP _ [_]) <- nGet user exchangeGreetingsMsgId 6 notService uId user sId +testMigrateConnectionsToService :: HasCallStack => (ASrvTransport, AStoreType) -> IO () +testMigrateConnectionsToService ps = do + (((sId1, uId1), (uId2, sId2)), ((sId3, uId3), (uId4, sId4)), ((sId5, uId5), (uId6, sId6))) <- + withSmpServerStoreLogOn ps testPort $ \_ -> do + -- starting without service + cs12@((sId1, uId1), (uId2, sId2)) <- + withAgentClientsServers2 (agentCfg, initAgentServers) (agentCfg, initAgentServers) $ \notService user -> + runRight $ (,) <$> makeConnection notService user <*> makeConnection user notService + -- migrating to service + cs34@((sId3, uId3), (uId4, sId4)) <- + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> runRight $ do + subscribeAllConnections service False Nothing + service `up` 2 + subscribeAllConnections user False Nothing + user `up` 2 + exchangeGreetingsMsgId 2 service uId1 user sId1 + exchangeGreetingsMsgId 2 service uId2 user sId2 + (,) <$> makeConnection service user <*> makeConnection user service + -- starting as service + cs56 <- + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> runRight $ do + subscribeAllConnections service False Nothing + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 4 _)))) -> True; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] + subscribeAllConnections user False Nothing + user `up` 4 + exchangeGreetingsMsgId 4 service uId1 user sId1 + exchangeGreetingsMsgId 4 service uId2 user sId2 + exchangeGreetingsMsgId 2 service uId3 user sId3 + exchangeGreetingsMsgId 2 service uId4 user sId4 + (,) <$> makeConnection service user <*> makeConnection user service + pure (cs12, cs34, cs56) + -- server reconnecting resubscribes service + let testSendMessages6 s u n = do + exchangeGreetingsMsgId (n + 4) s uId1 u sId1 + exchangeGreetingsMsgId (n + 4) s uId2 u sId2 + exchangeGreetingsMsgId (n + 2) s uId3 u sId3 + exchangeGreetingsMsgId (n + 2) s uId4 u sId4 + exchangeGreetingsMsgId n s uId5 u sId5 + exchangeGreetingsMsgId n s uId6 u sId6 + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight_ $ do + subscribeAllConnections service False Nothing + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 6 _)))) -> True; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] + subscribeAllConnections user False Nothing + user `up` 6 + testSendMessages6 service user 2 + ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 6 _)) <- nGet service + user `down` 6 + withSmpServerStoreLogOn ps testPort $ \_ -> runRight_ $ do + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 6 _)))) -> True; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] + user `up` 6 + testSendMessages6 service user 4 + ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 6 _)) <- nGet service + user `down` 6 + -- disabling service and adding connections + ((sId7, uId7), (uId8, sId8)) <- + withAgentClientsServers2 (agentCfg, initAgentServers) (agentCfg, initAgentServers) $ \notService user -> do + cs78@((sId7, uId7), (uId8, sId8)) <- + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + subscribeAllConnections notService False Nothing + notService `up` 6 + subscribeAllConnections user False Nothing + user `up` 6 + testSendMessages6 notService user 6 + (,) <$> makeConnection notService user <*> makeConnection user notService + notService `down` 8 + user `down` 8 + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + notService `up` 8 + user `up` 8 + testSendMessages6 notService user 8 + exchangeGreetingsMsgId 2 notService uId7 user sId7 + exchangeGreetingsMsgId 2 notService uId8 user sId8 + notService `down` 8 + user `down` 8 + pure cs78 + let testSendMessages8 s u n = do + testSendMessages6 s u (n + 8) + exchangeGreetingsMsgId (n + 2) s uId7 u sId7 + exchangeGreetingsMsgId (n + 2) s uId8 u sId8 + -- re-enabling service and adding connections + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight_ $ do + subscribeAllConnections service False Nothing + service `up` 8 + subscribeAllConnections user False Nothing + user `up` 8 + testSendMessages8 service user 2 + ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 8 _)) <- nGet service + user `down` 8 + -- re-connect to server + withSmpServerStoreLogOn ps testPort $ \_ -> runRight_ $ do + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 8 _)))) -> True; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] + user `up` 8 + testSendMessages8 service user 4 + ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ _ _)) <- nGet service -- should be 8 here + user `down` 8 + -- restart agents + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight_ $ do + subscribeAllConnections service False Nothing + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 8 _)))) -> True; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] + subscribeAllConnections user False Nothing + user `up` 8 + testSendMessages8 service user 6 + ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 8 _)) <- nGet service + user `down` 8 + runRight_ $ do + void $ sendMessage user sId7 SMP.noMsgFlags "hello 1" + void $ sendMessage user sId8 SMP.noMsgFlags "hello 2" + -- re-connect to server + withSmpServerStoreLogOn ps testPort $ \_ -> runRight_ $ do + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 8 _)))) -> True; _ -> False, + \case ("", c, AEvt SAEConn (Msg "hello 1")) -> c == uId7; _ -> False, + \case ("", c, AEvt SAEConn (Msg "hello 2")) -> c == uId8; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] + liftIO $ getInAnyOrder user + [ \case ("", "", AEvt SAENone (UP _ [_, _, _, _, _, _, _, _])) -> True; _ -> False, + \case ("", c, AEvt SAEConn (SENT 10)) -> c == sId7; _ -> False, + \case ("", c, AEvt SAEConn (SENT 10)) -> c == sId8; _ -> False + ] + testSendMessages6 service user 16 + where + up c n = do + ("", "", UP _ conns) <- nGet c + liftIO $ length conns `shouldBe` n + down c n = do + ("", "", DOWN _ conns) <- nGet c + liftIO $ length conns `shouldBe` n + getSMPAgentClient' :: Int -> AgentConfig -> InitialAgentServers -> String -> IO AgentClient getSMPAgentClient' clientId cfg' initServers dbPath = do Right st <- liftIO $ createStore dbPath diff --git a/tests/CoreTests/TSessionSubs.hs b/tests/CoreTests/TSessionSubs.hs index e9038b9d9..96975e9ef 100644 --- a/tests/CoreTests/TSessionSubs.hs +++ b/tests/CoreTests/TSessionSubs.hs @@ -69,21 +69,21 @@ testSessionSubs = do atomically (SS.hasPendingSub tSess1 (rcvId q4) ss) `shouldReturn` False atomically (SS.hasActiveSub tSess1 (rcvId q4) ss) `shouldReturn` False -- setting active queue without setting session ID would keep it as pending - atomically $ SS.addActiveSub' tSess1 "123" q1 False ss + atomically $ SS.addActiveSub' tSess1 "123" Nothing q1 False ss atomically (SS.hasPendingSub tSess1 (rcvId q1) ss) `shouldReturn` True atomically (SS.hasActiveSub tSess1 (rcvId q1) ss) `shouldReturn` False dumpSessionSubs ss `shouldReturn` st countSubs ss `shouldReturn` (0, 3) -- setting active queues atomically $ SS.setSessionId tSess1 "123" ss - atomically $ SS.addActiveSub' tSess1 "123" q1 False ss + atomically $ SS.addActiveSub' tSess1 "123" Nothing q1 False ss atomically (SS.hasPendingSub tSess1 (rcvId q1) ss) `shouldReturn` False atomically (SS.hasActiveSub tSess1 (rcvId q1) ss) `shouldReturn` True atomically (SS.getActiveSubs tSess1 ss) `shouldReturn` M.fromList [("r1", q1)] atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` (M.fromList [("r2", q2)], Nothing) countSubs ss `shouldReturn` (1, 2) atomically $ SS.setSessionId tSess2 "456" ss - atomically $ SS.addActiveSub' tSess2 "456" q4 False ss + atomically $ SS.addActiveSub' tSess2 "456" Nothing q4 False ss atomically (SS.hasPendingSub tSess2 (rcvId q4) ss) `shouldReturn` False atomically (SS.hasActiveSub tSess2 (rcvId q4) ss) `shouldReturn` True atomically (SS.hasActiveSub tSess1 (rcvId q4) ss) `shouldReturn` False -- wrong transport session diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 82a39af39..27a72d2ac 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -1334,7 +1334,7 @@ testMessageServiceNotifications = Resp "4" _ (SOK (Just serviceId')) <- serviceSignSendRecv nh2 nKey servicePK ("4", nId, NSUB) serviceId' `shouldBe` serviceId -- service subscription is terminated - Resp "" serviceId2 (ENDS 1) <- tGet1 nh1 + Resp "" serviceId2 (ENDS 1 _) <- tGet1 nh1 serviceId2 `shouldBe` serviceId deliverMessage rh rId rKey sh sId sKey nh2 "hello again" dec 1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case @@ -1374,7 +1374,7 @@ testMessageServiceNotifications = Resp "12" serviceId5 (SOKS 2 idsHash') <- signSendRecv nh1 (C.APrivateAuthKey C.SEd25519 servicePK) ("12", serviceId, NSUBS 2 idsHash) idsHash' `shouldBe` idsHash serviceId5 `shouldBe` serviceId - Resp "" serviceId6 (ENDS 2) <- tGet1 nh2 + Resp "" serviceId6 (ENDS 2 _) <- tGet1 nh2 serviceId6 `shouldBe` serviceId deliverMessage rh rId rKey sh sId sKey nh1 "connection 1 one more" dec deliverMessage rh rId'' rKey'' sh sId'' sKey'' nh1 "connection 2 one more" dec'' From 11ae20ea20e0eca886f698ac2305387e3c08da83 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Mon, 22 Dec 2025 07:56:53 +0000 Subject: [PATCH 08/11] ntf server: use different client certs for each SMP server, remove support for store log (#1681) * ntf server: remove support for store log * ntf server: use different client certificates for each SMP server --- simplexmq.cabal | 1 - src/Simplex/FileTransfer/Client.hs | 3 +- src/Simplex/Messaging/Agent/Client.hs | 12 +- .../Messaging/Agent/Store/AgentStore.hs | 14 +- src/Simplex/Messaging/Client.hs | 14 +- src/Simplex/Messaging/Client/Agent.hs | 32 +- src/Simplex/Messaging/Notifications/Server.hs | 2 +- .../Messaging/Notifications/Server/Env.hs | 63 +-- .../Messaging/Notifications/Server/Main.hs | 96 +--- .../Notifications/Server/Store/Migrations.hs | 36 +- .../Notifications/Server/Store/Postgres.hs | 521 +++++++----------- .../Server/Store/ntf_server_schema.sql | 5 +- .../Notifications/Server/StoreLog.hs | 177 ------ src/Simplex/Messaging/Server.hs | 5 +- src/Simplex/Messaging/Server/Env/STM.hs | 2 +- .../Messaging/Transport/HTTP2/Client.hs | 9 +- tests/AgentTests/FunctionalAPITests.hs | 3 + 17 files changed, 322 insertions(+), 673 deletions(-) delete mode 100644 src/Simplex/Messaging/Notifications/Server/StoreLog.hs diff --git a/simplexmq.cabal b/simplexmq.cabal index 3f9d1f61d..13759a05a 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -275,7 +275,6 @@ library Simplex.Messaging.Notifications.Server.Store.Migrations Simplex.Messaging.Notifications.Server.Store.Postgres Simplex.Messaging.Notifications.Server.Store.Types - Simplex.Messaging.Notifications.Server.StoreLog Simplex.Messaging.Server.MsgStore.Postgres Simplex.Messaging.Server.QueueStore.Postgres Simplex.Messaging.Server.QueueStore.Postgres.Migrations diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index 62f06b7d3..a425138e5 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -11,6 +11,7 @@ module Simplex.FileTransfer.Client where +import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad import Control.Monad.Except @@ -264,7 +265,7 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec { where errors = [ Handler $ \(e :: H.HTTP2Error) -> pure $ Left $ PCENetworkError $ NEConnectError $ displayException e, - Handler $ \(e :: IOException) -> pure $ Left $ PCEIOError e, + Handler $ \(e :: IOException) -> pure $ Left $ PCEIOError $ E.displayException e, Handler $ \(e :: SomeException) -> pure $ Left $ PCENetworkError $ toNetworkError e ] download cbState = diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 9bf1afd8d..4fd9eb175 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -751,8 +751,8 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c updateClientService service smp pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs} - updateClientService service smp = case (service, smpClientService smp) of - (Just (_, serviceId_), Just THClientService {serviceId}) -> withStore' c $ \db -> do + updateClientService service smp = case (service, smpClientServiceId smp) of + (Just (_, serviceId_), Just serviceId) -> withStore' c $ \db -> do setClientServiceId db userId srv serviceId forM_ serviceId_ $ \sId -> when (sId /= serviceId) $ removeRcvServiceAssocs db userId srv (Just _, Nothing) -> withStore' c $ \db -> deleteClientService db userId srv -- e.g., server version downgrade @@ -1255,7 +1255,7 @@ protocolClientError protocolError_ host = \case PCETransportError e -> BROKER host $ TRANSPORT e e@PCECryptoError {} -> INTERNAL $ show e PCEServiceUnavailable {} -> BROKER host NO_SERVICE - PCEIOError e -> BROKER host $ NETWORK $ NEConnectError $ E.displayException e + PCEIOError e -> BROKER host $ NETWORK $ NEConnectError e -- it is consistent with smpClientServiceError clientServiceError :: AgentErrorType -> Bool @@ -1546,6 +1546,7 @@ processSubResults c tSess@(userId, srv, _) sessId serviceId_ rs = do Left e -> case smpErrorClientNotice e of Just notice_ -> (failed', subscribed, (rq, notice_) : notices, ignored) where + -- TODO [certs rcv] not used? notices' = if isJust notice_ || isJust clientNoticeId then (rq, notice_) : notices else notices Nothing | temporaryClientError e -> acc @@ -1678,7 +1679,7 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c (active, (serviceQs, notices)) <- atomically $ do r@(_, (_, notices)) <- ifM (activeClientSession c tSess sessId) - ((True,) <$> processSubResults c tSess sessId smpServiceId rs) + ((True,) <$> processSubResults c tSess sessId (smpClientServiceId smp) rs) ((False, ([], [])) <$ incSMPServerStat' c userId srv connSubIgnored (length rs)) unless (null notices) $ takeTMVar $ clientNoticesLock c pure r @@ -1704,7 +1705,6 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c where tSess = transportSession' smp sessId = sessionId $ thParams smp - smpServiceId = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp processRcvServiceAssocs :: SMPQueue q => AgentClient -> [q] -> AM' () processRcvServiceAssocs _ [] = pure () @@ -1752,7 +1752,7 @@ subscribeClientService c withEvent userId srv (ServiceSub _ n idsHash) = withServiceClient :: AgentClient -> SMPTransportSession -> (SMPClient -> ServiceId -> ExceptT SMPClientError IO a) -> AM a withServiceClient c tSess subscribe = withLogClient c NRMBackground tSess B.empty "SUBS" $ \(SMPConnectedClient smp _) -> - case (\THClientService {serviceId} -> serviceId) <$> smpClientService smp of + case smpClientServiceId smp of Just smpServiceId -> subscribe smp smpServiceId Nothing -> throwE PCEServiceUnavailable diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index 853a76908..2dcb76327 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -472,15 +472,21 @@ toServerService (host, port, kh, serviceId, n, Binary idsHash) = (SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash)) setClientServiceId :: DB.Connection -> UserId -> SMPServer -> ServiceId -> IO () -setClientServiceId db userId srv serviceId = +setClientServiceId db userId (SMPServer h p kh) serviceId = DB.execute db [sql| UPDATE client_services SET service_id = ? - WHERE user_id = ? AND host = ? AND port = ? + FROM servers s + WHERE client_services.user_id = ? + AND client_services.host = ? + AND client_services.port = ? + AND s.host = client_services.host + AND s.port = client_services.port + AND COALESCE(client_services.server_key_hash, s.key_hash) = ? |] - (serviceId, userId, host srv, port srv) + (serviceId, userId, h, p, kh) deleteClientService :: DB.Connection -> UserId -> SMPServer -> IO () deleteClientService db userId (SMPServer h p kh) = @@ -2307,7 +2313,7 @@ unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = setRcvServiceAssocs :: SMPQueue q => DB.Connection -> [q] -> IO () setRcvServiceAssocs db rqs = #if defined(dbPostgres) - DB.execute db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id IN " $ Only $ In (map queueId rqs) + DB.execute db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id IN ?" $ Only $ In (map queueId rqs) #else DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = ?" $ map (Only . queueId) rqs #endif diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index ac2dc9a9d..ebc458c0e 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -52,6 +52,7 @@ module Simplex.Messaging.Client subscribeSMPQueuesNtfs, subscribeService, smpClientService, + smpClientServiceId, secureSMPQueue, secureSndSMPQueue, proxySecureSndSMPQueue, @@ -128,7 +129,8 @@ import Control.Applicative ((<|>)) import Control.Concurrent (ThreadId, forkFinally, forkIO, killThread, mkWeakThreadId) import Control.Concurrent.Async import Control.Concurrent.STM -import Control.Exception +import Control.Exception (Exception, SomeException) +import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad import Control.Monad.Except @@ -565,7 +567,7 @@ getProtocolClient g nm transportSession@(_, srv, _) cfg@ProtocolClientConfig {qS case chooseTransportHost networkConfig (host srv) of Right useHost -> (getCurrentTime >>= mkProtocolClient useHost >>= runClient useTransport useHost) - `catch` \(e :: IOException) -> pure . Left $ PCEIOError e + `E.catch` \(e :: SomeException) -> pure $ Left $ PCEIOError $ E.displayException e Left e -> pure $ Left e where NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig @@ -638,7 +640,7 @@ getProtocolClient g nm transportSession@(_, srv, _) cfg@ProtocolClientConfig {qS writeTVar (connected c) True putTMVar cVar $ Right c' raceAny_ ([send c' th, process c', receive c' th] <> [monitor c' | smpPingInterval > 0]) - `finally` disconnected c' + `E.finally` disconnected c' send :: Transport c => ProtocolClient v err msg -> THandle v c 'TClient -> IO () send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= sendPending @@ -765,7 +767,7 @@ data ProtocolClientError err | -- | Error when cryptographically "signing" the command or when initializing crypto_box. PCECryptoError C.CryptoError | -- | IO Error - PCEIOError IOException + PCEIOError String deriving (Eq, Show, Exception) type SMPClientError = ProtocolClientError ErrorType @@ -926,6 +928,10 @@ smpClientService :: SMPClient -> Maybe THClientService smpClientService = thAuth . thParams >=> clientService {-# INLINE smpClientService #-} +smpClientServiceId :: SMPClient -> Maybe ServiceId +smpClientServiceId = fmap (\THClientService {serviceId} -> serviceId) . smpClientService +{-# INLINE smpClientServiceId #-} + enablePings :: SMPClient -> IO () enablePings ProtocolClient {client_ = PClient {sendPings}} = atomically $ writeTVar sendPings True {-# INLINE enablePings #-} diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 45d747d21..9739c19c7 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -15,6 +15,7 @@ module Simplex.Messaging.Client.Agent ( SMPClientAgent (..), SMPClientAgentConfig (..), SMPClientAgentEvent (..), + DBService (..), OwnServer, defaultSMPClientAgentConfig, newSMPClientAgent, @@ -133,6 +134,7 @@ defaultSMPClientAgentConfig = data SMPClientAgent p = SMPClientAgent { agentCfg :: SMPClientAgentConfig, agentParty :: SParty p, + dbService :: Maybe DBService, active :: TVar Bool, startedAt :: UTCTime, msgQ :: TBQueue (ServerTransmissionBatch SMPVersion ErrorType BrokerMsg), @@ -155,8 +157,8 @@ data SMPClientAgent p = SMPClientAgent type OwnServer = Bool -newSMPClientAgent :: SParty p -> SMPClientAgentConfig -> TVar ChaChaDRG -> IO (SMPClientAgent p) -newSMPClientAgent agentParty agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg = do +newSMPClientAgent :: SParty p -> SMPClientAgentConfig -> Maybe DBService -> TVar ChaChaDRG -> IO (SMPClientAgent p) +newSMPClientAgent agentParty agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} dbService randomDrg = do active <- newTVarIO True startedAt <- getCurrentTime msgQ <- newTBQueueIO msgQSize @@ -173,6 +175,7 @@ newSMPClientAgent agentParty agentCfg@SMPClientAgentConfig {msgQSize, agentQSize SMPClientAgent { agentCfg, agentParty, + dbService, active, startedAt, msgQ, @@ -188,6 +191,11 @@ newSMPClientAgent agentParty agentCfg@SMPClientAgentConfig {msgQSize, agentQSize workerSeq } +data DBService = DBService + { getCredentials :: SMPServer -> IO (Either SMPClientError ServiceCredentials), + updateServiceId :: SMPServer -> Maybe ServiceId -> IO (Either SMPClientError ()) + } + -- | Get or create SMP client for SMPServer getSMPServerClient' :: SMPClientAgent p -> SMPServer -> ExceptT SMPClientError IO SMPClient getSMPServerClient' ca srv = snd <$> getSMPServerClient'' ca srv @@ -218,7 +226,7 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke newSMPClient :: SMPClientVar -> IO (Either SMPClientError (OwnServer, SMPClient)) newSMPClient v = do - r <- connectClient ca srv v `E.catch` (pure . Left . PCEIOError) + r <- connectClient ca srv v `E.catch` \(e :: E.SomeException) -> pure $ Left $ PCEIOError $ E.displayException e case r of Right smp -> do logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv @@ -227,8 +235,7 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke atomically $ do putTMVar (sessionVar v) (Right c) TM.insert (sessionId $ thParams smp) c smpSessions - let serviceId_ = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp - notify ca $ CAConnected srv serviceId_ + notify ca $ CAConnected srv $ smpClientServiceId smp pure $ Right c Left e -> do let ei = persistErrorInterval agentCfg @@ -249,9 +256,18 @@ isOwnServer SMPClientAgent {agentCfg} ProtocolServer {host} = -- | Run an SMP client for SMPClientVar connectClient :: SMPClientAgent p -> SMPServer -> SMPClientVar -> IO (Either SMPClientError SMPClient) -connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, randomDrg, startedAt} srv v = - getProtocolClient randomDrg NRMBackground (1, srv, Nothing) (smpCfg agentCfg) [] (Just msgQ) startedAt clientDisconnected +connectClient ca@SMPClientAgent {agentCfg, dbService, smpClients, smpSessions, msgQ, randomDrg, startedAt} srv v = case dbService of + Just dbs -> runExceptT $ do + creds <- ExceptT $ getCredentials dbs srv + smp <- ExceptT $ getClient cfg {serviceCredentials = Just creds} + whenM (atomically $ activeClientSession ca smp srv) $ + ExceptT $ updateServiceId dbs srv $ smpClientServiceId smp + pure smp + Nothing -> getClient cfg where + cfg = smpCfg agentCfg + getClient cfg' = getProtocolClient randomDrg NRMBackground (1, srv, Nothing) cfg' [] (Just msgQ) startedAt clientDisconnected + clientDisconnected :: SMPClient -> IO () clientDisconnected smp = do removeClientAndSubs smp >>= serverDown @@ -435,7 +451,7 @@ smpSubscribeQueues ca smp srv subs = do unless (null notPending) $ removePendingSubs ca srv notPending pure acc sessId = sessionId $ thParams smp - smpServiceId = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp + smpServiceId = smpClientServiceId smp groupSub :: Map QueueId C.APrivateAuthKey -> ((QueueId, C.APrivateAuthKey), Either SMPClientError (Maybe ServiceId)) -> diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index e7c1ca5f9..7d9e36c99 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -588,7 +588,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = logError $ "SMP server service subscription error " <> showService srv serviceSub <> ": " <> tshow e CAServiceUnavailable srv serviceSub -> do logError $ "SMP server service unavailable: " <> showService srv serviceSub - removeServiceAssociation st srv >>= \case + removeServiceAndAssociations st srv >>= \case Right (srvId, updated) -> do logSubStatus srv "removed service association" updated updated void $ subscribeSrvSubs ca st batchSize (srv, srvId, Nothing) diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index b0eafbc63..9ac89a12d 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -4,13 +4,14 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} module Simplex.Messaging.Notifications.Server.Env where import Control.Concurrent (ThreadId) -import Control.Logger.Simple -import Control.Monad +import Control.Monad.Except +import Control.Monad.Trans.Except import Crypto.Random import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) @@ -21,28 +22,26 @@ import qualified Data.X509.Validation as XV import Network.Socket import qualified Network.TLS as TLS import Numeric.Natural -import Simplex.Messaging.Client (ProtocolClientConfig (..)) +import Simplex.Messaging.Client (ProtocolClientError (..), SMPClientError) import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Server.Stats -import Simplex.Messaging.Notifications.Server.Store (newNtfSTMStore) import Simplex.Messaging.Notifications.Server.Store.Postgres import Simplex.Messaging.Notifications.Server.Store.Types -import Simplex.Messaging.Notifications.Server.StoreLog (readWriteNtfSTMStore) import Simplex.Messaging.Notifications.Transport (NTFVersion, VersionRangeNTF) -import Simplex.Messaging.Protocol (BasicAuth, CorrId, Party (..), SMPServer, SParty (..), Transmission) +import Simplex.Messaging.Protocol (BasicAuth, CorrId, Party (..), SMPServer, SParty (..), ServiceId, Transmission) import Simplex.Messaging.Server.Env.STM (StartOptions (..)) import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) -import Simplex.Messaging.Server.StoreLog (closeStoreLog) import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ASrvTransport, SMPServiceRole (..), ServiceCredentials (..), THandleParams, TransportPeer (..)) +import Simplex.Messaging.Transport.Credentials (genCredentials, tlsCredentials) import Simplex.Messaging.Transport.Server (AddHTTP, ServerCredentials, TransportServerConfig, loadFingerprint, loadServerCredential) -import System.Exit (exitFailure) +import Simplex.Messaging.Util (liftEitherWith) import System.Mem.Weak (Weak) import UnliftIO.STM @@ -96,33 +95,35 @@ data NtfEnv = NtfEnv } newNtfServerEnv :: NtfServerConfig -> IO NtfEnv -newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbStoreConfig, ntfCredentials, useServiceCreds, startOptions} = do - when (compactLog startOptions) $ compactDbStoreLog $ dbStoreLogPath dbStoreConfig +newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbStoreConfig, ntfCredentials, useServiceCreds} = do random <- C.newRandom store <- newNtfDbStore dbStoreConfig tlsServerCreds <- loadServerCredential ntfCredentials - serviceCertHash@(XV.Fingerprint fp) <- loadFingerprint ntfCredentials - smpAgentCfg' <- - if useServiceCreds - then do - serviceSignKey <- case C.x509ToPrivate' $ snd tlsServerCreds of - Right pk -> pure pk - Left e -> putStrLn ("Server has no valid key: " <> show e) >> exitFailure - let service = ServiceCredentials {serviceRole = SRNotifier, serviceCreds = tlsServerCreds, serviceCertHash, serviceSignKey} - pure smpAgentCfg {smpCfg = (smpCfg smpAgentCfg) {serviceCredentials = Just service}} - else pure smpAgentCfg - subscriber <- newNtfSubscriber smpAgentCfg' random + XV.Fingerprint fp <- loadFingerprint ntfCredentials + let dbService = if useServiceCreds then Just $ mkDbService random store else Nothing + subscriber <- newNtfSubscriber smpAgentCfg dbService random pushServer <- newNtfPushServer pushQSize apnsConfig serverStats <- newNtfServerStats =<< getCurrentTime pure NtfEnv {config, subscriber, pushServer, store, random, tlsServerCreds, serverIdentity = C.KeyHash fp, serverStats} where - compactDbStoreLog = \case - Just f -> do - logNote $ "compacting store log " <> T.pack f - newNtfSTMStore >>= readWriteNtfSTMStore False f >>= closeStoreLog - Nothing -> do - logError "Error: `--compact-log` used without `enable: on` option in STORE_LOG section of INI file" - exitFailure + mkDbService g st = DBService {getCredentials, updateServiceId} + where + getCredentials :: SMPServer -> IO (Either SMPClientError ServiceCredentials) + getCredentials srv = runExceptT $ do + ExceptT (withClientDB "" st $ \db -> getNtfServiceCredentials db srv >>= mapM (mkServiceCreds db)) >>= \case + Just (C.KeyHash kh, serviceCreds) -> do + serviceSignKey <- liftEitherWith PCEIOError $ C.x509ToPrivate' $ snd serviceCreds + pure ServiceCredentials {serviceRole = SRNotifier, serviceCreds, serviceCertHash = XV.Fingerprint kh, serviceSignKey} + Nothing -> throwE PCEServiceUnavailable -- this error cannot happen, as clients never connect to unknown servers + mkServiceCreds db = \case + (_, Just tlsCreds) -> pure tlsCreds + (srvId, Nothing) -> do + cred <- genCredentials g Nothing (25, 24 * 999999) "simplex" + let tlsCreds = tlsCredentials [cred] + setNtfServiceCredentials db srvId tlsCreds + pure tlsCreds + updateServiceId :: SMPServer -> Maybe ServiceId -> IO (Either SMPClientError ()) + updateServiceId srv serviceId_ = withClientDB "" st $ \db -> updateNtfServiceId db srv serviceId_ data NtfSubscriber = NtfSubscriber { smpSubscribers :: TMap SMPServer SMPSubscriberVar, @@ -132,11 +133,11 @@ data NtfSubscriber = NtfSubscriber type SMPSubscriberVar = SessionVar SMPSubscriber -newNtfSubscriber :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO NtfSubscriber -newNtfSubscriber smpAgentCfg random = do +newNtfSubscriber :: SMPClientAgentConfig -> Maybe DBService -> TVar ChaChaDRG -> IO NtfSubscriber +newNtfSubscriber smpAgentCfg dbService random = do smpSubscribers <- TM.emptyIO subscriberSeq <- newTVarIO 0 - smpAgent <- newSMPClientAgent SNotifierService smpAgentCfg random + smpAgent <- newSMPClientAgent SNotifierService smpAgentCfg dbService random pure NtfSubscriber {smpSubscribers, subscriberSeq, smpAgent} data SMPSubscriber = SMPSubscriber diff --git a/src/Simplex/Messaging/Notifications/Server/Main.hs b/src/Simplex/Messaging/Notifications/Server/Main.hs index de12c33f8..e855c84d4 100644 --- a/src/Simplex/Messaging/Notifications/Server/Main.hs +++ b/src/Simplex/Messaging/Notifications/Server/Main.hs @@ -17,42 +17,32 @@ import Data.Functor (($>)) import Data.Ini (lookupValue, readIniFile) import Data.Int (Int64) import Data.Maybe (fromMaybe) -import Data.Set (Set) -import qualified Data.Set as S import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import qualified Data.Text.IO as T import Network.Socket (HostName, ServiceName) import Options.Applicative -import Simplex.Messaging.Agent.Store.Postgres (checkSchemaExists) import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) import Simplex.Messaging.Client (HostMode (..), NetworkConfig (..), ProtocolClientConfig (..), SMPWebPortServers (..), SocksMode (..), defaultNetworkConfig, textToHostMode) import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Notifications.Protocol (NtfTokenId) -import Simplex.Messaging.Notifications.Server (runNtfServer, restoreServerLastNtfs) +import Simplex.Messaging.Notifications.Server (runNtfServer) import Simplex.Messaging.Notifications.Server.Env (NtfServerConfig (..), defaultInactiveClientExpiration) import Simplex.Messaging.Notifications.Server.Push.APNS (defaultAPNSPushClientConfig) -import Simplex.Messaging.Notifications.Server.Store (newNtfSTMStore) -import Simplex.Messaging.Notifications.Server.Store.Postgres (exportNtfDbStore, importNtfSTMStore, newNtfDbStore) -import Simplex.Messaging.Notifications.Server.StoreLog (readWriteNtfSTMStore) import Simplex.Messaging.Notifications.Transport (alpnSupportedNTFHandshakes, supportedServerNTFVRange) import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern NtfServer) import Simplex.Messaging.Server.CLI import Simplex.Messaging.Server.Env.STM (StartOptions (..)) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Server.Main (strParse) import Simplex.Messaging.Server.Main.Init (iniDbOpts) import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) -import Simplex.Messaging.Server.StoreLog (closeStoreLog) import Simplex.Messaging.Transport (ASrvTransport) import Simplex.Messaging.Transport.Client (TransportHost (..)) import Simplex.Messaging.Transport.HTTP2 (httpALPN) import Simplex.Messaging.Transport.Server (AddHTTP, ServerCredentials (..), mkTransportServerConfig) -import Simplex.Messaging.Util (eitherToMaybe, ifM, tshow) -import System.Directory (createDirectoryIfMissing, doesFileExist, renameFile) -import System.Exit (exitFailure) +import Simplex.Messaging.Util (eitherToMaybe, tshow) +import System.Directory (createDirectoryIfMissing, doesFileExist) import System.FilePath (combine) import System.IO (BufferMode (..), hSetBuffering, stderr, stdout) import Text.Read (readMaybe) @@ -73,69 +63,11 @@ ntfServerCLI cfgPath logPath = deleteDirIfExists cfgPath deleteDirIfExists logPath putStrLn "Deleted configuration and log files" - Database cmd dbOpts@DBOpts {connstr, schema} -> withIniFile $ \ini -> do - schemaExists <- checkSchemaExists connstr schema - storeLogExists <- doesFileExist storeLogFilePath - lastNtfsExists <- doesFileExist defaultLastNtfsFile - case cmd of - SCImport skipTokens - | schemaExists && (storeLogExists || lastNtfsExists) -> exitConfigureNtfStore connstr schema - | schemaExists -> do - putStrLn $ "Schema " <> B.unpack schema <> " already exists in PostrgreSQL database: " <> B.unpack connstr - exitFailure - | not storeLogExists -> do - putStrLn $ storeLogFilePath <> " file does not exist." - exitFailure - | not lastNtfsExists -> do - putStrLn $ defaultLastNtfsFile <> " file does not exist." - exitFailure - | otherwise -> do - storeLogFile <- getRequiredStoreLogFile ini - confirmOrExit - ("WARNING: store log file " <> storeLogFile <> " will be compacted and imported to PostrgreSQL database: " <> B.unpack connstr <> ", schema: " <> B.unpack schema) - "Notification server store not imported" - stmStore <- newNtfSTMStore - sl <- readWriteNtfSTMStore True storeLogFile stmStore - closeStoreLog sl - restoreServerLastNtfs stmStore defaultLastNtfsFile - let storeCfg = PostgresStoreCfg {dbOpts = dbOpts {createSchema = True}, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = iniDeletedTTL ini} - ps <- newNtfDbStore storeCfg - (tCnt, sCnt, nCnt, serviceCnt) <- importNtfSTMStore ps stmStore skipTokens - renameFile storeLogFile $ storeLogFile <> ".bak" - putStrLn $ "Import completed: " <> show tCnt <> " tokens, " <> show sCnt <> " subscriptions, " <> show serviceCnt <> " service associations, " <> show nCnt <> " last token notifications." - putStrLn "Configure database options in INI file." - SCExport - | schemaExists && storeLogExists -> exitConfigureNtfStore connstr schema - | not schemaExists -> do - putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr - exitFailure - | storeLogExists -> do - putStrLn $ storeLogFilePath <> " file already exists." - exitFailure - | lastNtfsExists -> do - putStrLn $ defaultLastNtfsFile <> " file already exists." - exitFailure - | otherwise -> do - confirmOrExit - ("WARNING: PostrgreSQL database schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to store log file " <> storeLogFilePath) - "Notification server store not imported" - let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Just storeLogFilePath, confirmMigrations = MCConsole, deletedTTL = iniDeletedTTL ini} - st <- newNtfDbStore storeCfg - (tCnt, sCnt, nCnt) <- exportNtfDbStore st defaultLastNtfsFile - putStrLn $ "Export completed: " <> show tCnt <> " tokens, " <> show sCnt <> " subscriptions, " <> show nCnt <> " last token notifications." where withIniFile a = doesFileExist iniFile >>= \case True -> readIniFile iniFile >>= either exitError a _ -> exitError $ "Error: server is not initialized (" <> iniFile <> " does not exist).\nRun `" <> executableName <> " init`." - getRequiredStoreLogFile ini = do - case enableStoreLog' ini $> storeLogFilePath of - Just storeLogFile -> do - ifM - (doesFileExist storeLogFile) - (pure storeLogFile) - (putStrLn ("Store log file " <> storeLogFile <> " not found") >> exitFailure) - Nothing -> putStrLn "Store log disabled, see `[STORE_LOG] enable`" >> exitFailure iniFile = combine cfgPath "ntf-server.ini" serverVersion = "SMP notifications server v" <> simplexmqVersionCommit defaultServerPort = "443" @@ -289,11 +221,6 @@ ntfServerCLI cfgPath logPath = startOptions } iniDeletedTTL ini = readIniDefault (86400 * defaultDeletedTTL) "STORE_LOG" "db_deleted_ttl" ini - defaultLastNtfsFile = combine logPath "ntf-server-last-notifications.log" - exitConfigureNtfStore connstr schema = do - putStrLn $ "Error: both " <> storeLogFilePath <> " file and " <> B.unpack schema <> " schema are present (database: " <> B.unpack connstr <> ")." - putStrLn "Configure notification server storage." - exitFailure printNtfServerConfig :: [(ServiceName, ASrvTransport, AddHTTP)] -> PostgresStoreCfg -> IO () printNtfServerConfig transports PostgresStoreCfg {dbOpts = DBOpts {connstr, schema}, dbStoreLogPath} = do @@ -305,9 +232,6 @@ data CliCommand | OnlineCert CertOptions | Start StartOptions | Delete - | Database StoreCmd DBOpts - -data StoreCmd = SCImport (Set NtfTokenId) | SCExport data InitOptions = InitOptions { enableStoreLog :: Bool, @@ -338,22 +262,8 @@ cliCommandP cfgPath logPath iniFile = <> command "cert" (info (OnlineCert <$> certOptionsP) (progDesc $ "Generate new online TLS server credentials (configuration: " <> iniFile <> ")")) <> command "start" (info (Start <$> startOptionsP) (progDesc $ "Start server (configuration: " <> iniFile <> ")")) <> command "delete" (info (pure Delete) (progDesc "Delete configuration and log files")) - <> command "database" (info (Database <$> databaseCmdP <*> dbOptsP defaultNtfDBOpts) (progDesc "Import/export notifications server store to/from PostgreSQL database")) ) where - databaseCmdP = - hsubparser - ( command "import" (info (SCImport <$> skipTokensP) (progDesc $ "Import store logs into a new PostgreSQL database schema")) - <> command "export" (info (pure SCExport) (progDesc $ "Export PostgreSQL database schema to store logs")) - ) - skipTokensP :: Parser (Set NtfTokenId) - skipTokensP = - option - strParse - ( long "skip-tokens" - <> help "Skip tokens during import" - <> value S.empty - ) initP :: Parser InitOptions initP = do enableStoreLog <- diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs index 8c0da7c07..87e89ac8d 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs @@ -14,7 +14,8 @@ ntfServerSchemaMigrations :: [(String, Text, Maybe Text)] ntfServerSchemaMigrations = [ ("20250417_initial", m20250417_initial, Nothing), ("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert), - ("20250830_queue_ids_hash", m20250830_queue_ids_hash, Just down_m20250830_queue_ids_hash) + ("20250830_queue_ids_hash", m20250830_queue_ids_hash, Just down_m20250830_queue_ids_hash), + ("20251219_service_cert_per_server", m20251219_service_cert_per_server, Just down_m20251219_service_cert_per_server) ] -- | The list of migrations in ascending order by date @@ -225,3 +226,36 @@ ALTER TABLE smp_servers DROP COLUMN smp_notifier_ids_hash; |] <> dropXorHashFuncs + +m20251219_service_cert_per_server :: Text +m20251219_service_cert_per_server = + [r| +ALTER TABLE smp_servers + ADD COLUMN ntf_service_cert BYTEA, + ADD COLUMN ntf_service_cert_hash BYTEA, + ADD COLUMN ntf_service_priv_key BYTEA; + |] + <> resetNtfServices + +down_m20251219_service_cert_per_server :: Text +down_m20251219_service_cert_per_server = + [r| +ALTER TABLE smp_servers + DROP COLUMN ntf_service_cert, + DROP COLUMN ntf_service_cert_hash, + DROP COLUMN ntf_service_priv_key; + |] + <> resetNtfServices + +resetNtfServices :: Text +resetNtfServices = + [r| +ALTER TABLE subscriptions DISABLE TRIGGER tr_subscriptions_update; +UPDATE subscriptions SET ntf_service_assoc = FALSE; +ALTER TABLE subscriptions ENABLE TRIGGER tr_subscriptions_update; + +UPDATE smp_servers +SET ntf_service_id = NULL, + smp_notifier_count = 0, + smp_notifier_ids_hash = DEFAULT; + |] diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs index 60e81a68b..80ab45ca1 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -18,7 +18,6 @@ module Simplex.Messaging.Notifications.Server.Store.Postgres where -import Control.Concurrent.STM import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad @@ -26,19 +25,13 @@ import Control.Monad.Except import Control.Monad.IO.Class import Control.Monad.Trans.Except import Data.Bitraversable (bimapM) -import qualified Data.ByteString.Base64.URL as B64 import Data.ByteString.Char8 (ByteString) -import qualified Data.ByteString.Char8 as B -import Data.Containers.ListUtils (nubOrd) import Data.Either (fromRight) import Data.Functor (($>)) import Data.Int (Int64) -import Data.List (findIndex, foldl') import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L -import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe, isJust, mapMaybe) -import qualified Data.Set as S import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1, encodeUtf8) @@ -51,31 +44,30 @@ import Database.PostgreSQL.Simple.FromField (FromField (..)) import Database.PostgreSQL.Simple.SqlQQ (sql) import Database.PostgreSQL.Simple.ToField (ToField (..)) import Network.Socket (ServiceName) +import qualified Network.TLS as TLS import Simplex.Messaging.Agent.Store.AgentStore () import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore) import Simplex.Messaging.Agent.Store.Postgres.Common import Simplex.Messaging.Agent.Store.Postgres.DB (fromTextField_) import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..)) +import Simplex.Messaging.Client (ProtocolClientError (..), SMPClientError) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol -import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore (..), NtfSubData (..), NtfTknData (..), TokenNtfMessageRecord (..), ntfSubServer) import Simplex.Messaging.Notifications.Server.Store.Migrations import Simplex.Messaging.Notifications.Server.Store.Types -import Simplex.Messaging.Notifications.Server.StoreLog -import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), IdsHash (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, ServiceSub (..), pattern SMPServer) -import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_) +import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), IdsHash (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, ProtocolServer (..), SMPServer, ServiceId, ServiceSub (..), pattern SMPServer) +import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate) import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) -import Simplex.Messaging.Server.StoreLog (openWriteStoreLog) import Simplex.Messaging.SystemTime import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (anyM, firstRow, maybeFirstRow, toChunks, tshow) +import Simplex.Messaging.Util (firstRow, maybeFirstRow, tshow) import System.Exit (exitFailure) -import System.IO (IOMode (..), hFlush, stdout, withFile) import Text.Hex (decodeHex) #if !defined(dbPostgres) +import qualified Data.X509 as X import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder) import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util (eitherToMaybe) @@ -83,7 +75,6 @@ import Simplex.Messaging.Util (eitherToMaybe) data NtfPostgresStore = NtfPostgresStore { dbStore :: DBStore, - dbStoreLog :: Maybe (StoreLog 'WriteMode), deletedTTL :: Int64 } @@ -99,25 +90,22 @@ data NtfEntityRec (e :: NtfEntity) where NtfSub :: NtfSubRec -> NtfEntityRec 'Subscription newNtfDbStore :: PostgresStoreCfg -> IO NtfPostgresStore -newNtfDbStore PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations, deletedTTL} = do +newNtfDbStore PostgresStoreCfg {dbOpts, confirmMigrations, deletedTTL} = do dbStore <- either err pure =<< createDBStore dbOpts ntfServerMigrations (MigrationConfig confirmMigrations Nothing) - dbStoreLog <- mapM (openWriteStoreLog True) dbStoreLogPath - pure NtfPostgresStore {dbStore, dbStoreLog, deletedTTL} + pure NtfPostgresStore {dbStore, deletedTTL} where err e = do logError $ "STORE: newNtfStore, error opening PostgreSQL database, " <> tshow e exitFailure closeNtfDbStore :: NtfPostgresStore -> IO () -closeNtfDbStore NtfPostgresStore {dbStore, dbStoreLog} = do - closeDBStore dbStore - mapM_ closeStoreLog dbStoreLog +closeNtfDbStore NtfPostgresStore {dbStore} = closeDBStore dbStore addNtfToken :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) addNtfToken st tkn = withFastDB "addNtfToken" st $ \db -> - E.try (DB.execute db insertNtfTknQuery $ ntfTknToRow tkn) - >>= bimapM handleDuplicate (\_ -> withLog "addNtfToken" st (`logCreateToken` tkn)) + E.try (void $ DB.execute db insertNtfTknQuery $ ntfTknToRow tkn) + >>= bimapM handleDuplicate pure insertNtfTknQuery :: Query insertNtfTknQuery = @@ -128,7 +116,7 @@ insertNtfTknQuery = |] replaceNtfToken :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) -replaceNtfToken st NtfTknRec {ntfTknId, token = token@(DeviceToken pp ppToken), tknStatus, tknRegCode = code@(NtfRegCode regCode)} = +replaceNtfToken st NtfTknRec {ntfTknId, token = DeviceToken pp ppToken, tknStatus, tknRegCode = NtfRegCode regCode} = withFastDB "replaceNtfToken" st $ \db -> runExceptT $ do ExceptT $ assertUpdated <$> DB.execute @@ -139,7 +127,6 @@ replaceNtfToken st NtfTknRec {ntfTknId, token = token@(DeviceToken pp ppToken), WHERE token_id = ? |] (pp, Binary ppToken, tknStatus, Binary regCode, ntfTknId) - withLog "replaceNtfToken" st $ \sl -> logUpdateToken sl ntfTknId token code ntfTknToRow :: NtfTknRec -> NtfTknRow ntfTknToRow NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} = @@ -160,15 +147,14 @@ getNtfToken_ :: ToRow q => NtfPostgresStore -> Query -> q -> IO (Either ErrorTyp getNtfToken_ st cond params = withFastDB' "getNtfToken" st $ \db -> do tkn_ <- maybeFirstRow rowToNtfTkn $ DB.query db (ntfTknQuery <> cond) params - mapM_ (updateTokenDate st db) tkn_ + mapM_ (updateTokenDate db) tkn_ pure tkn_ -updateTokenDate :: NtfPostgresStore -> DB.Connection -> NtfTknRec -> IO () -updateTokenDate st db NtfTknRec {ntfTknId, tknUpdatedAt} = do +updateTokenDate :: DB.Connection -> NtfTknRec -> IO () +updateTokenDate db NtfTknRec {ntfTknId, tknUpdatedAt} = do ts <- getSystemDate when (maybe True (ts /=) tknUpdatedAt) $ do void $ DB.execute db "UPDATE tokens SET updated_at = ? WHERE token_id = ?" (ts, ntfTknId) - withLog "updateTokenDate" st $ \sl -> logUpdateTokenTime sl ntfTknId ts type NtfTknRow = (NtfTokenId, PushProvider, Binary ByteString, NtfTknStatus, NtfPublicAuthKey, C.PrivateKeyX25519, C.DhSecretX25519, Binary ByteString, Word16, Maybe SystemDate) @@ -206,7 +192,6 @@ deleteNtfToken st tknId = |] (Only tknId) liftIO $ void $ DB.execute db "DELETE FROM tokens WHERE token_id = ?" (Only tknId) - withLog "deleteNtfToken" st (`logDeleteToken` tknId) pure subs where toServerSubs :: SMPServerRow :. Only Text -> (SMPServer, [NotifierId]) @@ -235,7 +220,6 @@ updateTknCronInterval st tknId cronInt = withFastDB "updateTknCronInterval" st $ \db -> runExceptT $ do ExceptT $ assertUpdated <$> DB.execute db "UPDATE tokens SET cron_interval = ? WHERE token_id = ?" (cronInt, tknId) - withLog "updateTknCronInterval" st $ \sl -> logTokenCron sl tknId 0 -- Reads servers that have subscriptions that need subscribing. -- It is executed on server start, and it is supposed to crash on database error @@ -259,6 +243,73 @@ getUsedSMPServers st = let service_ = (\serviceId -> ServiceSub serviceId n idsHash) <$> serviceId_ in (SMPServer host port kh, srvId, service_) +getNtfServiceCredentials :: DB.Connection -> SMPServer -> IO (Maybe (Int64, Maybe (C.KeyHash, TLS.Credential))) +getNtfServiceCredentials db srv = + maybeFirstRow toService $ + DB.query + db + [sql| + SELECT smp_server_id, ntf_service_cert_hash, ntf_service_cert, ntf_service_priv_key + FROM smp_servers + WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? + FOR UPDATE + |] + (host srv, port srv, keyHash srv) + where + toService (Only srvId :. creds) = (srvId, toCredentials creds) + toCredentials = \case + (Just kh, Just cert, Just pk) -> Just (kh, (cert, pk)) + _ -> Nothing + +setNtfServiceCredentials :: DB.Connection -> Int64 -> (C.KeyHash, TLS.Credential) -> IO () +setNtfServiceCredentials db srvId (kh, (cert, pk)) = + void $ DB.execute + db + [sql| + UPDATE smp_servers + SET ntf_service_cert_hash = ?, ntf_service_cert = ?, ntf_service_priv_key = ? + WHERE smp_server_id = ? + |] + (kh, cert, pk, srvId) + +updateNtfServiceId :: DB.Connection -> SMPServer -> Maybe ServiceId -> IO () +updateNtfServiceId db srv newServiceId_ = do + maybeFirstRow id (getSMPServiceForUpdate_ db srv) >>= mapM_ updateService + where + updateService (srvId, currServiceId_) = unless (currServiceId_ == newServiceId_) $ do + when (isJust currServiceId_) $ do + void $ removeServiceAssociation_ db srvId + logError $ "STORE: service ID for " <> enc (host srv) <> toServiceId <> ", removed sub associations" + void $ case newServiceId_ of + Just newServiceId -> + DB.execute + db + [sql| + UPDATE smp_servers + SET ntf_service_id = ?, + smp_notifier_count = 0, + smp_notifier_ids_hash = DEFAULT + WHERE smp_server_id = ? + |] + (newServiceId, srvId) + Nothing -> + DB.execute + db + [sql| + UPDATE smp_servers + SET ntf_service_id = NULL, + ntf_service_cert = NULL, + ntf_service_cert_hash = NULL, + ntf_service_priv_key = NULL, + smp_notifier_count = 0, + smp_notifier_ids_hash = DEFAULT + WHERE smp_server_id = ? + |] + (Only srvId) + toServiceId = maybe " removed" ((" changed to " <>) . enc) newServiceId_ + enc :: StrEncoding a => a -> Text + enc = decodeLatin1 . strEncode + getServerNtfSubscriptions :: NtfPostgresStore -> Int64 -> Maybe NtfSubscriptionId -> Int -> IO (Either ErrorType [ServerNtfSub]) getServerNtfSubscriptions st srvId afterSubId_ count = withDB' "getServerNtfSubscriptions" st $ \db -> do @@ -297,7 +348,7 @@ findNtfSubscription st tknId q = withFastDB "findNtfSubscription" st $ \db -> runExceptT $ do tkn@NtfTknRec {ntfTknId, tknStatus} <- ExceptT $ getNtfToken st tknId unless (allowNtfSubCommands tknStatus) $ throwE AUTH - liftIO $ updateTokenDate st db tkn + liftIO $ updateTokenDate db tkn sub_ <- liftIO $ maybeFirstRow (rowToNtfSub q) $ DB.query @@ -330,7 +381,7 @@ getNtfSubscription st subId = WHERE s.subscription_id = ? |] (Only subId) - liftIO $ updateTokenDate st db tkn + liftIO $ updateTokenDate db tkn unless (allowNtfSubCommands tknStatus) $ throwE AUTH pure r @@ -352,36 +403,30 @@ mkNtfSubRec ntfSubId (NewNtfSub tokenId smpQueue notifierKey) = updateTknStatus :: NtfPostgresStore -> NtfTknRec -> NtfTknStatus -> IO (Either ErrorType ()) updateTknStatus st tkn status = - withFastDB' "updateTknStatus" st $ \db -> updateTknStatus_ st db tkn status + withFastDB' "updateTknStatus" st $ \db -> updateTknStatus_ db tkn status -updateTknStatus_ :: NtfPostgresStore -> DB.Connection -> NtfTknRec -> NtfTknStatus -> IO () -updateTknStatus_ st db NtfTknRec {ntfTknId} status = do - updated <- DB.execute db "UPDATE tokens SET status = ? WHERE token_id = ? AND status != ?" (status, ntfTknId, status) - when (updated > 0) $ withLog "updateTknStatus" st $ \sl -> logTokenStatus sl ntfTknId status +updateTknStatus_ :: DB.Connection -> NtfTknRec -> NtfTknStatus -> IO () +updateTknStatus_ db NtfTknRec {ntfTknId} status = + void $ DB.execute db "UPDATE tokens SET status = ? WHERE token_id = ? AND status != ?" (status, ntfTknId, status) -- unless it was already active setTknStatusConfirmed :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) setTknStatusConfirmed st NtfTknRec {ntfTknId} = - withFastDB' "updateTknStatus" st $ \db -> do - updated <- DB.execute db "UPDATE tokens SET status = ? WHERE token_id = ? AND status != ? AND status != ?" (NTConfirmed, ntfTknId, NTConfirmed, NTActive) - when (updated > 0) $ withLog "updateTknStatus" st $ \sl -> logTokenStatus sl ntfTknId NTConfirmed + withFastDB' "updateTknStatus" st $ \db -> + void $ DB.execute db "UPDATE tokens SET status = ? WHERE token_id = ? AND status != ? AND status != ?" (NTConfirmed, ntfTknId, NTConfirmed, NTActive) setTokenActive :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) setTokenActive st tkn@NtfTknRec {ntfTknId, token = DeviceToken pp ppToken} = withFastDB' "setTokenActive" st $ \db -> do - updateTknStatus_ st db tkn NTActive + updateTknStatus_ db tkn NTActive -- this removes other instances of the same token, e.g. because of repeated token registration attempts - tknIds <- - liftIO $ map fromOnly <$> - DB.query - db - [sql| - DELETE FROM tokens - WHERE push_provider = ? AND push_provider_token = ? AND token_id != ? - RETURNING token_id - |] - (pp, Binary ppToken, ntfTknId) - withLog "deleteNtfToken" st $ \sl -> mapM_ (logDeleteToken sl) tknIds + void $ DB.execute + db + [sql| + DELETE FROM tokens + WHERE push_provider = ? AND push_provider_token = ? AND token_id != ? + |] + (pp, Binary ppToken, ntfTknId) withPeriodicNtfTokens :: NtfPostgresStore -> Int64 -> (NtfTknRec -> IO ()) -> IO Int withPeriodicNtfTokens st now notify = @@ -399,7 +444,6 @@ addNtfSubscription st sub = withFastDB "addNtfSubscription" st $ \db -> runExceptT $ do srvId :: Int64 <- ExceptT $ upsertServer db $ ntfSubServer' sub n <- liftIO $ DB.execute db insertNtfSubQuery $ ntfSubToRow srvId sub - withLog "addNtfSubscription" st (`logCreateSubscription` sub) pure (srvId, n > 0) where -- It is possible to combine these two statements into one with CTEs, @@ -442,76 +486,66 @@ ntfSubToRow srvId NtfSubRec {ntfSubId, tokenId, smpQueue = SMPQueueNtf _ nId, no deleteNtfSubscription :: NtfPostgresStore -> NtfSubscriptionId -> IO (Either ErrorType ()) deleteNtfSubscription st subId = - withFastDB "deleteNtfSubscription" st $ \db -> runExceptT $ do - ExceptT $ assertUpdated <$> + withFastDB "deleteNtfSubscription" st $ \db -> + assertUpdated <$> DB.execute db "DELETE FROM subscriptions WHERE subscription_id = ?" (Only subId) - withLog "deleteNtfSubscription" st (`logDeleteSubscription` subId) updateSubStatus :: NtfPostgresStore -> Int64 -> NotifierId -> NtfSubStatus -> IO (Either ErrorType ()) updateSubStatus st srvId nId status = withFastDB' "updateSubStatus" st $ \db -> do - sub_ :: Maybe (NtfSubscriptionId, NtfAssociatedService) <- - maybeFirstRow id $ - DB.query - db - [sql| - UPDATE subscriptions SET status = ? - WHERE smp_server_id = ? AND smp_notifier_id = ? AND status != ? - RETURNING subscription_id, ntf_service_assoc - |] - (status, srvId, nId, status) - forM_ sub_ $ \(subId, serviceAssoc) -> - withLog "updateSubStatus" st $ \sl -> logSubscriptionStatus sl (subId, status, serviceAssoc) + void $ + DB.execute + db + [sql| + UPDATE subscriptions SET status = ? + WHERE smp_server_id = ? AND smp_notifier_id = ? AND status != ? + |] + (status, srvId, nId, status) updateSrvSubStatus :: NtfPostgresStore -> SMPQueueNtf -> NtfSubStatus -> IO (Either ErrorType ()) updateSrvSubStatus st q status = - withFastDB' "updateSrvSubStatus" st $ \db -> do - sub_ :: Maybe (NtfSubscriptionId, NtfAssociatedService) <- - maybeFirstRow id $ - DB.query - db - [sql| - UPDATE subscriptions s - SET status = ? - FROM smp_servers p - WHERE p.smp_server_id = s.smp_server_id - AND p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? AND s.smp_notifier_id = ? - AND s.status != ? - RETURNING s.subscription_id, s.ntf_service_assoc - |] - (Only status :. smpQueueToRow q :. Only status) - forM_ sub_ $ \(subId, serviceAssoc) -> - withLog "updateSrvSubStatus" st $ \sl -> logSubscriptionStatus sl (subId, status, serviceAssoc) + withFastDB' "updateSrvSubStatus" st $ \db -> + void $ + DB.execute + db + [sql| + UPDATE subscriptions s + SET status = ? + FROM smp_servers p + WHERE p.smp_server_id = s.smp_server_id + AND p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? AND s.smp_notifier_id = ? + AND s.status != ? + |] + (Only status :. smpQueueToRow q :. Only status) batchUpdateSrvSubStatus :: NtfPostgresStore -> SMPServer -> Maybe ServiceId -> NonEmpty NotifierId -> NtfSubStatus -> IO Int batchUpdateSrvSubStatus st srv newServiceId nIds status = fmap (fromRight (-1)) $ withDB "batchUpdateSrvSubStatus" st $ \db -> runExceptT $ do - (srvId :: Int64, currServiceId) <- ExceptT $ getSMPServerService db + (srvId, currServiceId) <- ExceptT $ firstRow id AUTH $ getSMPServiceForUpdate_ db srv + -- TODO [certs rcv] should this remove associations/credentials when newServiceId is Nothing or different unless (currServiceId == newServiceId) $ liftIO $ void $ DB.execute db "UPDATE smp_servers SET ntf_service_id = ? WHERE smp_server_id = ?" (newServiceId, srvId) let params = L.toList $ L.map (srvId,isJust newServiceId,status,) nIds liftIO $ fromIntegral <$> DB.executeMany db updateSubStatusQuery params - where - getSMPServerService db = - firstRow id AUTH $ - DB.query - db - [sql| - SELECT smp_server_id, ntf_service_id - FROM smp_servers - WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? - FOR UPDATE - |] - (srvToRow srv) + +getSMPServiceForUpdate_ :: DB.Connection -> SMPServer -> IO [(Int64, Maybe ServiceId)] +getSMPServiceForUpdate_ db srv = + DB.query + db + [sql| + SELECT smp_server_id, ntf_service_id + FROM smp_servers + WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? + FOR UPDATE + |] + (srvToRow srv) batchUpdateSrvSubErrors :: NtfPostgresStore -> SMPServer -> NonEmpty (NotifierId, NtfSubStatus) -> IO Int batchUpdateSrvSubErrors st srv subs = fmap (fromRight (-1)) $ withDB "batchUpdateSrvSubErrors" st $ \db -> runExceptT $ do srvId :: Int64 <- ExceptT $ getSMPServerId db let params = map (\(nId, status) -> (srvId, False, status, nId)) $ L.toList subs - subs' <- liftIO $ DB.returning db (updateSubStatusQuery <> " RETURNING s.subscription_id, s.status, s.ntf_service_assoc") params - withLog "batchUpdateStatus_" st $ forM_ subs' . logSubscriptionStatus - pure $ length subs' + liftIO $ fromIntegral <$> DB.executeMany db updateSubStatusQuery params where getSMPServerId db = firstRow fromOnly AUTH $ @@ -535,36 +569,51 @@ updateSubStatusQuery = AND (s.status != upd.status OR s.ntf_service_assoc != upd.ntf_service_assoc) |] -removeServiceAssociation :: NtfPostgresStore -> SMPServer -> IO (Either ErrorType (Int64, Int)) -removeServiceAssociation st srv = do - withDB "removeServiceAssociation" st $ \db -> runExceptT $ do - srvId <- ExceptT $ removeServerService db - subs <- - liftIO $ - DB.query - db - [sql| - UPDATE subscriptions s - SET status = ?, ntf_service_assoc = FALSE - WHERE smp_server_id = ? - AND (s.status != ? OR s.ntf_service_assoc != FALSE) - RETURNING s.subscription_id, s.status, s.ntf_service_assoc - |] - (NSInactive, srvId, NSInactive) - withLog "removeServiceAssociation" st $ forM_ subs . logSubscriptionStatus - pure (srvId, length subs) +removeServiceAssociation_ :: DB.Connection -> Int64 -> IO Int64 +removeServiceAssociation_ db srvId = + DB.execute + db + [sql| + UPDATE subscriptions s + SET status = ?, ntf_service_assoc = FALSE + WHERE smp_server_id = ? + AND (s.status != ? OR s.ntf_service_assoc != FALSE) + |] + (NSInactive, srvId, NSInactive) + +removeServiceAndAssociations :: NtfPostgresStore -> SMPServer -> IO (Either ErrorType (Int64, Int)) +removeServiceAndAssociations st srv = do + withDB "removeServiceAndAssociations" st $ \db -> runExceptT $ do + srvId <- ExceptT $ getServerId db + subsCount <- liftIO $ removeServiceAssociation_ db srvId + liftIO $ removeServerService db srvId + pure (srvId, fromIntegral subsCount) where - removeServerService db = + getServerId db = firstRow fromOnly AUTH $ DB.query db [sql| - UPDATE smp_servers - SET ntf_service_id = NULL + SELECT smp_server_id + FROM smp_servers WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? - RETURNING smp_server_id + FOR UPDATE |] (srvToRow srv) + removeServerService db srvId = + DB.execute + db + [sql| + UPDATE smp_servers + SET ntf_service_id = NULL, + ntf_service_cert = NULL, + ntf_service_cert_hash = NULL, + ntf_service_priv_key = NULL, + smp_notifier_count = 0, + smp_notifier_ids_hash = DEFAULT + WHERE smp_server_id = ? + |] + (Only srvId) addTokenLastNtf :: NtfPostgresStore -> PNMessageData -> IO (Either ErrorType (NtfTknRec, NonEmpty PNMessageData)) addTokenLastNtf st newNtf = @@ -646,216 +695,6 @@ getEntityCounts st = count (Only n : _) = n count [] = 0 -importNtfSTMStore :: NtfPostgresStore -> NtfSTMStore -> S.Set NtfTokenId -> IO (Int64, Int64, Int64, Int64) -importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore skipTokens = do - (tIds, tCnt) <- importTokens - subLookup <- readTVarIO $ subscriptionLookup stmStore - sCnt <- importSubscriptions tIds subLookup - nCnt <- importLastNtfs tIds subLookup - serviceCnt <- importNtfServiceIds - pure (tCnt, sCnt, nCnt, serviceCnt) - where - importTokens = do - allTokens <- M.elems <$> readTVarIO (tokens stmStore) - tokens <- filterTokens allTokens - let skipped = length allTokens - length tokens - when (skipped /= 0) $ putStrLn $ "Total skipped tokens " <> show skipped - -- uncomment this line instead of the next two to import tokens one by one. - -- tCnt <- withConnection s $ \db -> foldM (importTkn db) 0 tokens - -- token interval is reset to 0 to only send notifications to devices with periodic mode, - -- and before clients are upgraded - to all active devices. - tRows <- mapM (fmap (ntfTknToRow . (\t -> t {tknCronInterval = 0} :: NtfTknRec)) . mkTknRec) tokens - tCnt <- withConnection s $ \db -> DB.executeMany db insertNtfTknQuery tRows - let tokenIds = S.fromList $ map (\NtfTknData {ntfTknId} -> ntfTknId) tokens - (tokenIds,) <$> checkCount "token" (length tokens) tCnt - where - filterTokens tokens = do - let deviceTokens = foldl' (\m t -> M.alter (Just . (t :) . fromMaybe []) (tokenKey t) m) M.empty tokens - tokenSubs <- readTVarIO (tokenSubscriptions stmStore) - filterM (keepTokenRegistration deviceTokens tokenSubs) tokens - tokenKey NtfTknData {token, tknVerifyKey} = strEncode token <> ":" <> C.toPubKey C.pubKeyBytes tknVerifyKey - keepTokenRegistration deviceTokens tokenSubs tkn@NtfTknData {ntfTknId, tknStatus} = - case M.lookup (tokenKey tkn) deviceTokens of - Just ts - | length ts < 2 -> pure True - | ntfTknId `S.member` skipTokens -> False <$ putStrLn ("Skipped token " <> enc ntfTknId <> " from --skip-tokens") - | otherwise -> - readTVarIO tknStatus >>= \case - NTConfirmed -> do - hasSubs <- maybe (pure False) (\v -> not . S.null <$> readTVarIO v) $ M.lookup ntfTknId tokenSubs - if hasSubs - then pure True - else do - anyBetterToken <- anyM $ map (\NtfTknData {tknStatus = tknStatus'} -> activeOrInvalid <$> readTVarIO tknStatus') ts - if anyBetterToken - then False <$ putStrLn ("Skipped duplicate inactive token " <> enc ntfTknId) - else case findIndex (\NtfTknData {ntfTknId = tId} -> tId == ntfTknId) ts of - Just 0 -> pure True -- keeping the first token - Just _ -> False <$ putStrLn ("Skipped duplicate inactive token " <> enc ntfTknId <> " (no active token)") - Nothing -> True <$ putStrLn "Error: no device token in the list" - _ -> pure True - Nothing -> True <$ putStrLn "Error: no device token in lookup map" - activeOrInvalid = \case - NTActive -> True - NTInvalid _ -> True - _ -> False - -- importTkn db !n tkn@NtfTknData {ntfTknId} = do - -- tknRow <- ntfTknToRow <$> mkTknRec tkn - -- (DB.execute db insertNtfTknQuery tknRow >>= pure . (n + )) `E.catch` \(e :: E.SomeException) -> - -- putStrLn ("Error inserting token " <> enc ntfTknId <> " " <> show e) $> n - importSubscriptions :: S.Set NtfTokenId -> M.Map SMPQueueNtf NtfSubscriptionId -> IO Int64 - importSubscriptions tIds subLookup = do - subs <- filterSubs . M.elems =<< readTVarIO (subscriptions stmStore) - srvIds <- importServers subs - putStrLn $ "Importing " <> show (length subs) <> " subscriptions..." - -- uncomment this line instead of the next to import subs one by one. - -- (sCnt, errTkns) <- withConnection s $ \db -> foldM (importSub db srvIds) (0, M.empty) subs - sCnt <- foldM (importSubs srvIds) 0 $ toChunks 500000 subs - checkCount "subscription" (length subs) sCnt - where - filterSubs allSubs = do - let subs = filter (\NtfSubData {tokenId} -> S.member tokenId tIds) allSubs - skipped = length allSubs - length subs - when (skipped /= 0) $ putStrLn $ "Skipped " <> show skipped <> " subscriptions of missing tokens" - let (removedSubTokens, removeSubs, dupQueues) = foldl' addSubToken (S.empty, S.empty, S.empty) subs - unless (null removeSubs) $ putStrLn $ "Skipped " <> show (S.size removeSubs) <> " duplicate subscriptions of " <> show (S.size removedSubTokens) <> " tokens for " <> show (S.size dupQueues) <> " queues" - pure $ filter (\NtfSubData {ntfSubId} -> S.notMember ntfSubId removeSubs) subs - where - addSubToken acc@(!stIds, !sIds, !qs) NtfSubData {ntfSubId, smpQueue, tokenId} = - case M.lookup smpQueue subLookup of - Just sId | sId /= ntfSubId -> - (S.insert tokenId stIds, S.insert ntfSubId sIds, S.insert smpQueue qs) - _ -> acc - importSubs srvIds !n subs = do - rows <- mapM (ntfSubRow srvIds) subs - cnt <- withConnection s $ \db -> DB.executeMany db insertNtfSubQuery $ L.toList rows - let n' = n + cnt - putStr $ "Imported " <> show n' <> " subscriptions" <> "\r" - hFlush stdout - pure n' - -- importSub db srvIds (!n, !errTkns) sub@NtfSubData {ntfSubId = sId, tokenId} = do - -- subRow <- ntfSubRow srvIds sub - -- E.try (DB.execute db insertNtfSubQuery subRow) >>= \case - -- Right i -> do - -- let n' = n + i - -- when (n' `mod` 100000 == 0) $ do - -- putStr $ "Imported " <> show n' <> " subscriptions" <> "\r" - -- hFlush stdout - -- pure (n', errTkns) - -- Left (e :: E.SomeException) -> do - -- when (n `mod` 100000 == 0) $ putStrLn "" - -- putStrLn $ "Error inserting subscription " <> enc sId <> " for token " <> enc tokenId <> " " <> show e - -- pure (n, M.alter (Just . maybe [sId] (sId :)) tokenId errTkns) - ntfSubRow srvIds sub = case M.lookup srv srvIds of - Just sId -> ntfSubToRow sId <$> mkSubRec sub - Nothing -> E.throwIO $ userError $ "no matching server ID for server " <> show srv - where - srv = ntfSubServer sub - importServers subs = do - sIds <- withConnection s $ \db -> map fromOnly <$> DB.returning db srvQuery (map srvToRow srvs) - void $ checkCount "server" (length srvs) (length sIds) - pure $ M.fromList $ zip srvs sIds - where - srvQuery = "INSERT INTO smp_servers (smp_host, smp_port, smp_keyhash) VALUES (?, ?, ?) RETURNING smp_server_id" - srvs = nubOrd $ map ntfSubServer subs - importLastNtfs :: S.Set NtfTokenId -> M.Map SMPQueueNtf NtfSubscriptionId -> IO Int64 - importLastNtfs tIds subLookup = do - ntfs <- readTVarIO (tokenLastNtfs stmStore) - ntfRows <- filterLastNtfRows ntfs - nCnt <- withConnection s $ \db -> DB.executeMany db lastNtfQuery ntfRows - checkCount "last notification" (length ntfRows) nCnt - where - lastNtfQuery = "INSERT INTO last_notifications(token_id, subscription_id, sent_at, nmsg_nonce, nmsg_data) VALUES (?,?,?,?,?)" - filterLastNtfRows ntfs = do - (skippedTkns, ntfCnt, (skippedQueues, ntfRows)) <- foldM lastNtfRows (S.empty, 0, (S.empty, [])) $ M.assocs ntfs - let skipped = ntfCnt - length ntfRows - when (skipped /= 0) $ putStrLn $ "Skipped last notifications " <> show skipped <> " for " <> show (S.size skippedTkns) <> " missing tokens and " <> show (S.size skippedQueues) <> " missing subscriptions with token present" - pure ntfRows - lastNtfRows (!stIds, !cnt, !acc) (tId, ntfVar) = do - ntfs <- L.toList <$> readTVarIO ntfVar - let cnt' = cnt + length ntfs - pure $ - if S.member tId tIds - then (stIds, cnt', foldl' ntfRow acc ntfs) - else (S.insert tId stIds, cnt', acc) - where - ntfRow (!qs, !rows) PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} = case M.lookup smpQueue subLookup of - Just ntfSubId -> - let row = (tId, ntfSubId, systemToUTCTime ntfTs, nmsgNonce, Binary encNMsgMeta) - in (qs, row : rows) - Nothing -> (S.insert smpQueue qs, rows) - importNtfServiceIds = do - ss <- M.assocs <$> readTVarIO (ntfServices stmStore) - withConnection s $ \db -> DB.executeMany db serviceQuery $ map serviceToRow ss - where - serviceQuery = - [sql| - INSERT INTO smp_servers (smp_host, smp_port, smp_keyhash, ntf_service_id) - VALUES (?, ?, ?, ?) - ON CONFLICT (smp_host, smp_port, smp_keyhash) - DO UPDATE SET ntf_service_id = EXCLUDED.ntf_service_id - |] - serviceToRow (srv, serviceId) = srvToRow srv :. Only serviceId - checkCount name expected inserted - | fromIntegral expected == inserted = do - putStrLn $ "Imported " <> show inserted <> " " <> name <> "s." - pure inserted - | otherwise = do - putStrLn $ "Incorrect " <> name <> " count: expected " <> show expected <> ", imported " <> show inserted - putStrLn "Import aborted, fix data and repeat" - exitFailure - enc = B.unpack . B64.encode . unEntityId - -exportNtfDbStore :: NtfPostgresStore -> FilePath -> IO (Int, Int, Int) -exportNtfDbStore NtfPostgresStore {dbStoreLog = Nothing} _ = - putStrLn "Internal error: export requires store log" >> exitFailure -exportNtfDbStore NtfPostgresStore {dbStore = s, dbStoreLog = Just sl} lastNtfsFile = - (,,) <$> exportTokens <*> exportSubscriptions <*> exportLastNtfs - where - exportTokens = do - tCnt <- withConnection s $ \db -> DB.fold_ db ntfTknQuery 0 $ \ !i tkn -> - logCreateToken sl (rowToNtfTkn tkn) $> (i + 1) - putStrLn $ "Exported " <> show tCnt <> " tokens" - pure tCnt - exportSubscriptions = do - sCnt <- withConnection s $ \db -> DB.fold_ db ntfSubQuery 0 $ \ !i sub -> do - let i' = i + 1 - logCreateSubscription sl (toNtfSub sub) - when (i' `mod` 500000 == 0) $ do - putStr $ "Exported " <> show i' <> " subscriptions" <> "\r" - hFlush stdout - pure i' - putStrLn $ "Exported " <> show sCnt <> " subscriptions" - pure sCnt - where - ntfSubQuery = - [sql| - SELECT s.token_id, s.subscription_id, s.smp_notifier_key, s.status, s.ntf_service_assoc, - p.smp_host, p.smp_port, p.smp_keyhash, s.smp_notifier_id - FROM subscriptions s - JOIN smp_servers p ON p.smp_server_id = s.smp_server_id - |] - toNtfSub :: Only NtfTokenId :. NtfSubRow :. SMPQueueNtfRow -> NtfSubRec - toNtfSub (Only tokenId :. (ntfSubId, notifierKey, subStatus, ntfServiceAssoc) :. qRow) = - let smpQueue = rowToSMPQueue qRow - in NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus, ntfServiceAssoc} - exportLastNtfs = - withFile lastNtfsFile WriteMode $ \h -> - withConnection s $ \db -> DB.fold_ db lastNtfsQuery 0 $ \ !i (Only tknId :. ntfRow) -> - B.hPutStr h (encodeLastNtf tknId $ toLastNtf ntfRow) $> (i + 1) - where - -- Note that the order here is ascending, to be compatible with how it is imported - lastNtfsQuery = - [sql| - SELECT s.token_id, p.smp_host, p.smp_port, p.smp_keyhash, s.smp_notifier_id, - n.sent_at, n.nmsg_nonce, n.nmsg_data - FROM last_notifications n - JOIN subscriptions s ON s.subscription_id = n.subscription_id - JOIN smp_servers p ON p.smp_server_id = s.smp_server_id - ORDER BY token_ntf_id ASC - |] - encodeLastNtf tknId ntf = strEncode (TNMRv1 tknId ntf) `B.snoc` '\n' - withFastDB' :: Text -> NtfPostgresStore -> (DB.Connection -> IO a) -> IO (Either ErrorType a) withFastDB' op st action = withFastDB op st $ fmap Right . action {-# INLINE withFastDB' #-} @@ -881,9 +720,12 @@ withDB_ op st priority action = where err = op <> ", withDB, " <> tshow e -withLog :: MonadIO m => Text -> NtfPostgresStore -> (StoreLog 'WriteMode -> IO ()) -> m () -withLog op NtfPostgresStore {dbStoreLog} = withLog_ op dbStoreLog -{-# INLINE withLog #-} +withClientDB :: Text -> NtfPostgresStore -> (DB.Connection -> IO a) -> IO (Either SMPClientError a) +withClientDB op st action = + E.uninterruptibleMask_ $ E.try (withTransaction (dbStore st) action) >>= bimapM logErr pure + where + logErr :: E.SomeException -> IO SMPClientError + logErr e = logError ("STORE: " <> op <> ", withDB, " <> tshow e) $> PCEIOError (E.displayException e) assertUpdated :: Int64 -> Either ErrorType () assertUpdated 0 = Left AUTH @@ -921,4 +763,9 @@ instance ToField C.KeyHash where toField = toField . Binary . strEncode instance FromField C.CbNonce where fromField = blobFieldDecoder $ parseAll smpP instance ToField C.CbNonce where toField = toField . Binary . smpEncode + +instance ToField X.PrivKey where toField = toField . Binary . C.encodeASNObj + +instance FromField X.PrivKey where + fromField = blobFieldDecoder $ C.decodeASNKey >=> \case (pk, []) -> Right pk; r -> C.asnKeyError r #endif diff --git a/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql b/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql index b73995684..801208aaa 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql +++ b/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql @@ -172,7 +172,10 @@ CREATE TABLE ntf_server.smp_servers ( smp_keyhash bytea NOT NULL, ntf_service_id bytea, smp_notifier_count bigint DEFAULT 0 NOT NULL, - smp_notifier_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL + smp_notifier_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL, + ntf_service_cert bytea, + ntf_service_cert_hash bytea, + ntf_service_priv_key bytea ); diff --git a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs b/src/Simplex/Messaging/Notifications/Server/StoreLog.hs deleted file mode 100644 index 7c71ddb08..000000000 --- a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs +++ /dev/null @@ -1,177 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DuplicateRecordFields #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE StrictData #-} -{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} - -module Simplex.Messaging.Notifications.Server.StoreLog - ( StoreLog, - NtfStoreLogRecord (..), - readWriteNtfSTMStore, - logCreateToken, - logTokenStatus, - logUpdateToken, - logTokenCron, - logDeleteToken, - logUpdateTokenTime, - logCreateSubscription, - logSubscriptionStatus, - logDeleteSubscription, - closeStoreLog, - ) -where - -import Control.Applicative (optional, (<|>)) -import Control.Concurrent.STM -import Control.Monad -import qualified Data.Attoparsec.ByteString.Char8 as A -import qualified Data.ByteString.Base64.URL as B64 -import qualified Data.ByteString.Char8 as B -import Data.Functor (($>)) -import qualified Data.Map.Strict as M -import Data.Maybe (fromMaybe) -import Data.Word (Word16) -import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Protocol -import Simplex.Messaging.Notifications.Server.Store -import Simplex.Messaging.Notifications.Server.Store.Types -import Simplex.Messaging.Protocol (EntityId (..), SMPServer, ServiceId) -import Simplex.Messaging.Server.StoreLog -import Simplex.Messaging.SystemTime -import System.IO - -data NtfStoreLogRecord - = CreateToken NtfTknRec - | TokenStatus NtfTokenId NtfTknStatus - | UpdateToken NtfTokenId DeviceToken NtfRegCode - | TokenCron NtfTokenId Word16 - | DeleteToken NtfTokenId - | UpdateTokenTime NtfTokenId SystemDate - | CreateSubscription NtfSubRec - | SubscriptionStatus NtfSubscriptionId NtfSubStatus NtfAssociatedService - | DeleteSubscription NtfSubscriptionId - | SetNtfService SMPServer (Maybe ServiceId) - deriving (Show) - -instance StrEncoding NtfStoreLogRecord where - strEncode = \case - CreateToken tknRec -> strEncode (Str "TCREATE", tknRec) - TokenStatus tknId tknStatus -> strEncode (Str "TSTATUS", tknId, tknStatus) - UpdateToken tknId token regCode -> strEncode (Str "TUPDATE", tknId, token, regCode) - TokenCron tknId cronInt -> strEncode (Str "TCRON", tknId, cronInt) - DeleteToken tknId -> strEncode (Str "TDELETE", tknId) - UpdateTokenTime tknId ts -> strEncode (Str "TTIME", tknId, ts) - CreateSubscription subRec -> strEncode (Str "SCREATE", subRec) - SubscriptionStatus subId subStatus serviceAssoc -> strEncode (Str "SSTATUS", subId, subStatus) <> serviceStr - where - serviceStr = if serviceAssoc then " service=" <> strEncode True else "" - DeleteSubscription subId -> strEncode (Str "SDELETE", subId) - SetNtfService srv serviceId -> strEncode (Str "SERVICE", srv) <> " service=" <> maybe "off" strEncode serviceId - strP = - A.choice - [ "TCREATE " *> (CreateToken <$> strP), - "TSTATUS " *> (TokenStatus <$> strP_ <*> strP), - "TUPDATE " *> (UpdateToken <$> strP_ <*> strP_ <*> strP), - "TCRON " *> (TokenCron <$> strP_ <*> strP), - "TDELETE " *> (DeleteToken <$> strP), - "TTIME " *> (UpdateTokenTime <$> strP_ <*> strP), - "SCREATE " *> (CreateSubscription <$> strP), - "SSTATUS " *> (SubscriptionStatus <$> strP_ <*> strP <*> (fromMaybe False <$> optional (" service=" *> strP))), - "SDELETE " *> (DeleteSubscription <$> strP), - "SERVICE " *> (SetNtfService <$> strP <* " service=" <*> ("off" $> Nothing <|> strP)) - ] - -logNtfStoreRecord :: StoreLog 'WriteMode -> NtfStoreLogRecord -> IO () -logNtfStoreRecord = writeStoreLogRecord -{-# INLINE logNtfStoreRecord #-} - -logCreateToken :: StoreLog 'WriteMode -> NtfTknRec -> IO () -logCreateToken s = logNtfStoreRecord s . CreateToken - -logTokenStatus :: StoreLog 'WriteMode -> NtfTokenId -> NtfTknStatus -> IO () -logTokenStatus s tknId tknStatus = logNtfStoreRecord s $ TokenStatus tknId tknStatus - -logUpdateToken :: StoreLog 'WriteMode -> NtfTokenId -> DeviceToken -> NtfRegCode -> IO () -logUpdateToken s tknId token regCode = logNtfStoreRecord s $ UpdateToken tknId token regCode - -logTokenCron :: StoreLog 'WriteMode -> NtfTokenId -> Word16 -> IO () -logTokenCron s tknId cronInt = logNtfStoreRecord s $ TokenCron tknId cronInt - -logDeleteToken :: StoreLog 'WriteMode -> NtfTokenId -> IO () -logDeleteToken s tknId = logNtfStoreRecord s $ DeleteToken tknId - -logUpdateTokenTime :: StoreLog 'WriteMode -> NtfTokenId -> SystemDate -> IO () -logUpdateTokenTime s tknId t = logNtfStoreRecord s $ UpdateTokenTime tknId t - -logCreateSubscription :: StoreLog 'WriteMode -> NtfSubRec -> IO () -logCreateSubscription s = logNtfStoreRecord s . CreateSubscription - -logSubscriptionStatus :: StoreLog 'WriteMode -> (NtfSubscriptionId, NtfSubStatus, NtfAssociatedService) -> IO () -logSubscriptionStatus s (subId, subStatus, serviceAssoc) = logNtfStoreRecord s $ SubscriptionStatus subId subStatus serviceAssoc - -logDeleteSubscription :: StoreLog 'WriteMode -> NtfSubscriptionId -> IO () -logDeleteSubscription s subId = logNtfStoreRecord s $ DeleteSubscription subId - -logSetNtfService :: StoreLog 'WriteMode -> SMPServer -> Maybe ServiceId -> IO () -logSetNtfService s srv serviceId = logNtfStoreRecord s $ SetNtfService srv serviceId - -readWriteNtfSTMStore :: Bool -> FilePath -> NtfSTMStore -> IO (StoreLog 'WriteMode) -readWriteNtfSTMStore tty = readWriteStoreLog (readNtfStore tty) writeNtfStore - -readNtfStore :: Bool -> FilePath -> NtfSTMStore -> IO () -readNtfStore tty f st = readLogLines tty f $ \_ -> processLine - where - processLine s = either printError procNtfLogRecord (strDecode s) - where - printError e = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> B.take 100 s - procNtfLogRecord = \case - CreateToken r@NtfTknRec {ntfTknId} -> do - tkn <- mkTknData r - atomically $ stmAddNtfToken st ntfTknId tkn - TokenStatus tknId status -> do - tkn_ <- stmGetNtfTokenIO st tknId - forM_ tkn_ $ \tkn@NtfTknData {tknStatus} -> do - atomically $ writeTVar tknStatus status - when (status == NTActive) $ void $ atomically $ stmRemoveInactiveTokenRegistrations st tkn - UpdateToken tknId token' tknRegCode -> do - stmGetNtfTokenIO st tknId - >>= mapM_ - ( \tkn@NtfTknData {tknStatus} -> do - atomically $ stmRemoveTokenRegistration st tkn - atomically $ writeTVar tknStatus NTRegistered - atomically $ stmAddNtfToken st tknId tkn {token = token', tknRegCode} - ) - TokenCron tknId cronInt -> - stmGetNtfTokenIO st tknId - >>= mapM_ (\NtfTknData {tknCronInterval} -> atomically $ writeTVar tknCronInterval cronInt) - DeleteToken tknId -> - atomically $ void $ stmDeleteNtfToken st tknId - UpdateTokenTime tknId t -> - stmGetNtfTokenIO st tknId - >>= mapM_ (\NtfTknData {tknUpdatedAt} -> atomically $ writeTVar tknUpdatedAt $ Just t) - CreateSubscription r@NtfSubRec {tokenId, ntfSubId} -> do - sub <- mkSubData r - atomically (stmAddNtfSubscription st ntfSubId sub) >>= \case - Just () -> pure () - Nothing -> B.putStrLn $ "Warning: no token " <> enc tokenId <> ", subscription " <> enc ntfSubId - where - enc = B64.encode . unEntityId - SubscriptionStatus subId status serviceAssoc -> do - stmGetNtfSubscriptionIO st subId >>= mapM_ update - where - update NtfSubData {subStatus, ntfServiceAssoc} = atomically $ do - writeTVar subStatus status - writeTVar ntfServiceAssoc serviceAssoc - DeleteSubscription subId -> - atomically $ stmDeleteNtfSubscription st subId - SetNtfService srv serviceId -> - atomically $ stmSetNtfService st srv serviceId - -writeNtfStore :: StoreLog 'WriteMode -> NtfSTMStore -> IO () -writeNtfStore s NtfSTMStore {tokens, subscriptions, ntfServices} = do - mapM_ (logCreateToken s <=< mkTknRec) =<< readTVarIO tokens - mapM_ (logCreateSubscription s <=< mkSubRec) =<< readTVarIO subscriptions - mapM_ (\(srv, serviceId) -> logSetNtfService s srv $ Just serviceId) . M.assocs =<< readTVarIO ntfServices diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 24247e781..21b03f3cf 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -46,6 +46,7 @@ module Simplex.Messaging.Server where import Control.Concurrent.STM (throwSTM) +import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad import Control.Monad.Except @@ -1385,7 +1386,7 @@ client Just r -> Just <$> proxyServerResponse a r Nothing -> forkProxiedCmd $ - liftIO (runExceptT (getSMPServerClient'' a srv) `catch` (pure . Left . PCEIOError)) + liftIO (runExceptT (getSMPServerClient'' a srv) `E.catch` (\(e :: SomeException) -> pure $ Left $ PCEIOError $ E.displayException e)) >>= proxyServerResponse a proxyServerResponse :: SMPClientAgent 'Sender -> Either SMPClientError (OwnServer, SMPClient) -> M s BrokerMsg proxyServerResponse a smp_ = do @@ -1422,7 +1423,7 @@ client inc own pRequests if v >= sendingProxySMPVersion then forkProxiedCmd $ do - liftIO (runExceptT (forwardSMPTransmission smp corrId fwdV pubKey encBlock) `catch` (pure . Left . PCEIOError)) >>= \case + liftIO (runExceptT (forwardSMPTransmission smp corrId fwdV pubKey encBlock) `E.catch` (\(e :: SomeException) -> pure $ Left $ PCEIOError $ E.displayException e)) >>= \case Right r -> PRES r <$ inc own pSuccesses Left e -> ERR (smpProxyError e) <$ case e of PCEProtocolError {} -> inc own pSuccesses diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index e59cd5c0b..574111c15 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -706,7 +706,7 @@ mkJournalStoreConfig queueStoreCfg storePath msgQueueQuota maxJournalMsgCount ma newSMPProxyAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO ProxyAgent newSMPProxyAgent smpAgentCfg random = do - smpAgent <- newSMPClientAgent SSender smpAgentCfg random + smpAgent <- newSMPClientAgent SSender smpAgentCfg Nothing random pure ProxyAgent {smpAgent} readWriteQueueStore :: forall q. StoreQueueClass q => Bool -> (RecipientId -> QueueRec -> IO q) -> FilePath -> STMQueueStore q -> IO (StoreLog 'WriteMode) diff --git a/src/Simplex/Messaging/Transport/HTTP2/Client.hs b/src/Simplex/Messaging/Transport/HTTP2/Client.hs index 91a8bf0e5..e805fa86c 100644 --- a/src/Simplex/Messaging/Transport/HTTP2/Client.hs +++ b/src/Simplex/Messaging/Transport/HTTP2/Client.hs @@ -11,7 +11,6 @@ module Simplex.Messaging.Transport.HTTP2.Client where import Control.Concurrent.Async -import Control.Exception (IOException, try) import qualified Control.Exception as E import Control.Monad import Data.Functor (($>)) @@ -90,7 +89,7 @@ defaultHTTP2ClientConfig = suportedTLSParams = http2TLSParams } -data HTTP2ClientError = HCResponseTimeout | HCNetworkError NetworkError | HCIOError IOException +data HTTP2ClientError = HCResponseTimeout | HCNetworkError NetworkError | HCIOError String deriving (Show) getHTTP2Client :: HostName -> ServiceName -> Maybe XS.CertificateStore -> HTTP2ClientConfig -> IO () -> IO (Either HTTP2ClientError HTTP2Client) @@ -111,7 +110,7 @@ attachHTTP2Client config host port disconnected bufferSize tls = getVerifiedHTTP getVerifiedHTTP2ClientWith :: forall p. TransportPeerI p => HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> ((TLS p -> H.Client HTTP2Response) -> IO HTTP2Response) -> IO (Either HTTP2ClientError HTTP2Client) getVerifiedHTTP2ClientWith config host port disconnected setup = (mkHTTPS2Client >>= runClient) - `E.catch` \(e :: IOException) -> pure . Left $ HCIOError e + `E.catch` \(e :: E.SomeException) -> pure $ Left $ HCIOError $ E.displayException e where mkHTTPS2Client :: IO HClient mkHTTPS2Client = do @@ -177,9 +176,9 @@ sendRequest HTTP2Client {client_ = HClient {config, reqQ}} req reqTimeout_ = do sendRequestDirect :: HTTP2Client -> Request -> Maybe Int -> IO (Either HTTP2ClientError HTTP2Response) sendRequestDirect HTTP2Client {client_ = HClient {config, disconnected}, sendReq} req reqTimeout_ = do let reqTimeout = http2RequestTimeout config reqTimeout_ - reqTimeout `timeout` try (sendReq req process) >>= \case + reqTimeout `timeout` E.try (sendReq req process) >>= \case Just (Right r) -> pure $ Right r - Just (Left e) -> disconnected $> Left (HCIOError e) + Just (Left (e :: E.SomeException)) -> disconnected $> Left (HCIOError $ E.displayException e) Nothing -> pure $ Left HCResponseTimeout where process r = do diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 34448fc10..18cdfd1fa 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -3677,6 +3677,7 @@ testClientServiceConnection ps = do exchangeGreetings service uId user sId pure conns withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> runRight $ do + liftIO $ threadDelay 250000 [(_, Right (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash)))] <- M.toList <$> subscribeClientServices service 1 ("", "", SERVICE_ALL _) <- nGet service subscribeConnection user sId @@ -3684,6 +3685,7 @@ testClientServiceConnection ps = do pure (conns, qIdHash) (uId', sId') <- withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + liftIO $ threadDelay 250000 subscribeAllConnections service False Nothing liftIO $ getInAnyOrder service [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash')))) -> qIdHash' == qIdHash; _ -> False, @@ -3708,6 +3710,7 @@ testClientServiceConnection ps = do pure conns' withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + liftIO $ threadDelay 250000 subscribeAllConnections service False Nothing liftIO $ getInAnyOrder service [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 2 _)))) -> True; _ -> False, From bafdbc1dec778021eacbf621f1467ca78287d2a4 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Thu, 25 Dec 2025 13:00:29 +0000 Subject: [PATCH 09/11] smp protocol: fix encoding for SOKS/ENDS responses (#1683) --- src/Simplex/Messaging/Protocol.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 4993aaac8..25b8ce357 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1948,7 +1948,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where e :: Encoding a => a -> ByteString e = smpEncode serviceResp tag n idsHash - | v >= serviceCertsSMPVersion = e (tag, ' ', n, idsHash) + | v >= rcvServiceSMPVersion = e (tag, ' ', n, idsHash) | otherwise = e (tag, ' ', n) protocolP v = \case @@ -1993,7 +1993,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where PONG_ -> pure PONG where serviceRespP resp - | v >= serviceCertsSMPVersion = resp <$> _smpP <*> smpP + | v >= rcvServiceSMPVersion = resp <$> _smpP <*> smpP | otherwise = resp <$> _smpP <*> pure mempty fromProtocolError = \case From db4b27e88a95af5b295d393b4c4483ffd220fafb Mon Sep 17 00:00:00 2001 From: Evgeny Date: Sat, 27 Dec 2025 09:12:22 +0000 Subject: [PATCH 10/11] agent: create user with option to enable client service (#1684) * agent: create user with option to enable client service * handle HTTP2 errors * do not catch async exceptions --- src/Simplex/FileTransfer/Client.hs | 16 ++++++------- src/Simplex/Messaging/Agent.hs | 23 +++++++++++++------ src/Simplex/Messaging/Client.hs | 12 ++++++++-- src/Simplex/Messaging/Client/Agent.hs | 6 ++--- .../Notifications/Server/Store/Postgres.hs | 2 +- src/Simplex/Messaging/Server.hs | 6 ++--- .../Messaging/Transport/HTTP2/Client.hs | 14 ++++++++--- tests/AgentTests/FunctionalAPITests.hs | 10 ++++---- 8 files changed, 57 insertions(+), 32 deletions(-) diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index a425138e5..d8ed04bc8 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -47,6 +47,7 @@ import Simplex.Messaging.Client transportClientConfig, clientSocksCredentials, unexpectedResponse, + clientHandlers, useWebPort, ) import qualified Simplex.Messaging.Crypto as C @@ -61,7 +62,6 @@ import Simplex.Messaging.Protocol SenderId, pattern NoEntity, NetworkError (..), - toNetworkError, ) import Simplex.Messaging.Transport (ALPN, CertChainPubKey (..), HandshakeError (..), THandleAuth (..), THandleParams (..), TransportError (..), TransportPeer (..), defaultSupportedParams) import Simplex.Messaging.Transport.Client (TransportClientConfig (..), TransportHost) @@ -70,8 +70,10 @@ import Simplex.Messaging.Transport.HTTP2.Client import Simplex.Messaging.Transport.HTTP2.File import Simplex.Messaging.Util (liftEitherWith, liftError', tshow, whenM) import Simplex.Messaging.Version -import UnliftIO +import System.IO (IOMode (..), SeekMode (..), hSeek, withFile) +import System.Timeout (timeout) import UnliftIO.Directory +import UnliftIO.STM data XFTPClient = XFTPClient { http2Client :: HTTP2Client, @@ -261,13 +263,11 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec { let dhSecret = C.dh' sDhKey rpDhKey cbState <- liftEither . first PCECryptoError $ LC.cbInit dhSecret cbNonce let t = chunkTimeout config chunkSize - ExceptT (sequence <$> (t `timeout` (download cbState `catches` errors))) >>= maybe (throwE PCEResponseTimeout) pure + ExceptT (sequence <$> (t `timeout` (download cbState `E.catches` handlers))) >>= maybe (throwE PCEResponseTimeout) pure where - errors = - [ Handler $ \(e :: H.HTTP2Error) -> pure $ Left $ PCENetworkError $ NEConnectError $ displayException e, - Handler $ \(e :: IOException) -> pure $ Left $ PCEIOError $ E.displayException e, - Handler $ \(e :: SomeException) -> pure $ Left $ PCENetworkError $ toNetworkError e - ] + handlers = + E.Handler (\(e :: H.HTTP2Error) -> pure $ Left $ PCENetworkError $ NEConnectError $ E.displayException e) + : clientHandlers download cbState = runExceptT . withExceptT PCEResponseError $ receiveEncFile chunkPart cbState chunkSpec `catchError` \e -> diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index e17c39a16..4acf880dd 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -337,8 +337,8 @@ resumeAgentClient :: AgentClient -> IO () resumeAgentClient c = atomically $ writeTVar (active c) True {-# INLINE resumeAgentClient #-} -createUser :: AgentClient -> NonEmpty (ServerCfg 'PSMP) -> NonEmpty (ServerCfg 'PXFTP) -> AE UserId -createUser c = withAgentEnv c .: createUser' c +createUser :: AgentClient -> Bool -> NonEmpty (ServerCfg 'PSMP) -> NonEmpty (ServerCfg 'PXFTP) -> AE UserId +createUser c = withAgentEnv c .:. createUser' c {-# INLINE createUser #-} -- | Delete user record optionally deleting all user's connections on SMP servers @@ -754,14 +754,23 @@ logConnection c connected = let event = if connected then "connected to" else "disconnected from" in logInfo $ T.unwords ["client", tshow (clientId c), event, "Agent"] -createUser' :: AgentClient -> NonEmpty (ServerCfg 'PSMP) -> NonEmpty (ServerCfg 'PXFTP) -> AM UserId -createUser' c smp xftp = do +createUser' :: AgentClient -> Bool -> NonEmpty (ServerCfg 'PSMP) -> NonEmpty (ServerCfg 'PXFTP) -> AM UserId +createUser' c useService smp xftp = do liftIO $ checkUserServers "createUser SMP" smp liftIO $ checkUserServers "createUser XFTP" xftp userId <- withStore' c createUserRecord - atomically $ TM.insert userId (mkUserServers smp) $ smpServers c - atomically $ TM.insert userId (mkUserServers xftp) $ xftpServers c - atomically $ TM.insert userId False $ useClientServices c + ok <- atomically $ do + (cfg, _) <- readTVar $ useNetworkConfig c + if useService && sessionMode cfg == TSMEntity + then pure False + else do + TM.insert userId (mkUserServers smp) $ smpServers c + TM.insert userId (mkUserServers xftp) $ xftpServers c + TM.insert userId useService $ useClientServices c + pure True + unless ok $ do + withStore c (`deleteUserRecord` userId) + throwE $ CMD PROHIBITED "createUser'" pure userId deleteUser' :: AgentClient -> UserId -> Bool -> AM () diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index ebc458c0e..bfd45f3a1 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -107,6 +107,7 @@ module Simplex.Messaging.Client smpProxyError, smpErrorClientNotice, textToHostMode, + clientHandlers, ServerTransmissionBatch, ServerTransmission (..), ClientCommand, @@ -129,7 +130,7 @@ import Control.Applicative ((<|>)) import Control.Concurrent (ThreadId, forkFinally, forkIO, killThread, mkWeakThreadId) import Control.Concurrent.Async import Control.Concurrent.STM -import Control.Exception (Exception, SomeException) +import Control.Exception (Exception, Handler (..), IOException, SomeAsyncException, SomeException) import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad @@ -567,7 +568,7 @@ getProtocolClient g nm transportSession@(_, srv, _) cfg@ProtocolClientConfig {qS case chooseTransportHost networkConfig (host srv) of Right useHost -> (getCurrentTime >>= mkProtocolClient useHost >>= runClient useTransport useHost) - `E.catch` \(e :: SomeException) -> pure $ Left $ PCEIOError $ E.displayException e + `E.catches` clientHandlers Left e -> pure $ Left e where NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig @@ -719,6 +720,13 @@ getProtocolClient g nm transportSession@(_, srv, _) cfg@ProtocolClientConfig {qS Left e -> logError $ "SMP client error: " <> tshow e Right _ -> logWarn "SMP client unprocessed event" +clientHandlers :: [Handler (Either (ProtocolClientError e) a)] +clientHandlers = + [ Handler $ \(e :: IOException) -> pure $ Left $ PCEIOError $ E.displayException e, + Handler $ \(e :: SomeAsyncException) -> E.throwIO e, + Handler $ \(e :: SomeException) -> pure $ Left $ PCENetworkError $ toNetworkError e + ] + useWebPort :: NetworkConfig -> [HostName] -> ProtocolServer p -> Bool useWebPort cfg presetDomains ProtocolServer {host = h :| _} = case smpWebPortServers cfg of SWPAll -> True diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 9739c19c7..d302ba237 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -37,6 +37,7 @@ where import Control.Concurrent (forkIO) import Control.Concurrent.Async (Async, uninterruptibleCancel) import Control.Concurrent.STM (retry) +import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad import Control.Monad.Except @@ -83,7 +84,6 @@ import Simplex.Messaging.Transport import Simplex.Messaging.Util (catchAll_, ifM, safeDecodeUtf8, toChunks, tshow, whenM, ($>>=), (<$$>)) import System.Timeout (timeout) import UnliftIO (async) -import qualified UnliftIO.Exception as E import UnliftIO.STM type SMPClientVar = SessionVar (Either (SMPClientError, Maybe UTCTime) (OwnServer, SMPClient)) @@ -226,7 +226,7 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke newSMPClient :: SMPClientVar -> IO (Either SMPClientError (OwnServer, SMPClient)) newSMPClient v = do - r <- connectClient ca srv v `E.catch` \(e :: E.SomeException) -> pure $ Left $ PCEIOError $ E.displayException e + r <- connectClient ca srv v `E.catches` clientHandlers case r of Right smp -> do logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv @@ -324,7 +324,7 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s (Just <$> getSessVar workerSeq srv smpSubWorkers ts) newSubWorker :: SessionVar (Async ()) -> IO () newSubWorker v = do - a <- async $ void (E.tryAny runSubWorker) >> atomically (cleanup v) + a <- async $ void (E.try @E.SomeException runSubWorker) >> atomically (cleanup v) atomically $ putTMVar (sessionVar v) a runSubWorker = withRetryInterval (reconnectInterval agentCfg) $ \_ loop -> do diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs index 80ab45ca1..54668d45c 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -586,7 +586,7 @@ removeServiceAndAssociations st srv = do withDB "removeServiceAndAssociations" st $ \db -> runExceptT $ do srvId <- ExceptT $ getServerId db subsCount <- liftIO $ removeServiceAssociation_ db srvId - liftIO $ removeServerService db srvId + liftIO $ void $ removeServerService db srvId pure (srvId, fromIntegral subsCount) where getServerId db = diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 21b03f3cf..3d977dc8c 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -97,7 +97,7 @@ import Network.Socket (ServiceName, Socket, socketToHandle) import qualified Network.TLS as TLS import Numeric.Natural (Natural) import Simplex.Messaging.Agent.Lock -import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, forwardSMPTransmission, smpProxyError, temporaryClientError) +import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, clientHandlers, forwardSMPTransmission, smpProxyError, temporaryClientError) import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient, getConnectedSMPServerClient) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding @@ -1386,7 +1386,7 @@ client Just r -> Just <$> proxyServerResponse a r Nothing -> forkProxiedCmd $ - liftIO (runExceptT (getSMPServerClient'' a srv) `E.catch` (\(e :: SomeException) -> pure $ Left $ PCEIOError $ E.displayException e)) + liftIO (runExceptT (getSMPServerClient'' a srv) `E.catches` clientHandlers) >>= proxyServerResponse a proxyServerResponse :: SMPClientAgent 'Sender -> Either SMPClientError (OwnServer, SMPClient) -> M s BrokerMsg proxyServerResponse a smp_ = do @@ -1423,7 +1423,7 @@ client inc own pRequests if v >= sendingProxySMPVersion then forkProxiedCmd $ do - liftIO (runExceptT (forwardSMPTransmission smp corrId fwdV pubKey encBlock) `E.catch` (\(e :: SomeException) -> pure $ Left $ PCEIOError $ E.displayException e)) >>= \case + liftIO (runExceptT (forwardSMPTransmission smp corrId fwdV pubKey encBlock) `E.catches` clientHandlers) >>= \case Right r -> PRES r <$ inc own pSuccesses Left e -> ERR (smpProxyError e) <$ case e of PCEProtocolError {} -> inc own pSuccesses diff --git a/src/Simplex/Messaging/Transport/HTTP2/Client.hs b/src/Simplex/Messaging/Transport/HTTP2/Client.hs index e805fa86c..09a1089ea 100644 --- a/src/Simplex/Messaging/Transport/HTTP2/Client.hs +++ b/src/Simplex/Messaging/Transport/HTTP2/Client.hs @@ -11,6 +11,7 @@ module Simplex.Messaging.Transport.HTTP2.Client where import Control.Concurrent.Async +import Control.Exception (Handler (..), IOException, SomeAsyncException, SomeException) import qualified Control.Exception as E import Control.Monad import Data.Functor (($>)) @@ -92,6 +93,13 @@ defaultHTTP2ClientConfig = data HTTP2ClientError = HCResponseTimeout | HCNetworkError NetworkError | HCIOError String deriving (Show) +httpClientHandlers :: [Handler (Either HTTP2ClientError a)] +httpClientHandlers = + [ Handler $ \(e :: IOException) -> pure $ Left $ HCIOError $ E.displayException e, + Handler $ \(e :: SomeAsyncException) -> E.throwIO e, + Handler $ \(e :: SomeException) -> pure $ Left $ HCNetworkError $ toNetworkError e + ] + getHTTP2Client :: HostName -> ServiceName -> Maybe XS.CertificateStore -> HTTP2ClientConfig -> IO () -> IO (Either HTTP2ClientError HTTP2Client) getHTTP2Client host port = getVerifiedHTTP2Client Nothing (THDomainName host) port Nothing @@ -110,7 +118,7 @@ attachHTTP2Client config host port disconnected bufferSize tls = getVerifiedHTTP getVerifiedHTTP2ClientWith :: forall p. TransportPeerI p => HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> ((TLS p -> H.Client HTTP2Response) -> IO HTTP2Response) -> IO (Either HTTP2ClientError HTTP2Client) getVerifiedHTTP2ClientWith config host port disconnected setup = (mkHTTPS2Client >>= runClient) - `E.catch` \(e :: E.SomeException) -> pure $ Left $ HCIOError $ E.displayException e + `E.catches` httpClientHandlers where mkHTTPS2Client :: IO HClient mkHTTPS2Client = do @@ -176,9 +184,9 @@ sendRequest HTTP2Client {client_ = HClient {config, reqQ}} req reqTimeout_ = do sendRequestDirect :: HTTP2Client -> Request -> Maybe Int -> IO (Either HTTP2ClientError HTTP2Response) sendRequestDirect HTTP2Client {client_ = HClient {config, disconnected}, sendReq} req reqTimeout_ = do let reqTimeout = http2RequestTimeout config reqTimeout_ - reqTimeout `timeout` E.try (sendReq req process) >>= \case + reqTimeout `timeout` ((Right <$> sendReq req process) `E.catches` httpClientHandlers) >>= \case Just (Right r) -> pure $ Right r - Just (Left (e :: E.SomeException)) -> disconnected $> Left (HCIOError $ E.displayException e) + Just (Left e) -> disconnected $> Left e Nothing -> pure $ Left HCResponseTimeout where process r = do diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 18cdfd1fa..62f0facd3 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -1018,7 +1018,7 @@ testUpdateConnectionUserId :: HasCallStack => IO () testUpdateConnectionUserId = withAgentClients2 $ \alice bob -> runRight_ $ do (connId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - newUserId <- createUser alice [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] + newUserId <- createUser alice False [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] _ <- changeConnectionUser alice 1 connId newUserId aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn sqSecured' <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe @@ -3001,7 +3001,7 @@ testUsers = withAgentClients2 $ \a b -> runRight_ $ do (aId, bId) <- makeConnection a b exchangeGreetings a bId b aId - auId <- createUser a [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] + auId <- createUser a False [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] (aId', bId') <- makeConnectionForUsers a auId b 1 exchangeGreetings a bId' b aId' deleteUser a auId True @@ -3016,7 +3016,7 @@ testDeleteUserQuietly = withAgentClients2 $ \a b -> runRight_ $ do (aId, bId) <- makeConnection a b exchangeGreetings a bId b aId - auId <- createUser a [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] + auId <- createUser a False [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] (aId', bId') <- makeConnectionForUsers a auId b 1 exchangeGreetings a bId' b aId' deleteUser a auId False @@ -3028,7 +3028,7 @@ testUsersNoServer ps = withAgentClientsCfg2 aCfg agentCfg $ \a b -> do (aId, bId, auId, _aId', bId') <- withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do (aId, bId) <- makeConnection a b exchangeGreetings a bId b aId - auId <- createUser a [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] + auId <- createUser a False [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] (aId', bId') <- makeConnectionForUsers a auId b 1 exchangeGreetings a bId' b aId' pure (aId, bId, auId, aId', bId') @@ -3628,7 +3628,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do ("", "", UP _ _) <- nGet a a `hasClients` 1 - aUserId2 <- createUser a [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] + aUserId2 <- createUser a False [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] (aId2, bId2) <- makeConnectionForUsers a aUserId2 b 1 exchangeGreetings a bId2 b aId2 (aId2', bId2') <- makeConnectionForUsers a aUserId2 b 1 From 502d92381729d5f42ec88fe07d54d0913b50b7da Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Sat, 17 Jan 2026 10:21:25 +0000 Subject: [PATCH 11/11] agent: minor fixes --- src/Simplex/Messaging/Agent/Client.hs | 3 +-- tests/AgentTests/FunctionalAPITests.hs | 11 ++++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index c42c0fa34..d8df98d1b 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -1545,9 +1545,8 @@ processSubResults c tSess@(userId, srv, _) sessId serviceId_ rs = do (Map SMP.RecipientId SMPClientError, ([RcvQueueSub], [RcvQueueSub]), [(RcvQueueSub, Maybe ClientNotice)], Int) partitionResults pendingSubs (rq@RcvQueueSub {rcvId, clientNoticeId}, r) acc@(failed, subscribed@(qs, sQs), notices, ignored) = case r of Left e -> case smpErrorClientNotice e of - Just notice_ -> (failed', subscribed, (rq, notice_) : notices, ignored) + Just notice_ -> (failed', subscribed, notices', ignored) where - -- TODO [certs rcv] not used? notices' = if isJust notice_ || isJust clientNoticeId then (rq, notice_) : notices else notices Nothing | temporaryClientError e -> acc diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 2aa5c0aca..11548c9e9 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -122,6 +122,8 @@ import XFTPClient (testXFTPServer) #if defined(dbPostgres) import Fixtures +import Simplex.Messaging.Agent.Store (RcvQueue, RcvQueueSub (..), ServiceAssoc) +import Simplex.Messaging.Agent.Store.AgentStore (deleteClientService, getSubscriptionService, getUserServerRcvQueueSubs, removeRcvServiceAssocs, setRcvServiceAssocs) #endif #if defined(dbServerPostgres) import qualified Database.PostgreSQL.Simple as PSQL @@ -786,7 +788,7 @@ runAgentClientStressTestOneWay n pqSupport sqSecured viaProxy alice bob baseId = msgId = subtract baseId . fst runAgentClientStressTestConc :: HasCallStack => Int64 -> PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientStressTestConc n pqSupport sqSecured viaProxy alice bob baseId = runRight_ $ do +runAgentClientStressTestConc n pqSupport sqSecured viaProxy alice bob _baseId = runRight_ $ do (aliceId, bobId) <- makeConnection_ pqSupport sqSecured alice bob amId <- newTVarIO 0 bmId <- newTVarIO 0 @@ -803,7 +805,6 @@ runAgentClientStressTestConc n pqSupport sqSecured viaProxy alice bob baseId = r liftIO $ noMessagesIngoreQCONT alice "nothing else should be delivered to alice" liftIO $ noMessagesIngoreQCONT bob "nothing else should be delivered to bob" where - msgId = subtract baseId . fst pqEnc = PQEncryption $ supportPQ pqSupport proxySrv = if viaProxy then Just testSMPServer else Nothing message i = "message " <> bshow i @@ -816,11 +817,11 @@ runAgentClientStressTestConc n pqSupport sqSecured viaProxy alice bob baseId = r timeout 100000 (get a) >>= mapM_ (\case ("", _, QCONT) -> drain; r -> expectationFailure $ "unexpected: " <> show r) loop (0, 0, 0, 0) = pure () - loop acc@(!s, !m, !r, !o) = + loop acc@(s, !m, !r, !o) = timeout 3000000 (get a) >>= \case Nothing -> error $ "timeout " <> show acc Just evt -> case evt of - ("", c, A.SENT mId srv) -> do + ("", c, A.SENT _mId srv) -> do liftIO $ c == bId && srv == proxySrv `shouldBe` True unless (s > 0) $ error "unexpected SENT" loop (s - 1, m, r, o) @@ -834,7 +835,7 @@ runAgentClientStressTestConc n pqSupport sqSecured viaProxy alice bob baseId = r ackMessageAsync a "123" bId mId (Just "") unless (m > 0) $ error "unexpected MSG" loop (s, m - 1, r, o) - ("", c, Rcvd' mId rcvdMsgId) -> do + ("", c, Rcvd' mId _rcvdMsgId) -> do liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True liftIO $ c == bId `shouldBe` True ackMessageAsync a "123" bId mId Nothing