From 3d2c3c486f6be5b85a390869148eff5fd723c7dc Mon Sep 17 00:00:00 2001 From: evgeny Date: Thu, 29 Jan 2026 10:51:04 +0000 Subject: [PATCH 1/2] [AIT-316] feat: introduce support for message annotations - Added `RealtimeAnnotations` class to manage annotation creation, deletion, and subscription on realtime channels. - Introduced `Annotation` and `AnnotationAction` types to encapsulate annotation details and actions. - Extended flags to include `ANNOTATION_PUBLISH` and `ANNOTATION_SUBSCRIBE`. - Refactored data encoding logic into `ably.util.encoding`. - Integrated annotation handling into `RealtimeChannel` and `RestChannel`. --- ably/realtime/annotations.py | 239 ++++++++++++ ably/realtime/channel.py | 54 ++- ably/rest/annotations.py | 202 ++++++++++ ably/rest/channel.py | 18 +- ably/types/annotation.py | 226 +++++++++++ ably/types/channelmode.py | 50 +++ ably/types/channeloptions.py | 19 +- ably/types/flags.py | 2 + ably/types/message.py | 48 +-- ably/types/presence.py | 38 +- ably/util/encoding.py | 33 ++ ably/util/helper.py | 10 + .../ably/realtime/realtimeannotations_test.py | 350 ++++++++++++++++++ test/ably/rest/restannotations_test.py | 242 ++++++++++++ uv.lock | 2 +- 15 files changed, 1434 insertions(+), 99 deletions(-) create mode 100644 ably/realtime/annotations.py create mode 100644 ably/rest/annotations.py create mode 100644 ably/types/annotation.py create mode 100644 ably/types/channelmode.py create mode 100644 ably/util/encoding.py create mode 100644 test/ably/realtime/realtimeannotations_test.py create mode 100644 test/ably/rest/restannotations_test.py diff --git a/ably/realtime/annotations.py b/ably/realtime/annotations.py new file mode 100644 index 00000000..96775b2c --- /dev/null +++ b/ably/realtime/annotations.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from ably.rest.annotations import RestAnnotations, construct_validate_annotation +from ably.transport.websockettransport import ProtocolMessageAction +from ably.types.annotation import Annotation, AnnotationAction +from ably.types.channelstate import ChannelState +from ably.types.flags import Flag +from ably.util.eventemitter import EventEmitter +from ably.util.exceptions import AblyException +from ably.util.helper import is_callable_or_coroutine + +if TYPE_CHECKING: + from ably.realtime.channel import RealtimeChannel + from ably.realtime.connectionmanager import ConnectionManager + +log = logging.getLogger(__name__) + + +class RealtimeAnnotations: + """ + Provides realtime methods for managing annotations on messages, + including publishing annotations and subscribing to annotation events. + """ + + __connection_manager: ConnectionManager + __channel: RealtimeChannel + + def __init__(self, channel: RealtimeChannel, connection_manager: ConnectionManager): + """ + Initialize RealtimeAnnotations. + + Args: + channel: The Realtime Channel this annotations instance belongs to + """ + self.__channel = channel + self.__connection_manager = connection_manager + self.__subscriptions = EventEmitter() + self.__rest_annotations = RestAnnotations(channel) + + async def publish(self, msg_or_serial, annotation: dict | Annotation, params: dict=None): + """ + Publish an annotation on a message via the realtime connection. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation: Dict containing annotation properties (type, name, data, etc.) or Annotation object + params: Optional dict of query parameters + + Returns: + None + + Raises: + AblyException: If the request fails, inputs are invalid, or channel is in unpublishable state + """ + annotation = construct_validate_annotation(msg_or_serial, annotation) + + # Check if channel and connection are in publishable state + self.__channel._throw_if_unpublishable_state() + + log.info( + f'RealtimeAnnotations.publish(), channelName = {self.__channel.name}, ' + f'sending annotation with messageSerial = {annotation.message_serial}, ' + f'type = {annotation.type}' + ) + + # Convert to wire format (array of annotations) + wire_annotation = annotation.as_dict(binary=self.__channel.ably.options.use_binary_protocol) + + # Build protocol message + protocol_message = { + "action": ProtocolMessageAction.ANNOTATION, + "channel": self.__channel.name, + "annotations": [wire_annotation], + } + + if params: + # Stringify boolean params + stringified_params = {k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items()} + protocol_message["params"] = stringified_params + + # Send via WebSocket + await self.__connection_manager.send_protocol_message(protocol_message) + + async def delete(self, msg_or_serial, annotation: dict | Annotation, params=None, timeout=None): + """ + Delete an annotation on a message. + + This is a convenience method that sets the action to 'annotation.delete' + and calls publish(). + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation: Dict containing annotation properties or Annotation object + params: Optional dict of query parameters + timeout: Optional timeout (not used for realtime, kept for compatibility) + + Returns: + None + + Raises: + AblyException: If the request fails or inputs are invalid + """ + if isinstance(annotation, Annotation): + annotation_values = annotation.as_dict() + else: + annotation_values = annotation.copy() + annotation_values['action'] = AnnotationAction.ANNOTATION_DELETE + return await self.publish(msg_or_serial, annotation_values, params) + + async def subscribe(self, *args): + """ + Subscribe to annotation events on this channel. + + Parameters + ---------- + *args: type, listener + Subscribe type and listener + + arg1(type): str, optional + Subscribe to annotations of the given type + + arg2(listener): callable + Subscribe to all annotations on the channel + + When no type is provided, arg1 is used as the listener. + + Raises + ------ + AblyException + If unable to subscribe due to invalid channel state or missing ANNOTATION_SUBSCRIBE mode + ValueError + If no valid subscribe arguments are passed + """ + # Parse arguments similar to channel.subscribe + if len(args) == 0: + raise ValueError("annotations.subscribe called without arguments") + + if len(args) >= 2 and isinstance(args[0], str): + annotation_type = args[0] + if not args[1]: + raise ValueError("annotations.subscribe called without listener") + if not is_callable_or_coroutine(args[1]): + raise ValueError("subscribe listener must be function or coroutine function") + listener = args[1] + elif is_callable_or_coroutine(args[0]): + listener = args[0] + annotation_type = None + else: + raise ValueError('invalid subscribe arguments') + + # Register subscription + if annotation_type is not None: + self.__subscriptions.on(annotation_type, listener) + else: + self.__subscriptions.on(listener) + + await self.__channel.attach() + + # Check if ANNOTATION_SUBSCRIBE mode is enabled + if self.__channel.state == ChannelState.ATTACHED: + if not Flag.ANNOTATION_SUBSCRIBE in self.__channel.modes: + raise AblyException( + "You are trying to add an annotation listener, but you haven't requested the " + "annotation_subscribe channel mode in ChannelOptions, so this won't do anything " + "(we only deliver annotations to clients who have explicitly requested them)", + 93001, + 400 + ) + + def unsubscribe(self, *args): + """ + Unsubscribe from annotation events on this channel. + + Parameters + ---------- + *args: type, listener + Unsubscribe type and listener + + arg1(type): str, optional + Unsubscribe from annotations of the given type + + arg2(listener): callable + Unsubscribe from all annotations on the channel + + When no type is provided, arg1 is used as the listener. + + Raises + ------ + ValueError + If no valid unsubscribe arguments are passed + """ + if len(args) == 0: + raise ValueError("annotations.unsubscribe called without arguments") + + if len(args) >= 2 and isinstance(args[0], str): + annotation_type = args[0] + listener = args[1] + self.__subscriptions.off(annotation_type, listener) + elif is_callable_or_coroutine(args[0]): + listener = args[0] + self.__subscriptions.off(listener) + else: + raise ValueError('invalid unsubscribe arguments') + + def _process_incoming(self, incoming_annotations): + """ + Process incoming annotations from the server. + + This is called internally when ANNOTATION protocol messages are received. + + Args: + incoming_annotations: List of Annotation objects received from the server + """ + for annotation in incoming_annotations: + # Emit to type-specific listeners and catch-all listeners + annotation_type = annotation.type or '' + self.__subscriptions._emit(annotation_type, annotation) + + async def get(self, msg_or_serial, params=None): + """ + Retrieve annotations for a message with pagination support. + + This delegates to the REST implementation. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + params: Optional dict of query parameters (limit, start, end, direction) + + Returns: + PaginatedResult: A paginated result containing Annotation objects + + Raises: + AblyException: If the request fails or serial is invalid + """ + # Delegate to REST implementation + return await self.__rest_annotations.get(msg_or_serial, params) diff --git a/ably/realtime/channel.py b/ably/realtime/channel.py index e0fd6251..4830132a 100644 --- a/ably/realtime/channel.py +++ b/ably/realtime/channel.py @@ -4,10 +4,13 @@ import logging from typing import TYPE_CHECKING +from ably.realtime.annotations import RealtimeAnnotations from ably.realtime.connection import ConnectionState +from ably.realtime.presence import RealtimePresence from ably.rest.channel import Channel from ably.rest.channel import Channels as RestChannels from ably.transport.websockettransport import ProtocolMessageAction +from ably.types.annotation import Annotation from ably.types.channeloptions import ChannelOptions from ably.types.channelstate import ChannelState, ChannelStateChange from ably.types.flags import Flag, has_flag @@ -18,6 +21,7 @@ from ably.util.eventemitter import EventEmitter from ably.util.exceptions import AblyException, IncompatibleClientIdException from ably.util.helper import Timer, is_callable_or_coroutine, validate_message_size +from ably.types.channelmode import ChannelMode, decode_channel_mode, encode_channel_mode if TYPE_CHECKING: from ably.realtime.realtime import AblyRealtime @@ -64,6 +68,7 @@ def __init__(self, realtime: AblyRealtime, name: str, channel_options: ChannelOp self.__error_reason: AblyException | None = None self.__channel_options = channel_options or ChannelOptions() self.__params: dict[str, str] | None = None + self.__modes: list[ChannelMode] = list() # Channel mode flags from ATTACHED message # Delta-specific fields for RTL19/RTL20 compliance vcdiff_decoder = self.__realtime.options.vcdiff_decoder if self.__realtime.options.vcdiff_decoder else None @@ -74,12 +79,15 @@ def __init__(self, realtime: AblyRealtime, name: str, channel_options: ChannelOp # will be disrupted if the user called .off() to remove all listeners self.__internal_state_emitter = EventEmitter() + # Pass channel options as dictionary to parent Channel class + Channel.__init__(self, realtime, name, self.__channel_options.to_dict()) + # Initialize presence for this channel - from ably.realtime.presence import RealtimePresence + self.__presence = RealtimePresence(self) - # Pass channel options as dictionary to parent Channel class - Channel.__init__(self, realtime, name, self.__channel_options.to_dict()) + # Initialize realtime annotations for this channel (override REST annotations) + self._Channel__annotations = RealtimeAnnotations(self, realtime.connection.connection_manager) async def set_options(self, channel_options: ChannelOptions) -> None: """Set channel options""" @@ -149,8 +157,10 @@ def _attach_impl(self): "channel": self.name, } - if self.__attach_resume: - attach_msg["flags"] = Flag.ATTACH_RESUME + flags = self._encode_flags() + + if flags: + attach_msg["flags"] = flags if self.__channel_serial: attach_msg["channelSerial"] = self.__channel_serial @@ -491,8 +501,8 @@ async def _send_update( if not message.serial: raise AblyException( "Message serial is required for update/delete/append operations", - 400, - 40003 + status_code=400, + code=40003, ) # Check connection and channel state @@ -702,6 +712,8 @@ def _on_message(self, proto_msg: dict) -> None: resumed = has_flag(flags, Flag.RESUMED) # RTP1: Check for HAS_PRESENCE flag has_presence = has_flag(flags, Flag.HAS_PRESENCE) + # Store channel attach flags + self.__modes = decode_channel_mode(flags) # RTL12 if self.state == ChannelState.ATTACHED: @@ -744,6 +756,15 @@ def _on_message(self, proto_msg: dict) -> None: decoded_presence = PresenceMessage.from_encoded_array(presence_messages, cipher=self.cipher) sync_channel_serial = proto_msg.get('channelSerial') self.__presence.set_presence(decoded_presence, is_sync=True, sync_channel_serial=sync_channel_serial) + elif action == ProtocolMessageAction.ANNOTATION: + # Handle ANNOTATION messages + annotation_data = proto_msg.get('annotations', []) + try: + annotations = Annotation.from_encoded_array(annotation_data, cipher=self.cipher) + # Process annotations through the annotations handler + self.annotations._process_incoming(annotations) + except Exception as e: + log.error(f"Annotation processing error {e}. Skip annotations {annotation_data}") elif action == ProtocolMessageAction.ERROR: error = AblyException.from_dict(proto_msg.get('error')) self._notify_state(ChannelState.FAILED, reason=error) @@ -890,6 +911,11 @@ def presence(self): """Get the RealtimePresence object for this channel""" return self.__presence + @property + def modes(self): + """Get the list of channel modes""" + return self.__modes + def _start_decode_failure_recovery(self, error: AblyException) -> None: """Start RTL18 decode failure recovery procedure""" @@ -908,6 +934,20 @@ def _start_decode_failure_recovery(self, error: AblyException) -> None: self._notify_state(ChannelState.ATTACHING, reason=error) self._check_pending_state() + def _encode_flags(self) -> int | None: + if not self.__channel_options.modes and not self.__attach_resume: + return None + + flags = 0 + + if self.__attach_resume: + flags |= Flag.ATTACH_RESUME + + if self.__channel_options.modes: + flags |= encode_channel_mode(self.__channel_options.modes) + + return flags + class Channels(RestChannels): """Creates and destroys RealtimeChannel objects. diff --git a/ably/rest/annotations.py b/ably/rest/annotations.py new file mode 100644 index 00000000..7f20fb3d --- /dev/null +++ b/ably/rest/annotations.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import json +import logging +from urllib import parse + +import msgpack + +from ably.http.paginatedresult import PaginatedResult, format_params +from ably.types.annotation import ( + Annotation, + make_annotation_response_handler, +) +from ably.types.message import Message +from ably.util.exceptions import AblyException + +log = logging.getLogger(__name__) + + +def serial_from_msg_or_serial(msg_or_serial): + """ + Extract the message serial from either a string serial or a Message object. + + Args: + msg_or_serial: Either a string serial or a Message object with a serial property + + Returns: + str: The message serial + + Raises: + AblyException: If the input is invalid or serial is missing + """ + if isinstance(msg_or_serial, str): + message_serial = msg_or_serial + elif isinstance(msg_or_serial, Message): + message_serial = msg_or_serial.serial + else: + message_serial = None + + if not message_serial or not isinstance(message_serial, str): + raise AblyException( + message='First argument of annotations.publish() must be either a Message ' + '(or at least an object with a string `serial` property) or a message serial (string)', + status_code=400, + code=40003, + ) + + return message_serial + + +def construct_validate_annotation(msg_or_serial, annotation: dict | Annotation): + """ + Construct and validate an Annotation from input values. + + Args: + msg_or_serial: Either a string serial or a Message object + annotation: Dict of annotation properties or Annotation object + + Returns: + Annotation: The constructed annotation + + Raises: + AblyException: If the inputs are invalid + """ + message_serial = serial_from_msg_or_serial(msg_or_serial) + + if not annotation or (not isinstance(annotation, dict) and not isinstance(annotation, Annotation)): + raise AblyException( + message='Second argument of annotations.publish() must be a dict or Annotation ' + '(the intended annotation to publish)', + status_code=400, + code=40003, + ) + elif isinstance(annotation, Annotation): + annotation_values = annotation.as_dict() + else: + annotation_values = annotation + + annotation_values['message_serial'] = message_serial + + return Annotation.from_values(annotation_values) + + +class RestAnnotations: + """ + Provides REST API methods for managing annotations on messages. + """ + + def __init__(self, channel): + """ + Initialize RestAnnotations. + + Args: + channel: The REST Channel this annotations instance belongs to + """ + self.__channel = channel + + def __base_path_for_serial(self, serial): + """ + Build the base API path for a message serial's annotations. + + Args: + serial: The message serial + + Returns: + str: The API path + """ + channel_path = '/channels/{}/'.format(parse.quote_plus(self.__channel.name, safe=':')) + return channel_path + 'messages/' + parse.quote_plus(serial, safe=':') + '/annotations' + + async def publish(self, msg_or_serial, annotation_values, params=None, timeout=None): + """ + Publish an annotation on a message. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation_values: Dict containing annotation properties (type, name, data, etc.) + params: Optional dict of query parameters + timeout: Optional timeout for the HTTP request + + Returns: + None + + Raises: + AblyException: If the request fails or inputs are invalid + """ + annotation = construct_validate_annotation(msg_or_serial, annotation_values) + + # Convert to wire format + request_body = annotation.as_dict(binary=self.__channel.ably.options.use_binary_protocol) + + # Wrap in array as API expects array of annotations + request_body = [request_body] + + # Encode based on protocol + if not self.__channel.ably.options.use_binary_protocol: + request_body = json.dumps(request_body, separators=(',', ':')) + else: + request_body = msgpack.packb(request_body, use_bin_type=True) + + # Build path + path = self.__base_path_for_serial(annotation.message_serial) + if params: + params = {k: str(v).lower() if type(v) is bool else v for k, v in params.items()} + path += '?' + parse.urlencode(params) + + # Send request + await self.__channel.ably.http.post(path, body=request_body, timeout=timeout) + + async def delete(self, msg_or_serial, annotation_values, params=None, timeout=None): + """ + Delete an annotation on a message. + + This is a convenience method that sets the action to 'annotation.delete' + and calls publish(). + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation_values: Dict containing annotation properties + params: Optional dict of query parameters + timeout: Optional timeout for the HTTP request + + Returns: + None + + Raises: + AblyException: If the request fails or inputs are invalid + """ + # Set action to delete + annotation_values = annotation_values.copy() + annotation_values['action'] = 'annotation.delete' + return await self.publish(msg_or_serial, annotation_values, params, timeout) + + async def get(self, msg_or_serial, params=None): + """ + Retrieve annotations for a message with pagination support. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + params: Optional dict of query parameters (limit, start, end, direction) + + Returns: + PaginatedResult: A paginated result containing Annotation objects + + Raises: + AblyException: If the request fails or serial is invalid + """ + message_serial = serial_from_msg_or_serial(msg_or_serial) + + # Build path + params_str = format_params({}, **params) if params else '' + path = self.__base_path_for_serial(message_serial) + params_str + + # Create annotation response handler + annotation_handler = make_annotation_response_handler(cipher=None) + + # Return paginated result + return await PaginatedResult.paginated_query( + self.__channel.ably.http, + url=path, + response_processor=annotation_handler + ) diff --git a/ably/rest/channel.py b/ably/rest/channel.py index 2c1c0246..f5a3e894 100644 --- a/ably/rest/channel.py +++ b/ably/rest/channel.py @@ -9,6 +9,7 @@ import msgpack from ably.http.paginatedresult import PaginatedResult, format_params +from ably.rest.annotations import RestAnnotations from ably.types.channeldetails import ChannelDetails from ably.types.message import ( Message, @@ -37,6 +38,7 @@ def __init__(self, ably, name, options): self.__cipher = None self.options = options self.__presence = Presence(self) + self.__annotations = RestAnnotations(self) @catch_all async def history(self, direction=None, limit: int = None, start=None, end=None): @@ -169,8 +171,8 @@ async def _send_update( if not message.serial: raise AblyException( "Message serial is required for update/delete/append operations", - 400, - 40003 + status_code=400, + code=40003, ) if not operation: @@ -282,8 +284,8 @@ async def get_message(self, serial_or_message, timeout=None): raise AblyException( 'This message lacks a serial. Make sure you have enabled "Message annotations, ' 'updates, and deletes" in channel settings on your dashboard.', - 400, - 40003 + status_code=400, + code=40003, ) # Build the path @@ -321,8 +323,8 @@ async def get_message_versions(self, serial_or_message, params=None): raise AblyException( 'This message lacks a serial. Make sure you have enabled "Message annotations, ' 'updates, and deletes" in channel settings on your dashboard.', - 400, - 40003 + status_code=400, + code=40003, ) # Build the path @@ -363,6 +365,10 @@ def options(self): def presence(self): return self.__presence + @property + def annotations(self): + return self.__annotations + @options.setter def options(self, options): self.__options = options diff --git a/ably/types/annotation.py b/ably/types/annotation.py new file mode 100644 index 00000000..a3aded28 --- /dev/null +++ b/ably/types/annotation.py @@ -0,0 +1,226 @@ +import logging +from enum import IntEnum + +from ably.types.mixins import EncodeDataMixin +from ably.util.encoding import encode_data +from ably.util.helper import to_text + +log = logging.getLogger(__name__) + + +class AnnotationAction(IntEnum): + """Annotation action types""" + ANNOTATION_CREATE = 0 + ANNOTATION_DELETE = 1 + + +class Annotation(EncodeDataMixin): + """ + Represents an annotation on a message, such as a reaction or other metadata. + + Annotations are not encrypted as they need to be parsed by the server for summarization. + """ + + def __init__(self, + action=None, + serial=None, + message_serial=None, + type=None, + name=None, + count=None, + data=None, + encoding='', + client_id=None, + timestamp=None, + extras=None): + """ + Args: + action: The action type - either 'annotation.create' or 'annotation.delete' + serial: A unique identifier for the annotation + message_serial: The serial of the message this annotation is for + type: The type of annotation (e.g., 'reaction', 'like', etc.) + name: The name/value of the annotation (e.g., specific emoji) + count: Count associated with the annotation + data: Optional data payload for the annotation + encoding: Encoding format for the data + client_id: The client ID that created this annotation + timestamp: Timestamp of the annotation + extras: Additional metadata + """ + super().__init__(encoding) + + self.__serial = to_text(serial) if serial is not None else None + self.__message_serial = to_text(message_serial) if message_serial is not None else None + self.__type = to_text(type) if type is not None else None + self.__name = to_text(name) if name is not None else None + self.__action = action if action is not None else AnnotationAction.ANNOTATION_CREATE + self.__count = count + self.__data = data + self.__client_id = to_text(client_id) if client_id is not None else None + self.__timestamp = timestamp + self.__extras = extras + + def __eq__(self, other): + if isinstance(other, Annotation): + return (self.serial == other.serial + and self.message_serial == other.message_serial + and self.type == other.type + and self.name == other.name + and self.action == other.action) + return NotImplemented + + def __ne__(self, other): + if isinstance(other, Annotation): + result = self.__eq__(other) + if result != NotImplemented: + return not result + return NotImplemented + + @property + def action(self): + return self.__action + + @property + def serial(self): + return self.__serial + + @property + def message_serial(self): + return self.__message_serial + + @property + def type(self): + return self.__type + + @property + def name(self): + return self.__name + + @property + def count(self): + return self.__count + + @property + def data(self): + return self.__data + + @property + def client_id(self): + return self.__client_id + + @property + def timestamp(self): + return self.__timestamp + + @property + def extras(self): + return self.__extras + + def as_dict(self, binary=False): + """ + Convert annotation to dictionary format for API communication. + + Note: Annotations are not encrypted as they need to be parsed by the server. + """ + # Encode data + encoded = encode_data(self.data, self._encoding_array, binary) + + request_body = { + 'action': int(self.action) if self.action is not None else None, + 'serial': self.serial, + 'messageSerial': self.message_serial, + 'type': self.type, # Annotation type (not data type) + 'name': self.name, + 'count': self.count, + 'data': encoded.get('data'), + 'encoding': encoded.get('encoding', ''), + 'dataType': encoded.get('type'), # Data type (not annotation type) + 'clientId': self.client_id or None, + 'timestamp': self.timestamp or None, + 'extras': self.extras, + } + + # None values aren't included + request_body = {k: v for k, v in request_body.items() if v is not None} + + return request_body + + @staticmethod + def from_encoded(obj, cipher=None, context=None): + """ + Create an Annotation from an encoded object received from the API. + + Note: cipher parameter is accepted for consistency but annotations are not encrypted. + """ + action = obj.get('action') + serial = obj.get('serial') + message_serial = obj.get('messageSerial') + type_val = obj.get('type') + name = obj.get('name') + count = obj.get('count') + data = obj.get('data') + encoding = obj.get('encoding', '') + client_id = obj.get('clientId') + timestamp = obj.get('timestamp') + extras = obj.get('extras', None) + + # Decode data if present + decoded_data = Annotation.decode(data, encoding, cipher, context) if data is not None else {} + + # Convert action from int to enum + if action is not None: + try: + action = AnnotationAction(action) + except ValueError: + # If it's not a valid action value, store as None + action = None + else: + action = None + + return Annotation( + action=action, + serial=serial, + message_serial=message_serial, + type=type_val, + name=name, + count=count, + client_id=client_id, + timestamp=timestamp, + extras=extras, + **decoded_data + ) + + @staticmethod + def from_encoded_array(obj_array, cipher=None, context=None): + """Create an array of Annotations from encoded objects""" + return [Annotation.from_encoded(obj, cipher, context) for obj in obj_array] + + @staticmethod + def from_values(values): + """Create an Annotation from a dict of values""" + return Annotation(**values) + + def __str__(self): + return ( + f"Annotation(action={self.action}, messageSerial={self.message_serial}, " + f"type={self.type}, name={self.name})" + ) + + def __repr__(self): + return self.__str__() + + +def make_annotation_response_handler(cipher=None): + """Create a response handler for annotation API responses""" + def annotation_response_handler(response): + annotations = response.to_native() + return Annotation.from_encoded_array(annotations, cipher=cipher) + return annotation_response_handler + + +def make_single_annotation_response_handler(cipher=None): + """Create a response handler for single annotation API responses""" + def single_annotation_response_handler(response): + annotation = response.to_native() + return Annotation.from_encoded(annotation, cipher=cipher) + return single_annotation_response_handler diff --git a/ably/types/channelmode.py b/ably/types/channelmode.py new file mode 100644 index 00000000..6ba95f08 --- /dev/null +++ b/ably/types/channelmode.py @@ -0,0 +1,50 @@ +from enum import Enum + +from ably.types.flags import Flag + + +class ChannelMode(int, Enum): + PRESENCE = Flag.PRESENCE + PUBLISH = Flag.PUBLISH + SUBSCRIBE = Flag.SUBSCRIBE + PRESENCE_SUBSCRIBE = Flag.PRESENCE_SUBSCRIBE + ANNOTATION_PUBLISH = Flag.ANNOTATION_PUBLISH + ANNOTATION_SUBSCRIBE = Flag.ANNOTATION_SUBSCRIBE + + +def encode_channel_mode(modes: list[ChannelMode]) -> int: + """ + Encode a list of ChannelMode values into a bitmask. + + Args: + modes: List of ChannelMode values to encode + + Returns: + Integer bitmask with the corresponding flags set + """ + flags = 0 + + for mode in modes: + flags |= mode.value + + return flags + + +def decode_channel_mode(flags: int) -> list[ChannelMode]: + """ + Decode channel mode flags from a bitmask into a list of ChannelMode values. + + Args: + flags: Integer bitmask containing channel mode flags + + Returns: + List of ChannelMode values that are set in the flags + """ + modes = [] + + # Check each channel mode flag + for mode in ChannelMode: + if flags & mode.value: + modes.append(mode) + + return modes diff --git a/ably/types/channeloptions.py b/ably/types/channeloptions.py index 48e34dfe..b745a3e8 100644 --- a/ably/types/channeloptions.py +++ b/ably/types/channeloptions.py @@ -4,6 +4,7 @@ from ably.util.crypto import CipherParams from ably.util.exceptions import AblyException +from ably.types.channelmode import ChannelMode class ChannelOptions: @@ -17,36 +18,43 @@ class ChannelOptions: Channel parameters that configure the behavior of the channel. """ - def __init__(self, cipher: CipherParams | None = None, params: dict | None = None): + def __init__(self, cipher: CipherParams | None = None, params: dict | None = None, modes: list[ChannelMode] | None = None): self.__cipher = cipher self.__params = params + self.__modes = modes # Validate params if self.__params and not isinstance(self.__params, dict): raise AblyException("params must be a dictionary", 40000, 400) @property - def cipher(self): + def cipher(self) -> CipherParams | None: """Get cipher configuration""" return self.__cipher @property - def params(self) -> dict[str, str]: + def params(self) -> dict[str, str] | None: """Get channel parameters""" return self.__params + @property + def modes(self) -> list[ChannelMode] | None: + """Get channel parameters""" + return self.__modes + def __eq__(self, other): """Check equality with another ChannelOptions instance""" if not isinstance(other, ChannelOptions): return False return (self.__cipher == other.__cipher and - self.__params == other.__params) + self.__params == other.__params and self.__modes == other.__modes) def __hash__(self): """Make ChannelOptions hashable""" return hash(( self.__cipher, tuple(sorted(self.__params.items())) if self.__params else None, + tuple(sorted(self.__modes)) if self.__modes else None )) def to_dict(self) -> dict[str, Any]: @@ -56,6 +64,8 @@ def to_dict(self) -> dict[str, Any]: result['cipher'] = self.__cipher if self.__params: result['params'] = self.__params + if self.__modes: + result['modes'] = self.__modes return result @classmethod @@ -67,4 +77,5 @@ def from_dict(cls, options_dict: dict[str, Any]) -> ChannelOptions: return cls( cipher=options_dict.get('cipher'), params=options_dict.get('params'), + modes=options_dict.get('modes'), ) diff --git a/ably/types/flags.py b/ably/types/flags.py index 1666434c..86666019 100644 --- a/ably/types/flags.py +++ b/ably/types/flags.py @@ -13,6 +13,8 @@ class Flag(int, Enum): PUBLISH = 1 << 17 SUBSCRIBE = 1 << 18 PRESENCE_SUBSCRIBE = 1 << 19 + ANNOTATION_PUBLISH = 1 << 21 + ANNOTATION_SUBSCRIBE = 1 << 22 def has_flag(message_flags: int, flag: Flag): diff --git a/ably/types/message.py b/ably/types/message.py index 11caba57..81043608 100644 --- a/ably/types/message.py +++ b/ably/types/message.py @@ -1,27 +1,16 @@ -import base64 -import json import logging from enum import IntEnum from ably.types.mixins import DeltaExtras, EncodeDataMixin from ably.types.typedbuffer import TypedBuffer from ably.util.crypto import CipherData +from ably.util.encoding import encode_data from ably.util.exceptions import AblyException +from ably.util.helper import to_text log = logging.getLogger(__name__) -def to_text(value): - if value is None: - return value - elif isinstance(value, str): - return value - elif isinstance(value, bytes): - return value.decode() - else: - raise TypeError(f"expected string or bytes, not {type(value)}") - - class MessageVersion: """ Contains the details regarding the current version of the message - including when it was updated and by whom. @@ -234,38 +223,9 @@ def decrypt(self, channel_cipher): self.__data = decrypted_data def as_dict(self, binary=False): - data = self.data - data_type = None - encoding = self._encoding_array[:] - - if isinstance(data, (dict, list)): - encoding.append('json') - data = json.dumps(data) - data = str(data) - elif isinstance(data, str) and not binary: - pass - elif not binary and isinstance(data, (bytearray, bytes)): - data = base64.b64encode(data).decode('ascii') - encoding.append('base64') - elif isinstance(data, CipherData): - encoding.append(data.encoding_str) - data_type = data.type - if not binary: - data = base64.b64encode(data.buffer).decode('ascii') - encoding.append('base64') - else: - data = data.buffer - elif binary and isinstance(data, bytearray): - data = bytes(data) - - if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None): - raise AblyException("Invalid data payload", 400, 40011) - request_body = { 'name': self.name, - 'data': data, 'timestamp': self.timestamp or None, - 'type': data_type or None, 'clientId': self.client_id or None, 'id': self.id or None, 'connectionId': self.connection_id or None, @@ -274,11 +234,9 @@ def as_dict(self, binary=False): 'version': self.version.as_dict() if self.version else None, 'serial': self.serial, 'action': int(self.action) if self.action is not None else None, + **encode_data(self.data, self._encoding_array, binary), } - if encoding: - request_body['encoding'] = '/'.join(encoding).strip('/') - # None values aren't included request_body = {k: v for k, v in request_body.items() if v is not None} diff --git a/ably/types/presence.py b/ably/types/presence.py index 723ceacc..7d1a3c05 100644 --- a/ably/types/presence.py +++ b/ably/types/presence.py @@ -1,5 +1,3 @@ -import base64 -import json from datetime import datetime, timedelta from urllib import parse @@ -7,7 +5,7 @@ from ably.types.mixins import EncodeDataMixin from ably.types.typedbuffer import TypedBuffer from ably.util.crypto import CipherData -from ably.util.exceptions import AblyException +from ably.util.encoding import encode_data def _ms_since_epoch(dt): @@ -151,36 +149,10 @@ def to_encoded(self, binary=False): Handles proper encoding of data including JSON serialization, base64 encoding for binary data, and encryption support. """ - data = self.data - data_type = None - encoding = self._encoding_array[:] - - # Handle different data types and build encoding string - if isinstance(data, (dict, list)): - encoding.append('json') - data = json.dumps(data) - data = str(data) - elif isinstance(data, str) and not binary: - pass - elif not binary and isinstance(data, (bytearray, bytes)): - data = base64.b64encode(data).decode('ascii') - encoding.append('base64') - elif isinstance(data, CipherData): - encoding.append(data.encoding_str) - data_type = data.type - if not binary: - data = base64.b64encode(data.buffer).decode('ascii') - encoding.append('base64') - else: - data = data.buffer - elif binary and isinstance(data, bytearray): - data = bytes(data) - - if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None): - raise AblyException("Invalid data payload", 400, 40011) result = { 'action': self.action, + **encode_data(self.data, self._encoding_array, binary), } if self.id: @@ -189,12 +161,6 @@ def to_encoded(self, binary=False): result['clientId'] = self.client_id if self.connection_id: result['connectionId'] = self.connection_id - if data is not None: - result['data'] = data - if data_type: - result['type'] = data_type - if encoding: - result['encoding'] = '/'.join(encoding).strip('/') if self.extras: result['extras'] = self.extras if self.timestamp: diff --git a/ably/util/encoding.py b/ably/util/encoding.py new file mode 100644 index 00000000..b0af9620 --- /dev/null +++ b/ably/util/encoding.py @@ -0,0 +1,33 @@ +import base64 +import json +from typing import Any + +from ably.util.crypto import CipherData + + +def encode_data(data: Any, encoding_array: list, binary: bool = False): + encoding = encoding_array[:] + + if isinstance(data, (dict, list)): + encoding.append('json') + data = json.dumps(data) + data = str(data) + elif isinstance(data, str) and not binary: + pass + elif not binary and isinstance(data, (bytearray, bytes)): + data = base64.b64encode(data).decode('ascii') + encoding.append('base64') + elif isinstance(data, CipherData): + encoding.append(data.encoding_str) + if not binary: + data = base64.b64encode(data.buffer).decode('ascii') + encoding.append('base64') + else: + data = data.buffer + elif binary and isinstance(data, bytearray): + data = bytes(data) + + return { + 'data': data, + 'encoding': '/'.join(encoding).strip('/') + } diff --git a/ably/util/helper.py b/ably/util/helper.py index 53226f27..a35ebe6e 100644 --- a/ably/util/helper.py +++ b/ably/util/helper.py @@ -98,3 +98,13 @@ def validate_message_size(encoded_messages: list, use_binary_protocol: bool, max 400, 40009, ) + +def to_text(value): + if value is None: + return value + elif isinstance(value, str): + return value + elif isinstance(value, bytes): + return value.decode() + else: + raise TypeError(f"expected string or bytes, not {type(value)}") diff --git a/test/ably/realtime/realtimeannotations_test.py b/test/ably/realtime/realtimeannotations_test.py new file mode 100644 index 00000000..5e502380 --- /dev/null +++ b/test/ably/realtime/realtimeannotations_test.py @@ -0,0 +1,350 @@ +import asyncio +import logging + +import pytest + +from ably import AblyException +from ably.types.annotation import AnnotationAction +from ably.types.channeloptions import ChannelOptions +from ably.types.message import MessageAction +from test.ably.testapp import TestApp +from test.ably.utils import BaseAsyncTestCase, assert_waiter +from ably.types.channelmode import ChannelMode + +log = logging.getLogger(__name__) + + +@pytest.mark.parametrize("transport", ["json", "msgpack"], ids=["JSON", "MsgPack"]) +class TestRealtimeAnnotations(BaseAsyncTestCase): + + @pytest.fixture(autouse=True) + async def setup(self, transport): + self.test_vars = await TestApp.get_test_vars() + self.ably = await TestApp.get_ably_realtime( + use_binary_protocol=True if transport == 'msgpack' else False, + ) + self.rest = await TestApp.get_ably_rest( + use_binary_protocol=True if transport == 'msgpack' else False, + ) + + async def test_publish_and_subscribe_annotations(self): + """Test publishing and subscribing to annotations (matches JS test)""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.ably.channels.get( + self.get_channel_name('mutable:publish_subscribe_annotation'), + channel_options + ) + rest_channel = self.rest.channels[channel.name] + await channel.attach() + + # Setup annotation listener + annotation_future = asyncio.Future() + + async def on_annotation(annotation): + if not annotation_future.done(): + annotation_future.set_result(annotation) + + await channel.annotations.subscribe(on_annotation) + + # Publish a message + publish_result = await channel.publish('message', 'foobar') + + # Reset for next message (summary) + message_summary = asyncio.Future() + + def on_message(msg): + if not message_summary.done(): + message_summary.set_result(msg) + + await channel.subscribe('message', on_message) + + # Publish annotation using realtime + await channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:multiple.v1', + 'name': '👍' + }) + + # Wait for annotation + annotation = await annotation_future + assert annotation.action == AnnotationAction.ANNOTATION_CREATE + assert annotation.message_serial == publish_result.serials[0] + assert annotation.type == 'reaction:multiple.v1' + assert annotation.name == '👍' + assert annotation.serial > annotation.message_serial + + # Wait for summary message + # summary = await message_summary + # assert summary.action == MessageAction.META + # assert summary.serial == publish_result.serials[0] + # + # # Try again but with REST publish + # annotation_future2 = asyncio.Future() + # + # async def on_annotation2(annotation): + # if not annotation_future2.done(): + # annotation_future2.set_result(annotation) + # + # await channel.annotations.subscribe(on_annotation2) + # + # await rest_channel.annotations.publish(publish_result.serials[0], { + # 'type': 'reaction:multiple.v1', + # 'name': '😕' + # }) + # + # annotation = await annotation_future2 + # assert annotation.action == AnnotationAction.ANNOTATION_CREATE + # assert annotation.message_serial == publish_result.serials[0] + # assert annotation.type == 'reaction:multiple.v1' + # assert annotation.name == '😕' + # assert annotation.serial > annotation.message_serial + + async def test_get_all_annotations_for_a_message(self): + """Test retrieving all annotations with pagination (matches JS test)""" + channel_options = ChannelOptions(params={ + 'modes': 'publish,subscribe,annotation_publish,annotation_subscribe' + }) + channel = self.ably.channels.get( + self.get_channel_name('mutable:get_all_annotations_for_a_message'), + channel_options + ) + await channel.attach() + + # Setup message listener + message_future = asyncio.Future() + + def on_message(msg): + if not message_future.done(): + message_future.set_result(msg) + + await channel.subscribe('message', on_message) + + # Publish a message + await channel.publish('message', 'foobar') + message = await message_future + + # Publish multiple annotations + emojis = ['👍', '😕', '👎', '👍👍', '😕😕', '👎👎'] + for emoji in emojis: + await channel.annotations.publish(message.serial, { + 'type': 'reaction:multiple.v1', + 'name': emoji + }) + + # Wait for all annotations to appear + annotations = [] + + async def check_annotations(): + nonlocal annotations + res = await channel.annotations.get(message.serial, {}) + annotations = res.items + return len(annotations) == 6 + + await assert_waiter(check_annotations, timeout=10) + + # Verify annotations + assert annotations[0].action == AnnotationAction.ANNOTATION_CREATE + assert annotations[0].message_serial == message.serial + assert annotations[0].type == 'reaction:multiple.v1' + assert annotations[0].name == '👍' + assert annotations[1].name == '😕' + assert annotations[2].name == '👎' + assert annotations[1].serial > annotations[0].serial + assert annotations[2].serial > annotations[1].serial + + # Test pagination + res = await channel.annotations.get(message.serial, {'limit': 2}) + assert len(res.items) == 2 + assert [a.name for a in res.items] == ['👍', '😕'] + assert res.has_next() + + res = await res.next() + assert res is not None + assert len(res.items) == 2 + assert [a.name for a in res.items] == ['👎', '👍👍'] + assert res.has_next() + + res = await res.next() + assert res is not None + assert len(res.items) == 2 + assert [a.name for a in res.items] == ['😕😕', '👎👎'] + assert not res.has_next() + + async def test_subscribe_by_annotation_type(self): + """Test subscribing to specific annotation types""" + channel_options = ChannelOptions(params={ + 'modes': 'publish,subscribe,annotation_publish,annotation_subscribe' + }) + channel = self.ably.channels.get( + self.get_channel_name('mutable:subscribe_by_type'), + channel_options + ) + await channel.attach() + + # Setup message listener + message_future = asyncio.Future() + + def on_message(msg): + if not message_future.done(): + message_future.set_result(msg) + + await channel.subscribe('message', on_message) + + # Subscribe to specific annotation type + reaction_future = asyncio.Future() + + async def on_reaction(annotation): + if not reaction_future.done(): + reaction_future.set_result(annotation) + + await channel.annotations.subscribe('reaction:multiple.v1', on_reaction) + + # Publish message and annotation + await channel.publish('message', 'test') + message = await message_future + + # Temporary anti-flake measure (matches JS test) + await asyncio.sleep(1) + + await channel.annotations.publish(message.serial, { + 'type': 'reaction:multiple.v1', + 'name': '👍' + }) + + # Should receive the annotation + annotation = await reaction_future + assert annotation.type == 'reaction:multiple.v1' + assert annotation.name == '👍' + + async def test_unsubscribe_annotations(self): + """Test unsubscribing from annotations""" + channel_options = ChannelOptions(params={ + 'modes': 'publish,subscribe,annotation_publish,annotation_subscribe' + }) + channel = self.ably.channels.get( + self.get_channel_name('mutable:unsubscribe_annotations'), + channel_options + ) + await channel.attach() + + # Setup message listener + message_future = asyncio.Future() + + def on_message(msg): + if not message_future.done(): + message_future.set_result(msg) + + await channel.subscribe('message', on_message) + + annotations_received = [] + + async def on_annotation(annotation): + annotations_received.append(annotation) + + await channel.annotations.subscribe(on_annotation) + + # Publish message and first annotation + await channel.publish('message', 'test') + message = await message_future + + # Temporary anti-flake measure (matches JS test) + await asyncio.sleep(1) + + await channel.annotations.publish(message.serial, { + 'type': 'reaction:multiple.v1', + 'name': '👍' + }) + + # Wait for first annotation + assert len(annotations_received) == 1 + + # Unsubscribe + channel.annotations.unsubscribe(on_annotation) + + # Publish another annotation + await channel.annotations.publish(message.serial, { + 'type': 'reaction:multiple.v1', + 'name': '😕' + }) + + # Wait and verify we didn't receive it + assert len(annotations_received) == 1 + + async def test_delete_annotation(self): + """Test deleting annotations""" + channel_options = ChannelOptions(params={ + 'modes': 'publish,subscribe,annotation_publish,annotation_subscribe' + }) + channel = self.ably.channels.get( + self.get_channel_name('mutable:delete_annotation'), + channel_options + ) + await channel.attach() + + # Setup message listener + message_future = asyncio.Future() + + def on_message(msg): + if not message_future.done(): + message_future.set_result(msg) + + await channel.subscribe('message', on_message) + + annotations_received = [] + + async def on_annotation(annotation): + annotations_received.append(annotation) + + await channel.annotations.subscribe(on_annotation) + + # Publish message and annotation + await channel.publish('message', 'test') + message = await message_future + + # Temporary anti-flake measure (matches JS test) + await asyncio.sleep(1) + + await channel.annotations.publish(message.serial, { + 'type': 'reaction:multiple.v1', + 'name': '👍' + }) + + # Wait for create annotation + assert len(annotations_received) == 1 + assert annotations_received[0].action == AnnotationAction.ANNOTATION_CREATE + + # Delete the annotation + await channel.annotations.delete(message.serial, { + 'type': 'reaction:multiple.v1', + 'name': '👍' + }) + + # Wait for delete annotation + assert len(annotations_received) == 2 + assert annotations_received[1].action == AnnotationAction.ANNOTATION_DELETE + + async def test_subscribe_without_annotation_mode_fails(self): + """Test that subscribing without annotation_subscribe mode raises an error""" + # Create channel without annotation_subscribe mode + channel_options = ChannelOptions(params={ + 'modes': 'publish,subscribe' + }) + channel = self.ably.channels.get( + self.get_channel_name('mutable:no_annotation_mode'), + channel_options + ) + await channel.attach() + + async def on_annotation(annotation): + pass + + # Should raise error about missing annotation_subscribe mode + with pytest.raises(AblyException) as exc_info: + await channel.annotations.subscribe(on_annotation) + + assert exc_info.value.status_code == 400 + assert 'annotation_subscribe' in str(exc_info.value).lower() diff --git a/test/ably/rest/restannotations_test.py b/test/ably/rest/restannotations_test.py new file mode 100644 index 00000000..6756c7ec --- /dev/null +++ b/test/ably/rest/restannotations_test.py @@ -0,0 +1,242 @@ +import logging + +import pytest + +from ably import AblyException +from ably.types.message import Message +from test.ably.testapp import TestApp +from test.ably.utils import BaseAsyncTestCase, assert_waiter + +log = logging.getLogger(__name__) + + +@pytest.mark.parametrize("transport", ["json", "msgpack"], ids=["JSON", "MsgPack"]) +class TestRestAnnotations(BaseAsyncTestCase): + + @pytest.fixture(autouse=True) + async def setup(self, transport): + self.test_vars = await TestApp.get_test_vars() + self.ably = await TestApp.get_ably_rest( + use_binary_protocol=True if transport == 'msgpack' else False, + ) + + async def test_publish_annotation_success(self): + """Test successfully publishing an annotation on a message""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_publish_test')] + + # First publish a message + result = await channel.publish('test-event', 'test data') + assert result.serials is not None + assert len(result.serials) > 0 + serial = result.serials[0] + + # Publish an annotation + await channel.annotations.publish(serial, { + 'type': 'reaction:multiple.v1', + 'name': '👍' + }) + + annotations_result = None + + # Wait for annotations to appear + async def check_annotations(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) == 1 + + await assert_waiter(check_annotations, timeout=10) + + # Get annotations to verify + annotations = annotations_result.items + assert len(annotations) >= 1 + assert annotations[0].message_serial == serial + assert annotations[0].type == 'reaction:multiple.v1' + assert annotations[0].name == '👍' + + async def test_publish_annotation_with_message_object(self): + """Test publishing an annotation using a Message object""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_publish_msg_obj')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Create a message object + message = Message(serial=serial) + + # Publish annotation with message object + await channel.annotations.publish(message, { + 'type': 'reaction:multiple.v1', + 'name': '😕' + }) + + annotations_result = None + + # Wait for annotations to appear + async def check_annotations(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) == 1 + + await assert_waiter(check_annotations, timeout=10) + + # Verify + annotations_result = await channel.annotations.get(serial) + annotations = annotations_result.items + assert len(annotations) >= 1 + assert annotations[0].name == '😕' + + async def test_publish_annotation_without_serial_fails(self): + """Test that publishing without a serial raises an exception""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_no_serial')] + + with pytest.raises(AblyException) as exc_info: + await channel.annotations.publish(None, {'type': 'reaction', 'name': '👍'}) + + assert exc_info.value.status_code == 400 + assert exc_info.value.code == 40003 + + async def test_delete_annotation_success(self): + """Test successfully deleting an annotation""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_delete_test')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Publish an annotation + await channel.annotations.publish(serial, { + 'type': 'reaction:multiple.v1', + 'name': '👍' + }) + + annotations_result = None + + # Wait for annotation to appear + async def check_annotation(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) >= 1 + + await assert_waiter(check_annotation, timeout=10) + + # Delete the annotation + await channel.annotations.delete(serial, { + 'type': 'reaction:multiple.v1', + 'name': '👍' + }) + + # Wait for annotation to appear + async def check_deleted_annotation(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) == 0 + + await assert_waiter(check_deleted_annotation, timeout=10) + + async def test_get_annotations_pagination(self): + """Test retrieving annotations with pagination""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_pagination_test')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Publish multiple annotations + emojis = ['👍', '😕', '👎', '👍👍', '😕😕', '👎👎'] + for emoji in emojis: + await channel.annotations.publish(serial, { + 'type': 'reaction:multiple.v1', + 'name': emoji + }) + + # Wait for annotations to appear + async def check_annotations(): + res = await channel.annotations.get(serial) + return len(res.items) == 6 + + await assert_waiter(check_annotations, timeout=10) + + # Test pagination with limit + result = await channel.annotations.get(serial, {'limit': 2}) + assert len(result.items) == 2 + assert result.items[0].name == '👍' + assert result.items[1].name == '😕' + assert result.has_next() + + # Get next page + result = await result.next() + assert result is not None + assert len(result.items) == 2 + assert result.items[0].name == '👎' + assert result.items[1].name == '👍👍' + assert result.has_next() + + # Get last page + result = await result.next() + assert result is not None + assert len(result.items) == 2 + assert result.items[0].name == '😕😕' + assert result.items[1].name == '👎👎' + assert not result.has_next() + + async def test_get_all_annotations(self): + """Test retrieving all annotations for a message""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_get_all_test')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Publish annotations + await channel.annotations.publish(serial, {'type': 'reaction:multiple.v1', 'name': '👍'}) + await channel.annotations.publish(serial, {'type': 'reaction:multiple.v1', 'name': '😕'}) + await channel.annotations.publish(serial, {'type': 'reaction:multiple.v1', 'name': '👎'}) + + # Wait and get all annotations + async def check_annotations(): + res = await channel.annotations.get(serial) + return len(res.items) >= 3 + + await assert_waiter(check_annotations, timeout=10) + + annotations_result = await channel.annotations.get(serial) + annotations = annotations_result.items + assert len(annotations) >= 3 + assert annotations[0].type == 'reaction:multiple.v1' + assert annotations[0].message_serial == serial + # Verify serials are in order + if len(annotations) > 1: + assert annotations[1].serial > annotations[0].serial + if len(annotations) > 2: + assert annotations[2].serial > annotations[1].serial + + async def test_annotation_properties(self): + """Test that annotation properties are correctly set""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_properties_test')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Publish annotation with various properties + await channel.annotations.publish(serial, { + 'type': 'reaction:multiple.v1', + 'name': '❤️', + 'data': {'count': 5} + }) + + # Retrieve and verify + async def check_annotation(): + res = await channel.annotations.get(serial) + return len(res.items) > 0 + + await assert_waiter(check_annotation, timeout=10) + + annotations_result = await channel.annotations.get(serial) + annotation = annotations_result.items[0] + assert annotation.message_serial == serial + assert annotation.type == 'reaction:multiple.v1' + assert annotation.name == '❤️' + assert annotation.serial is not None + assert annotation.serial > serial diff --git a/uv.lock b/uv.lock index 1b196ab7..5b48323d 100644 --- a/uv.lock +++ b/uv.lock @@ -10,7 +10,7 @@ resolution-markers = [ [[package]] name = "ably" -version = "2.1.3" +version = "3.0.0" source = { editable = "." } dependencies = [ { name = "h2", version = "4.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, From 20288a6af35f9a6e23fb800d188158ce03a19463 Mon Sep 17 00:00:00 2001 From: evgeny Date: Thu, 29 Jan 2026 16:15:31 +0000 Subject: [PATCH 2/2] [AIT-316] feat: introduce support for message annotations - Added `RealtimeAnnotations` class to manage annotation creation, deletion, and subscription on realtime channels. - Introduced `Annotation` and `AnnotationAction` types to encapsulate annotation details and actions. - Extended flags to include `ANNOTATION_PUBLISH` and `ANNOTATION_SUBSCRIBE`. - Refactored data encoding logic into `ably.util.encoding`. - Integrated annotation handling into `RealtimeChannel` and `RestChannel`. --- ably/realtime/annotations.py | 31 +-- ably/realtime/channel.py | 8 +- ably/rest/annotations.py | 45 ++-- ably/rest/auth.py | 2 +- ably/rest/channel.py | 4 +- ably/transport/websockettransport.py | 1 + ably/types/annotation.py | 7 +- ably/types/channelmode.py | 2 + ably/types/channeloptions.py | 9 +- ably/util/encoding.py | 10 +- .../ably/realtime/realtimeannotations_test.py | 238 ++++++++---------- test/ably/realtime/realtimeconnection_test.py | 2 +- test/ably/rest/restannotations_test.py | 77 ++---- test/ably/utils.py | 20 ++ 14 files changed, 221 insertions(+), 235 deletions(-) diff --git a/ably/realtime/annotations.py b/ably/realtime/annotations.py index 96775b2c..13f9a17d 100644 --- a/ably/realtime/annotations.py +++ b/ably/realtime/annotations.py @@ -5,7 +5,7 @@ from ably.rest.annotations import RestAnnotations, construct_validate_annotation from ably.transport.websockettransport import ProtocolMessageAction -from ably.types.annotation import Annotation, AnnotationAction +from ably.types.annotation import AnnotationAction from ably.types.channelstate import ChannelState from ably.types.flags import Flag from ably.util.eventemitter import EventEmitter @@ -40,13 +40,13 @@ def __init__(self, channel: RealtimeChannel, connection_manager: ConnectionManag self.__subscriptions = EventEmitter() self.__rest_annotations = RestAnnotations(channel) - async def publish(self, msg_or_serial, annotation: dict | Annotation, params: dict=None): + async def publish(self, msg_or_serial, annotation: dict, params: dict | None = None): """ Publish an annotation on a message via the realtime connection. Args: msg_or_serial: Either a message serial (string) or a Message object - annotation: Dict containing annotation properties (type, name, data, etc.) or Annotation object + annotation: Dict containing annotation properties (type, name, data, etc.) params: Optional dict of query parameters Returns: @@ -84,7 +84,12 @@ async def publish(self, msg_or_serial, annotation: dict | Annotation, params: di # Send via WebSocket await self.__connection_manager.send_protocol_message(protocol_message) - async def delete(self, msg_or_serial, annotation: dict | Annotation, params=None, timeout=None): + async def delete( + self, + msg_or_serial, + annotation: dict, + params: dict | None = None, + ): """ Delete an annotation on a message. @@ -93,9 +98,8 @@ async def delete(self, msg_or_serial, annotation: dict | Annotation, params=None Args: msg_or_serial: Either a message serial (string) or a Message object - annotation: Dict containing annotation properties or Annotation object + annotation: Dict containing annotation properties params: Optional dict of query parameters - timeout: Optional timeout (not used for realtime, kept for compatibility) Returns: None @@ -103,10 +107,7 @@ async def delete(self, msg_or_serial, annotation: dict | Annotation, params=None Raises: AblyException: If the request fails or inputs are invalid """ - if isinstance(annotation, Annotation): - annotation_values = annotation.as_dict() - else: - annotation_values = annotation.copy() + annotation_values = annotation.copy() annotation_values['action'] = AnnotationAction.ANNOTATION_DELETE return await self.publish(msg_or_serial, annotation_values, params) @@ -161,13 +162,13 @@ async def subscribe(self, *args): # Check if ANNOTATION_SUBSCRIBE mode is enabled if self.__channel.state == ChannelState.ATTACHED: - if not Flag.ANNOTATION_SUBSCRIBE in self.__channel.modes: + if Flag.ANNOTATION_SUBSCRIBE not in self.__channel.modes: raise AblyException( - "You are trying to add an annotation listener, but you haven't requested the " + message="You are trying to add an annotation listener, but you haven't requested the " "annotation_subscribe channel mode in ChannelOptions, so this won't do anything " "(we only deliver annotations to clients who have explicitly requested them)", - 93001, - 400 + code=93001, + status_code=400, ) def unsubscribe(self, *args): @@ -219,7 +220,7 @@ def _process_incoming(self, incoming_annotations): annotation_type = annotation.type or '' self.__subscriptions._emit(annotation_type, annotation) - async def get(self, msg_or_serial, params=None): + async def get(self, msg_or_serial, params: dict | None = None): """ Retrieve annotations for a message with pagination support. diff --git a/ably/realtime/channel.py b/ably/realtime/channel.py index 4830132a..801f4c6a 100644 --- a/ably/realtime/channel.py +++ b/ably/realtime/channel.py @@ -11,6 +11,7 @@ from ably.rest.channel import Channels as RestChannels from ably.transport.websockettransport import ProtocolMessageAction from ably.types.annotation import Annotation +from ably.types.channelmode import ChannelMode, decode_channel_mode, encode_channel_mode from ably.types.channeloptions import ChannelOptions from ably.types.channelstate import ChannelState, ChannelStateChange from ably.types.flags import Flag, has_flag @@ -21,7 +22,6 @@ from ably.util.eventemitter import EventEmitter from ably.util.exceptions import AblyException, IncompatibleClientIdException from ably.util.helper import Timer, is_callable_or_coroutine, validate_message_size -from ably.types.channelmode import ChannelMode, decode_channel_mode, encode_channel_mode if TYPE_CHECKING: from ably.realtime.realtime import AblyRealtime @@ -68,7 +68,7 @@ def __init__(self, realtime: AblyRealtime, name: str, channel_options: ChannelOp self.__error_reason: AblyException | None = None self.__channel_options = channel_options or ChannelOptions() self.__params: dict[str, str] | None = None - self.__modes: list[ChannelMode] = list() # Channel mode flags from ATTACHED message + self.__modes: list[ChannelMode] = [] # Channel mode flags from ATTACHED message # Delta-specific fields for RTL19/RTL20 compliance vcdiff_decoder = self.__realtime.options.vcdiff_decoder if self.__realtime.options.vcdiff_decoder else None @@ -911,6 +911,10 @@ def presence(self): """Get the RealtimePresence object for this channel""" return self.__presence + @property + def annotations(self) -> RealtimeAnnotations: + return self._Channel__annotations + @property def modes(self): """Get the list of channel modes""" diff --git a/ably/rest/annotations.py b/ably/rest/annotations.py index 7f20fb3d..7f97cf7c 100644 --- a/ably/rest/annotations.py +++ b/ably/rest/annotations.py @@ -9,6 +9,7 @@ from ably.http.paginatedresult import PaginatedResult, format_params from ably.types.annotation import ( Annotation, + AnnotationAction, make_annotation_response_handler, ) from ably.types.message import Message @@ -48,7 +49,7 @@ def serial_from_msg_or_serial(msg_or_serial): return message_serial -def construct_validate_annotation(msg_or_serial, annotation: dict | Annotation): +def construct_validate_annotation(msg_or_serial, annotation: dict): """ Construct and validate an Annotation from input values. @@ -71,11 +72,8 @@ def construct_validate_annotation(msg_or_serial, annotation: dict | Annotation): status_code=400, code=40003, ) - elif isinstance(annotation, Annotation): - annotation_values = annotation.as_dict() - else: - annotation_values = annotation + annotation_values = annotation.copy() annotation_values['message_serial'] = message_serial return Annotation.from_values(annotation_values) @@ -108,15 +106,19 @@ def __base_path_for_serial(self, serial): channel_path = '/channels/{}/'.format(parse.quote_plus(self.__channel.name, safe=':')) return channel_path + 'messages/' + parse.quote_plus(serial, safe=':') + '/annotations' - async def publish(self, msg_or_serial, annotation_values, params=None, timeout=None): + async def publish( + self, + msg_or_serial, + annotation: dict | Annotation, + params: dict | None = None, + ): """ Publish an annotation on a message. Args: msg_or_serial: Either a message serial (string) or a Message object - annotation_values: Dict containing annotation properties (type, name, data, etc.) + annotation: Dict containing annotation properties (type, name, data, etc.) or Annotation object params: Optional dict of query parameters - timeout: Optional timeout for the HTTP request Returns: None @@ -124,7 +126,7 @@ async def publish(self, msg_or_serial, annotation_values, params=None, timeout=N Raises: AblyException: If the request fails or inputs are invalid """ - annotation = construct_validate_annotation(msg_or_serial, annotation_values) + annotation = construct_validate_annotation(msg_or_serial, annotation) # Convert to wire format request_body = annotation.as_dict(binary=self.__channel.ably.options.use_binary_protocol) @@ -145,9 +147,14 @@ async def publish(self, msg_or_serial, annotation_values, params=None, timeout=N path += '?' + parse.urlencode(params) # Send request - await self.__channel.ably.http.post(path, body=request_body, timeout=timeout) - - async def delete(self, msg_or_serial, annotation_values, params=None, timeout=None): + await self.__channel.ably.http.post(path, body=request_body) + + async def delete( + self, + msg_or_serial, + annotation: dict | Annotation, + params: dict | None = None, + ): """ Delete an annotation on a message. @@ -156,9 +163,8 @@ async def delete(self, msg_or_serial, annotation_values, params=None, timeout=No Args: msg_or_serial: Either a message serial (string) or a Message object - annotation_values: Dict containing annotation properties + annotation: Dict containing annotation properties or Annotation object params: Optional dict of query parameters - timeout: Optional timeout for the HTTP request Returns: None @@ -167,11 +173,14 @@ async def delete(self, msg_or_serial, annotation_values, params=None, timeout=No AblyException: If the request fails or inputs are invalid """ # Set action to delete - annotation_values = annotation_values.copy() - annotation_values['action'] = 'annotation.delete' - return await self.publish(msg_or_serial, annotation_values, params, timeout) + if isinstance(annotation, Annotation): + annotation_values = annotation.as_dict() + else: + annotation_values = annotation.copy() + annotation_values['action'] = AnnotationAction.ANNOTATION_DELETE + return await self.publish(msg_or_serial, annotation_values, params) - async def get(self, msg_or_serial, params=None): + async def get(self, msg_or_serial, params: dict | None = None): """ Retrieve annotations for a message with pagination support. diff --git a/ably/rest/auth.py b/ably/rest/auth.py index 2aaa4b12..2dc5d497 100644 --- a/ably/rest/auth.py +++ b/ably/rest/auth.py @@ -90,7 +90,7 @@ def __init__(self, ably: AblyRest | AblyRealtime, options: Options): async def get_auth_transport_param(self): auth_credentials = {} if self.auth_options.client_id: - auth_credentials["client_id"] = self.auth_options.client_id + auth_credentials["clientId"] = self.auth_options.client_id if self.__auth_mechanism == Auth.Method.BASIC: key_name = self.__auth_options.key_name key_secret = self.__auth_options.key_secret diff --git a/ably/rest/channel.py b/ably/rest/channel.py index f5a3e894..e16f209d 100644 --- a/ably/rest/channel.py +++ b/ably/rest/channel.py @@ -31,6 +31,8 @@ class Channel: + __annotations: RestAnnotations + def __init__(self, ably, name, options): self.__ably = ably self.__name = name @@ -366,7 +368,7 @@ def presence(self): return self.__presence @property - def annotations(self): + def annotations(self) -> RestAnnotations: return self.__annotations @options.setter diff --git a/ably/transport/websockettransport.py b/ably/transport/websockettransport.py index 4f6f9fe0..be13d096 100644 --- a/ably/transport/websockettransport.py +++ b/ably/transport/websockettransport.py @@ -189,6 +189,7 @@ async def on_protocol_message(self, msg): ProtocolMessageAction.DETACHED, ProtocolMessageAction.MESSAGE, ProtocolMessageAction.PRESENCE, + ProtocolMessageAction.ANNOTATION, ProtocolMessageAction.SYNC ): self.connection_manager.on_channel_message(msg) diff --git a/ably/types/annotation.py b/ably/types/annotation.py index a3aded28..e099d00d 100644 --- a/ably/types/annotation.py +++ b/ably/types/annotation.py @@ -122,9 +122,6 @@ def as_dict(self, binary=False): Note: Annotations are not encrypted as they need to be parsed by the server. """ - # Encode data - encoded = encode_data(self.data, self._encoding_array, binary) - request_body = { 'action': int(self.action) if self.action is not None else None, 'serial': self.serial, @@ -132,12 +129,10 @@ def as_dict(self, binary=False): 'type': self.type, # Annotation type (not data type) 'name': self.name, 'count': self.count, - 'data': encoded.get('data'), - 'encoding': encoded.get('encoding', ''), - 'dataType': encoded.get('type'), # Data type (not annotation type) 'clientId': self.client_id or None, 'timestamp': self.timestamp or None, 'extras': self.extras, + **encode_data(self.data, self._encoding_array, binary) } # None values aren't included diff --git a/ably/types/channelmode.py b/ably/types/channelmode.py index 6ba95f08..23ed735c 100644 --- a/ably/types/channelmode.py +++ b/ably/types/channelmode.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum from ably.types.flags import Flag diff --git a/ably/types/channeloptions.py b/ably/types/channeloptions.py index b745a3e8..02f2bd5d 100644 --- a/ably/types/channeloptions.py +++ b/ably/types/channeloptions.py @@ -2,9 +2,9 @@ from typing import Any +from ably.types.channelmode import ChannelMode from ably.util.crypto import CipherParams from ably.util.exceptions import AblyException -from ably.types.channelmode import ChannelMode class ChannelOptions: @@ -18,7 +18,12 @@ class ChannelOptions: Channel parameters that configure the behavior of the channel. """ - def __init__(self, cipher: CipherParams | None = None, params: dict | None = None, modes: list[ChannelMode] | None = None): + def __init__( + self, + cipher: CipherParams | None = None, + params: dict | None = None, + modes: list[ChannelMode] | None = None + ): self.__cipher = cipher self.__params = params self.__modes = modes diff --git a/ably/util/encoding.py b/ably/util/encoding.py index b0af9620..3b3858b4 100644 --- a/ably/util/encoding.py +++ b/ably/util/encoding.py @@ -27,7 +27,9 @@ def encode_data(data: Any, encoding_array: list, binary: bool = False): elif binary and isinstance(data, bytearray): data = bytes(data) - return { - 'data': data, - 'encoding': '/'.join(encoding).strip('/') - } + result = { 'data': data } + + if encoding: + result['encoding'] = '/'.join(encoding).strip('/') + + return result diff --git a/test/ably/realtime/realtimeannotations_test.py b/test/ably/realtime/realtimeannotations_test.py index 5e502380..6852adaa 100644 --- a/test/ably/realtime/realtimeannotations_test.py +++ b/test/ably/realtime/realtimeannotations_test.py @@ -1,15 +1,17 @@ import asyncio import logging +import random +import string import pytest from ably import AblyException from ably.types.annotation import AnnotationAction +from ably.types.channelmode import ChannelMode from ably.types.channeloptions import ChannelOptions from ably.types.message import MessageAction from test.ably.testapp import TestApp -from test.ably.utils import BaseAsyncTestCase, assert_waiter -from ably.types.channelmode import ChannelMode +from test.ably.utils import BaseAsyncTestCase, ReusableFuture, assert_waiter log = logging.getLogger(__name__) @@ -20,26 +22,31 @@ class TestRealtimeAnnotations(BaseAsyncTestCase): @pytest.fixture(autouse=True) async def setup(self, transport): self.test_vars = await TestApp.get_test_vars() - self.ably = await TestApp.get_ably_realtime( + + client_id = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + self.realtime_client = await TestApp.get_ably_realtime( use_binary_protocol=True if transport == 'msgpack' else False, + client_id=client_id, ) - self.rest = await TestApp.get_ably_rest( + self.rest_client = await TestApp.get_ably_rest( use_binary_protocol=True if transport == 'msgpack' else False, + client_id=client_id, ) async def test_publish_and_subscribe_annotations(self): - """Test publishing and subscribing to annotations (matches JS test)""" + """Test publishing and subscribing to annotations""" channel_options = ChannelOptions(modes=[ ChannelMode.PUBLISH, ChannelMode.SUBSCRIBE, ChannelMode.ANNOTATION_PUBLISH, ChannelMode.ANNOTATION_SUBSCRIBE ]) - channel = self.ably.channels.get( - self.get_channel_name('mutable:publish_subscribe_annotation'), - channel_options + channel_name = self.get_channel_name('mutable:publish_and_subscribe_annotations') + channel = self.realtime_client.channels.get( + channel_name, + channel_options, ) - rest_channel = self.rest.channels[channel.name] + rest_channel = self.rest_client.channels.get(channel_name) await channel.attach() # Setup annotation listener @@ -65,7 +72,7 @@ def on_message(msg): # Publish annotation using realtime await channel.annotations.publish(publish_result.serials[0], { - 'type': 'reaction:multiple.v1', + 'type': 'reaction:distinct.v1', 'name': '👍' }) @@ -73,65 +80,58 @@ def on_message(msg): annotation = await annotation_future assert annotation.action == AnnotationAction.ANNOTATION_CREATE assert annotation.message_serial == publish_result.serials[0] - assert annotation.type == 'reaction:multiple.v1' + assert annotation.type == 'reaction:distinct.v1' assert annotation.name == '👍' assert annotation.serial > annotation.message_serial # Wait for summary message - # summary = await message_summary - # assert summary.action == MessageAction.META - # assert summary.serial == publish_result.serials[0] - # - # # Try again but with REST publish - # annotation_future2 = asyncio.Future() - # - # async def on_annotation2(annotation): - # if not annotation_future2.done(): - # annotation_future2.set_result(annotation) - # - # await channel.annotations.subscribe(on_annotation2) - # - # await rest_channel.annotations.publish(publish_result.serials[0], { - # 'type': 'reaction:multiple.v1', - # 'name': '😕' - # }) - # - # annotation = await annotation_future2 - # assert annotation.action == AnnotationAction.ANNOTATION_CREATE - # assert annotation.message_serial == publish_result.serials[0] - # assert annotation.type == 'reaction:multiple.v1' - # assert annotation.name == '😕' - # assert annotation.serial > annotation.message_serial + summary = await message_summary + assert summary.action == MessageAction.MESSAGE_SUMMARY + assert summary.serial == publish_result.serials[0] - async def test_get_all_annotations_for_a_message(self): - """Test retrieving all annotations with pagination (matches JS test)""" - channel_options = ChannelOptions(params={ - 'modes': 'publish,subscribe,annotation_publish,annotation_subscribe' + # Try again but with REST publish + annotation_future2 = asyncio.Future() + + async def on_annotation2(annotation): + if not annotation_future2.done(): + annotation_future2.set_result(annotation) + + await channel.annotations.subscribe(on_annotation2) + + await rest_channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', + 'name': '😕' }) - channel = self.ably.channels.get( + + annotation = await annotation_future2 + assert annotation.action == AnnotationAction.ANNOTATION_CREATE + assert annotation.message_serial == publish_result.serials[0] + assert annotation.type == 'reaction:distinct.v1' + assert annotation.name == '😕' + assert annotation.serial > annotation.message_serial + + async def test_get_all_annotations_for_a_message(self): + """Test retrieving all annotations with pagination""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( self.get_channel_name('mutable:get_all_annotations_for_a_message'), channel_options ) await channel.attach() - # Setup message listener - message_future = asyncio.Future() - - def on_message(msg): - if not message_future.done(): - message_future.set_result(msg) - - await channel.subscribe('message', on_message) - # Publish a message - await channel.publish('message', 'foobar') - message = await message_future + publish_result = await channel.publish('message', 'foobar') # Publish multiple annotations - emojis = ['👍', '😕', '👎', '👍👍', '😕😕', '👎👎'] + emojis = ['👍', '😕', '👎'] for emoji in emojis: - await channel.annotations.publish(message.serial, { - 'type': 'reaction:multiple.v1', + await channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', 'name': emoji }) @@ -140,46 +140,31 @@ def on_message(msg): async def check_annotations(): nonlocal annotations - res = await channel.annotations.get(message.serial, {}) + res = await channel.annotations.get(publish_result.serials[0], {}) annotations = res.items - return len(annotations) == 6 + return len(annotations) == 3 await assert_waiter(check_annotations, timeout=10) # Verify annotations assert annotations[0].action == AnnotationAction.ANNOTATION_CREATE - assert annotations[0].message_serial == message.serial - assert annotations[0].type == 'reaction:multiple.v1' + assert annotations[0].message_serial == publish_result.serials[0] + assert annotations[0].type == 'reaction:distinct.v1' assert annotations[0].name == '👍' assert annotations[1].name == '😕' assert annotations[2].name == '👎' assert annotations[1].serial > annotations[0].serial assert annotations[2].serial > annotations[1].serial - # Test pagination - res = await channel.annotations.get(message.serial, {'limit': 2}) - assert len(res.items) == 2 - assert [a.name for a in res.items] == ['👍', '😕'] - assert res.has_next() - - res = await res.next() - assert res is not None - assert len(res.items) == 2 - assert [a.name for a in res.items] == ['👎', '👍👍'] - assert res.has_next() - - res = await res.next() - assert res is not None - assert len(res.items) == 2 - assert [a.name for a in res.items] == ['😕😕', '👎👎'] - assert not res.has_next() - async def test_subscribe_by_annotation_type(self): """Test subscribing to specific annotation types""" - channel_options = ChannelOptions(params={ - 'modes': 'publish,subscribe,annotation_publish,annotation_subscribe' - }) - channel = self.ably.channels.get( + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( self.get_channel_name('mutable:subscribe_by_type'), channel_options ) @@ -201,85 +186,81 @@ async def on_reaction(annotation): if not reaction_future.done(): reaction_future.set_result(annotation) - await channel.annotations.subscribe('reaction:multiple.v1', on_reaction) + await channel.annotations.subscribe('reaction:distinct.v1', on_reaction) # Publish message and annotation - await channel.publish('message', 'test') - message = await message_future + publish_result = await channel.publish('message', 'test') - # Temporary anti-flake measure (matches JS test) - await asyncio.sleep(1) - - await channel.annotations.publish(message.serial, { - 'type': 'reaction:multiple.v1', + await channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', 'name': '👍' }) # Should receive the annotation annotation = await reaction_future - assert annotation.type == 'reaction:multiple.v1' + assert annotation.type == 'reaction:distinct.v1' assert annotation.name == '👍' async def test_unsubscribe_annotations(self): """Test unsubscribing from annotations""" - channel_options = ChannelOptions(params={ - 'modes': 'publish,subscribe,annotation_publish,annotation_subscribe' - }) - channel = self.ably.channels.get( + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( self.get_channel_name('mutable:unsubscribe_annotations'), channel_options ) await channel.attach() - # Setup message listener - message_future = asyncio.Future() - - def on_message(msg): - if not message_future.done(): - message_future.set_result(msg) - - await channel.subscribe('message', on_message) - annotations_received = [] + annotation_future = ReusableFuture() async def on_annotation(annotation): annotations_received.append(annotation) + annotation_future.set_result(annotation) await channel.annotations.subscribe(on_annotation) # Publish message and first annotation - await channel.publish('message', 'test') - message = await message_future - - # Temporary anti-flake measure (matches JS test) - await asyncio.sleep(1) + publish_result = await channel.publish('message', 'test') - await channel.annotations.publish(message.serial, { - 'type': 'reaction:multiple.v1', + await channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', 'name': '👍' }) - # Wait for first annotation + # Wait for the first annotation to appear + await annotation_future.get() assert len(annotations_received) == 1 # Unsubscribe channel.annotations.unsubscribe(on_annotation) + await channel.annotations.subscribe(lambda annotation: annotation_future.set_result(annotation)) + # Publish another annotation - await channel.annotations.publish(message.serial, { - 'type': 'reaction:multiple.v1', + await channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', 'name': '😕' }) - # Wait and verify we didn't receive it + # Wait for the second annotation to appear in another listener + await annotation_future.get() + assert len(annotations_received) == 1 async def test_delete_annotation(self): """Test deleting annotations""" - channel_options = ChannelOptions(params={ - 'modes': 'publish,subscribe,annotation_publish,annotation_subscribe' - }) - channel = self.ably.channels.get( + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( self.get_channel_name('mutable:delete_annotation'), channel_options ) @@ -295,9 +276,10 @@ def on_message(msg): await channel.subscribe('message', on_message) annotations_received = [] - + annotation_future = ReusableFuture() async def on_annotation(annotation): annotations_received.append(annotation) + annotation_future.set_result(annotation) await channel.annotations.subscribe(on_annotation) @@ -305,35 +287,37 @@ async def on_annotation(annotation): await channel.publish('message', 'test') message = await message_future - # Temporary anti-flake measure (matches JS test) - await asyncio.sleep(1) - await channel.annotations.publish(message.serial, { - 'type': 'reaction:multiple.v1', + 'type': 'reaction:distinct.v1', 'name': '👍' }) + await annotation_future.get() + # Wait for create annotation assert len(annotations_received) == 1 assert annotations_received[0].action == AnnotationAction.ANNOTATION_CREATE # Delete the annotation await channel.annotations.delete(message.serial, { - 'type': 'reaction:multiple.v1', + 'type': 'reaction:distinct.v1', 'name': '👍' }) # Wait for delete annotation + await annotation_future.get() + assert len(annotations_received) == 2 assert annotations_received[1].action == AnnotationAction.ANNOTATION_DELETE async def test_subscribe_without_annotation_mode_fails(self): """Test that subscribing without annotation_subscribe mode raises an error""" # Create channel without annotation_subscribe mode - channel_options = ChannelOptions(params={ - 'modes': 'publish,subscribe' - }) - channel = self.ably.channels.get( + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( self.get_channel_name('mutable:no_annotation_mode'), channel_options ) diff --git a/test/ably/realtime/realtimeconnection_test.py b/test/ably/realtime/realtimeconnection_test.py index b38c5aaf..f1eb9003 100644 --- a/test/ably/realtime/realtimeconnection_test.py +++ b/test/ably/realtime/realtimeconnection_test.py @@ -369,7 +369,7 @@ async def test_connection_client_id_query_params(self): ably = await TestApp.get_ably_realtime(client_id=client_id) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) - assert ably.connection.connection_manager.transport.params["client_id"] == client_id + assert ably.connection.connection_manager.transport.params["clientId"] == client_id assert ably.auth.client_id == client_id await ably.close() diff --git a/test/ably/rest/restannotations_test.py b/test/ably/rest/restannotations_test.py index 6756c7ec..8969e84d 100644 --- a/test/ably/rest/restannotations_test.py +++ b/test/ably/rest/restannotations_test.py @@ -1,8 +1,11 @@ import logging +import random +import string import pytest from ably import AblyException +from ably.types.annotation import AnnotationAction from ably.types.message import Message from test.ably.testapp import TestApp from test.ably.utils import BaseAsyncTestCase, assert_waiter @@ -16,8 +19,10 @@ class TestRestAnnotations(BaseAsyncTestCase): @pytest.fixture(autouse=True) async def setup(self, transport): self.test_vars = await TestApp.get_test_vars() + client_id = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) self.ably = await TestApp.get_ably_rest( use_binary_protocol=True if transport == 'msgpack' else False, + client_id=client_id, ) async def test_publish_annotation_success(self): @@ -32,7 +37,7 @@ async def test_publish_annotation_success(self): # Publish an annotation await channel.annotations.publish(serial, { - 'type': 'reaction:multiple.v1', + 'type': 'reaction:distinct.v1', 'name': '👍' }) @@ -50,7 +55,7 @@ async def check_annotations(): annotations = annotations_result.items assert len(annotations) >= 1 assert annotations[0].message_serial == serial - assert annotations[0].type == 'reaction:multiple.v1' + assert annotations[0].type == 'reaction:distinct.v1' assert annotations[0].name == '👍' async def test_publish_annotation_with_message_object(self): @@ -66,7 +71,7 @@ async def test_publish_annotation_with_message_object(self): # Publish annotation with message object await channel.annotations.publish(message, { - 'type': 'reaction:multiple.v1', + 'type': 'reaction:distinct.v1', 'name': '😕' }) @@ -106,7 +111,7 @@ async def test_delete_annotation_success(self): # Publish an annotation await channel.annotations.publish(serial, { - 'type': 'reaction:multiple.v1', + 'type': 'reaction:distinct.v1', 'name': '👍' }) @@ -122,7 +127,7 @@ async def check_annotation(): # Delete the annotation await channel.annotations.delete(serial, { - 'type': 'reaction:multiple.v1', + 'type': 'reaction:distinct.v1', 'name': '👍' }) @@ -130,55 +135,11 @@ async def check_annotation(): async def check_deleted_annotation(): nonlocal annotations_result annotations_result = await channel.annotations.get(serial) - return len(annotations_result.items) == 0 + return len(annotations_result.items) >= 2 await assert_waiter(check_deleted_annotation, timeout=10) - - async def test_get_annotations_pagination(self): - """Test retrieving annotations with pagination""" - channel = self.ably.channels[self.get_channel_name('mutable:annotation_pagination_test')] - - # Publish a message - result = await channel.publish('test-event', 'test data') - serial = result.serials[0] - - # Publish multiple annotations - emojis = ['👍', '😕', '👎', '👍👍', '😕😕', '👎👎'] - for emoji in emojis: - await channel.annotations.publish(serial, { - 'type': 'reaction:multiple.v1', - 'name': emoji - }) - - # Wait for annotations to appear - async def check_annotations(): - res = await channel.annotations.get(serial) - return len(res.items) == 6 - - await assert_waiter(check_annotations, timeout=10) - - # Test pagination with limit - result = await channel.annotations.get(serial, {'limit': 2}) - assert len(result.items) == 2 - assert result.items[0].name == '👍' - assert result.items[1].name == '😕' - assert result.has_next() - - # Get next page - result = await result.next() - assert result is not None - assert len(result.items) == 2 - assert result.items[0].name == '👎' - assert result.items[1].name == '👍👍' - assert result.has_next() - - # Get last page - result = await result.next() - assert result is not None - assert len(result.items) == 2 - assert result.items[0].name == '😕😕' - assert result.items[1].name == '👎👎' - assert not result.has_next() + assert annotations_result.items[-1].type == 'reaction:distinct.v1' + assert annotations_result.items[-1].action == AnnotationAction.ANNOTATION_DELETE async def test_get_all_annotations(self): """Test retrieving all annotations for a message""" @@ -189,9 +150,9 @@ async def test_get_all_annotations(self): serial = result.serials[0] # Publish annotations - await channel.annotations.publish(serial, {'type': 'reaction:multiple.v1', 'name': '👍'}) - await channel.annotations.publish(serial, {'type': 'reaction:multiple.v1', 'name': '😕'}) - await channel.annotations.publish(serial, {'type': 'reaction:multiple.v1', 'name': '👎'}) + await channel.annotations.publish(serial, {'type': 'reaction:distinct.v1', 'name': '👍'}) + await channel.annotations.publish(serial, {'type': 'reaction:distinct.v1', 'name': '😕'}) + await channel.annotations.publish(serial, {'type': 'reaction:distinct.v1', 'name': '👎'}) # Wait and get all annotations async def check_annotations(): @@ -203,7 +164,7 @@ async def check_annotations(): annotations_result = await channel.annotations.get(serial) annotations = annotations_result.items assert len(annotations) >= 3 - assert annotations[0].type == 'reaction:multiple.v1' + assert annotations[0].type == 'reaction:distinct.v1' assert annotations[0].message_serial == serial # Verify serials are in order if len(annotations) > 1: @@ -221,7 +182,7 @@ async def test_annotation_properties(self): # Publish annotation with various properties await channel.annotations.publish(serial, { - 'type': 'reaction:multiple.v1', + 'type': 'reaction:distinct.v1', 'name': '❤️', 'data': {'count': 5} }) @@ -236,7 +197,7 @@ async def check_annotation(): annotations_result = await channel.annotations.get(serial) annotation = annotations_result.items[0] assert annotation.message_serial == serial - assert annotation.type == 'reaction:multiple.v1' + assert annotation.type == 'reaction:distinct.v1' assert annotation.name == '❤️' assert annotation.serial is not None assert annotation.serial > serial diff --git a/test/ably/utils.py b/test/ably/utils.py index 09658fc0..eb75d3e6 100644 --- a/test/ably/utils.py +++ b/test/ably/utils.py @@ -229,6 +229,9 @@ def assert_waiter_sync(block: Callable[[], bool], timeout: float = 10) -> None: class WaitableEvent: + """ + Replacement for asyncio.Future that will work with autogenerated sync tests. + """ def __init__(self): self._finished = False @@ -243,3 +246,20 @@ async def wait(self, timeout=10): def finish(self): self._finished = True + +class ReusableFuture: + """ + A reusable future that after each wait() resets itself and wait for the next value. + """ + def __init__(self): + self.__future = asyncio.Future() + + async def get(self, timeout=10): + await asyncio.wait_for(self.__future, timeout=timeout) + self.__future = asyncio.Future() + + def set_result(self, result): + self.__future.set_result(result) + + def set_exception(self, exception): + self.__future.set_exception(exception)