diff --git a/kafka/protocol/new/api_data.py b/kafka/protocol/new/api_data.py new file mode 100644 index 000000000..7ecbe0a28 --- /dev/null +++ b/kafka/protocol/new/api_data.py @@ -0,0 +1,125 @@ +import io +import weakref + +from kafka.util import classproperty + +from .data_container import DataContainer, SlotsBuilder +from .schemas import BaseField, StructField, load_json +from .schemas.fields.codecs import Int16, Int32 + + +class JsonSchemaData(SlotsBuilder): + def __new__(metacls, name, bases, attrs, **kw): + if kw.get('init', True): + json = load_json(name) + if 'json_patch' in attrs: + json = attrs['json_patch'].__func__(metacls, json) + attrs['_json'] = json + attrs['_struct'] = StructField(json) + if 'doc' in json: + attrs['__doc__'] = attrs.get('__doc__', '') + "\nNotes from json schema:\n" + json.get('doc') + attrs['__license__'] = json.get('license') + return super().__new__(metacls, name, bases, attrs, **kw) + + def __init__(cls, name, bases, attrs, **kw): + super().__init__(name, bases, attrs, **kw) + if kw.get('init', True): + cls._struct.set_data_class(weakref.proxy(cls)) + + +class ApiData(DataContainer, metaclass=JsonSchemaData, init=False): + def __init_subclass__(cls, **kw): + super().__init_subclass__(**kw) + if kw.get('init', True): + # pylint: disable=E1101 + assert cls._json is not None + assert cls._json['type'] == 'data' + cls._flexible_versions = BaseField.parse_versions(cls._json['flexibleVersions']) + cls._valid_versions = BaseField.parse_versions(cls._json['validVersions']) + + def __init__(self, *args, **kwargs): + if len(args) > 0 and isinstance(args[0], int) and 'version' not in kwargs: + kwargs['version'] = args[0] + args = tuple(args[1:]) + super().__init__(*args, **kwargs) + + @classproperty + def name(cls): # pylint: disable=E0213 + return cls._json['name'] # pylint: disable=E1101 + + @classproperty + def type(cls): # pylint: disable=E0213 + return cls._json['type'] # pylint: disable=E1101 + + @classproperty + def json(cls): # pylint: disable=E0213 + return cls._json # pylint: disable=E1101 + + @classproperty + def valid_versions(cls): # pylint: disable=E0213 + return cls._valid_versions + + @classproperty + def min_version(cls): # pylint: disable=E0213 + return 0 + + @classproperty + def max_version(cls): # pylint: disable=E0213 + if cls._valid_versions is not None: + return cls._valid_versions[1] # pylint: disable=E1136 + return None + + @classmethod + def flexible_version_q(cls, version): + if cls._flexible_versions is not None: + if cls._flexible_versions[0] <= version <= cls._flexible_versions[1]: # pylint: disable=E1136 + return True + return False + + @classproperty + def header_class(cls): # pylint: disable=E0213 + return Int16 + + def encode_header(self, flexible=False): + assert self._version is not None + return self.header_class.encode(self._version) + + @classmethod + def parse_header(cls, data): + return cls.header_class.decode(data) # pylint: disable-msg=no-member + + def encode(self, version=None, header=True, framed=False): + if version is not None: + self._version = version + elif self._version is None: + raise ValueError('Version required to encode data') + flexible = self.flexible_version_q(self._version) + encoded = self._struct.encode(self, version=self._version, compact=flexible, tagged=flexible) + if not header and not framed: + return encoded + bits = [encoded] + if header: + bits.insert(0, self.encode_header(flexible=flexible)) + if framed: + bits.insert(0, Int32.encode(sum(map(len, bits)))) + return b''.join(bits) + + @classmethod + def decode(cls, data, version=None, header=True, framed=False): + if not header: + if version is None: + raise ValueError('Version required to decode data') + elif not 0 <= version <= cls.max_version: + raise ValueError('Invalid version %s (max version is %s).' % (version, cls.max_version)) + if isinstance(data, bytes): + data = io.BytesIO(data) + if framed: + size = Int32.decode(data) + if header: + decoded_version = cls.parse_header(data) + if version is not None: + if version > decoded_version: + raise ValueError('Version mismatch: found v%d, expected v%d' % (decoded_version, version)) + version = min(decoded_version, cls.max_version) + flexible = cls.flexible_version_q(version) + return cls._struct.decode(data, version=version, compact=flexible, tagged=flexible) diff --git a/kafka/protocol/new/api_header.py b/kafka/protocol/new/api_header.py index 4acbc8198..cc4e4c710 100644 --- a/kafka/protocol/new/api_header.py +++ b/kafka/protocol/new/api_header.py @@ -1,17 +1,9 @@ +from .api_data import JsonSchemaData from .data_container import DataContainer, SlotsBuilder from .schemas import BaseField, StructField, load_json -class ApiHeaderMeta(SlotsBuilder): - def __new__(metacls, name, bases, attrs, **kw): - if kw.get('init', True): - json = load_json(name) - attrs['_json'] = json - attrs['_struct'] = StructField(json) - return super().__new__(metacls, name, bases, attrs, **kw) - - -class ApiHeader(DataContainer, metaclass=ApiHeaderMeta, init=False): +class ApiHeader(DataContainer, metaclass=JsonSchemaData, init=False): __slots__ = () def __init_subclass__(cls, **kw): diff --git a/kafka/protocol/new/api_message.py b/kafka/protocol/new/api_message.py index 8efb31e2b..7a3169689 100644 --- a/kafka/protocol/new/api_message.py +++ b/kafka/protocol/new/api_message.py @@ -1,8 +1,9 @@ import io import weakref +from .api_data import JsonSchemaData from .api_header import RequestHeader, ResponseHeader, ResponseClassRegistry -from .data_container import DataContainer, SlotsBuilder +from .data_container import DataContainer from .schemas import BaseField, StructField, load_json from .schemas.fields.codecs import Int32 @@ -48,19 +49,7 @@ def __len__(cls): return cls._valid_versions[1] + 1 -class ApiMessageMeta(VersionSubscriptable, SlotsBuilder): - def __new__(metacls, name, bases, attrs, **kw): - # Pass init=False from base classes - if kw.get('init', True): - json = load_json(name) - if 'json_patch' in attrs: - json = attrs['json_patch'].__func__(metacls, json) - attrs['_json'] = json - attrs['_struct'] = StructField(json) - attrs['__doc__'] = json.get('doc') - attrs['__license__'] = json.get('license') - return super().__new__(metacls, name, bases, attrs, **kw) - +class ApiMessageData(VersionSubscriptable, JsonSchemaData): def __init__(cls, name, bases, attrs, **kw): super().__init__(name, bases, attrs, **kw) if kw.get('init', True): @@ -68,12 +57,9 @@ def __init__(cls, name, bases, attrs, **kw): # We'll get the brokers supported versions via ApiVersionsRequest if cls._struct._versions[0] > 0: cls._struct._versions = (0, cls._struct._versions[1]) - # Configure the StructField to use our ApiMessage wrapper - # and not construct a default DataContainer - cls._struct.set_data_class(weakref.proxy(cls)) -class ApiMessage(DataContainer, metaclass=ApiMessageMeta, init=False): +class ApiMessage(DataContainer, metaclass=ApiMessageData, init=False): __slots__ = ('_header') def __init_subclass__(cls, **kw):