diff --git a/src/tensor_view.h b/src/tensor_view.h index e747d19..0b6fc59 100644 --- a/src/tensor_view.h +++ b/src/tensor_view.h @@ -23,6 +23,14 @@ class TensorView { using Strides = std::vector; + template + TensorView(const TensorLike& tensor) + : data_{const_cast(static_cast(tensor.data()))}, + shape_{tensor.shape()}, + dtype_{tensor.dtype()}, + device_{tensor.device()}, + strides_{tensor.strides()} {} + template TensorView(void* data, const Shape& shape) : data_{data},