Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 38 additions & 40 deletions services/streaming_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
to simplify streaming implementations across services.
"""

from typing import Optional, Any, Dict
from dataclasses import dataclass
import uuid
import json
import uuid
from dataclasses import dataclass
from typing import Any


@dataclass
Expand All @@ -32,7 +32,7 @@ class StreamManager:
manager.end_stream()
"""

def __init__(self, model: str = "claude-3-7-sonnet-20250219", stream = True):
def __init__(self, model: str = "claude-3-7-sonnet-20250219", stream: bool = True):
"""
Initialize the stream manager.

Expand All @@ -41,15 +41,15 @@ def __init__(self, model: str = "claude-3-7-sonnet-20250219", stream = True):
"""
self.stream = stream
self.model = model
self.message_id: Optional[str] = None
self.message_id: str | None = None
self.stream_started = False
self.stream_ended = False

# Track all content blocks
self.blocks: list[ContentBlock] = []
self.current_index = -1

def _emit_event(self, event_type: str, data: Dict[str, Any]) -> None:
def _emit_event(self, event_type: str, data: dict[str, Any]) -> None:
"""
Send event through bridge to be forwarded as SSE.

Expand All @@ -60,16 +60,16 @@ def _emit_event(self, event_type: str, data: Dict[str, Any]) -> None:
# Use EVENT: prefix format that bridge.ts expects
# Bridge will convert this to proper SSE format
if self.stream:
print(f"EVENT:{event_type}:{json.dumps(data)}", flush=True)
print(f"EVENT:{event_type}:{json.dumps(data)}", flush=True) # noqa: T201

def start_stream(self) -> None:
"""
Start a new stream by sending message_start event.
Should be called once at the beginning of streaming.
"""
if self.stream_started:
raise RuntimeError("Stream already started")

self.message_id = f"msg_{uuid.uuid4().hex[:24]}"
self.stream_started = True

Expand All @@ -83,59 +83,59 @@ def start_stream(self) -> None:
"model": self.model,
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 0, "output_tokens": 0}
}
"usage": {"input_tokens": 0, "output_tokens": 0},
},
})

def send_thinking(
self,
self,
thinking_text: str,
signature: Optional[str] = "signature_filler"
signature: str | None = "signature_filler",
) -> None:
"""
Send a thinking block with the given text.
Creates a new content block, sends the thinking, and closes it.

Args:
thinking_text: The thinking content to send
signature: Optional signature to include with the thinking
"""
if self.stream_ended:
raise RuntimeError("Stream already ended")

if not self.stream_started:
self.start_stream()

self._close_open_blocks()

# Create new thinking block
index = self._next_index()
block = ContentBlock(index=index, block_type="thinking")
self.blocks.append(block)

# Send block start
self._emit_event('content_block_start', {
"type": "content_block_start",
"index": index,
"content_block": {"type": "thinking", "thinking": ""}
"content_block": {"type": "thinking", "thinking": ""},
})

# Send thinking delta
self._emit_event('content_block_delta', {
"type": "content_block_delta",
"index": index,
"delta": {"type": "thinking_delta", "thinking": thinking_text}
"delta": {"type": "thinking_delta", "thinking": thinking_text},
})

# Send an Anthropic signature string
self._emit_event('content_block_delta', {
"type": "content_block_delta",
"index": index,
"delta": {"type": "signature_delta", "signature": signature}
"delta": {"type": "signature_delta", "signature": signature},
})

self._close_block(block)

def send_text(self, text_chunk: str) -> None:
"""
Send a text chunk. Automatically manages text content block lifecycle.
Expand All @@ -150,33 +150,31 @@ def send_text(self, text_chunk: str) -> None:

# Check if we have an open text block
current_text_block = self._get_current_text_block()

if current_text_block is None:
# Create new text block
index = self._next_index()
current_text_block = ContentBlock(index=index, block_type="text")
self.blocks.append(current_text_block)

# Send block start
self._emit_event('content_block_start', {
"type": "content_block_start",
"index": index,
"content_block": {"type": "text", "text": ""}
"content_block": {"type": "text", "text": ""},
})

# Send text delta
self._emit_event('content_block_delta', {
"type": "content_block_delta",
"index": current_text_block.index,
"delta": {"type": "text_delta", "text": text_chunk}
"delta": {"type": "text_delta", "text": text_chunk},
})



def end_stream(self, stop_reason: str = "end_turn") -> None:
"""
End the stream by closing all open blocks and sending final events.
"""

if self.stream_ended or not self.stream_started:
return

Expand All @@ -187,41 +185,41 @@ def end_stream(self, stop_reason: str = "end_turn") -> None:
self._emit_event('message_delta', {
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
"usage": {"output_tokens": 0}
"usage": {"output_tokens": 0},
})

# Send message_stop
self._emit_event('message_stop', {
"type": "message_stop"
"type": "message_stop",
})

self.stream_ended = True

def _next_index(self) -> int:
"""Get the next content block index."""
self.current_index += 1
return self.current_index
def _get_current_text_block(self) -> Optional[ContentBlock]:

def _get_current_text_block(self) -> ContentBlock | None:
"""Get the currently open text block, if any."""
for block in reversed(self.blocks):
if block.is_open and block.block_type == "text":
return block
return None

def _close_block(self, block: ContentBlock) -> None:
"""Close a specific content block."""
if not block.is_open:
return

self._emit_event('content_block_stop', {
"type": "content_block_stop",
"index": block.index
"index": block.index,
})
block.is_open = False

def _close_open_blocks(self) -> None:
"""Close all currently open content blocks."""
for block in self.blocks:
if block.is_open:
self._close_block(block)
self._close_block(block)
50 changes: 26 additions & 24 deletions services/util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json
import logging
import os
import sys
import requests
import psycopg2
from dataclasses import dataclass
from typing import Optional, Any, Dict
from typing import Any

import psycopg2
import requests

# Adaptor parsing constants
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine to accept this but for future reference it's really not necessary. These aren't what I call "magic numbers" and are fine to leave in-situ.

The structure of adaptor specifiers won't change any time soon and I don't particularly think these constants make the code easier to read?

@hanna-paasivirta we can accept this, unless you have a strong opinion, but this isn't a pattern I want to see enforced across the codebase

SCOPED_ADAPTOR_MIN_PARTS = 3
SHORTHAND_ADAPTOR_PARTS = 2

class DictObj:
"""
Expand All @@ -22,13 +25,13 @@ def __init__(self, in_dict: dict):
else:
setattr(self, key, DictObj(val) if isinstance(val, dict) else val)

def get(self, key):
def get(self, key: str) -> Any: # noqa: ANN401
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this # noqa comment mean?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a Ruff rule suppression comment.
It tells the linter to ignore a specific rule on that line.
So ANN401 is a Ruff rule code.

So the comment tells Ruff Ignore the ANN401 rule for this line.
So that's why it is there.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! And what is it we're ignoring here - the Any return type?

We don't use Ruff as you can tell. I thought were were running Black or Prettier (or both?!) but I'll have to check into that.

Happy to accept code improvements but unless we're going to adopt Ruff as a standard (and I'll speak to the team about this) I don't think we should have Ruff-specific comments in the code.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the feedback.

Yeah we are suppressing the Any return type that's why we use that comment there. from the noise it maybe when we run ruff check. However, the project does us both Ruff for linter and Black for python formatter with comprehensive rule set configured in project dependencies in pyproject.toml including ANN (flake8- annotations).

So I may say the # noqa comment are appropriate for this codebase according to modern Python standards.
If it should not be there I'm happy to adjust it.
And I'm also happy and open with what we're going to adopt.

return self._dict.get(key)

def has(self, key):
def has(self, key: str) -> bool:
return key in self._dict

def to_dict(self):
def to_dict(self) -> dict:
return self._dict


Expand All @@ -38,7 +41,7 @@ class ApolloError(Exception):
code: int
message: str
type: str = "APOLLO_ERROR"
details: Optional[dict[str, Any]] = None
details: dict[str, Any] | None = None

def to_dict(self) -> dict:
"""Serialize the error to a dictionary format"""
Expand All @@ -53,21 +56,21 @@ def to_dict(self) -> dict:


filename = None
loggers = {}
loggers: dict[str, logging.Logger] = {}
apollo_port = 3000


def set_log_output(f):
def set_log_output(f: str | None) -> None:
"""Set the output file for logging."""
global filename
global filename # noqa: PLW0603

if f is not None:
print(f"[entry.py] writing logs to {f}")
print(f"[entry.py] writing logs to {f}") # noqa: T201

filename = f


def create_logger(name):
def create_logger(name: str) -> logging.Logger:
"""
Create or retrieve a logger with the given name.
Logs to stdout by default.
Expand All @@ -79,26 +82,25 @@ def create_logger(name):
return loggers[name]


def set_apollo_port(p):
def set_apollo_port(p: int) -> None:
"""Set the port for Apollo services."""
global apollo_port
global apollo_port # noqa: PLW0603
apollo_port = p


def apollo(name, payload):
def apollo(name: str, payload: dict) -> dict:
"""
Call out to an Apollo service through HTTP.
:param name: Name of the service.
:param payload: Payload to send in the POST request.
:return: JSON response.
"""
global apollo_port
url = f"http://127.0.0.1:{apollo_port}/services/{name}"
r = requests.post(url, json = payload)
r = requests.post(url, json=payload)
return r.json()


def get_db_connection():
def get_db_connection() -> "psycopg2.extensions.connection":
"""Get database connection from POSTGRES_URL environment variable.

Returns:
Expand Down Expand Up @@ -138,24 +140,24 @@ def __init__(self, adaptor_input: str):

# Handle format: "@openfn/language-http@3.1.11"
if adaptor_input.startswith("@"):
if len(adaptor_parts) >= 3:
if len(adaptor_parts) >= SCOPED_ADAPTOR_MIN_PARTS:
self.name = "@" + adaptor_parts[1]
self.version = adaptor_parts[2]
else:
raise ApolloError(
400,
f"Version must be specified in adaptor string. Expected format: '@openfn/language-http@3.1.11', got: '{adaptor_input}'",
type="BAD_REQUEST"
type="BAD_REQUEST",
)
# Handle format: "http@3.1.11"
elif len(adaptor_parts) == 2:
elif len(adaptor_parts) == SHORTHAND_ADAPTOR_PARTS:
self.name = f"@openfn/language-{adaptor_parts[0]}"
self.version = adaptor_parts[1]
else:
raise ApolloError(
400,
f"Version must be specified in adaptor string. Expected format: 'http@3.1.11' or '@openfn/language-http@3.1.11', got: '{adaptor_input}'",
type="BAD_REQUEST"
type="BAD_REQUEST",
)

@property
Expand All @@ -169,7 +171,7 @@ def short_name(self) -> str:
return self.name.split("/")[-1].replace("language-", "")


def add_page_prefix(content: str, page: Optional[dict]) -> str:
def add_page_prefix(content: str, page: dict | None) -> str:
"""
Add [pg:...] prefix to message for page navigation tracking.

Expand Down