From 1ea4b6a8c6d32928741dda04ac33d572917f0b67 Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Fri, 27 Feb 2026 13:30:27 +0200 Subject: [PATCH 1/2] Add approximate parameter to GELU activation function Add support for the 'approximate' parameter in GELU, matching PyTorch's torch.nn.GELU(approximate='tanh') functionality. Changes: - Add GELU.Approximate enum with 'none' and 'tanh' values - Thread approximate parameter through all layers: native C++, PInvoke, Tensor methods, functional API, and module factory - Add new overloads (no breaking changes to existing API) - Add test for tanh approximation mode Fixes dotnet/TorchSharp#1368 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Native/LibTorchSharp/THSTensor.cpp | 8 ++-- src/Native/LibTorchSharp/THSTensor.h | 4 +- src/TorchSharp/NN/Activation/GELU.cs | 43 ++++++++++++++++++- .../PInvoke/LibTorchSharp.THSTensor.cs | 4 +- src/TorchSharp/Tensor/Tensor.cs | 20 ++++++++- test/TorchSharpTest/NN.cs | 22 ++++++++++ 6 files changed, 89 insertions(+), 12 deletions(-) diff --git a/src/Native/LibTorchSharp/THSTensor.cpp b/src/Native/LibTorchSharp/THSTensor.cpp index 7b4a0e55e..33599be06 100644 --- a/src/Native/LibTorchSharp/THSTensor.cpp +++ b/src/Native/LibTorchSharp/THSTensor.cpp @@ -576,14 +576,14 @@ Tensor THSTensor_gather( CATCH_TENSOR(torch::gather(*tensor, dim, *index)); } -Tensor THSTensor_gelu(const Tensor tensor) +Tensor THSTensor_gelu(const Tensor tensor, const char* approximate) { - CATCH_TENSOR(torch::gelu(*tensor)); + CATCH_TENSOR(torch::gelu(*tensor, approximate)); } -Tensor THSTensor_gelu_(const Tensor tensor) +Tensor THSTensor_gelu_(const Tensor tensor, const char* approximate) { - CATCH_TENSOR(torch::gelu_(*tensor)); + CATCH_TENSOR(torch::gelu_(*tensor, approximate)); } Tensor THSTensor_get1(const Tensor tensor, int64_t index) diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h index 73bff0403..40edc85ba 100644 --- a/src/Native/LibTorchSharp/THSTensor.h +++ b/src/Native/LibTorchSharp/THSTensor.h @@ -600,8 +600,8 @@ EXPORT_API(Tensor) THSTensor_ge_scalar(const Tensor left, const Scalar right); EXPORT_API(void) THSTensor_ge_scalar_(const Tensor left, const Scalar right); -EXPORT_API(Tensor) THSTensor_gelu(const Tensor tensor); -EXPORT_API(Tensor) THSTensor_gelu_(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_gelu(const Tensor tensor, const char* approximate); +EXPORT_API(Tensor) THSTensor_gelu_(const Tensor tensor, const char* approximate); EXPORT_API(Tensor) THSTensor_glu(const Tensor tensor, const int64_t dim); diff --git a/src/TorchSharp/NN/Activation/GELU.cs b/src/TorchSharp/NN/Activation/GELU.cs index 90c314b99..0b57c17e0 100644 --- a/src/TorchSharp/NN/Activation/GELU.cs +++ b/src/TorchSharp/NN/Activation/GELU.cs @@ -14,17 +14,35 @@ namespace Modules /// public sealed class GELU : ParameterLessModule { - internal GELU(bool inplace) : base(nameof(GELU)) + /// + /// Specifies the approximation method for GELU. + /// + public enum Approximate + { + /// + /// Exact GELU computation. + /// + none, + /// + /// Tanh-based approximation. + /// + tanh + } + + internal GELU(bool inplace, Approximate approximate = Approximate.none) : base(nameof(GELU)) { this.inplace = inplace; + this.approximate = approximate; } public override Tensor forward(Tensor tensor) { - return torch.nn.functional.gelu(tensor, inplace); + return torch.nn.functional.gelu(tensor, approximate, inplace); } public bool inplace {get; set; } + + public Approximate approximate { get; set; } } } @@ -49,6 +67,16 @@ public static GELU GELU(bool inplace) return new GELU(inplace); } + /// + /// Gaussian Error Linear Units + /// + /// The approximation method to use. Default: none + /// Do the operation in-place. Default: False + public static GELU GELU(GELU.Approximate approximate, bool inplace = false) + { + return new GELU(inplace, approximate); + } + public static partial class functional { /// @@ -61,6 +89,17 @@ public static Tensor gelu(Tensor x, bool inplace) return inplace ? x.gelu_().alias() : x.gelu(); } + /// + /// Gaussian Error Linear Units + /// + /// The input tensor + /// The approximation method to use. + /// Do the operation in-place. Default: False + public static Tensor gelu(Tensor x, GELU.Approximate approximate, bool inplace = false) + { + return inplace ? x.gelu_(approximate).alias() : x.gelu(approximate); + } + /// /// Gaussian Error Linear Units /// diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs index e8db2c2cb..108e1b740 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs @@ -707,10 +707,10 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern void THSTensor_elu_(IntPtr tensor, IntPtr alpha, IntPtr scale, IntPtr input_scale); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_gelu(IntPtr tensor); + internal static extern IntPtr THSTensor_gelu(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_gelu_(IntPtr tensor); + internal static extern IntPtr THSTensor_gelu_(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_glu(IntPtr tensor, long dim); diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index ea70b83e1..7093a5a90 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -2977,7 +2977,15 @@ public Tensor elu_(Scalar alpha, Scalar scale, Scalar input_scale) public Tensor gelu() { - var res = NativeMethods.THSTensor_gelu(Handle); + var res = NativeMethods.THSTensor_gelu(Handle, "none"); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + public Tensor gelu(TorchSharp.Modules.GELU.Approximate approximate) + { + var res = NativeMethods.THSTensor_gelu(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none"); if (res == IntPtr.Zero) CheckForErrors(); return new Tensor(res); @@ -2985,7 +2993,15 @@ public Tensor gelu() public Tensor gelu_() { - var res = NativeMethods.THSTensor_gelu_(Handle); + var res = NativeMethods.THSTensor_gelu_(Handle, "none"); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + public Tensor gelu_(TorchSharp.Modules.GELU.Approximate approximate) + { + var res = NativeMethods.THSTensor_gelu_(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none"); if (res == IntPtr.Zero) CheckForErrors(); return new Tensor(res); diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index f2ed50db3..b9b5a93bb 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -618,6 +618,28 @@ public void EvaluateGELU() } } + [Fact] + public void EvaluateGELUWithTanhApproximate() + { + var rel = GELU(Modules.GELU.Approximate.tanh); + + foreach (var device in TestUtils.AvailableDevices()) { + var input = torch.randn(new long[] { 64, 8 }, device: device) * 25.0; + var output = rel.call(input); + Assert.Equal(device.type, output.device_type); + + var values = output.data().ToArray(); + Assert.Equal(input.shape, output.shape); + Assert.All(values, val => Assert.True(val >= -0.2)); + } + + // Verify that tanh approximate produces different results from exact + var x = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f }); + var exact = torch.nn.functional.gelu(x); + var approx = torch.nn.functional.gelu(x, Modules.GELU.Approximate.tanh); + Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5)); + } + [Fact] public void EvaluatePReLU() { From a098ea4381e58f42b0f528d5622678ebfac3a56e Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Wed, 11 Mar 2026 15:54:36 +0100 Subject: [PATCH 2/2] Address PR review comments for GELU approximate parameter - Move Approximate enum from GELU module class to neutral TorchSharp namespace as GELUApproximate, removing Tensor/functional layer dependency on Modules layer - Add CharSet, BestFitMapping, ThrowOnUnmappableChar attributes to THSTensor_gelu/gelu_ DllImport declarations to match existing LPStr-based imports pattern - Update all references in Tensor.cs, GELU.cs, and tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/TorchSharp/NN/Activation/GELU.cs | 23 ++++--------------- .../PInvoke/LibTorchSharp.THSTensor.cs | 4 ++-- .../Tensor/Enums/GELUApproximate.cs | 18 +++++++++++++++ src/TorchSharp/Tensor/Tensor.cs | 8 +++---- test/TorchSharpTest/NN.cs | 4 ++-- 5 files changed, 30 insertions(+), 27 deletions(-) create mode 100644 src/TorchSharp/Tensor/Enums/GELUApproximate.cs diff --git a/src/TorchSharp/NN/Activation/GELU.cs b/src/TorchSharp/NN/Activation/GELU.cs index 0b57c17e0..9deb7f5d8 100644 --- a/src/TorchSharp/NN/Activation/GELU.cs +++ b/src/TorchSharp/NN/Activation/GELU.cs @@ -14,22 +14,7 @@ namespace Modules /// public sealed class GELU : ParameterLessModule { - /// - /// Specifies the approximation method for GELU. - /// - public enum Approximate - { - /// - /// Exact GELU computation. - /// - none, - /// - /// Tanh-based approximation. - /// - tanh - } - - internal GELU(bool inplace, Approximate approximate = Approximate.none) : base(nameof(GELU)) + internal GELU(bool inplace, GELUApproximate approximate = GELUApproximate.none) : base(nameof(GELU)) { this.inplace = inplace; this.approximate = approximate; @@ -42,7 +27,7 @@ public override Tensor forward(Tensor tensor) public bool inplace {get; set; } - public Approximate approximate { get; set; } + public GELUApproximate approximate { get; set; } } } @@ -72,7 +57,7 @@ public static GELU GELU(bool inplace) /// /// The approximation method to use. Default: none /// Do the operation in-place. Default: False - public static GELU GELU(GELU.Approximate approximate, bool inplace = false) + public static GELU GELU(GELUApproximate approximate, bool inplace = false) { return new GELU(inplace, approximate); } @@ -95,7 +80,7 @@ public static Tensor gelu(Tensor x, bool inplace) /// The input tensor /// The approximation method to use. /// Do the operation in-place. Default: False - public static Tensor gelu(Tensor x, GELU.Approximate approximate, bool inplace = false) + public static Tensor gelu(Tensor x, GELUApproximate approximate, bool inplace = false) { return inplace ? x.gelu_(approximate).alias() : x.gelu(approximate); } diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs index 108e1b740..141ba55a2 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs @@ -706,10 +706,10 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern void THSTensor_elu_(IntPtr tensor, IntPtr alpha, IntPtr scale, IntPtr input_scale); - [DllImport("LibTorchSharp")] + [DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)] internal static extern IntPtr THSTensor_gelu(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate); - [DllImport("LibTorchSharp")] + [DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)] internal static extern IntPtr THSTensor_gelu_(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate); [DllImport("LibTorchSharp")] diff --git a/src/TorchSharp/Tensor/Enums/GELUApproximate.cs b/src/TorchSharp/Tensor/Enums/GELUApproximate.cs new file mode 100644 index 000000000..2a247a3e7 --- /dev/null +++ b/src/TorchSharp/Tensor/Enums/GELUApproximate.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +namespace TorchSharp +{ + /// + /// Specifies the approximation method for the GELU activation function. + /// + public enum GELUApproximate + { + /// + /// Exact GELU computation. + /// + none, + /// + /// Tanh-based approximation. + /// + tanh + } +} diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index 7093a5a90..2eb1c579c 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -2983,9 +2983,9 @@ public Tensor gelu() return new Tensor(res); } - public Tensor gelu(TorchSharp.Modules.GELU.Approximate approximate) + public Tensor gelu(GELUApproximate approximate) { - var res = NativeMethods.THSTensor_gelu(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none"); + var res = NativeMethods.THSTensor_gelu(Handle, approximate == GELUApproximate.tanh ? "tanh" : "none"); if (res == IntPtr.Zero) CheckForErrors(); return new Tensor(res); @@ -2999,9 +2999,9 @@ public Tensor gelu_() return new Tensor(res); } - public Tensor gelu_(TorchSharp.Modules.GELU.Approximate approximate) + public Tensor gelu_(GELUApproximate approximate) { - var res = NativeMethods.THSTensor_gelu_(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none"); + var res = NativeMethods.THSTensor_gelu_(Handle, approximate == GELUApproximate.tanh ? "tanh" : "none"); if (res == IntPtr.Zero) CheckForErrors(); return new Tensor(res); diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index b9b5a93bb..f3dc089eb 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -621,7 +621,7 @@ public void EvaluateGELU() [Fact] public void EvaluateGELUWithTanhApproximate() { - var rel = GELU(Modules.GELU.Approximate.tanh); + var rel = GELU(GELUApproximate.tanh); foreach (var device in TestUtils.AvailableDevices()) { var input = torch.randn(new long[] { 64, 8 }, device: device) * 25.0; @@ -636,7 +636,7 @@ public void EvaluateGELUWithTanhApproximate() // Verify that tanh approximate produces different results from exact var x = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f }); var exact = torch.nn.functional.gelu(x); - var approx = torch.nn.functional.gelu(x, Modules.GELU.Approximate.tanh); + var approx = torch.nn.functional.gelu(x, GELUApproximate.tanh); Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5)); }