-
Notifications
You must be signed in to change notification settings - Fork 218
Add support for bfloat16 type #1544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2c11a4c
fb459e7
388a93b
6beec00
bd8ded6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,127 @@ | ||||||||||||
| // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. | ||||||||||||
| using System; | ||||||||||||
| using System.Globalization; | ||||||||||||
| using System.Runtime.InteropServices; | ||||||||||||
|
|
||||||||||||
| #nullable enable | ||||||||||||
| namespace TorchSharp | ||||||||||||
| { | ||||||||||||
| /// <summary> | ||||||||||||
| /// Represents a 16-bit brain floating-point number (BFloat16). | ||||||||||||
| /// Binary layout: 1 sign bit, 8 exponent bits, 7 mantissa bits — the upper 16 bits of IEEE 754 float32. | ||||||||||||
| /// Binary-compatible with c10::BFloat16 in LibTorch. | ||||||||||||
| /// </summary> | ||||||||||||
| [StructLayout(LayoutKind.Sequential)] | ||||||||||||
| public readonly struct BFloat16 : IComparable<BFloat16>, IEquatable<BFloat16>, IComparable, IFormattable | ||||||||||||
| { | ||||||||||||
| internal readonly ushort value; | ||||||||||||
|
|
||||||||||||
| internal BFloat16(ushort rawValue, bool _) | ||||||||||||
| { | ||||||||||||
| value = rawValue; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| /// <summary> | ||||||||||||
| /// Creates a BFloat16 from a float value using round-to-nearest-even (matching PyTorch c10::BFloat16). | ||||||||||||
| /// </summary> | ||||||||||||
| public BFloat16(float f) | ||||||||||||
| { | ||||||||||||
| value = FloatToBFloat16Bits(f); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| /// <summary> | ||||||||||||
| /// Creates a BFloat16 from the raw 16-bit representation. | ||||||||||||
| /// </summary> | ||||||||||||
| public static BFloat16 FromRawValue(ushort rawValue) => new BFloat16(rawValue, false); | ||||||||||||
|
|
||||||||||||
| // --- Conversion to/from float --- | ||||||||||||
|
|
||||||||||||
| private static unsafe ushort FloatToBFloat16Bits(float f) | ||||||||||||
| { | ||||||||||||
| uint bits = *(uint*)&f; | ||||||||||||
| // NaN: preserve payload, just truncate | ||||||||||||
| if ((bits & 0x7F800000u) == 0x7F800000u && (bits & 0x007FFFFFu) != 0) | ||||||||||||
| return (ushort)(bits >> 16 | 0x0040u); // quiet NaN | ||||||||||||
| // Round-to-nearest-even (matching PyTorch c10::BFloat16) | ||||||||||||
| uint lsb = (bits >> 16) & 1u; | ||||||||||||
| uint roundingBias = 0x7FFFu + lsb; | ||||||||||||
| bits += roundingBias; | ||||||||||||
| return (ushort)(bits >> 16); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| private static unsafe float BFloat16BitsToFloat(ushort raw) | ||||||||||||
| { | ||||||||||||
| int bits = raw << 16; | ||||||||||||
| return *(float*)&bits; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| /// <summary> | ||||||||||||
| /// Converts this BFloat16 to a float. | ||||||||||||
| /// </summary> | ||||||||||||
| public float ToSingle() => BFloat16BitsToFloat(value); | ||||||||||||
|
|
||||||||||||
| // --- Conversion operators --- | ||||||||||||
|
|
||||||||||||
| public static explicit operator float(BFloat16 bf) => bf.ToSingle(); | ||||||||||||
| public static explicit operator double(BFloat16 bf) => bf.ToSingle(); | ||||||||||||
| public static explicit operator BFloat16(float f) => new BFloat16(f); | ||||||||||||
| public static explicit operator BFloat16(double d) => new BFloat16((float)d); | ||||||||||||
|
|
||||||||||||
| // --- Arithmetic operators (promote to float, truncate back) --- | ||||||||||||
|
|
||||||||||||
| public static BFloat16 operator +(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() + b.ToSingle()); | ||||||||||||
| public static BFloat16 operator -(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() - b.ToSingle()); | ||||||||||||
| public static BFloat16 operator *(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() * b.ToSingle()); | ||||||||||||
| public static BFloat16 operator /(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() / b.ToSingle()); | ||||||||||||
| public static BFloat16 operator %(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() % b.ToSingle()); | ||||||||||||
| public static BFloat16 operator -(BFloat16 a) => new BFloat16(-a.ToSingle()); | ||||||||||||
|
|
||||||||||||
| // --- Comparison operators --- | ||||||||||||
|
|
||||||||||||
| public static bool operator ==(BFloat16 a, BFloat16 b) => a.ToSingle() == b.ToSingle(); | ||||||||||||
| public static bool operator !=(BFloat16 a, BFloat16 b) => a.ToSingle() != b.ToSingle(); | ||||||||||||
| public static bool operator <(BFloat16 a, BFloat16 b) => a.ToSingle() < b.ToSingle(); | ||||||||||||
| public static bool operator >(BFloat16 a, BFloat16 b) => a.ToSingle() > b.ToSingle(); | ||||||||||||
| public static bool operator <=(BFloat16 a, BFloat16 b) => a.ToSingle() <= b.ToSingle(); | ||||||||||||
| public static bool operator >=(BFloat16 a, BFloat16 b) => a.ToSingle() >= b.ToSingle(); | ||||||||||||
|
|
||||||||||||
| // --- IEquatable / IComparable --- | ||||||||||||
|
|
||||||||||||
| public bool Equals(BFloat16 other) => value == other.value; | ||||||||||||
| public override bool Equals(object? obj) => obj is BFloat16 other && Equals(other); | ||||||||||||
| public override int GetHashCode() => value.GetHashCode(); | ||||||||||||
|
|
||||||||||||
| public int CompareTo(BFloat16 other) => ToSingle().CompareTo(other.ToSingle()); | ||||||||||||
| public int CompareTo(object? obj) | ||||||||||||
| { | ||||||||||||
| if (obj is null) return 1; | ||||||||||||
| if (obj is BFloat16 other) return CompareTo(other); | ||||||||||||
| throw new ArgumentException("Object must be of type BFloat16."); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // --- Formatting --- | ||||||||||||
|
|
||||||||||||
| public override string ToString() => ToSingle().ToString(); | ||||||||||||
| public string ToString(string? format, IFormatProvider? formatProvider) => ToSingle().ToString(format, formatProvider); | ||||||||||||
|
|
||||||||||||
| // --- Constants --- | ||||||||||||
|
|
||||||||||||
| public static readonly BFloat16 Zero = FromRawValue(0x0000); | ||||||||||||
| public static readonly BFloat16 One = FromRawValue(0x3F80); | ||||||||||||
| public static readonly BFloat16 NaN = FromRawValue(0x7FC0); | ||||||||||||
| public static readonly BFloat16 PositiveInfinity = FromRawValue(0x7F80); | ||||||||||||
| public static readonly BFloat16 NegativeInfinity = FromRawValue(0xFF80); | ||||||||||||
| public static readonly BFloat16 MaxValue = FromRawValue(0x7F7F); // ~3.39e+38 | ||||||||||||
| public static readonly BFloat16 MinValue = FromRawValue(0xFF7F); // ~-3.39e+38 | ||||||||||||
| public static readonly BFloat16 Epsilon = FromRawValue(0x0080); // smallest normal | ||||||||||||
| public static readonly BFloat16 SmallestSubnormal = FromRawValue(0x0001); | ||||||||||||
|
Comment on lines
+116
to
+117
|
||||||||||||
| public static readonly BFloat16 Epsilon = FromRawValue(0x0080); // smallest normal | |
| public static readonly BFloat16 SmallestSubnormal = FromRawValue(0x0001); | |
| public static readonly BFloat16 SmallestSubnormal = FromRawValue(0x0001); // smallest positive (subnormal) | |
| public static readonly BFloat16 Epsilon = SmallestSubnormal; // .NET-style epsilon: smallest positive > 0 | |
| public static readonly BFloat16 MinNormal = FromRawValue(0x0080); // smallest normal |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. | ||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.Diagnostics.Contracts; | ||
| using System.Linq; | ||
|
|
||
| #nullable enable | ||
| namespace TorchSharp | ||
| { | ||
| public static partial class torch | ||
| { | ||
| /// <summary> | ||
| /// Create a tensor from an array of values, shaping it based on the shape passed in. | ||
| /// </summary> | ||
| /// <remarks>The Torch runtime does not take ownership of the data, so there is no device argument.</remarks> | ||
| [Pure] | ||
| public static Tensor tensor(IList<BFloat16> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray.ToArray(), dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, false, names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, shaping it based on the shape passed in. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16[] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, stackalloc long[] { rawArray.LongLength }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, shaping it based on the shape passed in. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16[] rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a 1-D tensor from an array of values, shaping it based on the input array. | ||
| /// </summary> | ||
| /// <remarks>The Torch runtime does not take ownership of the data, so there is no device argument.</remarks> | ||
| [Pure] | ||
| public static Tensor tensor(IList<BFloat16> rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return tensor(rawArray, stackalloc long[] { (long)rawArray.Count }, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, organizing it as a two-dimensional tensor. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// The Torch runtime does not take ownership of the data, so there is no device argument. | ||
| /// The input array must have rows * columns elements. | ||
| /// </remarks> | ||
| [Pure] | ||
| public static Tensor tensor(IList<BFloat16> rawArray, long rows, long columns, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return tensor(rawArray, stackalloc long[] { rows, columns }, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, organizing it as a three-dimensional tensor. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// The Torch runtime does not take ownership of the data, so there is no device argument. | ||
| /// The input array must have dim0*dim1*dim2 elements. | ||
| /// </remarks> | ||
| [Pure] | ||
| public static Tensor tensor(IList<BFloat16> rawArray, long dim0, long dim1, long dim2, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2 }, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, organizing it as a four-dimensional tensor. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// The Torch runtime does not take ownership of the data, so there is no device argument. | ||
| /// The input array must have dim0*dim1*dim2*dim3 elements. | ||
| /// </remarks> | ||
| [Pure] | ||
| public static Tensor tensor(IList<BFloat16> rawArray, long dim0, long dim1, long dim2, long dim3, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2, dim3 }, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a two-dimensional tensor from a two-dimensional array of values. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16[,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a three-dimensional tensor from a three-dimensional array of values. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16[,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a four-dimensional tensor from a four-dimensional array of values. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16[,,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, shaping it based on the shape passed in. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(Memory<BFloat16> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BFloat16 equality is internally inconsistent: ==/!= compare via ToSingle(), but Equals/GetHashCode compare the raw 16-bit payload. This can produce cases where (a == b) is true but a.Equals(b) is false (e.g., +0 vs -0), which breaks dictionary/set behavior and general .NET equality expectations. Align operators and Equals/GetHashCode to use the same equality definition (either bitwise or float-like) and keep them consistent.