Skip to content

Commit f908d0f

Browse files
tremblapweefuzzy
andauthored
NMFCross seeding random (#330)
* first attempt at a first UT in c++ * wait, I forgot to add them here * thanks to @weefuzzy it actually works * Griffith Lim in progress * merges main in * argghhhhh griffinlim in line with difficult i/o * Test for GriffinLim repeatability * Clang-format TestGriffinLim * minimum clang format --------- Co-authored-by: Owen Green <gungwho@gmail.com>
1 parent abc07b8 commit f908d0f

6 files changed

Lines changed: 99 additions & 9 deletions

File tree

include/flucoma/algorithms/public/GriffinLim.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ under the European Union’s Horizon 2020 research and innovation programme
1212

1313
#include "STFT.hpp"
1414
#include "../util/AlgorithmUtils.hpp"
15+
#include "../util/EigenRandom.hpp"
1516
#include "../util/FluidEigenMappings.hpp"
1617
#include "../../data/FluidIndex.hpp"
1718
#include "../../data/TensorTypes.hpp"
@@ -26,7 +27,7 @@ class GriffinLim
2627

2728
public:
2829
void process(ComplexMatrixView in, index nSamples, index nIter, index winSize,
29-
index fftSize, index hopSize)
30+
index fftSize, index hopSize, index seed = -1)
3031
{
3132
using namespace Eigen;
3233
using namespace _impl;
@@ -36,9 +37,8 @@ class GriffinLim
3637
auto istft = ISTFT(winSize, fftSize, hopSize);
3738
ArrayXd tmp = ArrayXd::Zero(nSamples);
3839
ArrayXXcd magnitude = asEigen<Array>(in).abs();
39-
ArrayXXcd phase =
40-
ArrayXXcd::Random(magnitude.rows(), magnitude.cols()) * 2 * 1i * pi;
41-
phase = phase.exp();
40+
ArrayXXcd phase = EigenRandomPhase<ArrayXXcd>(
41+
magnitude.rows(), magnitude.cols(), RandomSeed{seed});
4242
ArrayXXcd estimate = ArrayXXcd::Zero(magnitude.rows(), magnitude.cols());
4343
ArrayXXcd prev = ArrayXXcd::Zero(magnitude.rows(), magnitude.cols());
4444
for (index i = 0; i < nIter; i++)

include/flucoma/algorithms/public/NMFCross.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ under the European Union’s Horizon 2020 research and innovation programme
1515
#pragma once
1616

1717
#include "STFT.hpp"
18+
#include "../util/EigenRandom.hpp"
1819
#include "../util/FluidEigenMappings.hpp"
1920
#include "../../data/FluidIndex.hpp"
2021
#include "../../data/TensorTypes.hpp"
@@ -57,16 +58,16 @@ class NMFCross
5758
}
5859

5960
void process(const RealMatrixView X, RealMatrixView H1, RealMatrixView W0,
60-
index r, index p, index c) const
61+
index r, index p, index c, index randomSeed = -1) const
6162
{
6263
index nFrames = X.extent(0);
6364
index nBins = X.extent(1);
6465
index rank = W0.extent(0);
6566
nBins = W0.extent(1);
6667
MatrixXd W = asEigen<Matrix>(W0).transpose();
6768
MatrixXd H;
68-
H = MatrixXd::Random(rank, nFrames) * 0.5 +
69-
MatrixXd::Constant(rank, nFrames, 0.5);
69+
H = EigenRandom<MatrixXd>(rank, nFrames, RandomSeed{randomSeed},
70+
Range{0.0, 1.0});
7071
MatrixXd V = asEigen<Matrix>(X).transpose();
7172
multiplicativeUpdates(V, W, H, r, p, c);
7273
MatrixXd HT = H.transpose();

include/flucoma/clients/nrt/NMFCrossClient.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum NMFCrossParamIndex {
3232
kPolyphony,
3333
kContinuity,
3434
kIterations,
35+
kRandomSeed,
3536
kFFT
3637
};
3738

@@ -44,6 +45,7 @@ constexpr auto NMFCrossParams = defineParameters(
4445
FrameSizeUpperLimit<kFFT>()),
4546
LongParam("continuity", "Continuity", 7, Min(1), Odd()),
4647
LongParam("iterations", "Number of Iterations", 50, Min(1)),
48+
LongParam("seed", "Random Seed", -1),
4749
FFTParam("fftSettings", "FFT Settings", 1024, -1, -1));
4850

4951
class NMFCrossClient : public FluidBaseClient,
@@ -154,7 +156,8 @@ class NMFCrossClient : public FluidBaseClient,
154156
});
155157

156158
nmf.process(tgtMag, outputEnvelopes, W, get<kTimeSparsity>(),
157-
std::min(srcWindows, get<kPolyphony>()), get<kContinuity>());
159+
std::min(srcWindows, get<kPolyphony>()), get<kContinuity>(),
160+
get<kRandomSeed>());
158161

159162
r = checkTask(c, progressCount, progressTotal);
160163
if (!r.ok()) return r;
@@ -166,7 +169,7 @@ class NMFCrossClient : public FluidBaseClient,
166169

167170
GriffinLim gl;
168171
gl.process(result, tgtFrames, 50, fftParams.winSize(), fftParams.fftSize(),
169-
fftParams.hopSize());
172+
fftParams.hopSize(), get<kRandomSeed>());
170173

171174
r = checkTask(c, ++progressCount, progressTotal);
172175
if (!r.ok()) return r;

tests/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ add_test_executable(TestTransientSlice algorithms/public/TestTransientSlice.cpp)
116116

117117
add_test_executable(TestMLP algorithms/public/TestMLP.cpp)
118118
add_test_executable(TestKMeans algorithms/public/TestKMeans.cpp)
119+
add_test_executable(TestNMFCross algorithms/public/TestNMFCross.cpp)
120+
add_test_executable(TestGriffinLim algorithms/public/TestGriffinLim.cpp)
119121
add_test_executable(TestNMF algorithms/public/TestNMF.cpp)
120122
add_test_executable(TestUMAP algorithms/public/TestUMAP.cpp)
121123

@@ -159,6 +161,8 @@ catch_discover_tests(TestBufferedProcess WORKING_DIRECTORY "${CMAKE_BINARY_DIR}"
159161
catch_discover_tests(TestMLP WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
160162
catch_discover_tests(TestKMeans WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
161163
catch_discover_tests(TestEigenRandom WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
164+
catch_discover_tests(TestNMFCross WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
165+
catch_discover_tests(TestGriffinLim WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
162166
catch_discover_tests(TestNNDSVD WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
163167
catch_discover_tests(TestNMF WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
164168
catch_discover_tests(TestRTPGHI WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#define CATCH_CONFIG_MAIN
2+
3+
#include <catch2/catch_all.hpp>
4+
#include <flucoma/algorithms/public/GriffinLim.hpp>
5+
#include <flucoma/data/FluidIndex.hpp>
6+
#include <flucoma/data/FluidTensor.hpp>
7+
#include <complex>
8+
#include <vector>
9+
10+
namespace fluid {
11+
TEST_CASE("GriffinLim is repeatable with user-supplied random seed")
12+
{
13+
14+
using algorithm::GriffinLim;
15+
using Tensor = FluidTensor<std::complex<double>, 2>;
16+
17+
index win = 64;
18+
index fft = 64;
19+
index hop = 64;
20+
index bins = fft / 2 + 1;
21+
22+
// only actually interested in 1 frame of results, but need padding in algo
23+
Tensor raw_input(2, bins);
24+
raw_input(0, index(bins / 2)) = std::polar(1.0, 0.0);
25+
26+
std::vector<Tensor> inouts(3, raw_input);
27+
28+
GriffinLim algo;
29+
30+
algo.process(inouts[0], win, 1, win, fft, hop, 42);
31+
algo.process(inouts[1], win, 1, win, fft, hop, 42);
32+
algo.process(inouts[2], win, 1, win, fft, hop, 987234);
33+
34+
using Catch::Matchers::RangeEquals;
35+
36+
SECTION("Calls with the same seed have the same output")
37+
{
38+
REQUIRE_THAT(inouts[1], RangeEquals(inouts[0]));
39+
}
40+
SECTION("Calls with different seeds have different outputs")
41+
{
42+
REQUIRE_THAT(inouts[1], !RangeEquals(inouts[2]));
43+
}
44+
}
45+
} // namespace fluid
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#define CATCH_CONFIG_MAIN
2+
#include <catch2/catch_all.hpp>
3+
#include <flucoma/algorithms/public/NMFCross.hpp>
4+
#include <flucoma/data/FluidTensor.hpp>
5+
#include <algorithm>
6+
#include <iostream>
7+
#include <vector>
8+
9+
TEST_CASE("NMFCross is repeatable with user-supplied random seed")
10+
{
11+
12+
using fluid::algorithm::NMFCross;
13+
using Tensor = fluid::FluidTensor<double, 2>;
14+
NMFCross algo(3);
15+
16+
Tensor targetMag{{0.5, 0.4}, {0.1, 1.1}, {0.7, 0.8},
17+
{0.3, 0.0}, {1.0, 0.9}, {0.2, 0.6}};
18+
Tensor sourceMag{{0.0, 0.4}, {0.6, 0.7}, {0.8, 0.1},
19+
{1.0, 0.5}, {1.1, 0.2}, {0.9, 0.3}};
20+
21+
std::vector Hs(3, Tensor(6, 6));
22+
23+
algo.process(targetMag, Hs[0], sourceMag, 3, 2, 7, 42);
24+
algo.process(targetMag, Hs[1], sourceMag, 3, 2, 7, 42);
25+
algo.process(targetMag, Hs[2], sourceMag, 3, 2, 7, 5063);
26+
27+
using Catch::Matchers::RangeEquals;
28+
29+
SECTION("Calls with the same seed have the same output")
30+
{
31+
REQUIRE_THAT(Hs[1], RangeEquals(Hs[0]));
32+
}
33+
SECTION("Calls with different seeds have different outputs")
34+
{
35+
REQUIRE_THAT(Hs[1], !RangeEquals(Hs[2]));
36+
}
37+
}

0 commit comments

Comments
 (0)