diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index 0e2d4a954..0f571bfad 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -38,9 +38,13 @@ class ResponseHeaderV2(Struct): ) -class Request(Struct, metaclass=abc.ABCMeta): +class RequestResponse(Struct, metaclass=abc.ABCMeta): FLEXIBLE_VERSION = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._header = None + @abc.abstractproperty def API_KEY(self): """Integer identifier for api request""" @@ -51,106 +55,104 @@ def API_VERSION(self): """Integer of api request version""" pass - @abc.abstractproperty - def RESPONSE_TYPE(self): - """The Response class associated with the api request""" - pass - - def expect_response(self): - """Override this method if an api request does not always generate a response""" - return True - def to_object(self): return _to_object(self.SCHEMA, self) - def build_header(self, correlation_id=0, client_id='kafka-python'): - if self.FLEXIBLE_VERSION: - return RequestHeaderV2(self.API_KEY, self.API_VERSION, correlation_id, client_id, {}) - return RequestHeader(self.API_KEY, self.API_VERSION, correlation_id, client_id) - @classmethod - def parse_header(cls, read_buffer): - if cls.FLEXIBLE_VERSION: - return RequestHeaderV2.decode(read_buffer) - return RequestHeader.decode(read_buffer) + @abc.abstractmethod + def is_request(cls): + pass + + @property + def header(self): + return self._header - def encode(self, header=False, framed=False, correlation_id=None, client_id=None, **kwargs): + def encode(self, header=False, framed=False): data = super().encode() if not framed and not header: return data bits = [data] if header: - bits.insert(0, self.build_header(correlation_id, client_id).encode()) + bits.insert(0, self.header.encode()) if framed: bits.insert(0, Int32.encode(sum(map(len, bits)))) return b''.join(bits) + @classmethod + @abc.abstractmethod + def header_class(cls): + pass + + @classmethod + def parse_header(cls, read_buffer): + return cls.header_class().decode(read_buffer) + @classmethod def decode(cls, data, header=False, framed=False): if not framed and not header: return super().decode(data) if isinstance(data, bytes): data = BytesIO(data) - ret = [] if framed: - ret.append(Int32.decode(data)) + size = Int32.decode(data) if header: - ret.append(cls.parse_header(data)) - ret.append(super().decode(data)) - return tuple(ret) + hdr = cls.parse_header(data) + else: + hdr = None + ret = super().decode(data) + if hdr is not None: + ret._header = hdr + return ret + def __eq__(self, other): + return self._header == other._header and super().__eq__(other) -class Response(Struct, metaclass=abc.ABCMeta): - FLEXIBLE_VERSION = False +class Request(RequestResponse): @abc.abstractproperty - def API_KEY(self): - """Integer identifier for api request/response""" + def RESPONSE_TYPE(self): + """The Response class associated with the api request""" pass - @abc.abstractproperty - def API_VERSION(self): - """Integer of api request/response version""" - pass + @classmethod + def is_request(cls): + return True - def to_object(self): - return _to_object(self.SCHEMA, self) + def expect_response(self): + """Override this method if an api request does not always generate a response""" + return True - def build_header(self, correlation_id=0): + def with_header(self, correlation_id=0, client_id='kafka-python'): if self.FLEXIBLE_VERSION: - return ResponseHeaderV2(correlation_id=correlation_id, tags=None) - return ResponseHeader(correlation_id=correlation_id) + self._header = self.header_class()(self.API_KEY, self.API_VERSION, correlation_id, client_id, {}) + else: + self._header = self.header_class()(self.API_KEY, self.API_VERSION, correlation_id, client_id) @classmethod - def parse_header(cls, read_buffer): + def header_class(cls): if cls.FLEXIBLE_VERSION: - return ResponseHeaderV2.decode(read_buffer) - return ResponseHeader.decode(read_buffer) + return RequestHeaderV2 + else: + return RequestHeader - def encode(self, header=False, framed=False, correlation_id=None, **kwargs): - data = super().encode() - if not framed and not header: - return data - bits = [data] - if header: - bits.insert(0, self.build_header(correlation_id).encode()) - if framed: - bits.insert(0, Int32.encode(sum(map(len, bits)))) - return b''.join(bits) +class Response(RequestResponse): @classmethod - def decode(cls, data, header=False, framed=False): - if not framed and not header: - return super().decode(data) - if isinstance(data, bytes): - data = BytesIO(data) - ret = [] - if framed: - ret.append(Int32.decode(data)) - if header: - ret.append(cls.parse_header(data)) - ret.append(super().decode(data)) - return tuple(ret) + def is_request(cls): + return False + + def with_header(self, correlation_id=0): + if self.FLEXIBLE_VERSION: + self._header = self.header_class()(correlation_id, {}) + else: + self._header = self.header_class()(correlation_id) + + @classmethod + def header_class(cls): + if cls.FLEXIBLE_VERSION: + return ResponseHeaderV2 + else: + return ResponseHeader def _to_object(schema, data): diff --git a/kafka/protocol/parser.py b/kafka/protocol/parser.py index 96066806a..0ebe5ace7 100644 --- a/kafka/protocol/parser.py +++ b/kafka/protocol/parser.py @@ -58,8 +58,8 @@ def send_request(self, request, correlation_id=None): correlation_id = self._next_correlation_id() log.debug('%s Sending request %d %s', self._ident, correlation_id, request) - data = request.encode(correlation_id=correlation_id, client_id=self._client_id, - framed=True, header=True) + request.with_header(correlation_id=correlation_id, client_id=self._client_id) + data = request.encode(framed=True, header=True) self.bytes_to_send.append(data) if request.expect_response(): ifr = (correlation_id, request) diff --git a/kafka/protocol/struct.py b/kafka/protocol/struct.py index f66170c60..c3015aa00 100644 --- a/kafka/protocol/struct.py +++ b/kafka/protocol/struct.py @@ -54,6 +54,8 @@ def __hash__(self): return hash(self.encode()) def __eq__(self, other): + if not isinstance(other, Struct): + return False if self.SCHEMA != other.SCHEMA: return False for attr in self.SCHEMA.names: diff --git a/test/protocol/test_api_versions.py b/test/protocol/test_api_versions.py index 8ca1284cc..ece60035e 100644 --- a/test/protocol/test_api_versions.py +++ b/test/protocol/test_api_versions.py @@ -65,5 +65,9 @@ @pytest.mark.parametrize('msg, encoded', TEST_CASES) def test_parse(msg, encoded): - assert msg.encode(correlation_id=1, client_id='_internal_client_kYVL', header=True, framed=True) == encoded - assert msg.decode(encoded, header=True, framed=True)[2] == msg + if msg.is_request(): + msg.with_header(correlation_id=1, client_id='_internal_client_kYVL') + else: + msg.with_header(correlation_id=1) + assert msg.encode(header=True, framed=True) == encoded + assert msg.decode(encoded, header=True, framed=True) == msg