From 942e6f601443fc080a19caa8f357fab9379e7fb9 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 8 Mar 2026 15:35:23 -0700 Subject: [PATCH 1/5] Use RequestResponse class for common bits; add header attr and with_header() to construct --- kafka/protocol/api.py | 100 ++++++++++++++++----------------------- kafka/protocol/parser.py | 4 +- 2 files changed, 43 insertions(+), 61 deletions(-) diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index 0e2d4a954..9c4d54627 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -38,9 +38,13 @@ class ResponseHeaderV2(Struct): ) -class Request(Struct, metaclass=abc.ABCMeta): +class RequestResponse(Struct, metaclass=abc.ABCMeta): FLEXIBLE_VERSION = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._header = None + @abc.abstractproperty def API_KEY(self): """Integer identifier for api request""" @@ -51,36 +55,20 @@ def API_VERSION(self): """Integer of api request version""" pass - @abc.abstractproperty - def RESPONSE_TYPE(self): - """The Response class associated with the api request""" - pass - - def expect_response(self): - """Override this method if an api request does not always generate a response""" - return True - def to_object(self): return _to_object(self.SCHEMA, self) - def build_header(self, correlation_id=0, client_id='kafka-python'): - if self.FLEXIBLE_VERSION: - return RequestHeaderV2(self.API_KEY, self.API_VERSION, correlation_id, client_id, {}) - return RequestHeader(self.API_KEY, self.API_VERSION, correlation_id, client_id) - - @classmethod - def parse_header(cls, read_buffer): - if cls.FLEXIBLE_VERSION: - return RequestHeaderV2.decode(read_buffer) - return RequestHeader.decode(read_buffer) + @property + def header(self): + return self._header - def encode(self, header=False, framed=False, correlation_id=None, client_id=None, **kwargs): + def encode(self, header=False, framed=False): data = super().encode() if not framed and not header: return data bits = [data] if header: - bits.insert(0, self.build_header(correlation_id, client_id).encode()) + bits.insert(0, self.header.encode()) if framed: bits.insert(0, Int32.encode(sum(map(len, bits)))) return b''.join(bits) @@ -100,26 +88,40 @@ def decode(cls, data, header=False, framed=False): return tuple(ret) -class Response(Struct, metaclass=abc.ABCMeta): - FLEXIBLE_VERSION = False - +class Request(RequestResponse): @abc.abstractproperty - def API_KEY(self): - """Integer identifier for api request/response""" + def RESPONSE_TYPE(self): + """The Response class associated with the api request""" pass - @abc.abstractproperty - def API_VERSION(self): - """Integer of api request/response version""" - pass + def expect_response(self): + """Override this method if an api request does not always generate a response""" + return True - def to_object(self): - return _to_object(self.SCHEMA, self) + def with_header(self, correlation_id=0, client_id='kafka-python'): + if self.FLEXIBLE_VERSION: + self._header = RequestHeaderV2(self.API_KEY, self.API_VERSION, correlation_id, client_id, {}) + else: + self._header = RequestHeader(self.API_KEY, self.API_VERSION, correlation_id, client_id) - def build_header(self, correlation_id=0): + @classmethod + def parse_header(cls, read_buffer): + if cls.FLEXIBLE_VERSION: + return RequestHeaderV2.decode(read_buffer) + return RequestHeader.decode(read_buffer) + + def encode(self, header=False, framed=False, correlation_id=None, client_id=None, **kwargs): + if header and self.header is None: + self.with_header(correlation_id=correlation_id, client_id=client_id) + return super().encode(header=header, framed=framed) + + +class Response(RequestResponse): + def with_header(self, correlation_id=0): if self.FLEXIBLE_VERSION: - return ResponseHeaderV2(correlation_id=correlation_id, tags=None) - return ResponseHeader(correlation_id=correlation_id) + self._header = ResponseHeaderV2(correlation_id=correlation_id, tags={}) + else: + self._header = ResponseHeader(correlation_id=correlation_id) @classmethod def parse_header(cls, read_buffer): @@ -128,29 +130,9 @@ def parse_header(cls, read_buffer): return ResponseHeader.decode(read_buffer) def encode(self, header=False, framed=False, correlation_id=None, **kwargs): - data = super().encode() - if not framed and not header: - return data - bits = [data] - if header: - bits.insert(0, self.build_header(correlation_id).encode()) - if framed: - bits.insert(0, Int32.encode(sum(map(len, bits)))) - return b''.join(bits) - - @classmethod - def decode(cls, data, header=False, framed=False): - if not framed and not header: - return super().decode(data) - if isinstance(data, bytes): - data = BytesIO(data) - ret = [] - if framed: - ret.append(Int32.decode(data)) - if header: - ret.append(cls.parse_header(data)) - ret.append(super().decode(data)) - return tuple(ret) + if header and self.header is None: + self.with_header(correlation_id=correlation_id) + return super().encode(header=header, framed=framed) def _to_object(schema, data): diff --git a/kafka/protocol/parser.py b/kafka/protocol/parser.py index 96066806a..0ebe5ace7 100644 --- a/kafka/protocol/parser.py +++ b/kafka/protocol/parser.py @@ -58,8 +58,8 @@ def send_request(self, request, correlation_id=None): correlation_id = self._next_correlation_id() log.debug('%s Sending request %d %s', self._ident, correlation_id, request) - data = request.encode(correlation_id=correlation_id, client_id=self._client_id, - framed=True, header=True) + request.with_header(correlation_id=correlation_id, client_id=self._client_id) + data = request.encode(framed=True, header=True) self.bytes_to_send.append(data) if request.expect_response(): ifr = (correlation_id, request) From 97d36cedb566ea2cfa25e3145fa17d1a59529784 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 8 Mar 2026 16:00:32 -0700 Subject: [PATCH 2/5] move parse_header to shared Request/Response; add header_class abstract property --- kafka/protocol/api.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index 9c4d54627..cebdc78eb 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -73,6 +73,15 @@ def encode(self, header=False, framed=False): bits.insert(0, Int32.encode(sum(map(len, bits)))) return b''.join(bits) + @classmethod + @abc.abstractproperty + def header_class(cls): + pass + + @classmethod + def parse_header(cls, read_buffer): + return cls.header_class.decode(read_buffer) + @classmethod def decode(cls, data, header=False, framed=False): if not framed and not header: @@ -100,15 +109,17 @@ def expect_response(self): def with_header(self, correlation_id=0, client_id='kafka-python'): if self.FLEXIBLE_VERSION: - self._header = RequestHeaderV2(self.API_KEY, self.API_VERSION, correlation_id, client_id, {}) + self._header = self.header_class(self.API_KEY, self.API_VERSION, correlation_id, client_id, {}) else: - self._header = RequestHeader(self.API_KEY, self.API_VERSION, correlation_id, client_id) + self._header = self.header_class(self.API_KEY, self.API_VERSION, correlation_id, client_id) @classmethod - def parse_header(cls, read_buffer): + @property + def header_class(cls): if cls.FLEXIBLE_VERSION: - return RequestHeaderV2.decode(read_buffer) - return RequestHeader.decode(read_buffer) + return RequestHeaderV2 + else: + return RequestHeader def encode(self, header=False, framed=False, correlation_id=None, client_id=None, **kwargs): if header and self.header is None: @@ -119,15 +130,17 @@ def encode(self, header=False, framed=False, correlation_id=None, client_id=None class Response(RequestResponse): def with_header(self, correlation_id=0): if self.FLEXIBLE_VERSION: - self._header = ResponseHeaderV2(correlation_id=correlation_id, tags={}) + self._header = self.header_class(correlation_id, {}) else: - self._header = ResponseHeader(correlation_id=correlation_id) + self._header = self.header_class(correlation_id) @classmethod - def parse_header(cls, read_buffer): + @property + def header_class(cls): if cls.FLEXIBLE_VERSION: - return ResponseHeaderV2.decode(read_buffer) - return ResponseHeader.decode(read_buffer) + return ResponseHeaderV2 + else: + return ResponseHeader def encode(self, header=False, framed=False, correlation_id=None, **kwargs): if header and self.header is None: From 964bcc417c58db91100d0bb2d78dbb97a26301a8 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 8 Mar 2026 16:14:50 -0700 Subject: [PATCH 3/5] Avoid python 3.14 TypeError: property not callable --- kafka/protocol/api.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index cebdc78eb..5f7faec84 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -80,7 +80,8 @@ def header_class(cls): @classmethod def parse_header(cls, read_buffer): - return cls.header_class.decode(read_buffer) + klass = cls.header_class + return klass.decode(read_buffer) @classmethod def decode(cls, data, header=False, framed=False): @@ -108,10 +109,11 @@ def expect_response(self): return True def with_header(self, correlation_id=0, client_id='kafka-python'): + klass = self.header_class if self.FLEXIBLE_VERSION: - self._header = self.header_class(self.API_KEY, self.API_VERSION, correlation_id, client_id, {}) + self._header = klass(self.API_KEY, self.API_VERSION, correlation_id, client_id, {}) else: - self._header = self.header_class(self.API_KEY, self.API_VERSION, correlation_id, client_id) + self._header = klass(self.API_KEY, self.API_VERSION, correlation_id, client_id) @classmethod @property @@ -129,10 +131,11 @@ def encode(self, header=False, framed=False, correlation_id=None, client_id=None class Response(RequestResponse): def with_header(self, correlation_id=0): + klass = self.header_class if self.FLEXIBLE_VERSION: - self._header = self.header_class(correlation_id, {}) + self._header = klass(correlation_id, {}) else: - self._header = self.header_class(correlation_id) + self._header = klass(correlation_id) @classmethod @property From 3e5b214a89a9b876564e83cac1b97ef9bff0e0ba Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 8 Mar 2026 16:43:42 -0700 Subject: [PATCH 4/5] header_class() abstractmethod to avoid the insanity --- kafka/protocol/api.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index 5f7faec84..e3afe1a10 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -74,14 +74,13 @@ def encode(self, header=False, framed=False): return b''.join(bits) @classmethod - @abc.abstractproperty + @abc.abstractmethod def header_class(cls): pass @classmethod def parse_header(cls, read_buffer): - klass = cls.header_class - return klass.decode(read_buffer) + return cls.header_class().decode(read_buffer) @classmethod def decode(cls, data, header=False, framed=False): @@ -109,14 +108,12 @@ def expect_response(self): return True def with_header(self, correlation_id=0, client_id='kafka-python'): - klass = self.header_class if self.FLEXIBLE_VERSION: - self._header = klass(self.API_KEY, self.API_VERSION, correlation_id, client_id, {}) + self._header = self.header_class()(self.API_KEY, self.API_VERSION, correlation_id, client_id, {}) else: - self._header = klass(self.API_KEY, self.API_VERSION, correlation_id, client_id) + self._header = self.header_class()(self.API_KEY, self.API_VERSION, correlation_id, client_id) @classmethod - @property def header_class(cls): if cls.FLEXIBLE_VERSION: return RequestHeaderV2 @@ -131,14 +128,12 @@ def encode(self, header=False, framed=False, correlation_id=None, client_id=None class Response(RequestResponse): def with_header(self, correlation_id=0): - klass = self.header_class if self.FLEXIBLE_VERSION: - self._header = klass(correlation_id, {}) + self._header = self.header_class()(correlation_id, {}) else: - self._header = klass(correlation_id) + self._header = self.header_class()(correlation_id) @classmethod - @property def header_class(cls): if cls.FLEXIBLE_VERSION: return ResponseHeaderV2 From d1f2750ee10c4213b3140399e9343874d1b2b458 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 8 Mar 2026 17:04:54 -0700 Subject: [PATCH 5/5] Retain header on decode; test header+msg equality --- kafka/protocol/api.py | 39 ++++++++++++++++++------------ kafka/protocol/struct.py | 2 ++ test/protocol/test_api_versions.py | 8 ++++-- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index e3afe1a10..0f571bfad 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -58,6 +58,11 @@ def API_VERSION(self): def to_object(self): return _to_object(self.SCHEMA, self) + @classmethod + @abc.abstractmethod + def is_request(cls): + pass + @property def header(self): return self._header @@ -88,13 +93,19 @@ def decode(cls, data, header=False, framed=False): return super().decode(data) if isinstance(data, bytes): data = BytesIO(data) - ret = [] if framed: - ret.append(Int32.decode(data)) + size = Int32.decode(data) if header: - ret.append(cls.parse_header(data)) - ret.append(super().decode(data)) - return tuple(ret) + hdr = cls.parse_header(data) + else: + hdr = None + ret = super().decode(data) + if hdr is not None: + ret._header = hdr + return ret + + def __eq__(self, other): + return self._header == other._header and super().__eq__(other) class Request(RequestResponse): @@ -103,6 +114,10 @@ def RESPONSE_TYPE(self): """The Response class associated with the api request""" pass + @classmethod + def is_request(cls): + return True + def expect_response(self): """Override this method if an api request does not always generate a response""" return True @@ -120,13 +135,12 @@ def header_class(cls): else: return RequestHeader - def encode(self, header=False, framed=False, correlation_id=None, client_id=None, **kwargs): - if header and self.header is None: - self.with_header(correlation_id=correlation_id, client_id=client_id) - return super().encode(header=header, framed=framed) - class Response(RequestResponse): + @classmethod + def is_request(cls): + return False + def with_header(self, correlation_id=0): if self.FLEXIBLE_VERSION: self._header = self.header_class()(correlation_id, {}) @@ -140,11 +154,6 @@ def header_class(cls): else: return ResponseHeader - def encode(self, header=False, framed=False, correlation_id=None, **kwargs): - if header and self.header is None: - self.with_header(correlation_id=correlation_id) - return super().encode(header=header, framed=framed) - def _to_object(schema, data): obj = {} diff --git a/kafka/protocol/struct.py b/kafka/protocol/struct.py index f66170c60..c3015aa00 100644 --- a/kafka/protocol/struct.py +++ b/kafka/protocol/struct.py @@ -54,6 +54,8 @@ def __hash__(self): return hash(self.encode()) def __eq__(self, other): + if not isinstance(other, Struct): + return False if self.SCHEMA != other.SCHEMA: return False for attr in self.SCHEMA.names: diff --git a/test/protocol/test_api_versions.py b/test/protocol/test_api_versions.py index 8ca1284cc..ece60035e 100644 --- a/test/protocol/test_api_versions.py +++ b/test/protocol/test_api_versions.py @@ -65,5 +65,9 @@ @pytest.mark.parametrize('msg, encoded', TEST_CASES) def test_parse(msg, encoded): - assert msg.encode(correlation_id=1, client_id='_internal_client_kYVL', header=True, framed=True) == encoded - assert msg.decode(encoded, header=True, framed=True)[2] == msg + if msg.is_request(): + msg.with_header(correlation_id=1, client_id='_internal_client_kYVL') + else: + msg.with_header(correlation_id=1) + assert msg.encode(header=True, framed=True) == encoded + assert msg.decode(encoded, header=True, framed=True) == msg