Skip to content

Commit f3b559a

Browse files
Address PR #1529 review comments
- Add eps clamping documentation to torch.logit_ XML docs - Add torch.i0_() static wrapper method with docs - Add test for torch.i0_() static method Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 0e12f38 commit f3b559a

2 files changed

Lines changed: 21 additions & 0 deletions

File tree

src/TorchSharp/Tensor/torch.PointwiseOps.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,10 @@ public static ICollection<Tensor> gradient(Tensor input, int spacing = 1, long[]
11091109
// https://pytorch.org/docs/stable/generated/torch.logit
11101110
/// <summary>
11111111
/// Returns the logit of the elements of input, in-place.
1112+
/// input is clamped to [eps, 1 - eps] when eps is not null.
11121113
/// </summary>
1114+
/// <param name="input">The input tensor.</param>
1115+
/// <param name="eps">The epsilon for input clamp bound.</param>
11131116
public static Tensor logit_(Tensor input, double? eps = null) => input.logit_(eps);
11141117

11151118
// https://pytorch.org/docs/stable/generated/torch.hypot
@@ -1128,6 +1131,13 @@ public static ICollection<Tensor> gradient(Tensor input, int spacing = 1, long[]
11281131
/// <returns></returns>
11291132
[Pure]public static Tensor i0(Tensor input) => special.i0(input);
11301133

1134+
// https://pytorch.org/docs/stable/generated/torch.i0
1135+
/// <summary>
1136+
/// Computes the zeroth order modified Bessel function of the first kind for each element of input, in-place.
1137+
/// </summary>
1138+
/// <param name="input">The input tensor.</param>
1139+
public static Tensor i0_(Tensor input) => input.i0_();
1140+
11311141
// https://pytorch.org/docs/stable/generated/torch.igamma
11321142
/// <summary>
11331143
/// Alias for <see cref="torch.special.gammainc"/>.

test/TorchSharpTest/PointwiseTensorMath.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,17 @@ public void I0InPlaceTest()
929929
Assert.True(res.allclose(torch.tensor(expected)));
930930
}
931931

932+
[Fact]
933+
[TestOf(nameof(torch.i0_))]
934+
public void I0InPlaceTorchTest()
935+
{
936+
var data = torch.arange(0, 5, 1, float32);
937+
var expected = new float[] { 0.99999994f, 1.266066f, 2.27958512f, 4.88079262f, 11.3019209f };
938+
var res = torch.i0_(data);
939+
Assert.Same(data, res);
940+
Assert.True(res.allclose(torch.tensor(expected)));
941+
}
942+
932943
[Fact]
933944
[TestOf(nameof(Tensor.hypot))]
934945
public void HypotTest()

0 commit comments

Comments
 (0)