Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 170 additions & 21 deletions src/FastExpressionCompiler/FastExpressionCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,15 @@ public static Result TryCollectInfo(ref ClosureInfo closure, Expression expr,
}
case ExpressionType.Conditional:
var condExpr = (ConditionalExpression)expr;
// Try structural branch elimination - skip collecting dead branch info
{
var reducedCond = Tools.TryReduceConditional(condExpr);
if (!ReferenceEquals(reducedCond, condExpr))
{
expr = reducedCond;
continue;
}
}
if ((r = TryCollectInfo(ref closure, condExpr.Test, paramExprs, nestedLambda, ref rootNestedLambdas, flags)) != Result.OK ||
(r = TryCollectInfo(ref closure, condExpr.IfFalse, paramExprs, nestedLambda, ref rootNestedLambdas, flags)) != Result.OK)
return r;
Expand Down Expand Up @@ -1604,6 +1613,16 @@ public static Result TryCollectInfo(ref ClosureInfo closure, Expression expr,

case ExpressionType.Switch:
var switchExpr = ((SwitchExpression)expr);
// Compile-time switch branch elimination (#489): if switch value is interpretable, collect only the matching branch
if (Interpreter.TryFindSwitchBranch(switchExpr, flags, out var switchMatchedBody))
{
if (switchMatchedBody != null)
{
expr = switchMatchedBody;
continue;
}
return r; // no matched body and no default → nothing to collect
}
if ((r = TryCollectInfo(ref closure, switchExpr.SwitchValue, paramExprs, nestedLambda, ref rootNestedLambdas, flags)) != Result.OK ||
switchExpr.DefaultBody != null && // todo: @check is the order of collection affects the result?
(r = TryCollectInfo(ref closure, switchExpr.DefaultBody, paramExprs, nestedLambda, ref rootNestedLambdas, flags)) != Result.OK)
Expand Down Expand Up @@ -2247,6 +2266,15 @@ public static bool TryEmit(Expression expr,
expr = testIsTrue ? condExpr.IfTrue : condExpr.IfFalse;
continue; // no recursion, just continue with the left or right side of condition
}
// Try structural branch elimination (e.g., null == Default(X) → always true/false)
{
var reducedCond = Tools.TryReduceConditional(condExpr);
if (!ReferenceEquals(reducedCond, condExpr))
{
expr = reducedCond;
continue;
}
}
return TryEmitConditional(testExpr, condExpr.IfTrue, condExpr.IfFalse, paramExprs, il, ref closure, setup, parent);

case ExpressionType.PostIncrementAssign:
Expand Down Expand Up @@ -5379,25 +5407,8 @@ private struct TestValueAndMultiTestCaseIndex
public int MultiTestValCaseBodyIdxPlusOne; // 0 means not multi-test case, otherwise index+1
}

private static long ConvertValueObjectToLong(object valObj)
{
Debug.Assert(valObj != null);
var type = valObj.GetType();
type = type.IsEnum ? Enum.GetUnderlyingType(type) : type;
return Type.GetTypeCode(type) switch
{
TypeCode.Char => (long)(char)valObj,
TypeCode.SByte => (long)(sbyte)valObj,
TypeCode.Byte => (long)(byte)valObj,
TypeCode.Int16 => (long)(short)valObj,
TypeCode.UInt16 => (long)(ushort)valObj,
TypeCode.Int32 => (long)(int)valObj,
TypeCode.UInt32 => (long)(uint)valObj,
TypeCode.Int64 => (long)valObj,
TypeCode.UInt64 => (long)(ulong)valObj,
_ => 0 // unreachable
};
}
private static long ConvertValueObjectToLong(object valObj) =>
Interpreter.ConvertValueObjectToLong(valObj);

#if LIGHT_EXPRESSION
private static bool TryEmitSwitch(SwitchExpression expr, IParameterProvider paramExprs, ILGenerator il, ref ClosureInfo closure,
Expand All @@ -5413,6 +5424,10 @@ private static bool TryEmitSwitch(SwitchExpression expr, IReadOnlyList<PE> param
var caseCount = cases.Count;
var defaultBody = expr.DefaultBody;

// Compile-time switch branch elimination (#489): if the switch value is interpretable, select the matching branch
if (Interpreter.TryFindSwitchBranch(expr, setup, out var matchedBody))
return matchedBody == null || TryEmit(matchedBody, paramExprs, il, ref closure, setup, parent);

// Optimization for the single case
if (caseCount == 1 & defaultBody != null)
{
Expand Down Expand Up @@ -7213,6 +7228,28 @@ internal static bool TryUnboxToPrimitiveValue(ref PValue value, object boxedValu
_ => UnreachableCase(code, (object)null)
};

/// <summary>Converts an integer/enum/char boxed value to <c>long</c> for uniform comparison.</summary>
[MethodImpl((MethodImplOptions)256)]
internal static long ConvertValueObjectToLong(object valObj)
{
Debug.Assert(valObj != null);
var type = valObj.GetType();
type = type.IsEnum ? Enum.GetUnderlyingType(type) : type;
return Type.GetTypeCode(type) switch
{
TypeCode.Char => (long)(char)valObj,
TypeCode.SByte => (long)(sbyte)valObj,
TypeCode.Byte => (long)(byte)valObj,
TypeCode.Int16 => (long)(short)valObj,
TypeCode.UInt16 => (long)(ushort)valObj,
TypeCode.Int32 => (long)(int)valObj,
TypeCode.UInt32 => (long)(uint)valObj,
TypeCode.Int64 => (long)valObj,
TypeCode.UInt64 => (long)(ulong)valObj,
_ => 0 // unreachable
};
}

internal static bool ComparePrimitiveValues(ref PValue left, ref PValue right, TypeCode code, ExpressionType nodeType)
{
switch (nodeType)
Expand Down Expand Up @@ -7545,7 +7582,7 @@ public static bool TryInterpretBool(out bool result, Expression expr, CompilerFl
{
var exprType = expr.Type;
Debug.Assert(exprType.IsPrimitive, // todo: @feat nullables are not supported yet // || Nullable.GetUnderlyingType(exprType)?.IsPrimitive == true,
"Can only reduce the boolean for the expressions of primitive types but found " + expr.Type);
"Can only reduce the boolean for the expressions of primitive type but found " + expr.Type);
result = false;
if ((flags & CompilerFlags.DisableInterpreter) != 0)
return false;
Expand All @@ -7564,6 +7601,95 @@ public static bool TryInterpretBool(out bool result, Expression expr, CompilerFl
}
}

/// <summary>
/// Tries to determine at compile time which branch a switch expression will take.
/// Works for integer/enum and string switch values with no custom equality method.
/// Returns true when the switch value is deterministic; <paramref name="matchedBody"/> is set to
/// the branch body to emit (null means use default body which may itself be null).
/// </summary>
public static bool TryFindSwitchBranch(SwitchExpression switchExpr, CompilerFlags flags, out Expression matchedBody)
{
matchedBody = null;
if (switchExpr.Comparison != null) return false; // custom equality: can't interpret statically
if ((flags & CompilerFlags.DisableInterpreter) != 0) return false;
var switchValueExpr = switchExpr.SwitchValue;
var switchValueType = switchValueExpr.Type;
var cases = switchExpr.Cases;
try
{
// String switch: only constant switch values supported
if (switchValueType == typeof(string))
{
if (switchValueExpr is not ConstantExpression ce) return false;
var switchStr = ce.Value;
for (var i = 0; i < cases.Count; i++)
{
var testValues = cases[i].TestValues;
for (var j = 0; j < testValues.Count; j++)
{
if (testValues[j] is not ConstantExpression testConst) return false;
if (Equals(switchStr, testConst.Value)) { matchedBody = cases[i].Body; return true; }
}
}
matchedBody = switchExpr.DefaultBody;
return true;
}

// Integer / enum / char switch
var effectiveType = switchValueType.IsEnum ? Enum.GetUnderlyingType(switchValueType) : switchValueType;
var typeCode = Type.GetTypeCode(effectiveType);
if (typeCode < TypeCode.Char || typeCode > TypeCode.UInt64) return false; // non-integral (e.g. float, decimal)

long switchValLong;
if (switchValueExpr is ConstantExpression switchConst && switchConst.Value != null)
switchValLong = ConvertValueObjectToLong(switchConst.Value);
else if (typeCode == TypeCode.Int32)
{
var intVal = 0;
if (!TryInterpretInt(ref intVal, switchValueExpr, switchValueExpr.NodeType)) return false;
switchValLong = intVal;
}
else
{
PValue pv = default;
if (!TryInterpretPrimitiveValue(ref pv, switchValueExpr, typeCode, switchValueExpr.NodeType)) return false;
switchValLong = PValueToLong(ref pv, typeCode);
}

for (var i = 0; i < cases.Count; i++)
{
var testValues = cases[i].TestValues;
for (var j = 0; j < testValues.Count; j++)
{
if (testValues[j] is not ConstantExpression testConst || testConst.Value == null) continue;
if (switchValLong == ConvertValueObjectToLong(testConst.Value)) { matchedBody = cases[i].Body; return true; }
}
}
matchedBody = switchExpr.DefaultBody;
return true;
}
catch
{
return false;
}
}

/// <summary>Converts a <see cref="PValue"/> union to a <c>long</c> for integer/char comparison.</summary>
[MethodImpl((MethodImplOptions)256)]
internal static long PValueToLong(ref PValue value, TypeCode code) => code switch
{
TypeCode.Char => (long)value.CharValue,
TypeCode.SByte => (long)value.SByteValue,
TypeCode.Byte => (long)value.ByteValue,
TypeCode.Int16 => (long)value.Int16Value,
TypeCode.UInt16 => (long)value.UInt16Value,
TypeCode.Int32 => (long)value.Int32Value,
TypeCode.UInt32 => (long)value.UInt32Value,
TypeCode.Int64 => value.Int64Value,
TypeCode.UInt64 => (long)value.UInt64Value,
_ => 0L,
};

// todo: @perf try split to `TryInterpretBinary` overload to streamline the calls for TryEmitConditional and similar
/// <summary>Tries to interpret the expression of the Primitive type of Constant, Convert, Logical, Comparison, Arithmetic.</summary>
internal static bool TryInterpretBool(ref bool resultBool, Expression expr, ExpressionType nodeType)
Expand Down Expand Up @@ -8591,7 +8717,9 @@ public static Expression TryReduceConditional(ConditionalExpression condExpr)
var testExpr = TryReduceConditionalTest(condExpr.Test);
if (testExpr is BinaryExpression bi && (bi.NodeType == ExpressionType.Equal || bi.NodeType == ExpressionType.NotEqual))
{
if (bi.Left is ConstantExpression lc && bi.Right is ConstantExpression rc)
var left = bi.Left;
var right = bi.Right;
if (left is ConstantExpression lc && right is ConstantExpression rc)
{
#if INTERPRETATION_DIAGNOSTICS
Console.WriteLine("//Reduced Conditional in Interpretation: " + condExpr);
Expand All @@ -8601,12 +8729,33 @@ public static Expression TryReduceConditional(ConditionalExpression condExpr)
? (equals ? condExpr.IfTrue : condExpr.IfFalse)
: (equals ? condExpr.IfFalse: condExpr.IfTrue);
}

// Handle compile-time branch elimination for null/default equality:
// e.g. Constant(null) == Default(typeof(X)) or Default(typeof(X)) == Constant(null)
// where X is a reference, interface, or nullable type - both represent null, so they are always equal
var leftIsNull = left is ConstantExpression lnc && lnc.Value == null ||
left is DefaultExpression lde && IsNullDefault(lde.Type);
var rightIsNull = right is ConstantExpression rnc && rnc.Value == null ||
right is DefaultExpression rde && IsNullDefault(rde.Type);
if (leftIsNull && rightIsNull)
{
#if INTERPRETATION_DIAGNOSTICS
Console.WriteLine("//Reduced Conditional (null/default equality) in Interpretation: " + condExpr);
#endif
// both sides represent null, so they are equal
return bi.NodeType == ExpressionType.Equal ? condExpr.IfTrue : condExpr.IfFalse;
}
}

return testExpr is ConstantExpression constExpr && constExpr.Value is bool testBool
? (testBool ? condExpr.IfTrue : condExpr.IfFalse)
: condExpr;
}

// Returns true if the type's default value is null (reference types, interfaces, and Nullable<T>)
[MethodImpl((MethodImplOptions)256)]
internal static bool IsNullDefault(Type type) =>
type.IsClass || type.IsInterface || Nullable.GetUnderlyingType(type) != null;
}

[RequiresUnreferencedCode(Trimming.Message)]
Expand Down
Loading
Loading