diff --git a/src/TorchSharp/BFloat16.cs b/src/TorchSharp/BFloat16.cs
new file mode 100644
index 000000000..3c96eb130
--- /dev/null
+++ b/src/TorchSharp/BFloat16.cs
@@ -0,0 +1,127 @@
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+using System;
+using System.Globalization;
+using System.Runtime.InteropServices;
+
+#nullable enable
+namespace TorchSharp
+{
+ ///
+ /// Represents a 16-bit brain floating-point number (BFloat16).
+ /// Binary layout: 1 sign bit, 8 exponent bits, 7 mantissa bits — the upper 16 bits of IEEE 754 float32.
+ /// Binary-compatible with c10::BFloat16 in LibTorch.
+ ///
+ [StructLayout(LayoutKind.Sequential)]
+ public readonly struct BFloat16 : IComparable, IEquatable, IComparable, IFormattable
+ {
+ internal readonly ushort value;
+
+ internal BFloat16(ushort rawValue, bool _)
+ {
+ value = rawValue;
+ }
+
+ ///
+ /// Creates a BFloat16 from a float value using round-to-nearest-even (matching PyTorch c10::BFloat16).
+ ///
+ public BFloat16(float f)
+ {
+ value = FloatToBFloat16Bits(f);
+ }
+
+ ///
+ /// Creates a BFloat16 from the raw 16-bit representation.
+ ///
+ public static BFloat16 FromRawValue(ushort rawValue) => new BFloat16(rawValue, false);
+
+ // --- Conversion to/from float ---
+
+ private static unsafe ushort FloatToBFloat16Bits(float f)
+ {
+ uint bits = *(uint*)&f;
+ // NaN: preserve payload, just truncate
+ if ((bits & 0x7F800000u) == 0x7F800000u && (bits & 0x007FFFFFu) != 0)
+ return (ushort)(bits >> 16 | 0x0040u); // quiet NaN
+ // Round-to-nearest-even (matching PyTorch c10::BFloat16)
+ uint lsb = (bits >> 16) & 1u;
+ uint roundingBias = 0x7FFFu + lsb;
+ bits += roundingBias;
+ return (ushort)(bits >> 16);
+ }
+
+ private static unsafe float BFloat16BitsToFloat(ushort raw)
+ {
+ int bits = raw << 16;
+ return *(float*)&bits;
+ }
+
+ ///
+ /// Converts this BFloat16 to a float.
+ ///
+ public float ToSingle() => BFloat16BitsToFloat(value);
+
+ // --- Conversion operators ---
+
+ public static explicit operator float(BFloat16 bf) => bf.ToSingle();
+ public static explicit operator double(BFloat16 bf) => bf.ToSingle();
+ public static explicit operator BFloat16(float f) => new BFloat16(f);
+ public static explicit operator BFloat16(double d) => new BFloat16((float)d);
+
+ // --- Arithmetic operators (promote to float, truncate back) ---
+
+ public static BFloat16 operator +(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() + b.ToSingle());
+ public static BFloat16 operator -(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() - b.ToSingle());
+ public static BFloat16 operator *(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() * b.ToSingle());
+ public static BFloat16 operator /(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() / b.ToSingle());
+ public static BFloat16 operator %(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() % b.ToSingle());
+ public static BFloat16 operator -(BFloat16 a) => new BFloat16(-a.ToSingle());
+
+ // --- Comparison operators ---
+
+ public static bool operator ==(BFloat16 a, BFloat16 b) => a.ToSingle() == b.ToSingle();
+ public static bool operator !=(BFloat16 a, BFloat16 b) => a.ToSingle() != b.ToSingle();
+ public static bool operator <(BFloat16 a, BFloat16 b) => a.ToSingle() < b.ToSingle();
+ public static bool operator >(BFloat16 a, BFloat16 b) => a.ToSingle() > b.ToSingle();
+ public static bool operator <=(BFloat16 a, BFloat16 b) => a.ToSingle() <= b.ToSingle();
+ public static bool operator >=(BFloat16 a, BFloat16 b) => a.ToSingle() >= b.ToSingle();
+
+ // --- IEquatable / IComparable ---
+
+ public bool Equals(BFloat16 other) => value == other.value;
+ public override bool Equals(object? obj) => obj is BFloat16 other && Equals(other);
+ public override int GetHashCode() => value.GetHashCode();
+
+ public int CompareTo(BFloat16 other) => ToSingle().CompareTo(other.ToSingle());
+ public int CompareTo(object? obj)
+ {
+ if (obj is null) return 1;
+ if (obj is BFloat16 other) return CompareTo(other);
+ throw new ArgumentException("Object must be of type BFloat16.");
+ }
+
+ // --- Formatting ---
+
+ public override string ToString() => ToSingle().ToString();
+ public string ToString(string? format, IFormatProvider? formatProvider) => ToSingle().ToString(format, formatProvider);
+
+ // --- Constants ---
+
+ public static readonly BFloat16 Zero = FromRawValue(0x0000);
+ public static readonly BFloat16 One = FromRawValue(0x3F80);
+ public static readonly BFloat16 NaN = FromRawValue(0x7FC0);
+ public static readonly BFloat16 PositiveInfinity = FromRawValue(0x7F80);
+ public static readonly BFloat16 NegativeInfinity = FromRawValue(0xFF80);
+ public static readonly BFloat16 MaxValue = FromRawValue(0x7F7F); // ~3.39e+38
+ public static readonly BFloat16 MinValue = FromRawValue(0xFF7F); // ~-3.39e+38
+ public static readonly BFloat16 Epsilon = FromRawValue(0x0080); // smallest normal
+ public static readonly BFloat16 SmallestSubnormal = FromRawValue(0x0001);
+
+ // --- Static helpers ---
+
+ public static bool IsNaN(BFloat16 bf) => float.IsNaN(bf.ToSingle());
+ public static bool IsInfinity(BFloat16 bf) => float.IsInfinity(bf.ToSingle());
+ public static bool IsPositiveInfinity(BFloat16 bf) => float.IsPositiveInfinity(bf.ToSingle());
+ public static bool IsNegativeInfinity(BFloat16 bf) => float.IsNegativeInfinity(bf.ToSingle());
+ public static bool IsFinite(BFloat16 bf) => !IsInfinity(bf) && !IsNaN(bf);
+ }
+}
diff --git a/src/TorchSharp/Scalar.cs b/src/TorchSharp/Scalar.cs
index 972039c0c..610333e68 100644
--- a/src/TorchSharp/Scalar.cs
+++ b/src/TorchSharp/Scalar.cs
@@ -81,6 +81,15 @@ public static implicit operator Scalar(Half value)
}
#endif
+ ///
+ /// Implicitly convert a BFloat16 value to Scalar
+ ///
+ /// The scalar value.
+ public static implicit operator Scalar(BFloat16 value)
+ {
+ return value.ToScalar();
+ }
+
///
/// Implicitly convert a .NET scalar value to Scalar
///
@@ -282,6 +291,16 @@ public static Scalar ToBFloat16Scalar(this float value)
return new Scalar(THSTorch_bfloat16_to_scalar(value));
}
+ ///
+ /// Explicitly construct a Scalar from a BFloat16 value.
+ ///
+ /// The input scalar value
+ public static Scalar ToScalar(this BFloat16 value)
+ {
+ torch.InitializeDeviceType(DeviceType.CPU);
+ return new Scalar(THSTorch_bfloat16_to_scalar(value.ToSingle()));
+ }
+
#if NET6_0_OR_GREATER
///
/// Explicitly convert a Scalar value to a .NET scalar
@@ -295,6 +314,16 @@ public static Half ToHalf(this Scalar value)
}
#endif
+ ///
+ /// Explicitly convert a Scalar value to a BFloat16.
+ ///
+ /// The input value.
+ public static BFloat16 ToBFloat16(this Scalar value)
+ {
+ THSTorch_scalar_to_bfloat16(value.Handle, out ushort res);
+ return BFloat16.FromRawValue(res);
+ }
+
///
/// Explicitly convert a Scalar value to a .NET scalar
///
diff --git a/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs b/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs
index b306c0cd7..572e0da4f 100644
--- a/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs
+++ b/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs
@@ -370,6 +370,8 @@ public static Tensor frombuffer(Array rawArray, ScalarType dtype, long count = -
switch (origType) {
case ScalarType.Int16:
+ case ScalarType.Float16:
+ case ScalarType.BFloat16:
offset *= 2;
break;
case ScalarType.Int32:
diff --git a/src/TorchSharp/Tensor/Factories/as_tensor.cs b/src/TorchSharp/Tensor/Factories/as_tensor.cs
index 162eda077..349db243f 100644
--- a/src/TorchSharp/Tensor/Factories/as_tensor.cs
+++ b/src/TorchSharp/Tensor/Factories/as_tensor.cs
@@ -121,5 +121,15 @@ public static Tensor as_tensor(System.Numerics.Complex[] rawArray, torch.ScalarT
{
return torch.from_array(rawArray, dtype, device);
}
+
+ public static Tensor as_tensor(IList rawArray, torch.ScalarType? dtype = null, torch.Device? device = null)
+ {
+ return torch.from_array(rawArray.ToArray(), dtype, device);
+ }
+
+ public static Tensor as_tensor(BFloat16[] rawArray, torch.ScalarType? dtype = null, torch.Device? device = null)
+ {
+ return torch.from_array(rawArray, dtype, device);
+ }
}
}
diff --git a/src/TorchSharp/Tensor/Factories/tensor_BFloat16.cs b/src/TorchSharp/Tensor/Factories/tensor_BFloat16.cs
new file mode 100644
index 000000000..0ab1e331d
--- /dev/null
+++ b/src/TorchSharp/Tensor/Factories/tensor_BFloat16.cs
@@ -0,0 +1,125 @@
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+using System;
+using System.Collections.Generic;
+using System.Diagnostics.Contracts;
+using System.Linq;
+
+#nullable enable
+namespace TorchSharp
+{
+ public static partial class torch
+ {
+ ///
+ /// Create a tensor from an array of values, shaping it based on the shape passed in.
+ ///
+ /// The Torch runtime does not take ownership of the data, so there is no device argument.
+ [Pure]
+ public static Tensor tensor(IList rawArray, ReadOnlySpan dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return _tensor_generic(rawArray.ToArray(), dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, false, names);
+ }
+
+ ///
+ /// Create a tensor from an array of values, shaping it based on the shape passed in.
+ ///
+ [Pure]
+ public static Tensor tensor(BFloat16[] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return _tensor_generic(rawArray, stackalloc long[] { rawArray.LongLength }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
+ }
+
+ ///
+ /// Create a tensor from an array of values, shaping it based on the shape passed in.
+ ///
+ [Pure]
+ public static Tensor tensor(BFloat16[] rawArray, ReadOnlySpan dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
+ }
+
+ ///
+ /// Create a 1-D tensor from an array of values, shaping it based on the input array.
+ ///
+ /// The Torch runtime does not take ownership of the data, so there is no device argument.
+ [Pure]
+ public static Tensor tensor(IList rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return tensor(rawArray, stackalloc long[] { (long)rawArray.Count }, dtype, device, requires_grad, names: names);
+ }
+
+ ///
+ /// Create a tensor from an array of values, organizing it as a two-dimensional tensor.
+ ///
+ ///
+ /// The Torch runtime does not take ownership of the data, so there is no device argument.
+ /// The input array must have rows * columns elements.
+ ///
+ [Pure]
+ public static Tensor tensor(IList rawArray, long rows, long columns, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return tensor(rawArray, stackalloc long[] { rows, columns }, dtype, device, requires_grad, names: names);
+ }
+
+ ///
+ /// Create a tensor from an array of values, organizing it as a three-dimensional tensor.
+ ///
+ ///
+ /// The Torch runtime does not take ownership of the data, so there is no device argument.
+ /// The input array must have dim0*dim1*dim2 elements.
+ ///
+ [Pure]
+ public static Tensor tensor(IList rawArray, long dim0, long dim1, long dim2, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2 }, dtype, device, requires_grad, names: names);
+ }
+
+ ///
+ /// Create a tensor from an array of values, organizing it as a four-dimensional tensor.
+ ///
+ ///
+ /// The Torch runtime does not take ownership of the data, so there is no device argument.
+ /// The input array must have dim0*dim1*dim2*dim3 elements.
+ ///
+ [Pure]
+ public static Tensor tensor(IList rawArray, long dim0, long dim1, long dim2, long dim3, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2, dim3 }, dtype, device, requires_grad, names: names);
+ }
+
+ ///
+ /// Create a two-dimensional tensor from a two-dimensional array of values.
+ ///
+ [Pure]
+ public static Tensor tensor(BFloat16[,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
+ }
+
+ ///
+ /// Create a three-dimensional tensor from a three-dimensional array of values.
+ ///
+ [Pure]
+ public static Tensor tensor(BFloat16[,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
+ }
+
+ ///
+ /// Create a four-dimensional tensor from a four-dimensional array of values.
+ ///
+ [Pure]
+ public static Tensor tensor(BFloat16[,,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
+ }
+
+ ///
+ /// Create a tensor from an array of values, shaping it based on the shape passed in.
+ ///
+ [Pure]
+ public static Tensor tensor(Memory rawArray, ReadOnlySpan dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
+ {
+ return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
+ }
+ }
+}
diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs
index ea70b83e1..59a31a551 100644
--- a/src/TorchSharp/Tensor/Tensor.cs
+++ b/src/TorchSharp/Tensor/Tensor.cs
@@ -475,7 +475,9 @@ internal void ValidateType(Type dotnetType)
throw new ArgumentException($"{dotnetType.Name} is not compatible with {dtype.ToString()}");
break;
case ScalarType.BFloat16:
- throw new ArgumentException($"No support for {dtype.ToString()} in TorchSharp");
+ if (dotnetType != typeof(BFloat16))
+ throw new ArgumentException($"{dotnetType.Name} is not compatible with {dtype.ToString()}");
+ break;
case ScalarType.Float16:
#if NET6_0_OR_GREATER
if (dotnetType != typeof(Half))
@@ -6886,6 +6888,14 @@ private static string ToCSharpString(Tensor t, long mdim, bool isFCreate, string
case ScalarType.Float64:
if (top) sb.Append("double ");
break;
+ case ScalarType.BFloat16:
+ if (top) sb.Append("bfloat16 ");
+ appendChar = "f";
+ break;
+ case ScalarType.Float16:
+ if (top) sb.Append("float16 ");
+ appendChar = "f";
+ break;
case ScalarType.ComplexFloat32:
if (top) sb.Append("complex32 ");
break;
@@ -7166,6 +7176,7 @@ private static void PrintValue(StringBuilder builder, ScalarType type, Scalar va
case ScalarType.Bool:
builder.Append(value.ToBoolean().ToString(cultureInfo));
break;
+ case ScalarType.BFloat16:
case ScalarType.Float16:
builder.Append(value.ToSingle().ToString(fltFormat, cultureInfo));
break;
@@ -7462,6 +7473,7 @@ public enum ScalarType : sbyte
{ typeof(short), ScalarType.Int16 },
{ typeof(int), ScalarType.Int32 },
{ typeof(long), ScalarType.Int64 },
+ { typeof(BFloat16), ScalarType.BFloat16 },
#if NET6_0_OR_GREATER
{ typeof(Half), ScalarType.Float16 },
#endif
diff --git a/src/TorchSharp/Tensor/TensorExtensionMethods.cs b/src/TorchSharp/Tensor/TensorExtensionMethods.cs
index cc4fc6b3c..d58b66624 100644
--- a/src/TorchSharp/Tensor/TensorExtensionMethods.cs
+++ b/src/TorchSharp/Tensor/TensorExtensionMethods.cs
@@ -578,6 +578,8 @@ true when typeof(T) == typeof(float) => tensor((array as float[])!, dimensions,
requires_grad: requires_grad),
true when typeof(T) == typeof(bool) => tensor((array as bool[])!, dimensions,
requires_grad: requires_grad),
+ true when typeof(T) == typeof(BFloat16) => tensor((array as BFloat16[])!, dimensions,
+ requires_grad: requires_grad),
_ => throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.")
};
}
@@ -592,7 +594,8 @@ true when typeof(T) == typeof(bool) => tensor((array as bool[])!, dimensions,
///
public static Tensor ToTensor(this T scalar, Device? device = null, bool requires_grad = false) where T : struct
{
- if (requires_grad && typeof(T) != typeof(float) && typeof(T) != typeof(double)) {
+ if (requires_grad && typeof(T) != typeof(float) && typeof(T) != typeof(double)
+ && typeof(T) != typeof(BFloat16)) {
throw new ArgumentException(nameof(requires_grad), "Only floating point types support gradients.");
}
@@ -610,6 +613,8 @@ public static Tensor ToTensor(this T scalar, Device? device = null, bool requ
return tensor((float)(object)scalar, float32, device, requires_grad);
if (typeof(T) == typeof(double))
return tensor((double)(object)scalar, float64, device, requires_grad);
+ if (typeof(T) == typeof(BFloat16))
+ return tensor(new BFloat16[] { (BFloat16)(object)scalar }, bfloat16, device, requires_grad);
throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
}
@@ -633,6 +638,12 @@ public static Tensor ToTensor(this T scalar, Device? device = null, bool requ
public static Half ToHalf(this Tensor value) => value.ToScalar().ToHalf();
#endif
+ ///
+ /// Explicitly convert a singleton tensor to a BFloat16 value.
+ ///
+ /// The input tensor
+ public static BFloat16 ToBFloat16(this Tensor value) => value.ToScalar().ToBFloat16();
+
///
/// Explicitly convert a singleton tensor to a .NET scalar value.
///
diff --git a/test/TorchSharpTest/TestTorchTensor.cs b/test/TorchSharpTest/TestTorchTensor.cs
index bd75ecbb2..294a04935 100644
--- a/test/TorchSharpTest/TestTorchTensor.cs
+++ b/test/TorchSharpTest/TestTorchTensor.cs
@@ -102,6 +102,91 @@ public void ScalarToString()
}
}
+ [Fact]
+ [TestOf(nameof(Tensor.ToString))]
+ public void TestBFloat16ScalarToString()
+ {
+ // Scalar (0-d tensor)
+ {
+ var t = torch.tensor(3.14f, torch.bfloat16);
+ var str = t.jlstr(cultureInfo: CultureInfo.InvariantCulture);
+ Assert.Equal("[], type = BFloat16, device = cpu, value = 3.1406", str);
+ }
+ {
+ var t = torch.tensor(3.14f, torch.bfloat16);
+ var str = t.npstr(cultureInfo: CultureInfo.InvariantCulture);
+ Assert.Equal("3.1406", str);
+ }
+ {
+ var t = torch.tensor(3.14f, torch.bfloat16);
+ var str = t.cstr(cultureInfo: CultureInfo.InvariantCulture);
+ Assert.Equal("[], type = BFloat16, device = cpu, value = 3.1406", str);
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.ToString))]
+ public void TestBFloat16TensorToString()
+ {
+ // 1-D tensor
+ {
+ var t = torch.zeros(4, torch.bfloat16);
+ var str = t.ToString(torch.numpy, cultureInfo: CultureInfo.InvariantCulture);
+ Assert.Equal("[0, 0, 0, 0]", str);
+ }
+ // 1-D Julia
+ {
+ var t = torch.zeros(4, torch.bfloat16);
+ var str = t.jlstr(cultureInfo: CultureInfo.InvariantCulture);
+ Assert.Equal($"[4], type = BFloat16, device = cpu{Environment.NewLine} 0 0 0 0{Environment.NewLine}", str);
+ }
+ // 1-D CSharp
+ {
+ var t = torch.zeros(4, torch.bfloat16);
+ var str = t.cstr(cultureInfo: CultureInfo.InvariantCulture);
+ Assert.Equal("[4], type = BFloat16, device = cpu, value = bfloat16 [] {0f, 0f, 0f, 0f}", str);
+ }
+ // 2-D tensor
+ {
+ var t = torch.ones(2, 3, torch.bfloat16);
+ var str = t.ToString(torch.numpy, cultureInfo: CultureInfo.InvariantCulture);
+ Assert.Equal($"[[1, 1, 1]{Environment.NewLine} [1, 1, 1]]", str);
+ }
+ // print() should not throw
+ {
+ var t = torch.randn(3, 3).to(torch.bfloat16);
+ var originalOut = Console.Out;
+ using (var sw = new StringWriter()) {
+ try {
+ Console.SetOut(sw);
+ t.print(cultureInfo: CultureInfo.InvariantCulture);
+ var result = sw.ToString();
+ Assert.False(string.IsNullOrEmpty(result));
+ } finally {
+ Console.SetOut(originalOut);
+ }
+ }
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.ToString))]
+ public void TestFloat16TensorToString()
+ {
+ // 1-D CSharp
+ {
+ var t = torch.zeros(4, torch.float16);
+ var str = t.cstr(cultureInfo: CultureInfo.InvariantCulture);
+ Assert.Equal("[4], type = Float16, device = cpu, value = float16 [] {0f, 0f, 0f, 0f}", str);
+ }
+ // 1-D Numpy
+ {
+ var t = torch.zeros(4, torch.float16);
+ var str = t.npstr(cultureInfo: CultureInfo.InvariantCulture);
+ Assert.Equal("[0, 0, 0, 0]", str);
+ }
+ }
+
private string _sep = Environment.NewLine;
[Fact]
@@ -611,6 +696,154 @@ public void DataFloat16()
#endif
}
+ [Fact]
+ [TestOf(nameof(torch.ones))]
+ public void DataBFloat16()
+ {
+ var x = torch.ones(5, torch.bfloat16);
+ Assert.Throws(() => x.data());
+ Assert.Throws(() => x.data());
+ Assert.Throws(() => x.data());
+ Assert.Throws(() => x.data());
+ Assert.Throws(() => x.data());
+ Assert.Throws(() => x.data());
+ Assert.Throws(() => x.data());
+ Assert.Throws(() => x.data());
+ Assert.Throws(() => x.data<(float, float)>());
+ Assert.Throws(() => x.data());
+ var accessor = x.data();
+ Assert.Equal(5, accessor.Count);
+ Assert.Equal((BFloat16)1.0f, accessor[0]);
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor))]
+ public void DataBFloat16Item()
+ {
+ // item on scalar tensor
+ var t = torch.tensor(3.14f, torch.bfloat16);
+ var val = t.item();
+ Assert.Equal(3.140625f, val.ToSingle());
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor))]
+ public void DataBFloat16RoundTrip()
+ {
+ // Create tensor from BFloat16 array, read back
+ var input = new BFloat16[] { (BFloat16)1.0f, (BFloat16)2.0f, (BFloat16)3.0f, (BFloat16)0.5f };
+ var t = torch.tensor(input);
+ Assert.Equal(ScalarType.BFloat16, t.dtype);
+ Assert.Equal(4, t.NumberOfElements);
+ var output = t.data().ToArray();
+ Assert.Equal(input, output);
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.tensor))]
+ public void MDTensorFactoryBFloat16()
+ {
+ {
+ var array = new BFloat16[8];
+ var t = torch.tensor(array);
+ Assert.Equal(1, t.ndim);
+ Assert.Equal(ScalarType.BFloat16, t.dtype);
+ }
+
+ {
+ var array = new BFloat16[8];
+ var t = torch.tensor(array, new long[] { 8 });
+ Assert.Equal(1, t.ndim);
+ Assert.Equal(ScalarType.BFloat16, t.dtype);
+ }
+
+ {
+ var array = new BFloat16[1, 2];
+ var t = torch.tensor(array);
+ Assert.Equal(2, t.ndim);
+ Assert.Equal(new long[] { 1, 2 }, t.shape);
+ Assert.Equal(ScalarType.BFloat16, t.dtype);
+ }
+
+ {
+ var array = new BFloat16[1, 2, 3];
+ var t = torch.tensor(array);
+ Assert.Equal(3, t.ndim);
+ Assert.Equal(new long[] { 1, 2, 3 }, t.shape);
+ Assert.Equal(ScalarType.BFloat16, t.dtype);
+ }
+
+ {
+ var array = new BFloat16[1, 2, 3, 4];
+ var t = torch.tensor(array);
+ Assert.Equal(4, t.ndim);
+ Assert.Equal(new long[] { 1, 2, 3, 4 }, t.shape);
+ Assert.Equal(ScalarType.BFloat16, t.dtype);
+ }
+
+ {
+ var array = new BFloat16[,,] { { { (BFloat16)1f, (BFloat16)2f }, { (BFloat16)3f, (BFloat16)4f } }, { { (BFloat16)5f, (BFloat16)6f }, { (BFloat16)7f, (BFloat16)8f } } };
+ var t = torch.tensor(array);
+ Assert.Equal(3, t.ndim);
+ Assert.Equal(new long[] { 2, 2, 2 }, t.shape);
+ Assert.Equal(ScalarType.BFloat16, t.dtype);
+ Assert.Equal(array.Cast().ToArray(), t.data().ToArray());
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(TensorExtensionMethods.ToBFloat16))]
+ public void TensorToBFloat16Extension()
+ {
+ var t = torch.tensor(3.14f, torch.bfloat16);
+ var bf = t.ToBFloat16();
+ Assert.Equal(3.140625f, bf.ToSingle());
+ }
+
+ [Fact]
+ [TestOf(nameof(BFloat16))]
+ public void BFloat16StructBasics()
+ {
+ // Size must be 2 bytes
+ Assert.Equal(2, Marshal.SizeOf());
+
+ // Precision loss: 333.0f -> BFloat16 -> float (matching PyTorch)
+ var bf = (BFloat16)333.0f;
+ Assert.Equal(332.0f, bf.ToSingle());
+
+ // Round-to-nearest-even: 3.14f → BFloat16 matches PyTorch native conversion
+ var bf_rne = (BFloat16)3.14f;
+ Assert.Equal(3.140625f, bf_rne.ToSingle());
+
+ // Round-trip for a value that fits exactly
+ var bf2 = (BFloat16)1.0f;
+ Assert.Equal(1.0f, bf2.ToSingle());
+
+ // Special values
+ Assert.True(BFloat16.IsNaN(BFloat16.NaN));
+ Assert.True(BFloat16.IsPositiveInfinity(BFloat16.PositiveInfinity));
+ Assert.True(BFloat16.IsNegativeInfinity(BFloat16.NegativeInfinity));
+ Assert.True(BFloat16.IsFinite(BFloat16.One));
+ Assert.False(BFloat16.IsFinite(BFloat16.NaN));
+
+ // Arithmetic
+ var a = (BFloat16)2.0f;
+ var b = (BFloat16)3.0f;
+ Assert.Equal(5.0f, (a + b).ToSingle());
+ Assert.Equal(-1.0f, (a - b).ToSingle());
+ Assert.Equal(6.0f, (a * b).ToSingle());
+
+ // Comparison
+ Assert.True(a < b);
+ Assert.True(b > a);
+ Assert.True(a == (BFloat16)2.0f);
+ Assert.True(a != b);
+
+ // Equality
+ Assert.Equal((BFloat16)1.5f, (BFloat16)1.5f);
+ Assert.NotEqual((BFloat16)1.5f, (BFloat16)2.0f);
+ }
+
[Fact]
[TestOf(nameof(torch.ones))]
public void DataFloat32()