Description
Caught in data-apis/array-api-tests#410 :
when clip receives a uint64 array and python scalar max which is in range of uint64 but overflows int64,
it reports an internal overflow:
(array-api) ev-br@qgpu3:~$ JAX_ENABLE_X64=true ipython
Python 3.11.0 | packaged by conda-forge | (main, Jan 14 2023, 12:27:40) [GCC 11.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.32.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import jax.numpy as jnp
In [2]: kw = {'min': 0, 'max': 9223372036854775808}
In [3]: jnp.clip(jnp.asarray(0, dtype=jnp.uint64), **kw)
WARNING:2026-01-17 20:43:18,604:jax._src.xla_bridge:850: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
---------------------------------------------------------------------------
OverflowError Traceback (most recent call last)
Cell In[3], line 1
----> 1 jnp.clip(jnp.asarray(0, dtype=jnp.uint64), **kw)
[... skipping hidden 5 frame]
File ~/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/pjit.py:671, in _infer_input_type(fun, dbg_fn, explicit_args)
669 dbg = dbg_fn()
670 arg_path = f"argument path is {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'}" # pytype: disable=name-error
--> 671 raise OverflowError(
672 "An overflow was encountered while parsing an argument to a jitted "
673 f"computation, whose {arg_path}."
674 ) from None
675 except TypeError:
676 dbg = dbg_fn()
OverflowError: An overflow was encountered while parsing an argument to a jitted computation, whose argument path is max.
In [4]: jnp.iinfo(jnp.uint64).max - kw['max']
Out[4]: 9223372036854775807
In [5]: jnp.iinfo(jnp.int64).max - kw['max']
Out[5]: -1
System info (python version, jaxlib version, accelerator, etc.)
In [6]: import jax; jax.print_environment_info()
jax: 0.8.0
jaxlib: 0.8.0
numpy: 2.2.5
python: 3.11.0 | packaged by conda-forge | (main, Jan 14 2023, 12:27:40) [GCC 11.3.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='qgpu3', release='6.8.0-71-generic', version='#71-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 16:52:38 UTC 2025', machine='x86_64')
JAX_ENABLE_X64=true
$ nvidia-smi
Sat Jan 17 20:47:16 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 2060 Off | 00000000:01:00.0 Off | N/A |
| 31% 38C P0 36W / 160W | 1MiB / 6144MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 2060 Off | 00000000:21:00.0 Off | N/A |
| 33% 41C P0 37W / 160W | 1MiB / 6144MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
Description
Caught in data-apis/array-api-tests#410 :
when
clipreceives auint64array and python scalarmaxwhich is in range of uint64 but overflows int64,it reports an internal overflow:
System info (python version, jaxlib version, accelerator, etc.)