diff --git a/src/frequenz/channels/__init__.py b/src/frequenz/channels/__init__.py index 87c86a34..63b04faf 100644 --- a/src/frequenz/channels/__init__.py +++ b/src/frequenz/channels/__init__.py @@ -80,7 +80,7 @@ """ from ._anycast import Anycast -from ._broadcast import Broadcast +from ._broadcast import Broadcast, broadcast from ._exceptions import ChannelClosedError, ChannelError, Error from ._generic import ( ChannelMessageT, @@ -92,6 +92,7 @@ ) from ._latest_value_cache import LatestValueCache from ._merge import Merger, merge +from ._one_shot import oneshot from ._receiver import Receiver, ReceiverError, ReceiverStoppedError from ._select import ( Selected, @@ -100,7 +101,14 @@ select, selected_from, ) -from ._sender import Sender, SenderError +from ._sender import ( + ClonableSender, + ClonableSubscribableSender, + Sender, + SenderClosedError, + SenderError, + SubscribableSender, +) __all__ = [ "Anycast", @@ -108,6 +116,8 @@ "ChannelClosedError", "ChannelError", "ChannelMessageT", + "ClonableSender", + "ClonableSubscribableSender", "Error", "ErroredChannelT_co", "LatestValueCache", @@ -120,11 +130,15 @@ "SelectError", "Selected", "Sender", + "SenderClosedError", "SenderError", "SenderMessageT_co", "SenderMessageT_contra", + "SubscribableSender", "UnhandledSelectedError", + "broadcast", "merge", + "oneshot", "select", "selected_from", ] diff --git a/src/frequenz/channels/_anycast.py b/src/frequenz/channels/_anycast.py index b5184a3f..d8259956 100644 --- a/src/frequenz/channels/_anycast.py +++ b/src/frequenz/channels/_anycast.py @@ -15,7 +15,7 @@ from ._exceptions import ChannelClosedError from ._generic import ChannelMessageT from ._receiver import Receiver, ReceiverStoppedError -from ._sender import Sender, SenderError +from ._sender import Sender, SenderClosedError, SenderError _logger = logging.getLogger(__name__) @@ -327,6 +327,9 @@ def __init__(self, channel: Anycast[_T], /) -> None: self._channel: Anycast[_T] = channel """The channel that this sender belongs to.""" + self._closed: bool = False + """Whether the sender is closed.""" + @override async def send(self, message: _T, /) -> None: """Send a message across the channel. @@ -343,7 +346,11 @@ async def send(self, message: _T, /) -> None: SenderError: If the underlying channel was closed. A [ChannelClosedError][frequenz.channels.ChannelClosedError] is set as the cause. + SenderClosedError: If this sender was closed. """ + if self._closed: + raise SenderClosedError(self) + # pylint: disable=protected-access if self._channel._closed: raise SenderError("The channel was closed", self) from ChannelClosedError( @@ -367,6 +374,16 @@ async def send(self, message: _T, /) -> None: self._channel._recv_cv.notify(1) # pylint: enable=protected-access + @override + def close(self) -> None: + """Close this sender. + + After closing, the sender will not be able to send any more messages. Any + attempt to send a message through a closed sender will raise a + [SenderError][frequenz.channels.SenderError]. + """ + self._closed = True + def __str__(self) -> str: """Return a string representation of this sender.""" return f"{self._channel}:{type(self).__name__}" diff --git a/src/frequenz/channels/_broadcast.py b/src/frequenz/channels/_broadcast.py index 2c167d5e..df84bc59 100644 --- a/src/frequenz/channels/_broadcast.py +++ b/src/frequenz/channels/_broadcast.py @@ -16,12 +16,51 @@ from ._exceptions import ChannelClosedError from ._generic import ChannelMessageT from ._receiver import Receiver, ReceiverStoppedError -from ._sender import Sender, SenderError +from ._sender import ClonableSubscribableSender, SenderClosedError, SenderError _logger = logging.getLogger(__name__) -class Broadcast(Generic[ChannelMessageT]): +def broadcast( + message_type: type[ChannelMessageT], # pylint: disable=unused-argument + /, + *, + name: str, + resend_latest: bool = False, +) -> tuple[ClonableSubscribableSender[ChannelMessageT], Receiver[ChannelMessageT]]: + """Create a new Broadcast channel and return a sender and a receiver attached to it. + + The channel will be automatically closed when all senders or all receivers + are closed. + + Args: + message_type: The type of messages that will be sent through this channel. This + is only for type checking purposes, it is not used at runtime. + name: The name of the channel. This is for logging purposes, and it will be + shown in the string representation of the channel. + resend_latest: When True, every time a new receiver is created with + `new_receiver`, the last message seen by the channel will be sent to the + new receiver automatically. This allows new receivers on slow streams to + get the latest message as soon as they are created, without having to + wait for the next message on the channel to arrive. It is safe to be + set in data/reporting channels, but is not recommended for use in + channels that stream control instructions. + + Returns: + A tuple of a sender and a receiver attached to the created channel. + """ + channel = Broadcast[ChannelMessageT]( + name=name, resend_latest=resend_latest, auto_close=True + ) + return channel.new_sender(), channel.new_receiver() + + +@deprecated( + "Please use the `broadcast` function to create a Broadcast channel instead." +) +class Broadcast( # pylint: disable=too-many-instance-attributes + Generic[ChannelMessageT] +): """A channel that deliver all messages to all receivers. # Description @@ -184,7 +223,13 @@ async def main() -> None: ``` """ - def __init__(self, *, name: str, resend_latest: bool = False) -> None: + def __init__( + self, + *, + name: str, + resend_latest: bool = False, + auto_close: bool = False, + ) -> None: """Initialize this channel. Args: @@ -197,6 +242,8 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None: wait for the next message on the channel to arrive. It is safe to be set in data/reporting channels, but is not recommended for use in channels that stream control instructions. + auto_close: If True, the channel will be closed when all senders or all + receivers are closed. """ self._name: str = name """The name of the broadcast channel. @@ -207,6 +254,9 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None: self._recv_cv: Condition = Condition() """The condition to wait for data in the channel's buffer.""" + self._sender_count: int = 0 + """The number of senders attached to this channel.""" + self._receivers: dict[ int, weakref.ReferenceType[_Receiver[ChannelMessageT]] ] = {} @@ -218,6 +268,9 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None: self._latest: ChannelMessageT | None = None """The latest message sent to the channel.""" + self._auto_close: bool = auto_close + """Whether to close the channel when all senders or all receivers are closed.""" + self.resend_latest: bool = resend_latest """Whether to resend the latest message to new receivers. @@ -269,7 +322,7 @@ async def close(self) -> None: # noqa: D402 """Close the channel, deprecated alias for `aclose()`.""" # noqa: D402 return await self.aclose() - def new_sender(self) -> Sender[ChannelMessageT]: + def new_sender(self) -> ClonableSubscribableSender[ChannelMessageT]: """Return a new sender attached to this channel.""" return _Sender(self) @@ -317,7 +370,7 @@ def __repr__(self) -> str: _T = TypeVar("_T") -class _Sender(Sender[_T]): +class _Sender(ClonableSubscribableSender[_T]): """A sender to send messages to the broadcast channel. Should not be created directly, but through the @@ -334,6 +387,11 @@ def __init__(self, channel: Broadcast[_T], /) -> None: self._channel: Broadcast[_T] = channel """The broadcast channel this sender belongs to.""" + self._closed: bool = False + """Whether this sender is closed.""" + + self._channel._sender_count += 1 + @override async def send(self, message: _T, /) -> None: """Send a message to all broadcast receivers. @@ -345,12 +403,22 @@ async def send(self, message: _T, /) -> None: SenderError: If the underlying channel was closed. A [ChannelClosedError][frequenz.channels.ChannelClosedError] is set as the cause. + SenderClosedError: If this sender was closed. """ # pylint: disable=protected-access if self._channel._closed: raise SenderError("The channel was closed", self) from ChannelClosedError( self._channel ) + if self._channel._auto_close and ( + self._channel._sender_count == 0 or len(self._channel._receivers) == 0 + ): + await self._channel.aclose() + raise SenderError("The channel was closed", self) from ChannelClosedError( + self._channel + ) + if self._closed: + raise SenderClosedError(self) self._channel._latest = message stale_refs = [] for _hash, recv_ref in self._channel._receivers.items(): @@ -365,6 +433,27 @@ async def send(self, message: _T, /) -> None: self._channel._recv_cv.notify_all() # pylint: enable=protected-access + @override + def close(self) -> None: + """Close this sender. + + After a sender is closed, it can no longer be used to send messages. Any + attempt to send a message through a closed sender will raise a + [SenderError][frequenz.channels.SenderError]. + """ + self._closed = True + self._channel._sender_count -= 1 + + @override + def clone(self) -> _Sender[_T]: + """Return a clone of this sender.""" + return _Sender(self._channel) + + @override + def subscribe(self) -> Receiver[_T]: + """Return a new receiver attached to this sender's channel.""" + return self._channel.new_receiver() + def __str__(self) -> str: """Return a string representation of this sender.""" return f"{self._channel}:{type(self).__name__}" @@ -476,6 +565,11 @@ async def ready(self) -> bool: while len(self._q) == 0: if self._channel._closed or self._closed: return False + if self._channel._auto_close and ( + self._channel._sender_count == 0 or len(self._channel._receivers) == 0 + ): + await self._channel.aclose() + return False async with self._channel._recv_cv: await self._channel._recv_cv.wait() return True diff --git a/src/frequenz/channels/_one_shot.py b/src/frequenz/channels/_one_shot.py new file mode 100644 index 00000000..f3cb4340 --- /dev/null +++ b/src/frequenz/channels/_one_shot.py @@ -0,0 +1,90 @@ +# License: MIT +# Copyright © 2026 Frequenz Energy-as-a-Service GmbH + +"""A channel that can send a single message.""" + +import typing +from asyncio import Condition + +from ._generic import ChannelMessageT +from ._receiver import Receiver, ReceiverStoppedError +from ._sender import Sender, SenderClosedError + + +def oneshot( + message_type: type[ChannelMessageT], # pylint: disable=unused-argument +) -> tuple[Sender[ChannelMessageT], Receiver[ChannelMessageT]]: + """Create a one-shot channel. + + A one-shot channel is a channel that can only send one message. After the first + message is sent, the sender is closed and any further attempts to send a message + will raise a `SenderClosedError`. + + Args: + message_type: The type of messages that can be sent through this channel. + + Returns: + A tuple of a sender and a receiver for this channel. + """ + channel = _OneShot[ChannelMessageT]() + return _OneShotSender(channel), _OneShotReceiver(channel) + + +class _Empty: + pass + + +_EMPTY = _Empty() + + +class _OneShot(typing.Generic[ChannelMessageT]): + """A one-shot channel. + + A one-shot channel is a channel that can only send one message. After the first + message is sent, the sender is closed and any further attempts to send a message + will raise a `SenderClosedError`. + """ + + def __init__(self) -> None: + """Create a new one-shot channel.""" + self.message: ChannelMessageT | _Empty = _EMPTY + self.sent = False + self.drained = False + self.condition = Condition() + + +class _OneShotSender(Sender[ChannelMessageT]): + def __init__(self, channel: _OneShot[ChannelMessageT]) -> None: + self._channel = channel + + async def send(self, message: ChannelMessageT, /) -> None: + if self._channel.sent: + raise SenderClosedError(self) + self._channel.message = message + self._channel.sent = True + if self._channel.condition.locked(): + self._channel.condition.notify() + + def close(self) -> None: + self._channel.sent = True + + +class _OneShotReceiver(Receiver[ChannelMessageT]): + def __init__(self, channel: _OneShot[ChannelMessageT]) -> None: + self._channel = channel + + async def ready(self) -> bool: + while not self._channel.sent: + await self._channel.condition.wait() + if self._channel.drained: + return False + return True + + def consume(self) -> ChannelMessageT: + if self._channel.drained: + raise ReceiverStoppedError(self) + if isinstance(self._channel.message, _Empty): + raise ReceiverStoppedError(self) + + self._channel.drained = True + return self._channel.message diff --git a/src/frequenz/channels/_sender.py b/src/frequenz/channels/_sender.py index e225e94c..a2a3c14f 100644 --- a/src/frequenz/channels/_sender.py +++ b/src/frequenz/channels/_sender.py @@ -49,11 +49,14 @@ ``` """ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Generic from ._exceptions import Error from ._generic import SenderMessageT_co, SenderMessageT_contra +from ._receiver import Receiver class Sender(ABC, Generic[SenderMessageT_contra]): @@ -70,6 +73,15 @@ async def send(self, message: SenderMessageT_contra, /) -> None: SenderError: If there was an error sending the message. """ + @abstractmethod + def close(self) -> None: + """Close this sender. + + After a sender is closed, it can no longer be used to send messages. Any + attempt to send a message through a closed sender will raise a + [SenderError][frequenz.channels.SenderError]. + """ + class SenderError(Error, Generic[SenderMessageT_co]): """An error that originated in a [Sender][frequenz.channels.Sender]. @@ -88,3 +100,47 @@ def __init__(self, message: str, sender: Sender[SenderMessageT_co]): super().__init__(message) self.sender: Sender[SenderMessageT_co] = sender """The sender where the error happened.""" + + +class SenderClosedError(SenderError[SenderMessageT_co]): + """An error indicating that a send operation was attempted a closed sender.""" + + def __init__(self, sender: Sender[SenderMessageT_co]): + """Initialize this error. + + Args: + sender: The [Sender][frequenz.channels.Sender] that was closed. + """ + super().__init__("Sender is closed", sender) + + +class SubscribableSender(Sender[SenderMessageT_contra], ABC): + """A [Sender][frequenz.channels.Sender] that can be subscribed to.""" + + @abstractmethod + def subscribe(self) -> Receiver[SenderMessageT_contra]: + """Subscribe to this sender. + + Returns: + A new sender that sends messages to the same channel as this sender. + """ + + +class ClonableSender(Sender[SenderMessageT_contra], ABC): + """A [Sender][frequenz.channels.Sender] that can be cloned.""" + + @abstractmethod + def clone(self) -> ClonableSender[SenderMessageT_contra]: + """Clone this sender. + + Returns: + A new sender that sends messages to the same channel as this sender. + """ + + +class ClonableSubscribableSender( + SubscribableSender[SenderMessageT_contra], + ClonableSender[SenderMessageT_contra], + ABC, +): + """A [Sender][frequenz.channels.Sender] that can be both cloned and subscribed to.""" diff --git a/src/frequenz/channels/experimental/_relay_sender.py b/src/frequenz/channels/experimental/_relay_sender.py index 398ba8d5..b173637a 100644 --- a/src/frequenz/channels/experimental/_relay_sender.py +++ b/src/frequenz/channels/experimental/_relay_sender.py @@ -7,15 +7,13 @@ to the senders it was created with. """ -import typing - from typing_extensions import override from .._generic import SenderMessageT_contra from .._sender import Sender -class RelaySender(typing.Generic[SenderMessageT_contra], Sender[SenderMessageT_contra]): +class RelaySender(Sender[SenderMessageT_contra]): """A Sender for sending messages to multiple senders. The `RelaySender` class takes multiple senders and forwards all the messages sent to @@ -57,3 +55,9 @@ async def send(self, message: SenderMessageT_contra, /) -> None: """ for sender in self._senders: await sender.send(message) + + @override + def close(self) -> None: + """Close this sender.""" + for sender in self._senders: + sender.close() diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index 0cc89f33..cbd81373 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -17,6 +17,7 @@ ReceiverStoppedError, Sender, SenderError, + broadcast, ) @@ -425,3 +426,50 @@ async def test_broadcast_close_receiver() -> None: with pytest.raises(ReceiverStoppedError): _ = await receiver_2.receive() + + +async def test_broadcast_auto_close_1() -> None: + """Ensure broadcast auto close works when all receivers are closed.""" + sender, receiver = broadcast(int, name="auto-close-test") + + receiver_2 = sender.subscribe() + + await sender.send(1) + + assert (await receiver.receive()) == 1 + assert (await receiver_2.receive()) == 1 + + receiver.close() + + await sender.send(2) + + assert (await receiver_2.receive()) == 2 + + receiver_2.close() + + with pytest.raises(SenderError) as excinfo: + await sender.send(3) + assert isinstance(excinfo.value.__cause__, ChannelClosedError) + + +async def test_broadcast_auto_close_2() -> None: + """Ensure broadcast auto close works when all senders are closed.""" + sender, receiver = broadcast(int, name="auto-close-test") + + await sender.send(1) + + assert (await receiver.receive()) == 1 + + sender_2 = sender.clone() + + sender.close() + + await sender_2.send(2) + + sender_2.close() + + assert (await receiver.receive()) == 2 + + with pytest.raises(ReceiverStoppedError) as excinfo: + await receiver.receive() + assert isinstance(excinfo.value.__cause__, ChannelClosedError) diff --git a/tests/test_oneshot.py b/tests/test_oneshot.py new file mode 100644 index 00000000..03462511 --- /dev/null +++ b/tests/test_oneshot.py @@ -0,0 +1,34 @@ +# License: MIT +# Copyright © 2026 Frequenz Energy-as-a-Service GmbH + +"""Tests for the oneshot channel.""" + +import asyncio + +import pytest + +from frequenz.channels import ReceiverStoppedError, SenderClosedError, oneshot + + +async def test_oneshot() -> None: + """Test the oneshot function.""" + sender, receiver = oneshot(int) + + received: int | None = None + + async def receive_in_background() -> None: + nonlocal received + received = await receiver.receive() + + task = asyncio.create_task(receive_in_background()) + + await sender.send(42) + await task + + assert received == 42 + + with pytest.raises(SenderClosedError): + await sender.send(43) + + with pytest.raises(ReceiverStoppedError): + await receiver.receive()