Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions src/TorchSharp/BFloat16.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.Globalization;
using System.Runtime.InteropServices;

#nullable enable
namespace TorchSharp
{
/// <summary>
/// Represents a 16-bit brain floating-point number (BFloat16).
/// Binary layout: 1 sign bit, 8 exponent bits, 7 mantissa bits — the upper 16 bits of IEEE 754 float32.
/// Binary-compatible with c10::BFloat16 in LibTorch.
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public readonly struct BFloat16 : IComparable<BFloat16>, IEquatable<BFloat16>, IComparable, IFormattable
{
internal readonly ushort value;

internal BFloat16(ushort rawValue, bool _)
{
value = rawValue;
}

/// <summary>
/// Creates a BFloat16 from a float value using round-to-nearest-even (matching PyTorch c10::BFloat16).
/// </summary>
public BFloat16(float f)
{
value = FloatToBFloat16Bits(f);
}

/// <summary>
/// Creates a BFloat16 from the raw 16-bit representation.
/// </summary>
public static BFloat16 FromRawValue(ushort rawValue) => new BFloat16(rawValue, false);

// --- Conversion to/from float ---

private static unsafe ushort FloatToBFloat16Bits(float f)
{
uint bits = *(uint*)&f;
// NaN: preserve payload, just truncate
if ((bits & 0x7F800000u) == 0x7F800000u && (bits & 0x007FFFFFu) != 0)
return (ushort)(bits >> 16 | 0x0040u); // quiet NaN
// Round-to-nearest-even (matching PyTorch c10::BFloat16)
uint lsb = (bits >> 16) & 1u;
uint roundingBias = 0x7FFFu + lsb;
bits += roundingBias;
return (ushort)(bits >> 16);
}

private static unsafe float BFloat16BitsToFloat(ushort raw)
{
int bits = raw << 16;
return *(float*)&bits;
}

/// <summary>
/// Converts this BFloat16 to a float.
/// </summary>
public float ToSingle() => BFloat16BitsToFloat(value);

// --- Conversion operators ---

public static explicit operator float(BFloat16 bf) => bf.ToSingle();
public static explicit operator double(BFloat16 bf) => bf.ToSingle();
public static explicit operator BFloat16(float f) => new BFloat16(f);
public static explicit operator BFloat16(double d) => new BFloat16((float)d);

// --- Arithmetic operators (promote to float, truncate back) ---

public static BFloat16 operator +(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() + b.ToSingle());
public static BFloat16 operator -(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() - b.ToSingle());
public static BFloat16 operator *(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() * b.ToSingle());
public static BFloat16 operator /(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() / b.ToSingle());
public static BFloat16 operator %(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() % b.ToSingle());
public static BFloat16 operator -(BFloat16 a) => new BFloat16(-a.ToSingle());

// --- Comparison operators ---

public static bool operator ==(BFloat16 a, BFloat16 b) => a.ToSingle() == b.ToSingle();
public static bool operator !=(BFloat16 a, BFloat16 b) => a.ToSingle() != b.ToSingle();
public static bool operator <(BFloat16 a, BFloat16 b) => a.ToSingle() < b.ToSingle();
public static bool operator >(BFloat16 a, BFloat16 b) => a.ToSingle() > b.ToSingle();
public static bool operator <=(BFloat16 a, BFloat16 b) => a.ToSingle() <= b.ToSingle();
public static bool operator >=(BFloat16 a, BFloat16 b) => a.ToSingle() >= b.ToSingle();

// --- IEquatable / IComparable ---

public bool Equals(BFloat16 other) => value == other.value;
public override bool Equals(object? obj) => obj is BFloat16 other && Equals(other);
public override int GetHashCode() => value.GetHashCode();
Comment on lines +90 to +92
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BFloat16 equality is internally inconsistent: ==/!= compare via ToSingle(), but Equals/GetHashCode compare the raw 16-bit payload. This can produce cases where (a == b) is true but a.Equals(b) is false (e.g., +0 vs -0), which breaks dictionary/set behavior and general .NET equality expectations. Align operators and Equals/GetHashCode to use the same equality definition (either bitwise or float-like) and keep them consistent.

Suggested change
public bool Equals(BFloat16 other) => value == other.value;
public override bool Equals(object? obj) => obj is BFloat16 other && Equals(other);
public override int GetHashCode() => value.GetHashCode();
public bool Equals(BFloat16 other) => ToSingle() == other.ToSingle();
public override bool Equals(object? obj) => obj is BFloat16 other && Equals(other);
public override int GetHashCode() => ToSingle().GetHashCode();

Copilot uses AI. Check for mistakes.

public int CompareTo(BFloat16 other) => ToSingle().CompareTo(other.ToSingle());
public int CompareTo(object? obj)
{
if (obj is null) return 1;
if (obj is BFloat16 other) return CompareTo(other);
throw new ArgumentException("Object must be of type BFloat16.");
}

// --- Formatting ---

public override string ToString() => ToSingle().ToString();
public string ToString(string? format, IFormatProvider? formatProvider) => ToSingle().ToString(format, formatProvider);

// --- Constants ---

public static readonly BFloat16 Zero = FromRawValue(0x0000);
public static readonly BFloat16 One = FromRawValue(0x3F80);
public static readonly BFloat16 NaN = FromRawValue(0x7FC0);
public static readonly BFloat16 PositiveInfinity = FromRawValue(0x7F80);
public static readonly BFloat16 NegativeInfinity = FromRawValue(0xFF80);
public static readonly BFloat16 MaxValue = FromRawValue(0x7F7F); // ~3.39e+38
public static readonly BFloat16 MinValue = FromRawValue(0xFF7F); // ~-3.39e+38
public static readonly BFloat16 Epsilon = FromRawValue(0x0080); // smallest normal
public static readonly BFloat16 SmallestSubnormal = FromRawValue(0x0001);
Comment on lines +116 to +117
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constant named Epsilon is documented/implemented as “smallest normal” (0x0080 == 2^-126). In .NET numeric types, Epsilon typically means the smallest positive value > 0 (usually the smallest subnormal). Since you already expose SmallestSubnormal, consider either redefining Epsilon to match .NET expectations or renaming this constant to something like MinNormal to avoid surprising API behavior.

Suggested change
public static readonly BFloat16 Epsilon = FromRawValue(0x0080); // smallest normal
public static readonly BFloat16 SmallestSubnormal = FromRawValue(0x0001);
public static readonly BFloat16 SmallestSubnormal = FromRawValue(0x0001); // smallest positive (subnormal)
public static readonly BFloat16 Epsilon = SmallestSubnormal; // .NET-style epsilon: smallest positive > 0
public static readonly BFloat16 MinNormal = FromRawValue(0x0080); // smallest normal

Copilot uses AI. Check for mistakes.

// --- Static helpers ---

public static bool IsNaN(BFloat16 bf) => float.IsNaN(bf.ToSingle());
public static bool IsInfinity(BFloat16 bf) => float.IsInfinity(bf.ToSingle());
public static bool IsPositiveInfinity(BFloat16 bf) => float.IsPositiveInfinity(bf.ToSingle());
public static bool IsNegativeInfinity(BFloat16 bf) => float.IsNegativeInfinity(bf.ToSingle());
public static bool IsFinite(BFloat16 bf) => !IsInfinity(bf) && !IsNaN(bf);
}
}
29 changes: 29 additions & 0 deletions src/TorchSharp/Scalar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ public static implicit operator Scalar(Half value)
}
#endif

/// <summary>
/// Implicitly convert a BFloat16 value to Scalar
/// </summary>
/// <param name="value">The scalar value.</param>
public static implicit operator Scalar(BFloat16 value)
{
return value.ToScalar();
}

/// <summary>
/// Implicitly convert a .NET scalar value to Scalar
/// </summary>
Expand Down Expand Up @@ -282,6 +291,16 @@ public static Scalar ToBFloat16Scalar(this float value)
return new Scalar(THSTorch_bfloat16_to_scalar(value));
}

/// <summary>
/// Explicitly construct a Scalar from a BFloat16 value.
/// </summary>
/// <param name="value">The input scalar value</param>
public static Scalar ToScalar(this BFloat16 value)
{
torch.InitializeDeviceType(DeviceType.CPU);
return new Scalar(THSTorch_bfloat16_to_scalar(value.ToSingle()));
}

#if NET6_0_OR_GREATER
/// <summary>
/// Explicitly convert a Scalar value to a .NET scalar
Expand All @@ -295,6 +314,16 @@ public static Half ToHalf(this Scalar value)
}
#endif

/// <summary>
/// Explicitly convert a Scalar value to a BFloat16.
/// </summary>
/// <param name="value">The input value.</param>
public static BFloat16 ToBFloat16(this Scalar value)
{
THSTorch_scalar_to_bfloat16(value.Handle, out ushort res);
return BFloat16.FromRawValue(res);
}

/// <summary>
/// Explicitly convert a Scalar value to a .NET scalar
/// </summary>
Expand Down
2 changes: 2 additions & 0 deletions src/TorchSharp/Tensor/Factories/Tensor.Factories.cs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ public static Tensor frombuffer(Array rawArray, ScalarType dtype, long count = -

switch (origType) {
case ScalarType.Int16:
case ScalarType.Float16:
case ScalarType.BFloat16:
offset *= 2;
break;
case ScalarType.Int32:
Expand Down
10 changes: 10 additions & 0 deletions src/TorchSharp/Tensor/Factories/as_tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,15 @@ public static Tensor as_tensor(System.Numerics.Complex[] rawArray, torch.ScalarT
{
return torch.from_array(rawArray, dtype, device);
}

public static Tensor as_tensor(IList<BFloat16> rawArray, torch.ScalarType? dtype = null, torch.Device? device = null)
{
return torch.from_array(rawArray.ToArray(), dtype, device);
}

public static Tensor as_tensor(BFloat16[] rawArray, torch.ScalarType? dtype = null, torch.Device? device = null)
{
return torch.from_array(rawArray, dtype, device);
}
}
}
125 changes: 125 additions & 0 deletions src/TorchSharp/Tensor/Factories/tensor_BFloat16.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Linq;

#nullable enable
namespace TorchSharp
{
public static partial class torch
{
/// <summary>
/// Create a tensor from an array of values, shaping it based on the shape passed in.
/// </summary>
/// <remarks>The Torch runtime does not take ownership of the data, so there is no device argument.</remarks>
[Pure]
public static Tensor tensor(IList<BFloat16> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return _tensor_generic(rawArray.ToArray(), dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, false, names);
}

/// <summary>
/// Create a tensor from an array of values, shaping it based on the shape passed in.
/// </summary>
[Pure]
public static Tensor tensor(BFloat16[] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return _tensor_generic(rawArray, stackalloc long[] { rawArray.LongLength }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
}

/// <summary>
/// Create a tensor from an array of values, shaping it based on the shape passed in.
/// </summary>
[Pure]
public static Tensor tensor(BFloat16[] rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
}

/// <summary>
/// Create a 1-D tensor from an array of values, shaping it based on the input array.
/// </summary>
/// <remarks>The Torch runtime does not take ownership of the data, so there is no device argument.</remarks>
[Pure]
public static Tensor tensor(IList<BFloat16> rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return tensor(rawArray, stackalloc long[] { (long)rawArray.Count }, dtype, device, requires_grad, names: names);
}

/// <summary>
/// Create a tensor from an array of values, organizing it as a two-dimensional tensor.
/// </summary>
/// <remarks>
/// The Torch runtime does not take ownership of the data, so there is no device argument.
/// The input array must have rows * columns elements.
/// </remarks>
[Pure]
public static Tensor tensor(IList<BFloat16> rawArray, long rows, long columns, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return tensor(rawArray, stackalloc long[] { rows, columns }, dtype, device, requires_grad, names: names);
}

/// <summary>
/// Create a tensor from an array of values, organizing it as a three-dimensional tensor.
/// </summary>
/// <remarks>
/// The Torch runtime does not take ownership of the data, so there is no device argument.
/// The input array must have dim0*dim1*dim2 elements.
/// </remarks>
[Pure]
public static Tensor tensor(IList<BFloat16> rawArray, long dim0, long dim1, long dim2, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2 }, dtype, device, requires_grad, names: names);
}

/// <summary>
/// Create a tensor from an array of values, organizing it as a four-dimensional tensor.
/// </summary>
/// <remarks>
/// The Torch runtime does not take ownership of the data, so there is no device argument.
/// The input array must have dim0*dim1*dim2*dim3 elements.
/// </remarks>
[Pure]
public static Tensor tensor(IList<BFloat16> rawArray, long dim0, long dim1, long dim2, long dim3, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2, dim3 }, dtype, device, requires_grad, names: names);
}

/// <summary>
/// Create a two-dimensional tensor from a two-dimensional array of values.
/// </summary>
[Pure]
public static Tensor tensor(BFloat16[,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
}

/// <summary>
/// Create a three-dimensional tensor from a three-dimensional array of values.
/// </summary>
[Pure]
public static Tensor tensor(BFloat16[,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
}

/// <summary>
/// Create a four-dimensional tensor from a four-dimensional array of values.
/// </summary>
[Pure]
public static Tensor tensor(BFloat16[,,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
}

/// <summary>
/// Create a tensor from an array of values, shaping it based on the shape passed in.
/// </summary>
[Pure]
public static Tensor tensor(Memory<BFloat16> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
{
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
}
}
}
14 changes: 13 additions & 1 deletion src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,9 @@ internal void ValidateType(Type dotnetType)
throw new ArgumentException($"{dotnetType.Name} is not compatible with {dtype.ToString()}");
break;
case ScalarType.BFloat16:
throw new ArgumentException($"No support for {dtype.ToString()} in TorchSharp");
if (dotnetType != typeof(BFloat16))
throw new ArgumentException($"{dotnetType.Name} is not compatible with {dtype.ToString()}");
break;
case ScalarType.Float16:
#if NET6_0_OR_GREATER
if (dotnetType != typeof(Half))
Expand Down Expand Up @@ -6886,6 +6888,14 @@ private static string ToCSharpString(Tensor t, long mdim, bool isFCreate, string
case ScalarType.Float64:
if (top) sb.Append("double ");
break;
case ScalarType.BFloat16:
if (top) sb.Append("bfloat16 ");
appendChar = "f";
break;
case ScalarType.Float16:
if (top) sb.Append("float16 ");
appendChar = "f";
break;
case ScalarType.ComplexFloat32:
if (top) sb.Append("complex32 ");
break;
Expand Down Expand Up @@ -7166,6 +7176,7 @@ private static void PrintValue(StringBuilder builder, ScalarType type, Scalar va
case ScalarType.Bool:
builder.Append(value.ToBoolean().ToString(cultureInfo));
break;
case ScalarType.BFloat16:
case ScalarType.Float16:
builder.Append(value.ToSingle().ToString(fltFormat, cultureInfo));
break;
Expand Down Expand Up @@ -7462,6 +7473,7 @@ public enum ScalarType : sbyte
{ typeof(short), ScalarType.Int16 },
{ typeof(int), ScalarType.Int32 },
{ typeof(long), ScalarType.Int64 },
{ typeof(BFloat16), ScalarType.BFloat16 },
#if NET6_0_OR_GREATER
{ typeof(Half), ScalarType.Float16 },
#endif
Expand Down
Loading
Loading