diff --git a/dataframe.cabal b/dataframe.cabal index c10646b..11e3794 100644 --- a/dataframe.cabal +++ b/dataframe.cabal @@ -105,7 +105,8 @@ library hashable >= 1.2 && < 2, process ^>= 1.6, snappy-hs ^>= 0.1, - random >= 1.3 && < 2, + random >= 1.2 && < 1.3, + random-shuffle >= 0.0.4 && < 1, regex-tdfa >= 1.3.0 && < 2, scientific >=0.3.1 && <0.4, template-haskell >= 2.0 && < 3, @@ -198,6 +199,7 @@ test-suite tests Operations.Join, Operations.Merge, Operations.ReadCsv, + Operations.Shuffle, Operations.Sort, Operations.Subset, Operations.Statistics, @@ -209,7 +211,6 @@ test-suite tests directory >= 1.3.0.0 && < 2, HUnit ^>= 1.6, QuickCheck >= 2 && < 3, - random >= 1 && < 2, random-shuffle >= 0.0.4 && < 1, random >= 1 && < 2, text >= 2.0 && < 3, diff --git a/examples/examples.cabal b/examples/examples.cabal index c53d878..2026919 100644 --- a/examples/examples.cabal +++ b/examples/examples.cabal @@ -85,7 +85,7 @@ executable chipotle hashable >= 1.2 && < 2, process ^>= 1.6, snappy-hs ^>= 0.1, - random >= 1.3 && < 2, + random >= 1.2 && < 1.3, regex-tdfa >= 1.3.0 && < 2, scientific >=0.3.1 && <0.4, template-haskell >= 2.0 && < 3, @@ -172,7 +172,7 @@ executable california_housing hasktorch >= 0.2.1.6 && < 0.3, process ^>= 1.6, snappy-hs ^>= 0.1, - random >= 1.3 && < 2, + random >= 1.2 && < 1.3, regex-tdfa >= 1.3.0 && < 2, scientific >=0.3.1 && <0.4, template-haskell >= 2.0 && < 3, @@ -259,7 +259,7 @@ executable one_billion_row_challenge hasktorch >= 0.2.1.6 && < 0.3, process ^>= 1.6, snappy-hs ^>= 0.1, - random >= 1.3 && < 2, + random >= 1.2 && < 1.3, regex-tdfa >= 1.3.0 && < 2, scientific >=0.3.1 && <0.4, template-haskell >= 2.0 && < 3, @@ -346,7 +346,7 @@ executable iris hasktorch >= 0.2.1.6 && < 0.3, process ^>= 1.6, snappy-hs ^>= 0.1, - random >= 1.3 && < 2, + random >= 1.2 && < 1.3, regex-tdfa >= 1.3.0 && < 2, scientific >=0.3.1 && <0.4, template-haskell >= 2.0 && < 3, diff --git a/flake.nix b/flake.nix index 0eae984..7c10fab 100644 --- a/flake.nix +++ b/flake.nix @@ -13,8 +13,6 @@ hsPkgs = pkgs.haskellPackages.extend (self: super: { dataframe = self.callCabal2nix "dataframe" ./. { }; - random = pkgs.haskellPackages.callHackage "random" "1.3.1" { }; - time-compat = pkgs.haskell.lib.dontCheck super.time-compat; }); in { diff --git a/src/DataFrame/Operations/Permutation.hs b/src/DataFrame/Operations/Permutation.hs index c2b31da..a381f98 100644 --- a/src/DataFrame/Operations/Permutation.hs +++ b/src/DataFrame/Operations/Permutation.hs @@ -18,6 +18,7 @@ import DataFrame.Internal.Expression import DataFrame.Internal.Row import DataFrame.Operations.Core import System.Random +import System.Random.Shuffle (shuffle') -- | Sort order taken as a parameter by the 'sortBy' function. data SortOrder where @@ -75,4 +76,4 @@ shuffle pureGen df = df{columns = V.map (atIndicesStable indexes) (columns df)} shuffledIndices :: (RandomGen g) => g -> Int -> VU.Vector Int -shuffledIndices pureGen k = VU.fromList (fst (uniformShuffleList [0 .. (k - 1)] pureGen)) +shuffledIndices pureGen k = VU.fromList (shuffle' [0 .. (k - 1)] k pureGen) diff --git a/tests/Main.hs b/tests/Main.hs index 37820ed..426602a 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -29,6 +29,7 @@ import qualified Operations.InsertColumn import qualified Operations.Join import qualified Operations.Merge import qualified Operations.ReadCsv +import qualified Operations.Shuffle import qualified Operations.Sort import qualified Operations.Statistics import qualified Operations.Subset @@ -5120,6 +5121,7 @@ tests = ++ Operations.Join.tests ++ Operations.Merge.tests ++ Operations.ReadCsv.tests + ++ Operations.Shuffle.tests ++ Operations.Sort.tests ++ Operations.Statistics.tests ++ Operations.Take.tests diff --git a/tests/Operations/Shuffle.hs b/tests/Operations/Shuffle.hs new file mode 100644 index 0000000..f1c52f4 --- /dev/null +++ b/tests/Operations/Shuffle.hs @@ -0,0 +1,84 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeApplications #-} + +module Operations.Shuffle where + +import qualified DataFrame as D + +import DataFrame.Operations.Permutation (shuffle) +import System.Random (mkStdGen) +import Test.HUnit (Test (..), assertEqual) + +testDataFrame :: D.DataFrame +testDataFrame = + D.fromNamedColumns + [ ("numbers", D.fromList @Int [1 .. 26]) + ] + +-- Test that shuffling does anything at all +shuffleShuffles :: Test +shuffleShuffles = + let gen = mkStdGen 1234 + shuffled = shuffle gen testDataFrame + initialNumbers = D.extractNumericColumn "numbers" testDataFrame + shuffledNumbers = D.extractNumericColumn "numbers" shuffled + in TestCase + ( assertEqual + "Shuffled column unequal to initial column" + False + (initialNumbers == shuffledNumbers) + ) + +shufflePreservesColumnNames :: Test +shufflePreservesColumnNames = + let gen = mkStdGen 837 + shuffled = shuffle gen testDataFrame + in TestCase + ( assertEqual + "Column names are unchanged" + (D.columnNames shuffled) + (D.columnNames testDataFrame) + ) + +-- Test that un-shuffling restores the original dataframe +-- which is known to be sorted in this case +shufflePreservesData :: Test +shufflePreservesData = + let gen = mkStdGen 1234 + shuffled = shuffle gen testDataFrame + sortedShuffled = D.sortBy [D.Asc (D.col @Int "numbers")] shuffled + in TestCase + (assertEqual "sort recovers initial numbers" testDataFrame sortedShuffled) + +-- Test that shuffling isn't doing anything sneaky with summoning +-- random numbers somehow +shuffleSameSeedIsSameShuffle :: Test +shuffleSameSeedIsSameShuffle = + let gen = mkStdGen 1234 + shuffled1 = shuffle gen testDataFrame + shuffled2 = shuffle gen testDataFrame + in TestCase + (assertEqual "shuffle with same seed gives same result" shuffled1 shuffled2) + +-- Test that different seeds give different results +shuffleDifferentSeedIsDifferent :: Test +shuffleDifferentSeedIsDifferent = + let gen1 = mkStdGen 1234 + gen2 = mkStdGen 4321 + shuffled1 = shuffle gen1 testDataFrame + shuffled2 = shuffle gen2 testDataFrame + in TestCase + ( assertEqual + "shuffle with different seeds gives different results" + False + (shuffled1 == shuffled2) + ) + +tests :: [Test] +tests = + [ TestLabel "shuffleShuffles" shuffleShuffles + , TestLabel "shufflePreservesData" shufflePreservesData + , TestLabel "shufflePreservesColumnNames" shufflePreservesColumnNames + , TestLabel "shuffleSameSeedIsSameShuffle" shuffleSameSeedIsSameShuffle + , TestLabel "shuffleDifferentSeedIsDifferent" shuffleDifferentSeedIsDifferent + ]