Skip to content

Commit 875be6a

Browse files
committed
Fix mypy type handling in partial function and add regression tests
1 parent ff4b2d4 commit 875be6a

6 files changed

Lines changed: 190 additions & 83 deletions

File tree

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python
12
# Configuration file for the Sphinx documentation builder.
23
#
34
# This file does only contain a selection of the most common options. For a

returns/contrib/mypy/_features/partial.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from mypy.nodes import ARG_STAR, ARG_STAR2
66
from mypy.plugin import FunctionContext
77
from mypy.types import (
8+
AnyType,
89
CallableType,
910
FunctionLike,
1011
Instance,
1112
Overloaded,
1213
ProperType,
14+
TypeOfAny,
1315
TypeType,
1416
get_proper_type,
1517
)
@@ -51,30 +53,55 @@ def analyze(ctx: FunctionContext) -> ProperType:
5153
default_return = get_proper_type(ctx.default_return_type)
5254
if not isinstance(default_return, CallableType):
5355
return default_return
56+
return _analyze_partial(ctx, default_return)
57+
58+
59+
def _analyze_partial(
60+
ctx: FunctionContext,
61+
default_return: CallableType,
62+
) -> ProperType:
63+
if not ctx.arg_types or not ctx.arg_types[0]:
64+
# No function passed: treat as decorator factory and fallback to Any.
65+
return AnyType(TypeOfAny.implementation_artifact)
5466

5567
function_def = get_proper_type(ctx.arg_types[0][0])
5668
func_args = _AppliedArgs(ctx)
5769

5870
if len(list(filter(len, ctx.arg_types))) == 1:
5971
return function_def # this means, that `partial(func)` is called
60-
if not isinstance(function_def, _SUPPORTED_TYPES):
72+
callable_def = _coerce_to_callable(function_def, func_args)
73+
if callable_def is None:
6174
return default_return
62-
if isinstance(function_def, Instance | TypeType):
63-
# We force `Instance` and similar types to coercse to callable:
64-
function_def = func_args.get_callable_from_context()
6575

6676
is_valid, applied_args = func_args.build_from_context()
67-
if not isinstance(function_def, CallableType | Overloaded) or not is_valid:
77+
if not is_valid:
6878
return default_return
6979

7080
return _PartialFunctionReducer(
7181
default_return,
72-
function_def,
82+
callable_def,
7383
applied_args,
7484
ctx,
7585
).new_partial()
7686

7787

88+
def _coerce_to_callable(
89+
function_def: ProperType,
90+
func_args: '_AppliedArgs',
91+
) -> CallableType | Overloaded | None:
92+
if not isinstance(function_def, _SUPPORTED_TYPES):
93+
return None
94+
if isinstance(function_def, Instance | TypeType):
95+
# We force `Instance` and similar types to coerce to callable:
96+
from_context = func_args.get_callable_from_context()
97+
return (
98+
from_context
99+
if isinstance(from_context, CallableType | Overloaded)
100+
else None
101+
)
102+
return function_def
103+
104+
78105
@final
79106
class _PartialFunctionReducer:
80107
"""

returns/curry.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,41 @@
22
from functools import partial as _partial
33
from functools import wraps
44
from inspect import BoundArguments, Signature
5-
from typing import Any, TypeAlias, TypeVar
5+
from typing import Any, Generic, TypeAlias, TypeVar, overload
66

77
_ReturnType = TypeVar('_ReturnType')
8+
_Decorator: TypeAlias = Callable[
9+
[Callable[..., _ReturnType]],
10+
Callable[..., _ReturnType],
11+
]
812

913

14+
class _PartialDecorator(Generic[_ReturnType]):
15+
"""Wraps ``functools.partial`` into a decorator without nesting."""
16+
__slots__ = ('_args', '_kwargs')
17+
18+
def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
19+
self._args = args
20+
self._kwargs = kwargs
21+
22+
def __call__(self, inner: Callable[..., _ReturnType]) -> Callable[..., _ReturnType]:
23+
return _partial(inner, *self._args, **self._kwargs)
24+
25+
26+
@overload
1027
def partial(
1128
func: Callable[..., _ReturnType],
29+
/,
1230
*args: Any,
1331
**kwargs: Any,
14-
) -> Callable[..., _ReturnType]:
32+
) -> Callable[..., _ReturnType]: ...
33+
34+
35+
@overload
36+
def partial(*args: Any, **kwargs: Any) -> _Decorator: ...
37+
38+
39+
def partial(*args: Any, **kwargs: Any) -> Any:
1540
"""
1641
Typed partial application.
1742
@@ -35,7 +60,11 @@ def partial(
3560
- https://docs.python.org/3/library/functools.html#functools.partial
3661
3762
"""
38-
return _partial(func, *args, **kwargs)
63+
if args and callable(args[0]):
64+
return _partial(args[0], *args[1:], **kwargs)
65+
if args and args[0] is None:
66+
args = args[1:]
67+
return _PartialDecorator(args, kwargs)
3968

4069

4170
def curry(function: Callable[..., _ReturnType]) -> Callable[..., _ReturnType]:

tests/test_curry/test_partial.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Callable, TypeAlias, TypeVar, cast
2+
3+
from returns.curry import partial
4+
5+
_ReturnType = TypeVar('_ReturnType')
6+
_Decorator: TypeAlias = Callable[
7+
[Callable[..., _ReturnType]],
8+
Callable[..., _ReturnType],
9+
]
10+
11+
12+
def add(first: int, second: int) -> int:
13+
return first + second
14+
15+
16+
def test_partial_direct_call() -> None:
17+
add_one = partial(add, 1)
18+
assert add_one(2) == 3
19+
20+
21+
def test_partial_as_decorator_factory() -> None:
22+
decorator = cast(_Decorator[int], partial())
23+
add_with_decorator = decorator(add)
24+
assert add_with_decorator(1, 2) == 3
25+
26+
27+
def test_partial_with_none_placeholder() -> None:
28+
decorator = cast(_Decorator[int], partial(None, 1))
29+
add_with_none_decorator = decorator(add)
30+
assert add_with_none_decorator(2) == 3

typesafety/test_curry/test_partial/test_partial.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

typesafety/test_curry/test_partial/test_partial.yml

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,97 @@
150150
function: Callable[[_SecondType, _FirstType], _SecondType],
151151
):
152152
reveal_type(partial(function, default)) # N: Revealed type is "def (_FirstType`-2) -> _SecondType`-1"
153+
154+
155+
- case: partial_regression1711
156+
disable_cache: false
157+
main: |
158+
from returns.curry import partial
159+
160+
def foo(x: int, y: int, z: int) -> int:
161+
...
162+
163+
def bar(x: int) -> int:
164+
...
165+
166+
baz = partial(foo, bar(1))
167+
reveal_type(baz) # N: Revealed type is "def (y: builtins.int, z: builtins.int) -> builtins.int"
168+
169+
170+
- case: partial_optional_arg
171+
disable_cache: false
172+
main: |
173+
from returns.curry import partial
174+
175+
def test_partial_fn(
176+
first_arg: int,
177+
optional_arg: str | None,
178+
) -> tuple[int, str | None]:
179+
...
180+
181+
bound = partial(test_partial_fn, 1)
182+
reveal_type(bound) # N: Revealed type is "def (optional_arg: builtins.str | None) -> tuple[builtins.int, builtins.str | None]"
183+
184+
185+
- case: partial_decorator
186+
disable_cache: false
187+
main: |
188+
from returns.curry import partial
189+
190+
@partial(first=1)
191+
def _decorated(first: int, second: str) -> float:
192+
...
193+
194+
reveal_type(_decorated) # N: Revealed type is "Any"
195+
out: |
196+
main:3: error: Untyped decorator makes function "_decorated" untyped [misc]
197+
198+
199+
- case: partial_keyword_arg
200+
disable_cache: false
201+
main: |
202+
from returns.curry import partial
203+
204+
def test_partial_fn(
205+
first_arg: int,
206+
optional_arg: str | None,
207+
) -> tuple[int, str | None]:
208+
...
209+
210+
bound = partial(test_partial_fn, optional_arg='a')
211+
reveal_type(bound) # N: Revealed type is "def (first_arg: builtins.int) -> tuple[builtins.int, builtins.str | None]"
212+
213+
214+
- case: partial_keyword_only
215+
disable_cache: false
216+
main: |
217+
from returns.curry import partial
218+
219+
def _target(*, arg: int) -> int:
220+
...
221+
222+
bound = partial(_target, arg=1)
223+
reveal_type(bound) # N: Revealed type is "def () -> builtins.int"
224+
225+
226+
- case: partial_keyword_mixed
227+
disable_cache: false
228+
main: |
229+
from returns.curry import partial
230+
231+
def _target(arg1: int, *, arg2: int) -> int:
232+
...
233+
234+
bound = partial(_target, arg2=1)
235+
reveal_type(bound) # N: Revealed type is "def (arg1: builtins.int) -> builtins.int"
236+
237+
238+
- case: partial_wrong_signature_any
239+
disable_cache: false
240+
main: |
241+
from returns.curry import partial
242+
243+
reveal_type(partial(len, 1))
244+
out: |
245+
main:3: error: Argument 1 to "len" has incompatible type "int"; expected "Sized" [arg-type]
246+
main:3: note: Revealed type is "def (*Any, **Any) -> builtins.int"

0 commit comments

Comments
 (0)