Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions kafka/protocol/new/api_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,26 @@ def __getitem__(cls, version):
# Use [] lookups to move from primary class to "versioned" classes
# which are simple wrappers around the primary class but with a _version attr
if cls._class_version is not None:
return cls._VERSIONS[None].__getitem__(version)
return cls._VERSIONS[None][version]
if cls._valid_versions is not None:
if version < 0:
version += 1 + cls.max_version # support negative index, e.g., [-1]
if not cls.min_version <= version <= cls.max_version:
raise ValueError('Invalid version! min=%d, max=%d' % (cls.min_version, cls.max_version))
klass_name = cls.__name__ + '_v' + str(version)
if klass_name in cls._VERSIONS:
return cls._VERSIONS[klass_name]
cls._VERSIONS[klass_name] = type(klass_name, tuple(cls.mro()), {'_class_version': version}, init=False)
return cls._VERSIONS[klass_name]

def __len__(cls):
# Maintain compatibility
if cls._valid_versions is None:
raise RuntimeError('Unable to calculate __len__ for class without valid_versions')
elif cls._class_version is not None:
raise TypeError('len() only supported on primary message class (not versioned)')
return cls._valid_versions[1] + 1


class ApiMessageMeta(VersionSubscriptable, SlotsBuilder):
def __new__(metacls, name, bases, attrs, **kw):
Expand Down Expand Up @@ -58,7 +71,7 @@ def __init__(cls, name, bases, attrs, **kw):


class ApiMessage(DataContainer, metaclass=ApiMessageMeta, init=False):
__slots__ = ('_header', '_version')
__slots__ = ('_header')

def __init_subclass__(cls, **kw):
super().__init_subclass__(**kw)
Expand All @@ -72,11 +85,17 @@ def __init_subclass__(cls, **kw):
ResponseClassRegistry.register_response_class(weakref.proxy(cls))

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._header = None
self._version = None
if 'version' in kwargs:
self.API_VERSION = kwargs['version']
if len(args) > 0:
untagged_fields = self._struct.untagged_fields(self.API_VERSION)
if len(args) != len(untagged_fields):
raise RuntimeError('Unable to init ApiMessage via positional args: unexpected len')
kwargs.update({field.name: args[i] for i, field in enumerate(untagged_fields)})
args = ()
super().__init__(*args, **kwargs)

@classproperty
def name(cls): # pylint: disable=E0213
Expand Down Expand Up @@ -171,7 +190,13 @@ def encode_header(self, flexible=False):
return self._header.encode(flexible=flexible) # pylint: disable=E1120

@classmethod
def parse_header(cls, data, flexible=False):
def parse_header(cls, data, version=None):
version = cls._class_version if version is None else version
if version is None:
raise ValueError('Version required to decode data')
elif not 0 <= version <= cls.max_version:
raise ValueError('Invalid version %s (max version is %s).' % (version, cls.max_version))
flexible = cls.flexible_version_q(version)
return cls.header_class.decode(data, flexible=flexible) # pylint: disable=E1101

def encode(self, version=None, header=False, framed=False):
Expand Down Expand Up @@ -206,15 +231,15 @@ def decode(cls, data, version=None, header=False, framed=False):
else:
data_class = cls

flexible = cls.flexible_version_q(version)
if isinstance(data, bytes):
data = io.BytesIO(data)
if framed:
size = Int32.decode(data)
if header:
hdr = cls.parse_header(data, flexible=flexible)
hdr = cls.parse_header(data, version=version)
else:
hdr = None
flexible = cls.flexible_version_q(version)
ret = cls._struct.decode(data, version=version, compact=flexible, tagged=flexible, data_class=data_class)
if hdr is not None:
ret._header = hdr
Expand Down
23 changes: 21 additions & 2 deletions kafka/protocol/new/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __new__(metacls, name, bases, attrs, **kw):


class DataContainer(metaclass=SlotsBuilder):
__slots__ = ('tags', 'unknown_tags')
__slots__ = ('tags', 'unknown_tags', '_version')
_struct = None

def __init_subclass__(cls, **kwargs):
Expand All @@ -22,8 +22,9 @@ def __init_subclass__(cls, **kwargs):
field.set_data_class(type(field.type_str, (DataContainer,), {'_struct': field}))
setattr(cls, field.type_str, field.data_class)

def __init__(self, **field_vals):
def __init__(self, version=None, **field_vals):
assert self._struct is not None
self._version = version
self.tags = None
self.unknown_tags = None
for field in self._struct._fields:
Expand Down Expand Up @@ -83,3 +84,21 @@ def __eq__(self, other):
if getattr(self, field.name) != getattr(other, field.name):
return False
return True

def __iter__(self):
if self._version is None:
raise RuntimeError('DataContainer Iteration not supported without _version')
return iter([getattr(self, field.name) for field in self._struct.untagged_fields(self._version)])

def __getitem__(self, key):
if self._version is None:
raise RuntimeError('DataContainer subscript not supported without _version')
elif isinstance(key, int):
field = self._struct.untagged_fields(self._version)[key]
return getattr(self, field.name)
elif isinstance(key, slice):
fields = self._struct.untagged_fields(self._version)
start, stop, step = key.indices(len(fields))
return [getattr(self, fields[i].name) for i in range(start, stop, step)]
else:
raise TypeError('DataContainer subscript supports int or slices only: %s' % type(key).__name__)
4 changes: 2 additions & 2 deletions kafka/protocol/new/metadata/api_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class ApiVersionsRequest(ApiMessage): pass
class ApiVersionsResponse(ApiMessage):
# ApiVersionsResponse header never uses flexible formats, even if body does
@classmethod
def parse_header(cls, data, flexible=False):
return super().parse_header(data, flexible=False)
def parse_header(cls, data, version=None):
return cls.header_class.decode(data, flexible=False) # pylint: disable=E1101

def encode_header(self, flexible=False):
return super().encode_header(flexible=False)
Expand Down
16 changes: 15 additions & 1 deletion kafka/protocol/new/metadata/metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
from ..api_message import ApiMessage

from kafka.util import classproperty


class MetadataRequest(ApiMessage):
@classproperty
def ALL_TOPICS(cls): # pylint: disable=E0213
if cls._class_version == 0: # pylint: disable=E1101
return []
else:
return None

@classproperty
def NO_TOPICS(cls): # pylint: disable=E0213
return []


class MetadataRequest(ApiMessage): pass
class MetadataResponse(ApiMessage):
@classmethod
def json_patch(cls, json):
Expand Down
3 changes: 3 additions & 0 deletions kafka/protocol/new/schemas/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,6 @@ def decode(self, data, version=None, compact=False, tagged=False):
return None
return [self.array_of.decode(data, version=version, compact=compact, tagged=tagged)
for _ in range(size)]

def __repr__(self):
return 'ArrayField(%s)' % self._json
4 changes: 3 additions & 1 deletion kafka/protocol/new/schemas/fields/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ def decode(self, data, version=None, compact=False, tagged=False):
assert version is not None, 'version is required to decode Field'
if not self.for_version_q(version):
return None
print("decoding", self.name)
if compact and self._type is Bytes:
return CompactBytes.decode(data)
elif compact and isinstance(self._type, String):
return CompactString(self._type.encoding).decode(data)
else:
return self._type.decode(data)

def __repr__(self):
return 'SimpleField(%s)' % self._json
17 changes: 13 additions & 4 deletions kafka/protocol/new/schemas/fields/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,26 @@ def tagged_fields(self, version):
if field.for_version_q(version)
and field.tagged_field_q(version)])

def untagged_fields(self, version):
return [field for field in self._fields
if field.for_version_q(version)
and not field.tagged_field_q(version)]

def encode(self, item, version=None, compact=False, tagged=False):
assert version is not None, 'version required to encode StructField'
if not self.for_version_q(version):
return b''
fields = [field for field in self._fields if field.for_version_q(version) and not field.tagged_field_q(version)]
fields = self.untagged_fields(version)
if isinstance(item, tuple):
getter = lambda item, i, field: item[i]
tags = {} if len(item) == len(fields) else item[-1]
elif isinstance(item, dict):
getter = lambda item, i, field: item.get(field.name) # defaults?
tags = item
elif isinstance(item, (str, int, float)):
assert len(fields) == 1, "Encoding single value item (str/int/float) requires single field struct"
getter = lambda item, i, field: item
tags = {}
else:
getter = lambda item, i, field: getattr(item, field.name)
tags = item
Expand All @@ -75,7 +84,7 @@ def encode(self, item, version=None, compact=False, tagged=False):
if tagged:
# TaggedFields are always compact and never include nested tagged fields
encoded.append(self.tagged_fields(version).encode(tags, version=version,
compact=True, tagged=False))
compact=True, tagged=False))
return b''.join(encoded)

def decode(self, data, version=None, compact=False, tagged=False, data_class=None):
Expand All @@ -91,7 +100,7 @@ def decode(self, data, version=None, compact=False, tagged=False, data_class=Non
}
if tagged:
decoded.update(self.tagged_fields(version).decode(data, version=version, compact=True, tagged=False))
return data_class(**decoded)
return data_class(version=version, **decoded)

def __len__(self):
return len(self._fields)
Expand All @@ -104,4 +113,4 @@ def __eq__(self, other):
return True

def __repr__(self):
return '%s(%s, %s)' % (self.__class__.__name__, self._name, self._fields)
return 'StructField(%s)' % self._json
9 changes: 9 additions & 0 deletions kafka/protocol/new/schemas/fields/struct_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def is_struct_array(self):
def fields(self):
return self.array_of.fields

def tagged_fields(self, version):
return self.array_of.tagged_fields(version)

def untagged_fields(self, version):
return self.array_of.untagged_fields(version)

def has_data_class(self):
return self.array_of.has_data_class()

Expand All @@ -52,3 +58,6 @@ def data_class(self):

def __call__(self, *args, **kw):
return self.data_class(*args, **kw) # pylint: disable=E1102

def __repr__(self):
return 'StructArrayField(%s)' % self._json
Loading
Loading