diff --git a/src/Native/LibTorchSharp/THSTensor.cpp b/src/Native/LibTorchSharp/THSTensor.cpp
index 7b4a0e55e..3b3e57276 100644
--- a/src/Native/LibTorchSharp/THSTensor.cpp
+++ b/src/Native/LibTorchSharp/THSTensor.cpp
@@ -1894,6 +1894,16 @@ Tensor THSTensor_to_dense(Tensor tensor)
CATCH_TENSOR(tensor->to_dense());
}
+Tensor THSTensor_to_sparse(Tensor tensor)
+{
+ CATCH_TENSOR(tensor->to_sparse());
+}
+
+Tensor THSTensor_to_sparse_with_dims(Tensor tensor, const int64_t sparse_dim)
+{
+ CATCH_TENSOR(tensor->to_sparse(sparse_dim));
+}
+
void THSTensor_set_(Tensor tensor, const Tensor source)
{
CATCH(tensor->set_(*source););
diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h
index 73bff0403..7cc3ffc04 100644
--- a/src/Native/LibTorchSharp/THSTensor.h
+++ b/src/Native/LibTorchSharp/THSTensor.h
@@ -1396,6 +1396,9 @@ EXPORT_API(Tensor) THSTensor_trapezoid_dx(const Tensor y, const double dx, int64
EXPORT_API(Tensor) THSTensor_to_dense(Tensor tensor);
+EXPORT_API(Tensor) THSTensor_to_sparse(Tensor tensor);
+EXPORT_API(Tensor) THSTensor_to_sparse_with_dims(Tensor tensor, const int64_t sparse_dim);
+
EXPORT_API(Tensor) THSTensor_to_device(const Tensor tensor, const int device_type, const int device_index, const bool copy, const bool non_blocking);
EXPORT_API(Tensor) THSTensor_to_type(const Tensor tensor, int8_t scalar_type, const bool copy, const bool non_blocking);
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
index e8db2c2cb..36827bd61 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
@@ -376,6 +376,12 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_to_dense(IntPtr handle);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_to_sparse(IntPtr handle);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_to_sparse_with_dims(IntPtr handle, long sparse_dim);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_clone(IntPtr handle);
diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs
index ea70b83e1..ee0067b6f 100644
--- a/src/TorchSharp/Tensor/Tensor.cs
+++ b/src/TorchSharp/Tensor/Tensor.cs
@@ -1376,6 +1376,29 @@ public Tensor to_dense()
return new Tensor(res);
}
+ ///
+ /// Converts a dense tensor to a sparse COO tensor.
+ ///
+ public Tensor to_sparse()
+ {
+ var res = NativeMethods.THSTensor_to_sparse(Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
+ ///
+ /// Converts a dense tensor to a sparse COO tensor with the specified number of sparse dimensions.
+ ///
+ /// The number of sparse dimensions.
+ public Tensor to_sparse(int sparse_dim)
+ {
+ var res = NativeMethods.THSTensor_to_sparse_with_dims(Handle, sparse_dim);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
///
/// Returns a copy of the tensor input.
///