Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/ezmsg/core/backendprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any

from .stream import Stream, InputStream, OutputStream
from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR, ZERO_COPY_ATTR
from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR
from .messagechannel import LeakyQueue

from .graphcontext import GraphContext
Expand Down Expand Up @@ -374,8 +374,6 @@ async def wrapped_task(msg: Any = None) -> None:
result = call_fn(msg)
if inspect.isasyncgen(result):
async for stream, obj in result:
if obj and getattr(task, ZERO_COPY_ATTR, False) and obj is msg:
obj = deepcopy(obj)
await pub_fn(stream, obj)

elif asyncio.iscoroutine(result):
Expand Down
26 changes: 19 additions & 7 deletions src/ezmsg/core/unit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
import inspect
import functools
import warnings
from .stream import InputStream, OutputStream
from .component import ComponentMeta, Component
from .settings import Settings
Expand All @@ -23,6 +24,8 @@
LEAKY_ATTR = "__ez_leaky__"
MAX_QUEUE_ATTR = "__ez_max_queue__"

_ZERO_COPY_SENTINEL = object()


class UnitMeta(ComponentMeta):
def __init__(
Expand Down Expand Up @@ -162,17 +165,18 @@ def pub_factory(func):
return pub_factory


def subscriber(stream: InputStream, zero_copy: bool = False):
def subscriber(stream: InputStream, zero_copy: Any = _ZERO_COPY_SENTINEL):
"""
A decorator for a method that subscribes to a stream in the task/messaging thread.

An async function will run once per message received from the :obj:`InputStream`
it subscribes to. A function can have both ``@subscriber`` and ``@publisher`` decorators.

The ``zero_copy`` argument is deprecated and ignored. Subscribers always receive
zero-copy messages, so callers can omit it.

:param stream: The input stream to subscribe to
:type stream: InputStream
:param zero_copy: Whether to use zero-copy message passing (default: False)
:type zero_copy: bool
:return: Decorated function that can subscribe to the stream
:rtype: collections.abc.Callable
:raises ValueError: If stream is not an InputStream
Expand All @@ -183,20 +187,28 @@ def subscriber(stream: InputStream, zero_copy: bool = False):

INPUT = ez.InputStream(Message)

@subscriber(INPUT)
async def print_message(self, message: Message) -> None:
print(message)
@subscriber(INPUT)
async def print_message(self, message: Message) -> None:
print(message)
"""

if not isinstance(stream, InputStream):
raise ValueError(f"Cannot subscribe to object of type {type(stream)}")

if zero_copy is not _ZERO_COPY_SENTINEL:
warnings.warn(
"The `zero_copy` argument to @subscriber is deprecated and ignored. "
"Subscribers always receive zero-copy messages, so remove `zero_copy=True`.",
DeprecationWarning,
stacklevel=2,
)

def sub_factory(func):
subscribed_streams: InputStream | None = getattr(func, SUBSCRIBES_ATTR, None)
if subscribed_streams is not None:
raise Exception(f"{func} cannot subscribe to more than one stream")
setattr(func, SUBSCRIBES_ATTR, stream)
setattr(func, ZERO_COPY_ATTR, zero_copy)
setattr(func, ZERO_COPY_ATTR, True)
setattr(func, LEAKY_ATTR, stream.leaky)
setattr(func, MAX_QUEUE_ATTR, stream.max_queue)
return task(func)
Expand Down
2 changes: 1 addition & 1 deletion src/ezmsg/util/debuglog.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DebugLog(ez.Unit):
OUTPUT = ez.OutputStream(Any)
"""Send messages back out to continue through the graph."""

@ez.subscriber(INPUT, zero_copy=True)
@ez.subscriber(INPUT)
@ez.publisher(OUTPUT)
async def log(self, msg: Any) -> AsyncGenerator:
logstr = f"{self.SETTINGS.name} - {msg=}"
Expand Down
4 changes: 2 additions & 2 deletions src/ezmsg/util/messages/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def on_settings(self, msg: ez.Settings) -> None:
self.apply_settings(msg)
self.construct_generator()

@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
@ez.subscriber(INPUT_SIGNAL)
@ez.publisher(OUTPUT_SIGNAL)
async def on_message(self, message: AxisArray) -> AsyncGenerator:
"""
Expand Down Expand Up @@ -125,7 +125,7 @@ class FilterOnKey(ez.Unit):
INPUT_SIGNAL = ez.InputStream(AxisArray)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)

@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
@ez.subscriber(INPUT_SIGNAL)
@ez.publisher(OUTPUT_SIGNAL)
async def on_message(self, message: AxisArray) -> AsyncGenerator:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/ezmsg/util/messages/modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def on_settings(self, msg: ez.Settings) -> None:
self.apply_settings(msg)
self._transformer = ModifyAxisTransformer(self.SETTINGS)

@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
@ez.subscriber(INPUT_SIGNAL)
@ez.publisher(OUTPUT_SIGNAL)
async def on_message(self, message: AxisArray) -> AsyncGenerator:
ret = self._transformer(message)
Expand Down
6 changes: 3 additions & 3 deletions src/ezmsg/util/perf/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class LoadTestRelay(ez.Unit):
INPUT = ez.InputStream(LoadTestSample)
OUTPUT = ez.OutputStream(LoadTestSample)

@ez.subscriber(INPUT, zero_copy=True)
@ez.subscriber(INPUT)
@ez.publisher(OUTPUT)
async def on_msg(self, msg: LoadTestSample) -> typing.AsyncGenerator:
yield self.OUTPUT, msg
Expand All @@ -152,7 +152,7 @@ class LoadTestReceiver(ez.Unit):
async def initialize(self) -> None:
ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})")

@ez.subscriber(INPUT, zero_copy=True)
@ez.subscriber(INPUT)
async def receive(self, sample: LoadTestSample) -> None:
counter = self.STATE.counters.get(sample.key, -1)
if sample.counter != counter + 1:
Expand All @@ -166,7 +166,7 @@ async def receive(self, sample: LoadTestSample) -> None:
class LoadTestSink(LoadTestReceiver):
INPUT = ez.InputStream(LoadTestSample)

@ez.subscriber(INPUT, zero_copy=True)
@ez.subscriber(INPUT)
async def receive(self, sample: LoadTestSample) -> None:
await super().receive(sample)
if len(self.STATE.received_data) == self.SETTINGS.num_msgs:
Expand Down