Skip to content

Commit 09a4d26

Browse files
committed
Fixed init bug of tensors
1 parent 553853a commit 09a4d26

3 files changed

Lines changed: 21 additions & 13 deletions

File tree

src/data_modeling/dim_type.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ class Dimension final {
3030
std::array<tensorDim_t, MAX_TENSOR_DIMS> dims; // assumption: maximum dimension of Tensor is 4
3131

3232
public:
33+
/**
34+
* @brief Explicit default ctor, so that dims is zero initialized.
35+
* Otherwise we will encounter undefined behavior.
36+
*/
37+
Dimension() : dims{} {}
38+
3339
tensorDim_t& operator[](int idx){
3440
assert(idx < MAX_TENSOR_DIMS);
3541
return dims[idx];

src/data_modeling/tensor.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,11 @@ void printValuesCpu(std::ostream& os, const Tensor& t) {
293293
const auto& dims = t.getDims();
294294
const auto MAX_IDX = static_cast<tensorDim_t>(5);
295295

296+
#ifndef NDEBUG
296297
for(int i=0; i<4; i++){
297298
cout << "Dim " << i << ": " << dims.get(i) << endl;
298299
}
300+
#endif // NDEBUG
299301

300302
if(dims.get(3)>0){
301303
std::__throw_invalid_argument("Printing 4D tensor not implemented");

src/data_modeling/tensor.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ struct Tensor final {
148148
// base case
149149
template<size_t idx>
150150
void populateDims() {}
151+
152+
Tensor multiplyScalar(const Tensor& scalar, const Tensor& other) const;
153+
Tensor multiply2D(const Tensor& left, const Tensor& right) const;
154+
155+
friend void printValuesCpu(std::ostream& os, const Tensor& t);
151156

152157
template<size_t idx, typename First, typename... Rest>
153158
requires (is_valid_dim<First>)
@@ -156,14 +161,9 @@ struct Tensor final {
156161
dims[idx] = static_cast<tensorDim_t>(first);
157162
populateDims<idx+1>(rest...);
158163
}
159-
160-
Tensor multiplyScalar(const Tensor& scalar, const Tensor& other) const;
161-
Tensor multiply2D(const Tensor& left, const Tensor& right) const;
162-
163-
friend void printValuesCpu(std::ostream& os, const Tensor& t);
164164

165165
template<typename... T>
166-
void constructTensor(Device d, T... dims) {
166+
void constructTensor(Device d, T... dimensions) {
167167
if constexpr(sizeof...(T)==4){
168168
type = TensorType::FourD;
169169
}
@@ -180,27 +180,27 @@ struct Tensor final {
180180
type = TensorType::Scalar;
181181
}
182182

183-
populateDims<0>(dims...);
183+
populateDims<0>(dimensions...);
184184

185185
values = std::make_shared<tensorValues_t>(d);
186186
if constexpr (sizeof...(T)==0){
187187
values->resize(1);
188188
}
189189
else {
190-
values->resize(varProduct(dims...));
190+
values->resize(varProduct(dimensions...));
191191
}
192192
}
193193

194194
public:
195195
template<typename... T>
196-
explicit Tensor(T... dims) {
197-
constructTensor(tensorValues_t::getDefaultDevice(), dims...);
196+
explicit Tensor(Device d, T... dimensions) {
197+
constructTensor(d, dimensions...);
198198
}
199199

200200
template<typename... T>
201-
explicit Tensor(Device d, T... dims) {
202-
constructTensor(d, dims...);
203-
}
201+
explicit Tensor(T... dimensions) :
202+
Tensor(tensorValues_t::getDefaultDevice(), dimensions...)
203+
{}
204204

205205
/**
206206
* @brief Copying.

0 commit comments

Comments
 (0)