diff --git a/simplexmq.cabal b/simplexmq.cabal index f6ab07e083..d72d3f02c0 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -266,6 +266,7 @@ library Simplex.Messaging.Notifications.Server.Prometheus Simplex.Messaging.Notifications.Server.Push Simplex.Messaging.Notifications.Server.Push.APNS + Simplex.Messaging.Notifications.Server.Push.WebPush Simplex.Messaging.Notifications.Server.Push.APNS.Internal Simplex.Messaging.Notifications.Server.Stats Simplex.Messaging.Notifications.Server.Store @@ -298,6 +299,7 @@ library , attoparsec ==0.14.* , base >=4.14 && <5 , base64-bytestring >=1.0 && <1.3 + , binary ==0.8.* , composition ==1.0.* , constraints >=0.12 && <0.14 , containers ==0.6.* @@ -310,6 +312,7 @@ library , directory ==1.3.* , filepath ==1.4.* , hourglass ==0.2.* + , http-client ==0.7.* , http-types ==0.12.* , http2 >=4.2.2 && <4.3 , iproute ==1.7.* @@ -341,6 +344,7 @@ library case-insensitive ==1.2.* , hashable ==1.4.* , ini ==0.4.1 + , http-client-tls ==0.3.6.* , optparse-applicative >=0.15 && <0.17 , process ==1.6.* , temporary ==1.3.* @@ -510,6 +514,7 @@ test-suite simplexmq-test AgentTests.NotificationTests NtfClient NtfServerTests + NtfWPTests PostgresSchemaDump hs-source-dirs: tests diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 217a1682a0..e06d0d6371 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -1322,7 +1322,7 @@ runNTFServerTest c@AgentClient {presetDomains} nm userId (ProtoServerWithAuth sr (nKey, npKey) <- atomically $ C.generateAuthKeyPair a g (dhKey, _) <- atomically $ C.generateKeyPair g r <- runExceptT $ do - let deviceToken = DeviceToken PPApnsNull "test_ntf_token" + let deviceToken = APNSDeviceToken PPApnsNull "test_ntf_token" (tknId, _) <- liftError (testErr TSCreateNtfToken) $ ntfRegisterToken ntf nm npKey (NewNtfTkn deviceToken nKey dhKey) liftError (testErr TSDeleteNtfToken) $ ntfDeleteToken ntf nm npKey tknId ok <- netTimeoutInt (tcpTimeout $ networkConfig cfg) nm `timeout` closeProtocolClient ntf diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index ef66eca38b..091b8826fb 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -294,7 +294,7 @@ import Simplex.Messaging.Crypto.Ratchet (PQEncryption (..), PQSupport (..), Ratc import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..)) +import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..), deviceTokenFields, deviceToken') import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Protocol @@ -1420,7 +1420,8 @@ deleteCommand db cmdId = DB.execute db "DELETE FROM commands WHERE command_id = ?" (Only cmdId) createNtfToken :: DB.Connection -> NtfToken -> IO () -createNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey), ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} = do +createNtfToken db NtfToken {deviceToken, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey), ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} = do + let (provider, token) = deviceTokenFields deviceToken upsertNtfServer_ db srv DB.execute db @@ -1447,10 +1448,12 @@ getSavedNtfToken db = do let ntfServer = NtfServer host port keyHash ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey) ntfMode = fromMaybe NMPeriodic ntfMode_ - in NtfToken {deviceToken = DeviceToken provider dt, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} + deviceToken = deviceToken' provider dt + in NtfToken {deviceToken, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} updateNtfTokenRegistration :: DB.Connection -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> IO () -updateNtfTokenRegistration db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknId ntfDhSecret = do +updateNtfTokenRegistration db NtfToken {deviceToken, ntfServer = ProtocolServer {host, port}} tknId ntfDhSecret = do + let (provider, token) = deviceTokenFields deviceToken updatedAt <- getCurrentTime DB.execute db @@ -1462,8 +1465,10 @@ updateNtfTokenRegistration db NtfToken {deviceToken = DeviceToken provider token (tknId, ntfDhSecret, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port) updateDeviceToken :: DB.Connection -> NtfToken -> DeviceToken -> IO () -updateDeviceToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} (DeviceToken toProvider toToken) = do +updateDeviceToken db NtfToken {deviceToken, ntfServer = ProtocolServer {host, port}} toDt = do + let (provider, token) = deviceTokenFields deviceToken updatedAt <- getCurrentTime + let (toProvider, toToken) = deviceTokenFields toDt DB.execute db [sql| @@ -1474,7 +1479,8 @@ updateDeviceToken db NtfToken {deviceToken = DeviceToken provider token, ntfServ (toProvider, toToken, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port) updateNtfMode :: DB.Connection -> NtfToken -> NotificationsMode -> IO () -updateNtfMode db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} ntfMode = do +updateNtfMode db NtfToken {deviceToken, ntfServer = ProtocolServer {host, port}} ntfMode = do + let (provider, token) = deviceTokenFields deviceToken updatedAt <- getCurrentTime DB.execute db @@ -1486,7 +1492,8 @@ updateNtfMode db NtfToken {deviceToken = DeviceToken provider token, ntfServer = (ntfMode, updatedAt, provider, token, host, port) updateNtfToken :: DB.Connection -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> IO () -updateNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknStatus tknAction = do +updateNtfToken db NtfToken {deviceToken, ntfServer = ProtocolServer {host, port}} tknStatus tknAction = do + let (provider, token) = deviceTokenFields deviceToken updatedAt <- getCurrentTime DB.execute db @@ -1498,7 +1505,8 @@ updateNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer (tknStatus, tknAction, updatedAt, provider, token, host, port) removeNtfToken :: DB.Connection -> NtfToken -> IO () -removeNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} = +removeNtfToken db NtfToken {deviceToken, ntfServer = ProtocolServer {host, port}} = do + let (provider, token) = deviceTokenFields deviceToken DB.execute db [sql| @@ -1823,7 +1831,8 @@ getActiveNtfToken db = let ntfServer = NtfServer host port keyHash ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey) ntfMode = fromMaybe NMPeriodic ntfMode_ - in NtfToken {deviceToken = DeviceToken provider dt, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} + deviceToken = deviceToken' provider dt + in NtfToken {deviceToken, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} getNtfRcvQueue :: DB.Connection -> SMPQueueNtf -> IO (Either StoreError (ConnId, Int64, RcvNtfDhSecret, Maybe UTCTime)) getNtfRcvQueue db SMPQueueNtf {smpServer = (SMPServer host port _), notifierId} = diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 9cc78acb30..bf2a4ac3b5 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -87,6 +87,7 @@ module Simplex.Messaging.Crypto signatureKeyPair, publicToX509, encodeASNObj, + readECPrivateKey, -- * key encoding/decoding encodePubKey, @@ -94,6 +95,10 @@ module Simplex.Messaging.Crypto encodePrivKey, decodePrivKey, pubKeyBytes, + encodeBigInt, + uncompressEncodePoint, + uncompressDecodePoint, + uncompressDecodePrivateNumber, -- * sign/verify Signature (..), @@ -128,6 +133,7 @@ module Simplex.Messaging.Crypto encryptAEAD, decryptAEAD, encryptAESNoPad, + encryptAES128NoPad, decryptAESNoPad, authTagSize, randomAesKey, @@ -210,24 +216,29 @@ import Control.Exception (Exception) import Control.Monad import Control.Monad.Except import Control.Monad.Trans.Except -import Crypto.Cipher.AES (AES256) +import Crypto.Cipher.AES (AES128, AES256) import qualified Crypto.Cipher.Types as AES import qualified Crypto.Cipher.XSalsa as XSalsa import qualified Crypto.Error as CE -import Crypto.Hash (Digest, SHA3_256, SHA3_384, SHA256 (..), SHA512 (..), hash, hashDigestSize) +import Crypto.Hash (Digest, SHA256 (..), SHA3_256, SHA3_384, SHA512 (..), hash, hashDigestSize) import qualified Crypto.KDF.HKDF as H import qualified Crypto.MAC.Poly1305 as Poly1305 import qualified Crypto.PubKey.Curve25519 as X25519 import qualified Crypto.PubKey.Curve448 as X448 +import qualified Crypto.PubKey.ECC.ECDSA as ECDSA +import qualified Crypto.PubKey.ECC.Types as ECC import qualified Crypto.PubKey.Ed25519 as Ed25519 import qualified Crypto.PubKey.Ed448 as Ed448 import Crypto.Random (ChaChaDRG, MonadPseudoRandom, drgNew, randomBytesGenerate, withDRG) +import qualified Crypto.Store.PKCS8 as PK import Data.ASN1.BinaryEncoding import Data.ASN1.Encoding import Data.ASN1.Types import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (bimap, first) +import qualified Data.Binary as Bin +import qualified Data.Bits as Bits import Data.ByteArray (ByteArrayAccess) import qualified Data.ByteArray as BA import Data.ByteString.Base64 (decode, encode) @@ -235,13 +246,14 @@ import qualified Data.ByteString.Base64.URL as U import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.ByteString.Lazy (fromStrict, toStrict) +import qualified Data.ByteString.Lazy as LB import Data.Constraint (Dict (..)) import Data.Kind (Constraint, Type) import qualified Data.List.NonEmpty as L import Data.String import Data.Type.Equality import Data.Typeable (Proxy (Proxy), Typeable) -import Data.Word (Word32) +import Data.Word (Word32, Word64) import qualified Data.X509 as X import Data.X509.Validation (Fingerprint (..), getFingerprint) import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+)) @@ -1039,9 +1051,20 @@ encryptAESNoPad :: Key -> GCMIV -> ByteString -> ExceptT CryptoError IO (AuthTag encryptAESNoPad key iv = encryptAEADNoPad key iv "" {-# INLINE encryptAESNoPad #-} +-- Used to encrypt WebPush notifications +-- This function requires 12 bytes IV, it does not transform IV. +encryptAES128NoPad :: Key -> GCMIV -> ByteString -> ExceptT CryptoError IO (AuthTag, ByteString) +encryptAES128NoPad key iv = encryptAEAD128NoPad key iv "" +{-# INLINE encryptAES128NoPad #-} + encryptAEADNoPad :: Key -> GCMIV -> ByteString -> ByteString -> ExceptT CryptoError IO (AuthTag, ByteString) encryptAEADNoPad aesKey ivBytes ad msg = do - aead <- initAEADGCM aesKey ivBytes + aead <- initAEADGCM @AES256 aesKey ivBytes + pure . first AuthTag $ AES.aeadSimpleEncrypt aead ad msg authTagSize + +encryptAEAD128NoPad :: Key -> GCMIV -> ByteString -> ByteString -> ExceptT CryptoError IO (AuthTag, ByteString) +encryptAEAD128NoPad aesKey ivBytes ad msg = do + aead <- initAEADGCM @AES128 aesKey ivBytes pure . first AuthTag $ AES.aeadSimpleEncrypt aead ad msg authTagSize -- | AEAD-GCM decryption with associated data. @@ -1063,7 +1086,7 @@ decryptAESNoPad key iv = decryptAEADNoPad key iv "" decryptAEADNoPad :: Key -> GCMIV -> ByteString -> ByteString -> AuthTag -> ExceptT CryptoError IO ByteString decryptAEADNoPad aesKey iv ad msg (AuthTag tag) = do - aead <- initAEADGCM aesKey iv + aead <- initAEADGCM @AES256 aesKey iv maybeError AESDecryptError (AES.aeadSimpleDecrypt aead ad msg tag) maxMsgLen :: Int @@ -1138,7 +1161,7 @@ initAEAD (Key aesKey) (IV ivBytes) = do AES.aeadInit AES.AEAD_GCM cipher iv -- this function requires 12 bytes IV, it does not transforms IV. -initAEADGCM :: Key -> GCMIV -> ExceptT CryptoError IO (AES.AEAD AES256) +initAEADGCM :: forall c. AES.BlockCipher c => Key -> GCMIV -> ExceptT CryptoError IO (AES.AEAD c) initAEADGCM (Key aesKey) (GCMIV ivBytes) = cryptoFailable $ do cipher <- AES.cipherInit aesKey AES.aeadInit AES.AEAD_GCM cipher ivBytes @@ -1240,11 +1263,11 @@ instance SignatureAlgorithmX509 pk => SignatureAlgorithmX509 (a, pk) where -- | A wrapper to marshall signed ASN1 objects, like certificates. newtype SignedObject a = SignedObject {getSignedExact :: X.SignedExact a} -instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a) where +instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a) #if defined(dbPostgres) - fromField f dat = SignedObject <$> blobFieldDecoder X.decodeSignedObject f dat + where fromField f dat = SignedObject <$> blobFieldDecoder X.decodeSignedObject f dat #else - fromField = fmap SignedObject . blobFieldDecoder X.decodeSignedObject + where fromField = fmap SignedObject . blobFieldDecoder X.decodeSignedObject #endif instance (Eq a, Show a, ASN1Object a) => ToField (SignedObject a) where @@ -1530,3 +1553,54 @@ keyError :: (a, [ASN1]) -> Either String b keyError = \case (_, []) -> Left "unknown key algorithm" _ -> Left "more than one key" + +readECPrivateKey :: FilePath -> IO ECDSA.PrivateKey +readECPrivateKey f = do + -- this pattern match is specific to APNS key type, it may need to be extended for other push providers + [PK.Unprotected (X.PrivKeyEC X.PrivKeyEC_Named {privkeyEC_name, privkeyEC_priv})] <- PK.readKeyFile f + pure ECDSA.PrivateKey {private_curve = ECC.getCurveByName privkeyEC_name, private_d = privkeyEC_priv} + +-- | Elliptic-Curve-Point-to-Octet-String Conversion without compression +-- | as required by RFC8291 +-- | https://www.secg.org/sec1-v2.pdf#subsubsection.2.3.3 +uncompressEncodePoint :: ECC.Point -> ByteString +uncompressEncodePoint (ECC.Point x y) = "\x04" <> encodeBigInt x <> encodeBigInt y +uncompressEncodePoint ECC.PointO = "\0" + +uncompressDecodePoint :: ByteString -> Either String ECC.Point +uncompressDecodePoint "\0" = pure ECC.PointO +uncompressDecodePoint s + | B.take 1 s /= prefix = Left "PointFormatUnsupported" + | B.length s /= 65 = Left "KeySizeInvalid" + | otherwise = do + let s' = B.drop 1 s + x <- decodeBigInt $ B.take 32 s' + y <- decodeBigInt $ B.drop 32 s' + pure $ ECC.Point x y + where + prefix = "\x04" :: ByteString + +-- Used to test encryption against the RFC8291 Example - which gives the AS private key +uncompressDecodePrivateNumber :: ByteString -> Either String ECC.PrivateNumber +uncompressDecodePrivateNumber s + | B.length s /= 32 = Left "KeySizeInvalid" + | otherwise = decodeBigInt s + +encodeBigInt :: Integer -> ByteString +encodeBigInt i = + let s1 = Bits.shiftR i 64 + s2 = Bits.shiftR s1 64 + s3 = Bits.shiftR s2 64 + in LB.toStrict $ Bin.encode (w64 s3, w64 s2, w64 s1, w64 i) + where + w64 :: Integer -> Word64 + w64 = fromIntegral + +decodeBigInt :: ByteString -> Either String Integer +decodeBigInt s + | B.length s /= 32 = Left "PointSizeInvalid" + | otherwise = + let (w3, w2, w1, w0) = Bin.decode (LB.fromStrict s) :: (Bin.Word64, Bin.Word64, Bin.Word64, Bin.Word64) + in Right $ shift 3 w3 + shift 2 w2 + shift 1 w1 + fromIntegral w0 + where + shift i w = Bits.shiftL (fromIntegral w) (64 * i) diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 0b5889bb7f..e0ca4fc9e5 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -12,6 +12,7 @@ module Simplex.Messaging.Notifications.Protocol where import Control.Applicative (optional, (<|>)) +import qualified Crypto.PubKey.ECC.Types as ECC import Data.Aeson (FromJSON (..), ToJSON (..), (.:), (.=)) import qualified Data.Aeson as J import qualified Data.Aeson.Encoding as JE @@ -27,6 +28,7 @@ import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time.Clock.System import Data.Type.Equality import Data.Word (Word16) +import Network.HTTP.Client (Request, parseUrlThrow) import Simplex.Messaging.Agent.Protocol (updateSMPServerHosts) import Simplex.Messaging.Agent.Store.DB (FromField (..), ToField (..), fromTextField_) import qualified Simplex.Messaging.Crypto as C @@ -372,14 +374,35 @@ instance StrEncoding SMPQueueNtf where notifierId <- A.char '/' *> strP pure SMPQueueNtf {smpServer, notifierId} -data PushProvider +data PushProvider = PPAPNS APNSProvider | PPWP WPProvider + deriving (Eq, Ord, Show) + +data APNSProvider = PPApnsDev -- provider for Apple development environment | PPApnsProd -- production environment, including TestFlight | PPApnsTest -- used for tests, to use APNS mock server | PPApnsNull -- used to test servers from the client - does not communicate with APNS deriving (Eq, Ord, Show) +newtype WPSrvLoc = WPSrvLoc SrvLoc + deriving (Eq, Ord, Show) + +newtype WPProvider = WPP WPSrvLoc + deriving (Eq, Ord, Show) + +wpAud :: WPProvider -> B.ByteString +wpAud (WPP (WPSrvLoc (SrvLoc aud _))) = B.pack aud + instance Encoding PushProvider where + smpEncode = \case + PPAPNS p -> smpEncode p + PPWP p -> smpEncode p + smpP = + A.peekChar' >>= \case + 'A' -> PPAPNS <$> smpP + _ -> PPWP <$> smpP + +instance Encoding APNSProvider where smpEncode = \case PPApnsDev -> "AD" PPApnsProd -> "AP" @@ -391,9 +414,18 @@ instance Encoding PushProvider where "AP" -> pure PPApnsProd "AT" -> pure PPApnsTest "AN" -> pure PPApnsNull - _ -> fail "bad PushProvider" + _ -> fail "bad APNSProvider" instance StrEncoding PushProvider where + strEncode = \case + PPAPNS p -> strEncode p + PPWP p -> strEncode p + strP = + A.peekChar' >>= \case + 'a' -> PPAPNS <$> strP + _ -> PPWP <$> strP + +instance StrEncoding APNSProvider where strEncode = \case PPApnsDev -> "apns_dev" PPApnsProd -> "apns_prod" @@ -405,38 +437,194 @@ instance StrEncoding PushProvider where "apns_prod" -> pure PPApnsProd "apns_test" -> pure PPApnsTest "apns_null" -> pure PPApnsNull - _ -> fail "bad PushProvider" + _ -> fail "bad APNSProvider" + +instance Encoding WPSrvLoc where + smpEncode (WPSrvLoc srv) = smpEncode srv + smpP = WPSrvLoc <$> smpP + +instance StrEncoding WPSrvLoc where + strEncode (WPSrvLoc srv) = "https://" <> strEncode srv + strP = WPSrvLoc <$> ("https://" *> strP) + +instance Encoding WPProvider where + smpEncode (WPP srv) = "WP" <> smpEncode srv + smpP = WPP <$> ("WP" *> smpP) + +instance StrEncoding WPProvider where + strEncode (WPP srv) = "webpush " <> strEncode srv + strP = WPP <$> ("webpush " *> strP) instance FromField PushProvider where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 instance ToField PushProvider where toField = toField . decodeLatin1 . strEncode -data DeviceToken = DeviceToken PushProvider ByteString +newtype WPAuth = WPAuth {unWPAuth :: ByteString} deriving (Eq, Ord, Show) + +toWPAuth :: ByteString -> Either String WPAuth +toWPAuth s + | B.length s == 16 = Right $ WPAuth s + | otherwise = Left "bad WPAuth" + +newtype WPP256dh = WPP256dh ECC.PublicPoint + deriving (Eq, Show) + +-- This Ord instance for ECC point is quite arbitrary, it is needed because token is used as Map key +instance Ord WPP256dh where + compare (WPP256dh p1) (WPP256dh p2) = case (p1, p2) of + (ECC.PointO, ECC.PointO) -> EQ + (ECC.PointO, _) -> GT + (_, ECC.PointO) -> LT + (ECC.Point x1 y1, ECC.Point x2 y2) -> compare (x1, y1) (x2, y2) + +data WPKey = WPKey + { wpAuth :: WPAuth, + wpP256dh :: WPP256dh + } + deriving (Eq, Ord, Show) + +uncompressEncode :: WPP256dh -> ByteString +uncompressEncode (WPP256dh p) = C.uncompressEncodePoint p +{-# INLINE uncompressEncode #-} + +uncompressDecode :: ByteString -> Either String WPP256dh +uncompressDecode bs = WPP256dh <$> C.uncompressDecodePoint bs +{-# INLINE uncompressDecode #-} + +data WPTokenParams = WPTokenParams + { wpPath :: ByteString, + wpKey :: WPKey + } deriving (Eq, Ord, Show) +instance Encoding WPAuth where + smpEncode = smpEncode . unWPAuth + smpP = toWPAuth <$?> smpP + +instance StrEncoding WPAuth where + strEncode = strEncode . unWPAuth + strP = toWPAuth <$?> strP + +instance Encoding WPP256dh where + smpEncode = smpEncode . uncompressEncode + {-# INLINE smpEncode #-} + smpP = uncompressDecode <$?> smpP + {-# INLINE smpP #-} + +instance StrEncoding WPP256dh where + strEncode = strEncode . uncompressEncode + {-# INLINE strEncode #-} + strP = uncompressDecode <$?> strP + {-# INLINE strP #-} + +instance Encoding WPKey where + smpEncode WPKey {wpAuth, wpP256dh} = smpEncode (wpAuth, wpP256dh) + smpP = do + wpAuth <- smpP + wpP256dh <- smpP + pure WPKey {wpAuth, wpP256dh} + +instance StrEncoding WPKey where + strEncode WPKey {wpAuth, wpP256dh} = strEncode (wpAuth, wpP256dh) + strP = do + (wpAuth, wpP256dh) <- strP + pure WPKey {wpAuth, wpP256dh} + +instance Encoding WPTokenParams where + smpEncode WPTokenParams {wpPath, wpKey} = smpEncode (wpPath, wpKey) + smpP = do + wpPath <- smpP + wpKey <- smpP + pure WPTokenParams {wpPath, wpKey} + +instance StrEncoding WPTokenParams where + strEncode WPTokenParams {wpPath, wpKey} = wpPath <> " " <> strEncode wpKey + strP = do + wpPath <- A.takeWhile (/= ' ') + _ <- A.char ' ' + wpKey <- strP + pure WPTokenParams {wpPath, wpKey} + +data DeviceToken + = APNSDeviceToken APNSProvider ByteString + | WPDeviceToken WPProvider WPTokenParams + deriving (Eq, Ord, Show) + +tokenPushProvider :: DeviceToken -> PushProvider +tokenPushProvider = \case + APNSDeviceToken pp _ -> PPAPNS pp + WPDeviceToken pp _ -> PPWP pp + instance Encoding DeviceToken where - smpEncode (DeviceToken p t) = smpEncode (p, t) - smpP = DeviceToken <$> smpP <*> smpP + smpEncode token = case token of + APNSDeviceToken p t -> smpEncode (p, t) + WPDeviceToken p t -> smpEncode (p, t) + smpP = + smpP >>= \case + PPAPNS p -> APNSDeviceToken p <$> smpP + PPWP p -> WPDeviceToken p <$> smpP instance StrEncoding DeviceToken where - strEncode (DeviceToken p t) = strEncode p <> " " <> t - strP = nullToken <|> hexToken + strEncode token = case token of + APNSDeviceToken p t -> strEncode p <> " " <> t + -- We don't do strEncode (p, t), because we don't want any space between + -- p (e.g. webpush https://localhost) and t.wpPath (e.g /random) + WPDeviceToken p t -> strEncode p <> strEncode t + strP = nullToken <|> deviceToken where - nullToken = "apns_null test_ntf_token" $> DeviceToken PPApnsNull "test_ntf_token" - hexToken = DeviceToken <$> strP <* A.space <*> hexStringP - hexStringP = + nullToken = "apns_null test_ntf_token" $> APNSDeviceToken PPApnsNull "test_ntf_token" + deviceToken = + strP >>= \case + PPAPNS p -> APNSDeviceToken p <$> hexStringP + PPWP p -> do + t <- WPDeviceToken p <$> strP + _ <- wpRequest t + pure t + hexStringP = do + _ <- A.space A.takeWhile (`B.elem` "0123456789abcdef") >>= \s -> if even (B.length s) then pure s else fail "odd number of hex characters" instance ToJSON DeviceToken where - toEncoding (DeviceToken pp t) = J.pairs $ "pushProvider" .= decodeLatin1 (strEncode pp) <> "token" .= decodeLatin1 t - toJSON (DeviceToken pp t) = J.object ["pushProvider" .= decodeLatin1 (strEncode pp), "token" .= decodeLatin1 t] + toEncoding token = case token of + APNSDeviceToken p t -> J.pairs $ "pushProvider" .= decodeLatin1 (strEncode p) <> "token" .= decodeLatin1 t + -- ToJSON/FromJSON isn't used for WPDeviceToken, we just include the pushProvider so it can fail properly if used to decrypt + WPDeviceToken p _ -> J.pairs $ "pushProvider" .= decodeLatin1 (strEncode p) + + -- WPDeviceToken p t -> J.pairs $ "pushProvider" .= decodeLatin1 (strEncode p) <> "token" .= toJSON t + toJSON token = case token of + APNSDeviceToken p t -> J.object ["pushProvider" .= decodeLatin1 (strEncode p), "token" .= decodeLatin1 t] + -- ToJSON/FromJSON isn't used for WPDeviceToken, we just include the pushProvider so it can fail properly if used to decrypt + WPDeviceToken p _ -> J.object ["pushProvider" .= decodeLatin1 (strEncode p)] + +-- WPDeviceToken p t -> J.object ["pushProvider" .= decodeLatin1 (strEncode p), "token" .= toJSON t] instance FromJSON DeviceToken where - parseJSON = J.withObject "DeviceToken" $ \o -> do - pp <- strDecode . encodeUtf8 <$?> o .: "pushProvider" - t <- encodeUtf8 <$> o .: "token" - pure $ DeviceToken pp t + parseJSON = J.withObject "DeviceToken" $ \o -> + (strDecode . encodeUtf8 <$?> o .: "pushProvider") >>= \case + PPAPNS p -> APNSDeviceToken p . encodeUtf8 <$> (o .: "token") + PPWP _ -> fail "FromJSON not implemented for WPDeviceToken" + +-- | Returns fields for the device token (pushProvider, token) +-- TODO [webpush] save token as separate fields +deviceTokenFields :: DeviceToken -> (PushProvider, ByteString) +deviceTokenFields dt = case dt of + APNSDeviceToken p t -> (PPAPNS p, t) + WPDeviceToken p t -> (PPWP p, strEncode t) + +-- | Returns the device token from the fields (pushProvider, token) +deviceToken' :: PushProvider -> ByteString -> DeviceToken +deviceToken' pp t = case pp of + PPAPNS p -> APNSDeviceToken p t + PPWP p -> WPDeviceToken p <$> either error id $ strDecode t + +wpRequest :: MonadFail m => DeviceToken -> m Request +wpRequest (APNSDeviceToken _ _) = fail "Invalid device token" +wpRequest (WPDeviceToken (WPP s) param) = do + let endpoint = strEncode s <> wpPath param + case parseUrlThrow $ B.unpack endpoint of + Left _ -> fail "Invalid URL" + Right r -> pure r -- List of PNMessageData uses semicolon-separated encoding instead of strEncode, -- because strEncode of NonEmpty list uses comma for separator, diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 7e8acac818..46258c7f71 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -629,18 +629,18 @@ showServer' = decodeLatin1 . strEncode . host ntfPush :: NtfPushServer -> M () ntfPush s@NtfPushServer {pushQ} = forever $ do - (srvHost_, tkn@NtfTknRec {ntfTknId, token = t@(DeviceToken pp _), tknStatus}, ntf) <- atomically (readTBQueue pushQ) - liftIO $ logDebug $ "sending push notification to " <> T.pack (show pp) + (srvHost_, tkn@NtfTknRec {ntfTknId, token = t, tknStatus}, ntf) <- atomically (readTBQueue pushQ) + logDebug $ "sending push notification to " <> tshow (tokenPushProvider t) st <- asks store case ntf of PNVerification _ -> - liftIO (deliverNotification st pp tkn ntf) >>= \case + liftIO (deliverNotification st tkn ntf) >>= \case Right _ -> do void $ liftIO $ setTknStatusConfirmed st tkn incNtfStatT t ntfVrfDelivered Left _ -> incNtfStatT t ntfVrfFailed PNCheckMessages -> do - liftIO (deliverNotification st pp tkn ntf) >>= \case + liftIO (deliverNotification st tkn ntf) >>= \case Right _ -> do void $ liftIO $ updateTokenCronSentAt st ntfTknId . systemSeconds =<< getSystemTime incNtfStatT t ntfCronDelivered @@ -648,7 +648,7 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do PNMessage {} -> checkActiveTkn tknStatus $ do stats <- asks serverStats liftIO $ updatePeriodStats (activeTokens stats) ntfTknId - liftIO (deliverNotification st pp tkn ntf) >>= \case + liftIO (deliverNotification st tkn ntf) >>= \case Left _ -> do incNtfStatT t ntfFailed liftIO $ mapM_ (`incServerStat` ntfFailedOwn stats) srvHost_ @@ -661,8 +661,8 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do checkActiveTkn status action | status == NTActive = action | otherwise = liftIO $ logError "bad notification token status" - deliverNotification :: NtfPostgresStore -> PushProvider -> NtfTknRec -> PushNotification -> IO (Either PushProviderError ()) - deliverNotification st pp tkn@NtfTknRec {ntfTknId} ntf = do + deliverNotification :: NtfPostgresStore -> NtfTknRec -> PushNotification -> IO (Either PushProviderError ()) + deliverNotification st tkn@NtfTknRec {ntfTknId, token} ntf = do deliver <- getPushClient s pp runExceptT (deliver tkn ntf) >>= \case Right _ -> pure $ Right () @@ -675,7 +675,10 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do void $ updateTknStatus st tkn $ NTInvalid $ Just r err e PPPermanentError -> err e + PPInvalidPusher -> err e + _ -> err e where + pp = tokenPushProvider token retryDeliver :: IO (Either PushProviderError ()) retryDeliver = do deliver <- newPushClient s pp @@ -905,7 +908,7 @@ withNtfStore stAction continue = do Right a -> continue a incNtfStatT :: DeviceToken -> (NtfServerStats -> IORef Int) -> M () -incNtfStatT (DeviceToken PPApnsNull _) _ = pure () +incNtfStatT (APNSDeviceToken PPApnsNull _) _ = pure () incNtfStatT _ statSel = incNtfStat statSel {-# INLINE incNtfStatT #-} diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 7ed258b9a0..83f9994614 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -1,8 +1,8 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} @@ -12,12 +12,15 @@ import Control.Concurrent (ThreadId) import Control.Logger.Simple import Control.Monad import Crypto.Random +import Data.IORef (newIORef) import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) import qualified Data.Text as T import Data.Time.Clock (getCurrentTime) import Data.Time.Clock.System (SystemTime) import qualified Data.X509.Validation as XV +import Network.HTTP.Client (Manager, ManagerSettings (..), Request (..), newManager) +import Network.HTTP.Client.TLS (tlsManagerSettings) import Network.Socket import qualified Network.TLS as TLS import Numeric.Natural @@ -25,7 +28,9 @@ import Simplex.Messaging.Client (ProtocolClientConfig (..)) import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Server.Push import Simplex.Messaging.Notifications.Server.Push.APNS +import Simplex.Messaging.Notifications.Server.Push.WebPush (WebPushClient (..), WebPushConfig, wpPushProviderClient) import Simplex.Messaging.Notifications.Server.Stats import Simplex.Messaging.Notifications.Server.Store (newNtfSTMStore) import Simplex.Messaging.Notifications.Server.Store.Postgres @@ -45,7 +50,6 @@ import Simplex.Messaging.Transport.Server (AddHTTP, ServerCredentials, Transport import System.Exit (exitFailure) import System.Mem.Weak (Weak) import UnliftIO.STM -import Simplex.Messaging.Notifications.Server.Push (PushNotification, PushProviderClient) data NtfServerConfig = NtfServerConfig { transports :: [(ServiceName, ASrvTransport, AddHTTP)], @@ -58,6 +62,7 @@ data NtfServerConfig = NtfServerConfig pushQSize :: Natural, smpAgentCfg :: SMPClientAgentConfig, apnsConfig :: APNSPushClientConfig, + wpConfig :: WebPushConfig, subsBatchSize :: Int, inactiveClientExpiration :: Maybe ExpirationConfig, dbStoreConfig :: PostgresStoreCfg, @@ -97,7 +102,7 @@ data NtfEnv = NtfEnv } newNtfServerEnv :: NtfServerConfig -> IO NtfEnv -newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbStoreConfig, ntfCredentials, useServiceCreds, startOptions} = do +newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, wpConfig, dbStoreConfig, ntfCredentials, useServiceCreds, startOptions} = do when (compactLog startOptions) $ compactDbStoreLog $ dbStoreLogPath dbStoreConfig random <- C.newRandom store <- newNtfDbStore dbStoreConfig @@ -113,7 +118,7 @@ newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbSt pure smpAgentCfg {smpCfg = (smpCfg smpAgentCfg) {serviceCredentials = Just service}} else pure smpAgentCfg subscriber <- newNtfSubscriber smpAgentCfg' random - pushServer <- newNtfPushServer pushQSize apnsConfig + pushServer <- newNtfPushServer pushQSize apnsConfig wpConfig serverStats <- newNtfServerStats =<< getCurrentTime pure NtfEnv {config, subscriber, pushServer, store, random, tlsServerCreds, serverIdentity = C.KeyHash fp, serverStats} where @@ -150,22 +155,50 @@ data SMPSubscriber = SMPSubscriber data NtfPushServer = NtfPushServer { pushQ :: TBQueue (Maybe T.Text, NtfTknRec, PushNotification), -- Maybe Text is a hostname of "own" server pushClients :: TMap PushProvider PushProviderClient, - apnsConfig :: APNSPushClientConfig + apnsConfig :: APNSPushClientConfig, + wpConfig :: WebPushConfig } -newNtfPushServer :: Natural -> APNSPushClientConfig -> IO NtfPushServer -newNtfPushServer qSize apnsConfig = do +newNtfPushServer :: Natural -> APNSPushClientConfig -> WebPushConfig -> IO NtfPushServer +newNtfPushServer qSize apnsConfig wpConfig = do pushQ <- newTBQueueIO qSize pushClients <- TM.emptyIO - pure NtfPushServer {pushQ, pushClients, apnsConfig} + pure NtfPushServer {pushQ, pushClients, apnsConfig, wpConfig} newPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient -newPushClient NtfPushServer {apnsConfig, pushClients} pp = do - c <- case apnsProviderHost pp of +newPushClient s pp = do + c <- case pp of + PPWP p -> newWPPushClient s p + PPAPNS p -> newAPNSPushClient s p + atomically $ TM.insert pp c $ pushClients s + pure c + +newAPNSPushClient :: NtfPushServer -> APNSProvider -> IO PushProviderClient +newAPNSPushClient NtfPushServer {apnsConfig, pushClients} pp = do + case apnsProviderHost pp of Nothing -> pure $ \_ _ -> pure () Just host -> apnsPushProviderClient <$> createAPNSPushClient host apnsConfig - atomically $ TM.insert pp c pushClients - pure c + +newWPPushClient :: NtfPushServer -> WPProvider -> IO PushProviderClient +newWPPushClient NtfPushServer {wpConfig, pushClients} pp = do + logDebug "New WP Client requested" + -- We use one http manager per push server (which may be used by different clients) + manager <- wpHTTPManager + cache <- newIORef Nothing + random <- C.newRandom + let client = WebPushClient {wpConfig, cache, manager, random} + pure $ wpPushProviderClient client + +wpHTTPManager :: IO Manager +wpHTTPManager = + newManager + tlsManagerSettings + { -- Ideally, we should be able to override the domain resolution to + -- disable requests to non-public IPs. The risk is very limited as + -- we allow https only, and the body is encrypted. Disabling redirections + -- avoids cross-protocol redir (https => http/unix) + managerModifyRequest = \r -> pure r {redirectCount = 0} + } getPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient getPushClient s@NtfPushServer {pushClients} pp = diff --git a/src/Simplex/Messaging/Notifications/Server/Main.hs b/src/Simplex/Messaging/Notifications/Server/Main.hs index de12c33f89..d2c2d393bc 100644 --- a/src/Simplex/Messaging/Notifications/Server/Main.hs +++ b/src/Simplex/Messaging/Notifications/Server/Main.hs @@ -11,7 +11,7 @@ module Simplex.Messaging.Notifications.Server.Main where import Control.Logger.Simple (setLogLevel) -import Control.Monad ((<$!>)) +import Control.Monad (unless, void, (<$!>)) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Data.Ini (lookupValue, readIniFile) @@ -31,9 +31,10 @@ import Simplex.Messaging.Client (HostMode (..), NetworkConfig (..), ProtocolClie import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol (NtfTokenId) -import Simplex.Messaging.Notifications.Server (runNtfServer, restoreServerLastNtfs) +import Simplex.Messaging.Notifications.Server (restoreServerLastNtfs, runNtfServer) import Simplex.Messaging.Notifications.Server.Env (NtfServerConfig (..), defaultInactiveClientExpiration) import Simplex.Messaging.Notifications.Server.Push.APNS (defaultAPNSPushClientConfig) +import Simplex.Messaging.Notifications.Server.Push.WebPush (VapidKey (..), WebPushConfig (..), mkVapid) import Simplex.Messaging.Notifications.Server.Store (newNtfSTMStore) import Simplex.Messaging.Notifications.Server.Store.Postgres (exportNtfDbStore, importNtfSTMStore, newNtfDbStore) import Simplex.Messaging.Notifications.Server.StoreLog (readWriteNtfSTMStore) @@ -55,6 +56,7 @@ import System.Directory (createDirectoryIfMissing, doesFileExist, renameFile) import System.Exit (exitFailure) import System.FilePath (combine) import System.IO (BufferMode (..), hSetBuffering, stderr, stdout) +import System.Process (readCreateProcess, shell) import Text.Read (readMaybe) ntfServerCLI :: FilePath -> FilePath -> IO () @@ -146,6 +148,7 @@ ntfServerCLI cfgPath logPath = clearDirIfExists logPath createDirectoryIfMissing True cfgPath createDirectoryIfMissing True logPath + _ <- genVapidKey vapidKeyPath let x509cfg = defaultX509Config {commonName = fromMaybe ip fqdn, signAlgorithm} fp <- createServerX509 cfgPath x509cfg let host = fromMaybe (if ip == "127.0.0.1" then "" else ip) fqdn @@ -212,11 +215,13 @@ ntfServerCLI cfgPath logPath = hSetBuffering stdout LineBuffering hSetBuffering stderr LineBuffering fp <- checkSavedFingerprint cfgPath defaultX509Config + vapidKey@VapidKey {fp = vapidFp} <- getVapidKey vapidKeyPath let host = either (const "") T.unpack $ lookupValue "TRANSPORT" "host" ini port = T.unpack $ strictIni "TRANSPORT" "port" ini - cfg@NtfServerConfig {transports} = serverConfig + cfg@NtfServerConfig {transports} = serverConfig vapidKey srv = ProtoServerWithAuth (NtfServer [THDomainName host] (if port == "443" then "" else port) (C.KeyHash fp)) Nothing printServiceInfo serverVersion srv + B.putStrLn $ "VAPID: " <> vapidFp printNtfServerConfig transports dbStoreConfig runNtfServer cfg where @@ -230,7 +235,7 @@ ntfServerCLI cfgPath logPath = confirmMigrations = MCYesUp, deletedTTL = iniDeletedTTL ini } - serverConfig = + serverConfig vapidKey = NtfServerConfig { transports = iniTransports ini, controlPort = either (const Nothing) (Just . T.unpack) $ lookupValue "TRANSPORT" "control_port" ini, @@ -258,6 +263,11 @@ ntfServerCLI cfgPath logPath = persistErrorInterval = 0 -- seconds }, apnsConfig = defaultAPNSPushClientConfig, + wpConfig = + WebPushConfig + { vapidKey, + paddedNtfLength = 3072 + }, subsBatchSize = 900, inactiveClientExpiration = settingIsOn "INACTIVE_CLIENTS" "disconnect" ini @@ -294,6 +304,7 @@ ntfServerCLI cfgPath logPath = putStrLn $ "Error: both " <> storeLogFilePath <> " file and " <> B.unpack schema <> " schema are present (database: " <> B.unpack connstr <> ")." putStrLn "Configure notification server storage." exitFailure + vapidKeyPath = combine cfgPath "vapid.privkey" printNtfServerConfig :: [(ServiceName, ASrvTransport, AddHTTP)] -> PostgresStoreCfg -> IO () printNtfServerConfig transports PostgresStoreCfg {dbOpts = DBOpts {connstr, schema}, dbStoreLogPath} = do @@ -350,18 +361,21 @@ cliCommandP cfgPath logPath iniFile = skipTokensP = option strParse - ( long "skip-tokens" - <> help "Skip tokens during import" - <> value S.empty - ) + ( long "skip-tokens" + <> help "Skip tokens during import" + <> value S.empty + ) initP :: Parser InitOptions initP = do enableStoreLog <- - flag' False + flag' + False ( long "disable-store-log" <> help "Disable store log for persistence (enabled by default)" ) - <|> flag True True + <|> flag + True + True ( long "store-log" <> short 'l' <> help "Enable store log for persistence (DEPRECATED, enabled by default)" @@ -395,3 +409,19 @@ cliCommandP cfgPath logPath iniFile = <> metavar "FQDN" ) pure InitOptions {enableStoreLog, dbOptions, signAlgorithm, ip, fqdn} + +genVapidKey :: FilePath -> IO VapidKey +genVapidKey file = do + cfgExists <- doesFileExist file + unless cfgExists $ run $ "openssl ecparam -name prime256v1 -genkey -noout -out " <> file + key <- C.readECPrivateKey file + pure $ mkVapid key + where + run cmd = void $ readCreateProcess (shell cmd) "" + +getVapidKey :: FilePath -> IO VapidKey +getVapidKey file = do + cfgExists <- doesFileExist file + unless cfgExists $ error $ "VAPID key not found: " <> file + key <- C.readECPrivateKey file + pure $ mkVapid key diff --git a/src/Simplex/Messaging/Notifications/Server/Push.hs b/src/Simplex/Messaging/Notifications/Server/Push.hs index edb671212b..ff21de2d4a 100644 --- a/src/Simplex/Messaging/Notifications/Server/Push.hs +++ b/src/Simplex/Messaging/Notifications/Server/Push.hs @@ -10,6 +10,8 @@ module Simplex.Messaging.Notifications.Server.Push where +import Control.Exception (Exception) +import Control.Monad.Except (ExceptT) import Crypto.Hash.Algorithms (SHA256 (..)) import qualified Crypto.PubKey.ECC.ECDSA as EC import qualified Crypto.PubKey.ECC.Types as ECT @@ -28,24 +30,30 @@ import Data.List.NonEmpty (NonEmpty (..)) import Data.Text (Text) import Data.Time.Clock.System import qualified Data.X509 as X +import GHC.Exception (SomeException) +import Network.HTTP.Types (Status) +import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Server.Store.Types (NtfTknRec) import Simplex.Messaging.Parsers (defaultJSON) import Simplex.Messaging.Transport.HTTP2.Client (HTTP2ClientError) -import qualified Simplex.Messaging.Crypto as C -import Network.HTTP.Types (Status) -import Control.Exception (Exception) -import Simplex.Messaging.Notifications.Server.Store.Types (NtfTknRec) -import Control.Monad.Except (ExceptT) data JWTHeader = JWTHeader - { alg :: Text, -- key algorithm, ES256 for APNS - kid :: Text -- key ID + { typ :: Text, -- "JWT" + alg :: Text, -- key algorithm, ES256 for APNS + kid :: Maybe Text -- key ID } deriving (Show) +mkJWTHeader :: Text -> Maybe Text -> JWTHeader +mkJWTHeader alg kid = JWTHeader {typ = "JWT", alg, kid} + data JWTClaims = JWTClaims - { iss :: Text, -- issuer, team ID for APNS - iat :: Int64 -- issue time, seconds from epoch + { iss :: Maybe Text, -- issuer, team ID for APNS + iat :: Maybe Int64, -- issue time, seconds from epoch for APNS + exp :: Maybe Int64, -- expired time, seconds from epoch for web push + aud :: Maybe Text, -- audience, for web push + sub :: Maybe Text -- subject, to be inform if there is an issue, for web push } deriving (Show) @@ -55,7 +63,16 @@ data JWTToken = JWTToken JWTHeader JWTClaims mkJWTToken :: JWTHeader -> Text -> IO JWTToken mkJWTToken hdr iss = do iat <- systemSeconds <$> getSystemTime - pure $ JWTToken hdr JWTClaims {iss, iat} + pure $ JWTToken hdr $ jwtClaims iat + where + jwtClaims iat = + JWTClaims + { iss = Just iss, + iat = Just iat, + exp = Nothing, + aud = Nothing, + sub = Nothing + } type SignedJWTToken = ByteString @@ -63,15 +80,23 @@ $(JQ.deriveToJSON defaultJSON ''JWTHeader) $(JQ.deriveToJSON defaultJSON ''JWTClaims) -signedJWTToken :: EC.PrivateKey -> JWTToken -> IO SignedJWTToken -signedJWTToken pk (JWTToken hdr claims) = do +signedJWTToken_ :: (EC.Signature -> ByteString) -> EC.PrivateKey -> JWTToken -> IO SignedJWTToken +signedJWTToken_ serialize pk (JWTToken hdr claims) = do let hc = jwtEncode hdr <> "." <> jwtEncode claims sig <- EC.sign pk SHA256 hc - pure $ hc <> "." <> serialize sig + pure $ hc <> "." <> U.encodeUnpadded (serialize sig) where jwtEncode :: ToJSON a => a -> ByteString jwtEncode = U.encodeUnpadded . LB.toStrict . J.encode - serialize sig = U.encodeUnpadded $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence] + +signedJWTToken :: EC.PrivateKey -> JWTToken -> IO SignedJWTToken +signedJWTToken = signedJWTToken_ $ \sig -> + encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence] + +-- | Does it work with APNS ? +signedJWTTokenRaw :: EC.PrivateKey -> JWTToken -> IO SignedJWTToken +signedJWTTokenRaw = signedJWTToken_ $ \sig -> + C.encodeBigInt (EC.sign_r sig) <> C.encodeBigInt (EC.sign_s sig) readECPrivateKey :: FilePath -> IO EC.PrivateKey readECPrivateKey f = do @@ -93,6 +118,11 @@ data PushProviderError | PPTokenInvalid NTInvalidReason | PPRetryLater | PPPermanentError + | PPInvalidPusher + | PPWPInvalidUrl + | PPWPRemovedEndpoint + | PPWPRequestTooLong + | PPWPOtherError SomeException deriving (Show, Exception) type PushProviderClient = NtfTknRec -> PushNotification -> ExceptT PushProviderError IO () diff --git a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs index 2337fa7fda..929360b53a 100644 --- a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs +++ b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs @@ -124,7 +124,7 @@ data APNSPushClientConfig = APNSPushClientConfig caStoreFile :: FilePath } -apnsProviderHost :: PushProvider -> Maybe HostName +apnsProviderHost :: APNSProvider -> Maybe HostName apnsProviderHost = \case PPApnsNull -> Nothing PPApnsTest -> Just "localhost" @@ -160,9 +160,9 @@ createAPNSPushClient :: HostName -> APNSPushClientConfig -> IO APNSPushClient createAPNSPushClient apnsHost apnsCfg@APNSPushClientConfig {authKeyFileEnv, authKeyAlg, authKeyIdEnv, appTeamId} = do https2Client <- newTVarIO Nothing void $ connectHTTPS2 apnsHost apnsCfg https2Client - privateKey <- readECPrivateKey =<< getEnv authKeyFileEnv + privateKey <- C.readECPrivateKey =<< getEnv authKeyFileEnv authKeyId <- T.pack <$> getEnv authKeyIdEnv - let jwtHeader = JWTHeader {alg = authKeyAlg, kid = authKeyId} + let jwtHeader = mkJWTHeader authKeyAlg (Just authKeyId) jwtToken <- newTVarIO =<< mkApnsJWTToken appTeamId jwtHeader privateKey nonceDrg <- C.newRandom pure APNSPushClient {https2Client, privateKey, jwtHeader, jwtToken, nonceDrg, apnsHost, apnsCfg} @@ -178,7 +178,8 @@ getApnsJWTToken APNSPushClient {apnsCfg = APNSPushClientConfig {appTeamId, token atomically $ writeTVar jwtToken t pure signedJWT' where - jwtTokenAge (JWTToken _ JWTClaims {iat}) = subtract iat . systemSeconds <$> getSystemTime + jwtTokenAge (JWTToken _ JWTClaims {iat = Just iat}) = subtract iat . systemSeconds <$> getSystemTime + jwtTokenAge (JWTToken _ JWTClaims {iat = Nothing}) = pure maxBound :: IO Int64 mkApnsJWTToken :: Text -> JWTHeader -> EC.PrivateKey -> IO (JWTToken, SignedJWTToken) mkApnsJWTToken appTeamId jwtHeader privateKey = do @@ -255,8 +256,10 @@ data APNSErrorResponse = APNSErrorResponse {reason :: Text} $(JQ.deriveFromJSON defaultJSON ''APNSErrorResponse) +-- TODO [webpush] change type accept token components so it only allows APNS token apnsPushProviderClient :: APNSPushClient -> PushProviderClient -apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknRec {token = DeviceToken _ tknStr} pn = do +apnsPushProviderClient _ NtfTknRec {token = WPDeviceToken _ _} _ = throwE PPInvalidPusher +apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknRec {token = APNSDeviceToken _ tknStr} pn = do http2 <- liftHTTPS2 $ getApnsHTTP2Client c nonce <- atomically $ C.randomCbNonce nonceDrg apnsNtf <- liftEither $ first PPCryptoError $ apnsNotification tkn nonce (paddedNtfLength apnsCfg) pn diff --git a/src/Simplex/Messaging/Notifications/Server/Push/WebPush.hs b/src/Simplex/Messaging/Notifications/Server/Push/WebPush.hs new file mode 100644 index 0000000000..d6a656d864 --- /dev/null +++ b/src/Simplex/Messaging/Notifications/Server/Push/WebPush.hs @@ -0,0 +1,225 @@ +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +{-# HLINT ignore "Use newtype instead of data" #-} + +module Simplex.Messaging.Notifications.Server.Push.WebPush where + +import Control.Exception (SomeException, fromException, try) +import Control.Logger.Simple (logDebug) +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class (liftIO) +import Control.Monad.Trans.Except (throwE) +import qualified Crypto.Cipher.Types as CT +import Crypto.Hash.Algorithms (SHA256) +import qualified Crypto.MAC.HMAC as HMAC +import qualified Crypto.PubKey.ECC.DH as ECDH +import qualified Crypto.PubKey.ECC.ECDSA as ECDSA +import qualified Crypto.PubKey.ECC.Types as ECC +import Crypto.Random (ChaChaDRG, getRandomBytes) +import Data.Aeson ((.=)) +import qualified Data.Aeson as J +import qualified Data.Binary as Bin +import qualified Data.ByteArray as BA +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import qualified Data.ByteString.Base64.URL as B64 +import qualified Data.ByteString.Lazy as LB +import Data.IORef +import Data.Int (Int64) +import Data.Text (Text) +import qualified Data.Text.Encoding as T +import Data.Time.Clock.System (getSystemTime, systemSeconds) +import Network.HTTP.Client +import qualified Network.HTTP.Types as N +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfRegCode (..), WPAuth (..), WPKey (..), WPP256dh (..), WPTokenParams (..), encodePNMessages, wpAud, wpRequest) +import Simplex.Messaging.Notifications.Server.Push +import Simplex.Messaging.Notifications.Server.Store.Types +import Simplex.Messaging.Util (liftError', safeDecodeUtf8, tshow) +import UnliftIO.STM + +-- | Vapid +-- | fp: fingerprint, base64url encoded without padding +-- | key: privkey +data VapidKey = VapidKey + { key :: ECDSA.PrivateKey, + fp :: ByteString + } + deriving (Eq, Show) + +mkVapid :: ECDSA.PrivateKey -> VapidKey +mkVapid key = VapidKey {key, fp} + where + fp = B64.encodeUnpadded $ C.uncompressEncodePoint $ ECDH.calculatePublic (ECC.getCurveByName ECC.SEC_p256r1) $ ECDSA.private_d key + +data WebPushClient = WebPushClient + { wpConfig :: WebPushConfig, + cache :: IORef (Maybe WPCache), + manager :: Manager, + random :: TVar ChaChaDRG + } + +data WebPushConfig = WebPushConfig + { vapidKey :: VapidKey, + paddedNtfLength :: Int + } + +data WPCache = WPCache + { vapidHeader :: ByteString, + expire :: Int64 + } + +getVapidHeader :: VapidKey -> IORef (Maybe WPCache) -> ByteString -> IO ByteString +getVapidHeader vapidK cache uriAuthority = do + h <- readIORef cache + now <- systemSeconds <$> getSystemTime + case h of + Nothing -> newCacheEntry now + -- if it expires in 1 min, then we renew - for safety + Just entry -> + if expire entry > now + 60 + then pure $ vapidHeader entry + else newCacheEntry now + where + newCacheEntry :: Int64 -> IO ByteString + newCacheEntry now = do + -- The new entry expires in one hour + let expire = now + 3600 + vapidHeader <- mkVapidHeader vapidK uriAuthority expire + let entry = Just WPCache {vapidHeader, expire} + atomicWriteIORef cache entry + pure vapidHeader + +-- | With time in input for the tests +getVapidHeader' :: Int64 -> VapidKey -> IORef (Maybe WPCache) -> ByteString -> IO ByteString +getVapidHeader' now vapidK cache uriAuthority = do + h <- readIORef cache + case h of + Nothing -> newCacheEntry + Just entry -> + if expire entry > now + then pure $ vapidHeader entry + else newCacheEntry + where + newCacheEntry :: IO ByteString + newCacheEntry = do + -- The new entry expires in one hour + let expire = now + 3600 + vapidHeader <- mkVapidHeader vapidK uriAuthority expire + let entry = Just WPCache {vapidHeader, expire} + atomicWriteIORef cache entry + pure vapidHeader + +-- | mkVapidHeader -> vapid -> endpoint -> expire -> vapid header +mkVapidHeader :: VapidKey -> ByteString -> Int64 -> IO ByteString +mkVapidHeader VapidKey {key, fp} uriAuthority expire = do + let jwtHeader = mkJWTHeader "ES256" Nothing + jwtClaims = + JWTClaims + { iss = Nothing, + iat = Nothing, + exp = Just expire, + aud = Just $ T.decodeUtf8 $ "https://" <> uriAuthority, + sub = Just "https://github.com/simplex-chat/simplexmq/" + } + jwt = JWTToken jwtHeader jwtClaims + signedToken <- signedJWTTokenRaw key jwt + pure $ "vapid t=" <> signedToken <> ",k=" <> fp + +wpPushProviderClient :: WebPushClient -> PushProviderClient +wpPushProviderClient _ NtfTknRec {token = APNSDeviceToken _ _} _ = throwE PPInvalidPusher +wpPushProviderClient c@WebPushClient {wpConfig, cache, manager} tkn@NtfTknRec {token = token@(WPDeviceToken pp params)} pn = do + -- TODO [webpush] this function should accept type that is restricted to WP token (so, possibly WPProvider and WPTokenParams) + -- parsing will happen in DeviceToken parser, so it won't fail here + r <- wpRequest token + vapidH <- liftError' toPPWPError $ try $ getVapidHeader (vapidKey wpConfig) cache $ wpAud pp + logDebug $ "Web Push request to " <> tshow (host r) + encBody <- withExceptT PPCryptoError $ wpEncrypt c tkn params pn + let requestHeaders = + [ ("TTL", "2592000"), -- 30 days + ("Urgency", "high"), + ("Content-Encoding", "aes128gcm"), + ("Authorization", vapidH) + -- TODO: topic for pings and interval + ] + req = + r + { method = "POST", + requestHeaders, + requestBody = RequestBodyBS encBody, + redirectCount = 0 + } + void $ liftError' toPPWPError $ try $ httpNoBody req manager + +-- | encrypt :: UA key -> clear -> cipher +-- | https://www.rfc-editor.org/rfc/rfc8291#section-3.4 +wpEncrypt :: WebPushClient -> NtfTknRec -> WPTokenParams -> PushNotification -> ExceptT C.CryptoError IO ByteString +wpEncrypt WebPushClient {wpConfig, random} NtfTknRec {tknDhSecret} params pn = do + salt <- liftIO $ getRandomBytes 16 + asPrivK <- liftIO $ ECDH.generatePrivate $ ECC.getCurveByName ECC.SEC_p256r1 + pn' <- + LB.toStrict . J.encode <$> case pn of + PNVerification (NtfRegCode code) -> do + (nonce, code') <- encrypt code + pure $ J.object ["nonce" .= nonce, "verification" .= code'] + PNMessage msgData -> do + (nonce, msgData') <- encrypt $ encodePNMessages msgData + pure $ J.object ["nonce" .= nonce, "message" .= msgData'] + PNCheckMessages -> pure $ J.object ["checkMessages" .= True] + wpEncrypt' (wpKey params) asPrivK salt pn' + where + encrypt :: ByteString -> ExceptT C.CryptoError IO (C.CbNonce, Text) + encrypt ntfData = do + nonce <- atomically $ C.randomCbNonce random + encData <- liftEither $ C.cbEncrypt tknDhSecret nonce ntfData $ paddedNtfLength wpConfig + pure (nonce, safeDecodeUtf8 $ B64.encode encData) + +-- | encrypt :: UA key -> AS key -> salt -> clear -> cipher +-- | https://www.rfc-editor.org/rfc/rfc8291#section-3.4 +wpEncrypt' :: WPKey -> ECC.PrivateNumber -> ByteString -> ByteString -> ExceptT C.CryptoError IO ByteString +wpEncrypt' WPKey {wpAuth, wpP256dh = WPP256dh uaPubK} asPrivK salt clearT = do + let uaPubKS = C.uncompressEncodePoint uaPubK + let asPubKS = C.uncompressEncodePoint $ ECDH.calculatePublic (ECC.getCurveByName ECC.SEC_p256r1) asPrivK + ecdhSecret = ECDH.getShared (ECC.getCurveByName ECC.SEC_p256r1) asPrivK uaPubK + prkKey = hmac (unWPAuth wpAuth) ecdhSecret + keyInfo = "WebPush: info\0" <> uaPubKS <> asPubKS + ikm = hmac prkKey (keyInfo <> "\x01") + prk = hmac salt ikm + cekInfo = "Content-Encoding: aes128gcm\0" :: ByteString + cek = B.take 16 $ BA.convert $ hmac prk (cekInfo <> "\x01") + nonceInfo = "Content-Encoding: nonce\0" :: ByteString + nonce = B.take 12 $ BA.convert $ hmac prk (nonceInfo <> "\x01") + rs = LB.toStrict $ Bin.encode (4096 :: Bin.Word32) -- with RFC8291, it's ok to always use 4096 because there is only one single record and the final record can be smaller than rs (RFC8188) + idlen = LB.toStrict $ Bin.encode (65 :: Bin.Word8) -- with RFC8291, keyid is the pubkey, so always 65 bytes + header = salt <> rs <> idlen <> asPubKS + iv <- liftEither $ C.gcmIV nonce + -- The last record uses a padding delimiter octet set to the value 0x02 + (C.AuthTag (CT.AuthTag tag), cipherT) <- C.encryptAES128NoPad (C.Key cek) iv $ clearT <> "\x02" + -- Uncomment to see intermediate values, to compare with RFC8291 example + -- liftIO . print $ strEncode (BA.convert ecdhSecret :: ByteString) + -- liftIO . print . strEncode $ B.take 32 $ BA.convert prkKey + -- liftIO . print $ strEncode cek + -- liftIO . print $ strEncode cipherT + pure $ header <> cipherT <> BA.convert tag + where + hmac k v = HMAC.hmac k v :: HMAC.HMAC SHA256 + +toPPWPError :: SomeException -> PushProviderError +toPPWPError e = case fromException e of + Just (InvalidUrlException _ _) -> PPWPInvalidUrl + Just (HttpExceptionRequest _ (StatusCodeException resp _)) -> fromStatusCode (responseStatus resp) ("" :: String) + _ -> PPWPOtherError e + where + fromStatusCode status reason + | status == N.status200 = PPWPRemovedEndpoint + | status == N.status410 = PPWPRemovedEndpoint + | status == N.status413 = PPWPRequestTooLong + | status == N.status429 = PPRetryLater + | status >= N.status500 = PPRetryLater + | otherwise = PPResponseError (Just status) (tshow reason) diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs index 6a53ff4a22..b07efa101b 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs @@ -12,7 +12,8 @@ import Text.RawString.QQ (r) ntfServerSchemaMigrations :: [(String, Text, Maybe Text)] ntfServerSchemaMigrations = [ ("20250417_initial", m20250417_initial, Nothing), - ("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert) + ("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert), + ("20250916_webpush", m20250916_webpush, Just down_m20250916_webpush) ] -- | The list of migrations in ascending order by date @@ -78,7 +79,7 @@ CREATE INDEX idx_last_notifications_token_id_sent_at ON last_notifications(token CREATE INDEX idx_last_notifications_subscription_id ON last_notifications(subscription_id); CREATE UNIQUE INDEX idx_last_notifications_token_subscription ON last_notifications(token_id, subscription_id); - |] + |] m20250517_service_cert :: Text m20250517_service_cert = @@ -89,7 +90,7 @@ ALTER TABLE subscriptions ADD COLUMN ntf_service_assoc BOOLEAN NOT NULL DEFAULT DROP INDEX idx_subscriptions_smp_server_id_status; CREATE INDEX idx_subscriptions_smp_server_id_ntf_service_status ON subscriptions(smp_server_id, ntf_service_assoc, status); - |] + |] down_m20250517_service_cert :: Text down_m20250517_service_cert = @@ -100,4 +101,33 @@ CREATE INDEX idx_subscriptions_smp_server_id_status ON subscriptions(smp_server_ ALTER TABLE smp_servers DROP COLUMN ntf_service_id; ALTER TABLE subscriptions DROP COLUMN ntf_service_assoc; - |] + |] + +m20250916_webpush :: Text +m20250916_webpush = + [r| +CREATE TABLE webpush_servers( + wp_server_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + wp_host TEXT NOT NULL, + wp_port TEXT NOT NULL, + wp_keyhash BYTEA NOT NULL +); + +ALTER TABLE tokens + ADD COLUMN wp_server_id BIGINT REFERENCES webpush_servers ON DELETE RESTRICT ON UPDATE RESTRICT, + ADD COLUMN wp_path TEXT, + ADD COLUMN wp_auth BYTEA, + ADD COLUMN wp_key BYTEA; + |] + +down_m20250916_webpush :: Text +down_m20250916_webpush = + [r| +ALTER TABLE tokens + DROP COLUMN wp_server_id, + DROP COLUMN wp_path, + DROP COLUMN wp_auth, + DROP COLUMN wp_key; + +DROP TABLE webpush_servers; + |] diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs index 80d946c8b3..c1065ce11a 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -128,8 +128,9 @@ insertNtfTknQuery = |] replaceNtfToken :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) -replaceNtfToken st NtfTknRec {ntfTknId, token = token@(DeviceToken pp ppToken), tknStatus, tknRegCode = code@(NtfRegCode regCode)} = +replaceNtfToken st NtfTknRec {ntfTknId, token, tknStatus, tknRegCode = code@(NtfRegCode regCode)} = withFastDB "replaceNtfToken" st $ \db -> runExceptT $ do + let (pp, ppToken) = deviceTokenFields token ExceptT $ assertUpdated <$> DB.execute db @@ -143,7 +144,7 @@ replaceNtfToken st NtfTknRec {ntfTknId, token = token@(DeviceToken pp ppToken), ntfTknToRow :: NtfTknRec -> NtfTknRow ntfTknToRow NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} = - let DeviceToken pp ppToken = token + let (pp, ppToken) = deviceTokenFields token NtfRegCode regCode = tknRegCode in (ntfTknId, pp, Binary ppToken, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, Binary regCode, tknCronInterval, tknUpdatedAt) @@ -153,7 +154,8 @@ getNtfToken st tknId = getNtfToken_ st " WHERE token_id = ?" (Only tknId) findNtfTokenRegistration :: NtfPostgresStore -> NewNtfEntity 'Token -> IO (Either ErrorType (Maybe NtfTknRec)) -findNtfTokenRegistration st (NewNtfTkn (DeviceToken pp ppToken) tknVerifyKey _) = +findNtfTokenRegistration st (NewNtfTkn token tknVerifyKey _) = do + let (pp, ppToken) = deviceTokenFields token getNtfToken_ st " WHERE push_provider = ? AND push_provider_token = ? AND verify_key = ?" (pp, Binary ppToken, tknVerifyKey) getNtfToken_ :: ToRow q => NtfPostgresStore -> Query -> q -> IO (Either ErrorType (Maybe NtfTknRec)) @@ -181,7 +183,7 @@ ntfTknQuery = rowToNtfTkn :: NtfTknRow -> NtfTknRec rowToNtfTkn (ntfTknId, pp, Binary ppToken, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, Binary regCode, tknCronInterval, tknUpdatedAt) = - let token = DeviceToken pp ppToken + let token = deviceToken' pp ppToken tknRegCode = NtfRegCode regCode in NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} @@ -376,8 +378,9 @@ setTknStatusConfirmed st NtfTknRec {ntfTknId} = when (updated > 0) $ withLog "updateTknStatus" st $ \sl -> logTokenStatus sl ntfTknId NTConfirmed setTokenActive :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) -setTokenActive st tkn@NtfTknRec {ntfTknId, token = DeviceToken pp ppToken} = +setTokenActive st tkn@NtfTknRec {ntfTknId, token} = withFastDB' "setTokenActive" st $ \db -> do + let (pp, ppToken) = deviceTokenFields token updateTknStatus_ st db tkn NTActive -- this removes other instances of the same token, e.g. because of repeated token registration attempts tknIds <- diff --git a/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql b/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql index 3b155fa1a9..535652b682 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql +++ b/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql @@ -92,7 +92,31 @@ CREATE TABLE ntf_server.tokens ( reg_code bytea NOT NULL, cron_interval bigint NOT NULL, cron_sent_at bigint, - updated_at bigint + updated_at bigint, + wp_server_id bigint, + wp_path text, + wp_auth bytea, + wp_key bytea +); + + + +CREATE TABLE ntf_server.webpush_servers ( + wp_server_id bigint NOT NULL, + wp_host text NOT NULL, + wp_port text NOT NULL, + wp_keyhash bytea NOT NULL +); + + + +ALTER TABLE ntf_server.webpush_servers ALTER COLUMN wp_server_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME ntf_server.webpush_servers_wp_server_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 ); @@ -122,6 +146,11 @@ ALTER TABLE ONLY ntf_server.tokens +ALTER TABLE ONLY ntf_server.webpush_servers + ADD CONSTRAINT webpush_servers_pkey PRIMARY KEY (wp_server_id); + + + CREATE INDEX idx_last_notifications_subscription_id ON ntf_server.last_notifications USING btree (subscription_id); @@ -178,3 +207,8 @@ ALTER TABLE ONLY ntf_server.subscriptions +ALTER TABLE ONLY ntf_server.tokens + ADD CONSTRAINT tokens_wp_server_id_fkey FOREIGN KEY (wp_server_id) REFERENCES ntf_server.webpush_servers(wp_server_id) ON UPDATE RESTRICT ON DELETE RESTRICT; + + + diff --git a/src/Simplex/Messaging/ServiceScheme.hs b/src/Simplex/Messaging/ServiceScheme.hs index 3cd828aa75..1f9fe22e19 100644 --- a/src/Simplex/Messaging/ServiceScheme.hs +++ b/src/Simplex/Messaging/ServiceScheme.hs @@ -9,6 +9,7 @@ import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Network.Socket (HostName, ServiceName) import Simplex.Messaging.Encoding.String (StrEncoding (..)) +import Simplex.Messaging.Encoding (Encoding(..)) data ServiceScheme = SSSimplex | SSAppServer SrvLoc deriving (Eq, Show) @@ -24,6 +25,12 @@ instance StrEncoding ServiceScheme where data SrvLoc = SrvLoc HostName ServiceName deriving (Eq, Ord, Show) +instance Encoding SrvLoc where + smpEncode (SrvLoc h s) = smpEncode (h, s) + smpP = do + (h, s) <- smpP + pure $ SrvLoc h s + instance StrEncoding SrvLoc where strEncode (SrvLoc host port) = B.pack $ host <> if null port then "" else ':' : port strP = SrvLoc <$> host <*> (port <|> pure "") diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 6a1c5cef99..5b495c7834 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -205,8 +205,9 @@ checkNtfToken c = A.checkNtfToken c NRMInteractive verifyNtfToken :: AgentClient -> DeviceToken -> C.CbNonce -> ByteString -> AE () verifyNtfToken c = A.verifyNtfToken c NRMInteractive -runNtfTestCfg :: HasCallStack => (ASrvTransport, AStoreType) -> AgentMsgId -> AServerConfig -> NtfServerConfig -> AgentConfig -> AgentConfig -> (APNSMockServer -> AgentMsgId -> AgentClient -> AgentClient -> IO ()) -> IO () -runNtfTestCfg (t, msType) baseId smpCfg ntfCfg aCfg bCfg runTest = do +runNtfTestCfg :: HasCallStack => (ASrvTransport, AStoreType) -> AgentMsgId -> AServerConfig -> IO NtfServerConfig -> AgentConfig -> AgentConfig -> (APNSMockServer -> AgentMsgId -> AgentClient -> AgentClient -> IO ()) -> IO () +runNtfTestCfg (t, msType) baseId smpCfg ntfCfg' aCfg bCfg runTest = do + ntfCfg <- ntfCfg' ASSCfg qt mt serverStoreCfg <- pure $ testServerStoreConfig msType let smpCfg' = withServerCfg smpCfg $ \cfg_ -> ASrvCfg qt mt cfg_ {serverStoreCfg} withSmpServerConfigOn t smpCfg' testPort $ \_ -> @@ -218,7 +219,7 @@ runNtfTestCfg (t, msType) baseId smpCfg ntfCfg aCfg bCfg runTest = do testNotificationToken :: APNSMockServer -> IO () testNotificationToken apns = do withAgent 1 agentCfg initAgentServers testDB $ \a -> runRight_ $ do - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn @@ -242,7 +243,7 @@ v .-> key = do testNtfTokenRepeatRegistration :: APNSMockServer -> IO () testNtfTokenRepeatRegistration apns = do withAgent 1 agentCfg initAgentServers testDB $ \a -> runRight_ $ do - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn @@ -261,7 +262,7 @@ testNtfTokenRepeatRegistration apns = do testNtfTokenSecondRegistration :: APNSMockServer -> IO () testNtfTokenSecondRegistration apns = withAgentClients2 $ \a a' -> runRight_ $ do - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn @@ -290,7 +291,7 @@ testNtfTokenSecondRegistration apns = testNtfTokenServerRestart :: ASrvTransport -> APNSMockServer -> IO () testNtfTokenServerRestart t apns = do - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" ntfData <- withAgent 1 agentCfg initAgentServers testDB $ \a -> withNtfServer t $ runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic @@ -311,7 +312,7 @@ testNtfTokenServerRestart t apns = do testNtfTokenServerRestartReverify :: ASrvTransport -> APNSMockServer -> IO () testNtfTokenServerRestartReverify t apns = do - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" withAgent 1 agentCfg initAgentServers testDB $ \a -> do ntfData <- withNtfServer t $ runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic @@ -334,7 +335,7 @@ testNtfTokenServerRestartReverify t apns = do testNtfTokenServerRestartReverifyTimeout :: ASrvTransport -> APNSMockServer -> IO () testNtfTokenServerRestartReverifyTimeout t apns = do - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" withAgent 1 agentCfg initAgentServers testDB $ \a@AgentClient {agentEnv = Env {store}} -> do (nonce, verification) <- withNtfServer t $ runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic @@ -355,7 +356,7 @@ testNtfTokenServerRestartReverifyTimeout t apns = do SET tkn_status = ?, tkn_action = ? WHERE provider = ? AND device_token = ? |] - (NTConfirmed, Just (NTAVerify code), PPApnsTest, "abcd" :: ByteString) + (NTConfirmed, Just (NTAVerify code), PPAPNS PPApnsTest, "abcd" :: ByteString) Just NtfToken {ntfTknStatus = NTConfirmed, ntfTknAction = Just (NTAVerify _)} <- withTransaction store getSavedNtfToken pure () threadDelay 1500000 @@ -369,7 +370,7 @@ testNtfTokenServerRestartReverifyTimeout t apns = do testNtfTokenServerRestartReregister :: ASrvTransport -> APNSMockServer -> IO () testNtfTokenServerRestartReregister t apns = do - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" withAgent 1 agentCfg initAgentServers testDB $ \a -> withNtfServer t $ runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic @@ -393,7 +394,7 @@ testNtfTokenServerRestartReregister t apns = do testNtfTokenServerRestartReregisterTimeout :: ASrvTransport -> APNSMockServer -> IO () testNtfTokenServerRestartReregisterTimeout t apns = do - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" withAgent 1 agentCfg initAgentServers testDB $ \a@AgentClient {agentEnv = Env {store}} -> do withNtfServer t $ runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic @@ -409,7 +410,7 @@ testNtfTokenServerRestartReregisterTimeout t apns = do SET tkn_id = NULL, tkn_dh_secret = NULL, tkn_status = ?, tkn_action = ? WHERE provider = ? AND device_token = ? |] - (NTNew, Just NTARegister, PPApnsTest, "abcd" :: ByteString) + (NTNew, Just NTARegister, PPAPNS PPApnsTest, "abcd" :: ByteString) Just NtfToken {ntfTokenId = Nothing, ntfTknStatus = NTNew, ntfTknAction = Just NTARegister} <- withTransaction store getSavedNtfToken pure () threadDelay 1000000 @@ -434,7 +435,7 @@ getTestNtfTokenPort a = testNtfTokenMultipleServers :: ASrvTransport -> APNSMockServer -> IO () testNtfTokenMultipleServers t apns = do - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" withAgent 1 agentCfg initAgentServers2 testDB $ \a -> withNtfServerThreadOn t ntfTestPort ntfTestDBCfg $ \ntf -> withNtfServerThreadOn t ntfTestPort2 ntfTestDBCfg2 $ \ntf2 -> runRight_ $ do @@ -554,7 +555,7 @@ testNotificationSubscriptionExistingConnection apns baseId alice@AgentClient {ag get alice ##> ("", bobId, CON) get bob ##> ("", aliceId, CON) -- register notification token - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" NTRegistered <- registerNtfToken alice tkn NMInstant APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn @@ -607,9 +608,9 @@ testNotificationSubscriptionNewConnection :: HasCallStack => APNSMockServer -> A testNotificationSubscriptionNewConnection apns baseId alice bob = runRight_ $ do -- alice registers notification token - DeviceToken {} <- registerTestToken alice "abcd" NMInstant apns + APNSDeviceToken {} <- registerTestToken alice "abcd" NMInstant apns -- bob registers notification token - DeviceToken {} <- registerTestToken bob "bcde" NMInstant apns + APNSDeviceToken {} <- registerTestToken bob "bcde" NMInstant apns -- establish connection liftIO $ threadDelay 50000 (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe @@ -645,7 +646,7 @@ testNotificationSubscriptionNewConnection apns baseId alice bob = registerTestToken :: AgentClient -> ByteString -> NotificationsMode -> APNSMockServer -> ExceptT AgentErrorType IO DeviceToken registerTestToken a token mode apns = do - let tkn = DeviceToken PPApnsTest token + let tkn = APNSDeviceToken PPApnsTest token NTRegistered <- registerNtfToken a tkn mode Just APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}} <- timeout 1000000 $ getMockNotification apns tkn @@ -931,7 +932,8 @@ testMigrateToServiceSubscriptions :: HasCallStack => (ASrvTransport, AStoreType) testMigrateToServiceSubscriptions ps@(t, msType) = withAgentClients2 $ \a b -> do (c1, c2, c3) <- withSmpServerConfigOn t cfgNoService testPort $ \_ -> do (c1, c2) <- withAPNSMockServer $ \apns -> do - withNtfServerCfg ntfCfgNoService $ \_ -> runRight $ do + cfg' <- ntfCfgNoService + withNtfServerCfg cfg' $ \_ -> runRight $ do _tkn <- registerTestToken a "abcd" NMInstant apns -- create 2 connections with ntfs, test delivery c1 <- testConnectMsg apns a b "hello" @@ -970,27 +972,31 @@ testMigrateToServiceSubscriptions ps@(t, msType) = withAgentClients2 $ \a b -> d serverDOWN a b 5 -- Ntf server does not use server, subscriptions downgrade - c6 <- withAPNSMockServer $ \apns -> withSmpServer ps $ withNtfServerCfg ntfCfgNoService $ \_ -> do - serverUP a b 5 - runRight $ do - testSendMsg apns a b c1 "msg 1" - testSendMsg apns a b c2 "msg 2" - testSendMsg apns a b c3 "msg 3" - testSendMsg apns a b c4 "msg 4" - testSendMsg apns a b c5 "msg 5" - testConnectMsg apns a b "msg 6" + c6 <- withAPNSMockServer $ \apns -> do + cfg' <- ntfCfgNoService + withSmpServer ps $ withNtfServerCfg cfg' $ \_ -> do + serverUP a b 5 + runRight $ do + testSendMsg apns a b c1 "msg 1" + testSendMsg apns a b c2 "msg 2" + testSendMsg apns a b c3 "msg 3" + testSendMsg apns a b c4 "msg 4" + testSendMsg apns a b c5 "msg 5" + testConnectMsg apns a b "msg 6" serverDOWN a b 6 - withAPNSMockServer $ \apns -> withSmpServerConfigOn t cfgNoService testPort $ \_ -> withNtfServerCfg ntfCfgNoService $ \_ -> do - serverUP a b 6 - runRight_ $ do - testSendMsg apns a b c1 "1" - testSendMsg apns a b c2 "2" - testSendMsg apns a b c3 "3" - testSendMsg apns a b c4 "4" - testSendMsg apns a b c5 "5" - testSendMsg apns a b c6 "6" - void $ testConnectMsg apns a b "7" + withAPNSMockServer $ \apns -> do + cfg' <- ntfCfgNoService + withSmpServerConfigOn t cfgNoService testPort $ \_ -> withNtfServerCfg cfg' $ \_ -> do + serverUP a b 6 + runRight_ $ do + testSendMsg apns a b c1 "1" + testSendMsg apns a b c2 "2" + testSendMsg apns a b c3 "3" + testSendMsg apns a b c4 "4" + testSendMsg apns a b c5 "5" + testSendMsg apns a b c6 "6" + void $ testConnectMsg apns a b "7" serverDOWN a b 7 where testConnectMsg apns a b msg = do @@ -1013,7 +1019,9 @@ testMigrateToServiceSubscriptions ps@(t, msType) = withAgentClients2 $ \a b -> d cfgNoService = updateCfg (cfgMS msType) $ \(cfg' :: ServerConfig s) -> let ServerConfig {transportConfig} = cfg' in cfg' {transportConfig = transportConfig {askClientCert = False}} :: ServerConfig s - ntfCfgNoService = ntfServerCfg {useServiceCreds = False, transports = [(ntfTestPort, t, False)]} + ntfCfgNoService = do + cfg' <- ntfServerCfg + pure cfg' {useServiceCreds = False, transports = [(ntfTestPort, t, False)]} testMessage_ :: HasCallStack => APNSMockServer -> AgentClient -> ConnId -> AgentClient -> ConnId -> SMP.MsgBody -> ExceptT AgentErrorType IO () testMessage_ apns a aId b bId msg = do diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index 30b648401c..cb59b3ec6d 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -16,6 +16,7 @@ module NtfClient where import Control.Concurrent.STM (retry) +import Control.Exception (throwIO) import Control.Monad import Control.Monad.Except (runExceptT) import Control.Monad.IO.Class @@ -44,8 +45,10 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfResponse) import Simplex.Messaging.Notifications.Server (runNtfServerBlocking) import Simplex.Messaging.Notifications.Server.Env +import Simplex.Messaging.Notifications.Server.Main (getVapidKey) import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Server.Push.APNS.Internal +import Simplex.Messaging.Notifications.Server.Push.WebPush (WebPushConfig (..)) import Simplex.Messaging.Notifications.Transport import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) @@ -124,55 +127,64 @@ testNtfClient client = do Right th -> client th Left e -> error $ show e -ntfServerCfg :: NtfServerConfig -ntfServerCfg = - NtfServerConfig - { transports = [], - controlPort = Nothing, - controlPortUserAuth = Nothing, - controlPortAdminAuth = Nothing, - subIdBytes = 24, - regCodeBytes = 32, - clientQSize = 2, - pushQSize = 2, - smpAgentCfg = defaultSMPClientAgentConfig {persistErrorInterval = 0}, - apnsConfig = - defaultAPNSPushClientConfig - { apnsPort = apnsTestPort, - caStoreFile = "tests/fixtures/ca.crt" - }, - subsBatchSize = 900, - inactiveClientExpiration = Just defaultInactiveClientExpiration, - dbStoreConfig = ntfTestDBCfg, - ntfCredentials = ntfTestServerCredentials, - useServiceCreds = True, - periodicNtfsInterval = 1, - -- stats config - logStatsInterval = Nothing, - logStatsStartTime = 0, - serverStatsLogFile = "tests/ntf-server-stats.daily.log", - serverStatsBackupFile = Nothing, - prometheusInterval = Nothing, - prometheusMetricsFile = ntfTestPrometheusMetricsFile, - ntfServerVRange = supportedServerNTFVRange, - transportConfig = mkTransportServerConfig True (Just alpnSupportedNTFHandshakes) False, - startOptions = defaultStartOptions - } - -ntfServerCfgVPrev :: NtfServerConfig +ntfServerCfg :: IO NtfServerConfig +ntfServerCfg = do + vapidKey <- getVapidKey "tests/fixtures/vapid.privkey" + pure + NtfServerConfig + { transports = [], + controlPort = Nothing, + controlPortUserAuth = Nothing, + controlPortAdminAuth = Nothing, + subIdBytes = 24, + regCodeBytes = 32, + clientQSize = 2, + pushQSize = 2, + smpAgentCfg = defaultSMPClientAgentConfig {persistErrorInterval = 0}, + apnsConfig = + defaultAPNSPushClientConfig + { apnsPort = apnsTestPort, + caStoreFile = "tests/fixtures/ca.crt" + }, + wpConfig = WebPushConfig {vapidKey, paddedNtfLength = 3072}, + subsBatchSize = 900, + inactiveClientExpiration = Just defaultInactiveClientExpiration, + dbStoreConfig = ntfTestDBCfg, + ntfCredentials = ntfTestServerCredentials, + useServiceCreds = True, + periodicNtfsInterval = 1, + -- stats config + logStatsInterval = Nothing, + logStatsStartTime = 0, + serverStatsLogFile = "tests/ntf-server-stats.daily.log", + serverStatsBackupFile = Nothing, + prometheusInterval = Nothing, + prometheusMetricsFile = ntfTestPrometheusMetricsFile, + ntfServerVRange = supportedServerNTFVRange, + transportConfig = mkTransportServerConfig True (Just alpnSupportedNTFHandshakes) False, + startOptions = defaultStartOptions + } + +ntfServerCfgVPrev :: IO NtfServerConfig ntfServerCfgVPrev = ntfServerCfg - { ntfServerVRange = prevRange $ ntfServerVRange ntfServerCfg, + >>= \cfg -> pure $ ntfServerCfgVPrev' cfg + +ntfServerCfgVPrev' :: NtfServerConfig -> NtfServerConfig +ntfServerCfgVPrev' cfg = + cfg + { ntfServerVRange = prevRange $ ntfServerVRange cfg, smpAgentCfg = smpAgentCfg' {smpCfg = smpCfg' {serverVRange = prevRange serverVRange'}} } where - smpAgentCfg' = smpAgentCfg ntfServerCfg + smpAgentCfg' = smpAgentCfg cfg smpCfg' = smpCfg smpAgentCfg' serverVRange' = serverVRange smpCfg' withNtfServerThreadOn :: HasCallStack => ASrvTransport -> ServiceName -> PostgresStoreCfg -> (HasCallStack => ThreadId -> IO a) -> IO a -withNtfServerThreadOn t port' dbStoreConfig = - withNtfServerCfg ntfServerCfg {transports = [(port', t, False)], dbStoreConfig} +withNtfServerThreadOn t port' dbStoreConfig a = + ntfServerCfg >>= \cfg -> + withNtfServerCfg cfg {transports = [(port', t, False)], dbStoreConfig} a withNtfServerCfg :: HasCallStack => NtfServerConfig -> (ThreadId -> IO a) -> IO a withNtfServerCfg cfg@NtfServerConfig {transports} = @@ -293,7 +305,8 @@ getAPNSMockServer config@HTTP2ServerConfig {qSize} = do sendApnsResponse $ APNSRespError N.badRequest400 "bad_request_body" getMockNotification :: MonadIO m => APNSMockServer -> DeviceToken -> m APNSMockRequest -getMockNotification APNSMockServer {notifications} (DeviceToken _ token) = do +getMockNotification _ (WPDeviceToken _ _) = liftIO . throwIO $ userError "Invalid pusher" +getMockNotification APNSMockServer {notifications} (APNSDeviceToken _ token) = do atomically $ TM.lookup token notifications >>= maybe retry readTBQueue getAnyMockNotification :: MonadIO m => APNSMockServer -> m APNSMockRequest diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index a4f0a7d626..c4dd72b24b 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -107,7 +107,7 @@ testNotificationSubscription (ATransport t, msType) createQueue = (nPub, nKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g (tknPub, tknKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g (dhPub, dhPriv :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g - let tkn = DeviceToken PPApnsTest "abcd" + let tkn = APNSDeviceToken PPApnsTest "abcd" withAPNSMockServer $ \apns -> smpTest2 t msType $ \rh sh -> ntfTest t $ \nh -> do @@ -160,7 +160,7 @@ testNotificationSubscription (ATransport t, msType) createQueue = (msgBody, "hello") #== "delivered from queue" Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, ACK mId1) -- replace token - let tkn' = DeviceToken PPApnsTest "efgh" + let tkn' = APNSDeviceToken PPApnsTest "efgh" RespNtf "7" tId' NROk <- signSendRecvNtf nh tknKey ("7", tId, TRPL tkn') tId `shouldBe` tId' APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData2}} <- @@ -237,7 +237,7 @@ registerToken nh apns token = do g <- C.newRandom (tknPub, tknKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g (dhPub, dhPriv :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g - let tkn = DeviceToken PPApnsTest token + let tkn = APNSDeviceToken PPApnsTest token RespNtf "1" NoEntity (NRTknId tId ntfDh) <- signSendRecvNtf nh tknKey ("1", NoEntity, TNEW $ NewNtfTkn tkn tknPub dhPub) APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn diff --git a/tests/NtfWPTests.hs b/tests/NtfWPTests.hs new file mode 100644 index 0000000000..1323eafa1a --- /dev/null +++ b/tests/NtfWPTests.hs @@ -0,0 +1,101 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module NtfWPTests where + +import Control.Monad (unless) +import Control.Monad.Except (runExceptT) +import qualified Crypto.PubKey.ECC.Types as ECC +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import Data.Either (isLeft) +import Data.IORef (newIORef) +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Encoding.String (StrEncoding (..)) +import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Server.Main (getVapidKey) +import Simplex.Messaging.Notifications.Server.Push.WebPush (getVapidHeader', wpEncrypt') +import Test.Hspec hiding (fit, it) +import Util + +ntfWPTests :: Spec +ntfWPTests = describe "NTF Protocol" $ do + it "decode WPDeviceToken from string" testWPDeviceTokenStrEncoding + it "decode invalid WPDeviceToken" testInvalidWPDeviceTokenStrEncoding + it "Encrypt RFC8291 example" testWPEncryption + it "Vapid header cache" testVapidCache + +testWPDeviceTokenStrEncoding :: Expectation +testWPDeviceTokenStrEncoding = do + let ts = "webpush https://localhost/secret AQ3VfRX3_F38J3ltcmMVRg BKuw4WxupnnrZHqk6vCwoms4tOpitZMvFdR9eAn54yOPY4q9jpXOpl-Ui_FwbIy8ZbFCnuaS7RnO02ahuL4XxIM" + -- let ts = "apns_null test_ntf_token" + -- let ts = "apns_test 11111111222222223333333344444444" + + let auth = either error id $ strDecode "AQ3VfRX3_F38J3ltcmMVRg" + let pk = either error id $ strDecode "BKuw4WxupnnrZHqk6vCwoms4tOpitZMvFdR9eAn54yOPY4q9jpXOpl-Ui_FwbIy8ZbFCnuaS7RnO02ahuL4XxIM" + let params :: WPTokenParams = either error id $ strDecode "/secret AQ3VfRX3_F38J3ltcmMVRg BKuw4WxupnnrZHqk6vCwoms4tOpitZMvFdR9eAn54yOPY4q9jpXOpl-Ui_FwbIy8ZbFCnuaS7RnO02ahuL4XxIM" + wpPath params `shouldBe` "/secret" + let key = wpKey params + wpAuth key `shouldBe` auth + wpP256dh key `shouldBe` pk + + let pp@(WPP s) :: WPProvider = either error id $ strDecode "webpush https://localhost" + + let parsed = either error id $ strDecode ts + parsed `shouldBe` WPDeviceToken pp params + -- TODO: strEncoding should be base64url _without padding_ + -- strEncode parsed `shouldBe` ts + + strEncode s <> wpPath params `shouldBe` "https://localhost/secret" + +testInvalidWPDeviceTokenStrEncoding :: Expectation +testInvalidWPDeviceTokenStrEncoding = do + -- http-client parser parseUrlThrow is very very lax, + -- e.g "https://#1" is a valid URL. But that is the same parser + -- we use to send the requests, so that's fine. + let ts = "webpush https://localhost:/ AQ3VfRX3_F38J3ltcmMVRg BKuw4WxupnnrZHqk6vCwoms4tOpitZMvFdR9eAn54yOPY4q9jpXOpl-Ui_FwbIy8ZbFCnuaS7RnO02ahuL4XxIM" + t = strDecode ts :: Either String DeviceToken + t `shouldSatisfy` isLeft + +-- | Example from RFC8291 +testWPEncryption :: Expectation +testWPEncryption = do + let clearT :: ByteString = "When I grow up, I want to be a watermelon" + pParams :: WPTokenParams = either error id $ strDecode "/push/JzLQ3raZJfFBR0aqvOMsLrt54w4rJUsV BTBZMqHH6r4Tts7J_aSIgg BCVxsr7N_eNgVRqvHtD0zTZsEc6-VV-JvLexhqUzORcxaOzi6-AYWXvTBHm4bjyPjs7Vd8pZGH6SRpkNtoIAiw4" + salt :: ByteString = either error id $ strDecode "DGv6ra1nlYgDCS1FRnbzlw" + privBS :: ByteString = either error id $ strDecode "yfWPiYE-n46HLnH0KqZOF1fJJU3MYrct3AELtAQ-oRw" + asPriv :: ECC.PrivateNumber <- case C.uncompressDecodePrivateNumber privBS of + Left e -> fail $ "Cannot decode PrivateNumber from b64 " <> show e + Right p -> pure p + mCip <- runExceptT $ wpEncrypt' (wpKey pParams) asPriv salt clearT + cipher <- case mCip of + Left _ -> fail "Cannot encrypt clear text" + Right c -> pure c + strEncode cipher `shouldBe` "DGv6ra1nlYgDCS1FRnbzlwAAEABBBP4z9KsN6nGRTbVYI_c7VJSPQTBtkgcy27mlmlMoZIIgDll6e3vCYLocInmYWAmS6TlzAC8wEqKK6PBru3jl7A_yl95bQpu6cVPTpK4Mqgkf1CXztLVBSt2Ks3oZwbuwXPXLWyouBWLVWGNWQexSgSxsj_Qulcy4a-fN" + +testVapidCache :: Expectation +testVapidCache = do + let wpaud = "https://localhost" + let now = 1761900906 + cache <- newIORef Nothing + vapidKey <- getVapidKey "tests/fixtures/vapid.privkey" + v1 <- getVapidHeader' now vapidKey cache wpaud + v2 <- getVapidHeader' now vapidKey cache wpaud + v1 `shouldBe` v2 + -- we just don't test the signature here + v1 `shouldContainBS` "vapid t=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJleHAiOjE3NjE5MDQ1MDYsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0Iiwic3ViIjoiaHR0cHM6Ly9naXRodWIuY29tL3NpbXBsZXgtY2hhdC9zaW1wbGV4bXEvIn0." + v1 `shouldContainBS` ",k=BIk7ASkEr1A1rJRGXMKi77tAGj3dRouSgZdW6S5pee7a3h7fkvd0OYQixy4yj35UFZt8hd9TwAQiybDK_HJLwJA" + v3 <- getVapidHeader' (now + 3600) vapidKey cache wpaud + v1 `shouldNotBe` v3 + v3 `shouldContainBS` "vapid t=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9." + v3 `shouldContainBS` ",k=BIk7ASkEr1A1rJRGXMKi77tAGj3dRouSgZdW6S5pee7a3h7fkvd0OYQixy4yj35UFZt8hd9TwAQiybDK_HJLwJA" + +shouldContainBS :: ByteString -> ByteString -> Expectation +shouldContainBS actual expected = + unless (expected `B.isInfixOf` actual) $ + expectationFailure $ + "Expected ByteString to contain:\n" + ++ show expected + ++ "\nBut got:\n" + ++ show actual diff --git a/tests/Test.hs b/tests/Test.hs index 3e36e192d6..611a6e2413 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -45,6 +45,7 @@ import AgentTests.SchemaDump (schemaDumpTest) #if defined(dbServerPostgres) import NtfServerTests (ntfServerTests) import NtfClient (ntfTestServerDBConnectInfo, ntfTestStoreDBOpts) +import NtfWPTests (ntfWPTests) import PostgresSchemaDump (postgresSchemaDumpTest) import SMPClient (testServerDBConnectInfo, testStoreDBOpts) import Simplex.Messaging.Notifications.Server.Store.Migrations (ntfServerMigrations) @@ -139,6 +140,7 @@ main = do -- before (pure $ ASType SQSPostgres SMSJournal) smpProxyTests describe "SMP proxy, postgres-only message store" $ before (pure $ ASType SQSPostgres SMSPostgres) smpProxyTests + describe "NTF WP tests" ntfWPTests #endif -- xdescribe "SMP client agent, server jornal message store" $ agentTests (transport @TLS, ASType SQSMemory SMSJournal) describe "SMP client agent, server memory message store" $ agentTests (transport @TLS, ASType SQSMemory SMSMemory) diff --git a/tests/fixtures/vapid.privkey b/tests/fixtures/vapid.privkey new file mode 100644 index 0000000000..294260c2d6 --- /dev/null +++ b/tests/fixtures/vapid.privkey @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIMTAncBq2I7G3KvW4C8Y8Heg2cbcDTobbGFQFnBiA5M/oAoGCCqGSM49 +AwEHoUQDQgAEiTsBKQSvUDWslEZcwqLvu0AaPd1Gi5KBl1bpLml57treHt+S93Q5 +hCLHLjKPflQVm3yF31PABCLJsMr8ckvAkA== +-----END EC PRIVATE KEY-----