Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

CHANGED

- Allow entities with custom names

## v1.2.0

ADDED:
Expand Down
8 changes: 8 additions & 0 deletions durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,14 @@ def retry_timeout(self) -> Optional[timedelta]:
return self._retry_timeout


def get_entity_name(fn: Entity) -> str:
if hasattr(fn, "__durable_entity_name__"):
return getattr(fn, "__durable_entity_name__")
if isinstance(fn, type) and issubclass(fn, DurableEntity):
return fn.__name__
return get_name(fn)


def get_name(fn: Callable) -> str:
"""Returns the name of the provided function"""
name = fn.__name__
Expand Down
16 changes: 7 additions & 9 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,14 @@ def add_named_activity(self, name: str, fn: task.Activity) -> None:
def get_activity(self, name: str) -> Optional[task.Activity]:
return self.activities.get(name)

def add_entity(self, fn: task.Entity) -> str:
def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str:
if fn is None:
raise ValueError("An entity function argument is required.")

if isinstance(fn, type) and issubclass(fn, DurableEntity):
name = fn.__name__
self.add_named_entity(name, fn)
else:
name = task.get_name(fn)
self.add_named_entity(name, fn)
if name is None:
name = task.get_entity_name(fn)

self.add_named_entity(name, fn)
return name

def add_named_entity(self, name: str, fn: task.Entity) -> None:
Expand Down Expand Up @@ -378,13 +376,13 @@ def add_activity(self, fn: task.Activity) -> str:
)
return self._registry.add_activity(fn)

def add_entity(self, fn: task.Entity) -> str:
def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str:
"""Registers an entity function with the worker."""
if self._is_running:
raise RuntimeError(
"Entities cannot be added while the worker is running."
)
return self._registry.add_entity(fn)
return self._registry.add_entity(fn, name)

def use_versioning(self, version: VersioningOptions) -> None:
"""Initializes versioning options for sub-orchestrators and activities."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")


def test_client_signal_class_entity():
def test_client_signal_class_entity_and_custom_name():
invoked = False

class EmptyEntity(entities.DurableEntity):
Expand All @@ -28,12 +28,12 @@ def do_nothing(self, _):
# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None) as w:
w.add_entity(EmptyEntity)
w.add_entity(EmptyEntity, name="EntityNameCustom")
w.start()

c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None)
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
c.signal_entity(entity_id, "do_nothing")
time.sleep(2) # wait for the signal to be processed

Expand Down Expand Up @@ -70,7 +70,7 @@ def do_nothing(self, _):
assert invoked


def test_orchestration_signal_class_entity():
def test_orchestration_signal_class_entity_and_custom_name():
invoked = False

class EmptyEntity(entities.DurableEntity):
Expand All @@ -79,14 +79,14 @@ def do_nothing(self, _):
invoked = True

def empty_orchestrator(ctx: task.OrchestrationContext, _):
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
ctx.signal_entity(entity_id, "do_nothing")

# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None) as w:
w.add_orchestrator(empty_orchestrator)
w.add_entity(EmptyEntity)
w.add_entity(EmptyEntity, name="EntityNameCustom")
w.start()

c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")


def test_client_signal_entity():
def test_client_signal_entity_and_custom_name():
invoked = False

def empty_entity(ctx: entities.EntityContext, _):
Expand All @@ -28,12 +28,12 @@ def empty_entity(ctx: entities.EntityContext, _):
# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None) as w:
w.add_entity(empty_entity)
w.add_entity(empty_entity, name="EntityNameCustom")
w.start()

c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None)
entity_id = entities.EntityInstanceId("empty_entity", "testEntity")
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
c.signal_entity(entity_id, "do_nothing")
time.sleep(2) # wait for the signal to be processed

Expand Down Expand Up @@ -70,7 +70,7 @@ def empty_entity(ctx: entities.EntityContext, _):
assert invoked


def test_orchestration_signal_entity():
def test_orchestration_signal_entity_and_custom_name():
invoked = False

def empty_entity(ctx: entities.EntityContext, _):
Expand All @@ -79,14 +79,14 @@ def empty_entity(ctx: entities.EntityContext, _):
invoked = True

def empty_orchestrator(ctx: task.OrchestrationContext, _):
entity_id = entities.EntityInstanceId("empty_entity", f"{ctx.instance_id}_testEntity")
entity_id = entities.EntityInstanceId("EntityNameCustom", f"{ctx.instance_id}_testEntity")
ctx.signal_entity(entity_id, "do_nothing")

# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None) as w:
w.add_orchestrator(empty_orchestrator)
w.add_entity(empty_entity)
w.add_entity(empty_entity, name="EntityNameCustom")
w.start()

c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
Expand Down
Loading