Skip to content

Commit 35b63f5

Browse files
authored
Merge pull request #332 from flucoma/feature/RTPGHI-random-seeding
Add random seed to RTPGHI and a test
2 parents b9d444e + 922e7e0 commit 35b63f5

5 files changed

Lines changed: 64 additions & 7 deletions

File tree

include/flucoma/algorithms/public/NMFMorph.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ class NMFMorph
102102

103103
bool initialized() const { return mInitialized; }
104104

105-
void processFrame(ComplexVectorView v, double interpolation, Allocator& alloc)
105+
void processFrame(ComplexVectorView v, double interpolation, index seed,
106+
Allocator& alloc)
106107
{
107108
using namespace Eigen;
108109
using namespace _impl;
@@ -120,7 +121,8 @@ class NMFMorph
120121
ScopedEigenMap<VectorXd> frame(W.rows(), alloc);
121122
frame = W * hFrame;
122123
RealVectorView mag1 = asFluid(frame);
123-
mRTPGHI.processFrame(mag1, v, mWindowSize, mFFTSize, mHopSize, 1e-6, alloc);
124+
mRTPGHI.processFrame(mag1, v, mWindowSize, mFFTSize, mHopSize, 1e-6, seed,
125+
alloc);
124126
mPos = (mPos + 1) % mH.cols();
125127
}
126128

include/flucoma/algorithms/util/RTPGHI.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ under the European Union’s Horizon 2020 research and innovation programme
1717
#include "AlgorithmUtils.hpp"
1818
#include "FluidEigenMappings.hpp"
1919
#include "../public/STFT.hpp"
20+
#include "../util/EigenRandom.hpp"
2021
#include "../../data/FluidIndex.hpp"
2122
#include "../../data/FluidMemory.hpp"
2223
#include "../../data/TensorTypes.hpp"
@@ -59,7 +60,7 @@ class RTPGHI
5960
}
6061

6162
void processFrame(RealVectorView in, ComplexVectorView out, index winSize,
62-
index fftSize, index hopSize, double tolerance, Allocator& alloc)
63+
index fftSize, index hopSize, double tolerance, index seed, Allocator& alloc)
6364
{
6465
using namespace Eigen;
6566
using namespace _impl;
@@ -88,8 +89,8 @@ class RTPGHI
8889
todo = (currentLogMag > absTol).cast<double>();
8990
index numTodo = static_cast<index>(todo.sum());
9091
ScopedEigenMap<ArrayXd> phaseEst(mBins, alloc);
91-
phaseEst = pi + ArrayXd::Random(mBins) * pi;
92-
92+
phaseEst =
93+
EigenRandom<ArrayXd>(mBins, RandomSeed{seed}, Range{0.0, 2.0 * pi});
9394
rt::vector<pair<double, index>> heap(alloc);
9495
heap.reserve(asUnsigned(mBins));
9596
for (index i = 0; i < mBins; i++)

include/flucoma/clients/rt/NMFMorphClient.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ enum NMFFilterIndex {
2727
kActBuf,
2828
kAutoAssign,
2929
kInterpolation,
30+
kRandomSeed,
3031
kFFT
3132
};
3233

@@ -36,6 +37,7 @@ constexpr auto NMFMorphParams = defineParameters(
3637
InputBufferParam("activations", "Activations"),
3738
EnumParam("autoassign", "Automatic assign", 1, "No", "Yes"),
3839
FloatParam("interpolation", "Interpolation", 0, Min(0.0), Max(1.0)),
40+
LongParam("seed", "Random Seed", -1),
3941
FFTParam("fftSettings", "FFT Settings", 1024, -1, -1));
4042

4143
class NMFMorphClient : public FluidBaseClient, public AudioOut
@@ -116,7 +118,7 @@ class NMFMorphClient : public FluidBaseClient, public AudioOut
116118
mSTFTProcessor.processOutput(
117119
get<kFFT>(), output, c, [&](ComplexMatrixView out) {
118120
mNMFMorph.processFrame(out.row(0), get<kInterpolation>(),
119-
c.allocator());
121+
get<kRandomSeed>(), c.allocator());
120122
});
121123
}
122124
}

tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ target_link_libraries(TestEnvelopeGate PRIVATE TestSignals)
128128
target_link_libraries(TestTransientSlice PRIVATE TestSignals)
129129

130130
add_test_executable(TestEigenRandom algorithms/util/TestEigenRandom.cpp)
131-
131+
add_test_executable(TestRTPGHI algorithms/util/TestRTPGHI.cpp)
132132

133133
include(CTest)
134134
include(Catch)
@@ -157,6 +157,7 @@ catch_discover_tests(TestBufferedProcess WORKING_DIRECTORY "${CMAKE_BINARY_DIR}"
157157
catch_discover_tests(TestMLP WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
158158
catch_discover_tests(TestKMeans WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
159159
catch_discover_tests(TestEigenRandom WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
160+
catch_discover_tests(TestRTPGHI WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
160161
catch_discover_tests(TestUMAP WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
161162

162163
catch_discover_tests(TestDataSampler)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#define CATCH_CONFIG_MAIN
2+
#include <catch2/catch_all.hpp>
3+
#include <flucoma/algorithms/util/RTPGHI.hpp>
4+
#include <flucoma/data/FluidMemory.hpp>
5+
#include <flucoma/data/FluidTensor.hpp>
6+
#include <complex>
7+
#include <vector>
8+
9+
namespace fluid {
10+
11+
12+
TEST_CASE("RTPGHI is repeatable with manually set random seed")
13+
{
14+
using Tensor = fluid::FluidTensor<double, 1>;
15+
using ComplexTensor = fluid::FluidTensor<std::complex<double>, 1>;
16+
using fluid::algorithm::RTPGHI;
17+
18+
index win = 64;
19+
index fft = 64;
20+
index hop = 64;
21+
index bins = fft / 2 + 1;
22+
23+
double mag = 1.0;
24+
// to stop algo converging, bypass loop by setting massive tolerence
25+
double tol = 2.0 * mag;
26+
27+
RTPGHI algo(fft, FluidDefaultAllocator());
28+
29+
Tensor input(bins);
30+
input[index(bins / 2)] = mag;
31+
std::vector results(3, ComplexTensor(bins));
32+
33+
// algo has memory, so re-init after each call to test repeatability, and call
34+
// twice to actually generate some action
35+
auto runit = [&](size_t run, index seed) {
36+
algo.init(fft);
37+
algo.processFrame(input, results[run], win, fft, hop, tol, seed,
38+
FluidDefaultAllocator());
39+
algo.processFrame(input, results[run], win, fft, hop, 2.0, seed,
40+
FluidDefaultAllocator());
41+
};
42+
43+
for (size_t run = 0; run < results.size(); ++run)
44+
runit(run, run < 2 ? 42 : 8347);
45+
46+
using Catch::Matchers::RangeEquals;
47+
48+
REQUIRE_THAT(results[0], RangeEquals(results[1]));
49+
REQUIRE_THAT(results[0], !RangeEquals(results[2]));
50+
}
51+
} // namespace fluid

0 commit comments

Comments
 (0)