Skip to content

Commit 39ecb15

Browse files
authored
Refactor treatment of versioned ApiMessage classes (#2739)
1 parent 0f8b1f2 commit 39ecb15

3 files changed

Lines changed: 33 additions & 13 deletions

File tree

kafka/protocol/new/api_message.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,26 @@ def __init__(cls, name, bases, attrs, **kw):
1818
# We also include cls[None] -> primary class to "exit" a version class
1919
if getattr(cls, '_class_version', None) is None:
2020
cls._class_version = None
21-
cls._VERSIONS = {None: weakref.proxy(cls)}
21+
cls._VERSIONS = {}
2222

2323
def __getitem__(cls, version):
2424
# Use [] lookups to move from primary class to "versioned" classes
2525
# which are simple wrappers around the primary class but with a _version attr
2626
if cls._class_version is not None:
27-
return cls._VERSIONS[None][version]
27+
primary_cls = cls.mro()[1]
28+
if version is None:
29+
return primary_cls
30+
return primary_cls[version]
2831
if cls._valid_versions is not None:
2932
if version < 0:
3033
version += 1 + cls.max_version # support negative index, e.g., [-1]
3134
if not cls.min_version <= version <= cls.max_version:
3235
raise ValueError('Invalid version! min=%d, max=%d' % (cls.min_version, cls.max_version))
36+
if version in cls._VERSIONS:
37+
return cls._VERSIONS[version]
3338
klass_name = cls.__name__ + '_v' + str(version)
34-
if klass_name in cls._VERSIONS:
35-
return cls._VERSIONS[klass_name]
36-
cls._VERSIONS[klass_name] = type(klass_name, tuple(cls.mro()), {'_class_version': version}, init=False)
37-
return cls._VERSIONS[klass_name]
39+
cls._VERSIONS[version] = type(klass_name, tuple(cls.mro()), {'_class_version': version}, init=False)
40+
return cls._VERSIONS[version]
3841

3942
def __len__(cls):
4043
# Maintain compatibility
@@ -84,10 +87,19 @@ def __init_subclass__(cls, **kw):
8487
if not cls.is_request():
8588
ResponseClassRegistry.register_response_class(weakref.proxy(cls))
8689

90+
def __new__(cls, *args, **kwargs):
91+
# Translate "versioned" classes back to primary w/ version= kwarg on construction
92+
if cls._class_version is not None:
93+
if kwargs.get('version', cls._class_version) != cls._class_version: # pylint: disable=E1101
94+
raise ValueError("Version has already been set by class")
95+
kwargs['version'] = cls._class_version
96+
instance = super().__new__(cls[None])
97+
instance.__init__(*args, **kwargs)
98+
return instance
99+
return super().__new__(cls)
100+
87101
def __init__(self, *args, **kwargs):
88102
self._header = None
89-
if 'version' not in kwargs:
90-
kwargs['version'] = self._class_version # pylint: disable=E1101
91103
super().__init__(*args, **kwargs)
92104

93105
@classproperty
@@ -141,8 +153,6 @@ def API_VERSION(self):
141153

142154
@API_VERSION.setter
143155
def API_VERSION(self, version):
144-
if self._class_version is not None and self._class_version != version: # pylint: disable=E1101
145-
raise ValueError("Version has already been set by class")
146156
if not 0 <= version <= self.max_version:
147157
raise ValueError('Invalid version %s (max version is %s).' % (version, self.max_version))
148158
self._version = version

kafka/protocol/new/data_container.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def __init__(self, *args, version=None, **field_vals):
5353
if field_vals:
5454
raise ValueError('Unrecognized fields for type %s: %s' % (self._struct.name, field_vals))
5555

56+
@property
57+
def version(self):
58+
return self._version
59+
5660
def encode(self, *args, **kwargs):
5761
"""Add version= to kwargs, otherwise pass-through to _struct"""
5862
return self._struct.encode(self, *args, **kwargs)
@@ -67,8 +71,14 @@ def fields(cls): # pylint: disable=E0213
6771
return cls._struct.fields
6872

6973
def __repr__(self):
70-
key_vals = ['%s=%s' % (field.name, repr(getattr(self, field.name)))
71-
for field in self._struct._fields]
74+
if self._version is not None:
75+
v_filter = lambda field: field.for_version_q(self._version)
76+
key_vals = ['version=%s' % self._version]
77+
else:
78+
v_filter = lambda field: True
79+
key_vals = []
80+
for field in filter(v_filter, self._struct._fields):
81+
key_vals.append('%s=%s' % (field.name, repr(getattr(self, field.name))))
7282
return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')'
7383

7484
def __eq__(self, other):

test/protocol/new/metadata/test_new_api_versions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
def test_parse(msg, encoded):
6767
msg.with_header(correlation_id=1, client_id='_internal_client_kYVL')
6868
assert msg.encode(header=True, framed=True) == encoded
69-
assert msg.decode(encoded, header=True, framed=True) == msg
69+
assert msg.__class__.decode(encoded, version=msg.version, header=True, framed=True) == msg
7070

7171

7272
@pytest.mark.parametrize('version', [0, 1, 2, 3, 4])

0 commit comments

Comments
 (0)