Skip to content
5 changes: 2 additions & 3 deletions src/replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
)

from .rpc import (
ErrorType,
InitType,
RequestType,
ResponseType,
Expand Down Expand Up @@ -129,7 +128,7 @@ async def send_subscription(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], Any],
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
Comment thread
airportyh marked this conversation as resolved.
with _trace_procedure(
"subscription", service_name, procedure_name
) as span_handle:
Expand Down Expand Up @@ -157,7 +156,7 @@ async def send_stream(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], Any],
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
with _trace_procedure("stream", service_name, procedure_name) as span_handle:
session = await self._transport.get_or_create_session()
async for msg in session.send_stream(
Expand Down
40 changes: 23 additions & 17 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError
RiverErrorTypeAdapter = TypeAdapter(RiverError)
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
import replit_river as river

"""
Expand Down Expand Up @@ -763,26 +762,25 @@ def generate_individual_service(
schema: RiverService,
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]:
serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []

def append_type_adapter_definition(
type_adapter_name: TypeName,
_type: TypeExpression,
module_info: list[ModuleName],
) -> None:
rendered_type_expr = render_type_expr(_type)
var_name = render_type_expr(type_adapter_name)
var_type = f"TypeAdapter[{rendered_type_expr}]"
var_value = f"TypeAdapter({rendered_type_expr})"
serdes.append(
(
[type_adapter_name],
module_info,
[
FileContents(
f"{type_adapter_name.value}: TypeAdapter[{rendered_type_expr}] = "
f"TypeAdapter({rendered_type_expr})"
)
],
[FileContents(f"{var_name}: {var_type} = {var_value}")],
Comment thread
airportyh marked this conversation as resolved.
Outdated
)
)

serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []
class_name = ClassName(f"{schema_name.title()}Service")
current_chunks: List[str] = [
dedent(
Expand Down Expand Up @@ -819,7 +817,9 @@ def __init__(self, client: river.Client[Any]):
permit_unknown_members=False,
)
input_type_name = extract_inner_type(input_type)
input_type_type_adapter_name = TypeName(f"{input_type_name.value}TypeAdapter")
input_type_type_adapter_name = TypeName(
f"{render_literal_type(input_type_name)}TypeAdapter"
)
serdes.append(
(
[extract_inner_type(input_type), *encoder_names],
Expand All @@ -845,7 +845,9 @@ def __init__(self, client: river.Client[Any]):
output_chunks,
)
)
output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter")
output_type_type_adapter_name = TypeName(
f"{render_literal_type(output_type_name)}TypeAdapter"
)
append_type_adapter_definition(
output_type_type_adapter_name, output_type, module_info
)
Expand All @@ -869,7 +871,9 @@ def __init__(self, client: river.Client[Any]):
error_type_name = TypeName("RiverError")
error_type = error_type_name

error_type_type_adapter_name = TypeName(f"{error_type_name.value}TypeAdapter")
error_type_type_adapter_name = TypeName(
f"{render_literal_type(error_type_name)}TypeAdapter"
)
if error_type_type_adapter_name.value != "RiverErrorTypeAdapter":
if len(module_info) == 0:
module_info = output_module_info
Expand All @@ -882,14 +886,16 @@ def __init__(self, client: river.Client[Any]):
# the function strings in the branches below, otherwise `dedent`
# will pick our indentation level for normalization, which will
# break the "def" indentation presuppositions.
ottd_name = render_literal_type(output_type_type_adapter_name)
parse_output_method = f"""\
lambda x: {output_type_type_adapter_name.value}
lambda x: {ottd_name}
Comment thread
airportyh marked this conversation as resolved.
Outdated
.validate_python(
x # type: ignore[arg-type]
)
"""
ettd_name = render_literal_type(error_type_type_adapter_name)
parse_error_method = f"""\
lambda x: {error_type_type_adapter_name.value}
lambda x: {ettd_name}
.validate_python(
x # type: ignore[arg-type]
)
Expand Down Expand Up @@ -920,8 +926,8 @@ def __init__(self, client: river.Client[Any]):
init_type_type_adapter_name, init_type, module_info
)
render_init_method = f"""\
lambda x: {init_type_type_adapter_name.value})
.validate_python
lambda x: {render_type_expr(init_type_type_adapter_name)}
.validate_python
"""

assert init_type is None or render_init_method, (
Expand All @@ -947,7 +953,7 @@ def __init__(self, client: river.Client[Any]):
render_input_method = f"encode_{render_literal_type(input_type)}"
else:
render_input_method = f"""\
lambda x: {input_type_type_adapter_name.value}
lambda x: {render_type_expr(input_type_type_adapter_name)}
.dump_python(
x, # type: ignore[arg-type]
by_alias=True,
Expand Down
5 changes: 4 additions & 1 deletion src/replit_river/error_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List, Optional

from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter

ERROR_CODE_STREAM_CLOSED = "stream_closed"
ERROR_HANDSHAKE = "handshake_failed"
Expand All @@ -25,6 +25,9 @@ class RiverError(BaseModel):
message: str


RiverErrorTypeAdapter = TypeAdapter(RiverError)


class RiverException(Exception):
"""Exception raised by the River server."""

Expand Down
4 changes: 1 addition & 3 deletions tests/codegen/rpc/generated/test_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError

RiverErrorTypeAdapter = TypeAdapter(RiverError)
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
import replit_river as river


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError

RiverErrorTypeAdapter = TypeAdapter(RiverError)
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
import replit_river as river


Expand Down
4 changes: 1 addition & 3 deletions tests/codegen/stream/generated/test_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError

RiverErrorTypeAdapter = TypeAdapter(RiverError)
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
import replit_river as river


Expand Down