From 9a56aa29a1f4a31d90b98c353659dfef0d75e576 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 26 Jun 2026 17:31:07 +0800 Subject: [PATCH] test: add `addcdiv` operator coverage --- tests/test_addcdiv.py | 81 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 tests/test_addcdiv.py diff --git a/tests/test_addcdiv.py b/tests/test_addcdiv.py new file mode 100644 index 000000000..290c6fc62 --- /dev/null +++ b/tests/test_addcdiv.py @@ -0,0 +1,81 @@ +import pytest +import torch + +import infini.ops + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +_SHAPE_CASES = ( + ((13, 4), None, None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1), (10, 1)), + ((13, 4, 4), None, None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((16, 5632), None, None, None, None), + ((4, 4, 5632), None, None, None, None), +) + +_FLOAT_DTYPE_CASES = ( + (torch.float32, 1e-6, 1e-6), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), +) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, tensor1_strides, tensor2_strides, out_strides", + _SHAPE_CASES, +) +@pytest.mark.parametrize(("dtype", "rtol", "atol"), _FLOAT_DTYPE_CASES) +def test_addcdiv( + shape, + input_strides, + tensor1_strides, + tensor2_strides, + out_strides, + dtype, + device, + implementation_index, + rtol, + atol, +): + if device == "cpu" and implementation_index == 8 and dtype == torch.float16: + pytest.skip("ATen `addcdiv.out` does not support float16 on CPU") + + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + tensor1 = randn_strided(shape, tensor1_strides, dtype=dtype, device=device) + tensor2 = randn_strided(shape, tensor2_strides, dtype=dtype, device=device) + tensor2 = tensor2.sign() * tensor2.abs().clamp_min(0.5) + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload( + lambda input, tensor1, tensor2, out: _addcdiv( + input, tensor1, tensor2, out, implementation_index + ), + _torch_addcdiv, + (input, tensor1, tensor2, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _addcdiv(input, tensor1, tensor2, out, implementation_index): + infini.ops.addcdiv( + input, + tensor1, + tensor2, + 0.5, + out, + stream=get_stream(input.device), + implementation_index=implementation_index, + ) + + return out + + +def _torch_addcdiv(input, tensor1, tensor2, out): + out.copy_(torch.addcdiv(input, tensor1, tensor2, value=0.5)) + + return out