diff --git a/src/csrc/dtype.c b/src/csrc/dtype.c index 2033f01..cb1406e 100644 --- a/src/csrc/dtype.c +++ b/src/csrc/dtype.c @@ -390,6 +390,25 @@ quadprec_fromstr(char *s, void *dptr, char **endptr, PyArray_Descr *descr_generi return 0; } +static npy_bool +quadprec_nonzero(void *data, void *arr) +{ + PyArrayObject *arr_obj = (PyArrayObject *)arr; + QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)PyArray_DESCR(arr_obj); + QuadBackendType backend = descr->backend; + + if (backend == BACKEND_SLEEF) { + Sleef_quad val; + memcpy(&val, data, sizeof(Sleef_quad)); + return !Sleef_icmpeqq1(val, QUAD_PRECISION_ZERO); + } + else { + long double val; + memcpy(&val, data, sizeof(long double)); + return val != 0.0L; + } +} + /* * Compare function for sorting operations (argsort, sort, etc.) * Implements PyArray_CompareFunc. @@ -464,6 +483,7 @@ static PyType_Slot QuadPrecDType_Slots[] = { {NPY_DT_getitem, &quadprec_getitem}, {NPY_DT_default_descr, &quadprec_default_descr}, {NPY_DT_get_constant, &quadprec_get_constant}, + {NPY_DT_PyArray_ArrFuncs_nonzero, &quadprec_nonzero}, {NPY_DT_PyArray_ArrFuncs_compare, &quadprec_compare}, {NPY_DT_PyArray_ArrFuncs_fill, &quadprec_fill}, {NPY_DT_PyArray_ArrFuncs_scanfunc, &quadprec_scanfunc}, diff --git a/tests/test_quaddtype.py b/tests/test_quaddtype.py index 7d65e2e..839047e 100644 --- a/tests/test_quaddtype.py +++ b/tests/test_quaddtype.py @@ -1531,6 +1531,30 @@ def test_unary_logical_not(x): assert isinstance(quad_result, (bool, np.bool_)), f"Result should be bool, got {type(quad_result)}" +@pytest.mark.parametrize("val,expected", [ + ("1.0", True), + ("0.5", True), + ("1e-100", True), + ("-1.0", True), + ("0.0", False), + ("-0.0", False), +]) +def test_bool_0d_array(val, expected): + """ + Test boolean conversion on 0-d QuadPrecision arrays. + + This tests that bool(np.array(QuadPrecision(x))) works correctly + and doesn't segfault due to missing dtype nonzero function. + """ + quad_scalar = QuadPrecision(val) + arr_0d = np.array(quad_scalar) + # Ensure it's actually a 0-d array + assert arr_0d.ndim == 0, f"Expected 0-d array, got {arr_0d.ndim}-d" + + # This should not segfault + result = bool(arr_0d) + assert result == expected, f"bool(np.array(QuadPrecision({val}))) should be {expected}, got {result}" + @pytest.mark.parametrize("val", [ # Small positive values that truncate to 0 when cast to int 0.5, 0.1, 0.01, 0.001, 0.0001,