Skip to content

Commit a566e92

Browse files
committed
Refactored graph creation to provide cleaner interface, updated unit tests
1 parent eb0560c commit a566e92

19 files changed

Lines changed: 250 additions & 126 deletions

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
gcc-version: [12.3]
14+
gcc-version: [13]
1515
cmake-version: ['3.31.3']
1616
fail-fast: true
1717

src/backend/computational_graph/add_node.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ using namespace graph;
1717
vector< shared_ptr<Tensor> > AddNode::backward(const Tensor& upstreamGrad) {
1818
assert(!upstreamGrad.getRequiresGrad());
1919
auto res = make_shared<Tensor>(upstreamGrad.createDeepCopy());
20-
return {res, res}; // TODO: make sure that this works as intended
20+
return {res, res};
2121
}

src/backend/computational_graph/add_node.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
namespace graph {
1717
class AddNode final : public GraphNode {
1818
public:
19-
explicit AddNode(Tensor* t1, Tensor* t2) : GraphNode({t1, t2}) {}
19+
explicit AddNode(std::shared_ptr<Tensor> t1, std::shared_ptr<Tensor> t2)
20+
: GraphNode({std::move(t1), std::move(t2)}) {}
2021

2122
AddNode(const AddNode& other) = delete;
2223
AddNode& operator=(const AddNode& other) = delete;

src/backend/computational_graph/elementwise_mul_node.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
namespace graph {
1717
class ElementwiseMulNode final : public GraphNode {
1818
public:
19-
explicit ElementwiseMulNode(Tensor* t1, Tensor* t2) : GraphNode({t1, t2}) {}
19+
explicit ElementwiseMulNode(std::shared_ptr<Tensor> t1, std::shared_ptr<Tensor> t2)
20+
: GraphNode({std::move(t1), std::move(t2)}) {}
2021

2122
ElementwiseMulNode(const ElementwiseMulNode& other) = delete;
2223
ElementwiseMulNode& operator=(const ElementwiseMulNode& other) = delete;
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/**
2+
* @file graph_creation.cpp
3+
* @author Robert Baumgartner (r.baumgartner-1@tudelft.nl)
4+
* @brief
5+
* @version 0.1
6+
* @date 2026-02-17
7+
*
8+
* @copyright Copyright (c) 2026
9+
*
10+
*/
11+
12+
#include "graph_creation.h"
13+
14+
#include "add_node.h"
15+
#include "matmul_node.h"
16+
#include "elementwise_mul_node.h"
17+
#include "scalar_op_nodes.h"
18+
19+
using namespace std;
20+
21+
shared_ptr<Tensor> graph::mul(const shared_ptr<Tensor> left, const shared_ptr<Tensor> right) {
22+
auto res = make_shared<Tensor>((*left) * (*right));
23+
if(left->getRequiresGrad() || right->getRequiresGrad()){
24+
assert(res->getRequiresGrad());
25+
res->setCgNode(make_shared<graph::ElementwiseMulNode>(left, right));
26+
}
27+
return res;
28+
}
29+
30+
shared_ptr<Tensor> graph::add(const shared_ptr<Tensor> left, const shared_ptr<Tensor> right) {
31+
auto res = make_shared<Tensor>(*left + *right);
32+
if(left->getRequiresGrad() || right->getRequiresGrad()){
33+
assert(res->getRequiresGrad());
34+
res->setCgNode(make_shared<graph::AddNode>(left, right));
35+
}
36+
return res;
37+
}
38+
39+
shared_ptr<Tensor> graph::matmul(const shared_ptr<Tensor> left, const shared_ptr<Tensor> right) {
40+
auto res = make_shared<Tensor>(left->matmul(*right));
41+
if(left->getRequiresGrad() || right->getRequiresGrad()){
42+
assert(res->getRequiresGrad());
43+
res->setCgNode(make_shared<graph::MatMulNode>(left, right));
44+
}
45+
return res;
46+
}
47+
48+
shared_ptr<Tensor> graph::mul(const shared_ptr<Tensor> t, ftype scalar) {
49+
auto res = make_shared<Tensor>((*t) * scalar);
50+
if(t->getRequiresGrad()){
51+
assert(res->getRequiresGrad());
52+
res->setCgNode(std::make_shared<graph::ScalarMulNode>(t, scalar));
53+
}
54+
return res;
55+
}
56+
57+
shared_ptr<Tensor> graph::mul(ftype scalar, const shared_ptr<Tensor> t) {
58+
return graph::mul(t, scalar);
59+
}
60+
61+
shared_ptr<Tensor> graph::add(const shared_ptr<Tensor> t, ftype scalar) {
62+
auto res = make_shared<Tensor>((*t) + scalar);
63+
if(t->getRequiresGrad()){
64+
assert(res->getRequiresGrad());
65+
res->setCgNode(std::make_shared<graph::ScalarAddNode>(t));
66+
}
67+
return res;
68+
}
69+
70+
shared_ptr<Tensor> graph::add(ftype scalar, const shared_ptr<Tensor> t) {
71+
return graph::add(t, scalar);
72+
}
73+
74+
shared_ptr<Tensor> graph::sub(const shared_ptr<Tensor> t, ftype scalar) {
75+
auto res = make_shared<Tensor>((*t) - scalar);
76+
if(t->getRequiresGrad()){
77+
assert(res->getRequiresGrad());
78+
res->setCgNode(std::make_shared<graph::ScalarAddNode>(t));
79+
}
80+
return res;
81+
}
82+
83+
shared_ptr<Tensor> graph::div(const shared_ptr<Tensor> t, ftype scalar) {
84+
auto res = make_shared<Tensor>((*t) / scalar);
85+
if(t->getRequiresGrad()){
86+
assert(res->getRequiresGrad());
87+
res->setCgNode(std::make_shared<graph::ScalarMulNode>(t, 1 / scalar));
88+
}
89+
return res;
90+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/**
2+
* @file graph_creation.h
3+
* @author Robert Baumgartner (r.baumgartner-1@tudelft.nl)
4+
* @brief Tensor operations that actually create the computational graph.
5+
* @version 0.1
6+
* @date 2026-02-17
7+
*
8+
* @copyright Copyright (c) 2026
9+
*
10+
*/
11+
12+
#pragma once
13+
14+
#include "data_modeling/tensor.h"
15+
16+
#include <memory>
17+
18+
namespace graph {
19+
std::shared_ptr<Tensor> mul(const std::shared_ptr<Tensor> left, const std::shared_ptr<Tensor> right);
20+
21+
std::shared_ptr<Tensor> add(const std::shared_ptr<Tensor> left, const std::shared_ptr<Tensor> right);
22+
23+
std::shared_ptr<Tensor> matmul(const std::shared_ptr<Tensor> left, const std::shared_ptr<Tensor> right);
24+
25+
std::shared_ptr<Tensor> mul(const std::shared_ptr<Tensor> left, ftype scalar);
26+
std::shared_ptr<Tensor> mul(ftype scalar, const std::shared_ptr<Tensor> left);
27+
28+
std::shared_ptr<Tensor> add(const std::shared_ptr<Tensor> left, ftype scalar);
29+
std::shared_ptr<Tensor> add(ftype scalar, const std::shared_ptr<Tensor> left);
30+
31+
std::shared_ptr<Tensor> sub(const std::shared_ptr<Tensor> left, ftype scalar);
32+
std::shared_ptr<Tensor> div(const std::shared_ptr<Tensor> left, ftype scalar);
33+
}
34+

src/backend/computational_graph/graph_node.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
namespace graph {
2222
class GraphNode {
2323
protected:
24-
std::vector<Tensor*> parents;
25-
explicit GraphNode(std::vector<Tensor*> parents) : parents{std::move(parents)}{}
24+
std::vector< std::shared_ptr<Tensor> > parents;
25+
explicit GraphNode(std::vector< std::shared_ptr<Tensor> > parents) : parents{std::move(parents)}{}
2626

2727
public:
2828
virtual std::vector<std::shared_ptr<Tensor>> backward(const Tensor& upstreamGrad) = 0;

src/backend/computational_graph/matmul_node.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
namespace graph {
1919
class MatMulNode final : public GraphNode {
2020
public:
21-
explicit MatMulNode(Tensor* t1, Tensor* t2): GraphNode({t1, t2}) {}
21+
explicit MatMulNode(std::shared_ptr<Tensor> t1, std::shared_ptr<Tensor> t2)
22+
: GraphNode({std::move(t1), std::move(t2)}) {}
2223

2324
MatMulNode(const MatMulNode& other) = delete;
2425
MatMulNode& operator=(const MatMulNode& other) = delete;

src/backend/computational_graph/relu_node.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
namespace graph {
1919
class ReLuNode final : public GraphNode {
2020
public:
21-
explicit ReLuNode(Tensor* t): GraphNode({t}) {}
21+
explicit ReLuNode(std::shared_ptr<Tensor> t)
22+
: GraphNode({std::move(t)}) {}
2223

2324
ReLuNode(const ReLuNode& other) = delete;
2425
ReLuNode& operator=(const ReLuNode& other) = delete;

src/backend/computational_graph/scalar_op_nodes.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
namespace graph {
1717
class ScalarAddNode final : public GraphNode {
1818
public:
19-
explicit ScalarAddNode(Tensor* t) : GraphNode({t}) {}
19+
explicit ScalarAddNode(std::shared_ptr<Tensor> t)
20+
: GraphNode({std::move(t)}) {}
2021

2122
ScalarAddNode(const ScalarAddNode& other) = delete;
2223
ScalarAddNode& operator=(const ScalarAddNode& other) = delete;
@@ -34,7 +35,8 @@ namespace graph {
3435
const ftype factor;
3536

3637
public:
37-
explicit ScalarMulNode(Tensor* t, const ftype factor) : GraphNode({t}), factor{factor} {}
38+
explicit ScalarMulNode(std::shared_ptr<Tensor> t, ftype factor)
39+
: GraphNode({std::move(t)}), factor{factor} {}
3840

3941
ScalarMulNode(const ScalarMulNode& other) = delete;
4042
ScalarMulNode& operator=(const ScalarMulNode& other) = delete;

0 commit comments

Comments
 (0)