diff --git a/kafka/protocol/new/api_message.py b/kafka/protocol/new/api_message.py index 6e4a689d2..8efb31e2b 100644 --- a/kafka/protocol/new/api_message.py +++ b/kafka/protocol/new/api_message.py @@ -18,23 +18,26 @@ def __init__(cls, name, bases, attrs, **kw): # We also include cls[None] -> primary class to "exit" a version class if getattr(cls, '_class_version', None) is None: cls._class_version = None - cls._VERSIONS = {None: weakref.proxy(cls)} + cls._VERSIONS = {} def __getitem__(cls, version): # Use [] lookups to move from primary class to "versioned" classes # which are simple wrappers around the primary class but with a _version attr if cls._class_version is not None: - return cls._VERSIONS[None][version] + primary_cls = cls.mro()[1] + if version is None: + return primary_cls + return primary_cls[version] if cls._valid_versions is not None: if version < 0: version += 1 + cls.max_version # support negative index, e.g., [-1] if not cls.min_version <= version <= cls.max_version: raise ValueError('Invalid version! min=%d, max=%d' % (cls.min_version, cls.max_version)) + if version in cls._VERSIONS: + return cls._VERSIONS[version] klass_name = cls.__name__ + '_v' + str(version) - if klass_name in cls._VERSIONS: - return cls._VERSIONS[klass_name] - cls._VERSIONS[klass_name] = type(klass_name, tuple(cls.mro()), {'_class_version': version}, init=False) - return cls._VERSIONS[klass_name] + cls._VERSIONS[version] = type(klass_name, tuple(cls.mro()), {'_class_version': version}, init=False) + return cls._VERSIONS[version] def __len__(cls): # Maintain compatibility @@ -84,10 +87,19 @@ def __init_subclass__(cls, **kw): if not cls.is_request(): ResponseClassRegistry.register_response_class(weakref.proxy(cls)) + def __new__(cls, *args, **kwargs): + # Translate "versioned" classes back to primary w/ version= kwarg on construction + if cls._class_version is not None: + if kwargs.get('version', cls._class_version) != cls._class_version: # pylint: disable=E1101 + raise ValueError("Version has already been set by class") + kwargs['version'] = cls._class_version + instance = super().__new__(cls[None]) + instance.__init__(*args, **kwargs) + return instance + return super().__new__(cls) + def __init__(self, *args, **kwargs): self._header = None - if 'version' not in kwargs: - kwargs['version'] = self._class_version # pylint: disable=E1101 super().__init__(*args, **kwargs) @classproperty @@ -141,8 +153,6 @@ def API_VERSION(self): @API_VERSION.setter def API_VERSION(self, version): - if self._class_version is not None and self._class_version != version: # pylint: disable=E1101 - raise ValueError("Version has already been set by class") if not 0 <= version <= self.max_version: raise ValueError('Invalid version %s (max version is %s).' % (version, self.max_version)) self._version = version diff --git a/kafka/protocol/new/data_container.py b/kafka/protocol/new/data_container.py index 6df3ce8dc..56ca3e296 100644 --- a/kafka/protocol/new/data_container.py +++ b/kafka/protocol/new/data_container.py @@ -53,6 +53,10 @@ def __init__(self, *args, version=None, **field_vals): if field_vals: raise ValueError('Unrecognized fields for type %s: %s' % (self._struct.name, field_vals)) + @property + def version(self): + return self._version + def encode(self, *args, **kwargs): """Add version= to kwargs, otherwise pass-through to _struct""" return self._struct.encode(self, *args, **kwargs) @@ -67,8 +71,14 @@ def fields(cls): # pylint: disable=E0213 return cls._struct.fields def __repr__(self): - key_vals = ['%s=%s' % (field.name, repr(getattr(self, field.name))) - for field in self._struct._fields] + if self._version is not None: + v_filter = lambda field: field.for_version_q(self._version) + key_vals = ['version=%s' % self._version] + else: + v_filter = lambda field: True + key_vals = [] + for field in filter(v_filter, self._struct._fields): + key_vals.append('%s=%s' % (field.name, repr(getattr(self, field.name)))) return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')' def __eq__(self, other): diff --git a/test/protocol/new/metadata/test_new_api_versions.py b/test/protocol/new/metadata/test_new_api_versions.py index 9aa33750c..6e34096c9 100644 --- a/test/protocol/new/metadata/test_new_api_versions.py +++ b/test/protocol/new/metadata/test_new_api_versions.py @@ -66,7 +66,7 @@ def test_parse(msg, encoded): msg.with_header(correlation_id=1, client_id='_internal_client_kYVL') assert msg.encode(header=True, framed=True) == encoded - assert msg.decode(encoded, header=True, framed=True) == msg + assert msg.__class__.decode(encoded, version=msg.version, header=True, framed=True) == msg @pytest.mark.parametrize('version', [0, 1, 2, 3, 4])