diff --git a/src/Native/LibTorchSharp/THSTensor.cpp b/src/Native/LibTorchSharp/THSTensor.cpp
index 13cd5787e..f313c56a1 100644
--- a/src/Native/LibTorchSharp/THSTensor.cpp
+++ b/src/Native/LibTorchSharp/THSTensor.cpp
@@ -2278,3 +2278,48 @@ Tensor THSTensor_unflatten_names(Tensor tensor, const char** names, const int64_
return nullptr;
}
+
+Tensor THSTensor_quantize_per_tensor(const Tensor tensor, double scale, int64_t zero_point, int8_t scalar_type)
+{
+ CATCH_TENSOR(torch::quantize_per_tensor(*tensor, scale, zero_point, at::ScalarType(scalar_type)));
+}
+
+Tensor THSTensor_quantize_per_channel(const Tensor tensor, const Tensor scales, const Tensor zero_points, int64_t axis, int8_t scalar_type)
+{
+ CATCH_TENSOR(torch::quantize_per_channel(*tensor, *scales, *zero_points, axis, at::ScalarType(scalar_type)));
+}
+
+Tensor THSTensor_dequantize(const Tensor tensor)
+{
+ CATCH_TENSOR(tensor->dequantize());
+}
+
+double THSTensor_q_scale(const Tensor tensor)
+{
+ CATCH_RETURN(double, 0.0, tensor->q_scale());
+}
+
+int64_t THSTensor_q_zero_point(const Tensor tensor)
+{
+ CATCH_RETURN(int64_t, 0, tensor->q_zero_point());
+}
+
+Tensor THSTensor_int_repr(const Tensor tensor)
+{
+ CATCH_TENSOR(tensor->int_repr());
+}
+
+Tensor THSTensor_q_per_channel_scales(const Tensor tensor)
+{
+ CATCH_TENSOR(tensor->q_per_channel_scales());
+}
+
+Tensor THSTensor_q_per_channel_zero_points(const Tensor tensor)
+{
+ CATCH_TENSOR(tensor->q_per_channel_zero_points());
+}
+
+int64_t THSTensor_q_per_channel_axis(const Tensor tensor)
+{
+ CATCH_RETURN(int64_t, 0, tensor->q_per_channel_axis());
+}
diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h
index 4ddfffb49..4437db84e 100644
--- a/src/Native/LibTorchSharp/THSTensor.h
+++ b/src/Native/LibTorchSharp/THSTensor.h
@@ -1790,3 +1790,15 @@ EXPORT_API(Tensor) THSTensor_kaiser_window(const int64_t len, bool periodic, dou
EXPORT_API(Tensor) THSTensor_stft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool normalized, int64_t onesided, bool return_complex);
EXPORT_API(Tensor) THSTensor_istft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool center, bool normalized, int64_t onesided, int64_t length, bool return_complex);
+
+// Quantization Ops
+
+EXPORT_API(Tensor) THSTensor_quantize_per_tensor(const Tensor tensor, double scale, int64_t zero_point, int8_t scalar_type);
+EXPORT_API(Tensor) THSTensor_quantize_per_channel(const Tensor tensor, const Tensor scales, const Tensor zero_points, int64_t axis, int8_t scalar_type);
+EXPORT_API(Tensor) THSTensor_dequantize(const Tensor tensor);
+EXPORT_API(double) THSTensor_q_scale(const Tensor tensor);
+EXPORT_API(int64_t) THSTensor_q_zero_point(const Tensor tensor);
+EXPORT_API(Tensor) THSTensor_int_repr(const Tensor tensor);
+EXPORT_API(Tensor) THSTensor_q_per_channel_scales(const Tensor tensor);
+EXPORT_API(Tensor) THSTensor_q_per_channel_zero_points(const Tensor tensor);
+EXPORT_API(int64_t) THSTensor_q_per_channel_axis(const Tensor tensor);
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
index bb5d2dbd9..a0597698a 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
@@ -2176,6 +2176,33 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern IntPtr THSTensor_histogram_out_t(IntPtr input, IntPtr bins, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_histogram_out_i(IntPtr input, long bins, IntPtr range, int length, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_quantize_per_tensor(IntPtr tensor, double scale, long zero_point, sbyte scalar_type);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_quantize_per_channel(IntPtr tensor, IntPtr scales, IntPtr zero_points, long axis, sbyte scalar_type);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_dequantize(IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern double THSTensor_q_scale(IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern long THSTensor_q_zero_point(IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_int_repr(IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_q_per_channel_scales(IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_q_per_channel_zero_points(IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern long THSTensor_q_per_channel_axis(IntPtr tensor);
}
#pragma warning restore CA2101
}
diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs
index 46a35db04..c6ccb88b7 100644
--- a/src/TorchSharp/Tensor/Tensor.cs
+++ b/src/TorchSharp/Tensor/Tensor.cs
@@ -271,6 +271,95 @@ internal IntPtr MoveHandle()
///
public bool is_complex() => torch.is_complex(dtype);
+ ///
+ /// Returns True if the data type of input is a quantized data type i.e., one of torch.qint8, torch.quint8, and torch.qint32.
+ ///
+ public bool is_quantized() => torch.is_quantized(dtype);
+
+ ///
+ /// Given a quantized Tensor, returns a dequantized (float) Tensor.
+ ///
+ public Tensor dequantize()
+ {
+ var res = NativeMethods.THSTensor_dequantize(Handle);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
+ ///
+ /// Given a quantized Tensor, returns the scale of the quantization as a double.
+ ///
+ public double q_scale()
+ {
+ var res = NativeMethods.THSTensor_q_scale(Handle);
+ CheckForErrors();
+ return res;
+ }
+
+ ///
+ /// Given a quantized Tensor, returns the zero_point of the quantization as a long.
+ ///
+ public long q_zero_point()
+ {
+ var res = NativeMethods.THSTensor_q_zero_point(Handle);
+ CheckForErrors();
+ return res;
+ }
+
+ ///
+ /// Given a quantized Tensor, returns a Tensor of the underlying integer representation.
+ ///
+ public Tensor int_repr()
+ {
+ var res = NativeMethods.THSTensor_int_repr(Handle);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
+ ///
+ /// Given a quantized Tensor quantized per channel, returns a Tensor of the scales of the quantization for each channel.
+ ///
+ public Tensor q_per_channel_scales()
+ {
+ var res = NativeMethods.THSTensor_q_per_channel_scales(Handle);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
+ ///
+ /// Given a quantized Tensor quantized per channel, returns a Tensor of the zero points of the quantization for each channel.
+ ///
+ public Tensor q_per_channel_zero_points()
+ {
+ var res = NativeMethods.THSTensor_q_per_channel_zero_points(Handle);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
+ ///
+ /// Given a quantized Tensor quantized per channel, returns the axis along which per channel quantization is applied.
+ ///
+ public long q_per_channel_axis()
+ {
+ var res = NativeMethods.THSTensor_q_per_channel_axis(Handle);
+ CheckForErrors();
+ return res;
+ }
+
+ internal Tensor _quantize_per_tensor(double scale, long zero_point, ScalarType dtype)
+ {
+ var res = NativeMethods.THSTensor_quantize_per_tensor(Handle, scale, zero_point, (sbyte)dtype);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
+ internal Tensor _quantize_per_channel(Tensor scales, Tensor zero_points, long axis, ScalarType dtype)
+ {
+ var res = NativeMethods.THSTensor_quantize_per_channel(Handle, scales.Handle, zero_points.Handle, axis, (sbyte)dtype);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
///
/// Returns True if the input is a single element tensor which is not equal to zero after type conversions,
/// i.e. not equal to torch.tensor([0.]) or torch.tensor([0]) or torch.tensor([False]).
@@ -7279,9 +7368,9 @@ public enum ScalarType : sbyte
ComplexFloat32 = 9,
ComplexFloat64 = 10,
Bool = 11,
- //QInt8 = 12,
- //QUInt8 = 13,
- //QUInt32 = 14,
+ QInt8 = 12,
+ QUInt8 = 13,
+ QInt32 = 14,
BFloat16 = 15
}
@@ -7413,6 +7502,18 @@ public static bool is_complex(ScalarType type)
}
}
+ public static bool is_quantized(ScalarType type)
+ {
+ switch (type) {
+ case ScalarType.QInt8:
+ case ScalarType.QUInt8:
+ case ScalarType.QInt32:
+ return true;
+ default:
+ return false;
+ }
+ }
+
public static long max_int_value(ScalarType type)
{
switch (type) {
@@ -7463,6 +7564,10 @@ public static long max_int_value(ScalarType type)
public static ScalarType cfloat = ScalarType.ComplexFloat32;
public static ScalarType cdouble = ScalarType.ComplexFloat64;
+ public static ScalarType qint8 = ScalarType.QInt8;
+ public static ScalarType quint8 = ScalarType.QUInt8;
+ public static ScalarType qint32 = ScalarType.QInt32;
+
///
/// Creates a new dispose scope for the current thread. Any tensor created within the dispose scope will
/// be automatically disposed once the dispose scope is disposed.
diff --git a/src/TorchSharp/Tensor/TensorExtensionMethods.cs b/src/TorchSharp/Tensor/TensorExtensionMethods.cs
index 2f4fa81dc..91b777660 100644
--- a/src/TorchSharp/Tensor/TensorExtensionMethods.cs
+++ b/src/TorchSharp/Tensor/TensorExtensionMethods.cs
@@ -368,6 +368,23 @@ internal static bool IsComplex(this ScalarType type)
}
}
+ ///
+ /// Indicates whether a given element type is quantized.
+ ///
+ /// The input type.
+ ///
+ internal static bool IsQuantized(this ScalarType type)
+ {
+ switch (type) {
+ case ScalarType.QInt8:
+ case ScalarType.QUInt8:
+ case ScalarType.QInt32:
+ return true;
+ default:
+ return false;
+ }
+ }
+
///
/// Save the tensor in a .NET-specific format.
///
diff --git a/src/TorchSharp/Tensor/torch.PointwiseOps.cs b/src/TorchSharp/Tensor/torch.PointwiseOps.cs
index 0fccbd8ce..cc7dffbd6 100644
--- a/src/TorchSharp/Tensor/torch.PointwiseOps.cs
+++ b/src/TorchSharp/Tensor/torch.PointwiseOps.cs
@@ -761,6 +761,47 @@ public static Tensor fake_quantize_per_channel_affine(Tensor input, Tensor scale
public static Tensor fake_quantize_per_tensor_affine(Tensor input, Tensor scale, Tensor zero_point, long quant_min, long quant_max)
=> throw new NotImplementedException();
+ // https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor
+ ///
+ /// Converts a float tensor to a quantized tensor with given scale and zero point.
+ ///
+ /// Float tensor to quantize
+ /// Scale to apply in quantization formula
+ /// Offset in integer value that maps to float zero
+ /// The desired data type of returned tensor. Must be a quantized type (torch.qint8, torch.quint8, or torch.qint32).
+ /// A newly quantized tensor
+ public static Tensor quantize_per_tensor(Tensor input, double scale, long zero_point, ScalarType dtype)
+ {
+ if (!is_quantized(dtype))
+ throw new ArgumentException("dtype must be a quantized type (QInt8, QUInt8, or QInt32)", nameof(dtype));
+ return input._quantize_per_tensor(scale, zero_point, dtype);
+ }
+
+ // https://pytorch.org/docs/stable/generated/torch.quantize_per_channel
+ ///
+ /// Converts a float tensor to a per-channel quantized tensor with given scales and zero points.
+ ///
+ /// Float tensor to quantize
+ /// Float 1D tensor of scales to use, size should match input.size(axis)
+ /// Integer 1D tensor of offsets to use, size should match input.size(axis)
+ /// Dimension on which to apply per-channel quantization
+ /// The desired data type of returned tensor. Must be a quantized type (torch.qint8, torch.quint8, or torch.qint32).
+ /// A newly quantized tensor
+ public static Tensor quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, long axis, ScalarType dtype)
+ {
+ if (!is_quantized(dtype))
+ throw new ArgumentException("dtype must be a quantized type (QInt8, QUInt8, or QInt32)", nameof(dtype));
+ return input._quantize_per_channel(scales, zero_points, axis, dtype);
+ }
+
+ // https://pytorch.org/docs/stable/generated/torch.dequantize
+ ///
+ /// Returns an fp32 Tensor by dequantizing a quantized Tensor.
+ ///
+ /// A quantized tensor
+ /// A dequantized (float) tensor
+ public static Tensor dequantize(Tensor input) => input.dequantize();
+
// https://pytorch.org/docs/stable/generated/torch.fix
///
/// Returns a new tensor with the truncated integer values of the elements of input.
diff --git a/test/TorchSharpTest/TestTorchSharp.cs b/test/TorchSharpTest/TestTorchSharp.cs
index 549b8f131..b78e48b70 100644
--- a/test/TorchSharpTest/TestTorchSharp.cs
+++ b/test/TorchSharpTest/TestTorchSharp.cs
@@ -466,5 +466,223 @@ public void CheckVersionStrings()
// Because some of the tests mess with global state, and are run in parallel, we need to
// acquire a lock before testing setting the default RNG see.
private static object _lock = new object();
+
+ [Fact]
+ [TestOf(nameof(ScalarType))]
+ public void QIntScalarTypeEnumValues()
+ {
+ // Verify the enum values match PyTorch's ScalarType ordinals
+ Assert.Equal(12, (int)ScalarType.QInt8);
+ Assert.Equal(13, (int)ScalarType.QUInt8);
+ Assert.Equal(14, (int)ScalarType.QInt32);
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.is_quantized))]
+ public void IsQuantizedScalarType()
+ {
+ // Quantized types should return true
+ Assert.True(torch.is_quantized(ScalarType.QInt8));
+ Assert.True(torch.is_quantized(ScalarType.QUInt8));
+ Assert.True(torch.is_quantized(ScalarType.QInt32));
+
+ // Non-quantized types should return false
+ Assert.False(torch.is_quantized(ScalarType.Float32));
+ Assert.False(torch.is_quantized(ScalarType.Float64));
+ Assert.False(torch.is_quantized(ScalarType.Int8));
+ Assert.False(torch.is_quantized(ScalarType.Int32));
+ Assert.False(torch.is_quantized(ScalarType.Bool));
+ Assert.False(torch.is_quantized(ScalarType.Byte));
+ Assert.False(torch.is_quantized(ScalarType.ComplexFloat32));
+ Assert.False(torch.is_quantized(ScalarType.BFloat16));
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.qint8))]
+ public void QIntDtypeAliases()
+ {
+ // Verify dtype aliases map to the correct ScalarType values
+ Assert.Equal(ScalarType.QInt8, torch.qint8);
+ Assert.Equal(ScalarType.QUInt8, torch.quint8);
+ Assert.Equal(ScalarType.QInt32, torch.qint32);
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.is_quantized))]
+ public void IsQuantizedNotIntegralOrFloating()
+ {
+ // Quantized types should not be classified as integral, floating, or complex
+ Assert.False(torch.is_integral(ScalarType.QInt8));
+ Assert.False(torch.is_integral(ScalarType.QUInt8));
+ Assert.False(torch.is_integral(ScalarType.QInt32));
+
+ Assert.False(torch.is_floating_point(ScalarType.QInt8));
+ Assert.False(torch.is_floating_point(ScalarType.QUInt8));
+ Assert.False(torch.is_floating_point(ScalarType.QInt32));
+
+ Assert.False(torch.is_complex(ScalarType.QInt8));
+ Assert.False(torch.is_complex(ScalarType.QUInt8));
+ Assert.False(torch.is_complex(ScalarType.QInt32));
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.quantize_per_tensor))]
+ public void QuantizePerTensorQInt8()
+ {
+ var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f, 4.0f });
+
+ var qTensor = torch.quantize_per_tensor(floatTensor, 0.1, 0, ScalarType.QInt8);
+ Assert.True(qTensor.is_quantized());
+ Assert.Equal(ScalarType.QInt8, qTensor.dtype);
+ Assert.False(qTensor.is_floating_point());
+ Assert.False(qTensor.is_integral());
+ Assert.False(qTensor.is_complex());
+
+ qTensor.Dispose();
+ floatTensor.Dispose();
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.quantize_per_tensor))]
+ public void QuantizePerTensorQUInt8()
+ {
+ var floatTensor = torch.tensor(new float[] { 0.5f, 1.5f, 2.5f });
+
+ var qTensor = torch.quantize_per_tensor(floatTensor, 0.1, 128, ScalarType.QUInt8);
+ Assert.True(qTensor.is_quantized());
+ Assert.Equal(ScalarType.QUInt8, qTensor.dtype);
+
+ qTensor.Dispose();
+ floatTensor.Dispose();
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.quantize_per_tensor))]
+ public void QuantizePerTensorQInt32()
+ {
+ var floatTensor = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f });
+
+ var qTensor = torch.quantize_per_tensor(floatTensor, 0.01, 0, ScalarType.QInt32);
+ Assert.True(qTensor.is_quantized());
+ Assert.Equal(ScalarType.QInt32, qTensor.dtype);
+
+ qTensor.Dispose();
+ floatTensor.Dispose();
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.quantize_per_tensor))]
+ public void QuantizePerTensorInvalidDtypeThrows()
+ {
+ var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f });
+ Assert.Throws(() => torch.quantize_per_tensor(floatTensor, 0.1, 0, ScalarType.Float32));
+ Assert.Throws(() => torch.quantize_per_tensor(floatTensor, 0.1, 0, ScalarType.Int32));
+ floatTensor.Dispose();
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.dequantize))]
+ public void DequantizeRoundtrip()
+ {
+ var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f });
+
+ var qTensor = torch.quantize_per_tensor(floatTensor, 1.0, 0, ScalarType.QInt8);
+ Assert.True(qTensor.is_quantized());
+
+ var dequantized = qTensor.dequantize();
+ Assert.False(dequantized.is_quantized());
+ Assert.True(dequantized.is_floating_point());
+ Assert.Equal(ScalarType.Float32, dequantized.dtype);
+
+ // With scale=1.0 and zero_point=0, values should roundtrip exactly
+ Assert.Equal(1.0f, dequantized[0].ToSingle());
+ Assert.Equal(2.0f, dequantized[1].ToSingle());
+ Assert.Equal(3.0f, dequantized[2].ToSingle());
+
+ dequantized.Dispose();
+ qTensor.Dispose();
+ floatTensor.Dispose();
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.dequantize))]
+ public void DequantizeStaticMethod()
+ {
+ var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f });
+ var qTensor = torch.quantize_per_tensor(floatTensor, 1.0, 0, ScalarType.QInt8);
+
+ var dequantized = torch.dequantize(qTensor);
+ Assert.False(dequantized.is_quantized());
+ Assert.Equal(ScalarType.Float32, dequantized.dtype);
+
+ dequantized.Dispose();
+ qTensor.Dispose();
+ floatTensor.Dispose();
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.q_scale))]
+ public void QScaleAndZeroPoint()
+ {
+ var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f });
+ double scale = 0.5;
+ long zeroPoint = 10;
+
+ var qTensor = torch.quantize_per_tensor(floatTensor, scale, zeroPoint, ScalarType.QInt8);
+ Assert.Equal(scale, qTensor.q_scale());
+ Assert.Equal(zeroPoint, qTensor.q_zero_point());
+
+ qTensor.Dispose();
+ floatTensor.Dispose();
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.int_repr))]
+ public void IntReprReturnsUnderlyingIntegers()
+ {
+ var floatTensor = torch.tensor(new float[] { 0.0f, 1.0f, 2.0f });
+
+ // scale=1.0, zero_point=0: quantized values should be 0, 1, 2
+ var qTensor = torch.quantize_per_tensor(floatTensor, 1.0, 0, ScalarType.QInt8);
+ var intRepr = qTensor.int_repr();
+
+ Assert.False(intRepr.is_quantized());
+ Assert.Equal(ScalarType.Int8, intRepr.dtype);
+ Assert.Equal(0, intRepr[0].ToSByte());
+ Assert.Equal(1, intRepr[1].ToSByte());
+ Assert.Equal(2, intRepr[2].ToSByte());
+
+ intRepr.Dispose();
+ qTensor.Dispose();
+ floatTensor.Dispose();
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.quantize_per_channel))]
+ public void QuantizePerChannel()
+ {
+ // Create a 2D tensor: 2 channels x 3 elements
+ var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }).reshape(2, 3);
+ var scales = torch.tensor(new double[] { 0.1, 0.2 });
+ var zeroPoints = torch.tensor(new long[] { 0, 0 });
+
+ var qTensor = torch.quantize_per_channel(floatTensor, scales, zeroPoints, 0, ScalarType.QInt8);
+ Assert.True(qTensor.is_quantized());
+ Assert.Equal(ScalarType.QInt8, qTensor.dtype);
+
+ // Verify per-channel quantization parameters
+ var channelScales = qTensor.q_per_channel_scales();
+ var channelZeroPoints = qTensor.q_per_channel_zero_points();
+ Assert.Equal(0, qTensor.q_per_channel_axis());
+ Assert.Equal(0.1, channelScales[0].ToDouble(), 5);
+ Assert.Equal(0.2, channelScales[1].ToDouble(), 5);
+
+ channelScales.Dispose();
+ channelZeroPoints.Dispose();
+ qTensor.Dispose();
+ scales.Dispose();
+ zeroPoints.Dispose();
+ floatTensor.Dispose();
+ }
}
}