Skip to content

Commit d4637e5

Browse files
committed
Use RequestResponse class for common bits; add header attr and with_header() to construct (#2724)
1 parent 445fa2d commit d4637e5

4 files changed

Lines changed: 76 additions & 68 deletions

File tree

kafka/protocol/api.py

Lines changed: 66 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,13 @@ class ResponseHeaderV2(Struct):
3838
)
3939

4040

41-
class Request(Struct, metaclass=abc.ABCMeta):
41+
class RequestResponse(Struct, metaclass=abc.ABCMeta):
4242
FLEXIBLE_VERSION = False
4343

44+
def __init__(self, *args, **kwargs):
45+
super().__init__(*args, **kwargs)
46+
self._header = None
47+
4448
@abc.abstractproperty
4549
def API_KEY(self):
4650
"""Integer identifier for api request"""
@@ -51,106 +55,104 @@ def API_VERSION(self):
5155
"""Integer of api request version"""
5256
pass
5357

54-
@abc.abstractproperty
55-
def RESPONSE_TYPE(self):
56-
"""The Response class associated with the api request"""
57-
pass
58-
59-
def expect_response(self):
60-
"""Override this method if an api request does not always generate a response"""
61-
return True
62-
6358
def to_object(self):
6459
return _to_object(self.SCHEMA, self)
6560

66-
def build_header(self, correlation_id=0, client_id='kafka-python'):
67-
if self.FLEXIBLE_VERSION:
68-
return RequestHeaderV2(self.API_KEY, self.API_VERSION, correlation_id, client_id, {})
69-
return RequestHeader(self.API_KEY, self.API_VERSION, correlation_id, client_id)
70-
7161
@classmethod
72-
def parse_header(cls, read_buffer):
73-
if cls.FLEXIBLE_VERSION:
74-
return RequestHeaderV2.decode(read_buffer)
75-
return RequestHeader.decode(read_buffer)
62+
@abc.abstractmethod
63+
def is_request(cls):
64+
pass
65+
66+
@property
67+
def header(self):
68+
return self._header
7669

77-
def encode(self, header=False, framed=False, correlation_id=None, client_id=None, **kwargs):
70+
def encode(self, header=False, framed=False):
7871
data = super().encode()
7972
if not framed and not header:
8073
return data
8174
bits = [data]
8275
if header:
83-
bits.insert(0, self.build_header(correlation_id, client_id).encode())
76+
bits.insert(0, self.header.encode())
8477
if framed:
8578
bits.insert(0, Int32.encode(sum(map(len, bits))))
8679
return b''.join(bits)
8780

81+
@classmethod
82+
@abc.abstractmethod
83+
def header_class(cls):
84+
pass
85+
86+
@classmethod
87+
def parse_header(cls, read_buffer):
88+
return cls.header_class().decode(read_buffer)
89+
8890
@classmethod
8991
def decode(cls, data, header=False, framed=False):
9092
if not framed and not header:
9193
return super().decode(data)
9294
if isinstance(data, bytes):
9395
data = BytesIO(data)
94-
ret = []
9596
if framed:
96-
ret.append(Int32.decode(data))
97+
size = Int32.decode(data)
9798
if header:
98-
ret.append(cls.parse_header(data))
99-
ret.append(super().decode(data))
100-
return tuple(ret)
99+
hdr = cls.parse_header(data)
100+
else:
101+
hdr = None
102+
ret = super().decode(data)
103+
if hdr is not None:
104+
ret._header = hdr
105+
return ret
101106

107+
def __eq__(self, other):
108+
return self._header == other._header and super().__eq__(other)
102109

103-
class Response(Struct, metaclass=abc.ABCMeta):
104-
FLEXIBLE_VERSION = False
105110

111+
class Request(RequestResponse):
106112
@abc.abstractproperty
107-
def API_KEY(self):
108-
"""Integer identifier for api request/response"""
113+
def RESPONSE_TYPE(self):
114+
"""The Response class associated with the api request"""
109115
pass
110116

111-
@abc.abstractproperty
112-
def API_VERSION(self):
113-
"""Integer of api request/response version"""
114-
pass
117+
@classmethod
118+
def is_request(cls):
119+
return True
115120

116-
def to_object(self):
117-
return _to_object(self.SCHEMA, self)
121+
def expect_response(self):
122+
"""Override this method if an api request does not always generate a response"""
123+
return True
118124

119-
def build_header(self, correlation_id=0):
125+
def with_header(self, correlation_id=0, client_id='kafka-python'):
120126
if self.FLEXIBLE_VERSION:
121-
return ResponseHeaderV2(correlation_id=correlation_id, tags=None)
122-
return ResponseHeader(correlation_id=correlation_id)
127+
self._header = self.header_class()(self.API_KEY, self.API_VERSION, correlation_id, client_id, {})
128+
else:
129+
self._header = self.header_class()(self.API_KEY, self.API_VERSION, correlation_id, client_id)
123130

124131
@classmethod
125-
def parse_header(cls, read_buffer):
132+
def header_class(cls):
126133
if cls.FLEXIBLE_VERSION:
127-
return ResponseHeaderV2.decode(read_buffer)
128-
return ResponseHeader.decode(read_buffer)
134+
return RequestHeaderV2
135+
else:
136+
return RequestHeader
129137

130-
def encode(self, header=False, framed=False, correlation_id=None, **kwargs):
131-
data = super().encode()
132-
if not framed and not header:
133-
return data
134-
bits = [data]
135-
if header:
136-
bits.insert(0, self.build_header(correlation_id).encode())
137-
if framed:
138-
bits.insert(0, Int32.encode(sum(map(len, bits))))
139-
return b''.join(bits)
140138

139+
class Response(RequestResponse):
141140
@classmethod
142-
def decode(cls, data, header=False, framed=False):
143-
if not framed and not header:
144-
return super().decode(data)
145-
if isinstance(data, bytes):
146-
data = BytesIO(data)
147-
ret = []
148-
if framed:
149-
ret.append(Int32.decode(data))
150-
if header:
151-
ret.append(cls.parse_header(data))
152-
ret.append(super().decode(data))
153-
return tuple(ret)
141+
def is_request(cls):
142+
return False
143+
144+
def with_header(self, correlation_id=0):
145+
if self.FLEXIBLE_VERSION:
146+
self._header = self.header_class()(correlation_id, {})
147+
else:
148+
self._header = self.header_class()(correlation_id)
149+
150+
@classmethod
151+
def header_class(cls):
152+
if cls.FLEXIBLE_VERSION:
153+
return ResponseHeaderV2
154+
else:
155+
return ResponseHeader
154156

155157

156158
def _to_object(schema, data):

kafka/protocol/parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def send_request(self, request, correlation_id=None):
5858
correlation_id = self._next_correlation_id()
5959

6060
log.debug('%s Sending request %d %s', self._ident, correlation_id, request)
61-
data = request.encode(correlation_id=correlation_id, client_id=self._client_id,
62-
framed=True, header=True)
61+
request.with_header(correlation_id=correlation_id, client_id=self._client_id)
62+
data = request.encode(framed=True, header=True)
6363
self.bytes_to_send.append(data)
6464
if request.expect_response():
6565
ifr = (correlation_id, request)

kafka/protocol/struct.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def __hash__(self):
5454
return hash(self.encode())
5555

5656
def __eq__(self, other):
57+
if not isinstance(other, Struct):
58+
return False
5759
if self.SCHEMA != other.SCHEMA:
5860
return False
5961
for attr in self.SCHEMA.names:

test/protocol/test_api_versions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,9 @@
6565

6666
@pytest.mark.parametrize('msg, encoded', TEST_CASES)
6767
def test_parse(msg, encoded):
68-
assert msg.encode(correlation_id=1, client_id='_internal_client_kYVL', header=True, framed=True) == encoded
69-
assert msg.decode(encoded, header=True, framed=True)[2] == msg
68+
if msg.is_request():
69+
msg.with_header(correlation_id=1, client_id='_internal_client_kYVL')
70+
else:
71+
msg.with_header(correlation_id=1)
72+
assert msg.encode(header=True, framed=True) == encoded
73+
assert msg.decode(encoded, header=True, framed=True) == msg

0 commit comments

Comments
 (0)