diff --git a/kafka/conn.py b/kafka/conn.py index 14fd22b42..c864d9685 100755 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -18,7 +18,7 @@ ) from kafka.protocol.parser import KafkaProtocol from kafka.protocol.new.sasl import SaslAuthenticateRequest, SaslHandshakeRequest -from kafka.protocol.types import Int32 +from kafka.protocol.new.schemas.fields.codecs import Int32 from kafka.sasl import get_sasl_mechanism from kafka.socks5_wrapper import Socks5Wrapper from kafka.version import __version__ diff --git a/kafka/protocol/new/schemas/fields/codecs/__init__.py b/kafka/protocol/new/schemas/fields/codecs/__init__.py index c98d1570a..034d3d83e 100644 --- a/kafka/protocol/new/schemas/fields/codecs/__init__.py +++ b/kafka/protocol/new/schemas/fields/codecs/__init__.py @@ -1,5 +1,4 @@ -from .....types import ( - Array, CompactArray, +from .types import ( BitField, Boolean, UUID, Int8, Int16, Int32, Int64, UnsignedVarInt32, Float64, Bytes, CompactBytes, String, CompactString, @@ -7,7 +6,6 @@ from .tagged_fields import TaggedFields __all__ = [ - 'Array', 'CompactArray', 'BitField', 'Boolean', 'UUID', 'Int8', 'Int16', 'Int32', 'Int64', 'UnsignedVarInt32', 'Float64', 'Bytes', 'CompactBytes', 'String', 'CompactString', diff --git a/kafka/protocol/new/schemas/fields/codecs/types.py b/kafka/protocol/new/schemas/fields/codecs/types.py new file mode 100644 index 000000000..3e33ee9b4 --- /dev/null +++ b/kafka/protocol/new/schemas/fields/codecs/types.py @@ -0,0 +1,306 @@ +from struct import error, pack, unpack +import uuid + +class Int8: + fmt = 'b' + size = 1 + + @classmethod + def encode(cls, value): + return pack('>b', value) + + @classmethod + def decode(cls, data): + return unpack('>b', data.read(1))[0] + + +class Int16: + fmt = 'h' + size = 2 + + @classmethod + def encode(cls, value): + return pack('>h', value) + + @classmethod + def decode(cls, data): + return unpack('>h', data.read(2))[0] + + +class Int32: + fmt = 'i' + size = 4 + + @classmethod + def encode(cls, value): + return pack('>i', value) + + @classmethod + def decode(cls, data): + return unpack('>i', data.read(4))[0] + + +class Int64: + fmt = 'q' + size = 8 + + @classmethod + def encode(cls, value): + return pack('>q', value) + + @classmethod + def decode(cls, data): + return unpack('>q', data.read(8))[0] + + +class Float64: + fmt = 'd' + size = 8 + + @classmethod + def encode(cls, value): + return pack('>d', value) + + @classmethod + def decode(cls, data): + return unpack('>d', data.read(8))[0] + + +class UUID: + fmt = '16B' + size = 16 + ZERO_UUID = uuid.UUID(int=0) + + @classmethod + def encode(cls, value): + if value is None: + value = cls.ZERO_UUID + if isinstance(value, uuid.UUID): + return value.bytes + return uuid.UUID(value).bytes + + @classmethod + def decode(cls, data): + val = uuid.UUID(bytes=data.read(16)) + if val == cls.ZERO_UUID: + return None + return val + + +class String: + fmt = Int16.fmt + size = 'variable' + + def __init__(self, encoding='utf-8'): + self.encoding = encoding + + def encode(self, value): + if value is None: + return Int16.encode(-1) + value = str(value).encode(self.encoding) + return Int16.encode(len(value)) + value + + def decode(self, data): + length = Int16.decode(data) + if length < 0: + return None + value = data.read(length) + if len(value) != length: + raise ValueError('Buffer underrun decoding string') + return value.decode(self.encoding) + + +class Bytes: + fmt = Int32.fmt + size = 'variable' + + @classmethod + def encode(cls, value): + if value is None: + return Int32.encode(-1) + elif not isinstance(value, bytes): + value = value.encode() + return Int32.encode(len(value)) + value + + @classmethod + def decode(cls, data): + length = Int32.decode(data) + if length < 0: + return None + value = data.read(length) + if len(value) != length: + raise ValueError('Buffer underrun decoding Bytes') + return value + + +class Boolean: + fmt = '?' + size = 1 + + @classmethod + def encode(cls, value): + return pack('>?', value) + + @classmethod + def decode(cls, data): + return unpack('>?', data.read(1))[0] + + +class UnsignedVarInt32: + fmt = 'B' + size = 'variable' + + @classmethod + def decode(cls, data): + value = VarInt32.decode(data) + return (value << 1) ^ (value >> 31) + + @classmethod + def encode(cls, value): + return VarInt32.encode((value >> 1) ^ -(value & 1)) + + +class VarInt32: + fmt = 'B' + size = 'variable' + + @classmethod + def decode(cls, data): + value, i = 0, 0 + while True: + b, = unpack('B', data.read(1)) + if not (b & 0x80): + break + value |= (b & 0x7f) << i + i += 7 + if i > 28: + raise ValueError('Invalid value {}'.format(value)) + value |= b << i + return (value >> 1) ^ -(value & 1) + + @classmethod + def encode(cls, value): + # bring it in line with the java binary repr + value = (value << 1) ^ (value >> 31) + value &= 0xffffffff + ret = b'' + while (value & 0xffffff80) != 0: + b = (value & 0x7f) | 0x80 + ret += pack('B', b) + value >>= 7 + ret += pack('B', value) + return ret + + +class VarInt64: + fmt = 'B' + size = 'variable' + + @classmethod + def decode(cls, data): + value, i = 0, 0 + while True: + b, = unpack('B', data.read(1)) + if not (b & 0x80): + break + value |= (b & 0x7f) << i + i += 7 + if i > 63: + raise ValueError('Invalid value {}'.format(value)) + value |= b << i + return (value >> 1) ^ -(value & 1) + + @classmethod + def encode(cls, value): + # bring it in line with the java binary repr + value = (value << 1) ^ (value >> 63) + value &= 0xffffffffffffffff + ret = b'' + while (value & 0xffffffffffffff80) != 0: + b = (value & 0x7f) | 0x80 + ret += pack('B', b) + value >>= 7 + ret += pack('B', value) + return ret + + +class CompactString(String): + fmt = 'B' + size = 'variable' + + def decode(self, data): + length = UnsignedVarInt32.decode(data) - 1 + if length < 0: + return None + value = data.read(length) + if len(value) != length: + raise ValueError('Buffer underrun decoding string') + return value.decode(self.encoding) + + def encode(self, value): + if value is None: + return UnsignedVarInt32.encode(0) + value = str(value).encode(self.encoding) + return UnsignedVarInt32.encode(len(value) + 1) + value + + +class CompactBytes: + fmt = 'B' + size = 'variable' + + @classmethod + def decode(cls, data): + length = UnsignedVarInt32.decode(data) - 1 + if length < 0: + return None + value = data.read(length) + if len(value) != length: + raise ValueError('Buffer underrun decoding Bytes') + return value + + @classmethod + def encode(cls, value): + if value is None: + return UnsignedVarInt32.encode(0) + else: + return UnsignedVarInt32.encode(len(value) + 1) + value + + +class BitField: + fmt = 'I' + size = 4 + + @classmethod + def decode(cls, data): + vals = cls.from_32_bit_field(unpack('>I', data.read(4))[0]) + if vals == {31}: + vals = None + return vals + + @classmethod + def encode(cls, vals): + if vals is None: + vals = {31} + # to_32_bit_field returns unsigned val, so we need to + # encode >I to avoid crash if/when byte 31 is set + # (note that decode as signed still works fine) + return pack('>I', cls.to_32_bit_field(vals)) + + @classmethod + def to_32_bit_field(cls, vals): + value = 0 + for b in vals: + assert 0 <= b < 32 + value |= 1 << b + return value + + @classmethod + def from_32_bit_field(cls, value): + result = set() + count = 0 + while value != 0: + if (value & 1) != 0: + result.add(count) + count += 1 + value = (value & 0xFFFFFFFF) >> 1 + return result diff --git a/kafka/protocol/parser.py b/kafka/protocol/parser.py index 274673cf0..a2f137b29 100644 --- a/kafka/protocol/parser.py +++ b/kafka/protocol/parser.py @@ -1,5 +1,6 @@ import collections import logging +import struct import kafka.errors as Errors from kafka.protocol.find_coordinator import FindCoordinatorResponse @@ -163,7 +164,7 @@ def _process_response(self, read_buffer): # decode response try: response = response_type.decode(read_buffer) - except ValueError: + except (ValueError, struct.error): read_buffer.seek(0) buf = read_buffer.read() log.error('Response %d [ResponseType: %s RequestHeader: %s]:' diff --git a/test/protocol/new/schemas/test_field.py b/test/protocol/new/schemas/test_field.py index 2a069ea99..2861b08fc 100644 --- a/test/protocol/new/schemas/test_field.py +++ b/test/protocol/new/schemas/test_field.py @@ -4,7 +4,7 @@ from kafka.protocol.new.schemas.fields.struct import StructField from kafka.protocol.new.schemas.fields.base import BaseField from kafka.protocol.new.schemas.fields.simple import SimpleField -from kafka.protocol.types import Int16, Int32, Boolean, String, UUID +from kafka.protocol.new.schemas.fields.codecs import Int16, Int32, Boolean, String, UUID def test_parse_versions(): diff --git a/test/protocol/new/test_api_compatibility.py b/test/protocol/new/test_api_compatibility.py index 1c86d9da4..d47c7538d 100644 --- a/test/protocol/new/test_api_compatibility.py +++ b/test/protocol/new/test_api_compatibility.py @@ -12,7 +12,7 @@ ApiVersionsRequest as NewApiVersionsRequest, ApiVersionsResponse as NewApiVersionsResponse ) -from kafka.protocol.types import Int16 +from kafka.protocol.new.schemas.fields.codecs import Int16 # --- Golden Samples (Generated from existing working system) --- diff --git a/test/protocol/new/test_new_parser.py b/test/protocol/new/test_new_parser.py index 4ac034ea9..cb5b89ffd 100644 --- a/test/protocol/new/test_new_parser.py +++ b/test/protocol/new/test_new_parser.py @@ -1,7 +1,7 @@ import pytest +from kafka.errors import KafkaProtocolError, CorrelationIdError from kafka.protocol.parser import KafkaProtocol - from kafka.protocol.new.metadata import ( ApiVersionsRequest, ApiVersionsResponse, FindCoordinatorRequest, FindCoordinatorResponse, @@ -241,3 +241,27 @@ def test_parser(test_case): assert len(responses) == 1 assert responses[0][0] == correlation_id assert responses[0][1] == resp + + +def test_correlation_id_error(): + parser = KafkaProtocol(client_id='test-parser-error') + parser.send_request(MetadataRequest[0]()) + sent_bytes = parser.send_bytes() + resp = MetadataResponse[0]() + resp.with_header(correlation_id=99) + with pytest.raises(CorrelationIdError): + parser.receive_bytes(resp.encode(header=True, framed=True)) + + +def test_parser_error(): + parser = KafkaProtocol(client_id='test-parser-error') + + parser.send_request(MetadataRequest[0]()) + sent_bytes = parser.send_bytes() + resp = MetadataResponse[0]() + resp.with_header(correlation_id=0) + resp_bytes = resp.encode(header=True, framed=True) + # frame header too short -> response decode buffer underrun + bad_bytes = b''.join([resp_bytes[0:3], b'\x0a', resp_bytes[4:]]) + with pytest.raises(KafkaProtocolError): + responses = parser.receive_bytes(bad_bytes)