diff --git a/test/TorchSharpTest/TestTorchTensor.cs b/test/TorchSharpTest/TestTorchTensor.cs index bd75ecbb2..f019b4ccb 100644 --- a/test/TorchSharpTest/TestTorchTensor.cs +++ b/test/TorchSharpTest/TestTorchTensor.cs @@ -6938,32 +6938,6 @@ public void Float64FFT() Assert.Equal(ScalarType.ComplexFloat64, inverted.dtype); } - [Fact] - [TestOf(nameof(fft.hfft))] - public void Float32HFFT() - { - var input = torch.arange(4); - var output = fft.hfft(input); - Assert.Equal(6, output.shape[0]); - Assert.Equal(ScalarType.Float32, output.dtype); - - var inverted = fft.ifft(output); - Assert.Equal(ScalarType.ComplexFloat32, inverted.dtype); - } - - [Fact] - [TestOf(nameof(fft.hfft))] - public void Float64HFFT() - { - var input = torch.arange(4, float64); - var output = fft.hfft(input); - Assert.Equal(6, output.shape[0]); - Assert.Equal(ScalarType.Float64, output.dtype); - - var inverted = fft.ifft(output); - Assert.Equal(ScalarType.ComplexFloat64, inverted.dtype); - } - [Fact] [TestOf(nameof(fft.rfft))] public void Float32RFFT() @@ -7201,12 +7175,12 @@ public void Float64RFFTN() [Fact] [TestOf(nameof(fft.hfft2))] - public void Float32HFFT2() + public void ComplexFloat32HFFT2() { - var input = torch.rand(new long[] { 5, 5, 5, 5 }); + var input = torch.rand(new long[] { 5, 5, 5, 5 }, complex64); var output = fft.hfft2(input); Assert.Equal(new long[] { 5, 5, 5, 8 }, output.shape); - Assert.Equal(input.dtype, output.dtype); + Assert.Equal(ScalarType.Float32, output.dtype); var inverted = fft.ihfft2(output); Assert.Equal(new long[] { 5, 5, 5, 5 }, inverted.shape); @@ -7215,12 +7189,12 @@ public void Float32HFFT2() [Fact] [TestOf(nameof(fft.hfft2))] - public void Float64HFFT2() + public void ComplexFloat64HFFT2() { - var input = torch.rand(new long[] { 5, 5, 5, 5 }, float64); + var input = torch.rand(new long[] { 5, 5, 5, 5 }, complex128); var output = fft.hfft2(input); Assert.Equal(new long[] { 5, 5, 5, 8 }, output.shape); - Assert.Equal(input.dtype, output.dtype); + Assert.Equal(ScalarType.Float64, output.dtype); var inverted = fft.ihfft2(output); Assert.Equal(new long[] { 5, 5, 5, 5 }, inverted.shape); @@ -7228,31 +7202,31 @@ public void Float64HFFT2() } [Fact] - [TestOf(nameof(fft.hfft2))] - public void Float32HFFTN() + [TestOf(nameof(fft.hfftn))] + public void ComplexFloat32HFFTN() { - var input = torch.rand(new long[] { 5, 5, 5, 5 }); - var output = fft.hfft2(input); + var input = torch.rand(new long[] { 5, 5, 5, 5 }, complex64); + var output = fft.hfftn(input); Assert.Equal(new long[] { 5, 5, 5, 8 }, output.shape); - Assert.Equal(input.dtype, output.dtype); + Assert.Equal(ScalarType.Float32, output.dtype); - var inverted = fft.ihfft2(output); + var inverted = fft.ihfftn(output); Assert.Equal(new long[] { 5, 5, 5, 5 }, inverted.shape); Assert.Equal(ScalarType.ComplexFloat32, inverted.dtype); } [Fact(Skip = "Fails on all Release builds.")] [TestOf(nameof(fft.hfftn))] - public void Float64HFFTN() + public void ComplexFloat64HFFTN() { if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && !RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { // TODO: Something in this test makes if fail on Windows / Release and MacOS / Release - var input = torch.rand(new long[] { 5, 5, 5, 5 }, float64); + var input = torch.rand(new long[] { 5, 5, 5, 5 }, complex128); var output = fft.hfftn(input); Assert.Equal(new long[] { 5, 5, 5, 8 }, output.shape); - Assert.Equal(input.dtype, output.dtype); + Assert.Equal(ScalarType.Float64, output.dtype); var inverted = fft.ihfftn(output); Assert.Equal(new long[] { 5, 5, 5, 5 }, inverted.shape);