Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE_NOTES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Changes from 2.14.1 to 2.14.2
-----------------------------

* **Under development.**
* Avoid keeping arrays passed as ``out=`` alive in the ``re_evaluate`` cache.

Changes from 2.14.0 to 2.14.1
-----------------------------
Expand Down
26 changes: 23 additions & 3 deletions numexpr/necompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import re
import sys
import threading
import weakref
from typing import Dict, Optional

import numpy
Expand Down Expand Up @@ -794,6 +795,26 @@ def getArguments(names, local_dict=None, global_dict=None, _frame_depth: int=2):
evaluate_lock = threading.Lock()


def _cache_last_kwargs(out: numpy.ndarray,
order: str,
casting: str,
ex_uses_vml: bool) -> Dict:
return {
'out': None if out is None else weakref.ref(out),
'order': order,
'casting': casting,
'ex_uses_vml': ex_uses_vml,
}


def _resolve_last_kwargs(kwargs: Dict) -> Dict:
kwargs = kwargs.copy()
out = kwargs.get('out')
if isinstance(out, weakref.ReferenceType):
kwargs['out'] = out()
return kwargs


def validate(ex: str,
local_dict: Optional[Dict] = None,
global_dict: Optional[Dict] = None,
Expand Down Expand Up @@ -905,8 +926,7 @@ def validate(ex: str,
compiled_ex = _numexpr_cache.c[numexpr_key]
except KeyError:
compiled_ex = _numexpr_cache.c[numexpr_key] = NumExpr(ex, signature, sanitize=sanitize, **context)
kwargs = {'out': out, 'order': order, 'casting': casting,
'ex_uses_vml': ex_uses_vml}
kwargs = _cache_last_kwargs(out, order, casting, ex_uses_vml)
_numexpr_last.l.set(ex=compiled_ex, argnames=names, kwargs=kwargs)
except Exception as e:
return e
Expand Down Expand Up @@ -1049,6 +1069,6 @@ def re_evaluate(local_dict: Optional[Dict] = None,
raise RuntimeError("A previous evaluate() execution was not found, please call `validate` or `evaluate` once before `re_evaluate`")
argnames = _numexpr_last.l['argnames']
args = getArguments(argnames, local_dict, global_dict, _frame_depth=_frame_depth)
kwargs = _numexpr_last.l['kwargs']
kwargs = _resolve_last_kwargs(_numexpr_last.l['kwargs'])
# with evaluate_lock:
return compiled_ex(*args, **kwargs)
27 changes: 26 additions & 1 deletion numexpr/tests/test_numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
# rights to use.
####################################################################


import gc
import os
import platform
import subprocess
import sys
import unittest
import warnings
import weakref
from contextlib import contextmanager
from unittest.mock import MagicMock

Expand Down Expand Up @@ -412,6 +413,30 @@ def test_re_evaluate_dict(self):
x = re_evaluate(local_dict=local_dict)
assert_array_equal(x, array([86., 124., 168.]))

def test_evaluate_out_is_not_kept_alive(self):
a = arange(1000.0)
out = zeros(a.shape)
out_ref = weakref.ref(out)

evaluate("a + 1", local_dict={"a": a}, out=out)
del out
gc.collect()

assert out_ref() is None

def test_re_evaluate_reuses_live_out(self):
a = array([1., 2., 3.])
out = zeros(a.shape)

x = evaluate("a + 1", local_dict={"a": a}, out=out)
assert x is out
assert_array_equal(out, array([2., 3., 4.]))

a = array([4., 5., 6.])
x = re_evaluate(local_dict={"a": a})
assert x is out
assert_array_equal(out, array([5., 6., 7.]))

def test_validate(self):
a = array([1., 2., 3.])
b = array([4., 5., 6.])
Expand Down
Loading