Skip to content

Commit 50d0aa7

Browse files
committed
Retain header on decode; test header+msg equality
1 parent 53fec88 commit 50d0aa7

3 files changed

Lines changed: 32 additions & 17 deletions

File tree

kafka/protocol/api.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def API_VERSION(self):
5858
def to_object(self):
5959
return _to_object(self.SCHEMA, self)
6060

61+
@classmethod
62+
@abc.abstractmethod
63+
def is_request(cls):
64+
pass
65+
6166
@property
6267
def header(self):
6368
return self._header
@@ -88,16 +93,26 @@ def decode(cls, data, header=False, framed=False):
8893
return super().decode(data)
8994
if isinstance(data, bytes):
9095
data = BytesIO(data)
91-
ret = []
9296
if framed:
93-
ret.append(Int32.decode(data))
97+
size = Int32.decode(data)
9498
if header:
95-
ret.append(cls.parse_header(data))
96-
ret.append(super().decode(data))
97-
return tuple(ret)
99+
hdr = cls.parse_header(data)
100+
else:
101+
hdr = None
102+
ret = super().decode(data)
103+
if hdr is not None:
104+
ret._header = hdr
105+
return ret
106+
107+
def __eq__(self, other):
108+
return self._header == other._header and super().__eq__(other)
98109

99110

100111
class Request(RequestResponse):
112+
@classmethod
113+
def is_request(cls):
114+
return True
115+
101116
def expect_response(self):
102117
"""Override this method if an api request does not always generate a response"""
103118
return True
@@ -115,13 +130,12 @@ def header_class(cls):
115130
else:
116131
return RequestHeader
117132

118-
def encode(self, header=False, framed=False, correlation_id=None, client_id=None, **kwargs):
119-
if header and self.header is None:
120-
self.with_header(correlation_id=correlation_id, client_id=client_id)
121-
return super().encode(header=header, framed=framed)
122-
123133

124134
class Response(RequestResponse):
135+
@classmethod
136+
def is_request(cls):
137+
return False
138+
125139
def with_header(self, correlation_id=0):
126140
if self.FLEXIBLE_VERSION:
127141
self._header = self.header_class()(correlation_id, {})
@@ -135,11 +149,6 @@ def header_class(cls):
135149
else:
136150
return ResponseHeader
137151

138-
def encode(self, header=False, framed=False, correlation_id=None, **kwargs):
139-
if header and self.header is None:
140-
self.with_header(correlation_id=correlation_id)
141-
return super().encode(header=header, framed=framed)
142-
143152

144153
def _to_object(schema, data):
145154
obj = {}

kafka/protocol/struct.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def __hash__(self):
5454
return hash(self.encode())
5555

5656
def __eq__(self, other):
57+
if not isinstance(other, Struct):
58+
return False
5759
if self.SCHEMA != other.SCHEMA:
5860
return False
5961
for attr in self.SCHEMA.names:

test/protocol/test_api_versions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,9 @@
6565

6666
@pytest.mark.parametrize('msg, encoded', TEST_CASES)
6767
def test_parse(msg, encoded):
68-
assert msg.encode(correlation_id=1, client_id='_internal_client_kYVL', header=True, framed=True) == encoded
69-
assert msg.decode(encoded, header=True, framed=True)[2] == msg
68+
if msg.is_request():
69+
msg.with_header(correlation_id=1, client_id='_internal_client_kYVL')
70+
else:
71+
msg.with_header(correlation_id=1)
72+
assert msg.encode(header=True, framed=True) == encoded
73+
assert msg.decode(encoded, header=True, framed=True) == msg

0 commit comments

Comments
 (0)