@@ -21,35 +21,35 @@ using namespace std;
2121shared_ptr<Tensor> graph::mul (const shared_ptr<Tensor> left, const shared_ptr<Tensor> right) {
2222 auto res = make_shared<Tensor>((*left) * (*right));
2323 if (left->getRequiresGrad () || right->getRequiresGrad ()){
24- assert (res->getRequiresGrad ());
2524 res->setCgNode (make_shared<graph::ElementwiseMulNode>(left, right));
25+ assert (res->getRequiresGrad ());
2626 }
2727 return res;
2828}
2929
3030shared_ptr<Tensor> graph::add (const shared_ptr<Tensor> left, const shared_ptr<Tensor> right) {
3131 auto res = make_shared<Tensor>(*left + *right);
3232 if (left->getRequiresGrad () || right->getRequiresGrad ()){
33- assert (res->getRequiresGrad ());
3433 res->setCgNode (make_shared<graph::AddNode>(left, right));
34+ assert (res->getRequiresGrad ());
3535 }
3636 return res;
3737}
3838
3939shared_ptr<Tensor> graph::matmul (const shared_ptr<Tensor> left, const shared_ptr<Tensor> right) {
4040 auto res = make_shared<Tensor>(left->matmul (*right));
4141 if (left->getRequiresGrad () || right->getRequiresGrad ()){
42- assert (res->getRequiresGrad ());
4342 res->setCgNode (make_shared<graph::MatMulNode>(left, right));
43+ assert (res->getRequiresGrad ());
4444 }
4545 return res;
4646}
4747
4848shared_ptr<Tensor> graph::mul (const shared_ptr<Tensor> t, ftype scalar) {
4949 auto res = make_shared<Tensor>((*t) * scalar);
5050 if (t->getRequiresGrad ()){
51- assert (res->getRequiresGrad ());
5251 res->setCgNode (std::make_shared<graph::ScalarMulNode>(t, scalar));
52+ assert (res->getRequiresGrad ());
5353 }
5454 return res;
5555}
@@ -61,8 +61,8 @@ shared_ptr<Tensor> graph::mul(ftype scalar, const shared_ptr<Tensor> t) {
6161shared_ptr<Tensor> graph::add (const shared_ptr<Tensor> t, ftype scalar) {
6262 auto res = make_shared<Tensor>((*t) + scalar);
6363 if (t->getRequiresGrad ()){
64- assert (res->getRequiresGrad ());
6564 res->setCgNode (std::make_shared<graph::ScalarAddNode>(t));
65+ assert (res->getRequiresGrad ());
6666 }
6767 return res;
6868}
@@ -74,17 +74,67 @@ shared_ptr<Tensor> graph::add(ftype scalar, const shared_ptr<Tensor> t) {
7474shared_ptr<Tensor> graph::sub (const shared_ptr<Tensor> t, ftype scalar) {
7575 auto res = make_shared<Tensor>((*t) - scalar);
7676 if (t->getRequiresGrad ()){
77- assert (res->getRequiresGrad ());
7877 res->setCgNode (std::make_shared<graph::ScalarAddNode>(t));
78+ assert (res->getRequiresGrad ());
7979 }
8080 return res;
8181}
8282
8383shared_ptr<Tensor> graph::div (const shared_ptr<Tensor> t, ftype scalar) {
8484 auto res = make_shared<Tensor>((*t) / scalar);
8585 if (t->getRequiresGrad ()){
86- assert (res->getRequiresGrad ());
8786 res->setCgNode (std::make_shared<graph::ScalarMulNode>(t, 1 / scalar));
87+ assert (res->getRequiresGrad ());
8888 }
8989 return res;
90+ }
91+
92+ /* *
93+ * @brief Special linear indexing, see getItem() overloads in tensor.
94+ * Used to keep the computational graph intact.
95+ * E.g. if we have something like
96+ *
97+ * loss = loss + other.get(i), we need to make sure get(i) can map to computational graph.
98+ */
99+ shared_ptr<Tensor> graph::getAsShared (const shared_ptr<Tensor>& t, tensorSize_t idx) {
100+ ftype val = t->getItem (idx);
101+ return make_shared<Tensor>(std::vector<tensorDim_t>{1 }, std::vector<ftype>{val},
102+ t->getDevice (), t->getRequiresGrad ());
103+ }
104+
105+ /* *
106+ * @brief Special linear indexing, see getItem() overloads in tensor.
107+ * Used to keep the computational graph intact.
108+ * E.g. if we have something like
109+ *
110+ * loss = loss + other.get(i), we need to make sure get(i) can map to computational graph.
111+ */
112+ std::shared_ptr<Tensor> graph::getAsShared (const Tensor& t, tensorSize_t idx) {
113+ ftype val = t.getItem (idx);
114+ return make_shared<Tensor>(std::vector<tensorDim_t>{1 }, std::vector<ftype>{val},
115+ t.getDevice (), t.getRequiresGrad ());
116+ }
117+
118+ /* *
119+ * @brief Used to keep the computational graph intact.
120+ * E.g. if we have something like
121+ *
122+ * loss = loss + other.get(i), we need to make sure get(i) can map to computational graph.
123+ */
124+ shared_ptr<Tensor> graph::getAsShared (const shared_ptr<Tensor>& t, vector<tensorDim_t>&& idx) {
125+ ftype val = t->getItem (std::move (idx));
126+ return make_shared<Tensor>(std::vector<tensorDim_t>{1 }, std::vector<ftype>{val},
127+ t->getDevice (), t->getRequiresGrad ());
128+ }
129+
130+ /* *
131+ * @brief Used to keep the computational graph intact.
132+ * E.g. if we have something like
133+ *
134+ * loss = loss + other.get(i), we need to make sure get(i) can map to computational graph.
135+ */
136+ std::shared_ptr<Tensor> graph::getAsShared (const Tensor& t, std::vector<tensorDim_t>&& idx) {
137+ ftype val = t.getItem (std::move (idx));
138+ return make_shared<Tensor>(std::vector<tensorDim_t>{1 }, std::vector<ftype>{val},
139+ t.getDevice (), t.getRequiresGrad ());
90140}
0 commit comments