Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/Native/LibTorchSharp/THSTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,14 +576,14 @@ Tensor THSTensor_gather(
CATCH_TENSOR(torch::gather(*tensor, dim, *index));
}

Tensor THSTensor_gelu(const Tensor tensor)
Tensor THSTensor_gelu(const Tensor tensor, const char* approximate)
{
CATCH_TENSOR(torch::gelu(*tensor));
CATCH_TENSOR(torch::gelu(*tensor, approximate));
}

Tensor THSTensor_gelu_(const Tensor tensor)
Tensor THSTensor_gelu_(const Tensor tensor, const char* approximate)
{
CATCH_TENSOR(torch::gelu_(*tensor));
CATCH_TENSOR(torch::gelu_(*tensor, approximate));
}

Tensor THSTensor_get1(const Tensor tensor, int64_t index)
Expand Down
4 changes: 2 additions & 2 deletions src/Native/LibTorchSharp/THSTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,8 @@ EXPORT_API(Tensor) THSTensor_ge_scalar(const Tensor left, const Scalar right);

EXPORT_API(void) THSTensor_ge_scalar_(const Tensor left, const Scalar right);

EXPORT_API(Tensor) THSTensor_gelu(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_gelu_(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_gelu(const Tensor tensor, const char* approximate);
EXPORT_API(Tensor) THSTensor_gelu_(const Tensor tensor, const char* approximate);

EXPORT_API(Tensor) THSTensor_glu(const Tensor tensor, const int64_t dim);

Expand Down
28 changes: 26 additions & 2 deletions src/TorchSharp/NN/Activation/GELU.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@ namespace Modules
/// </summary>
public sealed class GELU : ParameterLessModule<Tensor, Tensor>
{
internal GELU(bool inplace) : base(nameof(GELU))
internal GELU(bool inplace, GELUApproximate approximate = GELUApproximate.none) : base(nameof(GELU))
{
this.inplace = inplace;
this.approximate = approximate;
}

public override Tensor forward(Tensor tensor)
{
return torch.nn.functional.gelu(tensor, inplace);
return torch.nn.functional.gelu(tensor, approximate, inplace);
}

public bool inplace {get; set; }

public GELUApproximate approximate { get; set; }
}
}

Expand All @@ -49,6 +52,16 @@ public static GELU GELU(bool inplace)
return new GELU(inplace);
}

/// <summary>
/// Gaussian Error Linear Units
/// </summary>
/// <param name="approximate">The approximation method to use. Default: none</param>
/// <param name="inplace">Do the operation in-place. Default: False</param>
public static GELU GELU(GELUApproximate approximate, bool inplace = false)
{
return new GELU(inplace, approximate);
}

public static partial class functional
{
/// <summary>
Expand All @@ -61,6 +74,17 @@ public static Tensor gelu(Tensor x, bool inplace)
return inplace ? x.gelu_().alias() : x.gelu();
}

/// <summary>
/// Gaussian Error Linear Units
/// </summary>
/// <param name="x">The input tensor</param>
/// <param name="approximate">The approximation method to use.</param>
/// <param name="inplace">Do the operation in-place. Default: False</param>
public static Tensor gelu(Tensor x, GELUApproximate approximate, bool inplace = false)
{
return inplace ? x.gelu_(approximate).alias() : x.gelu(approximate);
}

/// <summary>
/// Gaussian Error Linear Units
/// </summary>
Expand Down
8 changes: 4 additions & 4 deletions src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -706,11 +706,11 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern void THSTensor_elu_(IntPtr tensor, IntPtr alpha, IntPtr scale, IntPtr input_scale);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_gelu(IntPtr tensor);
[DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)]
internal static extern IntPtr THSTensor_gelu(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_gelu_(IntPtr tensor);
[DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)]
internal static extern IntPtr THSTensor_gelu_(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_glu(IntPtr tensor, long dim);
Expand Down
18 changes: 18 additions & 0 deletions src/TorchSharp/Tensor/Enums/GELUApproximate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
namespace TorchSharp
{
/// <summary>
/// Specifies the approximation method for the GELU activation function.
/// </summary>
public enum GELUApproximate
{
/// <summary>
/// Exact GELU computation.
/// </summary>
none,
/// <summary>
Comment on lines +5 to +13
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description mentions adding a GELU.Approximate enum, but the implementation introduces a top-level GELUApproximate enum. If the intended public API is the nested name, consider renaming/moving the enum; otherwise, update the PR description to match the shipped API surface.

Copilot uses AI. Check for mistakes.
/// Tanh-based approximation.
/// </summary>
tanh
}
}
20 changes: 18 additions & 2 deletions src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2977,15 +2977,31 @@ public Tensor elu_(Scalar alpha, Scalar scale, Scalar input_scale)

public Tensor gelu()
{
var res = NativeMethods.THSTensor_gelu(Handle);
var res = NativeMethods.THSTensor_gelu(Handle, "none");
if (res == IntPtr.Zero)
CheckForErrors();
return new Tensor(res);
}

public Tensor gelu(GELUApproximate approximate)
{
var res = NativeMethods.THSTensor_gelu(Handle, approximate == GELUApproximate.tanh ? "tanh" : "none");
if (res == IntPtr.Zero)
CheckForErrors();
return new Tensor(res);
Comment on lines +2986 to 2991
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gelu(GELUApproximate approximate) silently maps any unrecognized enum value to "none". That hides invalid inputs (e.g., casts) and can make debugging hard. Consider validating approximate and throwing an ArgumentOutOfRangeException (or similar) for unsupported values instead of defaulting to "none".

Copilot uses AI. Check for mistakes.
}

public Tensor gelu_()
{
var res = NativeMethods.THSTensor_gelu_(Handle);
var res = NativeMethods.THSTensor_gelu_(Handle, "none");
if (res == IntPtr.Zero)
CheckForErrors();
return new Tensor(res);
}

public Tensor gelu_(GELUApproximate approximate)
{
var res = NativeMethods.THSTensor_gelu_(Handle, approximate == GELUApproximate.tanh ? "tanh" : "none");
if (res == IntPtr.Zero)
CheckForErrors();
Comment on lines +3002 to 3006
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gelu_(GELUApproximate approximate) also falls back to "none" for any enum value other than tanh. For consistency and to avoid silently ignoring invalid values, validate the enum and fail fast when it’s outside the supported set.

Copilot uses AI. Check for mistakes.
return new Tensor(res);
Expand Down
22 changes: 22 additions & 0 deletions test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,28 @@ public void EvaluateGELU()
}
}

[Fact]
public void EvaluateGELUWithTanhApproximate()
{
var rel = GELU(GELUApproximate.tanh);

foreach (var device in TestUtils.AvailableDevices()) {
var input = torch.randn(new long[] { 64, 8 }, device: device) * 25.0;
var output = rel.call(input);
Assert.Equal(device.type, output.device_type);

var values = output.data<float>().ToArray();
Assert.Equal(input.shape, output.shape);
Assert.All(values, val => Assert.True(val >= -0.2));
}

// Verify that tanh approximate produces different results from exact
var x = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f });
var exact = torch.nn.functional.gelu(x);
var approx = torch.nn.functional.gelu(x, GELUApproximate.tanh);
Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5));
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new test exercises the out-of-place approximate GELU path via torch.nn.functional.gelu(x, GELUApproximate.tanh), but it doesn’t cover the newly added in-place overload x.gelu_(GELUApproximate.tanh). Adding an assertion that the in-place path runs and matches the out-of-place approximate result would help catch P/Invoke/native wiring regressions.

Suggested change
Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5));
Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5));
// Verify that the in-place tanh approximate matches the out-of-place result
var xInPlace = x.clone();
xInPlace.gelu_(GELUApproximate.tanh);
Assert.True(approx.allclose(xInPlace, rtol: 1e-5, atol: 1e-5));

Copilot uses AI. Check for mistakes.
}

[Fact]
public void EvaluatePReLU()
{
Expand Down
Loading