Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 240 additions & 0 deletions ably/realtime/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from ably.rest.annotations import RestAnnotations, construct_validate_annotation
from ably.transport.websockettransport import ProtocolMessageAction
from ably.types.annotation import AnnotationAction
from ably.types.channelstate import ChannelState
from ably.types.flags import Flag
from ably.util.eventemitter import EventEmitter
from ably.util.exceptions import AblyException
from ably.util.helper import is_callable_or_coroutine

if TYPE_CHECKING:
from ably.realtime.channel import RealtimeChannel
from ably.realtime.connectionmanager import ConnectionManager

log = logging.getLogger(__name__)


class RealtimeAnnotations:
"""
Provides realtime methods for managing annotations on messages,
including publishing annotations and subscribing to annotation events.
"""

__connection_manager: ConnectionManager
__channel: RealtimeChannel

def __init__(self, channel: RealtimeChannel, connection_manager: ConnectionManager):
"""
Initialize RealtimeAnnotations.

Args:
channel: The Realtime Channel this annotations instance belongs to
"""
self.__channel = channel
self.__connection_manager = connection_manager
self.__subscriptions = EventEmitter()
self.__rest_annotations = RestAnnotations(channel)

async def publish(self, msg_or_serial, annotation: dict, params: dict | None = None):
"""
Publish an annotation on a message via the realtime connection.

Args:
msg_or_serial: Either a message serial (string) or a Message object
annotation: Dict containing annotation properties (type, name, data, etc.)
params: Optional dict of query parameters

Returns:
None

Raises:
AblyException: If the request fails, inputs are invalid, or channel is in unpublishable state
"""
annotation = construct_validate_annotation(msg_or_serial, annotation)

# Check if channel and connection are in publishable state
self.__channel._throw_if_unpublishable_state()

log.info(
f'RealtimeAnnotations.publish(), channelName = {self.__channel.name}, '
f'sending annotation with messageSerial = {annotation.message_serial}, '
f'type = {annotation.type}'
)

# Convert to wire format (array of annotations)
wire_annotation = annotation.as_dict(binary=self.__channel.ably.options.use_binary_protocol)

# Build protocol message
protocol_message = {
"action": ProtocolMessageAction.ANNOTATION,
"channel": self.__channel.name,
"annotations": [wire_annotation],
}

if params:
# Stringify boolean params
stringified_params = {k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items()}
protocol_message["params"] = stringified_params

# Send via WebSocket
await self.__connection_manager.send_protocol_message(protocol_message)

async def delete(
self,
msg_or_serial,
annotation: dict,
params: dict | None = None,
):
"""
Delete an annotation on a message.

This is a convenience method that sets the action to 'annotation.delete'
and calls publish().

Args:
msg_or_serial: Either a message serial (string) or a Message object
annotation: Dict containing annotation properties
params: Optional dict of query parameters

Returns:
None

Raises:
AblyException: If the request fails or inputs are invalid
"""
annotation_values = annotation.copy()
annotation_values['action'] = AnnotationAction.ANNOTATION_DELETE
return await self.publish(msg_or_serial, annotation_values, params)

async def subscribe(self, *args):
"""
Subscribe to annotation events on this channel.

Parameters
----------
*args: type, listener
Subscribe type and listener

arg1(type): str, optional
Subscribe to annotations of the given type

arg2(listener): callable
Subscribe to all annotations on the channel

When no type is provided, arg1 is used as the listener.

Raises
------
AblyException
If unable to subscribe due to invalid channel state or missing ANNOTATION_SUBSCRIBE mode
ValueError
If no valid subscribe arguments are passed
"""
# Parse arguments similar to channel.subscribe
if len(args) == 0:
raise ValueError("annotations.subscribe called without arguments")

if len(args) >= 2 and isinstance(args[0], str):
annotation_type = args[0]
if not args[1]:
raise ValueError("annotations.subscribe called without listener")
if not is_callable_or_coroutine(args[1]):
raise ValueError("subscribe listener must be function or coroutine function")
listener = args[1]
elif is_callable_or_coroutine(args[0]):
listener = args[0]
annotation_type = None
else:
raise ValueError('invalid subscribe arguments')

# Register subscription
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably validate channel mode before adding the subscription - in this code the listener would be attached even if the annotation_subscribe mode isn't enabled

if annotation_type is not None:
self.__subscriptions.on(annotation_type, listener)
else:
self.__subscriptions.on(listener)

await self.__channel.attach()

# Check if ANNOTATION_SUBSCRIBE mode is enabled
if self.__channel.state == ChannelState.ATTACHED:
if Flag.ANNOTATION_SUBSCRIBE not in self.__channel.modes:
raise AblyException(
message="You are trying to add an annotation listener, but you haven't requested the "
"annotation_subscribe channel mode in ChannelOptions, so this won't do anything "
"(we only deliver annotations to clients who have explicitly requested them)",
code=93001,
status_code=400,
)

def unsubscribe(self, *args):
"""
Unsubscribe from annotation events on this channel.

Parameters
----------
*args: type, listener
Unsubscribe type and listener

arg1(type): str, optional
Unsubscribe from annotations of the given type

arg2(listener): callable
Unsubscribe from all annotations on the channel

When no type is provided, arg1 is used as the listener.

Raises
------
ValueError
If no valid unsubscribe arguments are passed
"""
if len(args) == 0:
raise ValueError("annotations.unsubscribe called without arguments")

if len(args) >= 2 and isinstance(args[0], str):
annotation_type = args[0]
listener = args[1]
self.__subscriptions.off(annotation_type, listener)
elif is_callable_or_coroutine(args[0]):
listener = args[0]
self.__subscriptions.off(listener)
else:
raise ValueError('invalid unsubscribe arguments')

def _process_incoming(self, incoming_annotations):
"""
Process incoming annotations from the server.

This is called internally when ANNOTATION protocol messages are received.

Args:
incoming_annotations: List of Annotation objects received from the server
"""
for annotation in incoming_annotations:
# Emit to type-specific listeners and catch-all listeners
annotation_type = annotation.type or ''
self.__subscriptions._emit(annotation_type, annotation)

async def get(self, msg_or_serial, params: dict | None = None):
"""
Retrieve annotations for a message with pagination support.

This delegates to the REST implementation.

Args:
msg_or_serial: Either a message serial (string) or a Message object
params: Optional dict of query parameters (limit, start, end, direction)

Returns:
PaginatedResult: A paginated result containing Annotation objects

Raises:
AblyException: If the request fails or serial is invalid
"""
# Delegate to REST implementation
return await self.__rest_annotations.get(msg_or_serial, params)
58 changes: 51 additions & 7 deletions ably/realtime/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
import logging
from typing import TYPE_CHECKING

from ably.realtime.annotations import RealtimeAnnotations
from ably.realtime.connection import ConnectionState
from ably.realtime.presence import RealtimePresence
from ably.rest.channel import Channel
from ably.rest.channel import Channels as RestChannels
from ably.transport.websockettransport import ProtocolMessageAction
from ably.types.annotation import Annotation
from ably.types.channelmode import ChannelMode, decode_channel_mode, encode_channel_mode
from ably.types.channeloptions import ChannelOptions
from ably.types.channelstate import ChannelState, ChannelStateChange
from ably.types.flags import Flag, has_flag
Expand Down Expand Up @@ -64,6 +68,7 @@ def __init__(self, realtime: AblyRealtime, name: str, channel_options: ChannelOp
self.__error_reason: AblyException | None = None
self.__channel_options = channel_options or ChannelOptions()
self.__params: dict[str, str] | None = None
self.__modes: list[ChannelMode] = [] # Channel mode flags from ATTACHED message

# Delta-specific fields for RTL19/RTL20 compliance
vcdiff_decoder = self.__realtime.options.vcdiff_decoder if self.__realtime.options.vcdiff_decoder else None
Expand All @@ -74,12 +79,15 @@ def __init__(self, realtime: AblyRealtime, name: str, channel_options: ChannelOp
# will be disrupted if the user called .off() to remove all listeners
self.__internal_state_emitter = EventEmitter()

# Pass channel options as dictionary to parent Channel class
Channel.__init__(self, realtime, name, self.__channel_options.to_dict())

# Initialize presence for this channel
from ably.realtime.presence import RealtimePresence

self.__presence = RealtimePresence(self)

# Pass channel options as dictionary to parent Channel class
Channel.__init__(self, realtime, name, self.__channel_options.to_dict())
# Initialize realtime annotations for this channel (override REST annotations)
self._Channel__annotations = RealtimeAnnotations(self, realtime.connection.connection_manager)

async def set_options(self, channel_options: ChannelOptions) -> None:
"""Set channel options"""
Expand Down Expand Up @@ -149,8 +157,10 @@ def _attach_impl(self):
"channel": self.name,
}

if self.__attach_resume:
attach_msg["flags"] = Flag.ATTACH_RESUME
flags = self._encode_flags()

if flags:
attach_msg["flags"] = flags
if self.__channel_serial:
attach_msg["channelSerial"] = self.__channel_serial

Expand Down Expand Up @@ -491,8 +501,8 @@ async def _send_update(
if not message.serial:
raise AblyException(
"Message serial is required for update/delete/append operations",
400,
40003
status_code=400,
code=40003,
)

# Check connection and channel state
Expand Down Expand Up @@ -702,6 +712,8 @@ def _on_message(self, proto_msg: dict) -> None:
resumed = has_flag(flags, Flag.RESUMED)
# RTP1: Check for HAS_PRESENCE flag
has_presence = has_flag(flags, Flag.HAS_PRESENCE)
# Store channel attach flags
self.__modes = decode_channel_mode(flags)

# RTL12
if self.state == ChannelState.ATTACHED:
Expand Down Expand Up @@ -744,6 +756,15 @@ def _on_message(self, proto_msg: dict) -> None:
decoded_presence = PresenceMessage.from_encoded_array(presence_messages, cipher=self.cipher)
sync_channel_serial = proto_msg.get('channelSerial')
self.__presence.set_presence(decoded_presence, is_sync=True, sync_channel_serial=sync_channel_serial)
elif action == ProtocolMessageAction.ANNOTATION:
# Handle ANNOTATION messages
annotation_data = proto_msg.get('annotations', [])
try:
annotations = Annotation.from_encoded_array(annotation_data, cipher=self.cipher)
# Process annotations through the annotations handler
self.annotations._process_incoming(annotations)
except Exception as e:
log.error(f"Annotation processing error {e}. Skip annotations {annotation_data}")
elif action == ProtocolMessageAction.ERROR:
error = AblyException.from_dict(proto_msg.get('error'))
self._notify_state(ChannelState.FAILED, reason=error)
Expand Down Expand Up @@ -890,6 +911,15 @@ def presence(self):
"""Get the RealtimePresence object for this channel"""
return self.__presence

@property
def annotations(self) -> RealtimeAnnotations:
return self._Channel__annotations

@property
def modes(self):
"""Get the list of channel modes"""
return self.__modes

def _start_decode_failure_recovery(self, error: AblyException) -> None:
"""Start RTL18 decode failure recovery procedure"""

Expand All @@ -908,6 +938,20 @@ def _start_decode_failure_recovery(self, error: AblyException) -> None:
self._notify_state(ChannelState.ATTACHING, reason=error)
self._check_pending_state()

def _encode_flags(self) -> int | None:
if not self.__channel_options.modes and not self.__attach_resume:
return None

flags = 0

if self.__attach_resume:
flags |= Flag.ATTACH_RESUME

if self.__channel_options.modes:
flags |= encode_channel_mode(self.__channel_options.modes)

return flags


class Channels(RestChannels):
"""Creates and destroys RealtimeChannel objects.
Expand Down
Loading
Loading