From 2f560a6761e0153b6702e308bd2983d1978d4412 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 23 Mar 2026 13:27:08 -0700 Subject: [PATCH] Fix TaggedFields encoding/decoding of struct fields --- .../schemas/fields/codecs/tagged_fields.py | 22 +++++++++++------- kafka/protocol/new/schemas/fields/struct.py | 14 +++++++---- .../new/metadata/test_new_api_versions.py | 23 ++++++++++++++++++- .../new/schemas/test_new_tagged_fields.py | 22 ++++++++++++++++-- 4 files changed, 66 insertions(+), 15 deletions(-) diff --git a/kafka/protocol/new/schemas/fields/codecs/tagged_fields.py b/kafka/protocol/new/schemas/fields/codecs/tagged_fields.py index b15f24c6a..f2bba5f23 100644 --- a/kafka/protocol/new/schemas/fields/codecs/tagged_fields.py +++ b/kafka/protocol/new/schemas/fields/codecs/tagged_fields.py @@ -7,8 +7,7 @@ def __init__(self, fields): self._tags = {field.tag: field for field in self._fields} self._names = {field.name: field for field in self._fields} - def encode(self, item, version=None, compact=True, tagged=False): - assert compact and not tagged + def encode(self, item, version=None): if isinstance(item, dict): tags = [(self._names[name].tag, val) for name, val in item.items() @@ -22,15 +21,13 @@ def encode(self, item, version=None, compact=True, tagged=False): ret = [UnsignedVarInt32.encode(len(tags))] for tag, val in tags: ret.append(UnsignedVarInt32.encode(tag)) - # Tags that are structs never include nested tagged fields - encoded_val = self._tags[tag].encode(val, version=version, - compact=True, tagged=False) + # struct tags have an empty nested tagged fields (tagged=None) + encoded_val = self._tags[tag].encode(val, version=version, compact=True, tagged=None) ret.append(UnsignedVarInt32.encode(len(encoded_val))) ret.append(encoded_val) return b''.join(ret) - def decode(self, data, version=None, compact=True, tagged=False): - assert compact and not tagged + def decode(self, data, version=None): num_fields = UnsignedVarInt32.decode(data) ret = {} for i in range(num_fields): @@ -38,10 +35,19 @@ def decode(self, data, version=None, compact=True, tagged=False): size = UnsignedVarInt32.decode(data) if tag in self._tags: field = self._tags[tag] - ret[field.name] = field.decode(data, version=version, compact=compact, tagged=tagged) + # struct tags have an empty nested tagged fields (tagged=None) + ret[field.name] = field.decode(data, version=version, compact=True, tagged=None) else: ret['_%d' % tag] = data.read(size) return ret + @classmethod + def decode_empty(cls, data): + assert UnsignedVarInt32.decode(data) == 0 + + @classmethod + def encode_empty(cls): + return UnsignedVarInt32.encode(0) + def __repr__(self): return 'TaggedFields(%s)' % list(self._names.keys()) diff --git a/kafka/protocol/new/schemas/fields/struct.py b/kafka/protocol/new/schemas/fields/struct.py index aefcd6a36..5ff7c70c1 100644 --- a/kafka/protocol/new/schemas/fields/struct.py +++ b/kafka/protocol/new/schemas/fields/struct.py @@ -83,8 +83,9 @@ def encode(self, item, version=None, compact=False, tagged=False): for i, field in enumerate(fields)] if tagged: # TaggedFields are always compact and never include nested tagged fields - encoded.append(self.tagged_fields(version).encode(tags, version=version, - compact=True, tagged=False)) + encoded.append(self.tagged_fields(version).encode(tags, version=version)) + elif tagged is None: + encoded.append(TaggedFields.encode_empty()) return b''.join(encoded) def decode(self, data, version=None, compact=False, tagged=False, data_class=None): @@ -99,8 +100,13 @@ def decode(self, data, version=None, compact=False, tagged=False, data_class=Non if field.for_version_q(version) and not field.tagged_field_q(version) } if tagged: - decoded.update(self.tagged_fields(version).decode(data, version=version, compact=True, tagged=False)) - return data_class(version=version, **decoded) + decoded.update(self.tagged_fields(version).decode(data, version=version)) + elif tagged is None: + TaggedFields.decode_empty(data) + + if data_class is not None: + return data_class(version=version, **decoded) + return decoded def __len__(self): return len(self._fields) diff --git a/test/protocol/new/metadata/test_new_api_versions.py b/test/protocol/new/metadata/test_new_api_versions.py index 055d0e188..1d235b18c 100644 --- a/test/protocol/new/metadata/test_new_api_versions.py +++ b/test/protocol/new/metadata/test_new_api_versions.py @@ -1,5 +1,6 @@ -import pytest +import io +import pytest from kafka.protocol.new.api_header import ResponseHeader from kafka.protocol.new.metadata import ApiVersionsRequest, ApiVersionsResponse @@ -121,3 +122,23 @@ def test_api_versions_response_roundtrip(version): encoded = response.encode(version=version) decoded = ApiVersionsResponse.decode(encoded, version=version) assert decoded == response + + +def test_supported_features(): + encoded = b'\x00\x00\x00\x01\x00\x00>\x00\x00\x00\x00\x00\x0c\x00\x00\x01\x00\x04\x00\x11\x00\x00\x02\x00\x01\x00\n\x00\x00\x03\x00\x00\x00\r\x00\x00\x08\x00\x02\x00\t\x00\x00\t\x00\x01\x00\t\x00\x00\n\x00\x00\x00\x06\x00\x00\x0b\x00\x02\x00\t\x00\x00\x0c\x00\x00\x00\x04\x00\x00\r\x00\x00\x00\x05\x00\x00\x0e\x00\x00\x00\x05\x00\x00\x0f\x00\x00\x00\x06\x00\x00\x10\x00\x00\x00\x05\x00\x00\x11\x00\x00\x00\x01\x00\x00\x12\x00\x00\x00\x04\x00\x00\x13\x00\x02\x00\x07\x00\x00\x14\x00\x01\x00\x06\x00\x00\x15\x00\x00\x00\x02\x00\x00\x16\x00\x00\x00\x05\x00\x00\x17\x00\x02\x00\x04\x00\x00\x18\x00\x00\x00\x05\x00\x00\x19\x00\x00\x00\x04\x00\x00\x1a\x00\x00\x00\x05\x00\x00\x1b\x00\x01\x00\x01\x00\x00\x1c\x00\x00\x00\x05\x00\x00\x1d\x00\x01\x00\x03\x00\x00\x1e\x00\x01\x00\x03\x00\x00\x1f\x00\x01\x00\x03\x00\x00 \x00\x01\x00\x04\x00\x00!\x00\x00\x00\x02\x00\x00"\x00\x01\x00\x02\x00\x00#\x00\x01\x00\x04\x00\x00$\x00\x00\x00\x02\x00\x00%\x00\x00\x00\x03\x00\x00&\x00\x01\x00\x03\x00\x00\'\x00\x01\x00\x02\x00\x00(\x00\x01\x00\x02\x00\x00)\x00\x01\x00\x03\x00\x00*\x00\x00\x00\x02\x00\x00+\x00\x00\x00\x02\x00\x00,\x00\x00\x00\x01\x00\x00-\x00\x00\x00\x00\x00\x00.\x00\x00\x00\x00\x00\x00/\x00\x00\x00\x00\x00\x000\x00\x00\x00\x01\x00\x001\x00\x00\x00\x01\x00\x002\x00\x00\x00\x00\x00\x003\x00\x00\x00\x00\x00\x007\x00\x00\x00\x02\x00\x009\x00\x00\x00\x02\x00\x00<\x00\x00\x00\x02\x00\x00=\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00A\x00\x00\x00\x00\x00\x00B\x00\x00\x00\x01\x00\x00D\x00\x00\x00\x01\x00\x00E\x00\x00\x00\x01\x00\x00J\x00\x00\x00\x00\x00\x00K\x00\x00\x00\x00\x00\x00P\x00\x00\x00\x00\x00\x00Q\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x00|\x06\x0egroup.version\x00\x00\x00\x01\x00\x0ekraft.version\x00\x00\x00\x01\x00\x11metadata.version\x00\x07\x00\x19\x00\x14transaction.version\x00\x00\x00\x02\x00!eligible.leader.replicas.version\x00\x00\x00\x01\x00\x01\x08\x00\x00\x00\x00\x00\x00\x00\n\x02V\x05\x0egroup.version\x00\x01\x00\x01\x00\x14transaction.version\x00\x02\x00\x02\x00\x0ekraft.version\x00\x01\x00\x01\x00\x11metadata.version\x00\x19\x00\x19\x00' + + supported_features = b'\x06\x0egroup.version\x00\x00\x00\x01\x00\x0ekraft.version\x00\x00\x00\x01\x00\x11metadata.version\x00\x07\x00\x19\x00\x14transaction.version\x00\x00\x00\x02\x00!eligible.leader.replicas.version\x00\x00\x00\x01\x00' + + data = io.BytesIO(supported_features) + features = ApiVersionsResponse.fields['supported_features'].decode(data, version=4, compact=True, tagged=None) + assert len(features) == 5 + + data = io.BytesIO(encoded) + ApiVersionsResponse[4].parse_header(data) + decoded = ApiVersionsResponse.decode(data, version=4) + assert decoded.version == 4 + assert len(decoded.supported_features) == 5 + assert set([feature.name for feature in decoded.supported_features]) == set([ + 'group.version', 'kraft.version', 'metadata.version', + 'transaction.version', 'eligible.leader.replicas.version', + ]) diff --git a/test/protocol/new/schemas/test_new_tagged_fields.py b/test/protocol/new/schemas/test_new_tagged_fields.py index 1de678af4..d1ddfebca 100644 --- a/test/protocol/new/schemas/test_new_tagged_fields.py +++ b/test/protocol/new/schemas/test_new_tagged_fields.py @@ -2,7 +2,7 @@ import pytest -from kafka.protocol.new.schemas.fields import SimpleField +from kafka.protocol.new.schemas.fields import SimpleField, StructField from kafka.protocol.new.schemas.fields.codecs import TaggedFields, UnsignedVarInt32 @@ -13,10 +13,28 @@ def test_tagged_fields(): ]) val = {'foo': 2, 'bar': 'foobar'} encoded = tags.encode(val, version=0) - # length(2), tag(0), size(2), b'\x00\x02', tag(1), size(7), len(7), 'foobar' + # num_tags(2), tag(0), size(2), b'\x00\x02', tag(1), size(7), len(6+1), 'foobar' expected = (UnsignedVarInt32.encode(2) + UnsignedVarInt32.encode(0) + UnsignedVarInt32.encode(2) + b'\x00\x02' + UnsignedVarInt32.encode(1) + UnsignedVarInt32.encode(7) + UnsignedVarInt32.encode(7) + b'foobar') assert encoded == expected decoded = tags.decode(io.BytesIO(encoded), version=0) assert decoded == val + + +def test_tagged_fields_struct(): + tags = TaggedFields([ + StructField({'name': 'foo', 'tag': 0, 'type': 'Bar', 'versions': "0+", "fields": [ + {'name': 'bar', 'tag': 1, 'type': 'string', 'versions': "0+"}, + ]}), + ]) + val = {'foo': {'bar': 'foobar'}} + encoded = tags.encode(val, version=0) + # num_tags(1), tag(0), size(8), len(6+1), 'foobar', empty tags(\x00) + expected = (UnsignedVarInt32.encode(1) + UnsignedVarInt32.encode(0) + UnsignedVarInt32.encode(8) + + UnsignedVarInt32.encode(7) + b'foobar' + + UnsignedVarInt32.encode(0)) + assert encoded == expected + decoded = tags.decode(io.BytesIO(encoded), version=0) + assert decoded == val +