Skip to content
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.11
3.12
4 changes: 2 additions & 2 deletions .replit
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
run = "poetry run pytest tests"

modules = ["python-3.11"]
modules = ["python-3.12"]

[nix]
channel = "stable-23_11"
channel = "stable-24_11"
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib";
};
packages = replitNixDeps ++ [
pkgs.python311
pkgs.python312
pkgs.uv
];
};
Expand Down
74 changes: 61 additions & 13 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from pydantic import TypeAdapter

from replit_river.error_schema import RiverError
RiverErrorTypeAdapter = TypeAdapter(RiverError)
import replit_river as river

"""
Comment thread
airportyh marked this conversation as resolved.
Outdated
Expand Down Expand Up @@ -762,6 +763,25 @@ def generate_individual_service(
schema: RiverService,
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]:
def append_type_adapter_definition(
type_adapter_name: TypeName,
_type: TypeExpression,
module_info: list[ModuleName],
) -> None:
serdes.append(
(
[type_adapter_name],
module_info,
[
FileContents(
f"{type_adapter_name.value} = "
f"TypeAdapter({render_type_expr(_type)})"
" # type: ignore"
Comment thread
airportyh marked this conversation as resolved.
Outdated
)
],
)
)

serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []
class_name = ClassName(f"{schema_name.title()}Service")
current_chunks: List[str] = [
Expand Down Expand Up @@ -798,27 +818,38 @@ def __init__(self, client: river.Client[Any]):
module_names,
permit_unknown_members=False,
)
input_type_name = extract_inner_type(input_type)
input_type_type_adapter_name = TypeName(f"{input_type_name.value}TypeAdapter")
Comment thread
airportyh marked this conversation as resolved.
Outdated
serdes.append(
(
[extract_inner_type(input_type), *encoder_names],
module_info,
input_chunks,
)
)
append_type_adapter_definition(
input_type_type_adapter_name, input_type, module_info
)
output_type, module_info, output_chunks, encoder_names = encode_type(
procedure.output,
TypeName(f"{name.title()}Output"),
"BaseModel",
module_names,
permit_unknown_members=True,
)
output_type_name = extract_inner_type(output_type)
serdes.append(
(
[extract_inner_type(output_type), *encoder_names],
[output_type_name, *encoder_names],
module_info,
output_chunks,
)
)
output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter")
Comment thread
airportyh marked this conversation as resolved.
Outdated
append_type_adapter_definition(
output_type_type_adapter_name, output_type, module_info
)
output_module_info = module_info
if procedure.errors:
error_type, module_info, errors_chunks, encoder_names = encode_type(
procedure.errors,
Expand All @@ -828,27 +859,37 @@ def __init__(self, client: river.Client[Any]):
permit_unknown_members=True,
)
if isinstance(error_type, NoneTypeExpr):
error_type = TypeName("RiverError")
error_type_name = TypeName("RiverError")
error_type = error_type_name
else:
serdes.append(
([extract_inner_type(error_type)], module_info, errors_chunks)
)
error_type_name = extract_inner_type(error_type)
serdes.append(([error_type_name], module_info, errors_chunks))

else:
error_type = TypeName("RiverError")
output_or_error_type = UnionTypeExpr([output_type, error_type])
error_type_name = TypeName("RiverError")
error_type = error_type_name

error_type_type_adapter_name = TypeName(f"{error_type_name.value}TypeAdapter")
Comment thread
airportyh marked this conversation as resolved.
Outdated
if error_type_type_adapter_name.value != "RiverErrorTypeAdapter":
if len(module_info) == 0:
module_info = output_module_info
append_type_adapter_definition(
error_type_type_adapter_name, error_type, module_info
)
output_or_error_type = UnionTypeExpr([output_type, error_type_name])

# NB: These strings must be indented to at least the same level of
# the function strings in the branches below, otherwise `dedent`
# will pick our indentation level for normalization, which will
# break the "def" indentation presuppositions.
parse_output_method = f"""\
lambda x: TypeAdapter({render_type_expr(output_type)})
lambda x: {output_type_type_adapter_name.value}
Comment thread
airportyh marked this conversation as resolved.
Outdated
.validate_python(
x # type: ignore[arg-type]
)
"""
parse_error_method = f"""\
lambda x: TypeAdapter({render_type_expr(error_type)})
lambda x: {error_type_type_adapter_name.value}
Comment thread
airportyh marked this conversation as resolved.
Outdated
.validate_python(
x # type: ignore[arg-type]
)
Expand All @@ -871,8 +912,15 @@ def __init__(self, client: river.Client[Any]):
else:
render_init_method = f"encode_{render_literal_type(init_type)}"
else:
init_type_name = extract_inner_type(init_type)
init_type_type_adapter_name = TypeName(
f"{init_type_name.value}TypeAdapter"
Comment thread
blast-hardcheese marked this conversation as resolved.
)
append_type_adapter_definition(
init_type_type_adapter_name, init_type, module_info
)
render_init_method = f"""\
lambda x: TypeAdapter({render_type_expr(init_type)})
lambda x: {init_type_type_adapter_name.value})
Comment thread
airportyh marked this conversation as resolved.
Outdated
.validate_python
"""

Expand All @@ -889,17 +937,17 @@ def __init__(self, client: river.Client[Any]):
procedure.input, RiverConcreteType
) and procedure.input.type in ["array"]:
match input_type:
case ListTypeExpr(input_type_name):
case ListTypeExpr(list_type):
render_input_method = f"""\
lambda xs: [
encode_{render_literal_type(input_type_name)}(x) for x in xs
encode_{render_literal_type(list_type)}(x) for x in xs
]
"""
else:
render_input_method = f"encode_{render_literal_type(input_type)}"
else:
render_input_method = f"""\
lambda x: TypeAdapter({render_type_expr(input_type)})
lambda x: {input_type_type_adapter_name.value}
Comment thread
airportyh marked this conversation as resolved.
Outdated
.dump_python(
x, # type: ignore[arg-type]
by_alias=True,
Expand Down
14 changes: 11 additions & 3 deletions tests/codegen/rpc/generated/test_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
from pydantic import TypeAdapter

from replit_river.error_schema import RiverError

RiverErrorTypeAdapter = TypeAdapter(RiverError)
import replit_river as river


from .rpc_method import Rpc_MethodInput, Rpc_MethodOutput, encode_Rpc_MethodInput
from .rpc_method import (
Rpc_MethodInput,
Rpc_MethodInputTypeAdapter,
Rpc_MethodOutput,
Rpc_MethodOutputTypeAdapter,
encode_Rpc_MethodInput,
)


class Test_ServiceService:
Expand All @@ -26,10 +34,10 @@ async def rpc_method(
"rpc_method",
input,
encode_Rpc_MethodInput,
lambda x: TypeAdapter(Rpc_MethodOutput).validate_python(
lambda x: Rpc_MethodOutputTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: TypeAdapter(RiverError).validate_python(
lambda x: RiverErrorTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
timeout,
Expand Down
6 changes: 6 additions & 0 deletions tests/codegen/rpc/generated/test_service/rpc_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,11 @@ class Rpc_MethodInput(TypedDict):
data: str


Rpc_MethodInputTypeAdapter = TypeAdapter(Rpc_MethodInput) # type: ignore


class Rpc_MethodOutput(BaseModel):
data: str


Rpc_MethodOutputTypeAdapter = TypeAdapter(Rpc_MethodOutput) # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,27 @@
from pydantic import TypeAdapter

from replit_river.error_schema import RiverError

RiverErrorTypeAdapter = TypeAdapter(RiverError)
import replit_river as river


from .needsEnum import (
NeedsenumErrors,
NeedsenumErrorsTypeAdapter,
NeedsenumInput,
NeedsenumInputTypeAdapter,
NeedsenumOutput,
NeedsenumOutputTypeAdapter,
encode_NeedsenumInput,
)
from .needsEnumObject import (
NeedsenumobjectErrors,
NeedsenumobjectErrorsTypeAdapter,
NeedsenumobjectInput,
NeedsenumobjectInputTypeAdapter,
NeedsenumobjectOutput,
NeedsenumobjectOutputTypeAdapter,
encode_NeedsenumobjectInput,
)

Expand All @@ -37,10 +45,10 @@ async def needsEnum(
"needsEnum",
input,
lambda x: x,
lambda x: TypeAdapter(NeedsenumOutput).validate_python(
lambda x: NeedsenumOutputTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: TypeAdapter(NeedsenumErrors).validate_python(
lambda x: NeedsenumErrorsTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
timeout,
Expand All @@ -56,10 +64,10 @@ async def needsEnumObject(
"needsEnumObject",
input,
encode_NeedsenumobjectInput,
lambda x: TypeAdapter(NeedsenumobjectOutput).validate_python(
lambda x: NeedsenumobjectOutputTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: TypeAdapter(NeedsenumobjectErrors).validate_python(
lambda x: NeedsenumobjectErrorsTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@

NeedsenumInput = Literal["in_first"] | Literal["in_second"]
encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x
NeedsenumInputTypeAdapter = TypeAdapter(NeedsenumInput) # type: ignore
NeedsenumOutput = Annotated[
Literal["out_first"] | Literal["out_second"] | RiverUnknownValue,
WrapValidator(translate_unknown_value),
]
NeedsenumOutputTypeAdapter = TypeAdapter(NeedsenumOutput) # type: ignore
NeedsenumErrors = Annotated[
Literal["err_first"] | Literal["err_second"] | RiverUnknownValue,
WrapValidator(translate_unknown_value),
]
NeedsenumErrorsTypeAdapter = TypeAdapter(NeedsenumErrors) # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict):
if x["kind"] == "in_first"
else encode_NeedsenumobjectInputOneOf_in_second(x)
)
NeedsenumobjectInputTypeAdapter = TypeAdapter(NeedsenumobjectInput) # type: ignore


class NeedsenumobjectOutputFooOneOf_out_first(BaseModel):
Expand Down Expand Up @@ -103,6 +104,9 @@ class NeedsenumobjectOutput(BaseModel):
foo: Optional[NeedsenumobjectOutputFoo] = None


NeedsenumobjectOutputTypeAdapter = TypeAdapter(NeedsenumobjectOutput) # type: ignore


class NeedsenumobjectErrorsFooAnyOf_0(RiverError):
beep: Optional[Literal["err_first"]] = None

Expand All @@ -121,3 +125,6 @@ class NeedsenumobjectErrorsFooAnyOf_1(RiverError):

class NeedsenumobjectErrors(RiverError):
foo: Optional[NeedsenumobjectErrorsFoo] = None


NeedsenumobjectErrorsTypeAdapter = TypeAdapter(NeedsenumobjectErrors) # type: ignore
8 changes: 6 additions & 2 deletions tests/codegen/stream/generated/test_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
from pydantic import TypeAdapter

from replit_river.error_schema import RiverError

RiverErrorTypeAdapter = TypeAdapter(RiverError)
import replit_river as river


from .stream_method import (
Stream_MethodInput,
Stream_MethodInputTypeAdapter,
Stream_MethodOutput,
Stream_MethodOutputTypeAdapter,
encode_Stream_MethodInput,
)

Expand All @@ -31,10 +35,10 @@ async def stream_method(
inputStream,
None,
encode_Stream_MethodInput,
lambda x: TypeAdapter(Stream_MethodOutput).validate_python(
lambda x: Stream_MethodOutputTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: TypeAdapter(RiverError).validate_python(
lambda x: RiverErrorTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,11 @@ class Stream_MethodInput(TypedDict):
data: str


Stream_MethodInputTypeAdapter = TypeAdapter(Stream_MethodInput) # type: ignore


class Stream_MethodOutput(BaseModel):
data: str


Stream_MethodOutputTypeAdapter = TypeAdapter(Stream_MethodOutput) # type: ignore
Loading