Skip to content

Commit 0e3d63b

Browse files
committed
MLP clients: add missing random seed details
1 parent f908d0f commit 0e3d63b

2 files changed

Lines changed: 8 additions & 7 deletions

File tree

include/flucoma/clients/nrt/MLPClassifierClient.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,10 @@ class MLPClassifierClient : public FluidBaseClient,
157157

158158
mAlgorithm.encoder.fit(targetDataSet);
159159

160-
if (mTracker.changed(sourceDataSet.pointSize(),
160+
if (!mAlgorithm.initialized() ||
161+
mTracker.changed(sourceDataSet.pointSize(),
161162
mAlgorithm.encoder.numLabels(), get<kHidden>(),
162-
get<kActivation>()))
163+
get<kActivation>(), get<kRandomSeed>()))
163164
{
164165
mAlgorithm.mlp.init(sourceDataSet.pointSize(),
165166
mAlgorithm.encoder.numLabels(), get<kHidden>(),
@@ -183,7 +184,7 @@ class MLPClassifierClient : public FluidBaseClient,
183184
}
184185

185186
FluidDataSetSampler sampler(sourceDataSet, targetDataSet,
186-
get<kBatchSize>(), get<kVal>(), true);
187+
get<kBatchSize>(), get<kVal>(), true, get<kRandomSeed>());
187188

188189
algorithm::SGD sgd;
189190
double error = sgd.train(mAlgorithm.mlp, data, oneHot, sampler, get<kIter>(),
@@ -263,7 +264,7 @@ class MLPClassifierClient : public FluidBaseClient,
263264

264265
private:
265266

266-
ParameterTrackChanges<index, index, IndexVector, index> mTracker;
267+
ParameterTrackChanges<index, index, IndexVector, index, index> mTracker;
267268

268269
MessageResult<ParamValues> updateParameters()
269270
{

include/flucoma/clients/nrt/MLPRegressorClient.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class MLPRegressorClient : public FluidBaseClient,
118118
: get<kOutputActivation>();
119119
if (!mAlgorithm.initialized() ||
120120
mTracker.changed(sourceDataSet.pointSize(), targetDataSet.pointSize(),
121-
get<kHidden>(), get<kActivation>(), outputAct))
121+
get<kHidden>(), get<kActivation>(), outputAct, get<kRandomSeed>()))
122122
{
123123

124124
mAlgorithm.init(sourceDataSet.pointSize(), targetDataSet.pointSize(),
@@ -134,7 +134,7 @@ class MLPRegressorClient : public FluidBaseClient,
134134
auto data = sourceDataSet.getData();
135135
auto tgt = targetDataSet.getData();
136136
FluidDataSetSampler sampler(sourceDataSet, targetDataSet,
137-
get<kBatchSize>(), get<kVal>(), true);
137+
get<kBatchSize>(), get<kVal>(), true, get<kRandomSeed>());
138138
algorithm::SGD sgd;
139139
double error = sgd.train(mAlgorithm, data, tgt, sampler, get<kIter>(),
140140
get<kRate>(), get<kMomentum>());
@@ -243,7 +243,7 @@ class MLPRegressorClient : public FluidBaseClient,
243243

244244
private:
245245

246-
ParameterTrackChanges<index, index, IndexVector, index, index> mTracker;
246+
ParameterTrackChanges<index, index, IndexVector, index, index, index> mTracker;
247247

248248
MessageResult<ParamValues> updateParameters()
249249
{

0 commit comments

Comments
 (0)