Skip to content

Commit 1c55778

Browse files
Fixing two incorrect type comparisons
1 parent bedaab0 commit 1c55778

2 files changed

Lines changed: 30 additions & 33 deletions

File tree

src/replit_river/codegen/client.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ListTypeExpr,
3131
LiteralTypeExpr,
3232
ModuleName,
33+
NoneTypeExpr,
3334
OpenUnionTypeExpr,
3435
RenderedPath,
3536
TypeExpression,
@@ -170,7 +171,7 @@ def encode_type(
170171
encoder_name: TypeName | None = None # defining this up here to placate mypy
171172
chunks: List[FileContents] = []
172173
if isinstance(type, RiverNotType):
173-
return (TypeName("None"), [], [], set())
174+
return (NoneTypeExpr(), [], [], set())
174175
elif isinstance(type, RiverUnionType):
175176
typeddict_encoder = list[str]()
176177
encoder_names: set[TypeName] = set()
@@ -379,17 +380,17 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
379380
typeddict_encoder.append(
380381
f"encode_{render_literal_type(inner_type_name)}(x)"
381382
)
382-
case DictTypeExpr(_):
383-
raise ValueError(
384-
"What does it mean to try and encode a dict in"
385-
" this position?"
386-
)
387383
case LiteralTypeExpr(const):
388384
typeddict_encoder.append(repr(const))
385+
case TypeName(value):
386+
typeddict_encoder.append(f"encode_{value}(x)")
387+
case NoneTypeExpr():
388+
typeddict_encoder.append("None")
389389
case other:
390-
typeddict_encoder.append(
391-
f"encode_{render_literal_type(other)}(x)"
390+
_o2: DictTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = (
391+
other
392392
)
393+
raise ValueError(f"What does it mean to have {_o2} here?")
393394
if permit_unknown_members:
394395
union = OpenUnionTypeExpr(UnionTypeExpr(any_of))
395396
else:
@@ -471,7 +472,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
471472
return (TypeName("bool"), [], [], set())
472473
elif type.type == "null" or type.type == "undefined":
473474
typeddict_encoder.append("None")
474-
return (TypeName("None"), [], [], set())
475+
return (NoneTypeExpr(), [], [], set())
475476
elif type.type == "Date":
476477
typeddict_encoder.append("TODO: dstewart")
477478
return (TypeName("datetime.datetime"), [], [], set())
@@ -511,8 +512,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
511512
)
512513
case LiteralTypeExpr(const):
513514
typeddict_encoder.append(repr(const))
515+
case TypeName(value):
516+
typeddict_encoder.append(f"encode_{value}(x)")
514517
case other:
515-
typeddict_encoder.append(f"encode_{render_literal_type(other)}(x)")
518+
_o1: NoneTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = other
519+
raise ValueError(f"What does it mean to have {_o1} here?")
516520
return (DictTypeExpr(type_name), module_info, type_chunks, encoder_names)
517521
assert type.type == "object", type.type
518522

@@ -823,7 +827,7 @@ def __init__(self, client: river.Client[Any]):
823827
module_names,
824828
permit_unknown_members=True,
825829
)
826-
if error_type == "None":
830+
if isinstance(error_type, NoneTypeExpr):
827831
error_type = TypeName("RiverError")
828832
else:
829833
serdes.append(
@@ -916,7 +920,7 @@ def __init__(self, client: river.Client[Any]):
916920
f"Unable to derive the input encoder from: {input_type}"
917921
)
918922

919-
if output_type == "None":
923+
if isinstance(output_type, NoneTypeExpr):
920924
parse_output_method = "lambda x: None"
921925

922926
if procedure.type == "rpc":

src/replit_river/codegen/typing.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ def __str__(self) -> str:
1717
raise Exception("Complex type must be put through render_type_expr!")
1818

1919

20+
@dataclass(frozen=True)
21+
class NoneTypeExpr:
22+
def __str__(self) -> str:
23+
raise Exception("Complex type must be put through render_type_expr!")
24+
25+
2026
@dataclass(frozen=True)
2127
class DictTypeExpr:
2228
nested: "TypeExpression"
@@ -59,6 +65,7 @@ def __str__(self) -> str:
5965

6066
TypeExpression = (
6167
TypeName
68+
| NoneTypeExpr
6269
| DictTypeExpr
6370
| ListTypeExpr
6471
| LiteralTypeExpr
@@ -86,6 +93,8 @@ def render_type_expr(value: TypeExpression) -> str:
8693
)
8794
case TypeName(name):
8895
return name
96+
case NoneTypeExpr():
97+
return "None"
8998
case other:
9099
assert_never(other)
91100

@@ -112,33 +121,17 @@ def extract_inner_type(value: TypeExpression) -> TypeName:
112121
)
113122
case TypeName(name):
114123
return TypeName(name)
124+
case NoneTypeExpr():
125+
raise ValueError(f"Attempting to extract from a literal 'None': {value}")
115126
case other:
116127
assert_never(other)
117128

118129

119130
def ensure_literal_type(value: TypeExpression) -> TypeName:
120131
match value:
121-
case DictTypeExpr(_):
122-
raise ValueError(
123-
f"Unexpected expression when expecting a type name: {value}"
124-
)
125-
case ListTypeExpr(_):
126-
raise ValueError(
127-
f"Unexpected expression when expecting a type name: {value}"
128-
)
129-
case LiteralTypeExpr(_):
130-
raise ValueError(
131-
f"Unexpected expression when expecting a type name: {value}"
132-
)
133-
case UnionTypeExpr(_):
134-
raise ValueError(
135-
f"Unexpected expression when expecting a type name: {value}"
136-
)
137-
case OpenUnionTypeExpr(_):
138-
raise ValueError(
139-
f"Unexpected expression when expecting a type name: {value}"
140-
)
141132
case TypeName(name):
142133
return TypeName(name)
143134
case other:
144-
assert_never(other)
135+
raise ValueError(
136+
f"Unexpected expression when expecting a type name: {other}"
137+
)

0 commit comments

Comments
 (0)