From a726e4194ac6e0c6cca3fb85e6ca9821d160ffc0 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 18 Mar 2026 14:48:34 -0700 Subject: [PATCH 1/4] ApiMessage._VERSIONS dict refactor: int keys; drop None; use mro()[1] for primary_cls --- kafka/protocol/new/api_message.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/kafka/protocol/new/api_message.py b/kafka/protocol/new/api_message.py index 6e4a689d2..f1f9ad5f1 100644 --- a/kafka/protocol/new/api_message.py +++ b/kafka/protocol/new/api_message.py @@ -18,23 +18,26 @@ def __init__(cls, name, bases, attrs, **kw): # We also include cls[None] -> primary class to "exit" a version class if getattr(cls, '_class_version', None) is None: cls._class_version = None - cls._VERSIONS = {None: weakref.proxy(cls)} + cls._VERSIONS = {} def __getitem__(cls, version): # Use [] lookups to move from primary class to "versioned" classes # which are simple wrappers around the primary class but with a _version attr if cls._class_version is not None: - return cls._VERSIONS[None][version] + primary_cls = cls.mro()[1] + if version is None: + return primary_cls + return primary_cls[version] if cls._valid_versions is not None: if version < 0: version += 1 + cls.max_version # support negative index, e.g., [-1] if not cls.min_version <= version <= cls.max_version: raise ValueError('Invalid version! min=%d, max=%d' % (cls.min_version, cls.max_version)) + if version in cls._VERSIONS: + return cls._VERSIONS[version] klass_name = cls.__name__ + '_v' + str(version) - if klass_name in cls._VERSIONS: - return cls._VERSIONS[klass_name] - cls._VERSIONS[klass_name] = type(klass_name, tuple(cls.mro()), {'_class_version': version}, init=False) - return cls._VERSIONS[klass_name] + cls._VERSIONS[version] = type(klass_name, tuple(cls.mro()), {'_class_version': version}, init=False) + return cls._VERSIONS[version] def __len__(cls): # Maintain compatibility From 26b3fe9e30890cb16555548988e954af9d32c7cb Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 18 Mar 2026 14:49:15 -0700 Subject: [PATCH 2/4] Always return primary class instances; translate _class_version -> version kwarg for init --- kafka/protocol/new/api_message.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/kafka/protocol/new/api_message.py b/kafka/protocol/new/api_message.py index f1f9ad5f1..8efb31e2b 100644 --- a/kafka/protocol/new/api_message.py +++ b/kafka/protocol/new/api_message.py @@ -87,10 +87,19 @@ def __init_subclass__(cls, **kw): if not cls.is_request(): ResponseClassRegistry.register_response_class(weakref.proxy(cls)) + def __new__(cls, *args, **kwargs): + # Translate "versioned" classes back to primary w/ version= kwarg on construction + if cls._class_version is not None: + if kwargs.get('version', cls._class_version) != cls._class_version: # pylint: disable=E1101 + raise ValueError("Version has already been set by class") + kwargs['version'] = cls._class_version + instance = super().__new__(cls[None]) + instance.__init__(*args, **kwargs) + return instance + return super().__new__(cls) + def __init__(self, *args, **kwargs): self._header = None - if 'version' not in kwargs: - kwargs['version'] = self._class_version # pylint: disable=E1101 super().__init__(*args, **kwargs) @classproperty @@ -144,8 +153,6 @@ def API_VERSION(self): @API_VERSION.setter def API_VERSION(self, version): - if self._class_version is not None and self._class_version != version: # pylint: disable=E1101 - raise ValueError("Version has already been set by class") if not 0 <= version <= self.max_version: raise ValueError('Invalid version %s (max version is %s).' % (version, self.max_version)) self._version = version From db70e5c93270aaa8bd04bed946d633e05acbfeee Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 18 Mar 2026 14:52:24 -0700 Subject: [PATCH 3/4] Dont use instance.decode() in test_new_api_versions --- kafka/protocol/new/data_container.py | 4 ++++ test/protocol/new/metadata/test_new_api_versions.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/kafka/protocol/new/data_container.py b/kafka/protocol/new/data_container.py index 6df3ce8dc..baeca943d 100644 --- a/kafka/protocol/new/data_container.py +++ b/kafka/protocol/new/data_container.py @@ -53,6 +53,10 @@ def __init__(self, *args, version=None, **field_vals): if field_vals: raise ValueError('Unrecognized fields for type %s: %s' % (self._struct.name, field_vals)) + @property + def version(self): + return self._version + def encode(self, *args, **kwargs): """Add version= to kwargs, otherwise pass-through to _struct""" return self._struct.encode(self, *args, **kwargs) diff --git a/test/protocol/new/metadata/test_new_api_versions.py b/test/protocol/new/metadata/test_new_api_versions.py index 9aa33750c..6e34096c9 100644 --- a/test/protocol/new/metadata/test_new_api_versions.py +++ b/test/protocol/new/metadata/test_new_api_versions.py @@ -66,7 +66,7 @@ def test_parse(msg, encoded): msg.with_header(correlation_id=1, client_id='_internal_client_kYVL') assert msg.encode(header=True, framed=True) == encoded - assert msg.decode(encoded, header=True, framed=True) == msg + assert msg.__class__.decode(encoded, version=msg.version, header=True, framed=True) == msg @pytest.mark.parametrize('version', [0, 1, 2, 3, 4]) From 3802a388fbbbb7b30a5cd3336b11a306841fff5e Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 18 Mar 2026 14:53:58 -0700 Subject: [PATCH 4/4] version aware DataContainer __repr__ --- kafka/protocol/new/data_container.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/kafka/protocol/new/data_container.py b/kafka/protocol/new/data_container.py index baeca943d..56ca3e296 100644 --- a/kafka/protocol/new/data_container.py +++ b/kafka/protocol/new/data_container.py @@ -71,8 +71,14 @@ def fields(cls): # pylint: disable=E0213 return cls._struct.fields def __repr__(self): - key_vals = ['%s=%s' % (field.name, repr(getattr(self, field.name))) - for field in self._struct._fields] + if self._version is not None: + v_filter = lambda field: field.for_version_q(self._version) + key_vals = ['version=%s' % self._version] + else: + v_filter = lambda field: True + key_vals = [] + for field in filter(v_filter, self._struct._fields): + key_vals.append('%s=%s' % (field.name, repr(getattr(self, field.name)))) return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')' def __eq__(self, other):