diff --git a/benchmarks/test_deserializer_cache_benchmark.py b/benchmarks/test_deserializer_cache_benchmark.py new file mode 100644 index 0000000000..d2ed1585d1 --- /dev/null +++ b/benchmarks/test_deserializer_cache_benchmark.py @@ -0,0 +1,212 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmarks for find_deserializer / make_deserializers with and without caching. + +Run with: pytest benchmarks/test_deserializer_cache_benchmark.py -v + +Requires the ``pytest-benchmark`` plugin and Cython extensions to be built. +Skipped automatically when either dependency is unavailable. +""" + +import pytest + +pytest.importorskip("pytest_benchmark") +pytest.importorskip("cassandra.deserializers") + +from cassandra import cqltypes +from cassandra.deserializers import ( + find_deserializer, + make_deserializers, +) + + +# --------------------------------------------------------------------------- +# Reference: original uncached implementations (copied from master) +# --------------------------------------------------------------------------- + +_classes = {} + + +def _init_classes(): + """Lazily initialize the class lookup dict from deserializers module.""" + if not _classes: + from cassandra import deserializers as mod + + for name in dir(mod): + obj = getattr(mod, name) + if isinstance(obj, type): + _classes[name] = obj + + +def find_deserializer_uncached(cqltype): + """Original implementation without caching.""" + _init_classes() + + name = "Des" + cqltype.__name__ + if name in _classes: + cls = _classes[name] + elif issubclass(cqltype, cqltypes.ListType): + from cassandra.deserializers import DesListType + + cls = DesListType + elif issubclass(cqltype, cqltypes.SetType): + from cassandra.deserializers import DesSetType + + cls = DesSetType + elif issubclass(cqltype, cqltypes.MapType): + from cassandra.deserializers import DesMapType + + cls = DesMapType + elif issubclass(cqltype, cqltypes.UserType): + from cassandra.deserializers import DesUserType + + cls = DesUserType + elif issubclass(cqltype, cqltypes.TupleType): + from cassandra.deserializers import DesTupleType + + cls = DesTupleType + elif issubclass(cqltype, cqltypes.DynamicCompositeType): + from cassandra.deserializers import DesDynamicCompositeType + + cls = DesDynamicCompositeType + elif issubclass(cqltype, cqltypes.CompositeType): + from cassandra.deserializers import DesCompositeType + + cls = DesCompositeType + elif issubclass(cqltype, cqltypes.ReversedType): + from cassandra.deserializers import DesReversedType + + cls = DesReversedType + elif issubclass(cqltype, cqltypes.FrozenType): + from cassandra.deserializers import DesFrozenType + + cls = DesFrozenType + else: + from cassandra.deserializers import GenericDeserializer + + cls = GenericDeserializer + + return cls(cqltype) + + +def make_deserializers_uncached(ctypes): + """Original implementation without caching.""" + from cassandra.deserializers import obj_array + + return obj_array([find_deserializer_uncached(ct) for ct in ctypes]) + + +# --------------------------------------------------------------------------- +# Test type sets +# --------------------------------------------------------------------------- + +SIMPLE_TYPES = [ + cqltypes.Int32Type, + cqltypes.UTF8Type, + cqltypes.BooleanType, + cqltypes.DoubleType, + cqltypes.LongType, +] + +MIXED_TYPES = [ + cqltypes.Int32Type, + cqltypes.UTF8Type, + cqltypes.BooleanType, + cqltypes.DoubleType, + cqltypes.LongType, + cqltypes.FloatType, + cqltypes.TimestampType, + cqltypes.UUIDType, + cqltypes.InetAddressType, + cqltypes.DecimalType, +] + + +# --------------------------------------------------------------------------- +# Correctness tests +# --------------------------------------------------------------------------- + + +class TestDeserializerCacheCorrectness: + """Verify the cached implementation returns equivalent deserializers.""" + + @pytest.mark.parametrize("cqltype", SIMPLE_TYPES + MIXED_TYPES) + def test_find_deserializer_returns_correct_type(self, cqltype): + cached = find_deserializer(cqltype) + uncached = find_deserializer_uncached(cqltype) + assert type(cached).__name__ == type(uncached).__name__ + + def test_find_deserializer_cache_hit_same_object(self): + d1 = find_deserializer(cqltypes.Int32Type) + d2 = find_deserializer(cqltypes.Int32Type) + assert d1 is d2 + + def test_make_deserializers_returns_correct_length(self): + result = make_deserializers(SIMPLE_TYPES) + assert len(result) == len(SIMPLE_TYPES) + + def test_make_deserializers_cache_hit_same_object(self): + r1 = make_deserializers(SIMPLE_TYPES) + r2 = make_deserializers(SIMPLE_TYPES) + # Should be the exact same cached object + assert r1 is r2 + + +# --------------------------------------------------------------------------- +# Benchmarks +# --------------------------------------------------------------------------- + + +class TestFindDeserializerBenchmark: + """Benchmark find_deserializer cached vs uncached.""" + + # --- Single simple type --- + + @pytest.mark.benchmark(group="find_deser_simple") + def test_uncached_simple(self, benchmark): + benchmark(find_deserializer_uncached, cqltypes.Int32Type) + + @pytest.mark.benchmark(group="find_deser_simple") + def test_cached_simple(self, benchmark): + # Cache is already warm from correctness tests or previous iterations + find_deserializer(cqltypes.Int32Type) # ensure warm + benchmark(find_deserializer, cqltypes.Int32Type) + + +class TestMakeDeserializersBenchmark: + """Benchmark make_deserializers cached vs uncached.""" + + # --- 5 simple types --- + + @pytest.mark.benchmark(group="make_deser_5types") + def test_uncached_5types(self, benchmark): + benchmark(make_deserializers_uncached, SIMPLE_TYPES) + + @pytest.mark.benchmark(group="make_deser_5types") + def test_cached_5types(self, benchmark): + make_deserializers(SIMPLE_TYPES) # ensure warm + benchmark(make_deserializers, SIMPLE_TYPES) + + # --- 10 mixed types --- + + @pytest.mark.benchmark(group="make_deser_10types") + def test_uncached_10types(self, benchmark): + benchmark(make_deserializers_uncached, MIXED_TYPES) + + @pytest.mark.benchmark(group="make_deser_10types") + def test_cached_10types(self, benchmark): + make_deserializers(MIXED_TYPES) # ensure warm + benchmark(make_deserializers, MIXED_TYPES) diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx index 98e8676bbc..fd4308a02c 100644 --- a/cassandra/deserializers.pyx +++ b/cassandra/deserializers.pyx @@ -440,16 +440,54 @@ cdef class GenericDeserializer(Deserializer): #-------------------------------------------------------------------------- # Helper utilities +# Maximum number of entries in each deserializer cache. In practice the +# caches are bounded by the number of distinct column-type signatures in +# the schema (typically dozens to low hundreds), but parameterized types +# created via apply_parameters() for unprepared queries are *not* +# interned, so repeated simple queries could accumulate entries. The cap +# prevents unbounded growth in such edge cases. +cdef int _CACHE_MAX_SIZE = 256 + +# Cache make_deserializers results keyed on the tuple of cqltype objects. +# Using the cqltype objects themselves (rather than id()) as keys ensures +# the dict holds strong references, preventing GC and id() reuse issues +# with non-singleton parameterized types. +cdef dict _make_deserializers_cache = {} + def make_deserializers(cqltypes): """Create an array of Deserializers for each given cqltype in cqltypes""" - cdef Deserializer[::1] deserializers - return obj_array([find_deserializer(ct) for ct in cqltypes]) + cdef tuple key = tuple(cqltypes) + try: + return _make_deserializers_cache[key] + except KeyError: + pass + result = obj_array([find_deserializer(ct) for ct in cqltypes]) + if len(_make_deserializers_cache) >= _CACHE_MAX_SIZE: + _make_deserializers_cache.clear() + _make_deserializers_cache[key] = result + return result cdef dict classes = globals() +# Cache deserializer instances keyed on the cqltype object itself to avoid +# repeated class lookups and object creation on every result set. +# Using the object as key (rather than id()) holds a strong reference, +# preventing GC and id() reuse issues with parameterized types. +# +# Note: if a Des* class is overridden at runtime (e.g. DesBytesType = +# DesBytesTypeByteArray for cqlsh), callers must invoke +# clear_deserializer_caches() to flush stale entries so that subsequent +# find_deserializer() calls pick up the new class. +cdef dict _deserializer_cache = {} + cpdef Deserializer find_deserializer(cqltype): """Find a deserializer for a cqltype""" + try: + return _deserializer_cache[cqltype] + except KeyError: + pass + name = 'Des' + cqltype.__name__ if name in globals(): @@ -477,7 +515,28 @@ cpdef Deserializer find_deserializer(cqltype): else: cls = GenericDeserializer - return cls(cqltype) + cdef Deserializer result = cls(cqltype) + if len(_deserializer_cache) >= _CACHE_MAX_SIZE: + _deserializer_cache.clear() + _deserializer_cache[cqltype] = result + return result + + +def clear_deserializer_caches(): + """Clear the find_deserializer and make_deserializers caches. + + Call this after overriding a Des* class at runtime (e.g. + ``deserializers.DesBytesType = deserializers.DesBytesTypeByteArray``) + so that subsequent lookups pick up the new class instead of returning + stale cached instances. + """ + _deserializer_cache.clear() + _make_deserializers_cache.clear() + + +def get_deserializer_cache_sizes(): + """Return ``(find_cache_size, make_cache_size)`` for diagnostic use.""" + return len(_deserializer_cache), len(_make_deserializers_cache) def obj_array(list objs): diff --git a/tests/unit/cython/test_deserializer_cache.py b/tests/unit/cython/test_deserializer_cache.py new file mode 100644 index 0000000000..5924a71636 --- /dev/null +++ b/tests/unit/cython/test_deserializer_cache.py @@ -0,0 +1,190 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the deserializer caches in deserializers.pyx. + +Validates cache hit/miss behaviour, bounded eviction, the +clear_deserializer_caches() API (needed after runtime Des* overrides), +and the get_deserializer_cache_sizes() diagnostic helper. +""" + +import unittest + +from tests.unit.cython.utils import cythontest + +try: + from cassandra.deserializers import ( + clear_deserializer_caches, + find_deserializer, + get_deserializer_cache_sizes, + make_deserializers, + ) + + _HAS_DESERIALIZERS = True +except ImportError: + _HAS_DESERIALIZERS = False + +from cassandra import cqltypes + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class DeserializerCacheTest(unittest.TestCase): + """Tests for find_deserializer / make_deserializers caching.""" + + def setUp(self): + if _HAS_DESERIALIZERS: + clear_deserializer_caches() + + def tearDown(self): + if _HAS_DESERIALIZERS: + clear_deserializer_caches() + + # -- find_deserializer cache ------------------------------------------- + + @cythontest + def test_find_cache_hit_same_object(self): + """Repeated calls for the same cqltype return the same instance.""" + d1 = find_deserializer(cqltypes.Int32Type) + d2 = find_deserializer(cqltypes.Int32Type) + self.assertIs(d1, d2) + + @cythontest + def test_find_cache_miss_different_types(self): + """Different cqltypes produce different deserializer instances.""" + d_int = find_deserializer(cqltypes.Int32Type) + d_utf = find_deserializer(cqltypes.UTF8Type) + self.assertIsNot(d_int, d_utf) + + @cythontest + def test_find_returns_correct_deserializer_class(self): + """The returned deserializer class name matches the cqltype.""" + d = find_deserializer(cqltypes.Int32Type) + self.assertEqual(type(d).__name__, "DesInt32Type") + + # -- make_deserializers cache ------------------------------------------ + + @cythontest + def test_make_cache_hit_same_object(self): + """Repeated calls with the same type list return the same array.""" + types = [cqltypes.Int32Type, cqltypes.UTF8Type] + r1 = make_deserializers(types) + r2 = make_deserializers(types) + self.assertIs(r1, r2) + + @cythontest + def test_make_cache_correct_length(self): + """Returned array has the right number of entries.""" + types = [cqltypes.Int32Type, cqltypes.UTF8Type, cqltypes.BooleanType] + result = make_deserializers(types) + self.assertEqual(len(result), 3) + + # -- clear_deserializer_caches ----------------------------------------- + + @cythontest + def test_clear_invalidates_find_cache(self): + """After clearing, find_deserializer returns a new instance.""" + d1 = find_deserializer(cqltypes.Int32Type) + clear_deserializer_caches() + d2 = find_deserializer(cqltypes.Int32Type) + # New instance, but same deserializer class + self.assertIsNot(d1, d2) + self.assertEqual(type(d1).__name__, type(d2).__name__) + + @cythontest + def test_clear_invalidates_make_cache(self): + """After clearing, make_deserializers returns a new array.""" + types = [cqltypes.Int32Type, cqltypes.UTF8Type] + r1 = make_deserializers(types) + clear_deserializer_caches() + r2 = make_deserializers(types) + self.assertIsNot(r1, r2) + + # -- get_deserializer_cache_sizes -------------------------------------- + + @cythontest + def test_cache_sizes_empty_after_clear(self): + """Sizes are (0, 0) immediately after clearing.""" + find_size, make_size = get_deserializer_cache_sizes() + self.assertEqual(find_size, 0) + self.assertEqual(make_size, 0) + + @cythontest + def test_cache_sizes_increment(self): + """Sizes reflect the number of cached entries.""" + find_deserializer(cqltypes.Int32Type) + find_deserializer(cqltypes.UTF8Type) + make_deserializers([cqltypes.Int32Type, cqltypes.UTF8Type]) + + find_size, make_size = get_deserializer_cache_sizes() + self.assertEqual(find_size, 2) + self.assertEqual(make_size, 1) + + # -- bounded eviction -------------------------------------------------- + + @cythontest + def test_find_cache_bounded_size(self): + """find_deserializer cache should not exceed 256 entries.""" + # Create 300 distinct cqltype objects via apply_parameters. + # Each ListType.apply_parameters() call creates a fresh class. + inner_types = [ + cqltypes.Int32Type, + cqltypes.UTF8Type, + cqltypes.BooleanType, + cqltypes.DoubleType, + cqltypes.LongType, + ] + distinct_types = [] + for i in range(300): + # Create ListType(inner) — each apply_parameters returns a new + # class object, so these are all distinct cache keys. + inner = inner_types[i % len(inner_types)] + ct = cqltypes.ListType.apply_parameters([inner]) + distinct_types.append(ct) + + for ct in distinct_types: + find_deserializer(ct) + + find_size, _ = get_deserializer_cache_sizes() + self.assertLessEqual( + find_size, + 256, + "find_deserializer cache should be bounded to 256, got %d" % find_size, + ) + + @cythontest + def test_make_cache_bounded_size(self): + """make_deserializers cache should not exceed 256 entries.""" + inner_types = [ + cqltypes.Int32Type, + cqltypes.UTF8Type, + cqltypes.BooleanType, + cqltypes.DoubleType, + cqltypes.LongType, + ] + for i in range(300): + inner = inner_types[i % len(inner_types)] + ct = cqltypes.ListType.apply_parameters([inner]) + make_deserializers([ct]) + + _, make_size = get_deserializer_cache_sizes() + self.assertLessEqual( + make_size, + 256, + "make_deserializers cache should be bounded to 256, got %d" % make_size, + )