From 44db7eb8437d1f28a6cd86af02dea418547dc14c Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Wed, 29 Oct 2025 16:20:08 +0100 Subject: [PATCH] feat: `ON DUPLICATE UPDATE x=IF(x < VALUES(x), VALUES(x), x)` --- .../Database/Persist/MySQL.hs | 33 ++++++++++++++++--- .../persistent-mysql-haskell.cabal | 2 +- .../test/InsertDuplicateUpdate.hs | 12 +++++++ persistent-mysql-haskell/test/main.hs | 1 + 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/persistent-mysql-haskell/Database/Persist/MySQL.hs b/persistent-mysql-haskell/Database/Persist/MySQL.hs index bf48ff7f3..b0c52ce25 100644 --- a/persistent-mysql-haskell/Database/Persist/MySQL.hs +++ b/persistent-mysql-haskell/Database/Persist/MySQL.hs @@ -33,6 +33,7 @@ module Database.Persist.MySQL , copyUnlessNull , copyUnlessEmpty , copyUnlessEq + , copyWhenGreater -- * TLS configuration , setMySQLConnectInfoTLS , MySQLTLS.TrustedCAStore(..) @@ -1424,6 +1425,12 @@ data HandleUpdateCollision record where CopyField :: EntityField record typ -> HandleUpdateCollision record -- | Only copy the field if it is not equal to the provided value. CopyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record + -- | Only copy the field if the new value is greater than the existing value. + CopyWhenGreater :: PersistField typ => EntityField record typ -> HandleUpdateCollision record + +-- | Copy the field into the database only if its new value is greater than existing one. +copyWhenGreater :: PersistField typ => EntityField record typ -> HandleUpdateCollision record +copyWhenGreater = CopyWhenGreater -- | Copy the field into the database only if the value in the -- corresponding record is non-@NULL@. @@ -1596,10 +1603,18 @@ mkBulkInsertQuery mkBulkInsertQuery records fieldValues updates = (q, recordValues <> updsValues <> copyUnlessValues) where - mfieldDef x = case x of - CopyField rec -> Right (fieldDbToText (persistFieldDef rec)) - CopyUnlessEq rec val -> Left (fieldDbToText (persistFieldDef rec), toPersistValue val) - (fieldsToMaybeCopy, updateFieldNames) = partitionEithers $ map mfieldDef fieldValues + collectFieldDef !x (compareCopyList, toMaybeCopyList, fieldNamesList) = case x of + CopyField rec -> + let !fieldText = fieldDbToText (persistFieldDef rec) + in (compareCopyList, toMaybeCopyList, fieldText : fieldNamesList) + CopyUnlessEq rec val -> + let !fieldText = (fieldDbToText (persistFieldDef rec), toPersistValue val) + in (compareCopyList, fieldText : toMaybeCopyList, fieldNamesList) + CopyWhenGreater rec -> + let !fieldText = fieldDbToText (persistFieldDef rec) + in (fieldText : compareCopyList, toMaybeCopyList, fieldNamesList) + (fieldsCompareCopy, fieldsToMaybeCopy, updateFieldNames) + = foldr collectFieldDef ([], [], []) fieldValues fieldDbToText = T.pack . escapeF . fieldDB entityDef' = entityDef $ either id (map entityVal) records firstField = case entityFieldNames of @@ -1623,11 +1638,19 @@ mkBulkInsertQuery records fieldValues updates = , n , ")" ] + mkCompareFieldSet n = T.concat + [ n + , "=IF(" + , n, " < VALUES(", n, ")," + , "VALUES(", n, "),", n + , ")" + ] condFieldSets = map (uncurry mkCondFieldSet) fieldsToMaybeCopy + compareFieldSets = map mkCompareFieldSet fieldsCompareCopy fieldSets = map (\n -> T.concat [n, "=VALUES(", n, ")"]) updateFieldNames upds = map (Util.mkUpdateText' (pack . escapeF) id) updates updsValues = map (\(Update _ val _) -> toPersistValue val) updates - updateText = case fieldSets <> upds <> condFieldSets of + updateText = case fieldSets <> upds <> condFieldSets <> compareFieldSets of [] -> T.concat [firstField, "=", firstField] xs -> Util.commaSeparated xs q = T.concat diff --git a/persistent-mysql-haskell/persistent-mysql-haskell.cabal b/persistent-mysql-haskell/persistent-mysql-haskell.cabal index bcf8f99da..fbd6e00d1 100644 --- a/persistent-mysql-haskell/persistent-mysql-haskell.cabal +++ b/persistent-mysql-haskell/persistent-mysql-haskell.cabal @@ -38,7 +38,7 @@ library , resourcet >= 1.1 , monad-logger , resource-pool - , mysql-haskell >= 0.8.0.0 && < 1.0 + , mysql-haskell >= 0.8.0.0 && < 1.2 -- keep the following in sync with @mysql-haskell@ .cabal , io-streams >= 1.2 && < 2.0 , time >= 1.5.0 diff --git a/persistent-mysql-haskell/test/InsertDuplicateUpdate.hs b/persistent-mysql-haskell/test/InsertDuplicateUpdate.hs index e98a12db5..15a69f54a 100644 --- a/persistent-mysql-haskell/test/InsertDuplicateUpdate.hs +++ b/persistent-mysql-haskell/test/InsertDuplicateUpdate.hs @@ -101,6 +101,18 @@ specs = describe "DuplicateKeyUpdate" $ do [] dbItems <- sort . fmap entityVal <$> selectList [] [] dbItems @== sort postUpdate + it "only copies when newer value is greater than existing one" $ db $ do + deleteWhere ([] :: [Filter Item]) + insertMany_ items + let newItems = map (\i -> i { itemQuantity = Just 0, itemPrice = fmap (*2) (itemPrice i) }) items + postUpdate = map (\i -> i { itemPrice = fmap (*2) (itemPrice i) }) items + insertManyOnDuplicateKeyUpdate + newItems + [ copyWhenGreater ItemPrice + ] + [] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort postUpdate it "inserts without modifying existing records if no updates specified" $ db $ do let newItem = Item "item3" "hi friends!" Nothing Nothing deleteWhere ([] :: [Filter Item]) diff --git a/persistent-mysql-haskell/test/main.hs b/persistent-mysql-haskell/test/main.hs index c89e42173..08551a22c 100644 --- a/persistent-mysql-haskell/test/main.hs +++ b/persistent-mysql-haskell/test/main.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-}