diff --git a/eng/scripts/generate_exceptions.py b/eng/scripts/generate_exceptions.py index b57a5d82f..570caeb93 100644 --- a/eng/scripts/generate_exceptions.py +++ b/eng/scripts/generate_exceptions.py @@ -85,6 +85,7 @@ def MakeNewException(self): ExceptionInfo('Exception', 'IronPython.Runtime.Exceptions.PythonException', None, (), ( ExceptionInfo('StopIteration', 'IronPython.Runtime.Exceptions.StopIterationException', None, ('value',), ()), ExceptionInfo('StopAsyncIteration', 'IronPython.Runtime.Exceptions.StopAsyncIterationException', None, ('value',), ()), + ExceptionInfo('CancelledError', 'System.OperationCanceledException', None, (), ()), ExceptionInfo('ArithmeticError', 'System.ArithmeticException', None, (), ( ExceptionInfo('FloatingPointError', 'IronPython.Runtime.Exceptions.FloatingPointException', None, (), ()), ExceptionInfo('OverflowError', 'System.OverflowException', None, (), ()), @@ -261,7 +262,13 @@ def gen_topython_helper(cw): cw.exit_block() +_clr_name_overrides = { + 'CancelledError': 'OperationCanceledException', +} + def get_clr_name(e): + if e in _clr_name_overrides: + return _clr_name_overrides[e] return e.replace('Error', '') + 'Exception' FACTORY = """ @@ -269,8 +276,12 @@ def get_clr_name(e): public static Exception %(name)s(string format, params object?[] args) => new %(clrname)s(string.Format(format, args)); """.rstrip() +# Exceptions that map to existing CLR types (no generated CLR class needed), +# but still need factory methods in PythonOps. +_factory_only_exceptions = ['CancelledError'] + def factory_gen(cw): - for e in pythonExcs: + for e in pythonExcs + _factory_only_exceptions: cw.write(FACTORY, name=e, clrname=get_clr_name(e)) CLASS1 = """\ diff --git a/eng/scripts/generate_ops.py b/eng/scripts/generate_ops.py index 7805590aa..8f86d7765 100755 --- a/eng/scripts/generate_ops.py +++ b/eng/scripts/generate_ops.py @@ -10,7 +10,7 @@ kwlist = [ 'and', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', - 'raise', 'return', 'try', 'while', 'yield', 'as', 'with', 'async', 'nonlocal' + 'raise', 'return', 'try', 'while', 'yield', 'as', 'with', 'async', 'nonlocal', 'await' ] class Symbol: diff --git a/src/core/IronPython.StdLib/lib/test/exception_hierarchy.txt b/src/core/IronPython.StdLib/lib/test/exception_hierarchy.txt index 763a6c899..2b57a851a 100644 --- a/src/core/IronPython.StdLib/lib/test/exception_hierarchy.txt +++ b/src/core/IronPython.StdLib/lib/test/exception_hierarchy.txt @@ -5,6 +5,7 @@ BaseException +-- Exception +-- StopIteration +-- StopAsyncIteration + +-- CancelledError +-- ArithmeticError | +-- FloatingPointError | +-- OverflowError diff --git a/src/core/IronPython/Compiler/Ast/AstMethods.cs b/src/core/IronPython/Compiler/Ast/AstMethods.cs index da27a5ab0..611bf3bb1 100644 --- a/src/core/IronPython/Compiler/Ast/AstMethods.cs +++ b/src/core/IronPython/Compiler/Ast/AstMethods.cs @@ -79,6 +79,7 @@ internal static class AstMethods { public static readonly MethodInfo PushFrame = GetMethod((Func>)PythonOps.PushFrame); public static readonly MethodInfo FormatString = GetMethod((Func)PythonOps.FormatString); public static readonly MethodInfo GeneratorCheckThrowableAndReturnSendValue = GetMethod((Func)PythonOps.GeneratorCheckThrowableAndReturnSendValue); + public static readonly MethodInfo MakeCoroutine = GetMethod((Func)PythonOps.MakeCoroutine); // builtins public static readonly MethodInfo Format = GetMethod((Func)PythonOps.Format); diff --git a/src/core/IronPython/Compiler/Ast/AsyncForStatement.cs b/src/core/IronPython/Compiler/Ast/AsyncForStatement.cs new file mode 100644 index 000000000..16def8c82 --- /dev/null +++ b/src/core/IronPython/Compiler/Ast/AsyncForStatement.cs @@ -0,0 +1,146 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System.Threading; + +using Microsoft.Scripting; +using MSAst = System.Linq.Expressions; + +namespace IronPython.Compiler.Ast { + + /// + /// Represents an async for statement. + /// Desugared to Python AST that uses __aiter__ and await __anext__(). + /// + public class AsyncForStatement : Statement, ILoopStatement { + private static int _counter; + private Statement? _desugared; + + public AsyncForStatement(Expression left, Expression list, Statement body, Statement? @else) { + Left = left; + List = list; + Body = body; + Else = @else; + } + + public int HeaderIndex { private get; set; } + + public Expression Left { get; } + + public Expression List { get; set; } + + public Statement Body { get; set; } + + public Statement? Else { get; } + + MSAst.LabelTarget ILoopStatement.BreakLabel { get; set; } = null!; + + MSAst.LabelTarget ILoopStatement.ContinueLabel { get; set; } = null!; + + /// + /// Build the desugared tree. Called during Walk when Parent and IndexSpan are available. + /// + private Statement BuildDesugared() { + var parent = Parent; + var span = IndexSpan; + var id = Interlocked.Increment(ref _counter); + + // async for TARGET in ITER: + // BLOCK + // else: + // ELSE_BLOCK + // + // desugars to: + // + // __aiter = ITER.__aiter__() + // __running = True + // while __running: + // try: + // TARGET = await __aiter.__anext__() + // except StopAsyncIteration: + // __running = False + // else: + // BLOCK + // else: + // ELSE_BLOCK + + var iterName = $"__asyncfor_iter{id}"; + var runningName = $"__asyncfor_running{id}"; + + // Helper to create nodes with proper parent and span + NameExpression MakeName(string name) { + var n = new NameExpression(name) { Parent = parent }; + n.IndexSpan = span; + return n; + } + + T WithSpan(T node) where T : Node { + node.IndexSpan = span; + return node; + } + + // _iter = ITER.__aiter__() + var aiterAttr = WithSpan(new MemberExpression(List, "__aiter__") { Parent = parent }); + var aiterCall = WithSpan(new CallExpression(aiterAttr, null, null) { Parent = parent }); + var assignIter = WithSpan(new AssignmentStatement(new Expression[] { MakeName(iterName) }, aiterCall) { Parent = parent }); + + // running = True + var trueConst = new ConstantExpression(true) { Parent = parent }; trueConst.IndexSpan = span; + var assignRunning = WithSpan(new AssignmentStatement(new Expression[] { MakeName(runningName) }, trueConst) { Parent = parent }); + + // TARGET = await __aiter.__anext__() + var anextAttr = WithSpan(new MemberExpression(MakeName(iterName), "__anext__") { Parent = parent }); + var anextCall = WithSpan(new CallExpression(anextAttr, null, null) { Parent = parent }); + var awaitNext = new AwaitExpression(anextCall); + var assignTarget = WithSpan(new AssignmentStatement(new Expression[] { Left }, awaitNext) { Parent = parent }); + + // except StopAsyncIteration: __running = False + var falseConst = new ConstantExpression(false) { Parent = parent }; falseConst.IndexSpan = span; + var stopRunning = WithSpan(new AssignmentStatement( + new Expression[] { MakeName(runningName) }, falseConst) { Parent = parent }); + var handler = WithSpan(new TryStatementHandler( + MakeName("StopAsyncIteration"), + null!, + WithSpan(new SuiteStatement(new Statement[] { stopRunning }) { Parent = parent }) + ) { Parent = parent }); + handler.HeaderIndex = span.End; + + // try/except/else block + var tryExcept = WithSpan(new TryStatement( + assignTarget, + new[] { handler }, + WithSpan(new SuiteStatement(new Statement[] { Body }) { Parent = parent }), + null! + ) { Parent = parent }); + tryExcept.HeaderIndex = span.End; + + // while __running: try/except/else + var whileStmt = new WhileStatement(MakeName(runningName), tryExcept, Else); + whileStmt.SetLoc(GlobalParent, span.Start, span.End, span.End); + whileStmt.Parent = parent; + + var suite = WithSpan(new SuiteStatement(new Statement[] { assignIter, assignRunning, whileStmt }) { Parent = parent }); + return suite; + } + + public override MSAst.Expression Reduce() { + return _desugared!.Reduce(); + } + + public override void Walk(PythonWalker walker) { + if (walker.Walk(this)) { + // Build the desugared tree on first walk (when Parent and IndexSpan are set) + if (_desugared == null) { + _desugared = BuildDesugared(); + } + _desugared.Walk(walker); + } + walker.PostWalk(this); + } + + internal override bool CanThrow => true; + } +} diff --git a/src/core/IronPython/Compiler/Ast/AsyncWithStatement.cs b/src/core/IronPython/Compiler/Ast/AsyncWithStatement.cs new file mode 100644 index 000000000..804926de8 --- /dev/null +++ b/src/core/IronPython/Compiler/Ast/AsyncWithStatement.cs @@ -0,0 +1,123 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using Microsoft.Scripting; +using MSAst = System.Linq.Expressions; + +using AstUtils = Microsoft.Scripting.Ast.Utils; + +namespace IronPython.Compiler.Ast { + using Ast = MSAst.Expression; + + /// + /// Represents an async with statement. + /// Desugared to Python AST that uses await on __aenter__ and __aexit__. + /// + public class AsyncWithStatement : Statement { + private Statement? _desugared; + + public AsyncWithStatement(Expression contextManager, Expression? var, Statement body) { + ContextManager = contextManager; + Variable = var; + Body = body; + } + + public int HeaderIndex { private get; set; } + + public Expression ContextManager { get; } + + public new Expression? Variable { get; } + + public Statement Body { get; } + + /// + /// Build the desugared tree. Called during Walk when Parent and IndexSpan are available. + /// + private Statement BuildDesugared() { + var parent = Parent; + var span = IndexSpan; + + // async with EXPR as VAR: + // BLOCK + // + // desugars to: + // + // mgr = EXPR + // try: + // VAR = await mgr.__aenter__() (or just await mgr.__aenter__()) + // BLOCK + // finally: + // await mgr.__aexit__(None, None, None) + + // Helper to create nodes with proper parent and span + NameExpression MakeName(string name) { + var n = new NameExpression(name) { Parent = parent }; + n.IndexSpan = span; + return n; + } + + // mgr = EXPR + var assignMgr = new AssignmentStatement(new Expression[] { MakeName("__asyncwith_mgr") }, ContextManager) { Parent = parent }; + assignMgr.IndexSpan = span; + + // await mgr.__aenter__() + var aenterAttr = new MemberExpression(MakeName("__asyncwith_mgr"), "__aenter__") { Parent = parent }; + aenterAttr.IndexSpan = span; + var aenterCall = new CallExpression(aenterAttr, null, null) { Parent = parent }; + aenterCall.IndexSpan = span; + var awaitEnter = new AwaitExpression(aenterCall); + + Statement bodyStmt; + if (Variable != null) { + // VAR = await value; BLOCK + var assignVar = new AssignmentStatement(new Expression[] { Variable }, awaitEnter) { Parent = parent }; + assignVar.IndexSpan = span; + bodyStmt = new SuiteStatement(new Statement[] { assignVar, Body }) { Parent = parent }; + } else { + var exprStmt = new ExpressionStatement(awaitEnter) { Parent = parent }; + exprStmt.IndexSpan = span; + bodyStmt = new SuiteStatement(new Statement[] { exprStmt, Body }) { Parent = parent }; + } + + // await mgr.__aexit__(None, None, None) + var aexitAttr = new MemberExpression(MakeName("__asyncwith_mgr"), "__aexit__") { Parent = parent }; + aexitAttr.IndexSpan = span; + var none1 = new ConstantExpression(null) { Parent = parent }; none1.IndexSpan = span; + var none2 = new ConstantExpression(null) { Parent = parent }; none2.IndexSpan = span; + var none3 = new ConstantExpression(null) { Parent = parent }; none3.IndexSpan = span; + var aexitCallNormal = new CallExpression(aexitAttr, + new Expression[] { none1, none2, none3 }, null) { Parent = parent }; + aexitCallNormal.IndexSpan = span; + var awaitExitNormal = new AwaitExpression(aexitCallNormal); + + // try/finally: await __aexit__ on normal exit + var finallyExprStmt = new ExpressionStatement(awaitExitNormal) { Parent = parent }; + finallyExprStmt.IndexSpan = span; + var tryFinally = new TryStatement(bodyStmt, null, null, finallyExprStmt) { Parent = parent }; + tryFinally.IndexSpan = span; + tryFinally.HeaderIndex = span.End; + + var suite = new SuiteStatement(new Statement[] { assignMgr, tryFinally }) { Parent = parent }; + suite.IndexSpan = span; + return suite; + } + + public override MSAst.Expression Reduce() { + return _desugared!.Reduce(); + } + + public override void Walk(PythonWalker walker) { + if (walker.Walk(this)) { + // Build the desugared tree on first walk (when Parent and IndexSpan are set) + if (_desugared == null) { + _desugared = BuildDesugared(); + } + _desugared.Walk(walker); + } + walker.PostWalk(this); + } + } +} diff --git a/src/core/IronPython/Compiler/Ast/AwaitExpression.cs b/src/core/IronPython/Compiler/Ast/AwaitExpression.cs new file mode 100644 index 000000000..94b960aae --- /dev/null +++ b/src/core/IronPython/Compiler/Ast/AwaitExpression.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using MSAst = System.Linq.Expressions; + +using AstUtils = Microsoft.Scripting.Ast.Utils; + +namespace IronPython.Compiler.Ast { + using Ast = MSAst.Expression; + + /// + /// Represents an await expression. Implemented as yield from expr.__await__(). + /// + public class AwaitExpression : Expression { + private readonly Statement _statement; + private readonly NameExpression _result; + + public AwaitExpression(Expression expression) { + Expression = expression; + + // await expr is equivalent to yield from expr.__await__() + // We build: __awaitprefix_EXPR = expr; yield from __awaitprefix_EXPR.__await__(); __awaitprefix_r = __yieldfromprefix_r + var parent = expression.Parent; + + var awaitableExpr = new NameExpression("__awaitprefix_EXPR") { Parent = parent }; + var getAwait = new MemberExpression(awaitableExpr, "__await__") { Parent = parent }; + var callAwait = new CallExpression(getAwait, null, null) { Parent = parent }; + var yieldFrom = new YieldFromExpression(callAwait); + + Statement s1 = new AssignmentStatement(new Expression[] { new NameExpression("__awaitprefix_EXPR") { Parent = parent } }, expression) { Parent = parent }; + Statement s2 = new ExpressionStatement(yieldFrom) { Parent = parent }; + Statement s3 = new AssignmentStatement( + new Expression[] { new NameExpression("__awaitprefix_r") { Parent = parent } }, + new NameExpression("__yieldfromprefix_r") { Parent = parent } + ) { Parent = parent }; + + _statement = new SuiteStatement(new Statement[] { s1, s2, s3 }) { Parent = parent }; + + _result = new NameExpression("__awaitprefix_r") { Parent = parent }; + } + + public Expression Expression { get; } + + public override MSAst.Expression Reduce() { + return Ast.Block( + typeof(object), + _statement, + AstUtils.Convert(_result, typeof(object)) + ).Reduce(); + } + + public override void Walk(PythonWalker walker) { + if (walker.Walk(this)) { + Expression?.Walk(walker); + _statement.Walk(walker); + _result.Walk(walker); + } + walker.PostWalk(this); + } + + public override string NodeName => "await expression"; + } +} diff --git a/src/core/IronPython/Compiler/Ast/FunctionDefinition.cs b/src/core/IronPython/Compiler/Ast/FunctionDefinition.cs index bb5d18d35..900ecb86f 100644 --- a/src/core/IronPython/Compiler/Ast/FunctionDefinition.cs +++ b/src/core/IronPython/Compiler/Ast/FunctionDefinition.cs @@ -117,7 +117,7 @@ internal override int KwOnlyArgCount { public Expression ReturnAnnotation { get; internal set; } - internal override bool IsGeneratorMethod => IsGenerator; + internal override bool IsGeneratorMethod => IsGenerator || IsAsync; /// /// The function is a generator @@ -182,10 +182,14 @@ internal override FunctionAttributes Flags { fa |= FunctionAttributes.ContainsTryFinally; } - if (IsGenerator) { + if (IsGenerator || IsAsync) { fa |= FunctionAttributes.Generator; } + if (IsAsync) { + fa |= FunctionAttributes.Coroutine; + } + if (GeneratorStop) { fa |= FunctionAttributes.GeneratorStop; } @@ -353,8 +357,8 @@ internal MSAst.Expression MakeFunctionExpression() { annotations ) ), - IsGenerator ? - (MSAst.Expression)new PythonGeneratorExpression(code, GlobalParent.PyContext.Options.CompilationThreshold) : + (IsGenerator || IsAsync) ? + (MSAst.Expression)new PythonGeneratorExpression(code, GlobalParent.PyContext.Options.CompilationThreshold, IsAsync) : (MSAst.Expression)code ); } else { @@ -652,10 +656,10 @@ private LightLambdaExpression CreateFunctionLambda() { new SourceSpan(new SourceLocation(0, start.Line, start.Column), new SourceLocation(0, start.Line, int.MaxValue)))); - // For generators, we need to do a check before the first statement for Generator.Throw() / Generator.Close(). + // For generators/coroutines, we need to do a check before the first statement for Generator.Throw() / Generator.Close(). // The exception traceback needs to come from the generator's method body, and so we must do the check and throw // from inside the generator. - if (IsGenerator) { + if (IsGenerator || IsAsync) { MSAst.Expression s1 = YieldExpression.CreateCheckThrowExpression(SourceSpan.None); statements.Add(s1); } diff --git a/src/core/IronPython/Compiler/Ast/PythonNameBinder.cs b/src/core/IronPython/Compiler/Ast/PythonNameBinder.cs index 68e5eb2c4..2110c69db 100644 --- a/src/core/IronPython/Compiler/Ast/PythonNameBinder.cs +++ b/src/core/IronPython/Compiler/Ast/PythonNameBinder.cs @@ -346,11 +346,26 @@ public override bool Walk(AssertStatement node) { node.Parent = _currentScope; return base.Walk(node); } + // AsyncForStatement + public override bool Walk(AsyncForStatement node) { + node.Parent = _currentScope; + return base.Walk(node); + } // AsyncStatement public override bool Walk(AsyncStatement node) { node.Parent = _currentScope; return base.Walk(node); } + // AsyncWithStatement + public override bool Walk(AsyncWithStatement node) { + node.Parent = _currentScope; + return base.Walk(node); + } + // AwaitExpression + public override bool Walk(AwaitExpression node) { + node.Parent = _currentScope; + return base.Walk(node); + } // BinaryExpression public override bool Walk(BinaryExpression node) { node.Parent = _currentScope; diff --git a/src/core/IronPython/Compiler/Ast/PythonWalker.Generated.cs b/src/core/IronPython/Compiler/Ast/PythonWalker.Generated.cs index 936c34425..c97de6a5c 100644 --- a/src/core/IronPython/Compiler/Ast/PythonWalker.Generated.cs +++ b/src/core/IronPython/Compiler/Ast/PythonWalker.Generated.cs @@ -20,6 +20,10 @@ public class PythonWalker { public virtual bool Walk(AndExpression node) { return true; } public virtual void PostWalk(AndExpression node) { } + // AwaitExpression + public virtual bool Walk(AwaitExpression node) { return true; } + public virtual void PostWalk(AwaitExpression node) { } + // BinaryExpression public virtual bool Walk(BinaryExpression node) { return true; } public virtual void PostWalk(BinaryExpression node) { } @@ -136,10 +140,18 @@ public virtual void PostWalk(AssertStatement node) { } public virtual bool Walk(AssignmentStatement node) { return true; } public virtual void PostWalk(AssignmentStatement node) { } + // AsyncForStatement + public virtual bool Walk(AsyncForStatement node) { return true; } + public virtual void PostWalk(AsyncForStatement node) { } + // AsyncStatement public virtual bool Walk(AsyncStatement node) { return true; } public virtual void PostWalk(AsyncStatement node) { } + // AsyncWithStatement + public virtual bool Walk(AsyncWithStatement node) { return true; } + public virtual void PostWalk(AsyncWithStatement node) { } + // AugmentedAssignStatement public virtual bool Walk(AugmentedAssignStatement node) { return true; } public virtual void PostWalk(AugmentedAssignStatement node) { } @@ -279,6 +291,10 @@ public class PythonWalkerNonRecursive : PythonWalker { public override bool Walk(AndExpression node) { return false; } public override void PostWalk(AndExpression node) { } + // AwaitExpression + public override bool Walk(AwaitExpression node) { return false; } + public override void PostWalk(AwaitExpression node) { } + // BinaryExpression public override bool Walk(BinaryExpression node) { return false; } public override void PostWalk(BinaryExpression node) { } @@ -395,10 +411,18 @@ public override void PostWalk(AssertStatement node) { } public override bool Walk(AssignmentStatement node) { return false; } public override void PostWalk(AssignmentStatement node) { } + // AsyncForStatement + public override bool Walk(AsyncForStatement node) { return false; } + public override void PostWalk(AsyncForStatement node) { } + // AsyncStatement public override bool Walk(AsyncStatement node) { return false; } public override void PostWalk(AsyncStatement node) { } + // AsyncWithStatement + public override bool Walk(AsyncWithStatement node) { return false; } + public override void PostWalk(AsyncWithStatement node) { } + // AugmentedAssignStatement public override bool Walk(AugmentedAssignStatement node) { return false; } public override void PostWalk(AugmentedAssignStatement node) { } diff --git a/src/core/IronPython/Compiler/GeneratorRewriter.cs b/src/core/IronPython/Compiler/GeneratorRewriter.cs index ee48035eb..69784249c 100644 --- a/src/core/IronPython/Compiler/GeneratorRewriter.cs +++ b/src/core/IronPython/Compiler/GeneratorRewriter.cs @@ -57,9 +57,12 @@ internal sealed class GeneratorRewriter : DynamicExpressionVisitor { internal const int Finished = 0; internal static ParameterExpression _generatorParam = Expression.Parameter(typeof(PythonGenerator), "$generator"); - internal GeneratorRewriter(string name, Expression body) { + private readonly bool _isCoroutine; + + internal GeneratorRewriter(string name, Expression body, bool isCoroutine = false) { _body = body; _name = name; + _isCoroutine = isCoroutine; _returnLabels.Push(Expression.Label("retLabel")); _gotoRouter = Expression.Variable(typeof(int), "$gotoRouter"); } @@ -133,27 +136,47 @@ internal Expression Reduce(bool shouldInterpret, bool emitDebugSymbols, int comp new ParameterExpression[] { tupleArg } ); - // Generate a call to PythonOps.MakeGeneratorClosure(Tuple data, object generatorCode) + // Generate a call to PythonOps.MakeGenerator(Tuple data, object generatorCode) + // For coroutines, we wrap the result in PythonOps.MakeCoroutineWrapper after creating the generator + Expression generatorExpr = Expression.Call( + typeof(PythonOps).GetMethod(nameof(PythonOps.MakeGenerator)), + parameters[0], + Expression.Assign(tupleTmp, newTuple), + emitDebugSymbols ? + (Expression)bodyConverter(innerLambda) : + (Expression)Expression.Constant( + new LazyCode>( + bodyConverter(innerLambda), + shouldInterpret, + compilationThreshold + ), + typeof(object) + ) + ); + + if (_isCoroutine) { + ParameterExpression coroutineRet = Expression.Parameter(typeof(object), "coroutineRet"); + return Expression.Block( + new[] { tupleTmp, ret, coroutineRet }, + Expression.Assign(ret, generatorExpr), + new DelayedTupleAssign( + new DelayedTupleExpression(liftedGen.Index, new StrongBox(tupleTmp), _tupleType, _tupleSize, typeof(PythonGenerator)), + ret + ), + Expression.Assign( + coroutineRet, + Expression.Call( + typeof(PythonOps).GetMethod(nameof(PythonOps.MakeCoroutineWrapper)), + ret + ) + ), + coroutineRet + ); + } + return Expression.Block( new[] { tupleTmp, ret }, - Expression.Assign( - ret, - Expression.Call( - typeof(PythonOps).GetMethod(nameof(PythonOps.MakeGenerator)), - parameters[0], - Expression.Assign(tupleTmp, newTuple), - emitDebugSymbols ? - (Expression)bodyConverter(innerLambda) : - (Expression)Expression.Constant( - new LazyCode>( - bodyConverter(innerLambda), - shouldInterpret, - compilationThreshold - ), - typeof(object) - ) - ) - ), + Expression.Assign(ret, generatorExpr), new DelayedTupleAssign( new DelayedTupleExpression(liftedGen.Index, new StrongBox(tupleTmp), _tupleType, _tupleSize, typeof(PythonGenerator)), ret @@ -589,11 +612,15 @@ protected override Expression VisitExtension(Expression node) { return VisitYield(yield); } - if (node is FinallyFlowControlExpression ffc) { - return Visit(node.ReduceExtensions()); + // Reduce one level and re-visit so that extension nodes produced + // during reduction (e.g. YieldExpression from ReturnStatement + // inside DebugInfoRemovalExpression) are properly intercepted + // by this visitor instead of being reduced again by ReduceExtensions(). + var reduced = node.Reduce(); + if (reduced == node) { + throw new InvalidOperationException("node must be reducible"); } - - return Visit(node.ReduceExtensions()); + return Visit(reduced); } private Expression VisitYield(YieldExpression node) { @@ -1065,14 +1092,16 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) { internal sealed class PythonGeneratorExpression : Expression { private readonly LightLambdaExpression _lambda; private readonly int _compilationThreshold; + private readonly bool _isCoroutine; - public PythonGeneratorExpression(LightLambdaExpression lambda, int compilationThreshold) { + public PythonGeneratorExpression(LightLambdaExpression lambda, int compilationThreshold, bool isCoroutine = false) { _lambda = lambda; _compilationThreshold = compilationThreshold; + _isCoroutine = isCoroutine; } public override Expression Reduce() { - return _lambda.ToGenerator(false, true, _compilationThreshold); + return _lambda.ToGenerator(false, true, _compilationThreshold, _isCoroutine); } public sealed override ExpressionType NodeType { diff --git a/src/core/IronPython/Compiler/Parser.cs b/src/core/IronPython/Compiler/Parser.cs index 519e4b8da..8853fe6f9 100644 --- a/src/core/IronPython/Compiler/Parser.cs +++ b/src/core/IronPython/Compiler/Parser.cs @@ -1457,15 +1457,63 @@ private WithItem ParseWithItem() { // async_stmt: 'async' (funcdef | with_stmt | for_stmt) private Statement ParseAsyncStmt() { + var start = GetStart(); Eat(TokenKind.KeywordAsync); - ReportSyntaxError("invalid syntax"); - if (PeekToken().Kind == TokenKind.KeywordDef) { - FunctionDefinition def = ParseFuncDef(true); - return def; + switch (PeekToken().Kind) { + case TokenKind.KeywordDef: + return ParseFuncDef(true); + case TokenKind.KeywordWith: + return ParseAsyncWithStmt(start); + case TokenKind.KeywordFor: + return ParseAsyncForStmt(start); + default: + ReportSyntaxError("invalid syntax"); + return null; } + } - return null; + private AsyncWithStatement ParseAsyncWithStmt(int asyncStart) { + FunctionDefinition current = CurrentFunction; + if (current == null || !current.IsAsync) { + ReportSyntaxError("'async with' outside async function"); + } + + Eat(TokenKind.KeywordWith); + var withItem = ParseWithItem(); + var header = GetEnd(); + Statement body = ParseSuite(); + AsyncWithStatement ret = new AsyncWithStatement(withItem.ContextManager, withItem.Variable, body); + ret.HeaderIndex = header; + ret.SetLoc(_globalParent, asyncStart, GetEnd()); + return ret; + } + + private AsyncForStatement ParseAsyncForStmt(int asyncStart) { + FunctionDefinition current = CurrentFunction; + if (current == null || !current.IsAsync) { + ReportSyntaxError("'async for' outside async function"); + } + + Eat(TokenKind.KeywordFor); + var start = GetStart(); + + bool trailingComma; + List l = ParseExprList(out trailingComma); + + Expression lhs = MakeTupleOrExpr(l, trailingComma); + Eat(TokenKind.KeywordIn); + Expression list = ParseTestList(); + var header = GetEnd(); + Statement body = ParseLoopSuite(); + Statement else_ = null; + if (MaybeEat(TokenKind.KeywordElse)) { + else_ = ParseSuite(); + } + AsyncForStatement ret = new AsyncForStatement(lhs, list, body, else_); + ret.HeaderIndex = header; + ret.SetLoc(_globalParent, asyncStart, GetEnd()); + return ret; } // for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite] @@ -1910,8 +1958,11 @@ private Expression FinishUnaryNegate() { return new UnaryExpression(PythonOperator.Negate, ParseFactor()); } - // power: atom trailer* ['**' factor] + // power: ['await'] atom trailer* ['**' factor] private Expression ParsePower() { + if (MaybeEat(TokenKind.KeywordAwait)) { + return ParseAwaitExpression(); + } Expression ret = ParseAtom(); ret = AddTrailers(ret); if (MaybeEat(TokenKind.Power)) { @@ -1922,6 +1973,28 @@ private Expression ParsePower() { return ret; } + // await_expr: 'await' unary_expr (essentially power level) + private Expression ParseAwaitExpression() { + FunctionDefinition current = CurrentFunction; + if (current == null || !current.IsAsync) { + ReportSyntaxError("'await' outside async function"); + } + + if (current != null) { + current.IsGenerator = true; + current.GeneratorStop = GeneratorStop; + } + + var start = GetStart(); + + // Parse the awaitable expression at the unary level + Expression expr = ParsePower(); + + var ret = new AwaitExpression(expr); + ret.SetLoc(_globalParent, start, GetEnd()); + return ret; + } + //atom: ('(' [yield_expr|testlist_comp] ')' | // '[' [testlist_comp] ']' | // '{' [dictorsetmaker] '}' | diff --git a/src/core/IronPython/Compiler/TokenKind.Generated.cs b/src/core/IronPython/Compiler/TokenKind.Generated.cs index ee622564d..25af2cb76 100644 --- a/src/core/IronPython/Compiler/TokenKind.Generated.cs +++ b/src/core/IronPython/Compiler/TokenKind.Generated.cs @@ -73,7 +73,7 @@ public enum TokenKind { ReturnAnnotation = 76, FirstKeyword = KeywordAnd, - LastKeyword = KeywordNonlocal, + LastKeyword = KeywordAwait, KeywordAnd = 77, KeywordAssert = 78, KeywordBreak = 79, @@ -105,6 +105,7 @@ public enum TokenKind { KeywordWith = 105, KeywordAsync = 106, KeywordNonlocal = 107, + KeywordAwait = 108, // *** END GENERATED CODE *** @@ -193,6 +194,7 @@ public static class Tokens { public static Token KeywordAsToken { get; } = new SymbolToken(TokenKind.KeywordAs, "as"); public static Token KeywordAssertToken { get; } = new SymbolToken(TokenKind.KeywordAssert, "assert"); public static Token KeywordAsyncToken { get; } = new SymbolToken(TokenKind.KeywordAsync, "async"); + public static Token KeywordAwaitToken { get; } = new SymbolToken(TokenKind.KeywordAwait, "await"); public static Token KeywordBreakToken { get; } = new SymbolToken(TokenKind.KeywordBreak, "break"); public static Token KeywordClassToken { get; } = new SymbolToken(TokenKind.KeywordClass, "class"); public static Token KeywordContinueToken { get; } = new SymbolToken(TokenKind.KeywordContinue, "continue"); diff --git a/src/core/IronPython/Compiler/Tokenizer.cs b/src/core/IronPython/Compiler/Tokenizer.cs index a5c1b5a05..36a586a68 100644 --- a/src/core/IronPython/Compiler/Tokenizer.cs +++ b/src/core/IronPython/Compiler/Tokenizer.cs @@ -1059,6 +1059,11 @@ private Token ReadName() { return Tokens.KeywordAsyncToken; } } + } else if (ch == 'w') { + if (NextChar() == 'a' && NextChar() == 'i' && NextChar() == 't' && !IsNamePart(Peek())) { + MarkTokenEnd(); + return Tokens.KeywordAwaitToken; + } } } else if (ch == 'b') { if (NextChar() == 'r' && NextChar() == 'e' && NextChar() == 'a' && NextChar() == 'k' && !IsNamePart(Peek())) { diff --git a/src/core/IronPython/Modules/Builtin.Generated.cs b/src/core/IronPython/Modules/Builtin.Generated.cs index 8a8fda044..6c3c359b1 100644 --- a/src/core/IronPython/Modules/Builtin.Generated.cs +++ b/src/core/IronPython/Modules/Builtin.Generated.cs @@ -20,6 +20,7 @@ public static partial class Builtin { public static PythonType Exception => PythonExceptions.Exception; public static PythonType StopIteration => PythonExceptions.StopIteration; public static PythonType StopAsyncIteration => PythonExceptions.StopAsyncIteration; + public static PythonType CancelledError => PythonExceptions.CancelledError; public static PythonType ArithmeticError => PythonExceptions.ArithmeticError; public static PythonType FloatingPointError => PythonExceptions.FloatingPointError; public static PythonType OverflowError => PythonExceptions.OverflowError; diff --git a/src/core/IronPython/Modules/_ast.cs b/src/core/IronPython/Modules/_ast.cs index 32edce63a..294bf144d 100755 --- a/src/core/IronPython/Modules/_ast.cs +++ b/src/core/IronPython/Modules/_ast.cs @@ -231,6 +231,8 @@ internal static stmt Convert(Statement stmt) { GlobalStatement s => new Global(s), NonlocalStatement s => new Nonlocal(s), ClassDefinition s => new ClassDef(s), + AsyncForStatement s => new AsyncFor(s), + AsyncWithStatement s => new AsyncWith(s), BreakStatement _ => new Break(), ContinueStatement _ => new Continue(), EmptyStatement _ => new Pass(), @@ -295,6 +297,7 @@ internal static expr Convert(AstExpression expr, expr_context ctx) { MemberExpression x => new Attribute(x, ctx), YieldExpression x => new Yield(x), YieldFromExpression x => new YieldFrom(x), + AwaitExpression x => new Await(x), ConditionalExpression x => new IfExp(x), IndexExpression x => new Subscript(x, ctx), SetExpression x => new Set(x), @@ -3036,5 +3039,107 @@ internal override AstExpression Revert() { public expr value { get; set; } } + + [PythonType] + public class Await : expr { + public Await() { + _fields = PythonTuple.MakeTuple(new[] { nameof(value), }); + } + + public Await([Optional] expr value, [Optional] int? lineno, [Optional] int? col_offset) + : this() { + this.value = value; + _lineno = lineno; + _col_offset = col_offset; + } + + internal Await(AwaitExpression expr) + : this() { + value = Convert(expr.Expression); + } + + internal override AstExpression Revert() { + _containsYield = true; + return new AwaitExpression(expr.Revert(value)); + } + + public expr value { get; set; } + } + + [PythonType] + public class AsyncFor : stmt { + public AsyncFor() { + _fields = PythonTuple.MakeTuple(new[] { nameof(target), nameof(iter), nameof(body), nameof(orelse) }); + } + + public AsyncFor(expr target, expr iter, PythonList body, [Optional] PythonList orelse, + [Optional] int? lineno, [Optional] int? col_offset) + : this() { + this.target = target; + this.iter = iter; + this.body = body; + if (null == orelse) + this.orelse = new PythonList(); + else + this.orelse = orelse; + _lineno = lineno; + _col_offset = col_offset; + } + + internal AsyncFor(AsyncForStatement stmt) + : this() { + target = Convert(stmt.Left, Store.Instance); + iter = Convert(stmt.List); + body = ConvertStatements(stmt.Body); + orelse = ConvertStatements(stmt.Else, true); + } + + internal override Statement Revert() { + return new AsyncForStatement(expr.Revert(target), expr.Revert(iter), RevertStmts(body), RevertStmts(orelse)); + } + + public expr target { get; set; } + + public expr iter { get; set; } + + public PythonList body { get; set; } + + public PythonList orelse { get; set; } + } + + [PythonType] + public class AsyncWith : stmt { + public AsyncWith() { + _fields = PythonTuple.MakeTuple(new[] { nameof(items), nameof(body) }); + } + + public AsyncWith(PythonList items, PythonList body, + [Optional] int? lineno, [Optional] int? col_offset) + : this() { + this.items = items; + this.body = body; + _lineno = lineno; + _col_offset = col_offset; + } + + internal AsyncWith(AsyncWithStatement with) + : this() { + items = new PythonList(1); + items.AddNoLock(new withitem(Convert(with.ContextManager), with.Variable == null ? null : Convert(with.Variable, Store.Instance))); + body = ConvertStatements(with.Body); + } + + internal override Statement Revert() { + Statement statement = RevertStmts(this.body); + foreach (withitem item in items) { + statement = new AsyncWithStatement(expr.Revert(item.context_expr), expr.Revert(item.optional_vars), statement); + } + return statement; + } + + public PythonList items { get; set; } + + public PythonList body { get; set; } + } } } diff --git a/src/core/IronPython/Runtime/AsyncEnumerableWrapper.cs b/src/core/IronPython/Runtime/AsyncEnumerableWrapper.cs new file mode 100644 index 000000000..ba585bbcd --- /dev/null +++ b/src/core/IronPython/Runtime/AsyncEnumerableWrapper.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +// IAsyncEnumerable / IAsyncEnumerator require .NET Core 3.0+ +#if NET + +#nullable enable + +using System.Collections.Generic; +using System.Threading.Tasks; + +using Microsoft.Scripting.Runtime; + +using IronPython.Runtime.Exceptions; +using IronPython.Runtime.Operations; +using IronPython.Runtime.Types; + +namespace IronPython.Runtime { + /// + /// Wraps to implement the Python + /// async iterator protocol (__aiter__, __anext__). + /// Returned by . + /// + [PythonType("async_enumerator_wrapper")] + public sealed class AsyncEnumeratorWrapper { + private readonly IAsyncEnumerator _enumerator; + + internal AsyncEnumeratorWrapper(IAsyncEnumerator enumerator) { + _enumerator = enumerator; + } + + public AsyncEnumeratorWrapper __aiter__() => this; + + /// + /// Returns an awaitable that, when awaited, advances the async enumerator + /// and returns the next value or raises StopAsyncIteration. + /// + public object __anext__() { + return new AsyncEnumeratorAwaitable(_enumerator); + } + } + + /// + /// The awaitable object returned by . + /// Implements both __await__ and __iter__/__next__ (the yield-from protocol). + /// Non-blocking: yields the Task back to the runner if MoveNextAsync is not yet completed. + /// + [PythonType("async_enumerator_awaitable")] + public sealed class AsyncEnumeratorAwaitable { + private readonly IAsyncEnumerator _enumerator; + private Task? _moveNextTask; + + internal AsyncEnumeratorAwaitable(IAsyncEnumerator enumerator) { + _enumerator = enumerator; + } + + public AsyncEnumeratorAwaitable __await__() => this; + + public AsyncEnumeratorAwaitable __iter__() => this; + + [LightThrowing] + public object __next__() { + var task = _moveNextTask ??= _enumerator.MoveNextAsync().AsTask(); + if (!task.IsCompleted) return (Task)task; // yield Task to runner + bool hasNext = task.GetAwaiter().GetResult(); + _moveNextTask = null; // reset for next call + if (!hasNext) { + return LightExceptions.Throw( + new PythonExceptions._StopAsyncIteration().InitAndGetClrException()); + } + return LightExceptions.Throw( + new PythonExceptions._StopIteration().InitAndGetClrException(_enumerator.Current!)); + } + } +} + +#endif diff --git a/src/core/IronPython/Runtime/Coroutine.cs b/src/core/IronPython/Runtime/Coroutine.cs new file mode 100644 index 000000000..e7cfdf575 --- /dev/null +++ b/src/core/IronPython/Runtime/Coroutine.cs @@ -0,0 +1,169 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System.Runtime.CompilerServices; +using System.Threading.Tasks; + +using Microsoft.Scripting.Runtime; + +using IronPython.Runtime.Exceptions; +using IronPython.Runtime.Operations; +using IronPython.Runtime.Types; + +namespace IronPython.Runtime { + [PythonType("coroutine")] + [DontMapIDisposableToContextManager, DontMapIEnumerableToContains] + public sealed class PythonCoroutine : ICodeFormattable, IWeakReferenceable { + private readonly PythonGenerator _generator; + private WeakRefTracker? _tracker; + + internal PythonCoroutine(PythonGenerator generator) { + _generator = generator; + } + + [LightThrowing] + public object send(object? value) { + return _generator.send(value); + } + + [LightThrowing] + public object @throw(object? type) { + return _generator.@throw(type); + } + + [LightThrowing] + public object @throw(object? type, object? value) { + return _generator.@throw(type, value); + } + + [LightThrowing] + public object @throw(object? type, object? value, object? traceback) { + return _generator.@throw(type, value, traceback); + } + + [LightThrowing] + public object? close() { + return _generator.close(); + } + + public object __await__() { + return new CoroutineWrapper(this); + } + + public FunctionCode cr_code => _generator.gi_code; + + public int cr_running => _generator.gi_running; + + public TraceBackFrame cr_frame => _generator.gi_frame; + + public string __name__ => _generator.__name__; + + public string __qualname__ { + get => _generator.__name__; + } + + /// + /// Converts this coroutine into a .NET , + /// allowing C# code to await an IronPython async method. + /// The coroutine is driven on a single thread to avoid issues with + /// thread-local state in the Python generator runtime. + /// + public Task AsTask() { + return Task.Run(() => { + while (true) { + object result = send(null); + + if (LightExceptions.IsLightException(result)) { + var clrExc = LightExceptions.GetLightException(result); + if (clrExc is StopIterationException) { + var pyExc = ((IPythonAwareException)clrExc).PythonException; + return pyExc is PythonExceptions._StopIteration si ? si.value : null; + } + throw clrExc; + } + + if (result is Task task) { + task.Wait(); + } + } + }); + } + + /// + /// Enables await coroutine from C# code. + /// + public TaskAwaiter GetAwaiter() => AsTask().GetAwaiter(); + + internal PythonGenerator Generator => _generator; + + #region ICodeFormattable Members + + public string __repr__(CodeContext context) { + return $""; + } + + #endregion + + #region IWeakReferenceable Members + + WeakRefTracker? IWeakReferenceable.GetWeakRef() { + return _tracker; + } + + bool IWeakReferenceable.SetWeakRef(WeakRefTracker value) { + _tracker = value; + return true; + } + + void IWeakReferenceable.SetFinalizer(WeakRefTracker value) { + _tracker = value; + } + + #endregion + } + + [PythonType("coroutine_wrapper")] + public sealed class CoroutineWrapper { + private readonly PythonCoroutine _coroutine; + + internal CoroutineWrapper(PythonCoroutine coroutine) { + _coroutine = coroutine; + } + + [LightThrowing] + public object __next__() { + return _coroutine.send(null); + } + + [LightThrowing] + public object send(object? value) { + return _coroutine.send(value); + } + + [LightThrowing] + public object @throw(object? type) { + return _coroutine.@throw(type); + } + + [LightThrowing] + public object @throw(object? type, object? value) { + return _coroutine.@throw(type, value); + } + + [LightThrowing] + public object @throw(object? type, object? value, object? traceback) { + return _coroutine.@throw(type, value, traceback); + } + + public object? close() { + return _coroutine.close(); + } + + public CoroutineWrapper __iter__() { + return this; + } + } +} diff --git a/src/core/IronPython/Runtime/Exceptions/PythonExceptions.Generated.cs b/src/core/IronPython/Runtime/Exceptions/PythonExceptions.Generated.cs index 86f04c43a..4bf06b1ab 100644 --- a/src/core/IronPython/Runtime/Exceptions/PythonExceptions.Generated.cs +++ b/src/core/IronPython/Runtime/Exceptions/PythonExceptions.Generated.cs @@ -114,6 +114,17 @@ public _StopAsyncIteration(PythonType type) : base(type) { } public object value { get; set; } } + [MultiRuntimeAware] + private static PythonType CancelledErrorStorage; + public static PythonType CancelledError { + get { + if (CancelledErrorStorage == null) { + Interlocked.CompareExchange(ref CancelledErrorStorage, CreateSubType(Exception, "CancelledError", (msg, innerException) => new OperationCanceledException(msg, innerException)), null); + } + return CancelledErrorStorage; + } + } + [MultiRuntimeAware] private static PythonType ArithmeticErrorStorage; public static PythonType ArithmeticError { @@ -912,6 +923,7 @@ public static PythonType ResourceWarning { if (clrException is ModuleNotFoundException) return new PythonExceptions._ImportError(PythonExceptions.ModuleNotFoundError); if (clrException is NotADirectoryException) return new PythonExceptions._OSError(PythonExceptions.NotADirectoryError); if (clrException is NotImplementedException) return new PythonExceptions.BaseException(PythonExceptions.NotImplementedError); + if (clrException is OperationCanceledException) return new PythonExceptions.BaseException(PythonExceptions.CancelledError); if (clrException is OutOfMemoryException) return new PythonExceptions.BaseException(PythonExceptions.MemoryError); if (clrException is ProcessLookupException) return new PythonExceptions._OSError(PythonExceptions.ProcessLookupError); if (clrException is RecursionException) return new PythonExceptions.BaseException(PythonExceptions.RecursionError); diff --git a/src/core/IronPython/Runtime/FunctionAttributes.cs b/src/core/IronPython/Runtime/FunctionAttributes.cs index 45edf190e..8e7f667e5 100644 --- a/src/core/IronPython/Runtime/FunctionAttributes.cs +++ b/src/core/IronPython/Runtime/FunctionAttributes.cs @@ -23,6 +23,10 @@ public enum FunctionAttributes { /// Generator = 0x20, /// + /// Set if the function is a coroutine (async def). + /// + Coroutine = 0x100, + /// /// IronPython specific: Set if the function includes nested exception handling and therefore can alter /// sys.exc_info(). /// diff --git a/src/core/IronPython/Runtime/FunctionCode.cs b/src/core/IronPython/Runtime/FunctionCode.cs index b0310185e..a790be7e9 100644 --- a/src/core/IronPython/Runtime/FunctionCode.cs +++ b/src/core/IronPython/Runtime/FunctionCode.cs @@ -750,15 +750,17 @@ private LambdaExpression GetGeneratorOrNormalLambdaTracing(PythonContext context debugProperties // custom payload ); - if ((Flags & FunctionAttributes.Generator) == 0) { + if ((Flags & (FunctionAttributes.Generator | FunctionAttributes.Coroutine)) == 0) { return context.DebugContext.TransformLambda((LambdaExpression)Compiler.Ast.Node.RemoveFrame(_lambda.GetLambda()), debugInfo); } + bool isCoroutine = (Flags & FunctionAttributes.Coroutine) != 0; return Expression.Lambda( Code.Type, new GeneratorRewriter( _lambda.Name, - Compiler.Ast.Node.RemoveFrame(Code.Body) + Compiler.Ast.Node.RemoveFrame(Code.Body), + isCoroutine ).Reduce( _lambda.ShouldInterpret, _lambda.EmitDebugSymbols, @@ -779,13 +781,15 @@ private LambdaExpression GetGeneratorOrNormalLambdaTracing(PythonContext context /// private LightLambdaExpression GetGeneratorOrNormalLambda() { LightLambdaExpression finalCode; - if ((Flags & FunctionAttributes.Generator) == 0) { + if ((Flags & (FunctionAttributes.Generator | FunctionAttributes.Coroutine)) == 0) { finalCode = Code; } else { + bool isCoroutine = (Flags & FunctionAttributes.Coroutine) != 0; finalCode = Code.ToGenerator( _lambda.ShouldInterpret, _lambda.EmitDebugSymbols, - _lambda.GlobalParent.PyContext.Options.CompilationThreshold + _lambda.GlobalParent.PyContext.Options.CompilationThreshold, + isCoroutine ); } return finalCode; diff --git a/src/core/IronPython/Runtime/Operations/InstanceOps.cs b/src/core/IronPython/Runtime/Operations/InstanceOps.cs index 1d6e9233e..5c7edca0b 100644 --- a/src/core/IronPython/Runtime/Operations/InstanceOps.cs +++ b/src/core/IronPython/Runtime/Operations/InstanceOps.cs @@ -205,6 +205,47 @@ public static object NextMethod(object self) { #endregion + #region Async Interop + + /// + /// Provides the implementation of __await__ for . + /// + public static object TaskAwaitMethod(System.Threading.Tasks.Task self) { + return new TaskAwaitable(self); + } + + /// + /// Provides the implementation of __await__ for . + /// + public static object TaskAwaitMethodGeneric(System.Threading.Tasks.Task self) { + return new TaskAwaitable(self); + } + +#if NET + /// + /// Provides the implementation of __await__ for . + /// + public static object ValueTaskAwaitMethod(System.Threading.Tasks.ValueTask self) { + return new ValueTaskAwaitable(self); + } + + /// + /// Provides the implementation of __await__ for . + /// + public static object ValueTaskAwaitMethodGeneric(System.Threading.Tasks.ValueTask self) { + return new ValueTaskAwaitable(self); + } + + /// + /// Provides the implementation of __aiter__ for . + /// + public static object AsyncIterMethod(System.Collections.Generic.IAsyncEnumerable self) { + return new AsyncEnumeratorWrapper(self.GetAsyncEnumerator()); + } +#endif + + #endregion + /// /// __dir__(self) -> Returns the list of members defined on a foreign IDynamicMetaObjectProvider. /// diff --git a/src/core/IronPython/Runtime/Operations/PythonOps.Generated.cs b/src/core/IronPython/Runtime/Operations/PythonOps.Generated.cs index 87b537945..6f84e6297 100644 --- a/src/core/IronPython/Runtime/Operations/PythonOps.Generated.cs +++ b/src/core/IronPython/Runtime/Operations/PythonOps.Generated.cs @@ -125,6 +125,9 @@ public static partial class PythonOps { internal static Exception ModuleNotFoundError(string message) => new ModuleNotFoundException(message); public static Exception ModuleNotFoundError(string format, params object?[] args) => new ModuleNotFoundException(string.Format(format, args)); + internal static Exception CancelledError(string message) => new OperationCanceledException(message); + public static Exception CancelledError(string format, params object?[] args) => new OperationCanceledException(string.Format(format, args)); + // *** END GENERATED CODE *** #endregion diff --git a/src/core/IronPython/Runtime/Operations/PythonOps.cs b/src/core/IronPython/Runtime/Operations/PythonOps.cs index 1c871d283..9a88b06e7 100644 --- a/src/core/IronPython/Runtime/Operations/PythonOps.cs +++ b/src/core/IronPython/Runtime/Operations/PythonOps.cs @@ -3196,6 +3196,14 @@ public static PythonGenerator MakeGenerator(PythonFunction function, MutableTupl return new PythonGenerator(function, next, data); } + public static PythonCoroutine MakeCoroutine(PythonFunction function, MutableTuple data, object generatorCode) { + return new PythonCoroutine(MakeGenerator(function, data, generatorCode)); + } + + public static object MakeCoroutineWrapper(PythonGenerator generator) { + return new PythonCoroutine(generator); + } + public static object MakeGeneratorExpression(object function, object input) { PythonFunction func = (PythonFunction)function; return ((Func)func.__code__.Target)(func, input); @@ -4267,12 +4275,12 @@ public static List PushFrame(CodeContext/*!*/ context, FunctionCo return stack; } - internal static LightLambdaExpression ToGenerator(this LightLambdaExpression code, bool shouldInterpret, bool debuggable, int compilationThreshold) { + internal static LightLambdaExpression ToGenerator(this LightLambdaExpression code, bool shouldInterpret, bool debuggable, int compilationThreshold, bool isCoroutine = false) { #pragma warning disable CA2263 // Prefer generic overload when type is known return Utils.LightLambda( typeof(object), code.Type, - new GeneratorRewriter(code.Name, code.Body).Reduce(shouldInterpret, debuggable, compilationThreshold, code.Parameters, x => x), + new GeneratorRewriter(code.Name, code.Body, isCoroutine).Reduce(shouldInterpret, debuggable, compilationThreshold, code.Parameters, x => x), code.Name, code.Parameters ); diff --git a/src/core/IronPython/Runtime/TaskAwaitable.cs b/src/core/IronPython/Runtime/TaskAwaitable.cs new file mode 100644 index 000000000..f255be2f2 --- /dev/null +++ b/src/core/IronPython/Runtime/TaskAwaitable.cs @@ -0,0 +1,122 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System.Threading.Tasks; + +using Microsoft.Scripting.Runtime; + +using IronPython.Runtime.Exceptions; +using IronPython.Runtime.Types; + +namespace IronPython.Runtime { + /// + /// Provides an __await__ protocol wrapper for , + /// enabling await task from Python async code. + /// Non-blocking: yields the Task back to the runner if not yet completed. + /// + [PythonType("task_awaitable")] + public sealed class TaskAwaitable { + private readonly Task _task; + + internal TaskAwaitable(Task task) { + _task = task; + } + + public TaskAwaitable __await__() => this; + + public TaskAwaitable __iter__() => this; + + [LightThrowing] + public object __next__() { + var task = _task; + if (!task.IsCompleted) return task; // yield Task to runner + task.GetAwaiter().GetResult(); // propagate exceptions + return LightExceptions.Throw(new PythonExceptions._StopIteration().InitAndGetClrException()); + } + } + + /// + /// Provides an __await__ protocol wrapper for , + /// enabling result = await task from Python async code. + /// Non-blocking: yields the Task back to the runner if not yet completed. + /// + [PythonType("task_awaitable")] + public sealed class TaskAwaitable { + private readonly Task _task; + + internal TaskAwaitable(Task task) { + _task = task; + } + + public TaskAwaitable __await__() => this; + + public TaskAwaitable __iter__() => this; + + [LightThrowing] + public object __next__() { + var task = _task; + if (!task.IsCompleted) return task; // yield Task to runner + T result = task.GetAwaiter().GetResult(); + return LightExceptions.Throw(new PythonExceptions._StopIteration().InitAndGetClrException(result!)); + } + } + +#if NET + /// + /// Provides an __await__ protocol wrapper for , + /// enabling await valuetask from Python async code. + /// Non-blocking: yields the Task back to the runner if not yet completed. + /// + [PythonType("task_awaitable")] + public sealed class ValueTaskAwaitable { + private readonly ValueTask _task; + private Task? _asTask; + + internal ValueTaskAwaitable(ValueTask task) { + _task = task; + } + + public ValueTaskAwaitable __await__() => this; + + public ValueTaskAwaitable __iter__() => this; + + [LightThrowing] + public object __next__() { + var task = _asTask ??= _task.AsTask(); + if (!task.IsCompleted) return task; // yield Task to runner + task.GetAwaiter().GetResult(); // propagate exceptions + return LightExceptions.Throw(new PythonExceptions._StopIteration().InitAndGetClrException()); + } + } + + /// + /// Provides an __await__ protocol wrapper for , + /// enabling result = await valuetask from Python async code. + /// Non-blocking: yields the Task back to the runner if not yet completed. + /// + [PythonType("task_awaitable")] + public sealed class ValueTaskAwaitable { + private readonly ValueTask _task; + private Task? _asTask; + + internal ValueTaskAwaitable(ValueTask task) { + _task = task; + } + + public ValueTaskAwaitable __await__() => this; + + public ValueTaskAwaitable __iter__() => this; + + [LightThrowing] + public object __next__() { + var task = _asTask ??= _task.AsTask(); + if (!task.IsCompleted) return (Task)task; // yield Task to runner + T result = task.GetAwaiter().GetResult(); + return LightExceptions.Throw(new PythonExceptions._StopIteration().InitAndGetClrException(result!)); + } + } +#endif +} diff --git a/src/core/IronPython/Runtime/Types/PythonTypeInfo.cs b/src/core/IronPython/Runtime/Types/PythonTypeInfo.cs index 87c537fd7..515e81402 100644 --- a/src/core/IronPython/Runtime/Types/PythonTypeInfo.cs +++ b/src/core/IronPython/Runtime/Types/PythonTypeInfo.cs @@ -10,6 +10,7 @@ using System.Linq; using System.Numerics; using System.Reflection; +using System.Threading.Tasks; using Microsoft.Scripting; using Microsoft.Scripting.Actions; @@ -672,6 +673,9 @@ private class ProtectedMemberResolver : MemberResolver { new OneOffResolver("__len__", LengthResolver), new OneOffResolver("__format__", FormatResolver), new OneOffResolver("__next__", NextResolver), + new OneOffResolver("__await__", AwaitResolver), + new OneOffResolver("__aiter__", AsyncIterResolver), + new OneOffResolver("__anext__", AsyncNextResolver), new OneOffResolver("__complex__", ComplexResolver), new OneOffResolver("__float__", FloatResolver), @@ -965,6 +969,89 @@ internal static MemberGroup GetExtensionMemberGroup(Type type, MemberInfo[] news return MemberGroup.EmptyGroup; } + /// + /// Provides a resolution for __await__ on Task, Task<T>, ValueTask and ValueTask<T>. + /// + private static MemberGroup/*!*/ AwaitResolver(MemberBinder/*!*/ binder, Type/*!*/ type) { + foreach (Type t in binder.GetContributingTypes(type)) { + if (t.GetMember("__await__").Length > 0) { + return MemberGroup.EmptyGroup; + } + } + + if (typeof(Task).IsAssignableFrom(type)) { + // Walk up the type hierarchy to find Task (the runtime type may be + // a subclass such as AsyncStateMachineBox). + Type taskType = type; + while (taskType != null && !(taskType.IsGenericType && taskType.GetGenericTypeDefinition() == typeof(Task<>))) { + taskType = taskType.BaseType; + } + if (taskType != null) { + // Only use the generic TaskAwaitable if the result type is visible + // (e.g. Task.CompletedTask is Task at runtime where + // VoidTaskResult is internal — fall back to non-generic TaskAwaitable) + Type resultType = taskType.GetGenericArguments()[0]; + if (resultType.IsVisible) { + MethodInfo genMeth = typeof(InstanceOps).GetMethod(nameof(InstanceOps.TaskAwaitMethodGeneric)); + return new MemberGroup( + MethodTracker.FromMemberInfo(genMeth.MakeGenericMethod(taskType.GetGenericArguments()), type) + ); + } + } + return GetInstanceOpsMethod(type, nameof(InstanceOps.TaskAwaitMethod)); + } + +#if NET + if (type.IsGenericType) { + Type genDef = type.GetGenericTypeDefinition(); + if (genDef == typeof(ValueTask<>)) { + MethodInfo genMeth = typeof(InstanceOps).GetMethod(nameof(InstanceOps.ValueTaskAwaitMethodGeneric)); + return new MemberGroup( + MethodTracker.FromMemberInfo(genMeth.MakeGenericMethod(type.GetGenericArguments()), type) + ); + } + } + + if (type == typeof(ValueTask)) { + return GetInstanceOpsMethod(type, nameof(InstanceOps.ValueTaskAwaitMethod)); + } +#endif + + return MemberGroup.EmptyGroup; + } + + /// + /// Provides a resolution for __aiter__ on IAsyncEnumerable<T>. + /// + private static MemberGroup/*!*/ AsyncIterResolver(MemberBinder/*!*/ binder, Type/*!*/ type) { +#if NET + foreach (Type t in binder.GetContributingTypes(type)) { + if (t.GetMember("__aiter__").Length > 0) { + return MemberGroup.EmptyGroup; + } + } + + foreach (Type t in binder.GetInterfaces(type)) { + if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>)) { + MethodInfo genMeth = typeof(InstanceOps).GetMethod(nameof(InstanceOps.AsyncIterMethod)); + return new MemberGroup( + MethodTracker.FromMemberInfo(genMeth.MakeGenericMethod(t.GetGenericArguments()), type) + ); + } + } +#endif + + return MemberGroup.EmptyGroup; + } + + /// + /// Provides a resolution for __anext__ on AsyncEnumeratorWrapper<T>. + /// Not auto-mapped from interfaces; the wrapper class provides __anext__ directly. + /// + private static MemberGroup/*!*/ AsyncNextResolver(MemberBinder/*!*/ binder, Type/*!*/ type) { + return MemberGroup.EmptyGroup; + } + /// /// Provides a resolution for __len__ /// diff --git a/tests/IronPython.Tests/AsyncInteropHelpers.cs b/tests/IronPython.Tests/AsyncInteropHelpers.cs new file mode 100644 index 000000000..d5c6382a0 --- /dev/null +++ b/tests/IronPython.Tests/AsyncInteropHelpers.cs @@ -0,0 +1,123 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +#if NET + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace IronPythonTest { + /// + /// Provides IAsyncEnumerable test helpers accessible from Python via clr.AddReference('IronPythonTest'). + /// + public static class AsyncInteropHelpers { + /// + /// Returns an IAsyncEnumerable<int> that yields the given values. + /// + public static IAsyncEnumerable GetAsyncInts(params int[] values) { + return YieldInts(values); + } + + private static async IAsyncEnumerable YieldInts( + int[] values, + [EnumeratorCancellation] CancellationToken ct = default) { + foreach (var v in values) { + await Task.Yield(); + ct.ThrowIfCancellationRequested(); + yield return v; + } + } + + /// + /// Returns an IAsyncEnumerable<string> that yields the given values. + /// + public static IAsyncEnumerable GetAsyncStrings(params string[] values) { + return YieldStrings(values); + } + + private static async IAsyncEnumerable YieldStrings( + string[] values, + [EnumeratorCancellation] CancellationToken ct = default) { + foreach (var v in values) { + await Task.Yield(); + yield return v; + } + } + + /// + /// Returns a real async Task<int> with a delay. + /// The runtime type will be AsyncStateMachineBox, not Task<int> directly. + /// + public static async Task GetAsyncInt(int value, int delayMs = 50) { + await Task.Delay(delayMs); + return value; + } + + /// + /// Returns a real async Task<string> with a delay. + /// + public static async Task GetAsyncString(string value, int delayMs = 50) { + await Task.Delay(delayMs); + return value; + } + + /// + /// Returns a real async Task (void result) with a delay. + /// + public static async Task DoAsync(int delayMs = 50) { + await Task.Delay(delayMs); + } + + /// + /// Async Task<int> that respects a CancellationToken. + /// Throws OperationCanceledException if token is cancelled during the delay. + /// + public static async Task GetAsyncIntWithCancellation(int value, CancellationToken token, int delayMs = 5000) { + await Task.Delay(delayMs, token); + return value; + } + + /// + /// Async Task that respects a CancellationToken. + /// + public static async Task DoAsyncWithCancellation(CancellationToken token, int delayMs = 5000) { + await Task.Delay(delayMs, token); + } + + /// + /// IAsyncEnumerable<int> that yields values with delay and respects cancellation. + /// + public static IAsyncEnumerable GetAsyncIntsWithCancellation(CancellationToken token, params int[] values) { + return YieldIntsWithCancellation(values, token); + } + + private static async IAsyncEnumerable YieldIntsWithCancellation( + int[] values, + CancellationToken token, + [EnumeratorCancellation] CancellationToken ct = default) { + using var linked = CancellationTokenSource.CreateLinkedTokenSource(token, ct); + foreach (var v in values) { + await Task.Delay(50, linked.Token); + yield return v; + } + } + + /// + /// Returns an empty IAsyncEnumerable<int>. + /// + public static IAsyncEnumerable GetEmptyAsyncInts() { + return EmptyAsyncEnumerable(); + } + + private static async IAsyncEnumerable EmptyAsyncEnumerable( + [EnumeratorCancellation] CancellationToken ct = default) { + await Task.CompletedTask; + yield break; + } + } +} + +#endif diff --git a/tests/IronPython.Tests/Cases/CommonCases.cs b/tests/IronPython.Tests/Cases/CommonCases.cs index b8bd7d6c3..d8d5ac163 100644 --- a/tests/IronPython.Tests/Cases/CommonCases.cs +++ b/tests/IronPython.Tests/Cases/CommonCases.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.IO; using System.Threading; using NUnit.Framework; @@ -36,6 +37,42 @@ protected int TestImpl(TestInfo testcase) { return -1; } finally { m?.ReleaseMutex(); + CleanupTempFiles(testcase); + } + } + + /// + /// Removes @test_*_tmp files/directories left behind by test.support.TESTFN. + /// + private static void CleanupTempFiles(TestInfo testcase) { + var testDir = Path.GetDirectoryName(testcase.Path); + if (testDir is null) return; + + // Clean test directory and also the StdLib test directory + CleanupTempFilesInDir(testDir); + var stdlibTestDir = Path.Combine(CaseExecuter.FindRoot(), "src", "core", "IronPython.StdLib", "lib", "test"); + if (stdlibTestDir != testDir) { + CleanupTempFilesInDir(stdlibTestDir); + } + } + + private static void CleanupTempFilesInDir(string dir) { + if (!Directory.Exists(dir)) return; + + try { + foreach (var entry in Directory.EnumerateFileSystemEntries(dir, "@test_*_tmp*")) { + try { + if (File.GetAttributes(entry).HasFlag(FileAttributes.Directory)) { + Directory.Delete(entry, recursive: true); + } else { + File.Delete(entry); + } + } catch { + // ignore locked/in-use files + } + } + } catch { + // ignore enumeration errors } } } diff --git a/tests/IronPython.Tests/Cases/IronPythonCasesManifest.ini b/tests/IronPython.Tests/Cases/IronPythonCasesManifest.ini index 07fe270f8..c6106d718 100644 --- a/tests/IronPython.Tests/Cases/IronPythonCasesManifest.ini +++ b/tests/IronPython.Tests/Cases/IronPythonCasesManifest.ini @@ -4,6 +4,9 @@ WorkingDirectory=$(TEST_FILE_DIR) Redirect=false Timeout=120000 # 2 minute timeout +[IronPython.test_async] +IsolationLevel=PROCESS # loads IronPythonTest assembly, causes a failure in IronPython.test_attrinjector + [IronPython.test_builtin_stdlib] RunCondition=NOT $(IS_MONO) Reason=Exception on adding DocTestSuite diff --git a/tests/IronPython.Tests/CoroutineAsTaskTest.cs b/tests/IronPython.Tests/CoroutineAsTaskTest.cs new file mode 100644 index 000000000..c5b17982b --- /dev/null +++ b/tests/IronPython.Tests/CoroutineAsTaskTest.cs @@ -0,0 +1,169 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +#if NET + +using System; +using System.Threading.Tasks; + +using IronPython.Hosting; +using IronPython.Runtime; + +using Microsoft.Scripting.Hosting; + +using NUnit.Framework; + +namespace IronPythonTest { + public class CoroutineAsTaskTest { + private readonly ScriptEngine _engine; + private readonly ScriptScope _scope; + + public CoroutineAsTaskTest() { + _engine = Python.CreateEngine(); + _scope = _engine.CreateScope(); + } + + [Test] + public void AsTask_SimpleReturn() { + _engine.Execute(@" +async def foo(): + return 42 +coro = foo() +", _scope); + + var coro = (PythonCoroutine)_scope.GetVariable("coro"); + var result = coro.AsTask().GetAwaiter().GetResult(); + Assert.That(result, Is.EqualTo(42)); + } + + [Test] + public void AsTask_StringReturn() { + _engine.Execute(@" +async def foo(): + return 'hello' +coro = foo() +", _scope); + + var coro = (PythonCoroutine)_scope.GetVariable("coro"); + var result = coro.AsTask().GetAwaiter().GetResult(); + Assert.That(result, Is.EqualTo("hello")); + } + + [Test] + public void AsTask_AwaitCompletedTask() { + _engine.Execute(@" +from System.Threading.Tasks import Task +async def foo(): + val = await Task.FromResult(10) + return val + 5 +coro = foo() +", _scope); + + var coro = (PythonCoroutine)_scope.GetVariable("coro"); + var result = coro.AsTask().GetAwaiter().GetResult(); + Assert.That(result, Is.EqualTo(15)); + } + + [Test] + public void AsTask_AwaitRealAsync() { + _engine.Execute(@" +from System.Threading.Tasks import Task +async def foo(): + await Task.Delay(50) + return 'delayed' +coro = foo() +", _scope); + + var coro = (PythonCoroutine)_scope.GetVariable("coro"); + var result = coro.AsTask().GetAwaiter().GetResult(); + Assert.That(result, Is.EqualTo("delayed")); + } + + [Test] + public async Task AsTask_CanBeAwaited() { + _engine.Execute(@" +from System.Threading.Tasks import Task +async def foo(): + await Task.Delay(50) + return 99 +coro = foo() +", _scope); + + var coro = (PythonCoroutine)_scope.GetVariable("coro"); + var result = await coro.AsTask(); + Assert.That(result, Is.EqualTo(99)); + } + + [Test] + public void AsTask_PropagatesException() { + _engine.Execute(@" +async def foo(): + raise ValueError('boom') +coro = foo() +", _scope); + + var coro = (PythonCoroutine)_scope.GetVariable("coro"); + Assert.Throws(() => coro.AsTask().Wait()); + } + + [Test] + public void AsTask_MultipleAwaits() { + _engine.Execute(@" +from System.Threading.Tasks import Task +async def foo(): + a = await Task.FromResult(10) + b = await Task.FromResult(20) + return a + b +coro = foo() +", _scope); + + var coro = (PythonCoroutine)_scope.GetVariable("coro"); + var result = coro.AsTask().GetAwaiter().GetResult(); + Assert.That(result, Is.EqualTo(30)); + } + + [Test] + public void AsTask_NoneReturn() { + _engine.Execute(@" +async def foo(): + pass +coro = foo() +", _scope); + + var coro = (PythonCoroutine)_scope.GetVariable("coro"); + var result = coro.AsTask().GetAwaiter().GetResult(); + Assert.That(result, Is.Null); + } + + [Test] + public async Task DirectAwait_Simple() { + _engine.Execute(@" +async def foo(): + return 42 +coro = foo() +", _scope); + + var coro = (PythonCoroutine)_scope.GetVariable("coro"); + var result = await coro; + Assert.That(result, Is.EqualTo(42)); + } + + [Test] + public async Task DirectAwait_WithRealAsync() { + _engine.Execute(@" +from System.Threading.Tasks import Task +async def foo(): + await Task.Delay(50) + return 'done' +coro = foo() +", _scope); + + var coro = (PythonCoroutine)_scope.GetVariable("coro"); + var result = await coro; + Assert.That(result, Is.EqualTo("done")); + } + } +} + +#endif diff --git a/tests/suite/test_async.py b/tests/suite/test_async.py new file mode 100644 index 000000000..93236db82 --- /dev/null +++ b/tests/suite/test_async.py @@ -0,0 +1,638 @@ +# Licensed to the .NET Foundation under one or more agreements. +# The .NET Foundation licenses this file to you under the Apache 2.0 License. +# See the LICENSE file in the project root for more information. + +"""Tests for PEP 492: async/await support.""" + +import unittest + +from iptest import run_test + + +def run_coro(coro): + """Run a coroutine to completion, blocking on yielded .NET Tasks.""" + value = None + while True: + try: + task = coro.send(value) + # .NET Task yielded — block on it (test runner is synchronous) + task.Wait() + value = None + except StopIteration as e: + return e.value + + +class AsyncIter: + """Async iterator for testing async for.""" + def __init__(self, items): + self.items = list(items) + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.items): + raise StopAsyncIteration + val = self.items[self.index] + self.index += 1 + return val + + +class AsyncDefTest(unittest.TestCase): + """Tests for basic async def and coroutine type.""" + + def test_basic_return(self): + async def foo(): + return 42 + self.assertEqual(run_coro(foo()), 42) + + def test_coroutine_type(self): + async def foo(): + return 1 + coro = foo() + self.assertEqual(type(coro).__name__, 'coroutine') + try: + coro.send(None) + except StopIteration: + pass + + def test_coroutine_properties(self): + async def named_coro(): + return 1 + coro = named_coro() + self.assertTrue(hasattr(coro, 'cr_code')) + self.assertTrue(hasattr(coro, 'cr_running')) + self.assertTrue(hasattr(coro, 'cr_frame')) + self.assertTrue(hasattr(coro, '__name__')) + self.assertTrue(hasattr(coro, '__qualname__')) + self.assertEqual(coro.__name__, 'named_coro') + try: + coro.send(None) + except StopIteration: + pass + + def test_async_def_no_await(self): + async def foo(): + x = 1 + y = 2 + return x + y + self.assertEqual(run_coro(foo()), 3) + + def test_coroutine_close(self): + async def foo(): + return 1 + coro = foo() + coro.close() # should not raise + + def test_coroutine_throw(self): + async def foo(): + return 42 + coro = foo() + with self.assertRaises(ValueError) as cm: + coro.throw(ValueError('boom')) + self.assertEqual(str(cm.exception), 'boom') + + +class AwaitTest(unittest.TestCase): + """Tests for await expression.""" + + def test_basic_await(self): + async def inner(): + return 10 + + async def outer(): + val = await inner() + return val + 5 + + self.assertEqual(run_coro(outer()), 15) + + def test_multiple_awaits(self): + async def val(x): + return x + + async def test(): + return await val(1) + await val(2) + await val(3) + + self.assertEqual(run_coro(test()), 6) + + def test_await_protocol(self): + async def foo(): + return 99 + coro = foo() + wrapper = coro.__await__() + self.assertEqual(type(wrapper).__name__, 'coroutine_wrapper') + try: + wrapper.__next__() + except StopIteration as e: + self.assertEqual(e.value, 99) + + def test_custom_awaitable(self): + class MyAwaitable: + def __await__(self): + return iter([]) + + async def test(): + await MyAwaitable() + return 'done' + + self.assertEqual(run_coro(test()), 'done') + + +class AsyncWithTest(unittest.TestCase): + """Tests for async with statement.""" + + def test_basic_async_with(self): + class CM: + def __init__(self): + self.entered = False + self.exited = False + async def __aenter__(self): + self.entered = True + return self + async def __aexit__(self, *args): + self.exited = True + + async def test(): + cm = CM() + async with cm: + self.assertTrue(cm.entered) + self.assertTrue(cm.exited) + return 'ok' + + self.assertEqual(run_coro(test()), 'ok') + + def test_async_with_as(self): + class CM: + async def __aenter__(self): + return 'value' + async def __aexit__(self, *args): + pass + + async def test(): + async with CM() as v: + return v + + self.assertEqual(run_coro(test()), 'value') + + def test_async_with_order(self): + class CM: + def __init__(self): + self.log = [] + async def __aenter__(self): + self.log.append('enter') + return self + async def __aexit__(self, *args): + self.log.append('exit') + + async def test(): + cm = CM() + async with cm: + cm.log.append('body') + return cm.log + + self.assertEqual(run_coro(test()), ['enter', 'body', 'exit']) + + def test_async_with_no_as(self): + class CM: + async def __aenter__(self): + return 'unused' + async def __aexit__(self, *args): + pass + + async def test(): + async with CM(): + return 'ok' + + self.assertEqual(run_coro(test()), 'ok') + + +class AsyncForTest(unittest.TestCase): + """Tests for async for statement.""" + + def test_basic_async_for(self): + async def test(): + result = [] + async for x in AsyncIter([1, 2, 3]): + result.append(x) + return result + + self.assertEqual(run_coro(test()), [1, 2, 3]) + + def test_async_for_empty(self): + async def test(): + result = [] + async for x in AsyncIter([]): + result.append(x) + return result + + self.assertEqual(run_coro(test()), []) + + def test_async_for_else(self): + async def test(): + result = [] + async for x in AsyncIter([]): + result.append(x) + else: + result.append('else') + return result + + self.assertEqual(run_coro(test()), ['else']) + + def test_async_for_else_on_completion(self): + async def test(): + result = [] + async for x in AsyncIter([1, 2]): + result.append(x) + else: + result.append('else') + return result + + self.assertEqual(run_coro(test()), [1, 2, 'else']) + + def test_async_for_break(self): + async def test(): + result = [] + async for x in AsyncIter([1, 2, 3, 4, 5]): + if x == 3: + break + result.append(x) + return result + + self.assertEqual(run_coro(test()), [1, 2]) + + def test_async_for_break_skips_else(self): + async def test(): + result = [] + async for x in AsyncIter([1, 2, 3]): + if x == 2: + break + result.append(x) + else: + result.append('else') + return result + + self.assertEqual(run_coro(test()), [1]) + + def test_async_for_continue(self): + async def test(): + result = [] + async for x in AsyncIter([1, 2, 3, 4, 5]): + if x % 2 == 0: + continue + result.append(x) + return result + + self.assertEqual(run_coro(test()), [1, 3, 5]) + + def test_nested_async_for(self): + async def test(): + result = [] + async for x in AsyncIter([1, 2]): + async for y in AsyncIter([10, 20]): + result.append(x * 100 + y) + return result + + self.assertEqual(run_coro(test()), [110, 120, 210, 220]) + + +class AsyncCombinedTest(unittest.TestCase): + """Tests combining async with and async for.""" + + def test_async_with_and_for(self): + class CM: + def __init__(self): + self.log = [] + async def __aenter__(self): + self.log.append('enter') + return self + async def __aexit__(self, *args): + self.log.append('exit') + + async def test(): + cm = CM() + async with cm: + async for x in AsyncIter([1, 2]): + cm.log.append(x) + return cm.log + + self.assertEqual(run_coro(test()), ['enter', 1, 2, 'exit']) + + +class DotNetAsyncInteropTest(unittest.TestCase): + """Tests for .NET async interop (await Task, async for IAsyncEnumerable, CancelledError).""" + + def test_await_completed_task(self): + """await a Task that is already completed (Task.CompletedTask).""" + from System.Threading.Tasks import Task + async def test(): + await Task.CompletedTask + return 'done' + self.assertEqual(run_coro(test()), 'done') + + def test_await_task_delay(self): + """await Task.Delay -a real async .NET operation.""" + from System.Threading.Tasks import Task + async def test(): + await Task.Delay(10) + return 'delayed' + self.assertEqual(run_coro(test()), 'delayed') + + def test_await_task_from_result(self): + """await Task.FromResult should return the value.""" + from System.Threading.Tasks import Task + async def test(): + result = await Task.FromResult(42) + return result + self.assertEqual(run_coro(test()), 42) + + def test_await_task_from_result_string(self): + """await Task.FromResult should return the string.""" + from System.Threading.Tasks import Task + async def test(): + result = await Task.FromResult("hello") + return result + self.assertEqual(run_coro(test()), "hello") + + def test_await_multiple_tasks(self): + """Multiple awaits in sequence.""" + from System.Threading.Tasks import Task + async def test(): + a = await Task.FromResult(10) + b = await Task.FromResult(20) + c = await Task.FromResult(30) + return a + b + c + self.assertEqual(run_coro(test()), 60) + + def test_task_has_await(self): + """Task objects should have __await__ method.""" + from System.Threading.Tasks import Task + task = Task.FromResult(99) + self.assertTrue(hasattr(task, '__await__')) + + def test_task_awaitable_protocol(self): + """Task.__await__() should return an iterable that raises StopIteration(value).""" + from System.Threading.Tasks import Task + task = Task.FromResult(42) + awaitable = task.__await__() + it = iter(awaitable) + try: + next(it) + self.fail("Expected StopIteration") + except StopIteration as e: + self.assertEqual(e.value, 42) + + def test_await_faulted_task(self): + """Awaiting a faulted task should propagate the exception.""" + from System import Exception as DotNetException + from System.Threading.Tasks import Task + async def test(): + await Task.FromException(DotNetException("boom")) + with self.assertRaises(DotNetException): + run_coro(test()) + + def test_cancelled_error_from_cancelled_task(self): + """Awaiting a cancelled task should raise CancelledError.""" + from System.Threading import CancellationTokenSource + from System.Threading.Tasks import Task + cts = CancellationTokenSource() + cts.Cancel() + async def test(): + await Task.FromCanceled(cts.Token) + with self.assertRaises(CancelledError): + run_coro(test()) + + def test_cancelled_error_type(self): + """CancelledError should be a subclass of Exception.""" + self.assertTrue(issubclass(CancelledError, Exception)) + + def test_operation_cancelled_maps_to_cancelled_error(self): + """System.OperationCanceledException should map to CancelledError.""" + from System import OperationCanceledException + try: + raise OperationCanceledException("test cancel") + except CancelledError: + pass # expected + + def test_cancellation_token_cancel(self): + """CancellationToken can be used with .NET async APIs.""" + from System.Threading import CancellationTokenSource + cts = CancellationTokenSource() + token = cts.Token + self.assertFalse(token.IsCancellationRequested) + cts.Cancel() + self.assertTrue(token.IsCancellationRequested) + + def test_await_valuetask(self): + """await a ValueTask (non-generic).""" + from System.Threading.Tasks import ValueTask + async def test(): + await ValueTask.CompletedTask + return 'done' + self.assertEqual(run_coro(test()), 'done') + + def test_await_valuetask_generic(self): + """await a ValueTask should return the value.""" + from System.Threading.Tasks import ValueTask + async def test(): + vt = ValueTask[int](42) + result = await vt + return result + self.assertEqual(run_coro(test()), 42) + + def test_valuetask_has_await(self): + """ValueTask should have __await__ method.""" + from System.Threading.Tasks import ValueTask + vt = ValueTask.CompletedTask + self.assertTrue(hasattr(vt, '__await__')) + + def test_valuetask_generic_has_await(self): + """ValueTask should have __await__ method.""" + from System.Threading.Tasks import ValueTask + vt = ValueTask[str]("hello") + self.assertTrue(hasattr(vt, '__await__')) + + def test_await_valuetask_string(self): + """await a ValueTask.""" + from System.Threading.Tasks import ValueTask + async def test(): + vt = ValueTask[str]("world") + return await vt + self.assertEqual(run_coro(test()), "world") + + +import sys +if sys.implementation.name == 'ironpython': + import clr + try: + clr.AddReference('IronPythonTest') + from IronPythonTest import AsyncInteropHelpers + _has_async_helpers = True + except Exception: + _has_async_helpers = False +else: + _has_async_helpers = False + + +@unittest.skipUnless(_has_async_helpers, "requires IronPythonTest with AsyncInteropHelpers") +class DotNetAsyncEnumerableTest(unittest.TestCase): + """Tests for async for over .NET IAsyncEnumerable.""" + + def test_async_for_ints(self): + """async for over IAsyncEnumerable.""" + async def test(): + result = [] + async for x in AsyncInteropHelpers.GetAsyncInts(1, 2, 3): + result.append(x) + return result + self.assertEqual(run_coro(test()), [1, 2, 3]) + + def test_async_for_strings(self): + """async for over IAsyncEnumerable.""" + async def test(): + result = [] + async for s in AsyncInteropHelpers.GetAsyncStrings("a", "b", "c"): + result.append(s) + return result + self.assertEqual(run_coro(test()), ["a", "b", "c"]) + + def test_async_for_empty(self): + """async for over empty IAsyncEnumerable.""" + async def test(): + result = [] + async for x in AsyncInteropHelpers.GetEmptyAsyncInts(): + result.append(x) + return result + self.assertEqual(run_coro(test()), []) + + def test_async_for_break(self): + """break inside async for over IAsyncEnumerable.""" + async def test(): + result = [] + async for x in AsyncInteropHelpers.GetAsyncInts(10, 20, 30, 40, 50): + if x == 30: + break + result.append(x) + return result + self.assertEqual(run_coro(test()), [10, 20]) + + def test_async_for_has_aiter(self): + """IAsyncEnumerable objects should have __aiter__.""" + stream = AsyncInteropHelpers.GetAsyncInts(1) + self.assertTrue(hasattr(stream, '__aiter__')) + + +@unittest.skipUnless(_has_async_helpers, "requires IronPythonTest with AsyncInteropHelpers") +class DotNetRealAsyncTaskTest(unittest.TestCase): + """Tests for await on real async Task methods. + + These test real .NET async methods where the runtime type is + AsyncStateMachineBox, not Task directly. + All methods include real delays (Task.Delay) to ensure truly async behavior. + """ + + def test_await_real_async_int(self): + """await a real async Task with delay.""" + async def test(): + return await AsyncInteropHelpers.GetAsyncInt(42) + self.assertEqual(run_coro(test()), 42) + + def test_await_real_async_string(self): + """await a real async Task with delay.""" + async def test(): + return await AsyncInteropHelpers.GetAsyncString("hello") + self.assertEqual(run_coro(test()), "hello") + + def test_await_real_async_void(self): + """await a real async Task (void) with delay.""" + async def test(): + await AsyncInteropHelpers.DoAsync() + return 'done' + self.assertEqual(run_coro(test()), 'done') + + def test_await_real_async_multiple(self): + """Multiple awaits on real async Task in sequence.""" + async def test(): + a = await AsyncInteropHelpers.GetAsyncInt(10) + b = await AsyncInteropHelpers.GetAsyncInt(20) + s = await AsyncInteropHelpers.GetAsyncString("!") + return str(a + b) + s + self.assertEqual(run_coro(test()), "30!") + + def test_await_real_async_mixed_with_python(self): + """Mix real .NET async with Python coroutines.""" + async def py_double(x): + return x * 2 + + async def test(): + val = await AsyncInteropHelpers.GetAsyncInt(5) + doubled = await py_double(val) + return doubled + self.assertEqual(run_coro(test()), 10) + + +@unittest.skipUnless(_has_async_helpers, "requires IronPythonTest with AsyncInteropHelpers") +class DotNetCancellationTest(unittest.TestCase): + """Tests for CancellationToken and CancelledError with .NET async methods.""" + + def test_cancel_async_task_int(self): + """Cancelling a Task should raise CancelledError.""" + from System.Threading import CancellationTokenSource + cts = CancellationTokenSource() + cts.Cancel() + async def test(): + return await AsyncInteropHelpers.GetAsyncIntWithCancellation(42, cts.Token) + with self.assertRaises(CancelledError): + run_coro(test()) + + def test_cancel_async_task_void(self): + """Cancelling a Task (void) should raise CancelledError.""" + from System.Threading import CancellationTokenSource + cts = CancellationTokenSource() + cts.Cancel() + async def test(): + await AsyncInteropHelpers.DoAsyncWithCancellation(cts.Token) + with self.assertRaises(CancelledError): + run_coro(test()) + + def test_cancel_async_enumerable(self): + """Cancelling during async for over IAsyncEnumerable should raise CancelledError.""" + from System.Threading import CancellationTokenSource + cts = CancellationTokenSource() + cts.Cancel() + async def test(): + result = [] + async for x in AsyncInteropHelpers.GetAsyncIntsWithCancellation(cts.Token, 1, 2, 3): + result.append(x) + return result + with self.assertRaises(CancelledError): + run_coro(test()) + + def test_cancelled_error_is_exception_subclass(self): + """CancelledError should be a subclass of Exception.""" + self.assertTrue(issubclass(CancelledError, Exception)) + + def test_cancelled_error_catch_as_exception(self): + """CancelledError should be catchable as Exception.""" + from System.Threading import CancellationTokenSource + cts = CancellationTokenSource() + cts.Cancel() + async def test(): + try: + await AsyncInteropHelpers.GetAsyncIntWithCancellation(99, cts.Token) + except Exception: + return 'caught' + self.assertEqual(run_coro(test()), 'caught') + + def test_operation_cancelled_maps_to_cancelled_error(self): + """System.OperationCanceledException raised directly should be catchable as CancelledError.""" + from System import OperationCanceledException + caught = False + try: + raise OperationCanceledException("test") + except CancelledError: + caught = True + self.assertTrue(caught) + + +run_test(__name__)