Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions kafka/protocol/new/api_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import io
import weakref

from kafka.util import classproperty

from .data_container import DataContainer, SlotsBuilder
from .schemas import BaseField, StructField, load_json
from .schemas.fields.codecs import Int16, Int32


class JsonSchemaData(SlotsBuilder):
def __new__(metacls, name, bases, attrs, **kw):
if kw.get('init', True):
json = load_json(name)
if 'json_patch' in attrs:
json = attrs['json_patch'].__func__(metacls, json)
attrs['_json'] = json
attrs['_struct'] = StructField(json)
if 'doc' in json:
attrs['__doc__'] = attrs.get('__doc__', '') + "\nNotes from json schema:\n" + json.get('doc')
attrs['__license__'] = json.get('license')
return super().__new__(metacls, name, bases, attrs, **kw)

def __init__(cls, name, bases, attrs, **kw):
super().__init__(name, bases, attrs, **kw)
if kw.get('init', True):
cls._struct.set_data_class(weakref.proxy(cls))


class ApiData(DataContainer, metaclass=JsonSchemaData, init=False):
def __init_subclass__(cls, **kw):
super().__init_subclass__(**kw)
if kw.get('init', True):
# pylint: disable=E1101
assert cls._json is not None
assert cls._json['type'] == 'data'
cls._flexible_versions = BaseField.parse_versions(cls._json['flexibleVersions'])
cls._valid_versions = BaseField.parse_versions(cls._json['validVersions'])

def __init__(self, *args, **kwargs):
if len(args) > 0 and isinstance(args[0], int) and 'version' not in kwargs:
kwargs['version'] = args[0]
args = tuple(args[1:])
super().__init__(*args, **kwargs)

@classproperty
def name(cls): # pylint: disable=E0213
return cls._json['name'] # pylint: disable=E1101

@classproperty
def type(cls): # pylint: disable=E0213
return cls._json['type'] # pylint: disable=E1101

@classproperty
def json(cls): # pylint: disable=E0213
return cls._json # pylint: disable=E1101

@classproperty
def valid_versions(cls): # pylint: disable=E0213
return cls._valid_versions

@classproperty
def min_version(cls): # pylint: disable=E0213
return 0

@classproperty
def max_version(cls): # pylint: disable=E0213
if cls._valid_versions is not None:
return cls._valid_versions[1] # pylint: disable=E1136
return None

@classmethod
def flexible_version_q(cls, version):
if cls._flexible_versions is not None:
if cls._flexible_versions[0] <= version <= cls._flexible_versions[1]: # pylint: disable=E1136
return True
return False

@classproperty
def header_class(cls): # pylint: disable=E0213
return Int16

def encode_header(self, flexible=False):
assert self._version is not None
return self.header_class.encode(self._version)

@classmethod
def parse_header(cls, data):
return cls.header_class.decode(data) # pylint: disable-msg=no-member

def encode(self, version=None, header=True, framed=False):
if version is not None:
self._version = version
elif self._version is None:
raise ValueError('Version required to encode data')
flexible = self.flexible_version_q(self._version)
encoded = self._struct.encode(self, version=self._version, compact=flexible, tagged=flexible)
if not header and not framed:
return encoded
bits = [encoded]
if header:
bits.insert(0, self.encode_header(flexible=flexible))
if framed:
bits.insert(0, Int32.encode(sum(map(len, bits))))
return b''.join(bits)

@classmethod
def decode(cls, data, version=None, header=True, framed=False):
if not header:
if version is None:
raise ValueError('Version required to decode data')
elif not 0 <= version <= cls.max_version:
raise ValueError('Invalid version %s (max version is %s).' % (version, cls.max_version))
if isinstance(data, bytes):
data = io.BytesIO(data)
if framed:
size = Int32.decode(data)
if header:
decoded_version = cls.parse_header(data)
if version is not None:
if version > decoded_version:
raise ValueError('Version mismatch: found v%d, expected v%d' % (decoded_version, version))
version = min(decoded_version, cls.max_version)
flexible = cls.flexible_version_q(version)
return cls._struct.decode(data, version=version, compact=flexible, tagged=flexible)
12 changes: 2 additions & 10 deletions kafka/protocol/new/api_header.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
from .api_data import JsonSchemaData
from .data_container import DataContainer, SlotsBuilder
from .schemas import BaseField, StructField, load_json


class ApiHeaderMeta(SlotsBuilder):
def __new__(metacls, name, bases, attrs, **kw):
if kw.get('init', True):
json = load_json(name)
attrs['_json'] = json
attrs['_struct'] = StructField(json)
return super().__new__(metacls, name, bases, attrs, **kw)


class ApiHeader(DataContainer, metaclass=ApiHeaderMeta, init=False):
class ApiHeader(DataContainer, metaclass=JsonSchemaData, init=False):
__slots__ = ()

def __init_subclass__(cls, **kw):
Expand Down
22 changes: 4 additions & 18 deletions kafka/protocol/new/api_message.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import io
import weakref

from .api_data import JsonSchemaData
from .api_header import RequestHeader, ResponseHeader, ResponseClassRegistry
from .data_container import DataContainer, SlotsBuilder
from .data_container import DataContainer
from .schemas import BaseField, StructField, load_json
from .schemas.fields.codecs import Int32

Expand Down Expand Up @@ -48,32 +49,17 @@ def __len__(cls):
return cls._valid_versions[1] + 1


class ApiMessageMeta(VersionSubscriptable, SlotsBuilder):
def __new__(metacls, name, bases, attrs, **kw):
# Pass init=False from base classes
if kw.get('init', True):
json = load_json(name)
if 'json_patch' in attrs:
json = attrs['json_patch'].__func__(metacls, json)
attrs['_json'] = json
attrs['_struct'] = StructField(json)
attrs['__doc__'] = json.get('doc')
attrs['__license__'] = json.get('license')
return super().__new__(metacls, name, bases, attrs, **kw)

class ApiMessageData(VersionSubscriptable, JsonSchemaData):
def __init__(cls, name, bases, attrs, **kw):
super().__init__(name, bases, attrs, **kw)
if kw.get('init', True):
# Ignore min valid version on request/response schemas
# We'll get the brokers supported versions via ApiVersionsRequest
if cls._struct._versions[0] > 0:
cls._struct._versions = (0, cls._struct._versions[1])
# Configure the StructField to use our ApiMessage wrapper
# and not construct a default DataContainer
cls._struct.set_data_class(weakref.proxy(cls))


class ApiMessage(DataContainer, metaclass=ApiMessageMeta, init=False):
class ApiMessage(DataContainer, metaclass=ApiMessageData, init=False):
__slots__ = ('_header')

def __init_subclass__(cls, **kw):
Expand Down
Loading