Skip to content

Commit 8807e35

Browse files
committed
[GR-60145] Use XSum port for math.fsum().
PullRequest: graalpython/4197
2 parents b084104 + f2ee178 commit 8807e35

11 files changed

Lines changed: 787 additions & 142 deletions

File tree

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2017, 2025, Oracle and/or its affiliates.
2+
* Copyright (c) 2017, 2026, Oracle and/or its affiliates.
33
* Copyright (c) 2013, Regents of the University of California
44
*
55
* All rights reserved.
@@ -498,19 +498,21 @@ static boolean doObject(VirtualFrame frame, Object object,
498498
@Cached PyIterNextNode nextNode,
499499
@Cached PyObjectIsTrueNode isTrueNode) {
500500
Object iterator = getIter.execute(frame, inliningTarget, object);
501-
int nbrIter = 0;
501+
int loopCount = 0;
502502

503503
while (true) {
504504
try {
505505
Object next = nextNode.execute(frame, inliningTarget, iterator);
506-
nbrIter++;
506+
if (CompilerDirectives.hasNextTier() && loopCount < Integer.MAX_VALUE) {
507+
loopCount++;
508+
}
507509
if (!isTrueNode.execute(frame, next)) {
508510
return false;
509511
}
510512
} catch (IteratorExhausted e) {
511513
break;
512514
} finally {
513-
LoopNode.reportLoopCount(inliningTarget, nbrIter);
515+
LoopNode.reportLoopCount(inliningTarget, loopCount);
514516
}
515517
}
516518

@@ -549,19 +551,21 @@ static boolean doObject(VirtualFrame frame, Object object,
549551
@Cached PyIterNextNode nextNode,
550552
@Cached PyObjectIsTrueNode isTrueNode) {
551553
Object iterator = getIter.execute(frame, inliningTarget, object);
552-
int nbrIter = 0;
554+
int loopCount = 0;
553555

554556
while (true) {
555557
try {
556558
Object next = nextNode.execute(frame, inliningTarget, iterator);
557-
nbrIter++;
559+
if (CompilerDirectives.hasNextTier() && loopCount < Integer.MAX_VALUE) {
560+
loopCount++;
561+
}
558562
if (isTrueNode.execute(frame, next)) {
559563
return true;
560564
}
561565
} catch (IteratorExhausted e) {
562566
break;
563567
} finally {
564-
LoopNode.reportLoopCount(inliningTarget, nbrIter);
568+
LoopNode.reportLoopCount(inliningTarget, loopCount);
565569
}
566570
}
567571

@@ -1563,11 +1567,13 @@ static Object minmaxSequenceWithKey(VirtualFrame frame, Node inliningTarget, Obj
15631567
currentKey = nextKey;
15641568
currentValue = nextValue;
15651569
}
1566-
loopCount++;
1570+
if (CompilerDirectives.hasNextTier() && loopCount < Integer.MAX_VALUE) {
1571+
loopCount++;
1572+
}
15671573
} catch (IteratorExhausted e) {
15681574
break;
15691575
} finally {
1570-
LoopNode.reportLoopCount(inliningTarget, loopCount < 0 ? Integer.MAX_VALUE : loopCount);
1576+
LoopNode.reportLoopCount(inliningTarget, loopCount);
15711577
}
15721578
}
15731579

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MathModuleBuiltins.java

Lines changed: 45 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import java.math.BigDecimal;
3535
import java.math.BigInteger;
3636
import java.math.MathContext;
37-
import java.util.Arrays;
3837
import java.util.List;
3938

4039
import com.oracle.graal.python.PythonLanguage;
@@ -50,6 +49,9 @@
5049
import com.oracle.graal.python.builtins.objects.ints.IntBuiltins;
5150
import com.oracle.graal.python.builtins.objects.ints.PInt;
5251
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
52+
import com.oracle.graal.python.builtins.objects.type.TpSlots;
53+
import com.oracle.graal.python.builtins.objects.type.slots.TpSlot;
54+
import com.oracle.graal.python.builtins.objects.type.slots.TpSlotIterNext;
5355
import com.oracle.graal.python.lib.IteratorExhausted;
5456
import com.oracle.graal.python.lib.PyBoolCheckNode;
5557
import com.oracle.graal.python.lib.PyFloatAsDoubleNode;
@@ -91,6 +93,7 @@
9193
import com.oracle.graal.python.runtime.exception.PException;
9294
import com.oracle.graal.python.runtime.object.PFactory;
9395
import com.oracle.graal.python.util.OverflowException;
96+
import com.oracle.graal.python.util.XSum;
9497
import com.oracle.truffle.api.CompilerDirectives;
9598
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
9699
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
@@ -108,6 +111,7 @@
108111
import com.oracle.truffle.api.dsl.Specialization;
109112
import com.oracle.truffle.api.dsl.TypeSystemReference;
110113
import com.oracle.truffle.api.frame.VirtualFrame;
114+
import com.oracle.truffle.api.nodes.LoopNode;
111115
import com.oracle.truffle.api.nodes.Node;
112116
import com.oracle.truffle.api.profiles.InlinedConditionProfile;
113117
import com.oracle.truffle.api.profiles.InlinedLoopConditionProfile;
@@ -850,127 +854,65 @@ protected ArgumentClinicProvider getArgumentClinic() {
850854
@GenerateNodeFactory
851855
public abstract static class FsumNode extends PythonUnaryBuiltinNode {
852856

857+
/**
858+
* Note: this specialization uses an inlined version of {@link PyIterNextNode} with the
859+
* tp_iternext slot moved out of the loop.
860+
*/
853861
@Specialization
854862
static double doIt(VirtualFrame frame, Object iterable,
855863
@Bind Node inliningTarget,
856864
@Cached PyObjectGetIter getIter,
857-
@Cached PyIterNextNode nextNode,
865+
@Cached GetClassNode nextNodeGetClassNode,
866+
@Cached TpSlots.GetCachedTpSlotsNode nextNodeGetSlots,
867+
@Cached TpSlotIterNext.CallSlotTpIterNextNode nextNodeCallNext,
868+
@Cached IsBuiltinObjectProfile nextNodeStopIterationProfile,
858869
@Cached PyFloatAsDoubleNode asDoubleNode,
859-
@Cached InlinedLoopConditionProfile loopProfile,
860870
@Cached PRaiseNode raiseNode) {
861-
/*
862-
* This implementation is taken from CPython. The performance is not good. Should be
863-
* faster. It can be easily replace with much simpler code based on BigDecimal:
864-
*
865-
* BigDecimal result = BigDecimal.ZERO;
866-
*
867-
* in cycle just: result = result.add(BigDecimal.valueof(x); ... The current
868-
* implementation is little bit faster. The testFSum in test_math.py takes in different
869-
* implementations: CPython ~0.6s CurrentImpl: ~14.3s Using BigDecimal: ~15.1
870-
*/
871871
Object iterator = getIter.execute(frame, inliningTarget, iterable);
872-
double x, y, t, hi, lo = 0, yr, inf_sum = 0, special_sum = 0, sum;
873-
double xsave;
874-
int i, j, n = 0, arayLength = 32;
875-
double[] p = new double[arayLength];
876-
boolean exhausted = false;
877-
while (loopProfile.profile(inliningTarget, !exhausted)) {
878-
try {
879-
Object next = nextNode.execute(frame, inliningTarget, iterator);
880-
x = asDoubleNode.execute(frame, inliningTarget, next);
881-
xsave = x;
882-
for (i = j = 0; j < n; j++) { /* for y in partials */
883-
y = p[j];
884-
if (Math.abs(x) < Math.abs(y)) {
885-
t = x;
886-
x = y;
887-
y = t;
888-
}
889-
hi = x + y;
890-
yr = hi - x;
891-
lo = y - yr;
892-
if (lo != 0.0) {
893-
p[i++] = lo;
894-
}
895-
x = hi;
896-
}
897872

898-
n = i;
899-
if (x != 0.0) {
900-
if (!Double.isFinite(x)) {
901-
/*
902-
* a nonfinite x could arise either as a result of intermediate
903-
* overflow, or as a result of a nan or inf in the summands
904-
*/
905-
if (Double.isFinite(xsave)) {
906-
throw raiseNode.raise(inliningTarget, OverflowError, ErrorMessages.INTERMEDIATE_OVERFLOW_IN, "fsum");
907-
}
908-
if (Double.isInfinite(xsave)) {
909-
inf_sum += xsave;
910-
}
911-
special_sum += xsave;
912-
/* reset partials */
913-
n = 0;
914-
} else if (n >= arayLength) {
915-
arayLength += arayLength;
916-
p = Arrays.copyOf(p, arayLength);
917-
} else {
918-
p[n++] = x;
919-
}
873+
TpSlot tpIternext = nextNodeGetSlots.execute(inliningTarget, nextNodeGetClassNode.execute(inliningTarget, iterator)).tp_iternext();
874+
assert tpIternext != null;
875+
876+
var acc = new XSum.SmallAccumulator();
877+
int loopCount = 0;
878+
while (true) {
879+
try {
880+
Object next = nextNodeCallNext.execute(frame, inliningTarget, tpIternext, iterator);
881+
if (CompilerDirectives.hasNextTier() && loopCount < Integer.MAX_VALUE) {
882+
loopCount++;
920883
}
884+
acc.add(asDoubleNode.execute(frame, inliningTarget, next));
921885
} catch (IteratorExhausted e) {
922-
exhausted = true;
886+
break;
887+
} catch (PException e) {
888+
e.expectStopIteration(inliningTarget, nextNodeStopIterationProfile);
889+
break;
890+
} finally {
891+
LoopNode.reportLoopCount(inliningTarget, loopCount);
923892
}
924893
}
925894

926-
if (special_sum != 0.0) {
927-
if (Double.isNaN(inf_sum)) {
895+
if (acc.isNaNResult()) {
896+
return Double.NaN;
897+
}
898+
899+
if (acc.isInfiniteResult()) {
900+
double result = acc.getInfiniteResult();
901+
if (Double.isNaN(result)) {
928902
throw raiseNode.raise(inliningTarget, ValueError, ErrorMessages.NEG_INF_PLUS_INF_IN);
929903
} else {
930-
sum = special_sum;
931-
return sum;
904+
assert Double.isInfinite(result);
905+
return result;
932906
}
933907
}
934908

935-
hi = 0.0;
936-
if (n > 0) {
937-
hi = p[--n];
938-
/*
939-
* sum_exact(ps, hi) from the top, stop when the sum becomes inexact.
940-
*/
941-
while (n > 0) {
942-
x = hi;
943-
y = p[--n];
944-
assert (Math.abs(y) < Math.abs(x));
945-
hi = x + y;
946-
yr = hi - x;
947-
lo = y - yr;
948-
if (lo != 0.0) {
949-
break;
950-
}
951-
}
952-
/*
953-
* Make half-even rounding work across multiple partials. Needed so that sum([1e-16,
954-
* 1, 1e16]) will round-up the last digit to two instead of down to zero (the 1e-16
955-
* makes the 1 slightly closer to two). With a potential 1 ULP rounding error
956-
* fixed-up, math.fsum() can guarantee commutativity.
957-
*/
958-
if (n > 0 && ((lo < 0.0 && p[n - 1] < 0.0) ||
959-
(lo > 0.0 && p[n - 1] > 0.0))) {
960-
y = lo * 2.0;
961-
x = hi + y;
962-
yr = x - hi;
963-
if (compareAsBigDecimal(y, yr) == 0) {
964-
hi = x;
965-
}
966-
}
909+
double result = acc.round();
910+
// +Inf or -Inf if exponent has overflowed
911+
if (Double.isInfinite(result)) {
912+
throw raiseNode.raise(inliningTarget, OverflowError, ErrorMessages.INTERMEDIATE_OVERFLOW_IN, "fsum");
913+
} else {
914+
return result;
967915
}
968-
return hi;
969-
}
970-
971-
@TruffleBoundary
972-
private static int compareAsBigDecimal(double y, double yr) {
973-
return BigDecimal.valueOf(y).compareTo(BigDecimal.valueOf(yr));
974916
}
975917
}
976918

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/functools/FunctoolsModuleBuiltins.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2018, 2026, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -45,7 +45,6 @@
4545
import static com.oracle.graal.python.nodes.BuiltinNames.T_FUNCTOOLS;
4646
import static com.oracle.graal.python.nodes.ErrorMessages.REDUCE_EMPTY_SEQ;
4747
import static com.oracle.graal.python.nodes.ErrorMessages.S_ARG_N_MUST_SUPPORT_ITERATION;
48-
import static com.oracle.truffle.api.nodes.LoopNode.reportLoopCount;
4948

5049
import java.util.List;
5150

@@ -75,6 +74,7 @@
7574
import com.oracle.truffle.api.dsl.NodeFactory;
7675
import com.oracle.truffle.api.dsl.Specialization;
7776
import com.oracle.truffle.api.frame.VirtualFrame;
77+
import com.oracle.truffle.api.nodes.LoopNode;
7878
import com.oracle.truffle.api.nodes.Node;
7979
import com.oracle.truffle.api.profiles.InlinedConditionProfile;
8080

@@ -136,7 +136,7 @@ Object doReduce(VirtualFrame frame, Object function, Object sequence, Object ini
136136

137137
Object[] args = new Object[2];
138138

139-
int count = 0;
139+
int loopCount = 0;
140140
while (true) {
141141
Object op2;
142142
try {
@@ -152,11 +152,11 @@ Object doReduce(VirtualFrame frame, Object function, Object sequence, Object ini
152152
args[1] = op2;
153153
result = callNode.execute(frame, function, args);
154154
}
155-
if (CompilerDirectives.hasNextTier()) {
156-
count++;
155+
if (CompilerDirectives.hasNextTier() && loopCount < Integer.MAX_VALUE) {
156+
loopCount++;
157157
}
158158
}
159-
reportLoopCount(this, count >= 0 ? count : Integer.MAX_VALUE);
159+
LoopNode.reportLoopCount(this, loopCount);
160160

161161
if (result == null) {
162162
throw raiseNode.raise(inliningTarget, PythonBuiltinClassType.TypeError, REDUCE_EMPTY_SEQ);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingStorageNodes.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2022, 2026, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -1216,12 +1216,12 @@ static boolean doIt(Frame frame, Node inliningTarget, HashingStorage aStorage, H
12161216
if (lenANode.execute(inliningTarget, aStorage) != lenBNode.execute(inliningTarget, bStorage)) {
12171217
return false;
12181218
}
1219-
int index = 0;
1219+
int loopCount = 0;
12201220
try {
12211221
HashingStorageIterator aIter = getAIter.execute(inliningTarget, aStorage);
12221222
while (loopProfile.profile(inliningTarget, aIterNext.execute(inliningTarget, aStorage, aIter))) {
1223-
if (CompilerDirectives.hasNextTier()) {
1224-
index++;
1223+
if (CompilerDirectives.hasNextTier() && loopCount < Integer.MAX_VALUE) {
1224+
loopCount++;
12251225
}
12261226

12271227
Object aKey = aIterKey.execute(inliningTarget, aStorage, aIter);
@@ -1236,8 +1236,8 @@ static boolean doIt(Frame frame, Node inliningTarget, HashingStorage aStorage, H
12361236
return false;
12371237
}
12381238
} finally {
1239-
if (index != 0) {
1240-
LoopNode.reportLoopCount(inliningTarget, index);
1239+
if (loopCount != 0) {
1240+
LoopNode.reportLoopCount(inliningTarget, loopCount);
12411241
}
12421242
}
12431243
return true;
@@ -1287,19 +1287,19 @@ static Object doIt(Frame frame, Node callbackInliningTarget, HashingStorage stor
12871287
@Cached HashingStorageGetIterator getIter,
12881288
@Cached HashingStorageIteratorNext iterNext,
12891289
@Cached InlinedLoopConditionProfile loopProfile) {
1290-
int index = 0;
1290+
int loopCount = 0;
12911291
Object accumulator = accumulatorIn;
12921292
try {
12931293
HashingStorageIterator aIter = getIter.execute(inliningTarget, storage);
12941294
while (loopProfile.profile(inliningTarget, iterNext.execute(inliningTarget, storage, aIter))) {
1295-
if (CompilerDirectives.hasNextTier()) {
1296-
index++;
1295+
if (CompilerDirectives.hasNextTier() && loopCount < Integer.MAX_VALUE) {
1296+
loopCount++;
12971297
}
12981298
accumulator = callback.execute(frame, callbackInliningTarget, storage, aIter, accumulator);
12991299
}
13001300
} finally {
1301-
if (index != 0) {
1302-
LoopNode.reportLoopCount(getIter, index);
1301+
if (loopCount != 0) {
1302+
LoopNode.reportLoopCount(getIter, loopCount);
13031303
}
13041304
}
13051305
return accumulator;

0 commit comments

Comments
 (0)