Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions kafka/protocol/new/api_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@ def __init__(cls, name, bases, attrs, **kw):
# We also include cls[None] -> primary class to "exit" a version class
if getattr(cls, '_class_version', None) is None:
cls._class_version = None
cls._VERSIONS = {None: weakref.proxy(cls)}
cls._VERSIONS = {}

def __getitem__(cls, version):
# Use [] lookups to move from primary class to "versioned" classes
# which are simple wrappers around the primary class but with a _version attr
if cls._class_version is not None:
return cls._VERSIONS[None][version]
primary_cls = cls.mro()[1]
if version is None:
return primary_cls
return primary_cls[version]
if cls._valid_versions is not None:
if version < 0:
version += 1 + cls.max_version # support negative index, e.g., [-1]
if not cls.min_version <= version <= cls.max_version:
raise ValueError('Invalid version! min=%d, max=%d' % (cls.min_version, cls.max_version))
if version in cls._VERSIONS:
return cls._VERSIONS[version]
klass_name = cls.__name__ + '_v' + str(version)
if klass_name in cls._VERSIONS:
return cls._VERSIONS[klass_name]
cls._VERSIONS[klass_name] = type(klass_name, tuple(cls.mro()), {'_class_version': version}, init=False)
return cls._VERSIONS[klass_name]
cls._VERSIONS[version] = type(klass_name, tuple(cls.mro()), {'_class_version': version}, init=False)
return cls._VERSIONS[version]

def __len__(cls):
# Maintain compatibility
Expand Down Expand Up @@ -84,10 +87,19 @@ def __init_subclass__(cls, **kw):
if not cls.is_request():
ResponseClassRegistry.register_response_class(weakref.proxy(cls))

def __new__(cls, *args, **kwargs):
# Translate "versioned" classes back to primary w/ version= kwarg on construction
if cls._class_version is not None:
if kwargs.get('version', cls._class_version) != cls._class_version: # pylint: disable=E1101
raise ValueError("Version has already been set by class")
kwargs['version'] = cls._class_version
instance = super().__new__(cls[None])
instance.__init__(*args, **kwargs)
return instance
return super().__new__(cls)

def __init__(self, *args, **kwargs):
self._header = None
if 'version' not in kwargs:
kwargs['version'] = self._class_version # pylint: disable=E1101
super().__init__(*args, **kwargs)

@classproperty
Expand Down Expand Up @@ -141,8 +153,6 @@ def API_VERSION(self):

@API_VERSION.setter
def API_VERSION(self, version):
if self._class_version is not None and self._class_version != version: # pylint: disable=E1101
raise ValueError("Version has already been set by class")
if not 0 <= version <= self.max_version:
raise ValueError('Invalid version %s (max version is %s).' % (version, self.max_version))
self._version = version
Expand Down
14 changes: 12 additions & 2 deletions kafka/protocol/new/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def __init__(self, *args, version=None, **field_vals):
if field_vals:
raise ValueError('Unrecognized fields for type %s: %s' % (self._struct.name, field_vals))

@property
def version(self):
return self._version

def encode(self, *args, **kwargs):
"""Add version= to kwargs, otherwise pass-through to _struct"""
return self._struct.encode(self, *args, **kwargs)
Expand All @@ -67,8 +71,14 @@ def fields(cls): # pylint: disable=E0213
return cls._struct.fields

def __repr__(self):
key_vals = ['%s=%s' % (field.name, repr(getattr(self, field.name)))
for field in self._struct._fields]
if self._version is not None:
v_filter = lambda field: field.for_version_q(self._version)
key_vals = ['version=%s' % self._version]
else:
v_filter = lambda field: True
key_vals = []
for field in filter(v_filter, self._struct._fields):
key_vals.append('%s=%s' % (field.name, repr(getattr(self, field.name))))
return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')'

def __eq__(self, other):
Expand Down
2 changes: 1 addition & 1 deletion test/protocol/new/metadata/test_new_api_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
def test_parse(msg, encoded):
msg.with_header(correlation_id=1, client_id='_internal_client_kYVL')
assert msg.encode(header=True, framed=True) == encoded
assert msg.decode(encoded, header=True, framed=True) == msg
assert msg.__class__.decode(encoded, version=msg.version, header=True, framed=True) == msg


@pytest.mark.parametrize('version', [0, 1, 2, 3, 4])
Expand Down
Loading