Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions persistent-mysql-haskell/Database/Persist/MySQL.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ module Database.Persist.MySQL
, copyUnlessNull
, copyUnlessEmpty
, copyUnlessEq
, copyWhenGreater
-- * TLS configuration
, setMySQLConnectInfoTLS
, MySQLTLS.TrustedCAStore(..)
Expand Down Expand Up @@ -1424,6 +1425,12 @@ data HandleUpdateCollision record where
CopyField :: EntityField record typ -> HandleUpdateCollision record
-- | Only copy the field if it is not equal to the provided value.
CopyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record
-- | Only copy the field if the new value is greater than the existing value.
CopyWhenGreater :: PersistField typ => EntityField record typ -> HandleUpdateCollision record

-- | Copy the field into the database only if its new value is greater than existing one.
copyWhenGreater :: PersistField typ => EntityField record typ -> HandleUpdateCollision record
copyWhenGreater = CopyWhenGreater

-- | Copy the field into the database only if the value in the
-- corresponding record is non-@NULL@.
Expand Down Expand Up @@ -1596,10 +1603,18 @@ mkBulkInsertQuery
mkBulkInsertQuery records fieldValues updates =
(q, recordValues <> updsValues <> copyUnlessValues)
where
mfieldDef x = case x of
CopyField rec -> Right (fieldDbToText (persistFieldDef rec))
CopyUnlessEq rec val -> Left (fieldDbToText (persistFieldDef rec), toPersistValue val)
(fieldsToMaybeCopy, updateFieldNames) = partitionEithers $ map mfieldDef fieldValues
collectFieldDef !x (compareCopyList, toMaybeCopyList, fieldNamesList) = case x of
CopyField rec ->
let !fieldText = fieldDbToText (persistFieldDef rec)
in (compareCopyList, toMaybeCopyList, fieldText : fieldNamesList)
CopyUnlessEq rec val ->
let !fieldText = (fieldDbToText (persistFieldDef rec), toPersistValue val)
in (compareCopyList, fieldText : toMaybeCopyList, fieldNamesList)
CopyWhenGreater rec ->
let !fieldText = fieldDbToText (persistFieldDef rec)
in (fieldText : compareCopyList, toMaybeCopyList, fieldNamesList)
(fieldsCompareCopy, fieldsToMaybeCopy, updateFieldNames)
= foldr collectFieldDef ([], [], []) fieldValues
fieldDbToText = T.pack . escapeF . fieldDB
entityDef' = entityDef $ either id (map entityVal) records
firstField = case entityFieldNames of
Expand All @@ -1623,11 +1638,19 @@ mkBulkInsertQuery records fieldValues updates =
, n
, ")"
]
mkCompareFieldSet n = T.concat
[ n
, "=IF("
, n, " < VALUES(", n, "),"
, "VALUES(", n, "),", n
, ")"
]
condFieldSets = map (uncurry mkCondFieldSet) fieldsToMaybeCopy
compareFieldSets = map mkCompareFieldSet fieldsCompareCopy
fieldSets = map (\n -> T.concat [n, "=VALUES(", n, ")"]) updateFieldNames
upds = map (Util.mkUpdateText' (pack . escapeF) id) updates
updsValues = map (\(Update _ val _) -> toPersistValue val) updates
updateText = case fieldSets <> upds <> condFieldSets of
updateText = case fieldSets <> upds <> condFieldSets <> compareFieldSets of
[] -> T.concat [firstField, "=", firstField]
xs -> Util.commaSeparated xs
q = T.concat
Expand Down
2 changes: 1 addition & 1 deletion persistent-mysql-haskell/persistent-mysql-haskell.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ library
, resourcet >= 1.1
, monad-logger
, resource-pool
, mysql-haskell >= 0.8.0.0 && < 1.0
, mysql-haskell >= 0.8.0.0 && < 1.2
-- keep the following in sync with @mysql-haskell@ .cabal
, io-streams >= 1.2 && < 2.0
, time >= 1.5.0
Expand Down
12 changes: 12 additions & 0 deletions persistent-mysql-haskell/test/InsertDuplicateUpdate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ specs = describe "DuplicateKeyUpdate" $ do
[]
dbItems <- sort . fmap entityVal <$> selectList [] []
dbItems @== sort postUpdate
it "only copies when newer value is greater than existing one" $ db $ do
deleteWhere ([] :: [Filter Item])
insertMany_ items
let newItems = map (\i -> i { itemQuantity = Just 0, itemPrice = fmap (*2) (itemPrice i) }) items
postUpdate = map (\i -> i { itemPrice = fmap (*2) (itemPrice i) }) items
insertManyOnDuplicateKeyUpdate
newItems
[ copyWhenGreater ItemPrice
]
[]
dbItems <- sort . fmap entityVal <$> selectList [] []
dbItems @== sort postUpdate
it "inserts without modifying existing records if no updates specified" $ db $ do
let newItem = Item "item3" "hi friends!" Nothing Nothing
deleteWhere ([] :: [Filter Item])
Expand Down
1 change: 1 addition & 0 deletions persistent-mysql-haskell/test/main.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
Expand Down
Loading