From 2c11a4c680420a45b5927c6cce7019bb756dae12 Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Fri, 13 Feb 2026 15:31:49 +0100 Subject: [PATCH 1/5] Fix bfloat16 tensor printing (issue #1469) Added BFloat16 case to PrintValue() and ToCSharpString() so that bfloat16 (and float16) tensors can be printed in all string styles. --- src/TorchSharp/Tensor/Tensor.cs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index ea70b83e1..5b7266051 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -6886,6 +6886,14 @@ private static string ToCSharpString(Tensor t, long mdim, bool isFCreate, string case ScalarType.Float64: if (top) sb.Append("double "); break; + case ScalarType.BFloat16: + if (top) sb.Append("bfloat16 "); + appendChar = "f"; + break; + case ScalarType.Float16: + if (top) sb.Append("float16 "); + appendChar = "f"; + break; case ScalarType.ComplexFloat32: if (top) sb.Append("complex32 "); break; @@ -7166,6 +7174,7 @@ private static void PrintValue(StringBuilder builder, ScalarType type, Scalar va case ScalarType.Bool: builder.Append(value.ToBoolean().ToString(cultureInfo)); break; + case ScalarType.BFloat16: case ScalarType.Float16: builder.Append(value.ToSingle().ToString(fltFormat, cultureInfo)); break; From fb459e7bcf030425664428dc9a949ea827203b6f Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Mon, 16 Feb 2026 10:57:20 +0100 Subject: [PATCH 2/5] Add tests for BFloat16 and Float16 tensor printing Add test coverage for the bfloat16/float16 printing fix: - TestBFloat16ScalarToString: scalar tensor in jlstr, npstr, cstr formats - TestBFloat16TensorToString: 1-D and 2-D tensors in numpy, julia, csharp formats plus print() - TestFloat16TensorToString: 1-D tensors in csharp and numpy formats Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- test/TorchSharpTest/TestTorchTensor.cs | 85 ++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/test/TorchSharpTest/TestTorchTensor.cs b/test/TorchSharpTest/TestTorchTensor.cs index bd75ecbb2..7663d9ef0 100644 --- a/test/TorchSharpTest/TestTorchTensor.cs +++ b/test/TorchSharpTest/TestTorchTensor.cs @@ -102,6 +102,91 @@ public void ScalarToString() } } + [Fact] + [TestOf(nameof(Tensor.ToString))] + public void TestBFloat16ScalarToString() + { + // Scalar (0-d tensor) + { + var t = torch.tensor(3.14f, torch.bfloat16); + var str = t.jlstr(cultureInfo: CultureInfo.InvariantCulture); + Assert.Equal("[], type = BFloat16, device = cpu, value = 3.1406", str); + } + { + var t = torch.tensor(3.14f, torch.bfloat16); + var str = t.npstr(cultureInfo: CultureInfo.InvariantCulture); + Assert.Equal("3.1406", str); + } + { + var t = torch.tensor(3.14f, torch.bfloat16); + var str = t.cstr(cultureInfo: CultureInfo.InvariantCulture); + Assert.Equal("[], type = BFloat16, device = cpu, value = 3.1406", str); + } + } + + [Fact] + [TestOf(nameof(Tensor.ToString))] + public void TestBFloat16TensorToString() + { + // 1-D tensor + { + var t = torch.zeros(4, torch.bfloat16); + var str = t.ToString(torch.numpy, cultureInfo: CultureInfo.InvariantCulture); + Assert.Equal("[0, 0, 0, 0]", str); + } + // 1-D Julia + { + var t = torch.zeros(4, torch.bfloat16); + var str = t.jlstr(cultureInfo: CultureInfo.InvariantCulture); + Assert.Equal($"[4], type = BFloat16, device = cpu{Environment.NewLine} 0 0 0 0{Environment.NewLine}", str); + } + // 1-D CSharp + { + var t = torch.zeros(4, torch.bfloat16); + var str = t.cstr(cultureInfo: CultureInfo.InvariantCulture); + Assert.Equal("[4], type = BFloat16, device = cpu, value = bfloat16 [] {0f, 0f, 0f, 0f}", str); + } + // 2-D tensor + { + var t = torch.ones(2, 3, torch.bfloat16); + var str = t.ToString(torch.numpy, cultureInfo: CultureInfo.InvariantCulture); + Assert.Equal($"[[1, 1, 1]{Environment.NewLine} [1, 1, 1]]", str); + } + // print() should not throw + { + var t = torch.randn(3, 3).to(torch.bfloat16); + var originalOut = Console.Out; + using (var sw = new StringWriter()) { + try { + Console.SetOut(sw); + t.print(cultureInfo: CultureInfo.InvariantCulture); + var result = sw.ToString(); + Assert.False(string.IsNullOrEmpty(result)); + } finally { + Console.SetOut(originalOut); + } + } + } + } + + [Fact] + [TestOf(nameof(Tensor.ToString))] + public void TestFloat16TensorToString() + { + // 1-D CSharp + { + var t = torch.zeros(4, torch.float16); + var str = t.cstr(cultureInfo: CultureInfo.InvariantCulture); + Assert.Equal("[4], type = Float16, device = cpu, value = float16 [] {0f, 0f, 0f, 0f}", str); + } + // 1-D Numpy + { + var t = torch.zeros(4, torch.float16); + var str = t.npstr(cultureInfo: CultureInfo.InvariantCulture); + Assert.Equal("[0, 0, 0, 0]", str); + } + } + private string _sep = Environment.NewLine; [Fact] From 388a93ba0d07d011acce75b00cdc2b9149186cbe Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Wed, 25 Feb 2026 12:11:52 +0100 Subject: [PATCH 3/5] Add standalone BFloat16 test scripts Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- TestBFloat.cs | 22 ++++++++++++++++++++++ test_bfloat.csx | 9 +++++++++ 2 files changed, 31 insertions(+) create mode 100644 TestBFloat.cs create mode 100644 test_bfloat.csx diff --git a/TestBFloat.cs b/TestBFloat.cs new file mode 100644 index 000000000..294101c3c --- /dev/null +++ b/TestBFloat.cs @@ -0,0 +1,22 @@ +using System; +using System.Globalization; +using TorchSharp; + +class Program +{ + static void Main() + { + var t = torch.tensor(3.14f, torch.bfloat16); + Console.WriteLine($"BFloat16 tensor value: {t.item()}"); + Console.WriteLine($"Julia string: {t.jlstr(cultureInfo: CultureInfo.InvariantCulture)}"); + Console.WriteLine($"Numpy string: {t.npstr(cultureInfo: CultureInfo.InvariantCulture)}"); + Console.WriteLine($"CSharp string: {t.cstr(cultureInfo: CultureInfo.InvariantCulture)}"); + + // Test a few more values to understand precision + var t2 = torch.tensor(1.0f, torch.bfloat16); + Console.WriteLine($"\n1.0 as bfloat16: {t2.item()}"); + + var t3 = torch.tensor(0.1f, torch.bfloat16); + Console.WriteLine($"0.1 as bfloat16: {t3.item()}"); + } +} diff --git a/test_bfloat.csx b/test_bfloat.csx new file mode 100644 index 000000000..12a284653 --- /dev/null +++ b/test_bfloat.csx @@ -0,0 +1,9 @@ +using System; +using System.Globalization; +using TorchSharp; + +var t = torch.tensor(3.14f, torch.bfloat16); +Console.WriteLine($"BFloat16 tensor value: {t.item()}"); +Console.WriteLine($"Julia string: {t.jlstr(cultureInfo: CultureInfo.InvariantCulture)}"); +Console.WriteLine($"Numpy string: {t.npstr(cultureInfo: CultureInfo.InvariantCulture)}"); +Console.WriteLine($"CSharp string: {t.cstr(cultureInfo: CultureInfo.InvariantCulture)}"); From 6beec00a0fcabae786004ca338480a4513d1a483 Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Wed, 25 Feb 2026 12:13:31 +0100 Subject: [PATCH 4/5] Remove standalone BFloat16 test scripts Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- TestBFloat.cs | 22 ---------------------- test_bfloat.csx | 9 --------- 2 files changed, 31 deletions(-) delete mode 100644 TestBFloat.cs delete mode 100644 test_bfloat.csx diff --git a/TestBFloat.cs b/TestBFloat.cs deleted file mode 100644 index 294101c3c..000000000 --- a/TestBFloat.cs +++ /dev/null @@ -1,22 +0,0 @@ -using System; -using System.Globalization; -using TorchSharp; - -class Program -{ - static void Main() - { - var t = torch.tensor(3.14f, torch.bfloat16); - Console.WriteLine($"BFloat16 tensor value: {t.item()}"); - Console.WriteLine($"Julia string: {t.jlstr(cultureInfo: CultureInfo.InvariantCulture)}"); - Console.WriteLine($"Numpy string: {t.npstr(cultureInfo: CultureInfo.InvariantCulture)}"); - Console.WriteLine($"CSharp string: {t.cstr(cultureInfo: CultureInfo.InvariantCulture)}"); - - // Test a few more values to understand precision - var t2 = torch.tensor(1.0f, torch.bfloat16); - Console.WriteLine($"\n1.0 as bfloat16: {t2.item()}"); - - var t3 = torch.tensor(0.1f, torch.bfloat16); - Console.WriteLine($"0.1 as bfloat16: {t3.item()}"); - } -} diff --git a/test_bfloat.csx b/test_bfloat.csx deleted file mode 100644 index 12a284653..000000000 --- a/test_bfloat.csx +++ /dev/null @@ -1,9 +0,0 @@ -using System; -using System.Globalization; -using TorchSharp; - -var t = torch.tensor(3.14f, torch.bfloat16); -Console.WriteLine($"BFloat16 tensor value: {t.item()}"); -Console.WriteLine($"Julia string: {t.jlstr(cultureInfo: CultureInfo.InvariantCulture)}"); -Console.WriteLine($"Numpy string: {t.npstr(cultureInfo: CultureInfo.InvariantCulture)}"); -Console.WriteLine($"CSharp string: {t.cstr(cultureInfo: CultureInfo.InvariantCulture)}"); From bd8ded682c8fcb684c8386d6a8a6ed373ded8524 Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Wed, 25 Feb 2026 13:42:59 +0100 Subject: [PATCH 5/5] Add full BFloat16 managed type support Implement a custom BFloat16 struct (2 bytes, binary-compatible with c10::BFloat16) with round-to-nearest-even conversion matching PyTorch. New features: - BFloat16 struct with arithmetic, comparison, and conversion operators - data() and item() for zero-copy tensor data access - torch.tensor(BFloat16[], ...) factory overloads for 1D-4D arrays - Scalar implicit conversion and ToBFloat16() extraction - ToBFloat16() tensor extension method - ToTensor support for BFloat16 - as_tensor() BFloat16 overloads - frombuffer offset fix for BFloat16 and Float16 (2-byte types) Fixes #1469 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/TorchSharp/BFloat16.cs | 127 +++++++++++++++ src/TorchSharp/Scalar.cs | 29 ++++ .../Tensor/Factories/Tensor.Factories.cs | 2 + src/TorchSharp/Tensor/Factories/as_tensor.cs | 10 ++ .../Tensor/Factories/tensor_BFloat16.cs | 125 +++++++++++++++ src/TorchSharp/Tensor/Tensor.cs | 5 +- .../Tensor/TensorExtensionMethods.cs | 13 +- test/TorchSharpTest/TestTorchTensor.cs | 148 ++++++++++++++++++ 8 files changed, 457 insertions(+), 2 deletions(-) create mode 100644 src/TorchSharp/BFloat16.cs create mode 100644 src/TorchSharp/Tensor/Factories/tensor_BFloat16.cs diff --git a/src/TorchSharp/BFloat16.cs b/src/TorchSharp/BFloat16.cs new file mode 100644 index 000000000..3c96eb130 --- /dev/null +++ b/src/TorchSharp/BFloat16.cs @@ -0,0 +1,127 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Globalization; +using System.Runtime.InteropServices; + +#nullable enable +namespace TorchSharp +{ + /// + /// Represents a 16-bit brain floating-point number (BFloat16). + /// Binary layout: 1 sign bit, 8 exponent bits, 7 mantissa bits — the upper 16 bits of IEEE 754 float32. + /// Binary-compatible with c10::BFloat16 in LibTorch. + /// + [StructLayout(LayoutKind.Sequential)] + public readonly struct BFloat16 : IComparable, IEquatable, IComparable, IFormattable + { + internal readonly ushort value; + + internal BFloat16(ushort rawValue, bool _) + { + value = rawValue; + } + + /// + /// Creates a BFloat16 from a float value using round-to-nearest-even (matching PyTorch c10::BFloat16). + /// + public BFloat16(float f) + { + value = FloatToBFloat16Bits(f); + } + + /// + /// Creates a BFloat16 from the raw 16-bit representation. + /// + public static BFloat16 FromRawValue(ushort rawValue) => new BFloat16(rawValue, false); + + // --- Conversion to/from float --- + + private static unsafe ushort FloatToBFloat16Bits(float f) + { + uint bits = *(uint*)&f; + // NaN: preserve payload, just truncate + if ((bits & 0x7F800000u) == 0x7F800000u && (bits & 0x007FFFFFu) != 0) + return (ushort)(bits >> 16 | 0x0040u); // quiet NaN + // Round-to-nearest-even (matching PyTorch c10::BFloat16) + uint lsb = (bits >> 16) & 1u; + uint roundingBias = 0x7FFFu + lsb; + bits += roundingBias; + return (ushort)(bits >> 16); + } + + private static unsafe float BFloat16BitsToFloat(ushort raw) + { + int bits = raw << 16; + return *(float*)&bits; + } + + /// + /// Converts this BFloat16 to a float. + /// + public float ToSingle() => BFloat16BitsToFloat(value); + + // --- Conversion operators --- + + public static explicit operator float(BFloat16 bf) => bf.ToSingle(); + public static explicit operator double(BFloat16 bf) => bf.ToSingle(); + public static explicit operator BFloat16(float f) => new BFloat16(f); + public static explicit operator BFloat16(double d) => new BFloat16((float)d); + + // --- Arithmetic operators (promote to float, truncate back) --- + + public static BFloat16 operator +(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() + b.ToSingle()); + public static BFloat16 operator -(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() - b.ToSingle()); + public static BFloat16 operator *(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() * b.ToSingle()); + public static BFloat16 operator /(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() / b.ToSingle()); + public static BFloat16 operator %(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() % b.ToSingle()); + public static BFloat16 operator -(BFloat16 a) => new BFloat16(-a.ToSingle()); + + // --- Comparison operators --- + + public static bool operator ==(BFloat16 a, BFloat16 b) => a.ToSingle() == b.ToSingle(); + public static bool operator !=(BFloat16 a, BFloat16 b) => a.ToSingle() != b.ToSingle(); + public static bool operator <(BFloat16 a, BFloat16 b) => a.ToSingle() < b.ToSingle(); + public static bool operator >(BFloat16 a, BFloat16 b) => a.ToSingle() > b.ToSingle(); + public static bool operator <=(BFloat16 a, BFloat16 b) => a.ToSingle() <= b.ToSingle(); + public static bool operator >=(BFloat16 a, BFloat16 b) => a.ToSingle() >= b.ToSingle(); + + // --- IEquatable / IComparable --- + + public bool Equals(BFloat16 other) => value == other.value; + public override bool Equals(object? obj) => obj is BFloat16 other && Equals(other); + public override int GetHashCode() => value.GetHashCode(); + + public int CompareTo(BFloat16 other) => ToSingle().CompareTo(other.ToSingle()); + public int CompareTo(object? obj) + { + if (obj is null) return 1; + if (obj is BFloat16 other) return CompareTo(other); + throw new ArgumentException("Object must be of type BFloat16."); + } + + // --- Formatting --- + + public override string ToString() => ToSingle().ToString(); + public string ToString(string? format, IFormatProvider? formatProvider) => ToSingle().ToString(format, formatProvider); + + // --- Constants --- + + public static readonly BFloat16 Zero = FromRawValue(0x0000); + public static readonly BFloat16 One = FromRawValue(0x3F80); + public static readonly BFloat16 NaN = FromRawValue(0x7FC0); + public static readonly BFloat16 PositiveInfinity = FromRawValue(0x7F80); + public static readonly BFloat16 NegativeInfinity = FromRawValue(0xFF80); + public static readonly BFloat16 MaxValue = FromRawValue(0x7F7F); // ~3.39e+38 + public static readonly BFloat16 MinValue = FromRawValue(0xFF7F); // ~-3.39e+38 + public static readonly BFloat16 Epsilon = FromRawValue(0x0080); // smallest normal + public static readonly BFloat16 SmallestSubnormal = FromRawValue(0x0001); + + // --- Static helpers --- + + public static bool IsNaN(BFloat16 bf) => float.IsNaN(bf.ToSingle()); + public static bool IsInfinity(BFloat16 bf) => float.IsInfinity(bf.ToSingle()); + public static bool IsPositiveInfinity(BFloat16 bf) => float.IsPositiveInfinity(bf.ToSingle()); + public static bool IsNegativeInfinity(BFloat16 bf) => float.IsNegativeInfinity(bf.ToSingle()); + public static bool IsFinite(BFloat16 bf) => !IsInfinity(bf) && !IsNaN(bf); + } +} diff --git a/src/TorchSharp/Scalar.cs b/src/TorchSharp/Scalar.cs index 972039c0c..610333e68 100644 --- a/src/TorchSharp/Scalar.cs +++ b/src/TorchSharp/Scalar.cs @@ -81,6 +81,15 @@ public static implicit operator Scalar(Half value) } #endif + /// + /// Implicitly convert a BFloat16 value to Scalar + /// + /// The scalar value. + public static implicit operator Scalar(BFloat16 value) + { + return value.ToScalar(); + } + /// /// Implicitly convert a .NET scalar value to Scalar /// @@ -282,6 +291,16 @@ public static Scalar ToBFloat16Scalar(this float value) return new Scalar(THSTorch_bfloat16_to_scalar(value)); } + /// + /// Explicitly construct a Scalar from a BFloat16 value. + /// + /// The input scalar value + public static Scalar ToScalar(this BFloat16 value) + { + torch.InitializeDeviceType(DeviceType.CPU); + return new Scalar(THSTorch_bfloat16_to_scalar(value.ToSingle())); + } + #if NET6_0_OR_GREATER /// /// Explicitly convert a Scalar value to a .NET scalar @@ -295,6 +314,16 @@ public static Half ToHalf(this Scalar value) } #endif + /// + /// Explicitly convert a Scalar value to a BFloat16. + /// + /// The input value. + public static BFloat16 ToBFloat16(this Scalar value) + { + THSTorch_scalar_to_bfloat16(value.Handle, out ushort res); + return BFloat16.FromRawValue(res); + } + /// /// Explicitly convert a Scalar value to a .NET scalar /// diff --git a/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs b/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs index b306c0cd7..572e0da4f 100644 --- a/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs +++ b/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs @@ -370,6 +370,8 @@ public static Tensor frombuffer(Array rawArray, ScalarType dtype, long count = - switch (origType) { case ScalarType.Int16: + case ScalarType.Float16: + case ScalarType.BFloat16: offset *= 2; break; case ScalarType.Int32: diff --git a/src/TorchSharp/Tensor/Factories/as_tensor.cs b/src/TorchSharp/Tensor/Factories/as_tensor.cs index 162eda077..349db243f 100644 --- a/src/TorchSharp/Tensor/Factories/as_tensor.cs +++ b/src/TorchSharp/Tensor/Factories/as_tensor.cs @@ -121,5 +121,15 @@ public static Tensor as_tensor(System.Numerics.Complex[] rawArray, torch.ScalarT { return torch.from_array(rawArray, dtype, device); } + + public static Tensor as_tensor(IList rawArray, torch.ScalarType? dtype = null, torch.Device? device = null) + { + return torch.from_array(rawArray.ToArray(), dtype, device); + } + + public static Tensor as_tensor(BFloat16[] rawArray, torch.ScalarType? dtype = null, torch.Device? device = null) + { + return torch.from_array(rawArray, dtype, device); + } } } diff --git a/src/TorchSharp/Tensor/Factories/tensor_BFloat16.cs b/src/TorchSharp/Tensor/Factories/tensor_BFloat16.cs new file mode 100644 index 000000000..0ab1e331d --- /dev/null +++ b/src/TorchSharp/Tensor/Factories/tensor_BFloat16.cs @@ -0,0 +1,125 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.Diagnostics.Contracts; +using System.Linq; + +#nullable enable +namespace TorchSharp +{ + public static partial class torch + { + /// + /// Create a tensor from an array of values, shaping it based on the shape passed in. + /// + /// The Torch runtime does not take ownership of the data, so there is no device argument. + [Pure] + public static Tensor tensor(IList rawArray, ReadOnlySpan dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return _tensor_generic(rawArray.ToArray(), dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, false, names); + } + + /// + /// Create a tensor from an array of values, shaping it based on the shape passed in. + /// + [Pure] + public static Tensor tensor(BFloat16[] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return _tensor_generic(rawArray, stackalloc long[] { rawArray.LongLength }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); + } + + /// + /// Create a tensor from an array of values, shaping it based on the shape passed in. + /// + [Pure] + public static Tensor tensor(BFloat16[] rawArray, ReadOnlySpan dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); + } + + /// + /// Create a 1-D tensor from an array of values, shaping it based on the input array. + /// + /// The Torch runtime does not take ownership of the data, so there is no device argument. + [Pure] + public static Tensor tensor(IList rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return tensor(rawArray, stackalloc long[] { (long)rawArray.Count }, dtype, device, requires_grad, names: names); + } + + /// + /// Create a tensor from an array of values, organizing it as a two-dimensional tensor. + /// + /// + /// The Torch runtime does not take ownership of the data, so there is no device argument. + /// The input array must have rows * columns elements. + /// + [Pure] + public static Tensor tensor(IList rawArray, long rows, long columns, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return tensor(rawArray, stackalloc long[] { rows, columns }, dtype, device, requires_grad, names: names); + } + + /// + /// Create a tensor from an array of values, organizing it as a three-dimensional tensor. + /// + /// + /// The Torch runtime does not take ownership of the data, so there is no device argument. + /// The input array must have dim0*dim1*dim2 elements. + /// + [Pure] + public static Tensor tensor(IList rawArray, long dim0, long dim1, long dim2, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2 }, dtype, device, requires_grad, names: names); + } + + /// + /// Create a tensor from an array of values, organizing it as a four-dimensional tensor. + /// + /// + /// The Torch runtime does not take ownership of the data, so there is no device argument. + /// The input array must have dim0*dim1*dim2*dim3 elements. + /// + [Pure] + public static Tensor tensor(IList rawArray, long dim0, long dim1, long dim2, long dim3, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2, dim3 }, dtype, device, requires_grad, names: names); + } + + /// + /// Create a two-dimensional tensor from a two-dimensional array of values. + /// + [Pure] + public static Tensor tensor(BFloat16[,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); + } + + /// + /// Create a three-dimensional tensor from a three-dimensional array of values. + /// + [Pure] + public static Tensor tensor(BFloat16[,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); + } + + /// + /// Create a four-dimensional tensor from a four-dimensional array of values. + /// + [Pure] + public static Tensor tensor(BFloat16[,,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); + } + + /// + /// Create a tensor from an array of values, shaping it based on the shape passed in. + /// + [Pure] + public static Tensor tensor(Memory rawArray, ReadOnlySpan dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) + { + return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); + } + } +} diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index 5b7266051..59a31a551 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -475,7 +475,9 @@ internal void ValidateType(Type dotnetType) throw new ArgumentException($"{dotnetType.Name} is not compatible with {dtype.ToString()}"); break; case ScalarType.BFloat16: - throw new ArgumentException($"No support for {dtype.ToString()} in TorchSharp"); + if (dotnetType != typeof(BFloat16)) + throw new ArgumentException($"{dotnetType.Name} is not compatible with {dtype.ToString()}"); + break; case ScalarType.Float16: #if NET6_0_OR_GREATER if (dotnetType != typeof(Half)) @@ -7471,6 +7473,7 @@ public enum ScalarType : sbyte { typeof(short), ScalarType.Int16 }, { typeof(int), ScalarType.Int32 }, { typeof(long), ScalarType.Int64 }, + { typeof(BFloat16), ScalarType.BFloat16 }, #if NET6_0_OR_GREATER { typeof(Half), ScalarType.Float16 }, #endif diff --git a/src/TorchSharp/Tensor/TensorExtensionMethods.cs b/src/TorchSharp/Tensor/TensorExtensionMethods.cs index cc4fc6b3c..d58b66624 100644 --- a/src/TorchSharp/Tensor/TensorExtensionMethods.cs +++ b/src/TorchSharp/Tensor/TensorExtensionMethods.cs @@ -578,6 +578,8 @@ true when typeof(T) == typeof(float) => tensor((array as float[])!, dimensions, requires_grad: requires_grad), true when typeof(T) == typeof(bool) => tensor((array as bool[])!, dimensions, requires_grad: requires_grad), + true when typeof(T) == typeof(BFloat16) => tensor((array as BFloat16[])!, dimensions, + requires_grad: requires_grad), _ => throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.") }; } @@ -592,7 +594,8 @@ true when typeof(T) == typeof(bool) => tensor((array as bool[])!, dimensions, /// public static Tensor ToTensor(this T scalar, Device? device = null, bool requires_grad = false) where T : struct { - if (requires_grad && typeof(T) != typeof(float) && typeof(T) != typeof(double)) { + if (requires_grad && typeof(T) != typeof(float) && typeof(T) != typeof(double) + && typeof(T) != typeof(BFloat16)) { throw new ArgumentException(nameof(requires_grad), "Only floating point types support gradients."); } @@ -610,6 +613,8 @@ public static Tensor ToTensor(this T scalar, Device? device = null, bool requ return tensor((float)(object)scalar, float32, device, requires_grad); if (typeof(T) == typeof(double)) return tensor((double)(object)scalar, float64, device, requires_grad); + if (typeof(T) == typeof(BFloat16)) + return tensor(new BFloat16[] { (BFloat16)(object)scalar }, bfloat16, device, requires_grad); throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported."); } @@ -633,6 +638,12 @@ public static Tensor ToTensor(this T scalar, Device? device = null, bool requ public static Half ToHalf(this Tensor value) => value.ToScalar().ToHalf(); #endif + /// + /// Explicitly convert a singleton tensor to a BFloat16 value. + /// + /// The input tensor + public static BFloat16 ToBFloat16(this Tensor value) => value.ToScalar().ToBFloat16(); + /// /// Explicitly convert a singleton tensor to a .NET scalar value. /// diff --git a/test/TorchSharpTest/TestTorchTensor.cs b/test/TorchSharpTest/TestTorchTensor.cs index 7663d9ef0..294a04935 100644 --- a/test/TorchSharpTest/TestTorchTensor.cs +++ b/test/TorchSharpTest/TestTorchTensor.cs @@ -696,6 +696,154 @@ public void DataFloat16() #endif } + [Fact] + [TestOf(nameof(torch.ones))] + public void DataBFloat16() + { + var x = torch.ones(5, torch.bfloat16); + Assert.Throws(() => x.data()); + Assert.Throws(() => x.data()); + Assert.Throws(() => x.data()); + Assert.Throws(() => x.data()); + Assert.Throws(() => x.data()); + Assert.Throws(() => x.data()); + Assert.Throws(() => x.data()); + Assert.Throws(() => x.data()); + Assert.Throws(() => x.data<(float, float)>()); + Assert.Throws(() => x.data()); + var accessor = x.data(); + Assert.Equal(5, accessor.Count); + Assert.Equal((BFloat16)1.0f, accessor[0]); + } + + [Fact] + [TestOf(nameof(Tensor))] + public void DataBFloat16Item() + { + // item on scalar tensor + var t = torch.tensor(3.14f, torch.bfloat16); + var val = t.item(); + Assert.Equal(3.140625f, val.ToSingle()); + } + + [Fact] + [TestOf(nameof(Tensor))] + public void DataBFloat16RoundTrip() + { + // Create tensor from BFloat16 array, read back + var input = new BFloat16[] { (BFloat16)1.0f, (BFloat16)2.0f, (BFloat16)3.0f, (BFloat16)0.5f }; + var t = torch.tensor(input); + Assert.Equal(ScalarType.BFloat16, t.dtype); + Assert.Equal(4, t.NumberOfElements); + var output = t.data().ToArray(); + Assert.Equal(input, output); + } + + [Fact] + [TestOf(nameof(torch.tensor))] + public void MDTensorFactoryBFloat16() + { + { + var array = new BFloat16[8]; + var t = torch.tensor(array); + Assert.Equal(1, t.ndim); + Assert.Equal(ScalarType.BFloat16, t.dtype); + } + + { + var array = new BFloat16[8]; + var t = torch.tensor(array, new long[] { 8 }); + Assert.Equal(1, t.ndim); + Assert.Equal(ScalarType.BFloat16, t.dtype); + } + + { + var array = new BFloat16[1, 2]; + var t = torch.tensor(array); + Assert.Equal(2, t.ndim); + Assert.Equal(new long[] { 1, 2 }, t.shape); + Assert.Equal(ScalarType.BFloat16, t.dtype); + } + + { + var array = new BFloat16[1, 2, 3]; + var t = torch.tensor(array); + Assert.Equal(3, t.ndim); + Assert.Equal(new long[] { 1, 2, 3 }, t.shape); + Assert.Equal(ScalarType.BFloat16, t.dtype); + } + + { + var array = new BFloat16[1, 2, 3, 4]; + var t = torch.tensor(array); + Assert.Equal(4, t.ndim); + Assert.Equal(new long[] { 1, 2, 3, 4 }, t.shape); + Assert.Equal(ScalarType.BFloat16, t.dtype); + } + + { + var array = new BFloat16[,,] { { { (BFloat16)1f, (BFloat16)2f }, { (BFloat16)3f, (BFloat16)4f } }, { { (BFloat16)5f, (BFloat16)6f }, { (BFloat16)7f, (BFloat16)8f } } }; + var t = torch.tensor(array); + Assert.Equal(3, t.ndim); + Assert.Equal(new long[] { 2, 2, 2 }, t.shape); + Assert.Equal(ScalarType.BFloat16, t.dtype); + Assert.Equal(array.Cast().ToArray(), t.data().ToArray()); + } + } + + [Fact] + [TestOf(nameof(TensorExtensionMethods.ToBFloat16))] + public void TensorToBFloat16Extension() + { + var t = torch.tensor(3.14f, torch.bfloat16); + var bf = t.ToBFloat16(); + Assert.Equal(3.140625f, bf.ToSingle()); + } + + [Fact] + [TestOf(nameof(BFloat16))] + public void BFloat16StructBasics() + { + // Size must be 2 bytes + Assert.Equal(2, Marshal.SizeOf()); + + // Precision loss: 333.0f -> BFloat16 -> float (matching PyTorch) + var bf = (BFloat16)333.0f; + Assert.Equal(332.0f, bf.ToSingle()); + + // Round-to-nearest-even: 3.14f → BFloat16 matches PyTorch native conversion + var bf_rne = (BFloat16)3.14f; + Assert.Equal(3.140625f, bf_rne.ToSingle()); + + // Round-trip for a value that fits exactly + var bf2 = (BFloat16)1.0f; + Assert.Equal(1.0f, bf2.ToSingle()); + + // Special values + Assert.True(BFloat16.IsNaN(BFloat16.NaN)); + Assert.True(BFloat16.IsPositiveInfinity(BFloat16.PositiveInfinity)); + Assert.True(BFloat16.IsNegativeInfinity(BFloat16.NegativeInfinity)); + Assert.True(BFloat16.IsFinite(BFloat16.One)); + Assert.False(BFloat16.IsFinite(BFloat16.NaN)); + + // Arithmetic + var a = (BFloat16)2.0f; + var b = (BFloat16)3.0f; + Assert.Equal(5.0f, (a + b).ToSingle()); + Assert.Equal(-1.0f, (a - b).ToSingle()); + Assert.Equal(6.0f, (a * b).ToSingle()); + + // Comparison + Assert.True(a < b); + Assert.True(b > a); + Assert.True(a == (BFloat16)2.0f); + Assert.True(a != b); + + // Equality + Assert.Equal((BFloat16)1.5f, (BFloat16)1.5f); + Assert.NotEqual((BFloat16)1.5f, (BFloat16)2.0f); + } + [Fact] [TestOf(nameof(torch.ones))] public void DataFloat32()