Skip to content

Commit 7c62e0c

Browse files
committed
More bugfixing, unit testing in C++
1 parent 195a057 commit 7c62e0c

3 files changed

Lines changed: 55 additions & 3 deletions

File tree

src/python/py_nn/py_nn.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ BOOST_PYTHON_MODULE(_nn)
2525
/**
2626
* Return values, so BP knows how to wrap them. Example: parameters(), see FfLayer
2727
* Omitting these steps will result in crashes when working with tensors returned by
28-
* those functions
28+
* those functions.
2929
*/
3030
boost::python::object coreModule = boost::python::import("dl_lib._compiled._core");
31+
// this works when function returns a single shared_ptr<Tensor>
3132
boost::python::register_ptr_to_python<std::shared_ptr<Tensor>>();
3233

3334
using namespace Py_Util;
@@ -77,9 +78,10 @@ BOOST_PYTHON_MODULE(_nn)
7778
.add_property("bias", &module::FfLayer::getBias)
7879
// methods
7980
.def("parameters", +[](const module::FfLayer& f) -> boost::python::list {
81+
// we get a vector of shared_ptr, therefore need to give instructions on conversion
8082
boost::python::list result;
8183
for(auto& t : f.parameters())
82-
result.append(t);
84+
result.append(t); // forces conversion through Object*
8385
return result;
8486
})
8587
// operators
@@ -92,7 +94,9 @@ BOOST_PYTHON_MODULE(_nn)
9294
.def("__str__", &toString<module::ReLu>)
9395
;
9496

95-
class_<module::LeakyReLu, std::shared_ptr<module::LeakyReLu>, boost::noncopyable>("LeakyReLU", init<ftype>())
97+
class_<module::LeakyReLu, std::shared_ptr<module::LeakyReLu>, boost::noncopyable>("LeakyReLU")
98+
.def(init<>()) // default epsilon
99+
.def(init<ftype>())
96100
.def("__call__", WRAP_METHOD_ONE_TENSORARG(module::LeakyReLu, Py_nn::leakyReluF))
97101
.def("__str__", &toString<module::LeakyReLu>)
98102
;

src/python/py_train/py_train.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525

2626
BOOST_PYTHON_MODULE(_train)
2727
{
28+
// enable conversion from Tensor registered in _core
29+
boost::python::object coreModule = boost::python::import("dl_lib._compiled._core");
30+
boost::python::register_ptr_to_python<std::shared_ptr<Tensor>>();
31+
2832
using namespace boost::python;
2933

3034
// Loss functions

tests/backend/test_train_loop.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,4 +231,48 @@ TEST(OverfitTest, CrossEntropyRMSPropOverfitsSmallDataset_OptimizedLoss) {
231231
EXPECT_LT((*finalLoss)[0], 0.05f)
232232
<< "Network failed to overfit multiclass dataset"
233233
<< "Final prediction: " << softmax(*pred) << "\nFinal loss: " << *finalLoss;
234+
}
235+
236+
TEST(OptimizerTest, ZeroGrad_ClearsAllGradients) {
237+
auto x = TensorFunctions::makeSharedTensor(
238+
{4, 2}, {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0}, false);
239+
auto y = TensorFunctions::makeSharedTensor(
240+
{4, 1}, {0.0, 1.0, 1.0, 0.0}, false);
241+
242+
auto net = makeBinaryNet();
243+
auto loss = std::make_shared<train::BceSigmoidLoss>();
244+
auto optim = std::make_shared<train::SgdOptimizer>(
245+
net->parameters(), 0.01f);
246+
247+
// one forward/backward pass to populate gradients
248+
auto pred = (*net)(x);
249+
auto l = (*loss)(y, pred);
250+
l->backward();
251+
252+
// verify gradients are non-zero before zeroing
253+
bool anyNonZero = false;
254+
for(auto& p : net->parameters()) {
255+
if(p->getGrads()) {
256+
for(tensorSize_t i = 0; i < p->getGrads()->getSize(); i++) {
257+
if((*p->getGrads())[i] != 0.0f) {
258+
anyNonZero = true;
259+
break;
260+
}
261+
}
262+
}
263+
}
264+
EXPECT_TRUE(anyNonZero) << "Expected some non-zero gradients before zeroGrad";
265+
266+
// zero gradients
267+
optim->zeroGrad();
268+
269+
// verify all gradients are zero after zeroing
270+
for(auto& p : net->parameters()) {
271+
if(p->getGrads()) {
272+
for(tensorSize_t i = 0; i < p->getGrads()->getSize(); i++) {
273+
EXPECT_FLOAT_EQ((*p->getGrads())[i], 0.0f)
274+
<< "Gradient not zeroed at index " << i;
275+
}
276+
}
277+
}
234278
}

0 commit comments

Comments
 (0)