@@ -23,6 +23,10 @@ using namespace std;
2323*************************** tensorValues_t *********************************
2424********************************************************************/
2525
26+ Tensor::tensorValues_t::tensorValues_t () {
27+ device = defaultDevice;
28+ }
29+
2630Tensor::tensorValues_t::tensorValues_t (Device d) : device(d) {}
2731
2832Tensor::tensorValues_t::tensorValues_t (tensorValues_t&& other) noexcept {
@@ -78,6 +82,14 @@ Device Tensor::tensorValues_t::getDevice() const noexcept {
7882 return this ->device ;
7983}
8084
85+ void Tensor::tensorValues_t::setDefaultDevice (const Device d) noexcept {
86+ defaultDevice = d;
87+ }
88+
89+ Device Tensor::tensorValues_t::getDefaultDevice () noexcept {
90+ return defaultDevice;
91+ }
92+
8193tensorSize_t Tensor::tensorValues_t::getSize () const noexcept {
8294 return this ->size ;
8395}
@@ -96,6 +108,9 @@ ftype& Tensor::tensorValues_t::operator[](int idx) {
96108 case Device::CUDA:
97109 __throw_invalid_argument (" Cuda operator[] not implemented" );
98110 }
111+
112+ __throw_invalid_argument (" Unexpected device encountered" );
113+ return values[0 ]; // never reached, suppress warning
99114}
100115
101116ftype Tensor::tensorValues_t::get (const int idx) const {
@@ -108,6 +123,9 @@ ftype Tensor::tensorValues_t::get(const int idx) const {
108123 case Device::CUDA:
109124 __throw_invalid_argument (" Cuda getter not implemented" );
110125 }
126+
127+ __throw_invalid_argument (" Unexpected device encountered" );
128+ return 0 ; // never reached, suppress warning
111129}
112130
113131/* *******************************************************************
@@ -152,7 +170,7 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept {
152170 */
153171Tensor Tensor::multiplyScalar (const Tensor& scalar, const Tensor& right) const {
154172 Tensor res (right);
155- for (int i=0 ; i<right.dims . getTotalSize (); ++i){
173+ for (int i=0 ; i<right.getSize (); ++i){
156174 (*res.values )[i] = (*this ->values )[0 ] * (*right.values )[i];
157175 }
158176 return res;
@@ -167,7 +185,7 @@ Tensor Tensor::multiplyScalar(const Tensor& scalar, const Tensor& right) const {
167185 * network class object instance upon construction.
168186 */
169187Tensor Tensor::multiply2D (const Tensor& left, const Tensor& right) const {
170- Tensor res (left.dims .get (0 ), right.dims .get (1 ), this -> values -> getDevice ( ));
188+ Tensor res (this -> values -> getDevice (), left.dims .get (0 ), right.dims .get (1 ));
171189
172190 for (uint16_t row=0 ; row<left.dims .get (0 ); row++){
173191 const uint32_t leftRowOffset = row * left.dims .get (1 );
@@ -228,10 +246,20 @@ Tensor Tensor::multiply(const Tensor& left, const Tensor& right) {
228246 return left * right;
229247}
230248
249+ /* *
250+ * @brief Populates the tensor with value.
251+ */
252+ void Tensor::reset (const ftype x) {
253+ for (tensorSize_t i=0 ; i<values->getSize (); i++){
254+ (*values)[i] = x;
255+ }
256+ }
257+
231258/* *
232259 * @brief Populates the tensor with values drawn according to initializer.
233260 */
234- void Tensor::initialize (const unique_ptr<utility::InitializerBase>& init) {
261+ void Tensor::reset (const utility::InitClass ic) {
262+ const auto init = utility::InitializerFactory::getInitializer (ic);
235263 for (tensorSize_t i=0 ; i<values->getSize (); i++){
236264 (*values)[i] = init->drawNumber ();
237265 }
@@ -241,10 +269,34 @@ const Dimension& Tensor::getDims() const noexcept {
241269 return dims;
242270}
243271
272+ tensorSize_t Tensor::getSize () const noexcept {
273+ return values->getSize ();
274+ }
275+
276+ void Tensor::setDefaultDevice (const Device d) noexcept {
277+ tensorValues_t::setDefaultDevice (d);
278+ }
279+
280+ Device Tensor::getDefaultDevice () noexcept {
281+ return tensorValues_t::getDefaultDevice ();
282+ }
283+
284+ void Tensor::setDevice (const Device d) noexcept {
285+ values->setDevice (d);
286+ }
287+
288+ Device Tensor::getDevice () const noexcept {
289+ return values->getDevice ();
290+ }
291+
244292void printValuesCpu (std::ostream& os, const Tensor& t) {
245293 const auto & dims = t.getDims ();
246294 const auto MAX_IDX = static_cast <tensorDim_t>(5 );
247295
296+ for (int i=0 ; i<4 ; i++){
297+ cout << " Dim " << i << " : " << dims.get (i) << endl;
298+ }
299+
248300 if (dims.get (3 )>0 ){
249301 std::__throw_invalid_argument (" Printing 4D tensor not implemented" );
250302 }
0 commit comments