Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 66 additions & 64 deletions kafka/protocol/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@ class ResponseHeaderV2(Struct):
)


class Request(Struct, metaclass=abc.ABCMeta):
class RequestResponse(Struct, metaclass=abc.ABCMeta):
FLEXIBLE_VERSION = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._header = None

@abc.abstractproperty
def API_KEY(self):
"""Integer identifier for api request"""
Expand All @@ -51,106 +55,104 @@ def API_VERSION(self):
"""Integer of api request version"""
pass

@abc.abstractproperty
def RESPONSE_TYPE(self):
"""The Response class associated with the api request"""
pass

def expect_response(self):
"""Override this method if an api request does not always generate a response"""
return True

def to_object(self):
return _to_object(self.SCHEMA, self)

def build_header(self, correlation_id=0, client_id='kafka-python'):
if self.FLEXIBLE_VERSION:
return RequestHeaderV2(self.API_KEY, self.API_VERSION, correlation_id, client_id, {})
return RequestHeader(self.API_KEY, self.API_VERSION, correlation_id, client_id)

@classmethod
def parse_header(cls, read_buffer):
if cls.FLEXIBLE_VERSION:
return RequestHeaderV2.decode(read_buffer)
return RequestHeader.decode(read_buffer)
@abc.abstractmethod
def is_request(cls):
pass

@property
def header(self):
return self._header

def encode(self, header=False, framed=False, correlation_id=None, client_id=None, **kwargs):
def encode(self, header=False, framed=False):
data = super().encode()
if not framed and not header:
return data
bits = [data]
if header:
bits.insert(0, self.build_header(correlation_id, client_id).encode())
bits.insert(0, self.header.encode())
if framed:
bits.insert(0, Int32.encode(sum(map(len, bits))))
return b''.join(bits)

@classmethod
@abc.abstractmethod
def header_class(cls):
pass

@classmethod
def parse_header(cls, read_buffer):
return cls.header_class().decode(read_buffer)

@classmethod
def decode(cls, data, header=False, framed=False):
if not framed and not header:
return super().decode(data)
if isinstance(data, bytes):
data = BytesIO(data)
ret = []
if framed:
ret.append(Int32.decode(data))
size = Int32.decode(data)
if header:
ret.append(cls.parse_header(data))
ret.append(super().decode(data))
return tuple(ret)
hdr = cls.parse_header(data)
else:
hdr = None
ret = super().decode(data)
if hdr is not None:
ret._header = hdr
return ret

def __eq__(self, other):
return self._header == other._header and super().__eq__(other)

class Response(Struct, metaclass=abc.ABCMeta):
FLEXIBLE_VERSION = False

class Request(RequestResponse):
@abc.abstractproperty
def API_KEY(self):
"""Integer identifier for api request/response"""
def RESPONSE_TYPE(self):
"""The Response class associated with the api request"""
pass

@abc.abstractproperty
def API_VERSION(self):
"""Integer of api request/response version"""
pass
@classmethod
def is_request(cls):
return True

def to_object(self):
return _to_object(self.SCHEMA, self)
def expect_response(self):
"""Override this method if an api request does not always generate a response"""
return True

def build_header(self, correlation_id=0):
def with_header(self, correlation_id=0, client_id='kafka-python'):
if self.FLEXIBLE_VERSION:
return ResponseHeaderV2(correlation_id=correlation_id, tags=None)
return ResponseHeader(correlation_id=correlation_id)
self._header = self.header_class()(self.API_KEY, self.API_VERSION, correlation_id, client_id, {})
else:
self._header = self.header_class()(self.API_KEY, self.API_VERSION, correlation_id, client_id)

@classmethod
def parse_header(cls, read_buffer):
def header_class(cls):
if cls.FLEXIBLE_VERSION:
return ResponseHeaderV2.decode(read_buffer)
return ResponseHeader.decode(read_buffer)
return RequestHeaderV2
else:
return RequestHeader

def encode(self, header=False, framed=False, correlation_id=None, **kwargs):
data = super().encode()
if not framed and not header:
return data
bits = [data]
if header:
bits.insert(0, self.build_header(correlation_id).encode())
if framed:
bits.insert(0, Int32.encode(sum(map(len, bits))))
return b''.join(bits)

class Response(RequestResponse):
@classmethod
def decode(cls, data, header=False, framed=False):
if not framed and not header:
return super().decode(data)
if isinstance(data, bytes):
data = BytesIO(data)
ret = []
if framed:
ret.append(Int32.decode(data))
if header:
ret.append(cls.parse_header(data))
ret.append(super().decode(data))
return tuple(ret)
def is_request(cls):
return False

def with_header(self, correlation_id=0):
if self.FLEXIBLE_VERSION:
self._header = self.header_class()(correlation_id, {})
else:
self._header = self.header_class()(correlation_id)

@classmethod
def header_class(cls):
if cls.FLEXIBLE_VERSION:
return ResponseHeaderV2
else:
return ResponseHeader


def _to_object(schema, data):
Expand Down
4 changes: 2 additions & 2 deletions kafka/protocol/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def send_request(self, request, correlation_id=None):
correlation_id = self._next_correlation_id()

log.debug('%s Sending request %d %s', self._ident, correlation_id, request)
data = request.encode(correlation_id=correlation_id, client_id=self._client_id,
framed=True, header=True)
request.with_header(correlation_id=correlation_id, client_id=self._client_id)
data = request.encode(framed=True, header=True)
self.bytes_to_send.append(data)
if request.expect_response():
ifr = (correlation_id, request)
Expand Down
2 changes: 2 additions & 0 deletions kafka/protocol/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __hash__(self):
return hash(self.encode())

def __eq__(self, other):
if not isinstance(other, Struct):
return False
if self.SCHEMA != other.SCHEMA:
return False
for attr in self.SCHEMA.names:
Expand Down
8 changes: 6 additions & 2 deletions test/protocol/test_api_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,9 @@

@pytest.mark.parametrize('msg, encoded', TEST_CASES)
def test_parse(msg, encoded):
assert msg.encode(correlation_id=1, client_id='_internal_client_kYVL', header=True, framed=True) == encoded
assert msg.decode(encoded, header=True, framed=True)[2] == msg
if msg.is_request():
msg.with_header(correlation_id=1, client_id='_internal_client_kYVL')
else:
msg.with_header(correlation_id=1)
assert msg.encode(header=True, framed=True) == encoded
assert msg.decode(encoded, header=True, framed=True) == msg
Loading