From f60c0a97a61dd3cfabe745c9f51341c6325118ab Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Wed, 28 Jan 2026 13:38:07 +0000 Subject: [PATCH] BUG: reject non-scalar `fill_value` in `full{_like}` closes gh-55 --- array_api_strict/_creation_functions.py | 9 +++++++-- array_api_strict/tests/test_creation_functions.py | 2 ++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 64c51ce..99c4cde 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -240,8 +240,9 @@ def full( _check_valid_dtype(dtype) _check_device(device) - if isinstance(fill_value, Array) and fill_value.ndim == 0: - fill_value = fill_value._array + if not isinstance(fill_value, bool | int | float | complex): + msg = f"Expected Python scalar fill_value, got type {type(fill_value)}" + raise TypeError(msg) res = np.full(shape, fill_value, dtype=_np_dtype(dtype)) if DType(res.dtype) not in _all_dtypes: # This will happen if the fill value is not something that NumPy @@ -270,6 +271,10 @@ def full_like( if device is None: device = x.device + if not isinstance(fill_value, bool | int | float | complex): + msg = f"Expected Python scalar fill_value, got type {type(fill_value)}" + raise TypeError(msg) + res = np.full_like(x._array, fill_value, dtype=_np_dtype(dtype)) if DType(res.dtype) not in _all_dtypes: # This will happen if the fill value is not something that NumPy diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 573fc7f..6736826 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -161,6 +161,7 @@ def test_full_errors(): assert_raises(ValueError, lambda: full((1,), 0, device="gpu")) assert_raises(ValueError, lambda: full((1,), 0, dtype=int)) assert_raises(ValueError, lambda: full((1,), 0, dtype="i")) + assert_raises(TypeError, lambda: full((1,), asarray(0))) def test_full_like_errors(): @@ -169,6 +170,7 @@ def test_full_like_errors(): assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="gpu")) assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype=int)) assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype="i")) + assert_raises(TypeError, lambda: full(asarray(1), asarray(0))) def test_linspace_errors():