|
34 | 34 | import java.math.BigDecimal; |
35 | 35 | import java.math.BigInteger; |
36 | 36 | import java.math.MathContext; |
37 | | -import java.util.Arrays; |
38 | 37 | import java.util.List; |
39 | 38 |
|
40 | 39 | import com.oracle.graal.python.PythonLanguage; |
|
50 | 49 | import com.oracle.graal.python.builtins.objects.ints.IntBuiltins; |
51 | 50 | import com.oracle.graal.python.builtins.objects.ints.PInt; |
52 | 51 | import com.oracle.graal.python.builtins.objects.tuple.PTuple; |
| 52 | +import com.oracle.graal.python.builtins.objects.type.TpSlots; |
| 53 | +import com.oracle.graal.python.builtins.objects.type.slots.TpSlot; |
| 54 | +import com.oracle.graal.python.builtins.objects.type.slots.TpSlotIterNext; |
53 | 55 | import com.oracle.graal.python.lib.IteratorExhausted; |
54 | 56 | import com.oracle.graal.python.lib.PyBoolCheckNode; |
55 | 57 | import com.oracle.graal.python.lib.PyFloatAsDoubleNode; |
|
91 | 93 | import com.oracle.graal.python.runtime.exception.PException; |
92 | 94 | import com.oracle.graal.python.runtime.object.PFactory; |
93 | 95 | import com.oracle.graal.python.util.OverflowException; |
| 96 | +import com.oracle.graal.python.util.XSum; |
94 | 97 | import com.oracle.truffle.api.CompilerDirectives; |
95 | 98 | import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; |
96 | 99 | import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; |
|
108 | 111 | import com.oracle.truffle.api.dsl.Specialization; |
109 | 112 | import com.oracle.truffle.api.dsl.TypeSystemReference; |
110 | 113 | import com.oracle.truffle.api.frame.VirtualFrame; |
| 114 | +import com.oracle.truffle.api.nodes.LoopNode; |
111 | 115 | import com.oracle.truffle.api.nodes.Node; |
112 | 116 | import com.oracle.truffle.api.profiles.InlinedConditionProfile; |
113 | 117 | import com.oracle.truffle.api.profiles.InlinedLoopConditionProfile; |
@@ -850,127 +854,65 @@ protected ArgumentClinicProvider getArgumentClinic() { |
850 | 854 | @GenerateNodeFactory |
851 | 855 | public abstract static class FsumNode extends PythonUnaryBuiltinNode { |
852 | 856 |
|
| 857 | + /** |
| 858 | + * Note: this specialization uses an inlined version of {@link PyIterNextNode} with the |
| 859 | + * tp_iternext slot moved out of the loop. |
| 860 | + */ |
853 | 861 | @Specialization |
854 | 862 | static double doIt(VirtualFrame frame, Object iterable, |
855 | 863 | @Bind Node inliningTarget, |
856 | 864 | @Cached PyObjectGetIter getIter, |
857 | | - @Cached PyIterNextNode nextNode, |
| 865 | + @Cached GetClassNode nextNodeGetClassNode, |
| 866 | + @Cached TpSlots.GetCachedTpSlotsNode nextNodeGetSlots, |
| 867 | + @Cached TpSlotIterNext.CallSlotTpIterNextNode nextNodeCallNext, |
| 868 | + @Cached IsBuiltinObjectProfile nextNodeStopIterationProfile, |
858 | 869 | @Cached PyFloatAsDoubleNode asDoubleNode, |
859 | | - @Cached InlinedLoopConditionProfile loopProfile, |
860 | 870 | @Cached PRaiseNode raiseNode) { |
861 | | - /* |
862 | | - * This implementation is taken from CPython. The performance is not good. Should be |
863 | | - * faster. It can be easily replace with much simpler code based on BigDecimal: |
864 | | - * |
865 | | - * BigDecimal result = BigDecimal.ZERO; |
866 | | - * |
867 | | - * in cycle just: result = result.add(BigDecimal.valueof(x); ... The current |
868 | | - * implementation is little bit faster. The testFSum in test_math.py takes in different |
869 | | - * implementations: CPython ~0.6s CurrentImpl: ~14.3s Using BigDecimal: ~15.1 |
870 | | - */ |
871 | 871 | Object iterator = getIter.execute(frame, inliningTarget, iterable); |
872 | | - double x, y, t, hi, lo = 0, yr, inf_sum = 0, special_sum = 0, sum; |
873 | | - double xsave; |
874 | | - int i, j, n = 0, arayLength = 32; |
875 | | - double[] p = new double[arayLength]; |
876 | | - boolean exhausted = false; |
877 | | - while (loopProfile.profile(inliningTarget, !exhausted)) { |
878 | | - try { |
879 | | - Object next = nextNode.execute(frame, inliningTarget, iterator); |
880 | | - x = asDoubleNode.execute(frame, inliningTarget, next); |
881 | | - xsave = x; |
882 | | - for (i = j = 0; j < n; j++) { /* for y in partials */ |
883 | | - y = p[j]; |
884 | | - if (Math.abs(x) < Math.abs(y)) { |
885 | | - t = x; |
886 | | - x = y; |
887 | | - y = t; |
888 | | - } |
889 | | - hi = x + y; |
890 | | - yr = hi - x; |
891 | | - lo = y - yr; |
892 | | - if (lo != 0.0) { |
893 | | - p[i++] = lo; |
894 | | - } |
895 | | - x = hi; |
896 | | - } |
897 | 872 |
|
898 | | - n = i; |
899 | | - if (x != 0.0) { |
900 | | - if (!Double.isFinite(x)) { |
901 | | - /* |
902 | | - * a nonfinite x could arise either as a result of intermediate |
903 | | - * overflow, or as a result of a nan or inf in the summands |
904 | | - */ |
905 | | - if (Double.isFinite(xsave)) { |
906 | | - throw raiseNode.raise(inliningTarget, OverflowError, ErrorMessages.INTERMEDIATE_OVERFLOW_IN, "fsum"); |
907 | | - } |
908 | | - if (Double.isInfinite(xsave)) { |
909 | | - inf_sum += xsave; |
910 | | - } |
911 | | - special_sum += xsave; |
912 | | - /* reset partials */ |
913 | | - n = 0; |
914 | | - } else if (n >= arayLength) { |
915 | | - arayLength += arayLength; |
916 | | - p = Arrays.copyOf(p, arayLength); |
917 | | - } else { |
918 | | - p[n++] = x; |
919 | | - } |
| 873 | + TpSlot tpIternext = nextNodeGetSlots.execute(inliningTarget, nextNodeGetClassNode.execute(inliningTarget, iterator)).tp_iternext(); |
| 874 | + assert tpIternext != null; |
| 875 | + |
| 876 | + var acc = new XSum.SmallAccumulator(); |
| 877 | + int loopCount = 0; |
| 878 | + while (true) { |
| 879 | + try { |
| 880 | + Object next = nextNodeCallNext.execute(frame, inliningTarget, tpIternext, iterator); |
| 881 | + if (CompilerDirectives.hasNextTier() && loopCount < Integer.MAX_VALUE) { |
| 882 | + loopCount++; |
920 | 883 | } |
| 884 | + acc.add(asDoubleNode.execute(frame, inliningTarget, next)); |
921 | 885 | } catch (IteratorExhausted e) { |
922 | | - exhausted = true; |
| 886 | + break; |
| 887 | + } catch (PException e) { |
| 888 | + e.expectStopIteration(inliningTarget, nextNodeStopIterationProfile); |
| 889 | + break; |
| 890 | + } finally { |
| 891 | + LoopNode.reportLoopCount(inliningTarget, loopCount); |
923 | 892 | } |
924 | 893 | } |
925 | 894 |
|
926 | | - if (special_sum != 0.0) { |
927 | | - if (Double.isNaN(inf_sum)) { |
| 895 | + if (acc.isNaNResult()) { |
| 896 | + return Double.NaN; |
| 897 | + } |
| 898 | + |
| 899 | + if (acc.isInfiniteResult()) { |
| 900 | + double result = acc.getInfiniteResult(); |
| 901 | + if (Double.isNaN(result)) { |
928 | 902 | throw raiseNode.raise(inliningTarget, ValueError, ErrorMessages.NEG_INF_PLUS_INF_IN); |
929 | 903 | } else { |
930 | | - sum = special_sum; |
931 | | - return sum; |
| 904 | + assert Double.isInfinite(result); |
| 905 | + return result; |
932 | 906 | } |
933 | 907 | } |
934 | 908 |
|
935 | | - hi = 0.0; |
936 | | - if (n > 0) { |
937 | | - hi = p[--n]; |
938 | | - /* |
939 | | - * sum_exact(ps, hi) from the top, stop when the sum becomes inexact. |
940 | | - */ |
941 | | - while (n > 0) { |
942 | | - x = hi; |
943 | | - y = p[--n]; |
944 | | - assert (Math.abs(y) < Math.abs(x)); |
945 | | - hi = x + y; |
946 | | - yr = hi - x; |
947 | | - lo = y - yr; |
948 | | - if (lo != 0.0) { |
949 | | - break; |
950 | | - } |
951 | | - } |
952 | | - /* |
953 | | - * Make half-even rounding work across multiple partials. Needed so that sum([1e-16, |
954 | | - * 1, 1e16]) will round-up the last digit to two instead of down to zero (the 1e-16 |
955 | | - * makes the 1 slightly closer to two). With a potential 1 ULP rounding error |
956 | | - * fixed-up, math.fsum() can guarantee commutativity. |
957 | | - */ |
958 | | - if (n > 0 && ((lo < 0.0 && p[n - 1] < 0.0) || |
959 | | - (lo > 0.0 && p[n - 1] > 0.0))) { |
960 | | - y = lo * 2.0; |
961 | | - x = hi + y; |
962 | | - yr = x - hi; |
963 | | - if (compareAsBigDecimal(y, yr) == 0) { |
964 | | - hi = x; |
965 | | - } |
966 | | - } |
| 909 | + double result = acc.round(); |
| 910 | + // +Inf or -Inf if exponent has overflowed |
| 911 | + if (Double.isInfinite(result)) { |
| 912 | + throw raiseNode.raise(inliningTarget, OverflowError, ErrorMessages.INTERMEDIATE_OVERFLOW_IN, "fsum"); |
| 913 | + } else { |
| 914 | + return result; |
967 | 915 | } |
968 | | - return hi; |
969 | | - } |
970 | | - |
971 | | - @TruffleBoundary |
972 | | - private static int compareAsBigDecimal(double y, double yr) { |
973 | | - return BigDecimal.valueOf(y).compareTo(BigDecimal.valueOf(yr)); |
974 | 916 | } |
975 | 917 | } |
976 | 918 |
|
|
0 commit comments