diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index 43ca0d1..9d13520 100644 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -6,6 +6,7 @@ Changes from 2.14.1 to 2.14.2 ----------------------------- * **Under development.** +* Avoid keeping arrays passed as ``out=`` alive in the ``re_evaluate`` cache. Changes from 2.14.0 to 2.14.1 ----------------------------- diff --git a/numexpr/necompiler.py b/numexpr/necompiler.py index 96c66f6..d4a1145 100644 --- a/numexpr/necompiler.py +++ b/numexpr/necompiler.py @@ -14,6 +14,7 @@ import re import sys import threading +import weakref from typing import Dict, Optional import numpy @@ -794,6 +795,26 @@ def getArguments(names, local_dict=None, global_dict=None, _frame_depth: int=2): evaluate_lock = threading.Lock() +def _cache_last_kwargs(out: numpy.ndarray, + order: str, + casting: str, + ex_uses_vml: bool) -> Dict: + return { + 'out': None if out is None else weakref.ref(out), + 'order': order, + 'casting': casting, + 'ex_uses_vml': ex_uses_vml, + } + + +def _resolve_last_kwargs(kwargs: Dict) -> Dict: + kwargs = kwargs.copy() + out = kwargs.get('out') + if isinstance(out, weakref.ReferenceType): + kwargs['out'] = out() + return kwargs + + def validate(ex: str, local_dict: Optional[Dict] = None, global_dict: Optional[Dict] = None, @@ -905,8 +926,7 @@ def validate(ex: str, compiled_ex = _numexpr_cache.c[numexpr_key] except KeyError: compiled_ex = _numexpr_cache.c[numexpr_key] = NumExpr(ex, signature, sanitize=sanitize, **context) - kwargs = {'out': out, 'order': order, 'casting': casting, - 'ex_uses_vml': ex_uses_vml} + kwargs = _cache_last_kwargs(out, order, casting, ex_uses_vml) _numexpr_last.l.set(ex=compiled_ex, argnames=names, kwargs=kwargs) except Exception as e: return e @@ -1049,6 +1069,6 @@ def re_evaluate(local_dict: Optional[Dict] = None, raise RuntimeError("A previous evaluate() execution was not found, please call `validate` or `evaluate` once before `re_evaluate`") argnames = _numexpr_last.l['argnames'] args = getArguments(argnames, local_dict, global_dict, _frame_depth=_frame_depth) - kwargs = _numexpr_last.l['kwargs'] + kwargs = _resolve_last_kwargs(_numexpr_last.l['kwargs']) # with evaluate_lock: return compiled_ex(*args, **kwargs) diff --git a/numexpr/tests/test_numexpr.py b/numexpr/tests/test_numexpr.py index 46fad29..8133de3 100644 --- a/numexpr/tests/test_numexpr.py +++ b/numexpr/tests/test_numexpr.py @@ -9,13 +9,14 @@ # rights to use. #################################################################### - +import gc import os import platform import subprocess import sys import unittest import warnings +import weakref from contextlib import contextmanager from unittest.mock import MagicMock @@ -412,6 +413,30 @@ def test_re_evaluate_dict(self): x = re_evaluate(local_dict=local_dict) assert_array_equal(x, array([86., 124., 168.])) + def test_evaluate_out_is_not_kept_alive(self): + a = arange(1000.0) + out = zeros(a.shape) + out_ref = weakref.ref(out) + + evaluate("a + 1", local_dict={"a": a}, out=out) + del out + gc.collect() + + assert out_ref() is None + + def test_re_evaluate_reuses_live_out(self): + a = array([1., 2., 3.]) + out = zeros(a.shape) + + x = evaluate("a + 1", local_dict={"a": a}, out=out) + assert x is out + assert_array_equal(out, array([2., 3., 4.])) + + a = array([4., 5., 6.]) + x = re_evaluate(local_dict={"a": a}) + assert x is out + assert_array_equal(out, array([5., 6., 7.])) + def test_validate(self): a = array([1., 2., 3.]) b = array([4., 5., 6.])