From a6381fc831792e337991e12b3a793c5aab29052d Mon Sep 17 00:00:00 2001 From: min-guk Date: Tue, 25 Feb 2025 23:46:44 +0900 Subject: [PATCH 1/9] prog level fedplanner --- .../hops/fedplanner/FederatedMemoTable.java | 332 +++----- .../fedplanner/FederatedMemoTablePrinter.java | 302 ++++--- .../FederatedPlanCostEnumerator.java | 770 +++++++++++++----- .../FederatedPlanCostEstimator.java | 467 ++++++----- .../FederatedPlanCostEnumeratorTest.java | 158 ++-- .../FederatedPlanCostEnumeratorTest4.dml | 28 + .../FederatedPlanCostEnumeratorTest5.dml | 26 + .../FederatedPlanCostEnumeratorTest6.dml | 34 + .../FederatedPlanCostEnumeratorTest7.dml | 28 + .../FederatedPlanCostEnumeratorTest8.dml | 49 ++ .../FederatedPlanCostEnumeratorTest9.dml | 58 ++ 11 files changed, 1422 insertions(+), 830 deletions(-) create mode 100644 src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest4.dml create mode 100644 src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml create mode 100644 src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml create mode 100644 src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest7.dml create mode 100644 src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest8.dml create mode 100644 src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest9.dml diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java index b2b58871f62..dae809179b6 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -17,200 +17,138 @@ * under the License. */ -package org.apache.sysds.hops.fedplanner; - -import org.apache.sysds.hops.Hop; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.ArrayList; -import java.util.Map; - -/** - * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. - * This table stores and manages different execution plan variants for each Hop and fedOutType combination, - * facilitating the optimization of federated execution plans. - */ -public class FederatedMemoTable { - // Maps Hop ID and fedOutType pairs to their plan variants - private final Map, FedPlanVariants> hopMemoTable = new HashMap<>(); - - /** - * Adds a new federated plan to the memo table. - * Creates a new variant list if none exists for the given Hop and fedOutType. - * - * @param hop The Hop node - * @param fedOutType The federated output type - * @param planChilds List of child plan references - * @return The newly created FedPlan - */ - public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List> planChilds) { - long hopID = hop.getHopID(); - FedPlanVariants fedPlanVariantList; - - if (contains(hopID, fedOutType)) { - fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); - } else { - fedPlanVariantList = new FedPlanVariants(hop, fedOutType); - hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList); - } - - FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList); - fedPlanVariantList.addFedPlan(newPlan); - - return newPlan; - } - - /** - * Retrieves the minimum cost child plan considering the parent's output type. - * The cost is calculated using getParentViewCost to account for potential type mismatches. - * - * @param fedPlanPair ??? - * @return min cost fed plan - */ - public FedPlan getMinCostFedPlan(Pair fedPlanPair) { - FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); - return fedPlanVariantList._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - } - - public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType) { - return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); - } - - public FedPlanVariants getFedPlanVariants(Pair fedPlanPair) { - return hopMemoTable.get(fedPlanPair); - } - - public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) { - // Todo: Consider whether to verify if pruning has been performed - FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); - return fedPlanVariantList._fedPlanVariants.get(0); - } - - public FedPlan getFedPlanAfterPrune(Pair fedPlanPair) { - // Todo: Consider whether to verify if pruning has been performed - FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); - return fedPlanVariantList._fedPlanVariants.get(0); - } - - /** - * Checks if the memo table contains an entry for a given Hop and fedOutType. - * - * @param hopID The Hop ID. - * @param fedOutType The associated fedOutType. - * @return True if the entry exists, false otherwise. - */ - public boolean contains(long hopID, FederatedOutput fedOutType) { - return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); - } - - /** - * Prunes the specified entry in the memo table, retaining only the minimum-cost - * FedPlan for the given Hop ID and federated output type. - * - * @param hopID The ID of the Hop to prune - * @param federatedOutput The federated output type associated with the Hop - */ - public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) { - hopMemoTable.get(new ImmutablePair<>(hopID, federatedOutput)).prune(); - } - - /** - * Represents common properties and costs associated with a Hop. - * This class holds a reference to the Hop and tracks its execution and network transfer costs. - */ - public static class HopCommon { - protected final Hop hopRef; // Reference to the associated Hop - protected double selfCost; // Current execution cost (compute + memory access) - protected double netTransferCost; // Network transfer cost - - protected HopCommon(Hop hopRef) { - this.hopRef = hopRef; - this.selfCost = 0; - this.netTransferCost = 0; - } - } - - /** - * Represents a collection of federated execution plan variants for a specific Hop and FederatedOutput. - * This class contains cost information and references to the associated plans. - * It uses HopCommon to store common properties and costs related to the Hop. - */ - public static class FedPlanVariants { - protected HopCommon hopCommon; // Common properties and costs for the Hop - private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) - protected List _fedPlanVariants; // List of plan variants - - public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) { - this.hopCommon = new HopCommon(hopRef); - this.fedOutType = fedOutType; - this._fedPlanVariants = new ArrayList<>(); - } - - public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} - public List getFedPlanVariants() {return _fedPlanVariants;} - public boolean isEmpty() {return _fedPlanVariants.isEmpty();} - - public void prune() { - if (_fedPlanVariants.size() > 1) { - // Find the FedPlan with the minimum cost - FedPlan minCostPlan = _fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - - // Retain only the minimum cost plan - _fedPlanVariants.clear(); - _fedPlanVariants.add(minCostPlan); - } - } - } - - /** - * Represents a single federated execution plan with its associated costs and dependencies. - * This class contains: - * 1. selfCost: Cost of current hop (compute + input/output memory access) - * 2. totalCost: Cumulative cost including this plan and all child plans - * 3. netTransferCost: Network transfer cost for this plan to parent plan. - * - * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs. - */ - public static class FedPlan { - private double totalCost; // Total cost including child plans - private final FedPlanVariants fedPlanVariants; // Reference to variant list - private final List> childFedPlans; // Child plan references - - public FedPlan(List> childFedPlans, FedPlanVariants fedPlanVariants) { - this.totalCost = 0; - this.childFedPlans = childFedPlans; - this.fedPlanVariants = fedPlanVariants; - } - - public void setTotalCost(double totalCost) {this.totalCost = totalCost;} - public void setSelfCost(double selfCost) {fedPlanVariants.hopCommon.selfCost = selfCost;} - public void setNetTransferCost(double netTransferCost) {fedPlanVariants.hopCommon.netTransferCost = netTransferCost;} - - public Hop getHopRef() {return fedPlanVariants.hopCommon.hopRef;} - public long getHopID() {return fedPlanVariants.hopCommon.hopRef.getHopID();} - public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;} - public double getTotalCost() {return totalCost;} - public double getSelfCost() {return fedPlanVariants.hopCommon.selfCost;} - public double getNetTransferCost() {return fedPlanVariants.hopCommon.netTransferCost;} - public List> getChildFedPlans() {return childFedPlans;} - - /** - * Calculates the conditional network transfer cost based on output type compatibility. - * Returns 0 if output types match, otherwise returns the network transfer cost. - * @param parentFedOutType The federated output type of the parent plan. - * @return The conditional network transfer cost. - */ - public double getCondNetTransferCost(FederatedOutput parentFedOutType) { - if (parentFedOutType == getFedOutType()) return 0; - return fedPlanVariants.hopCommon.netTransferCost; - } - } -} + package org.apache.sysds.hops.fedplanner; + + import java.util.Comparator; + import java.util.HashMap; + import java.util.List; + import java.util.ArrayList; + import java.util.Map; + import org.apache.sysds.hops.Hop; + import org.apache.commons.lang3.tuple.Pair; + import org.apache.commons.lang3.tuple.ImmutablePair; + import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + + /** + * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. + * This table stores and manages different execution plan variants for each Hop and fedOutType combination, + * facilitating the optimization of federated execution plans. + */ + public class FederatedMemoTable { + // Maps Hop ID and fedOutType pairs to their plan variants + private final Map, FedPlanVariants> hopMemoTable = new HashMap<>(); + + public void addFedPlanVariants(long hopID, FederatedOutput fedOutType, FedPlanVariants fedPlanVariants) { + hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariants); + } + + public FedPlanVariants getFedPlanVariants(Pair fedPlanPair) { + return hopMemoTable.get(fedPlanPair); + } + + public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) { + FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + return fedPlanVariantList._fedPlanVariants.get(0); + } + + public FedPlan getFedPlanAfterPrune(Pair fedPlanPair) { + FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); + return fedPlanVariantList._fedPlanVariants.get(0); + } + + public boolean contains(long hopID, FederatedOutput fedOutType) { + return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); + } + + /** + * Represents a single federated execution plan with its associated costs and dependencies. + * This class contains: + * 1. selfCost: Cost of the current hop (computation + input/output memory access). + * 2. cumulativeCost: Total cost including this plan's selfCost and all child plans' cumulativeCost. + * 3. forwardingCost: Network transfer cost for this plan to the parent plan. + * + * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs. + */ + public static class FedPlan { + private double cumulativeCost; // Total cost = sum of selfCost + cumulativeCost of child plans + private final FedPlanVariants fedPlanVariants; // Reference to variant list + private final List> childFedPlans; // Child plan references + + public FedPlan(double cumulativeCost, FedPlanVariants fedPlanVariants, List> childFedPlans) { + this.cumulativeCost = cumulativeCost; + this.fedPlanVariants = fedPlanVariants; + this.childFedPlans = childFedPlans; + } + + public Hop getHopRef() {return fedPlanVariants.hopCommon.getHopRef();} + public long getHopID() {return fedPlanVariants.hopCommon.getHopRef().getHopID();} + public FederatedOutput getFedOutType() {return fedPlanVariants.getFedOutType();} + public double getCumulativeCost() {return cumulativeCost;} + public double getSelfCost() {return fedPlanVariants.hopCommon.getSelfCost();} + public double getForwardingCost() {return fedPlanVariants.hopCommon.getForwardingCost();} + public double getWeight() {return fedPlanVariants.hopCommon.getWeight();} + public List> getChildFedPlans() {return childFedPlans;} + } + + /** + * Represents a collection of federated execution plan variants for a specific Hop and FederatedOutput. + * This class contains cost information and references to the associated plans. + * It uses HopCommon to store common properties and costs related to the Hop. + */ + public static class FedPlanVariants { + protected HopCommon hopCommon; // Common properties and costs for the Hop + private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) + protected List _fedPlanVariants; // List of plan variants + + public FedPlanVariants(HopCommon hopCommon, FederatedOutput fedOutType) { + this.hopCommon = hopCommon; + this.fedOutType = fedOutType; + this._fedPlanVariants = new ArrayList<>(); + } + + public boolean isEmpty() {return _fedPlanVariants.isEmpty();} + public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} + public List getFedPlanVariants() {return _fedPlanVariants;} + public FederatedOutput getFedOutType() {return fedOutType;} + + public void pruneFedPlans() { + if (_fedPlanVariants.size() > 1) { + // Find the FedPlan with the minimum cumulative cost + FedPlan minCostPlan = _fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getCumulativeCost)) + .orElse(null); + + // Retain only the minimum cost plan + _fedPlanVariants.clear(); + _fedPlanVariants.add(minCostPlan); + } + } + } + + /** + * Represents common properties and costs associated with a Hop. + * This class holds a reference to the Hop and tracks its execution and network forwarding (transfer) costs. + */ + public static class HopCommon { + protected final Hop hopRef; // Reference to the associated Hop + protected double selfCost; // Cost of the hop's computation and memory access + protected double forwardingCost; // Cost of forwarding the hop's output to its parent + protected double weight; // Weight used to calculate cost based on hop execution frequency + + public HopCommon(Hop hopRef, double weight) { + this.hopRef = hopRef; + this.selfCost = 0; + this.forwardingCost = 0; + this.weight = weight; + } + + public Hop getHopRef() {return hopRef;} + public double getSelfCost() {return selfCost;} + public double getForwardingCost() {return forwardingCost;} + public double getWeight() {return weight;} + + protected void setSelfCost(double selfCost) {this.selfCost = selfCost;} + protected void setForwardingCost(double forwardingCost) {this.forwardingCost = forwardingCost;} + } + } + \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java index f7b3343a986..ddddc641d2e 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java @@ -1,139 +1,189 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - package org.apache.sysds.hops.fedplanner; import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; import org.apache.sysds.runtime.instructions.fed.FEDInstruction; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import java.util.HashSet; import java.util.List; import java.util.Set; public class FederatedMemoTablePrinter { - /** - * Recursively prints a tree representation of the DAG starting from the given root FedPlan. - * Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node. - * Additionally, prints the additional total cost once at the beginning. - * - * @param rootFedPlan The starting point FedPlan to print - * @param memoTable The memoization table containing FedPlan variants - * @param additionalTotalCost The additional cost to be printed once - */ - public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, FederatedMemoTable memoTable, - double additionalTotalCost) { - System.out.println("Additional Cost: " + additionalTotalCost); - Set visited = new HashSet<>(); - printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0); - } - - /** - * Helper method to recursively print the FedPlan tree. - * - * @param plan The current FedPlan to print - * @param visited Set to keep track of visited FedPlans (prevents cycles) - * @param depth The current depth level for indentation - */ - private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, - Set visited, int depth) { - if (plan == null || visited.contains(plan)) { - return; - } - - visited.add(plan); - - Hop hop = plan.getHopRef(); - StringBuilder sb = new StringBuilder(); - - // Add FedPlan information - sb.append(String.format("(%d) ", plan.getHopRef().getHopID())) - .append(plan.getHopRef().getOpString()) - .append(" [") - .append(plan.getFedOutType()) - .append("]"); - - StringBuilder childs = new StringBuilder(); - childs.append(" ("); - boolean childAdded = false; - for( Hop input : hop.getInput()){ - childs.append(childAdded?",":""); - childs.append(input.getHopID()); - childAdded = true; - } - childs.append(")"); - if( childAdded ) - sb.append(childs.toString()); - - - sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", - plan.getTotalCost(), - plan.getSelfCost(), - plan.getNetTransferCost())); - - // Add matrix characteristics - sb.append(" [") - .append(hop.getDim1()).append(", ") - .append(hop.getDim2()).append(", ") - .append(hop.getBlocksize()).append(", ") - .append(hop.getNnz()); - - if (hop.getUpdateType().isInPlace()) { - sb.append(", ").append(hop.getUpdateType().toString().toLowerCase()); - } - sb.append("]"); - - // Add memory estimates - sb.append(" [") - .append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ") - .append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ") - .append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ") - .append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]"); - - // Add reblock and checkpoint requirements - if (hop.requiresReblock() && hop.requiresCheckpoint()) { - sb.append(" [rblk, chkpt]"); - } else if (hop.requiresReblock()) { - sb.append(" [rblk]"); - } else if (hop.requiresCheckpoint()) { - sb.append(" [chkpt]"); - } - - // Add execution type - if (hop.getExecType() != null) { - sb.append(", ").append(hop.getExecType()); - } - - System.out.println(sb); - - // Process child nodes - List> childFedPlanPairs = plan.getChildFedPlans(); - for (int i = 0; i < childFedPlanPairs.size(); i++) { - Pair childFedPlanPair = childFedPlanPairs.get(i); - FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); - if (childVariants == null || childVariants.isEmpty()) - continue; - - for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { - printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1); - } - } - } + /** + * Recursively prints a tree representation of the DAG starting from the given root FedPlan. + * Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node. + * Additionally, prints the additional total cost once at the beginning. + * + * @param rootFedPlan The starting point FedPlan to print + * @param memoTable The memoization table containing FedPlan variants + * @param additionalTotalCost The additional cost to be printed once + */ + public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, Set rootHopStatSet, + FederatedMemoTable memoTable, double additionalTotalCost) { + System.out.println("Additional Cost: " + additionalTotalCost); + Set visited = new HashSet<>(); + printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0); + + for (Hop hop : rootHopStatSet) { + FedPlan plan = memoTable.getFedPlanAfterPrune(hop.getHopID(), FederatedOutput.LOUT); + printNotReferencedFedPlanRecursive(plan, memoTable, visited, 1); + } + } + + /** + * Helper method to recursively print the FedPlan tree. + * + * @param plan The current FedPlan to print + * @param visited Set to keep track of visited FedPlans (prevents cycles) + * @param depth The current depth level for indentation + */ + private static void printNotReferencedFedPlanRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, + Set visited, int depth) { + long hopID = plan.getHopRef().getHopID(); + + if (visited.contains(hopID)) { + return; + } + + visited.add(hopID); + printFedPlan(plan, depth, true); + + // Process child nodes + List> childFedPlanPairs = plan.getChildFedPlans(); + for (int i = 0; i < childFedPlanPairs.size(); i++) { + Pair childFedPlanPair = childFedPlanPairs.get(i); + FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); + if (childVariants == null || childVariants.isEmpty()) + continue; + + for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { + printNotReferencedFedPlanRecursive(childPlan, memoTable, visited, depth + 1); + } + } + } + + /** + * Helper method to recursively print the FedPlan tree. + * + * @param plan The current FedPlan to print + * @param visited Set to keep track of visited FedPlans (prevents cycles) + * @param depth The current depth level for indentation + */ + private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, + Set visited, int depth) { + long hopID = 0; + + if (depth == 0) { + hopID = -1; + } else { + hopID = plan.getHopRef().getHopID(); + } + + if (visited.contains(hopID)) { + return; + } + + visited.add(hopID); + printFedPlan(plan, depth, false); + + // Process child nodes + List> childFedPlanPairs = plan.getChildFedPlans(); + for (int i = 0; i < childFedPlanPairs.size(); i++) { + Pair childFedPlanPair = childFedPlanPairs.get(i); + FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); + if (childVariants == null || childVariants.isEmpty()) + continue; + + for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { + printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1); + } + } + } + + private static void printFedPlan(FederatedMemoTable.FedPlan plan, int depth, boolean isNotReferenced) { + StringBuilder sb = new StringBuilder(); + Hop hop = null; + + if (depth == 0){ + sb.append("(R) ROOT [Root]"); + } else { + hop = plan.getHopRef(); + // Add FedPlan information + sb.append(String.format("(%d) ", hop.getHopID())) + .append(hop.getOpString()) + .append(" ["); + + if (isNotReferenced) { + sb.append("NRef"); + } else{ + sb.append(plan.getFedOutType()); + } + sb.append("]"); + } + + StringBuilder childs = new StringBuilder(); + childs.append(" ("); + + boolean childAdded = false; + for (Pair childPair : plan.getChildFedPlans()){ + childs.append(childAdded?",":""); + childs.append(childPair.getLeft()); + childAdded = true; + } + + childs.append(")"); + + if( childAdded ) + sb.append(childs.toString()); + + if (depth == 0){ + sb.append(String.format(" {Total: %.1f}", plan.getCumulativeCost())); + System.out.println(sb); + return; + } + + sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f, Weight: %.1f}", + plan.getCumulativeCost(), + plan.getSelfCost(), + plan.getForwardingCost(), + plan.getWeight())); + + // Add matrix characteristics + sb.append(" [") + .append(hop.getDim1()).append(", ") + .append(hop.getDim2()).append(", ") + .append(hop.getBlocksize()).append(", ") + .append(hop.getNnz()); + + if (hop.getUpdateType().isInPlace()) { + sb.append(", ").append(hop.getUpdateType().toString().toLowerCase()); + } + sb.append("]"); + + // Add memory estimates + sb.append(" [") + .append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ") + .append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ") + .append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ") + .append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]"); + + // Add reblock and checkpoint requirements + if (hop.requiresReblock() && hop.requiresCheckpoint()) { + sb.append(" [rblk, chkpt]"); + } else if (hop.requiresReblock()) { + sb.append(" [rblk]"); + } else if (hop.requiresCheckpoint()) { + sb.append(" [chkpt]"); + } + + // Add execution type + if (hop.getExecType() != null) { + sb.append(", ").append(hop.getExecType()); + } + + System.out.println(sb); + } } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java index be1cfa7cdf3..56586a30622 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -17,218 +17,558 @@ * under the License. */ -package org.apache.sysds.hops.fedplanner; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Comparator; -import java.util.HashMap; -import java.util.Objects; -import java.util.LinkedHashMap; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; - -/** - * Enumerates and evaluates all possible federated execution plans for a given Hop DAG. - * Works with FederatedMemoTable to store plan variants and FederatedPlanCostEstimator - * to compute their costs. - */ -public class FederatedPlanCostEnumerator { - /** - * Entry point for federated plan enumeration. This method creates a memo table - * and returns the minimum cost plan for the entire Directed Acyclic Graph (DAG). - * It also resolves conflicts where FedPlans have different FederatedOutput types. - * - * @param rootHop The root Hop node from which to start the plan enumeration. - * @param printTree A boolean flag indicating whether to print the federated plan tree. - * @return The optimal FedPlan with the minimum cost for the entire DAG. - */ - public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) { - // Create new memo table to store all plan variants - FederatedMemoTable memoTable = new FederatedMemoTable(); - - // Recursively enumerate all possible plans - enumerateFederatedPlanCost(rootHop, memoTable); - - // Return the minimum cost plan for the root node - FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable); - - // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types - double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); - - // Optionally print the federated plan tree if requested - if (printTree) FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost); - - return optimalPlan; - } - - /** - * Recursively enumerates all possible federated execution plans for a Hop DAG. - * For each node: - * 1. First processes all input nodes recursively if not already processed - * 2. Generates all possible combinations of federation types (FOUT/LOUT) for inputs - * 3. Creates and evaluates both FOUT and LOUT variants for current node with each input combination - * - * The enumeration uses a bottom-up approach where: - * - Each input combination is represented by a binary number (i) - * - Bit j in i determines whether input j is FOUT (1) or LOUT (0) - * - Total number of combinations is 2^numInputs - * - * @param hop ? - * @param memoTable ? - */ - private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) { - int numInputs = hop.getInput().size(); - - // Process all input nodes first if not already in memo table - for (Hop inputHop : hop.getInput()) { - if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) - && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) { - enumerateFederatedPlanCost(inputHop, memoTable); - } - } - - // Generate all possible input combinations using binary representation - // i represents a specific combination of FOUT/LOUT for inputs - for (int i = 0; i < (1 << numInputs); i++) { - List> planChilds = new ArrayList<>(); - - // For each input, determine if it should be FOUT or LOUT based on bit j in i - for (int j = 0; j < numInputs; j++) { - Hop inputHop = hop.getInput().get(j); - // If bit j is set (1), use FOUT; otherwise use LOUT - FederatedOutput childType = ((i & (1 << j)) != 0) ? - FederatedOutput.FOUT : FederatedOutput.LOUT; - planChilds.add(Pair.of(inputHop.getHopID(), childType)); - } - - // Create and evaluate FOUT variant for current input combination - FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds); - FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable); - - // Create and evaluate LOUT variant for current input combination - FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds); - FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable); - } - - // Prune MemoTable for hop. - memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.LOUT); - memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.FOUT); - } - - /** - * Returns the minimum cost plan for the root Hop, comparing both FOUT and LOUT variants. - * Used to select the final execution plan after enumeration. - * - * @param HopID ? - * @param memoTable ? - * @return ? - */ - private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) { - FedPlanVariants fOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT); - FedPlanVariants lOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT); - - FedPlan minFOutFedPlan = fOutFedPlanVariants._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - FedPlan minlOutFedPlan = lOutFedPlanVariants._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - - if (Objects.requireNonNull(minFOutFedPlan).getTotalCost() - < Objects.requireNonNull(minlOutFedPlan).getTotalCost()) { - return minFOutFedPlan; - } - return minlOutFedPlan; - } - - /** - * Detects and resolves conflicts in federated plans starting from the root plan. - * This function performs a breadth-first search (BFS) to traverse the federated plan tree. - * It identifies conflicts where the same plan ID has different federated output types. - * For each conflict, it records the plan ID and its conflicting parent plans. - * The function ensures that each plan ID is associated with a consistent federated output type - * by resolving these conflicts iteratively. - * - * The process involves: - * - Using a map to track conflicts, associating each plan ID with its federated output type - * and a list of parent plans. - * - Storing detected conflicts in a linked map, each entry containing a plan ID and its - * conflicting parent plans. - * - Performing BFS traversal starting from the root plan, checking each child plan for conflicts. - * - If a conflict is detected (i.e., a plan ID has different output types), the conflicting plan - * is removed from the BFS queue and added to the conflict map to prevent duplicate calculations. - * - Resolving conflicts by ensuring a consistent federated output type across the plan. - * - Re-running BFS with resolved conflicts to ensure all inconsistencies are addressed. - * - * @param rootPlan The root federated plan from which to start the conflict detection. - * @param memoTable The memoization table used to retrieve pruned federated plans. - * @return The cumulative additional cost for resolving conflicts. - */ - private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, FederatedMemoTable memoTable) { - // Map to track conflicts: maps a plan ID to its federated output type and list of parent plans - Map>> conflictCheckMap = new HashMap<>(); - - // LinkedMap to store detected conflicts, each with a plan ID and its conflicting parent plans - LinkedHashMap> conflictLinkedMap = new LinkedHashMap<>(); - - // LinkedMap for BFS traversal starting from the root plan (Do not use value (boolean)) - LinkedHashMap bfsLinkedMap = new LinkedHashMap<>(); - bfsLinkedMap.put(rootPlan, true); - - // Array to store cumulative additional cost for resolving conflicts - double[] cumulativeAdditionalCost = new double[]{0.0}; - - while (!bfsLinkedMap.isEmpty()) { - // Perform BFS to detect conflicts in federated plans - while (!bfsLinkedMap.isEmpty()) { - FedPlan currentPlan = bfsLinkedMap.keySet().iterator().next(); - bfsLinkedMap.remove(currentPlan); - - // Iterate over each child plan of the current plan - for (Pair childPlanPair : currentPlan.getChildFedPlans()) { - FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair); - - // Check if the child plan ID is already visited - if (conflictCheckMap.containsKey(childPlanPair.getLeft())) { - // Retrieve the existing conflict pair for the child plan - Pair> conflictChildPlanPair = conflictCheckMap.get(childPlanPair.getLeft()); - // Add the current plan to the list of parent plans - conflictChildPlanPair.getRight().add(currentPlan); - - // If the federated output type differs, a conflict is detected - if (conflictChildPlanPair.getLeft() != childPlanPair.getRight()) { - // If this is the first detection, remove conflictChildFedPlan from the BFS queue and add it to the conflict linked map (queue) - // If the existing FedPlan is not removed from the bfsqueue or both actions are performed, duplicate calculations for the same FedPlan and its children occur - if (!conflictLinkedMap.containsKey(childPlanPair.getLeft())) { - conflictLinkedMap.put(childPlanPair.getLeft(), conflictChildPlanPair.getRight()); - bfsLinkedMap.remove(childFedPlan); - } - } - } else { - // If no conflict exists, create a new entry in the conflict check map - List parentFedPlanList = new ArrayList<>(); - parentFedPlanList.add(currentPlan); - - // Map the child plan ID to its output type and list of parent plans - conflictCheckMap.put(childPlanPair.getLeft(), new ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList)); - // Add the child plan to the BFS queue - bfsLinkedMap.put(childFedPlan, true); - } - } - } - // Resolve these conflicts to ensure a consistent federated output type across the plan - // Re-run BFS with resolved conflicts - bfsLinkedMap = FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap, cumulativeAdditionalCost); - conflictLinkedMap.clear(); - } - - // Return the cumulative additional cost for resolving conflicts - return cumulativeAdditionalCost[0]; - } -} + package org.apache.sysds.hops.fedplanner; + import java.util.ArrayList; + import java.util.List; + import java.util.Map; + import java.util.HashMap; + import java.util.LinkedHashMap; + import java.util.Optional; + import java.util.Set; + import java.util.HashSet; + + import org.apache.commons.lang3.tuple.Pair; + + import org.apache.commons.lang3.tuple.ImmutablePair; + import org.apache.sysds.common.Types; + import org.apache.sysds.hops.DataOp; + import org.apache.sysds.hops.Hop; + import org.apache.sysds.hops.LiteralOp; + import org.apache.sysds.hops.UnaryOp; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; + import org.apache.sysds.hops.rewrite.HopRewriteUtils; + import org.apache.sysds.parser.DMLProgram; + import org.apache.sysds.parser.ForStatement; + import org.apache.sysds.parser.ForStatementBlock; + import org.apache.sysds.parser.FunctionStatement; + import org.apache.sysds.parser.FunctionStatementBlock; + import org.apache.sysds.parser.IfStatement; + import org.apache.sysds.parser.IfStatementBlock; + import org.apache.sysds.parser.StatementBlock; + import org.apache.sysds.parser.WhileStatement; + import org.apache.sysds.parser.WhileStatementBlock; + import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + import org.apache.sysds.runtime.util.UtilFunctions; + + public class FederatedPlanCostEnumerator { + private static final double DEFAULT_LOOP_WEIGHT = 10.0; + private static final double DEFAULT_IF_ELSE_WEIGHT = 0.5; + + /** + * Enumerates the entire DML program to generate federated execution plans. + * It processes each statement block, computes the optimal federated plan, + * detects and resolves conflicts, and optionally prints the plan tree. + * + * @param prog The DML program to enumerate. + * @param isPrint A boolean indicating whether to print the federated plan tree. + */ + public static void enumerateProgram(DMLProgram prog, boolean isPrint) { + FederatedMemoTable memoTable = new FederatedMemoTable(); + + Map> outerTransTable = new HashMap<>(); + Map> formerInnerTransTable = new HashMap<>(); + Set progRootHopSet = new HashSet<>(); // Set of hops for the root dummy node + // TODO: Just for debug, remove later + Set statRootHopSet = new HashSet<>(); // Set of hops that have no parent but are not referenced + + for (StatementBlock sb : prog.getStatementBlocks()) { + Optional.ofNullable(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, 1, false)) + .ifPresent(outerTransTable::putAll); + } + + FedPlan optimalPlan = getMinCostRootFedPlan(progRootHopSet, memoTable); + + // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types + double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); + + // Print the federated plan tree if requested + if (isPrint) { + FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, statRootHopSet, memoTable, additionalTotalCost); + } + } + + + /** + * Enumerates the statement block and updates the transient and memoization tables. + * This method processes different types of statement blocks such as If, For, While, and Function blocks. + * It recursively enumerates the Hop DAGs within these blocks and updates the corresponding tables. + * The method also calculates weights recursively for if-else/loops and handles inner and outer block distinctions. + * + * @param sb The statement block to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track immutable outer transient writes. + * @param formerInnerTransTable The table to track immutable former inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). + * @param weight The weight associated with the current Hop. + * @param isInnerBlock A boolean indicating if the current block is an inner block. + * @return A map of inner transient writes. + */ + public static Map> enumerateStatementBlock(StatementBlock sb, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Set progRootHopSet, Set statRootHopSet, double weight, boolean isInnerBlock) { + Map> innerTransTable = new HashMap<>(); + + if (sb instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement)isb.getStatement(0); + + enumerateHopDAG(isb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + + // Treat outerTransTable as immutable in inner blocks + // Write TWrite of sb sequentially in innerTransTable, and update formerInnerTransTable after the sb ends + // In case of if-else, create separate formerInnerTransTables for if and else, merge them after completion, and update formerInnerTransTable + Map> ifFormerInnerTransTable = new HashMap<>(formerInnerTransTable); + Map> elseFormerInnerTransTable = new HashMap<>(formerInnerTransTable); + + for (StatementBlock csb : istmt.getIfBody()){ + ifFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, ifFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); + } + + for (StatementBlock csb : istmt.getElseBody()){ + elseFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, elseFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); + } + + // If there are common keys: merge elseValue list into ifValue list + elseFormerInnerTransTable.forEach((key, elseValue) -> { + ifFormerInnerTransTable.merge(key, elseValue, (ifValue, newValue) -> { + ifValue.addAll(newValue); + return ifValue; + }); + }); + // Update innerTransTable + innerTransTable.putAll(ifFormerInnerTransTable); + } else if (sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + + // Calculate for-loop iteration count if possible + double loopWeight = DEFAULT_LOOP_WEIGHT; + Hop from = fsb.getFromHops().getInput().get(0); + Hop to = fsb.getToHops().getInput().get(0); + Hop incr = (fsb.getIncrementHops() != null) ? + fsb.getIncrementHops().getInput().get(0) : new LiteralOp(1); + + // Calculate for-loop iteration count (weight) if from, to, and incr are literal ops (constant values) + if( from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp ) { + double dfrom = HopRewriteUtils.getDoubleValue((LiteralOp) from); + double dto = HopRewriteUtils.getDoubleValue((LiteralOp) to); + double dincr = HopRewriteUtils.getDoubleValue((LiteralOp) incr); + if( dfrom > dto && dincr == 1 ) + dincr = -1; + loopWeight = UtilFunctions.getSeqLength(dfrom, dto, dincr, false); + } + weight *= loopWeight; + + enumerateHopDAG(fsb.getFromHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateHopDAG(fsb.getToHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateHopDAG(fsb.getIncrementHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + + enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); + } else if (sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + weight *= DEFAULT_LOOP_WEIGHT; + + enumerateHopDAG(wsb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateStatementBlockBody(wstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); + } else if (sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + + // TODO: NOT descent multiple types (use hash set for functions using function name) + enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); + } else { //generic (last-level) + if( sb.getHops() != null ){ + for(Hop c : sb.getHops()) + // In the statement block, if isInner, write hopDAG in innerTransTable, if not, write directly in outerTransTable + enumerateHopDAG(c, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + } + } + return innerTransTable; + } + + /** + * Enumerates the statement blocks within a body and updates the transient and memoization tables. + * + * @param sbList The list of statement blocks to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track immutable outer transient writes. + * @param formerInnerTransTable The table to track immutable former inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). + * @param weight The weight associated with the current Hop. + */ + public static void enumerateStatementBlockBody(List sbList, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Map> innerTransTable, Set progRootHopSet, Set statRootHopSet, double weight) { + // The statement blocks within the body reference outerTransTable and formerInnerTransTable as immutable read-only, + // and record TWrite in the innerTransTable of the statement block within the body. + // Update the formerInnerTransTable with the contents of the returned innerTransTable. + for (StatementBlock sb : sbList) + formerInnerTransTable.putAll(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, weight, true)); + + // Then update and return the innerTransTable of the statement block containing the body. + innerTransTable.putAll(formerInnerTransTable); + } + + /** + * Enumerates the statement hop DAG within a statement block. + * This method recursively enumerates all possible federated execution plans + * and identifies hops to connect to the root dummy node. + * + * @param rootHop The root Hop of the DAG to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track transient writes. + * @param formerInnerTransTable The table to track immutable inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of root hops for debugging purposes. + * @param weight The weight associated with the current Hop. + * @param isInnerBlock A boolean indicating if the current block is an inner block. + */ + public static void enumerateHopDAG(Hop rootHop, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Map> innerTransTable, Set progRootHopSet, Set statRootHopSet, double weight, boolean isInnerBlock) { + // Recursively enumerate all possible plans + rewireAndEnumerateFedPlan(rootHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInnerBlock); + + // Identify hops to connect to the root dummy node + + if ((rootHop instanceof DataOp && (rootHop.getName().equals("__pred"))) // TWrite "__pred" + || (rootHop instanceof UnaryOp && ((UnaryOp)rootHop).getOp() == Types.OpOp1.PRINT)){ // u(print) + // Connect TWrite pred and u(print) to the root dummy node + // TODO: Should the last unreferenced TWrite be connected? + progRootHopSet.add(rootHop); + } else { + // TODO: Just for debug, remove later + // For identifying TWrites that are not referenced later + statRootHopSet.add(rootHop); + } + } + + /** + * Rewires and enumerates federated execution plans for a given Hop. + * This method processes all input nodes, rewires TWrite and TRead operations, + * and generates federated plan variants for both inner and outer code blocks. + * + * @param hop The Hop for which to rewire and enumerate federated plans. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track transient writes. + * @param formerInnerTransTable The table to track immutable inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param weight The weight associated with the current Hop. + * @param isInner A boolean indicating if the current block is an inner block. + */ + private static void rewireAndEnumerateFedPlan(Hop hop, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Map> innerTransTable, double weight, boolean isInner) { + // Process all input nodes first if not already in memo table + for (Hop inputHop : hop.getInput()) { + long inputHopID = inputHop.getHopID(); + if (!memoTable.contains(inputHopID, FederatedOutput.FOUT) + && !memoTable.contains(inputHopID, FederatedOutput.LOUT)) { + rewireAndEnumerateFedPlan(inputHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInner); + } + } + + // Detect and Rewire TWrite and TRead operations + List childHops = hop.getInput(); + if (hop instanceof DataOp && !(hop.getName().equals("__pred"))){ + String hopName = hop.getName(); + + if (isInner){ // If it's an inner code block + if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTWRITE){ + innerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); + } else if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTREAD){ + // Copy existing and add TWrite + childHops = new ArrayList<>(childHops); + List additionalChildHops = null; + + // Read according to priority + if (innerTransTable.containsKey(hopName)){ + additionalChildHops = innerTransTable.get(hopName); + } else if (formerInnerTransTable.containsKey(hopName)){ + additionalChildHops = formerInnerTransTable.get(hopName); + } else if (outerTransTable.containsKey(hopName)){ + additionalChildHops = outerTransTable.get(hopName); + } + + if (additionalChildHops != null) { + childHops.addAll(additionalChildHops); + } + } + } else { // If it's an outer code block + if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTWRITE){ + // Add directly to outerTransTable + outerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); + } else if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTREAD){ + childHops = new ArrayList<>(childHops); + + // TODO: In the case of for (i in 1:10), there is no hop that writes TWrite for i. + // Read directly from outerTransTable and add + List additionalChildHops = outerTransTable.get(hopName); + if (additionalChildHops != null) { + childHops.addAll(additionalChildHops); + } + } + } + } + + // Enumerate the federated plan for the current Hop + enumerateFedPlan(hop, memoTable, childHops, weight); + } + + /** + * Enumerates federated execution plans for a given Hop. + * This method calculates the self cost and child costs for the Hop, + * generates federated plan variants for both LOUT and FOUT output types, + * and prunes redundant plans before adding them to the memo table. + * + * @param hop The Hop for which to enumerate federated plans. + * @param memoTable The memoization table to store plan variants. + * @param childHops The list of child hops. + * @param weight The weight associated with the current Hop. + */ + private static void enumerateFedPlan(Hop hop, FederatedMemoTable memoTable, List childHops, double weight){ + long hopID = hop.getHopID(); + HopCommon hopCommon = new HopCommon(hop, weight); + double selfCost = FederatedPlanCostEstimator.computeHopCost(hopCommon); + + FedPlanVariants lOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.LOUT); + FedPlanVariants fOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.FOUT); + + int numInputs = childHops.size(); + int numInitInputs = hop.getInput().size(); + + double[][] childCumulativeCost = new double[numInputs][2]; // # of child, LOUT/FOUT of child + double[] childForwardingCost = new double[numInputs]; // # of child + + // The self cost follows its own weight, while the forwarding cost follows the parent's weight. + FederatedPlanCostEstimator.getChildCosts(hopCommon, memoTable, childHops, childCumulativeCost, childForwardingCost); + + if (numInitInputs == numInputs){ + enumerateOnlyInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); + } else { + enumerateTReadInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, numInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); + } + + // Prune the FedPlans to remove redundant plans + lOutFedPlanVariants.pruneFedPlans(); + fOutFedPlanVariants.pruneFedPlans(); + + // Add the FedPlanVariants to the memo table + memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT, lOutFedPlanVariants); + memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT, fOutFedPlanVariants); + } + + /** + * Enumerates federated execution plans for initial child hops only. + * This method generates all possible combinations of federated output types (LOUT and FOUT) + * for the initial child hops and calculates their cumulative costs. + * + * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. + * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. + * @param numInitInputs The number of initial input hops. + * @param childHops The list of child hops. + * @param childCumulativeCost The cumulative costs for each child hop. + * @param childForwardingCost The forwarding costs for each child hop. + * @param selfCost The self cost of the current hop. + */ + private static void enumerateOnlyInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, int numInitInputs, List childHops, + double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ + // Iterate 2^n times, generating two FedPlans (LOUT, FOUT) each time. + for (int i = 0; i < (1 << numInitInputs); i++) { + double[] cumulativeCost = new double[]{selfCost, selfCost}; + List> planChilds = new ArrayList<>(); + // LOUT and FOUT share the same planChilds in each iteration (only forwarding cost differs). + enumerateInitChildFedPlan(numInitInputs, childHops, planChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); + + lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, planChilds)); + fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, planChilds)); + } + } + + /** + * Enumerates federated execution plans for a TRead hop. + * This method calculates the cumulative costs for both LOUT and FOUT federated output types + * by considering the additional child hops, which are TWrite hops. + * It generates all possible combinations of federated output types for the initial child hops + * and adds the pre-calculated costs of the TWrite child hops to these combinations. + * + * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. + * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. + * @param numInitInputs The number of initial input hops. + * @param numInputs The total number of input hops, including additional TWrite hops. + * @param childHops The list of child hops. + * @param childCumulativeCost The cumulative costs for each child hop. + * @param childForwardingCost The forwarding costs for each child hop. + * @param selfCost The self cost of the current hop. + */ + private static void enumerateTReadInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, + int numInitInputs, int numInputs, List childHops, + double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ + double lOutTReadCumulativeCost = selfCost; + double fOutTReadCumulativeCost = selfCost; + + List> lOutTReadPlanChilds = new ArrayList<>(); + List> fOutTReadPlanChilds = new ArrayList<>(); + + // Pre-calculate the cost for the additional child hop, which is a TWrite hop, of the TRead hop. + // Constraint: TWrite must have the same FedOutType as TRead. + for (int j = numInitInputs; j < numInputs; j++) { + Hop inputHop = childHops.get(j); + lOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.LOUT)); + fOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.FOUT)); + + lOutTReadCumulativeCost += childCumulativeCost[j][0]; + fOutTReadCumulativeCost += childCumulativeCost[j][1]; + // Skip TWrite -> TRead as they have the same FedOutType. + } + + for (int i = 0; i < (1 << numInitInputs); i++) { + double[] cumulativeCost = new double[]{selfCost, selfCost}; + List> lOutPlanChilds = new ArrayList<>(); + enumerateInitChildFedPlan(numInitInputs, childHops, lOutPlanChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); + + // Copy lOutPlanChilds to create fOutPlanChilds and add the pre-calculated cost of the TWrite child hop. + List> fOutPlanChilds = new ArrayList<>(lOutPlanChilds); + + lOutPlanChilds.addAll(lOutTReadPlanChilds); + fOutPlanChilds.addAll(fOutTReadPlanChilds); + + cumulativeCost[0] += lOutTReadCumulativeCost; + cumulativeCost[1] += fOutTReadCumulativeCost; + + lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, lOutPlanChilds)); + fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, fOutPlanChilds)); + } + } + + // Calculates costs for initial child hops, determining FOUT or LOUT based on `i`. + private static void enumerateInitChildFedPlan(int numInitInputs, List childHops, List> planChilds, + double[][] childCumulativeCost, double[] childForwardingCost, double[] cumulativeCost, int i){ + // For each input, determine if it should be FOUT or LOUT based on bit j in i + for (int j = 0; j < numInitInputs; j++) { + Hop inputHop = childHops.get(j); + // Calculate the bit value to decide between FOUT and LOUT for the current input + final int bit = (i & (1 << j)) != 0 ? 1 : 0; // Determine the bit value (decides FOUT/LOUT) + final FederatedOutput childType = (bit == 1) ? FederatedOutput.FOUT : FederatedOutput.LOUT; + planChilds.add(Pair.of(inputHop.getHopID(), childType)); + + // Update the cumulative cost for LOUT, FOUT + cumulativeCost[0] += childCumulativeCost[j][bit] + childForwardingCost[j] * bit; + cumulativeCost[1] += childCumulativeCost[j][bit] + childForwardingCost[j] * (1 - bit); + } + } + + // Creates a dummy root node (fedplan) and selects the FedPlan with the minimum cost to return. + // The dummy root node does not have LOUT or FOUT. + private static FedPlan getMinCostRootFedPlan(Set progRootHopSet, FederatedMemoTable memoTable) { + double cumulativeCost = 0; + List> rootFedPlanChilds = new ArrayList<>(); + + // Iterate over each Hop in the progRootHopSet + for (Hop endHop : progRootHopSet){ + // Retrieve the pruned FedPlan for LOUT and FOUT from the memo table + FedPlan lOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.LOUT); + FedPlan fOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.FOUT); + + // Compare the cumulative costs of LOUT and FOUT FedPlans + if (lOutFedPlan.getCumulativeCost() <= fOutFedPlan.getCumulativeCost()){ + cumulativeCost += lOutFedPlan.getCumulativeCost(); + rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.LOUT)); + } else{ + cumulativeCost += fOutFedPlan.getCumulativeCost(); + rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.FOUT)); + } + } + + return new FedPlan(cumulativeCost, null, rootFedPlanChilds); + } + + /** + * Detects and resolves conflicts in federated plans starting from the root plan. + * This function performs a breadth-first search (BFS) to traverse the federated plan tree. + * It identifies conflicts where the same plan ID has different federated output types. + * For each conflict, it records the plan ID and its conflicting parent plans. + * The function ensures that each plan ID is associated with a consistent federated output type + * by resolving these conflicts iteratively. + * + * The process involves: + * - Using a map to track conflicts, associating each plan ID with its federated output type + * and a list of parent plans. + * - Storing detected conflicts in a linked map, each entry containing a plan ID and its + * conflicting parent plans. + * - Performing BFS traversal starting from the root plan, checking each child plan for conflicts. + * - If a conflict is detected (i.e., a plan ID has different output types), the conflicting plan + * is removed from the BFS queue and added to the conflict map to prevent duplicate calculations. + * - Resolving conflicts by ensuring a consistent federated output type across the plan. + * - Re-running BFS with resolved conflicts to ensure all inconsistencies are addressed. + * + * @param rootPlan The root federated plan from which to start the conflict detection. + * @param memoTable The memoization table used to retrieve pruned federated plans. + * @return The cumulative additional cost for resolving conflicts. + */ + private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, FederatedMemoTable memoTable) { + // Map to track conflicts: maps a plan ID to its federated output type and list of parent plans + Map>> conflictCheckMap = new HashMap<>(); + + // LinkedMap to store detected conflicts, each with a plan ID and its conflicting parent plans + LinkedHashMap> conflictLinkedMap = new LinkedHashMap<>(); + + // LinkedMap for BFS traversal starting from the root plan (Do not use value (boolean)) + LinkedHashMap bfsLinkedMap = new LinkedHashMap<>(); + bfsLinkedMap.put(rootPlan, true); + + // Array to store cumulative additional cost for resolving conflicts + double[] cumulativeAdditionalCost = new double[]{0.0}; + + while (!bfsLinkedMap.isEmpty()) { + // Perform BFS to detect conflicts in federated plans + while (!bfsLinkedMap.isEmpty()) { + FedPlan currentPlan = bfsLinkedMap.keySet().iterator().next(); + bfsLinkedMap.remove(currentPlan); + + // Iterate over each child plan of the current plan + for (Pair childPlanPair : currentPlan.getChildFedPlans()) { + FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair); + + // Check if the child plan ID is already visited + if (conflictCheckMap.containsKey(childPlanPair.getLeft())) { + // Retrieve the existing conflict pair for the child plan + Pair> conflictChildPlanPair = conflictCheckMap.get(childPlanPair.getLeft()); + // Add the current plan to the list of parent plans + conflictChildPlanPair.getRight().add(currentPlan); + + // If the federated output type differs, a conflict is detected + if (conflictChildPlanPair.getLeft() != childPlanPair.getRight()) { + // If this is the first detection, remove conflictChildFedPlan from the BFS queue and add it to the conflict linked map (queue) + // If the existing FedPlan is not removed from the bfsqueue or both actions are performed, duplicate calculations for the same FedPlan and its children occur + if (!conflictLinkedMap.containsKey(childPlanPair.getLeft())) { + conflictLinkedMap.put(childPlanPair.getLeft(), conflictChildPlanPair.getRight()); + bfsLinkedMap.remove(childFedPlan); + } + } + } else { + // If no conflict exists, create a new entry in the conflict check map + List parentFedPlanList = new ArrayList<>(); + parentFedPlanList.add(currentPlan); + + // Map the child plan ID to its output type and list of parent plans + conflictCheckMap.put(childPlanPair.getLeft(), new ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList)); + // Add the child plan to the BFS queue + bfsLinkedMap.put(childFedPlan, true); + } + } + } + // Resolve these conflicts to ensure a consistent federated output type across the plan + // Re-run BFS with resolved conflicts + bfsLinkedMap = FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap, cumulativeAdditionalCost); + conflictLinkedMap.clear(); + } + + // Return the cumulative additional cost for resolving conflicts + return cumulativeAdditionalCost[0]; + } + } + \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java index 7bc7339563a..55b1c9daa15 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -17,224 +17,249 @@ * under the License. */ -package org.apache.sysds.hops.fedplanner; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.cost.ComputeCost; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; - -import java.util.LinkedHashMap; -import java.util.NoSuchElementException; -import java.util.List; -import java.util.Map; - -/** - * Cost estimator for federated execution plans. - * Calculates computation, memory access, and network transfer costs for federated operations. - * Works in conjunction with FederatedMemoTable to evaluate different execution plan variants. - */ -public class FederatedPlanCostEstimator { - // Default value is used as a reasonable estimate since we only need - // to compare relative costs between different federated plans - // Memory bandwidth for local computations (25 GB/s) - private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0; - // Network bandwidth for data transfers between federated sites (1 Gbps) - private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; - - /** - * Computes total cost of federated plan by: - * 1. Computing current node cost (if not cached) - * 2. Adding minimum-cost child plans - * 3. Including network transfer costs when needed - * - * @param currentPlan Plan to compute cost for - * @param memoTable Table containing all plan variants - */ - public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) { - double totalCost; - Hop currentHop = currentPlan.getHopRef(); - - // Step 1: Calculate current node costs if not already computed - if (currentPlan.getSelfCost() == 0) { - // Compute cost for current node (computation + memory access) - totalCost = computeCurrentCost(currentHop); - currentPlan.setSelfCost(totalCost); - // Calculate potential network transfer cost if federation type changes - currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate())); - } else { - totalCost = currentPlan.getSelfCost(); - } - - // Step 2: Process each child plan and add their costs - for (Pair childPlanPair : currentPlan.getChildFedPlans()) { - // Find minimum cost child plan considering federation type compatibility - // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents - // because we're selecting child plans independently for each parent - FedPlan planRef = memoTable.getMinCostFedPlan(childPlanPair); - - // Add child plan cost (includes network transfer cost if federation types differ) - totalCost += planRef.getTotalCost() + planRef.getCondNetTransferCost(currentPlan.getFedOutType()); - } - - // Step 3: Set final cumulative cost including current node - currentPlan.setTotalCost(totalCost); - } - - /** - * Resolves conflicts in federated plans where different plans have different FederatedOutput types. - * This function traverses the list of conflicting plans in reverse order to ensure that conflicts - * are resolved from the bottom-up, allowing for consistent federated output types across the plan. - * It calculates additional costs for each potential resolution and updates the cumulative additional cost. - * - * @param memoTable The FederatedMemoTable containing all federated plan variants. - * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent plans with conflicting federated outputs. - * @param cumulativeAdditionalCost An array to store the cumulative additional cost incurred by resolving conflicts. - * @return A LinkedHashMap of resolved federated plans, marked with a boolean indicating resolution status. - */ - public static LinkedHashMap resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) { - // LinkedHashMap to store resolved federated plans for BFS traversal. - LinkedHashMap resolvedFedPlanLinkedMap = new LinkedHashMap<>(); - - // Traverse the conflictFedPlanList in reverse order after BFS to resolve conflicts - for (Map.Entry> conflictFedPlanPair : conflictFedPlanLinkedMap.entrySet()) { - long conflictHopID = conflictFedPlanPair.getKey(); - List conflictParentFedPlans = conflictFedPlanPair.getValue(); - - // Retrieve the conflicting federated plans for LOUT and FOUT types - FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT); - FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT); - - // Variables to store additional costs for LOUT and FOUT types - double lOutAdditionalCost = 0; - double fOutAdditionalCost = 0; - - // Flags to check if the plan involves network transfer - // Network transfer cost is calculated only once, even if it occurs multiple times - boolean isLOutNetTransfer = false; - boolean isFOutNetTransfer = false; - - // Determine the optimal federated output type based on the calculated costs - FederatedOutput optimalFedOutType; - - // Iterate over each parent federated plan in the current conflict pair - for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { - // Find the calculated FedOutType of the child plan - Pair cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream() - .filter(pair -> pair.getLeft().equals(conflictHopID)) - .findFirst() - .orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + conflictHopID)); - - // CASE 1. Calculated LOUT / Parent LOUT / Current LOUT: Total cost remains unchanged. - // CASE 2. Calculated LOUT / Parent FOUT / Current LOUT: Total cost remains unchanged, subtract net cost, add net cost later. - // CASE 3. Calculated FOUT / Parent LOUT / Current LOUT: Change total cost, subtract net cost. - // CASE 4. Calculated FOUT / Parent FOUT / Current LOUT: Change total cost, add net cost later. - // CASE 5. Calculated LOUT / Parent LOUT / Current FOUT: Change total cost, add net cost later. - // CASE 6. Calculated LOUT / Parent FOUT / Current FOUT: Change total cost, subtract net cost. - // CASE 7. Calculated FOUT / Parent LOUT / Current FOUT: Total cost remains unchanged, subtract net cost, add net cost later. - // CASE 8. Calculated FOUT / Parent FOUT / Current FOUT: Total cost remains unchanged. - - // Adjust LOUT, FOUT costs based on the calculated plan's output type - if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { - // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost - // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. - fOutAdditionalCost += confilctFOutFedPlan.getTotalCost() - confilctLOutFedPlan.getTotalCost(); - - if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { - // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred - // (CASE 5) If changing from calculated LOUT to current FOUT, network transfer cost occurs, but calculated later - isFOutNetTransfer = true; - } else { - // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred - // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later - isLOutNetTransfer = true; - lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); - - // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it - fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); - } - } else { - lOutAdditionalCost += confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost(); - - if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { - isLOutNetTransfer = true; - } else { - isFOutNetTransfer = true; - lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); - fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); - } - } - } - - // Add network transfer costs if applicable - if (isLOutNetTransfer) { - lOutAdditionalCost += confilctLOutFedPlan.getNetTransferCost(); - } - if (isFOutNetTransfer) { - fOutAdditionalCost += confilctFOutFedPlan.getNetTransferCost(); - } - - // Determine the optimal federated output type based on the calculated costs - if (lOutAdditionalCost <= fOutAdditionalCost) { - optimalFedOutType = FederatedOutput.LOUT; - cumulativeAdditionalCost[0] += lOutAdditionalCost; - resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true); - } else { - optimalFedOutType = FederatedOutput.FOUT; - cumulativeAdditionalCost[0] += fOutAdditionalCost; - resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true); - } - - // Update only the optimal federated output type, not the cost itself or recursively - for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { - for (Pair childPlanPair : conflictParentFedPlan.getChildFedPlans()) { - if (childPlanPair.getLeft() == conflictHopID && childPlanPair.getRight() != optimalFedOutType) { - int index = conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair); - conflictParentFedPlan.getChildFedPlans().set(index, - Pair.of(childPlanPair.getLeft(), optimalFedOutType)); - break; - } - } - } - } - return resolvedFedPlanLinkedMap; - } - - /** - * Computes the cost for the current Hop node. - * - * @param currentHop The Hop node whose cost needs to be computed - * @return The total cost for the current node's operation - */ - private static double computeCurrentCost(Hop currentHop){ - double computeCost = ComputeCost.getHOPComputeCost(currentHop); - double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); - double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); - - // Compute total cost assuming: - // 1. Computation and input access can be overlapped (hence taking max) - // 2. Output access must wait for both to complete (hence adding) - return Math.max(computeCost, inputAccessCost) + ouputAccessCost; - } - - /** - * Calculates the memory access cost based on data size and memory bandwidth. - * - * @param memSize Size of data to be accessed (in bytes) - * @return Time cost for memory access (in seconds) - */ - private static double computeHopMemoryAccessCost(double memSize) { - return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; - } - - /** - * Calculates the network transfer cost based on data size and network bandwidth. - * Used when federation status changes between parent and child plans. - * - * @param memSize Size of data to be transferred (in bytes) - * @return Time cost for network transfer (in seconds) - */ - private static double computeHopNetworkAccessCost(double memSize) { - return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; - } -} + package org.apache.sysds.hops.fedplanner; + import org.apache.commons.lang3.tuple.Pair; + import org.apache.sysds.common.Types; + import org.apache.sysds.hops.DataOp; + import org.apache.sysds.hops.Hop; + import org.apache.sysds.hops.cost.ComputeCost; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; + import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + + import java.util.LinkedHashMap; + import java.util.NoSuchElementException; + import java.util.List; + import java.util.Map; + + /** + * Cost estimator for federated execution plans. + * Calculates computation, memory access, and network transfer costs for federated operations. + * Works in conjunction with FederatedMemoTable to evaluate different execution plan variants. + */ + public class FederatedPlanCostEstimator { + // Default value is used as a reasonable estimate since we only need + // to compare relative costs between different federated plans + // Memory bandwidth for local computations (25 GB/s) + private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0; + // Network bandwidth for data transfers between federated sites (1 Gbps) + private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; + + // Retrieves the cumulative and forwarding costs of the child hops and stores them in arrays + public static void getChildCosts(HopCommon hopCommon, FederatedMemoTable memoTable, List inputHops, + double[][] childCumulativeCost, double[] childForwardingCost) { + for (int i = 0; i < inputHops.size(); i++) { + long childHopID = inputHops.get(i).getHopID(); + + FedPlan childLOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.LOUT); + FedPlan childFOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.FOUT); + + // The cumulative cost of the child already includes the weight + childCumulativeCost[i][0] = childLOutFedPlan.getCumulativeCost(); + childCumulativeCost[i][1] = childFOutFedPlan.getCumulativeCost(); + + // TODO: Q. Shouldn't the child's forwarding cost follow the parent's weight, regardless of loops or if-else statements? + childForwardingCost[i] = hopCommon.weight * childLOutFedPlan.getForwardingCost(); + } + } + + /** + * Computes the cost associated with a given Hop node. + * This method calculates both the self cost and the forwarding cost for the Hop, + * taking into account its type and the number of parent nodes. + * + * @param hopCommon The HopCommon object containing the Hop and its properties. + * @return The self cost of the Hop. + */ + public static double computeHopCost(HopCommon hopCommon){ + // TWrite and TRead are meta-data operations, hence selfCost is zero + if (hopCommon.hopRef instanceof DataOp){ + if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTWRITE ){ + hopCommon.setSelfCost(0); + // Since TWrite and TRead have the same FedOutType, forwarding cost is zero + hopCommon.setForwardingCost(0); + return 0; + } else if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTREAD) { + hopCommon.setSelfCost(0); + // TRead may have a different FedOutType from its parent, so calculate forwarding cost + // TODO: Uncertain about the number of TWrites + hopCommon.setForwardingCost(computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate())); + return 0; + } + } + + // In loops, selfCost is repeated, but forwarding may not be + // Therefore, the weight for forwarding follows the parent's weight (TODO: Q. Is the parent also receiving forwarding once?) + double selfCost = hopCommon.weight * computeSelfCost(hopCommon.hopRef); + double forwardingCost = computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate()); + + int numParents = hopCommon.hopRef.getParent().size(); + if (numParents >= 2) { + selfCost /= numParents; + forwardingCost /= numParents; + } + + hopCommon.setSelfCost(selfCost); + hopCommon.setForwardingCost(forwardingCost); + + return selfCost; + } + + /** + * Computes the cost for the current Hop node. + * + * @param currentHop The Hop node whose cost needs to be computed + * @return The total cost for the current node's operation + */ + private static double computeSelfCost(Hop currentHop){ + double computeCost = ComputeCost.getHOPComputeCost(currentHop); + double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); + double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); + + // Compute total cost assuming: + // 1. Computation and input access can be overlapped (hence taking max) + // 2. Output access must wait for both to complete (hence adding) + return Math.max(computeCost, inputAccessCost) + ouputAccessCost; + } + + /** + * Calculates the memory access cost based on data size and memory bandwidth. + * + * @param memSize Size of data to be accessed (in bytes) + * @return Time cost for memory access (in seconds) + */ + private static double computeHopMemoryAccessCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; + } + + /** + * Calculates the network transfer cost based on data size and network bandwidth. + * Used when federation status changes between parent and child plans. + * + * @param memSize Size of data to be transferred (in bytes) + * @return Time cost for network transfer (in seconds) + */ + private static double computeHopForwardingCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; + } + + /** + * Resolves conflicts in federated plans where different plans have different FederatedOutput types. + * This function traverses the list of conflicting plans in reverse order to ensure that conflicts + * are resolved from the bottom-up, allowing for consistent federated output types across the plan. + * It calculates additional costs for each potential resolution and updates the cumulative additional cost. + * + * @param memoTable The FederatedMemoTable containing all federated plan variants. + * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent plans with conflicting federated outputs. + * @param cumulativeAdditionalCost An array to store the cumulative additional cost incurred by resolving conflicts. + * @return A LinkedHashMap of resolved federated plans, marked with a boolean indicating resolution status. + */ + public static LinkedHashMap resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) { + // LinkedHashMap to store resolved federated plans for BFS traversal. + LinkedHashMap resolvedFedPlanLinkedMap = new LinkedHashMap<>(); + + // Traverse the conflictFedPlanList in reverse order after BFS to resolve conflicts + for (Map.Entry> conflictFedPlanPair : conflictFedPlanLinkedMap.entrySet()) { + long conflictHopID = conflictFedPlanPair.getKey(); + List conflictParentFedPlans = conflictFedPlanPair.getValue(); + + // Retrieve the conflicting federated plans for LOUT and FOUT types + FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT); + FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT); + + // Variables to store additional costs for LOUT and FOUT types + double lOutAdditionalCost = 0; + double fOutAdditionalCost = 0; + + // Flags to check if the plan involves network transfer + // Network transfer cost is calculated only once, even if it occurs multiple times + boolean isLOutForwarding = false; + boolean isFOutForwarding = false; + + // Determine the optimal federated output type based on the calculated costs + FederatedOutput optimalFedOutType; + + // Iterate over each parent federated plan in the current conflict pair + for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { + // Find the calculated FedOutType of the child plan + Pair cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream() + .filter(pair -> pair.getLeft().equals(conflictHopID)) + .findFirst() + .orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + conflictHopID)); + + // CASE 1. Calculated LOUT / Parent LOUT / Current LOUT: Total cost remains unchanged. + // CASE 2. Calculated LOUT / Parent FOUT / Current LOUT: Total cost remains unchanged, subtract net cost, add net cost later. + // CASE 3. Calculated FOUT / Parent LOUT / Current LOUT: Change total cost, subtract net cost. + // CASE 4. Calculated FOUT / Parent FOUT / Current LOUT: Change total cost, add net cost later. + // CASE 5. Calculated LOUT / Parent LOUT / Current FOUT: Change total cost, add net cost later. + // CASE 6. Calculated LOUT / Parent FOUT / Current FOUT: Change total cost, subtract net cost. + // CASE 7. Calculated FOUT / Parent LOUT / Current FOUT: Total cost remains unchanged, subtract net cost, add net cost later. + // CASE 8. Calculated FOUT / Parent FOUT / Current FOUT: Total cost remains unchanged. + + // Adjust LOUT, FOUT costs based on the calculated plan's output type + if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { + // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost + // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. + fOutAdditionalCost += confilctFOutFedPlan.getCumulativeCost() - confilctLOutFedPlan.getCumulativeCost(); + + if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { + // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred + // (CASE 5) If changing from calculated LOUT to current FOUT, network transfer cost occurs, but calculated later + isFOutForwarding = true; + } else { + // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred + // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later + isLOutForwarding = true; + lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + + // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it + fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + } + } else { + lOutAdditionalCost += confilctLOutFedPlan.getCumulativeCost() - confilctFOutFedPlan.getCumulativeCost(); + + if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { + isLOutForwarding = true; + } else { + isFOutForwarding = true; + lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + } + } + } + + // Add network transfer costs if applicable + if (isLOutForwarding) { + lOutAdditionalCost += confilctLOutFedPlan.getForwardingCost(); + } + if (isFOutForwarding) { + fOutAdditionalCost += confilctFOutFedPlan.getForwardingCost(); + } + + // Determine the optimal federated output type based on the calculated costs + if (lOutAdditionalCost <= fOutAdditionalCost) { + optimalFedOutType = FederatedOutput.LOUT; + cumulativeAdditionalCost[0] += lOutAdditionalCost; + resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true); + } else { + optimalFedOutType = FederatedOutput.FOUT; + cumulativeAdditionalCost[0] += fOutAdditionalCost; + resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true); + } + + // Update only the optimal federated output type, not the cost itself or recursively + for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { + for (Pair childPlanPair : conflictParentFedPlan.getChildFedPlans()) { + if (childPlanPair.getLeft() == conflictHopID && childPlanPair.getRight() != optimalFedOutType) { + int index = conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair); + conflictParentFedPlan.getChildFedPlans().set(index, + Pair.of(childPlanPair.getLeft(), optimalFedOutType)); + break; + } + } + } + } + return resolvedFedPlanLinkedMap; + } + } + \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java index 20485588d32..d23f7ebcf92 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java @@ -17,75 +17,91 @@ * under the License. */ -package org.apache.sysds.test.component.federated; + package org.apache.sysds.test.component.federated; -import java.io.IOException; -import java.util.HashMap; - -import org.apache.sysds.hops.Hop; -import org.junit.Assert; -import org.junit.Test; -import org.apache.sysds.api.DMLScript; -import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.conf.DMLConfig; -import org.apache.sysds.parser.DMLProgram; -import org.apache.sysds.parser.DMLTranslator; -import org.apache.sysds.parser.ParserFactory; -import org.apache.sysds.parser.ParserWrapper; -import org.apache.sysds.test.AutomatedTestBase; -import org.apache.sysds.test.TestConfiguration; -import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; - - -public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase -{ - private static final String TEST_DIR = "functions/federated/privacy/"; - private static final String HOME = SCRIPT_DIR + TEST_DIR; - private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; - - @Override - public void setUp() {} - - @Test - public void testFederatedPlanCostEnumerator1() { runTest("FederatedPlanCostEnumeratorTest1.dml"); } - - @Test - public void testFederatedPlanCostEnumerator2() { runTest("FederatedPlanCostEnumeratorTest2.dml"); } - - @Test - public void testFederatedPlanCostEnumerator3() { runTest("FederatedPlanCostEnumeratorTest3.dml"); } - - // Todo: Need to write test scripts for the federated version - private void runTest( String scriptFilename ) { - int index = scriptFilename.lastIndexOf(".dml"); - String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); - TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); - addTestConfiguration(testName, testConfig); - loadTestConfiguration(testConfig); - - try { - DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); - ConfigurationManager.setLocalConfig(conf); - - //read script - String dmlScriptString = DMLScript.readDMLScript(true, HOME + scriptFilename); - - //parsing and dependency analysis - ParserWrapper parser = ParserFactory.createParser(); - DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); - DMLTranslator dmlt = new DMLTranslator(prog); - dmlt.liveVariableAnalysis(prog); - dmlt.validateParseTree(prog); - dmlt.constructHops(prog); - dmlt.rewriteHopsDAG(prog); - dmlt.constructLops(prog); - - Hop hops = prog.getStatementBlocks().get(0).getHops().get(0); - FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true); - } - catch (IOException e) { - e.printStackTrace(); - Assert.fail(); - } - } -} + import java.io.IOException; + import java.util.HashMap; + import org.junit.Assert; + import org.junit.Test; + import org.apache.sysds.api.DMLScript; + import org.apache.sysds.conf.ConfigurationManager; + import org.apache.sysds.conf.DMLConfig; + import org.apache.sysds.parser.DMLProgram; + import org.apache.sysds.parser.DMLTranslator; + import org.apache.sysds.parser.ParserFactory; + import org.apache.sysds.parser.ParserWrapper; + import org.apache.sysds.test.AutomatedTestBase; + import org.apache.sysds.test.TestConfiguration; + import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; + + public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase + { + private static final String TEST_DIR = "functions/federated/privacy/"; + private static final String HOME = SCRIPT_DIR + TEST_DIR; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; + + @Override + public void setUp() {} + + @Test + public void testFederatedPlanCostEnumerator1() { runTest("FederatedPlanCostEnumeratorTest1.dml"); } + + @Test + public void testFederatedPlanCostEnumerator2() { runTest("FederatedPlanCostEnumeratorTest2.dml"); } + + @Test + public void testFederatedPlanCostEnumerator3() { runTest("FederatedPlanCostEnumeratorTest3.dml"); } + + @Test + public void testFederatedPlanCostEnumerator4() { runTest("FederatedPlanCostEnumeratorTest4.dml"); } + + @Test + public void testFederatedPlanCostEnumerator5() { runTest("FederatedPlanCostEnumeratorTest5.dml"); } + + @Test + public void testFederatedPlanCostEnumerator6() { runTest("FederatedPlanCostEnumeratorTest6.dml"); } + + @Test + public void testFederatedPlanCostEnumerator7() { runTest("FederatedPlanCostEnumeratorTest7.dml"); } + + @Test + public void testFederatedPlanCostEnumerator8() { runTest("FederatedPlanCostEnumeratorTest8.dml"); } + + @Test + public void testFederatedPlanCostEnumerator9() { runTest("FederatedPlanCostEnumeratorTest9.dml"); } + + // Todo: Need to write test scripts for the federated version + private void runTest( String scriptFilename ) { + int index = scriptFilename.lastIndexOf(".dml"); + String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); + TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); + addTestConfiguration(testName, testConfig); + loadTestConfiguration(testConfig); + + try { + DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); + ConfigurationManager.setLocalConfig(conf); + + //read script + String dmlScriptString = DMLScript.readDMLScript(true, HOME + scriptFilename); + + //parsing and dependency analysis + ParserWrapper parser = ParserFactory.createParser(); + DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + dmlt.rewriteHopsDAG(prog); + dmlt.constructLops(prog); + dmlt.rewriteLopDAG(prog); + + FederatedPlanCostEnumerator.enumerateProgram(prog, true); + } + catch (IOException e) { + e.printStackTrace(); + Assert.fail(); + } + } + } + \ No newline at end of file diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest4.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest4.dml new file mode 100644 index 00000000000..06533df144d --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest4.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +a = matrix(7,10,10); +if (sum(a) > 0.5) + b = a * 2; +else + b = a * 3; +c = sqrt(b); +print(sum(c)); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml new file mode 100644 index 00000000000..2721bbcbaf6 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +for( i in 1:100 ) +{ + b = i + 1; + print(b); +} \ No newline at end of file diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml new file mode 100644 index 00000000000..b95ae1b5bb0 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix(7, rows=10, cols=10) +b = rand(rows = 1, cols = ncol(A), min = 1, max = 2); +i = 0 + +while (sum(b) < i) { + i = i + 1 + b = b + i + A = A * A + s = b %*% A + print(mean(s)) +} +c = sqrt(A) +print(sum(c)) \ No newline at end of file diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest7.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest7.dml new file mode 100644 index 00000000000..e3efaa28515 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest7.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +a = 1; + +parfor( i in 1:10 ) +{ + b = i + a; + #print(b); +} diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest8.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest8.dml new file mode 100644 index 00000000000..1587ff613b4 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest8.dml @@ -0,0 +1,49 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +a = rand(); +b= rand(); +c= rand(); +d= rand(); +e= rand(); +f= rand(); +h= rand(); +i= rand(); + +if (a < 30){ + a = a + b; + + if (a < 20) { + a = a * c; + } else { + a = a + d; + + if (a < 10) { + a = a + e; + } else { + a = a + f; + } + } +} else { + a = a + h; +} +c = a + i; +print(mean(c)) \ No newline at end of file diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest9.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest9.dml new file mode 100644 index 00000000000..b5713374f2c --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest9.dml @@ -0,0 +1,58 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +# Define UDFs +meanUser = function (matrix[double] A) return (double m) { + m = sum(A)/nrow(A) +} + +minMaxUser = function( matrix[double] M) return (double minVal, double maxVal) { + minVal = min(M); + maxVal = max(M); +} + +# Recursive function: Calculate factorial +factorialUser = function(int n) return (int result) { + if (n <= 1) { + result = 1; # base case + } else { + result = n * factorialUser(n - 1); # recursive call + } +} + +# Main script +# 1. Create matrix and calculate statistics +M = rand(rows=4, cols=4, min=1, max=5); # 4x4 random matrix +avg = meanUser(M); +[min_val, max_val] = minMaxUser(M); + +# 2. Call recursive function (factorial) +number = 5; +fact_result = factorialUser(number); + +# 3. Print results +print("=== Matrix Statistics ==="); +print("Average: " + avg); +print("Min: " + min_val + ", Max: " + max_val); + +print("\n=== Recursive Function ==="); +print("Factorial of " + number + ": " + fact_result); \ No newline at end of file From 51a290fca154fe38b2c1abf0b32fdd6f2d3d2cb3 Mon Sep 17 00:00:00 2001 From: min-guk Date: Wed, 8 Jan 2025 16:55:23 +0900 Subject: [PATCH 2/9] Update detectConflictFedPlan and resolveConflictFedPlan --- .../hops/fedplanner/FederatedMemoTable.java | 426 ++++++---- .../FederatedPlanCostEnumerator.java | 746 +++++------------- .../FederatedPlanCostEstimator.java | 457 +++++------ .../FederatedPlanCostEnumeratorTest.java | 155 ++-- 4 files changed, 761 insertions(+), 1023 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java index dae809179b6..a18376e188e 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -17,138 +17,294 @@ * under the License. */ - package org.apache.sysds.hops.fedplanner; - - import java.util.Comparator; - import java.util.HashMap; - import java.util.List; - import java.util.ArrayList; - import java.util.Map; - import org.apache.sysds.hops.Hop; - import org.apache.commons.lang3.tuple.Pair; - import org.apache.commons.lang3.tuple.ImmutablePair; - import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; - - /** - * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. - * This table stores and manages different execution plan variants for each Hop and fedOutType combination, - * facilitating the optimization of federated execution plans. - */ - public class FederatedMemoTable { - // Maps Hop ID and fedOutType pairs to their plan variants - private final Map, FedPlanVariants> hopMemoTable = new HashMap<>(); - - public void addFedPlanVariants(long hopID, FederatedOutput fedOutType, FedPlanVariants fedPlanVariants) { - hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariants); - } - - public FedPlanVariants getFedPlanVariants(Pair fedPlanPair) { - return hopMemoTable.get(fedPlanPair); - } - - public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) { - FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); - return fedPlanVariantList._fedPlanVariants.get(0); - } - - public FedPlan getFedPlanAfterPrune(Pair fedPlanPair) { - FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); - return fedPlanVariantList._fedPlanVariants.get(0); - } - - public boolean contains(long hopID, FederatedOutput fedOutType) { - return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); - } - - /** - * Represents a single federated execution plan with its associated costs and dependencies. - * This class contains: - * 1. selfCost: Cost of the current hop (computation + input/output memory access). - * 2. cumulativeCost: Total cost including this plan's selfCost and all child plans' cumulativeCost. - * 3. forwardingCost: Network transfer cost for this plan to the parent plan. - * - * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs. - */ - public static class FedPlan { - private double cumulativeCost; // Total cost = sum of selfCost + cumulativeCost of child plans - private final FedPlanVariants fedPlanVariants; // Reference to variant list - private final List> childFedPlans; // Child plan references - - public FedPlan(double cumulativeCost, FedPlanVariants fedPlanVariants, List> childFedPlans) { - this.cumulativeCost = cumulativeCost; - this.fedPlanVariants = fedPlanVariants; - this.childFedPlans = childFedPlans; - } - - public Hop getHopRef() {return fedPlanVariants.hopCommon.getHopRef();} - public long getHopID() {return fedPlanVariants.hopCommon.getHopRef().getHopID();} - public FederatedOutput getFedOutType() {return fedPlanVariants.getFedOutType();} - public double getCumulativeCost() {return cumulativeCost;} - public double getSelfCost() {return fedPlanVariants.hopCommon.getSelfCost();} - public double getForwardingCost() {return fedPlanVariants.hopCommon.getForwardingCost();} - public double getWeight() {return fedPlanVariants.hopCommon.getWeight();} - public List> getChildFedPlans() {return childFedPlans;} - } - - /** - * Represents a collection of federated execution plan variants for a specific Hop and FederatedOutput. - * This class contains cost information and references to the associated plans. - * It uses HopCommon to store common properties and costs related to the Hop. - */ - public static class FedPlanVariants { - protected HopCommon hopCommon; // Common properties and costs for the Hop - private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) - protected List _fedPlanVariants; // List of plan variants - - public FedPlanVariants(HopCommon hopCommon, FederatedOutput fedOutType) { - this.hopCommon = hopCommon; - this.fedOutType = fedOutType; - this._fedPlanVariants = new ArrayList<>(); - } - - public boolean isEmpty() {return _fedPlanVariants.isEmpty();} - public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} - public List getFedPlanVariants() {return _fedPlanVariants;} - public FederatedOutput getFedOutType() {return fedOutType;} - - public void pruneFedPlans() { - if (_fedPlanVariants.size() > 1) { - // Find the FedPlan with the minimum cumulative cost - FedPlan minCostPlan = _fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getCumulativeCost)) - .orElse(null); - - // Retain only the minimum cost plan - _fedPlanVariants.clear(); - _fedPlanVariants.add(minCostPlan); - } - } - } - - /** - * Represents common properties and costs associated with a Hop. - * This class holds a reference to the Hop and tracks its execution and network forwarding (transfer) costs. - */ - public static class HopCommon { - protected final Hop hopRef; // Reference to the associated Hop - protected double selfCost; // Cost of the hop's computation and memory access - protected double forwardingCost; // Cost of forwarding the hop's output to its parent - protected double weight; // Weight used to calculate cost based on hop execution frequency - - public HopCommon(Hop hopRef, double weight) { - this.hopRef = hopRef; - this.selfCost = 0; - this.forwardingCost = 0; - this.weight = weight; - } - - public Hop getHopRef() {return hopRef;} - public double getSelfCost() {return selfCost;} - public double getForwardingCost() {return forwardingCost;} - public double getWeight() {return weight;} - - protected void setSelfCost(double selfCost) {this.selfCost = selfCost;} - protected void setForwardingCost(double forwardingCost) {this.forwardingCost = forwardingCost;} - } - } - \ No newline at end of file +package org.apache.sysds.hops.fedplanner; + +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashSet; +import java.util.Set; + +/** + * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. + * This table stores and manages different execution plan variants for each Hop and fedOutType combination, + * facilitating the optimization of federated execution plans. + */ +public class FederatedMemoTable { + // Maps Hop ID and fedOutType pairs to their plan variants + private final Map, FedPlanVariants> hopMemoTable = new HashMap<>(); + + /** + * Adds a new federated plan to the memo table. + * Creates a new variant list if none exists for the given Hop and fedOutType. + * + * @param hop The Hop node + * @param fedOutType The federated output type + * @param planChilds List of child plan references + * @return The newly created FedPlan + */ + public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List> planChilds) { + long hopID = hop.getHopID(); + FedPlanVariants fedPlanVariantList; + + if (contains(hopID, fedOutType)) { + fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + } else { + fedPlanVariantList = new FedPlanVariants(hop, fedOutType); + hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList); + } + + FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList); + fedPlanVariantList.addFedPlan(newPlan); + + return newPlan; + } + + /** + * Retrieves the minimum cost child plan considering the parent's output type. + * The cost is calculated using getParentViewCost to account for potential type mismatches. + * + * @param childHopID ? + * @param childFedOutType ? + * @return ? + */ + public FedPlan getMinCostFedPlan(long hopID, FederatedOutput fedOutType) { + FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + return fedPlanVariantList._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + } + + public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType) { + return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + } + + public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) { + // Todo: Consider whether to verify if pruning has been performed + FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + return fedPlanVariantList._fedPlanVariants.get(0); + } + + /** + * Checks if the memo table contains an entry for a given Hop and fedOutType. + * + * @param hopID The Hop ID. + * @param fedOutType The associated fedOutType. + * @return True if the entry exists, false otherwise. + */ + public boolean contains(long hopID, FederatedOutput fedOutType) { + return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); + } + + /** + * Prunes all entries in the memo table, retaining only the minimum-cost + * FedPlan for each entry. + */ + public void pruneMemoTable() { + for (Map.Entry, FedPlanVariants> entry : hopMemoTable.entrySet()) { + List fedPlanList = entry.getValue().getFedPlanVariants(); + if (fedPlanList.size() > 1) { + // Find the FedPlan with the minimum cost + FedPlan minCostPlan = fedPlanList.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + + // Retain only the minimum cost plan + fedPlanList.clear(); + fedPlanList.add(minCostPlan); + } + } + } + + // Todo: Separate print functions from FederatedMemoTable + /** + * Recursively prints a tree representation of the DAG starting from the given root FedPlan. + * Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node. + * + * @param rootFedPlan The starting point FedPlan to print + */ + public void printFedPlanTree(FedPlan rootFedPlan) { + Set visited = new HashSet<>(); + printFedPlanTreeRecursive(rootFedPlan, visited, 0, true); + } + + /** + * Helper method to recursively print the FedPlan tree. + * + * @param plan The current FedPlan to print + * @param visited Set to keep track of visited FedPlans (prevents cycles) + * @param depth The current depth level for indentation + * @param isLast Whether this node is the last child of its parent + */ + private void printFedPlanTreeRecursive(FedPlan plan, Set visited, int depth, boolean isLast) { + if (plan == null || visited.contains(plan)) { + return; + } + + visited.add(plan); + + Hop hop = plan.getHopRef(); + StringBuilder sb = new StringBuilder(); + + // Add FedPlan information + sb.append(String.format("(%d) ", plan.getHopRef().getHopID())) + .append(plan.getHopRef().getOpString()) + .append(" [") + .append(plan.getFedOutType()) + .append("]"); + + StringBuilder childs = new StringBuilder(); + childs.append(" ("); + boolean childAdded = false; + for( Hop input : hop.getInput()){ + childs.append(childAdded?",":""); + childs.append(input.getHopID()); + childAdded = true; + } + childs.append(")"); + if( childAdded ) + sb.append(childs.toString()); + + + sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", + plan.getTotalCost(), + plan.getSelfCost(), + plan.getNetTransferCost())); + + // Add matrix characteristics + sb.append(" [") + .append(hop.getDim1()).append(", ") + .append(hop.getDim2()).append(", ") + .append(hop.getBlocksize()).append(", ") + .append(hop.getNnz()); + + if (hop.getUpdateType().isInPlace()) { + sb.append(", ").append(hop.getUpdateType().toString().toLowerCase()); + } + sb.append("]"); + + // Add memory estimates + sb.append(" [") + .append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ") + .append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ") + .append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ") + .append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]"); + + // Add reblock and checkpoint requirements + if (hop.requiresReblock() && hop.requiresCheckpoint()) { + sb.append(" [rblk, chkpt]"); + } else if (hop.requiresReblock()) { + sb.append(" [rblk]"); + } else if (hop.requiresCheckpoint()) { + sb.append(" [chkpt]"); + } + + // Add execution type + if (hop.getExecType() != null) { + sb.append(", ").append(hop.getExecType()); + } + + System.out.println(sb); + + // Process child nodes + List> childRefs = plan.getChildFedPlans(); + for (int i = 0; i < childRefs.size(); i++) { + Pair childRef = childRefs.get(i); + FedPlanVariants childVariants = getFedPlanVariants(childRef.getLeft(), childRef.getRight()); + if (childVariants == null || childVariants.getFedPlanVariants().isEmpty()) + continue; + + boolean isLastChild = (i == childRefs.size() - 1); + for (FedPlan childPlan : childVariants.getFedPlanVariants()) { + printFedPlanTreeRecursive(childPlan, visited, depth + 1, isLastChild); + } + } + } + + /** + * Represents common properties and costs associated with a Hop. + * This class holds a reference to the Hop and tracks its execution and network transfer costs. + */ + public static class HopCommon { + protected final Hop hopRef; // Reference to the associated Hop + protected double selfCost; // Current execution cost (compute + memory access) + protected double netTransferCost; // Network transfer cost + + protected HopCommon(Hop hopRef) { + this.hopRef = hopRef; + this.selfCost = 0; + this.netTransferCost = 0; + } + } + + /** + * Represents a collection of federated execution plan variants for a specific Hop and FederatedOutput. + * This class contains cost information and references to the associated plans. + * It uses HopCommon to store common properties and costs related to the Hop. + */ + public static class FedPlanVariants { + protected HopCommon hopCommon; // Common properties and costs for the Hop + private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) + protected List _fedPlanVariants; // List of plan variants + + public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) { + this.hopCommon = new HopCommon(hopRef); + this.fedOutType = fedOutType; + this._fedPlanVariants = new ArrayList<>(); + } + + public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} + public List getFedPlanVariants() {return _fedPlanVariants;} + } + + /** + * Represents a single federated execution plan with its associated costs and dependencies. + * This class contains: + * 1. selfCost: Cost of current hop (compute + input/output memory access) + * 2. totalCost: Cumulative cost including this plan and all child plans + * 3. netTransferCost: Network transfer cost for this plan to parent plan. + * + * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs. + */ + public static class FedPlan { + private double totalCost; // Total cost including child plans + private final FedPlanVariants fedPlanVariants; // Reference to variant list + private final List> childFedPlans; // Child plan references + + public FedPlan(List> childFedPlans, FedPlanVariants fedPlanVariants) { + this.totalCost = 0; + this.childFedPlans = childFedPlans; + this.fedPlanVariants = fedPlanVariants; + } + + public void setTotalCost(double totalCost) {this.totalCost = totalCost;} + public void setSelfCost(double selfCost) {fedPlanVariants.hopCommon.selfCost = selfCost;} + public void setNetTransferCost(double netTransferCost) {fedPlanVariants.hopCommon.netTransferCost = netTransferCost;} + + public Hop getHopRef() {return fedPlanVariants.hopCommon.hopRef;} + public long getHopID() {return fedPlanVariants.hopCommon.hopRef.getHopID();} + public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;} + public double getTotalCost() {return totalCost;} + public double getSelfCost() {return fedPlanVariants.hopCommon.selfCost;} + public double getNetTransferCost() {return fedPlanVariants.hopCommon.netTransferCost;} + public List> getChildFedPlans() {return childFedPlans;} + + /** + * Calculates the conditional network transfer cost based on output type compatibility. + * Returns 0 if output types match, otherwise returns the network transfer cost. + * @param parentFedOutType The federated output type of the parent plan. + * @return The conditional network transfer cost. + */ + public double getCondNetTransferCost(FederatedOutput parentFedOutType) { + if (parentFedOutType == getFedOutType()) return 0; + return fedPlanVariants.hopCommon.netTransferCost; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java index 56586a30622..db1583ab2fb 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -17,558 +17,194 @@ * under the License. */ - package org.apache.sysds.hops.fedplanner; - import java.util.ArrayList; - import java.util.List; - import java.util.Map; - import java.util.HashMap; - import java.util.LinkedHashMap; - import java.util.Optional; - import java.util.Set; - import java.util.HashSet; - - import org.apache.commons.lang3.tuple.Pair; - - import org.apache.commons.lang3.tuple.ImmutablePair; - import org.apache.sysds.common.Types; - import org.apache.sysds.hops.DataOp; - import org.apache.sysds.hops.Hop; - import org.apache.sysds.hops.LiteralOp; - import org.apache.sysds.hops.UnaryOp; - import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; - import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; - import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; - import org.apache.sysds.hops.rewrite.HopRewriteUtils; - import org.apache.sysds.parser.DMLProgram; - import org.apache.sysds.parser.ForStatement; - import org.apache.sysds.parser.ForStatementBlock; - import org.apache.sysds.parser.FunctionStatement; - import org.apache.sysds.parser.FunctionStatementBlock; - import org.apache.sysds.parser.IfStatement; - import org.apache.sysds.parser.IfStatementBlock; - import org.apache.sysds.parser.StatementBlock; - import org.apache.sysds.parser.WhileStatement; - import org.apache.sysds.parser.WhileStatementBlock; - import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; - import org.apache.sysds.runtime.util.UtilFunctions; - - public class FederatedPlanCostEnumerator { - private static final double DEFAULT_LOOP_WEIGHT = 10.0; - private static final double DEFAULT_IF_ELSE_WEIGHT = 0.5; - - /** - * Enumerates the entire DML program to generate federated execution plans. - * It processes each statement block, computes the optimal federated plan, - * detects and resolves conflicts, and optionally prints the plan tree. - * - * @param prog The DML program to enumerate. - * @param isPrint A boolean indicating whether to print the federated plan tree. - */ - public static void enumerateProgram(DMLProgram prog, boolean isPrint) { - FederatedMemoTable memoTable = new FederatedMemoTable(); - - Map> outerTransTable = new HashMap<>(); - Map> formerInnerTransTable = new HashMap<>(); - Set progRootHopSet = new HashSet<>(); // Set of hops for the root dummy node - // TODO: Just for debug, remove later - Set statRootHopSet = new HashSet<>(); // Set of hops that have no parent but are not referenced - - for (StatementBlock sb : prog.getStatementBlocks()) { - Optional.ofNullable(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, 1, false)) - .ifPresent(outerTransTable::putAll); - } - - FedPlan optimalPlan = getMinCostRootFedPlan(progRootHopSet, memoTable); - - // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types - double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); - - // Print the federated plan tree if requested - if (isPrint) { - FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, statRootHopSet, memoTable, additionalTotalCost); - } - } - - - /** - * Enumerates the statement block and updates the transient and memoization tables. - * This method processes different types of statement blocks such as If, For, While, and Function blocks. - * It recursively enumerates the Hop DAGs within these blocks and updates the corresponding tables. - * The method also calculates weights recursively for if-else/loops and handles inner and outer block distinctions. - * - * @param sb The statement block to enumerate. - * @param memoTable The memoization table to store plan variants. - * @param outerTransTable The table to track immutable outer transient writes. - * @param formerInnerTransTable The table to track immutable former inner transient writes. - * @param progRootHopSet The set of hops to connect to the root dummy node. - * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). - * @param weight The weight associated with the current Hop. - * @param isInnerBlock A boolean indicating if the current block is an inner block. - * @return A map of inner transient writes. - */ - public static Map> enumerateStatementBlock(StatementBlock sb, FederatedMemoTable memoTable, Map> outerTransTable, - Map> formerInnerTransTable, Set progRootHopSet, Set statRootHopSet, double weight, boolean isInnerBlock) { - Map> innerTransTable = new HashMap<>(); - - if (sb instanceof IfStatementBlock) { - IfStatementBlock isb = (IfStatementBlock) sb; - IfStatement istmt = (IfStatement)isb.getStatement(0); - - enumerateHopDAG(isb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - - // Treat outerTransTable as immutable in inner blocks - // Write TWrite of sb sequentially in innerTransTable, and update formerInnerTransTable after the sb ends - // In case of if-else, create separate formerInnerTransTables for if and else, merge them after completion, and update formerInnerTransTable - Map> ifFormerInnerTransTable = new HashMap<>(formerInnerTransTable); - Map> elseFormerInnerTransTable = new HashMap<>(formerInnerTransTable); - - for (StatementBlock csb : istmt.getIfBody()){ - ifFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, ifFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); - } - - for (StatementBlock csb : istmt.getElseBody()){ - elseFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, elseFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); - } - - // If there are common keys: merge elseValue list into ifValue list - elseFormerInnerTransTable.forEach((key, elseValue) -> { - ifFormerInnerTransTable.merge(key, elseValue, (ifValue, newValue) -> { - ifValue.addAll(newValue); - return ifValue; - }); - }); - // Update innerTransTable - innerTransTable.putAll(ifFormerInnerTransTable); - } else if (sb instanceof ForStatementBlock) { //incl parfor - ForStatementBlock fsb = (ForStatementBlock) sb; - ForStatement fstmt = (ForStatement)fsb.getStatement(0); - - // Calculate for-loop iteration count if possible - double loopWeight = DEFAULT_LOOP_WEIGHT; - Hop from = fsb.getFromHops().getInput().get(0); - Hop to = fsb.getToHops().getInput().get(0); - Hop incr = (fsb.getIncrementHops() != null) ? - fsb.getIncrementHops().getInput().get(0) : new LiteralOp(1); - - // Calculate for-loop iteration count (weight) if from, to, and incr are literal ops (constant values) - if( from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp ) { - double dfrom = HopRewriteUtils.getDoubleValue((LiteralOp) from); - double dto = HopRewriteUtils.getDoubleValue((LiteralOp) to); - double dincr = HopRewriteUtils.getDoubleValue((LiteralOp) incr); - if( dfrom > dto && dincr == 1 ) - dincr = -1; - loopWeight = UtilFunctions.getSeqLength(dfrom, dto, dincr, false); - } - weight *= loopWeight; - - enumerateHopDAG(fsb.getFromHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - enumerateHopDAG(fsb.getToHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - enumerateHopDAG(fsb.getIncrementHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - - enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); - } else if (sb instanceof WhileStatementBlock) { - WhileStatementBlock wsb = (WhileStatementBlock) sb; - WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); - weight *= DEFAULT_LOOP_WEIGHT; - - enumerateHopDAG(wsb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - enumerateStatementBlockBody(wstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); - } else if (sb instanceof FunctionStatementBlock) { - FunctionStatementBlock fsb = (FunctionStatementBlock)sb; - FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); - - // TODO: NOT descent multiple types (use hash set for functions using function name) - enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); - } else { //generic (last-level) - if( sb.getHops() != null ){ - for(Hop c : sb.getHops()) - // In the statement block, if isInner, write hopDAG in innerTransTable, if not, write directly in outerTransTable - enumerateHopDAG(c, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - } - } - return innerTransTable; - } - - /** - * Enumerates the statement blocks within a body and updates the transient and memoization tables. - * - * @param sbList The list of statement blocks to enumerate. - * @param memoTable The memoization table to store plan variants. - * @param outerTransTable The table to track immutable outer transient writes. - * @param formerInnerTransTable The table to track immutable former inner transient writes. - * @param innerTransTable The table to track inner transient writes. - * @param progRootHopSet The set of hops to connect to the root dummy node. - * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). - * @param weight The weight associated with the current Hop. - */ - public static void enumerateStatementBlockBody(List sbList, FederatedMemoTable memoTable, Map> outerTransTable, - Map> formerInnerTransTable, Map> innerTransTable, Set progRootHopSet, Set statRootHopSet, double weight) { - // The statement blocks within the body reference outerTransTable and formerInnerTransTable as immutable read-only, - // and record TWrite in the innerTransTable of the statement block within the body. - // Update the formerInnerTransTable with the contents of the returned innerTransTable. - for (StatementBlock sb : sbList) - formerInnerTransTable.putAll(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, weight, true)); - - // Then update and return the innerTransTable of the statement block containing the body. - innerTransTable.putAll(formerInnerTransTable); - } - - /** - * Enumerates the statement hop DAG within a statement block. - * This method recursively enumerates all possible federated execution plans - * and identifies hops to connect to the root dummy node. - * - * @param rootHop The root Hop of the DAG to enumerate. - * @param memoTable The memoization table to store plan variants. - * @param outerTransTable The table to track transient writes. - * @param formerInnerTransTable The table to track immutable inner transient writes. - * @param innerTransTable The table to track inner transient writes. - * @param progRootHopSet The set of hops to connect to the root dummy node. - * @param statRootHopSet The set of root hops for debugging purposes. - * @param weight The weight associated with the current Hop. - * @param isInnerBlock A boolean indicating if the current block is an inner block. - */ - public static void enumerateHopDAG(Hop rootHop, FederatedMemoTable memoTable, Map> outerTransTable, - Map> formerInnerTransTable, Map> innerTransTable, Set progRootHopSet, Set statRootHopSet, double weight, boolean isInnerBlock) { - // Recursively enumerate all possible plans - rewireAndEnumerateFedPlan(rootHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInnerBlock); - - // Identify hops to connect to the root dummy node - - if ((rootHop instanceof DataOp && (rootHop.getName().equals("__pred"))) // TWrite "__pred" - || (rootHop instanceof UnaryOp && ((UnaryOp)rootHop).getOp() == Types.OpOp1.PRINT)){ // u(print) - // Connect TWrite pred and u(print) to the root dummy node - // TODO: Should the last unreferenced TWrite be connected? - progRootHopSet.add(rootHop); - } else { - // TODO: Just for debug, remove later - // For identifying TWrites that are not referenced later - statRootHopSet.add(rootHop); - } - } - - /** - * Rewires and enumerates federated execution plans for a given Hop. - * This method processes all input nodes, rewires TWrite and TRead operations, - * and generates federated plan variants for both inner and outer code blocks. - * - * @param hop The Hop for which to rewire and enumerate federated plans. - * @param memoTable The memoization table to store plan variants. - * @param outerTransTable The table to track transient writes. - * @param formerInnerTransTable The table to track immutable inner transient writes. - * @param innerTransTable The table to track inner transient writes. - * @param weight The weight associated with the current Hop. - * @param isInner A boolean indicating if the current block is an inner block. - */ - private static void rewireAndEnumerateFedPlan(Hop hop, FederatedMemoTable memoTable, Map> outerTransTable, - Map> formerInnerTransTable, Map> innerTransTable, double weight, boolean isInner) { - // Process all input nodes first if not already in memo table - for (Hop inputHop : hop.getInput()) { - long inputHopID = inputHop.getHopID(); - if (!memoTable.contains(inputHopID, FederatedOutput.FOUT) - && !memoTable.contains(inputHopID, FederatedOutput.LOUT)) { - rewireAndEnumerateFedPlan(inputHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInner); - } - } - - // Detect and Rewire TWrite and TRead operations - List childHops = hop.getInput(); - if (hop instanceof DataOp && !(hop.getName().equals("__pred"))){ - String hopName = hop.getName(); - - if (isInner){ // If it's an inner code block - if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTWRITE){ - innerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); - } else if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTREAD){ - // Copy existing and add TWrite - childHops = new ArrayList<>(childHops); - List additionalChildHops = null; - - // Read according to priority - if (innerTransTable.containsKey(hopName)){ - additionalChildHops = innerTransTable.get(hopName); - } else if (formerInnerTransTable.containsKey(hopName)){ - additionalChildHops = formerInnerTransTable.get(hopName); - } else if (outerTransTable.containsKey(hopName)){ - additionalChildHops = outerTransTable.get(hopName); - } - - if (additionalChildHops != null) { - childHops.addAll(additionalChildHops); - } - } - } else { // If it's an outer code block - if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTWRITE){ - // Add directly to outerTransTable - outerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); - } else if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTREAD){ - childHops = new ArrayList<>(childHops); - - // TODO: In the case of for (i in 1:10), there is no hop that writes TWrite for i. - // Read directly from outerTransTable and add - List additionalChildHops = outerTransTable.get(hopName); - if (additionalChildHops != null) { - childHops.addAll(additionalChildHops); - } - } - } - } - - // Enumerate the federated plan for the current Hop - enumerateFedPlan(hop, memoTable, childHops, weight); - } - - /** - * Enumerates federated execution plans for a given Hop. - * This method calculates the self cost and child costs for the Hop, - * generates federated plan variants for both LOUT and FOUT output types, - * and prunes redundant plans before adding them to the memo table. - * - * @param hop The Hop for which to enumerate federated plans. - * @param memoTable The memoization table to store plan variants. - * @param childHops The list of child hops. - * @param weight The weight associated with the current Hop. - */ - private static void enumerateFedPlan(Hop hop, FederatedMemoTable memoTable, List childHops, double weight){ - long hopID = hop.getHopID(); - HopCommon hopCommon = new HopCommon(hop, weight); - double selfCost = FederatedPlanCostEstimator.computeHopCost(hopCommon); - - FedPlanVariants lOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.LOUT); - FedPlanVariants fOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.FOUT); - - int numInputs = childHops.size(); - int numInitInputs = hop.getInput().size(); - - double[][] childCumulativeCost = new double[numInputs][2]; // # of child, LOUT/FOUT of child - double[] childForwardingCost = new double[numInputs]; // # of child - - // The self cost follows its own weight, while the forwarding cost follows the parent's weight. - FederatedPlanCostEstimator.getChildCosts(hopCommon, memoTable, childHops, childCumulativeCost, childForwardingCost); - - if (numInitInputs == numInputs){ - enumerateOnlyInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); - } else { - enumerateTReadInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, numInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); - } - - // Prune the FedPlans to remove redundant plans - lOutFedPlanVariants.pruneFedPlans(); - fOutFedPlanVariants.pruneFedPlans(); - - // Add the FedPlanVariants to the memo table - memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT, lOutFedPlanVariants); - memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT, fOutFedPlanVariants); - } - - /** - * Enumerates federated execution plans for initial child hops only. - * This method generates all possible combinations of federated output types (LOUT and FOUT) - * for the initial child hops and calculates their cumulative costs. - * - * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. - * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. - * @param numInitInputs The number of initial input hops. - * @param childHops The list of child hops. - * @param childCumulativeCost The cumulative costs for each child hop. - * @param childForwardingCost The forwarding costs for each child hop. - * @param selfCost The self cost of the current hop. - */ - private static void enumerateOnlyInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, int numInitInputs, List childHops, - double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ - // Iterate 2^n times, generating two FedPlans (LOUT, FOUT) each time. - for (int i = 0; i < (1 << numInitInputs); i++) { - double[] cumulativeCost = new double[]{selfCost, selfCost}; - List> planChilds = new ArrayList<>(); - // LOUT and FOUT share the same planChilds in each iteration (only forwarding cost differs). - enumerateInitChildFedPlan(numInitInputs, childHops, planChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); - - lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, planChilds)); - fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, planChilds)); - } - } - - /** - * Enumerates federated execution plans for a TRead hop. - * This method calculates the cumulative costs for both LOUT and FOUT federated output types - * by considering the additional child hops, which are TWrite hops. - * It generates all possible combinations of federated output types for the initial child hops - * and adds the pre-calculated costs of the TWrite child hops to these combinations. - * - * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. - * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. - * @param numInitInputs The number of initial input hops. - * @param numInputs The total number of input hops, including additional TWrite hops. - * @param childHops The list of child hops. - * @param childCumulativeCost The cumulative costs for each child hop. - * @param childForwardingCost The forwarding costs for each child hop. - * @param selfCost The self cost of the current hop. - */ - private static void enumerateTReadInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, - int numInitInputs, int numInputs, List childHops, - double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ - double lOutTReadCumulativeCost = selfCost; - double fOutTReadCumulativeCost = selfCost; - - List> lOutTReadPlanChilds = new ArrayList<>(); - List> fOutTReadPlanChilds = new ArrayList<>(); - - // Pre-calculate the cost for the additional child hop, which is a TWrite hop, of the TRead hop. - // Constraint: TWrite must have the same FedOutType as TRead. - for (int j = numInitInputs; j < numInputs; j++) { - Hop inputHop = childHops.get(j); - lOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.LOUT)); - fOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.FOUT)); - - lOutTReadCumulativeCost += childCumulativeCost[j][0]; - fOutTReadCumulativeCost += childCumulativeCost[j][1]; - // Skip TWrite -> TRead as they have the same FedOutType. - } - - for (int i = 0; i < (1 << numInitInputs); i++) { - double[] cumulativeCost = new double[]{selfCost, selfCost}; - List> lOutPlanChilds = new ArrayList<>(); - enumerateInitChildFedPlan(numInitInputs, childHops, lOutPlanChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); - - // Copy lOutPlanChilds to create fOutPlanChilds and add the pre-calculated cost of the TWrite child hop. - List> fOutPlanChilds = new ArrayList<>(lOutPlanChilds); - - lOutPlanChilds.addAll(lOutTReadPlanChilds); - fOutPlanChilds.addAll(fOutTReadPlanChilds); - - cumulativeCost[0] += lOutTReadCumulativeCost; - cumulativeCost[1] += fOutTReadCumulativeCost; - - lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, lOutPlanChilds)); - fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, fOutPlanChilds)); - } - } - - // Calculates costs for initial child hops, determining FOUT or LOUT based on `i`. - private static void enumerateInitChildFedPlan(int numInitInputs, List childHops, List> planChilds, - double[][] childCumulativeCost, double[] childForwardingCost, double[] cumulativeCost, int i){ - // For each input, determine if it should be FOUT or LOUT based on bit j in i - for (int j = 0; j < numInitInputs; j++) { - Hop inputHop = childHops.get(j); - // Calculate the bit value to decide between FOUT and LOUT for the current input - final int bit = (i & (1 << j)) != 0 ? 1 : 0; // Determine the bit value (decides FOUT/LOUT) - final FederatedOutput childType = (bit == 1) ? FederatedOutput.FOUT : FederatedOutput.LOUT; - planChilds.add(Pair.of(inputHop.getHopID(), childType)); - - // Update the cumulative cost for LOUT, FOUT - cumulativeCost[0] += childCumulativeCost[j][bit] + childForwardingCost[j] * bit; - cumulativeCost[1] += childCumulativeCost[j][bit] + childForwardingCost[j] * (1 - bit); - } - } - - // Creates a dummy root node (fedplan) and selects the FedPlan with the minimum cost to return. - // The dummy root node does not have LOUT or FOUT. - private static FedPlan getMinCostRootFedPlan(Set progRootHopSet, FederatedMemoTable memoTable) { - double cumulativeCost = 0; - List> rootFedPlanChilds = new ArrayList<>(); - - // Iterate over each Hop in the progRootHopSet - for (Hop endHop : progRootHopSet){ - // Retrieve the pruned FedPlan for LOUT and FOUT from the memo table - FedPlan lOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.LOUT); - FedPlan fOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.FOUT); - - // Compare the cumulative costs of LOUT and FOUT FedPlans - if (lOutFedPlan.getCumulativeCost() <= fOutFedPlan.getCumulativeCost()){ - cumulativeCost += lOutFedPlan.getCumulativeCost(); - rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.LOUT)); - } else{ - cumulativeCost += fOutFedPlan.getCumulativeCost(); - rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.FOUT)); - } - } - - return new FedPlan(cumulativeCost, null, rootFedPlanChilds); - } - - /** - * Detects and resolves conflicts in federated plans starting from the root plan. - * This function performs a breadth-first search (BFS) to traverse the federated plan tree. - * It identifies conflicts where the same plan ID has different federated output types. - * For each conflict, it records the plan ID and its conflicting parent plans. - * The function ensures that each plan ID is associated with a consistent federated output type - * by resolving these conflicts iteratively. - * - * The process involves: - * - Using a map to track conflicts, associating each plan ID with its federated output type - * and a list of parent plans. - * - Storing detected conflicts in a linked map, each entry containing a plan ID and its - * conflicting parent plans. - * - Performing BFS traversal starting from the root plan, checking each child plan for conflicts. - * - If a conflict is detected (i.e., a plan ID has different output types), the conflicting plan - * is removed from the BFS queue and added to the conflict map to prevent duplicate calculations. - * - Resolving conflicts by ensuring a consistent federated output type across the plan. - * - Re-running BFS with resolved conflicts to ensure all inconsistencies are addressed. - * - * @param rootPlan The root federated plan from which to start the conflict detection. - * @param memoTable The memoization table used to retrieve pruned federated plans. - * @return The cumulative additional cost for resolving conflicts. - */ - private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, FederatedMemoTable memoTable) { - // Map to track conflicts: maps a plan ID to its federated output type and list of parent plans - Map>> conflictCheckMap = new HashMap<>(); - - // LinkedMap to store detected conflicts, each with a plan ID and its conflicting parent plans - LinkedHashMap> conflictLinkedMap = new LinkedHashMap<>(); - - // LinkedMap for BFS traversal starting from the root plan (Do not use value (boolean)) - LinkedHashMap bfsLinkedMap = new LinkedHashMap<>(); - bfsLinkedMap.put(rootPlan, true); - - // Array to store cumulative additional cost for resolving conflicts - double[] cumulativeAdditionalCost = new double[]{0.0}; - - while (!bfsLinkedMap.isEmpty()) { - // Perform BFS to detect conflicts in federated plans - while (!bfsLinkedMap.isEmpty()) { - FedPlan currentPlan = bfsLinkedMap.keySet().iterator().next(); - bfsLinkedMap.remove(currentPlan); - - // Iterate over each child plan of the current plan - for (Pair childPlanPair : currentPlan.getChildFedPlans()) { - FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair); - - // Check if the child plan ID is already visited - if (conflictCheckMap.containsKey(childPlanPair.getLeft())) { - // Retrieve the existing conflict pair for the child plan - Pair> conflictChildPlanPair = conflictCheckMap.get(childPlanPair.getLeft()); - // Add the current plan to the list of parent plans - conflictChildPlanPair.getRight().add(currentPlan); - - // If the federated output type differs, a conflict is detected - if (conflictChildPlanPair.getLeft() != childPlanPair.getRight()) { - // If this is the first detection, remove conflictChildFedPlan from the BFS queue and add it to the conflict linked map (queue) - // If the existing FedPlan is not removed from the bfsqueue or both actions are performed, duplicate calculations for the same FedPlan and its children occur - if (!conflictLinkedMap.containsKey(childPlanPair.getLeft())) { - conflictLinkedMap.put(childPlanPair.getLeft(), conflictChildPlanPair.getRight()); - bfsLinkedMap.remove(childFedPlan); - } - } - } else { - // If no conflict exists, create a new entry in the conflict check map - List parentFedPlanList = new ArrayList<>(); - parentFedPlanList.add(currentPlan); - - // Map the child plan ID to its output type and list of parent plans - conflictCheckMap.put(childPlanPair.getLeft(), new ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList)); - // Add the child plan to the BFS queue - bfsLinkedMap.put(childFedPlan, true); - } - } - } - // Resolve these conflicts to ensure a consistent federated output type across the plan - // Re-run BFS with resolved conflicts - bfsLinkedMap = FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap, cumulativeAdditionalCost); - conflictLinkedMap.clear(); - } - - // Return the cumulative additional cost for resolving conflicts - return cumulativeAdditionalCost[0]; - } - } - \ No newline at end of file +package org.apache.sysds.hops.fedplanner; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Objects; +import java.util.Queue; +import java.util.LinkedList; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + +/** + * Enumerates and evaluates all possible federated execution plans for a given Hop DAG. + * Works with FederatedMemoTable to store plan variants and FederatedPlanCostEstimator + * to compute their costs. + */ +public class FederatedPlanCostEnumerator { + /** + * Entry point for federated plan enumeration. This method creates a memo table + * and returns the minimum cost plan for the entire Directed Acyclic Graph (DAG). + * It also resolves conflicts where FedPlans have different FederatedOutput types. + * + * @param rootHop The root Hop node from which to start the plan enumeration. + * @param printTree A boolean flag indicating whether to print the federated plan tree. + * @return The optimal FedPlan with the minimum cost for the entire DAG. + */ + public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) { + // Create new memo table to store all plan variants + FederatedMemoTable memoTable = new FederatedMemoTable(); + + // Recursively enumerate all possible plans + enumerateFederatedPlanCost(rootHop, memoTable); + + // Return the minimum cost plan for the root node + FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable); + memoTable.pruneMemoTable(); + + // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types + List>> conflictFedPlanList = detectConflictFedPlan(optimalPlan, memoTable); + + // Resolve these conflicts to ensure a consistent federated output type across the plan + FederatedPlanCostEstimator.resolveConflictFedPlan(optimalPlan, memoTable, conflictFedPlanList); + + // Optionally print the federated plan tree if requested + if (printTree) memoTable.printFedPlanTree(optimalPlan); + + return optimalPlan; + } + + /** + * Recursively enumerates all possible federated execution plans for a Hop DAG. + * For each node: + * 1. First processes all input nodes recursively if not already processed + * 2. Generates all possible combinations of federation types (FOUT/LOUT) for inputs + * 3. Creates and evaluates both FOUT and LOUT variants for current node with each input combination + * + * The enumeration uses a bottom-up approach where: + * - Each input combination is represented by a binary number (i) + * - Bit j in i determines whether input j is FOUT (1) or LOUT (0) + * - Total number of combinations is 2^numInputs + * + * @param hop ? + * @param memoTable ? + */ + private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) { + int numInputs = hop.getInput().size(); + + // Process all input nodes first if not already in memo table + for (Hop inputHop : hop.getInput()) { + if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) + && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) { + enumerateFederatedPlanCost(inputHop, memoTable); + } + } + + // Generate all possible input combinations using binary representation + // i represents a specific combination of FOUT/LOUT for inputs + for (int i = 0; i < (1 << numInputs); i++) { + List> planChilds = new ArrayList<>(); + + // For each input, determine if it should be FOUT or LOUT based on bit j in i + for (int j = 0; j < numInputs; j++) { + Hop inputHop = hop.getInput().get(j); + // If bit j is set (1), use FOUT; otherwise use LOUT + FederatedOutput childType = ((i & (1 << j)) != 0) ? + FederatedOutput.FOUT : FederatedOutput.LOUT; + planChilds.add(Pair.of(inputHop.getHopID(), childType)); + } + + // Create and evaluate FOUT variant for current input combination + FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds); + FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable); + + // Create and evaluate LOUT variant for current input combination + FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds); + FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable); + } + } + + /** + * Returns the minimum cost plan for the root Hop, comparing both FOUT and LOUT variants. + * Used to select the final execution plan after enumeration. + * + * @param HopID ? + * @param memoTable ? + * @return ? + */ + private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) { + FedPlanVariants fOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT); + FedPlanVariants lOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT); + + FedPlan minFOutFedPlan = fOutFedPlanVariants._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + FedPlan minlOutFedPlan = lOutFedPlanVariants._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + + if (Objects.requireNonNull(minFOutFedPlan).getTotalCost() + < Objects.requireNonNull(minlOutFedPlan).getTotalCost()) { + return minFOutFedPlan; + } + return minlOutFedPlan; + } + + /** + * Detects conflicts in federated plans starting from the root plan. + * This function performs a breadth-first search (BFS) to traverse the federated plan tree. + * It identifies conflicts where the same plan ID has different federated output types + * and returns a list of such conflicts, each represented by a plan ID and its conflicting parent plans. + * + * @param rootPlan The root federated plan from which to start the conflict detection. + * @param memoTable The memoization table used to retrieve pruned federated plans. + * @return A list of pairs, each containing a plan ID and a list of parent plans that have conflicting federated outputs. + */ + private static List>> detectConflictFedPlan(FedPlan rootPlan, FederatedMemoTable memoTable) { + // Map to track conflicts: maps a plan ID to its federated output type and list of parent plans + Map>> conflictCheckMap = new HashMap<>(); + // List to store detected conflicts, each with a plan ID and its conflicting parent plans + List>> conflictFedPlanList = new ArrayList<>(); + + // Queue for BFS traversal starting from the root plan + Queue bfsQueue = new LinkedList<>(); + bfsQueue.add(rootPlan); + + // Perform BFS to detect conflicts in federated plans + while (!bfsQueue.isEmpty()) { + FedPlan currentPlan = bfsQueue.poll(); + + // Iterate over each child plan of the current plan + for (Pair childPlanPair : currentPlan.getChildFedPlans()) { + FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair.getLeft(), childPlanPair.getRight()); + + // Check if the child plan ID is already in the conflict check map + if (conflictCheckMap.containsKey(childPlanPair.getLeft())) { + // Retrieve the existing conflict pair for the child plan + Pair> conflictFedPlanPair = conflictCheckMap.get(childPlanPair.getLeft()); + // Add the current plan to the list of parent plans + conflictFedPlanPair.getRight().add(currentPlan); + + // If the federated output type differs, a conflict is detected + if (conflictFedPlanPair.getLeft() != childPlanPair.getRight()) { + // Add the conflict to the conflict list + conflictFedPlanList.add(new ImmutablePair<>(childPlanPair.getLeft(), conflictFedPlanPair.getRight())); + // Add the child plan to the BFS queue for further exploration + // Todo: Unsure whether to skip or continue traversal when encountering the same Hop ID with different FederatedOutput types + bfsQueue.add(childFedPlan); + } + } else { + // If no conflict exists, create a new entry in the conflict check map + List parentFedPlanList = new ArrayList<>(); + parentFedPlanList.add(currentPlan); + + // Map the child plan ID to its output type and list of parent plans + conflictCheckMap.put(childPlanPair.getLeft(), new ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList)); + // Add the child plan to the BFS queue + bfsQueue.add(childFedPlan); + } + } + } + + // Return the list of detected conflicts + return conflictFedPlanList; + } +} diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java index 55b1c9daa15..be59bb6fda7 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -17,249 +17,214 @@ * under the License. */ - package org.apache.sysds.hops.fedplanner; - import org.apache.commons.lang3.tuple.Pair; - import org.apache.sysds.common.Types; - import org.apache.sysds.hops.DataOp; - import org.apache.sysds.hops.Hop; - import org.apache.sysds.hops.cost.ComputeCost; - import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; - import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; - import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; - - import java.util.LinkedHashMap; - import java.util.NoSuchElementException; - import java.util.List; - import java.util.Map; - - /** - * Cost estimator for federated execution plans. - * Calculates computation, memory access, and network transfer costs for federated operations. - * Works in conjunction with FederatedMemoTable to evaluate different execution plan variants. - */ - public class FederatedPlanCostEstimator { - // Default value is used as a reasonable estimate since we only need - // to compare relative costs between different federated plans - // Memory bandwidth for local computations (25 GB/s) - private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0; - // Network bandwidth for data transfers between federated sites (1 Gbps) - private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; - - // Retrieves the cumulative and forwarding costs of the child hops and stores them in arrays - public static void getChildCosts(HopCommon hopCommon, FederatedMemoTable memoTable, List inputHops, - double[][] childCumulativeCost, double[] childForwardingCost) { - for (int i = 0; i < inputHops.size(); i++) { - long childHopID = inputHops.get(i).getHopID(); - - FedPlan childLOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.LOUT); - FedPlan childFOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.FOUT); - - // The cumulative cost of the child already includes the weight - childCumulativeCost[i][0] = childLOutFedPlan.getCumulativeCost(); - childCumulativeCost[i][1] = childFOutFedPlan.getCumulativeCost(); - - // TODO: Q. Shouldn't the child's forwarding cost follow the parent's weight, regardless of loops or if-else statements? - childForwardingCost[i] = hopCommon.weight * childLOutFedPlan.getForwardingCost(); - } - } - - /** - * Computes the cost associated with a given Hop node. - * This method calculates both the self cost and the forwarding cost for the Hop, - * taking into account its type and the number of parent nodes. - * - * @param hopCommon The HopCommon object containing the Hop and its properties. - * @return The self cost of the Hop. - */ - public static double computeHopCost(HopCommon hopCommon){ - // TWrite and TRead are meta-data operations, hence selfCost is zero - if (hopCommon.hopRef instanceof DataOp){ - if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTWRITE ){ - hopCommon.setSelfCost(0); - // Since TWrite and TRead have the same FedOutType, forwarding cost is zero - hopCommon.setForwardingCost(0); - return 0; - } else if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTREAD) { - hopCommon.setSelfCost(0); - // TRead may have a different FedOutType from its parent, so calculate forwarding cost - // TODO: Uncertain about the number of TWrites - hopCommon.setForwardingCost(computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate())); - return 0; - } - } - - // In loops, selfCost is repeated, but forwarding may not be - // Therefore, the weight for forwarding follows the parent's weight (TODO: Q. Is the parent also receiving forwarding once?) - double selfCost = hopCommon.weight * computeSelfCost(hopCommon.hopRef); - double forwardingCost = computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate()); - - int numParents = hopCommon.hopRef.getParent().size(); - if (numParents >= 2) { - selfCost /= numParents; - forwardingCost /= numParents; - } - - hopCommon.setSelfCost(selfCost); - hopCommon.setForwardingCost(forwardingCost); - - return selfCost; - } - - /** - * Computes the cost for the current Hop node. - * - * @param currentHop The Hop node whose cost needs to be computed - * @return The total cost for the current node's operation - */ - private static double computeSelfCost(Hop currentHop){ - double computeCost = ComputeCost.getHOPComputeCost(currentHop); - double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); - double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); - - // Compute total cost assuming: - // 1. Computation and input access can be overlapped (hence taking max) - // 2. Output access must wait for both to complete (hence adding) - return Math.max(computeCost, inputAccessCost) + ouputAccessCost; - } - - /** - * Calculates the memory access cost based on data size and memory bandwidth. - * - * @param memSize Size of data to be accessed (in bytes) - * @return Time cost for memory access (in seconds) - */ - private static double computeHopMemoryAccessCost(double memSize) { - return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; - } - - /** - * Calculates the network transfer cost based on data size and network bandwidth. - * Used when federation status changes between parent and child plans. - * - * @param memSize Size of data to be transferred (in bytes) - * @return Time cost for network transfer (in seconds) - */ - private static double computeHopForwardingCost(double memSize) { - return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; - } - - /** - * Resolves conflicts in federated plans where different plans have different FederatedOutput types. - * This function traverses the list of conflicting plans in reverse order to ensure that conflicts - * are resolved from the bottom-up, allowing for consistent federated output types across the plan. - * It calculates additional costs for each potential resolution and updates the cumulative additional cost. - * - * @param memoTable The FederatedMemoTable containing all federated plan variants. - * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent plans with conflicting federated outputs. - * @param cumulativeAdditionalCost An array to store the cumulative additional cost incurred by resolving conflicts. - * @return A LinkedHashMap of resolved federated plans, marked with a boolean indicating resolution status. - */ - public static LinkedHashMap resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) { - // LinkedHashMap to store resolved federated plans for BFS traversal. - LinkedHashMap resolvedFedPlanLinkedMap = new LinkedHashMap<>(); - - // Traverse the conflictFedPlanList in reverse order after BFS to resolve conflicts - for (Map.Entry> conflictFedPlanPair : conflictFedPlanLinkedMap.entrySet()) { - long conflictHopID = conflictFedPlanPair.getKey(); - List conflictParentFedPlans = conflictFedPlanPair.getValue(); - - // Retrieve the conflicting federated plans for LOUT and FOUT types - FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT); - FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT); - - // Variables to store additional costs for LOUT and FOUT types - double lOutAdditionalCost = 0; - double fOutAdditionalCost = 0; - - // Flags to check if the plan involves network transfer - // Network transfer cost is calculated only once, even if it occurs multiple times - boolean isLOutForwarding = false; - boolean isFOutForwarding = false; - - // Determine the optimal federated output type based on the calculated costs - FederatedOutput optimalFedOutType; - - // Iterate over each parent federated plan in the current conflict pair - for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { - // Find the calculated FedOutType of the child plan - Pair cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream() - .filter(pair -> pair.getLeft().equals(conflictHopID)) - .findFirst() - .orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + conflictHopID)); - - // CASE 1. Calculated LOUT / Parent LOUT / Current LOUT: Total cost remains unchanged. - // CASE 2. Calculated LOUT / Parent FOUT / Current LOUT: Total cost remains unchanged, subtract net cost, add net cost later. - // CASE 3. Calculated FOUT / Parent LOUT / Current LOUT: Change total cost, subtract net cost. - // CASE 4. Calculated FOUT / Parent FOUT / Current LOUT: Change total cost, add net cost later. - // CASE 5. Calculated LOUT / Parent LOUT / Current FOUT: Change total cost, add net cost later. - // CASE 6. Calculated LOUT / Parent FOUT / Current FOUT: Change total cost, subtract net cost. - // CASE 7. Calculated FOUT / Parent LOUT / Current FOUT: Total cost remains unchanged, subtract net cost, add net cost later. - // CASE 8. Calculated FOUT / Parent FOUT / Current FOUT: Total cost remains unchanged. - - // Adjust LOUT, FOUT costs based on the calculated plan's output type - if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { - // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost - // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. - fOutAdditionalCost += confilctFOutFedPlan.getCumulativeCost() - confilctLOutFedPlan.getCumulativeCost(); - - if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { - // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred - // (CASE 5) If changing from calculated LOUT to current FOUT, network transfer cost occurs, but calculated later - isFOutForwarding = true; - } else { - // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred - // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later - isLOutForwarding = true; - lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); - - // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it - fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); - } - } else { - lOutAdditionalCost += confilctLOutFedPlan.getCumulativeCost() - confilctFOutFedPlan.getCumulativeCost(); - - if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { - isLOutForwarding = true; - } else { - isFOutForwarding = true; - lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); - fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); - } - } - } - - // Add network transfer costs if applicable - if (isLOutForwarding) { - lOutAdditionalCost += confilctLOutFedPlan.getForwardingCost(); - } - if (isFOutForwarding) { - fOutAdditionalCost += confilctFOutFedPlan.getForwardingCost(); - } - - // Determine the optimal federated output type based on the calculated costs - if (lOutAdditionalCost <= fOutAdditionalCost) { - optimalFedOutType = FederatedOutput.LOUT; - cumulativeAdditionalCost[0] += lOutAdditionalCost; - resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true); - } else { - optimalFedOutType = FederatedOutput.FOUT; - cumulativeAdditionalCost[0] += fOutAdditionalCost; - resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true); - } - - // Update only the optimal federated output type, not the cost itself or recursively - for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { - for (Pair childPlanPair : conflictParentFedPlan.getChildFedPlans()) { - if (childPlanPair.getLeft() == conflictHopID && childPlanPair.getRight() != optimalFedOutType) { - int index = conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair); - conflictParentFedPlan.getChildFedPlans().set(index, - Pair.of(childPlanPair.getLeft(), optimalFedOutType)); - break; - } - } - } - } - return resolvedFedPlanLinkedMap; - } - } - \ No newline at end of file +package org.apache.sysds.hops.fedplanner; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.cost.ComputeCost; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; +import java.util.NoSuchElementException; +import java.util.List; + +/** + * Cost estimator for federated execution plans. + * Calculates computation, memory access, and network transfer costs for federated operations. + * Works in conjunction with FederatedMemoTable to evaluate different execution plan variants. + */ +public class FederatedPlanCostEstimator { + // Default value is used as a reasonable estimate since we only need + // to compare relative costs between different federated plans + // Memory bandwidth for local computations (25 GB/s) + private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0; + // Network bandwidth for data transfers between federated sites (1 Gbps) + private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; + + /** + * Computes total cost of federated plan by: + * 1. Computing current node cost (if not cached) + * 2. Adding minimum-cost child plans + * 3. Including network transfer costs when needed + * + * @param currentPlan Plan to compute cost for + * @param memoTable Table containing all plan variants + */ + public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) { + double totalCost; + Hop currentHop = currentPlan.getHopRef(); + + // Step 1: Calculate current node costs if not already computed + if (currentPlan.getSelfCost() == 0) { + // Compute cost for current node (computation + memory access) + totalCost = computeCurrentCost(currentHop); + currentPlan.setSelfCost(totalCost); + // Calculate potential network transfer cost if federation type changes + currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate())); + } else { + totalCost = currentPlan.getSelfCost(); + } + + // Step 2: Process each child plan and add their costs + for (Pair planRefMeta : currentPlan.getChildFedPlans()) { + // Find minimum cost child plan considering federation type compatibility + // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents + // because we're selecting child plans independently for each parent + FedPlan planRef = memoTable.getMinCostFedPlan(planRefMeta.getLeft(), planRefMeta.getRight()); + + // Add child plan cost (includes network transfer cost if federation types differ) + totalCost += planRef.getTotalCost() + planRef.getCondNetTransferCost(currentPlan.getFedOutType()); + } + + // Step 3: Set final cumulative cost including current node + currentPlan.setTotalCost(totalCost); + } + + /** + * Resolves conflicts in federated plans where different plans have different FederatedOutput types. + * This function traverses the list of conflicting plans in reverse order to ensure that conflicts + * are resolved from the bottom-up, allowing for consistent federated output types across the plan. + * + * @param currentPlan The current FedPlan being evaluated for conflicts. + * @param memoTable The FederatedMemoTable containing all federated plan variants. + * @param conflictFedPlanList A list of pairs, each containing a plan ID and a list of parent plans + * that have conflicting federated outputs. + */ + public static void resolveConflictFedPlan(FedPlan currentPlan, FederatedMemoTable memoTable, List>> conflictFedPlanList) { + // Traverse the conflictFedPlanList in reverse order after BFS to resolve conflicts + for (int i = conflictFedPlanList.size() - 1; i >= 0; i--) { + Pair> conflictFedPlanPair = conflictFedPlanList.get(i); + + // Retrieve the conflicting federated plans for LOUT and FOUT types + FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictFedPlanPair.getLeft(), FederatedOutput.LOUT); + FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictFedPlanPair.getLeft(), FederatedOutput.FOUT); + + double lOutCost = 0; + double fOutCost = 0; + + // Flags to check if the plan involves network transfer + // Network transfer cost is calculated only once, even if it occurs multiple times + boolean isLOutNetTransfer = false; + boolean isFOutNetTransfer = false; + + FederatedOutput optimalFedOutType; + + // Iterate over each parent federated plan in the current conflict pair + for (FedPlan conflictParentFedPlan : conflictFedPlanPair.getValue()) { + // Find the calculated FedOutType of the child plan + Pair cacluatedCurrentPlan = conflictParentFedPlan.getChildFedPlans().stream() + .filter(pair -> pair.getLeft().equals(currentPlan.getHopID())) + .findFirst() + .orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + currentPlan.getHopID())); + + // Accumulate the total costs for both LOUT and FOUT + // Total cost includes compute and memory access, but not network transfer cost + lOutCost += conflictParentFedPlan.getTotalCost(); + fOutCost += conflictParentFedPlan.getTotalCost(); + + // CASE 1. Calculated LOUT / Parent LOUT / Current LOUT: Total cost remains unchanged. + // CASE 2. Calculated LOUT / Parent FOUT / Current LOUT: Total cost remains unchanged, subtract net cost, add net cost later. + // CASE 3. Calculated FOUT / Parent LOUT / Current LOUT: Change total cost, subtract net cost. + // CASE 4. Calculated FOUT / Parent FOUT / Current LOUT: Change total cost, add net cost later. + // CASE 5. Calculated LOUT / Parent LOUT / Current FOUT: Change total cost, add net cost later. + // CASE 6. Calculated LOUT / Parent FOUT / Current FOUT: Change total cost, subtract net cost. + // CASE 7. Calculated FOUT / Parent LOUT / Current FOUT: Total cost remains unchanged, subtract net cost, add net cost later. + // CASE 8. Calculated FOUT / Parent FOUT / Current FOUT: Total cost remains unchanged. + + // Adjust LOUT, FOUT costs based on the calculated plan's output type + if (cacluatedCurrentPlan.getRight() == FederatedOutput.LOUT) { + // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost + // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. + fOutCost -= confilctLOutFedPlan.getTotalCost(); + fOutCost += confilctFOutFedPlan.getTotalCost(); + + if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { + // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred + // (CASE 5) If changing from calculated LOUT to current FOUT, network transfer cost occurs, but calculated later + isFOutNetTransfer = true; + } else { + // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred + // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later + isLOutNetTransfer = true; + lOutCost -= confilctLOutFedPlan.getNetTransferCost(); + // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it + fOutCost -= confilctLOutFedPlan.getNetTransferCost(); + } + } else { + lOutCost -= confilctFOutFedPlan.getTotalCost(); + lOutCost += confilctLOutFedPlan.getTotalCost(); + + if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { + isLOutNetTransfer = true; + } else { + isFOutNetTransfer = true; + lOutCost -= confilctLOutFedPlan.getNetTransferCost(); + fOutCost -= confilctLOutFedPlan.getNetTransferCost(); + } + } + } + + // Add network transfer costs if applicable + if (isLOutNetTransfer) { + lOutCost += confilctLOutFedPlan.getNetTransferCost(); + } + if (isFOutNetTransfer) { + fOutCost += confilctFOutFedPlan.getNetTransferCost(); + } + + // Determine the optimal federated output type based on the calculated costs + if (lOutCost < fOutCost) { + optimalFedOutType = FederatedOutput.LOUT; + } else { + optimalFedOutType = FederatedOutput.FOUT; + } + + // Update only the optimal federated output type, not the cost itself or recursively + for (FedPlan conflictParentFedPlan : conflictFedPlanPair.getValue()) { + for (Pair childPlanPair : conflictParentFedPlan.getChildFedPlans()) { + if (childPlanPair.getLeft() == currentPlan.getHopID() && childPlanPair.getRight() != optimalFedOutType) { + int index = conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair); + conflictParentFedPlan.getChildFedPlans().set(index, + Pair.of(childPlanPair.getLeft(), optimalFedOutType)); + } + } + } + } + } + + /** + * Computes the cost for the current Hop node. + * + * @param currentHop The Hop node whose cost needs to be computed + * @return The total cost for the current node's operation + */ + private static double computeCurrentCost(Hop currentHop){ + double computeCost = ComputeCost.getHOPComputeCost(currentHop); + double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); + double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); + + // Compute total cost assuming: + // 1. Computation and input access can be overlapped (hence taking max) + // 2. Output access must wait for both to complete (hence adding) + return Math.max(computeCost, inputAccessCost) + ouputAccessCost; + } + + /** + * Calculates the memory access cost based on data size and memory bandwidth. + * + * @param memSize Size of data to be accessed (in bytes) + * @return Time cost for memory access (in seconds) + */ + private static double computeHopMemoryAccessCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; + } + + /** + * Calculates the network transfer cost based on data size and network bandwidth. + * Used when federation status changes between parent and child plans. + * + * @param memSize Size of data to be transferred (in bytes) + * @return Time cost for network transfer (in seconds) + */ + private static double computeHopNetworkAccessCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; + } +} diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java index d23f7ebcf92..1d0740fbc04 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java @@ -17,91 +17,72 @@ * under the License. */ - package org.apache.sysds.test.component.federated; +package org.apache.sysds.test.component.federated; - import java.io.IOException; - import java.util.HashMap; - import org.junit.Assert; - import org.junit.Test; - import org.apache.sysds.api.DMLScript; - import org.apache.sysds.conf.ConfigurationManager; - import org.apache.sysds.conf.DMLConfig; - import org.apache.sysds.parser.DMLProgram; - import org.apache.sysds.parser.DMLTranslator; - import org.apache.sysds.parser.ParserFactory; - import org.apache.sysds.parser.ParserWrapper; - import org.apache.sysds.test.AutomatedTestBase; - import org.apache.sysds.test.TestConfiguration; - import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; - - public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase - { - private static final String TEST_DIR = "functions/federated/privacy/"; - private static final String HOME = SCRIPT_DIR + TEST_DIR; - private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; - - @Override - public void setUp() {} - - @Test - public void testFederatedPlanCostEnumerator1() { runTest("FederatedPlanCostEnumeratorTest1.dml"); } - - @Test - public void testFederatedPlanCostEnumerator2() { runTest("FederatedPlanCostEnumeratorTest2.dml"); } - - @Test - public void testFederatedPlanCostEnumerator3() { runTest("FederatedPlanCostEnumeratorTest3.dml"); } - - @Test - public void testFederatedPlanCostEnumerator4() { runTest("FederatedPlanCostEnumeratorTest4.dml"); } - - @Test - public void testFederatedPlanCostEnumerator5() { runTest("FederatedPlanCostEnumeratorTest5.dml"); } - - @Test - public void testFederatedPlanCostEnumerator6() { runTest("FederatedPlanCostEnumeratorTest6.dml"); } - - @Test - public void testFederatedPlanCostEnumerator7() { runTest("FederatedPlanCostEnumeratorTest7.dml"); } - - @Test - public void testFederatedPlanCostEnumerator8() { runTest("FederatedPlanCostEnumeratorTest8.dml"); } - - @Test - public void testFederatedPlanCostEnumerator9() { runTest("FederatedPlanCostEnumeratorTest9.dml"); } - - // Todo: Need to write test scripts for the federated version - private void runTest( String scriptFilename ) { - int index = scriptFilename.lastIndexOf(".dml"); - String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); - TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); - addTestConfiguration(testName, testConfig); - loadTestConfiguration(testConfig); - - try { - DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); - ConfigurationManager.setLocalConfig(conf); - - //read script - String dmlScriptString = DMLScript.readDMLScript(true, HOME + scriptFilename); - - //parsing and dependency analysis - ParserWrapper parser = ParserFactory.createParser(); - DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); - DMLTranslator dmlt = new DMLTranslator(prog); - dmlt.liveVariableAnalysis(prog); - dmlt.validateParseTree(prog); - dmlt.constructHops(prog); - dmlt.rewriteHopsDAG(prog); - dmlt.constructLops(prog); - dmlt.rewriteLopDAG(prog); - - FederatedPlanCostEnumerator.enumerateProgram(prog, true); - } - catch (IOException e) { - e.printStackTrace(); - Assert.fail(); - } - } - } - \ No newline at end of file +import java.io.IOException; +import java.util.HashMap; + +import org.apache.sysds.hops.Hop; +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.DMLTranslator; +import org.apache.sysds.parser.ParserFactory; +import org.apache.sysds.parser.ParserWrapper; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; + + +public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase +{ + private static final String TEST_DIR = "functions/federated/privacy/"; + private static final String HOME = SCRIPT_DIR + TEST_DIR; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; + + @Override + public void setUp() {} + + @Test + public void testFederatedPlanCostEnumerator1() { runTest("FederatedPlanCostEnumeratorTest1.dml"); } + + @Test + public void testFederatedPlanCostEnumerator2() { runTest("FederatedPlanCostEnumeratorTest2.dml"); } + + // Todo: Need to write test scripts for the federated version + private void runTest( String scriptFilename ) { + int index = scriptFilename.lastIndexOf(".dml"); + String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); + TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); + addTestConfiguration(testName, testConfig); + loadTestConfiguration(testConfig); + + try { + DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); + ConfigurationManager.setLocalConfig(conf); + + //read script + String dmlScriptString = DMLScript.readDMLScript(true, HOME + scriptFilename); + + //parsing and dependency analysis + ParserWrapper parser = ParserFactory.createParser(); + DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + dmlt.rewriteHopsDAG(prog); + dmlt.constructLops(prog); + + Hop hops = prog.getStatementBlocks().get(0).getHops().get(0); + FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true); + } + catch (IOException e) { + e.printStackTrace(); + Assert.fail(); + } + } +} From 0c0690aae9932d6bf808e3bfb2a29919236e1252 Mon Sep 17 00:00:00 2001 From: min-guk Date: Sun, 12 Jan 2025 04:53:25 +0900 Subject: [PATCH 3/9] Update detectConflictFedPlan, resolveConflictFedPlan, and MemoTablePrinter --- .../hops/fedplanner/FederatedMemoTable.java | 158 ++++-------------- .../fedplanner/FederatedMemoTablePrinter.java | 133 ++++----------- .../FederatedPlanCostEnumerator.java | 130 ++++++++------ .../FederatedPlanCostEstimator.java | 88 +++++----- 4 files changed, 190 insertions(+), 319 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java index a18376e188e..c84d697a8e6 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -70,13 +70,9 @@ public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List(hopID, fedOutType)); + public FedPlan getMinCostFedPlan(Pair fedPlanPair) { + FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); return fedPlanVariantList._fedPlanVariants.stream() .min(Comparator.comparingDouble(FedPlan::getTotalCost)) .orElse(null); @@ -86,12 +82,22 @@ public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); } + public FedPlanVariants getFedPlanVariants(Pair fedPlanPair) { + return hopMemoTable.get(fedPlanPair); + } + public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) { // Todo: Consider whether to verify if pruning has been performed FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); return fedPlanVariantList._fedPlanVariants.get(0); } + public FedPlan getFedPlanAfterPrune(Pair fedPlanPair) { + // Todo: Consider whether to verify if pruning has been performed + FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); + return fedPlanVariantList._fedPlanVariants.get(0); + } + /** * Checks if the memo table contains an entry for a given Hop and fedOutType. * @@ -104,128 +110,14 @@ public boolean contains(long hopID, FederatedOutput fedOutType) { } /** - * Prunes all entries in the memo table, retaining only the minimum-cost - * FedPlan for each entry. - */ - public void pruneMemoTable() { - for (Map.Entry, FedPlanVariants> entry : hopMemoTable.entrySet()) { - List fedPlanList = entry.getValue().getFedPlanVariants(); - if (fedPlanList.size() > 1) { - // Find the FedPlan with the minimum cost - FedPlan minCostPlan = fedPlanList.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - - // Retain only the minimum cost plan - fedPlanList.clear(); - fedPlanList.add(minCostPlan); - } - } - } - - // Todo: Separate print functions from FederatedMemoTable - /** - * Recursively prints a tree representation of the DAG starting from the given root FedPlan. - * Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node. + * Prunes the specified entry in the memo table, retaining only the minimum-cost + * FedPlan for the given Hop ID and federated output type. * - * @param rootFedPlan The starting point FedPlan to print + * @param hopID The ID of the Hop to prune + * @param federatedOutput The federated output type associated with the Hop */ - public void printFedPlanTree(FedPlan rootFedPlan) { - Set visited = new HashSet<>(); - printFedPlanTreeRecursive(rootFedPlan, visited, 0, true); - } - - /** - * Helper method to recursively print the FedPlan tree. - * - * @param plan The current FedPlan to print - * @param visited Set to keep track of visited FedPlans (prevents cycles) - * @param depth The current depth level for indentation - * @param isLast Whether this node is the last child of its parent - */ - private void printFedPlanTreeRecursive(FedPlan plan, Set visited, int depth, boolean isLast) { - if (plan == null || visited.contains(plan)) { - return; - } - - visited.add(plan); - - Hop hop = plan.getHopRef(); - StringBuilder sb = new StringBuilder(); - - // Add FedPlan information - sb.append(String.format("(%d) ", plan.getHopRef().getHopID())) - .append(plan.getHopRef().getOpString()) - .append(" [") - .append(plan.getFedOutType()) - .append("]"); - - StringBuilder childs = new StringBuilder(); - childs.append(" ("); - boolean childAdded = false; - for( Hop input : hop.getInput()){ - childs.append(childAdded?",":""); - childs.append(input.getHopID()); - childAdded = true; - } - childs.append(")"); - if( childAdded ) - sb.append(childs.toString()); - - - sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", - plan.getTotalCost(), - plan.getSelfCost(), - plan.getNetTransferCost())); - - // Add matrix characteristics - sb.append(" [") - .append(hop.getDim1()).append(", ") - .append(hop.getDim2()).append(", ") - .append(hop.getBlocksize()).append(", ") - .append(hop.getNnz()); - - if (hop.getUpdateType().isInPlace()) { - sb.append(", ").append(hop.getUpdateType().toString().toLowerCase()); - } - sb.append("]"); - - // Add memory estimates - sb.append(" [") - .append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ") - .append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ") - .append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ") - .append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]"); - - // Add reblock and checkpoint requirements - if (hop.requiresReblock() && hop.requiresCheckpoint()) { - sb.append(" [rblk, chkpt]"); - } else if (hop.requiresReblock()) { - sb.append(" [rblk]"); - } else if (hop.requiresCheckpoint()) { - sb.append(" [chkpt]"); - } - - // Add execution type - if (hop.getExecType() != null) { - sb.append(", ").append(hop.getExecType()); - } - - System.out.println(sb); - - // Process child nodes - List> childRefs = plan.getChildFedPlans(); - for (int i = 0; i < childRefs.size(); i++) { - Pair childRef = childRefs.get(i); - FedPlanVariants childVariants = getFedPlanVariants(childRef.getLeft(), childRef.getRight()); - if (childVariants == null || childVariants.getFedPlanVariants().isEmpty()) - continue; - - boolean isLastChild = (i == childRefs.size() - 1); - for (FedPlan childPlan : childVariants.getFedPlanVariants()) { - printFedPlanTreeRecursive(childPlan, visited, depth + 1, isLastChild); - } - } + public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) { + hopMemoTable.get(new ImmutablePair<>(hopID, federatedOutput)).prune(); } /** @@ -262,6 +154,20 @@ public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) { public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} public List getFedPlanVariants() {return _fedPlanVariants;} + public boolean isEmpty() {return _fedPlanVariants.isEmpty();} + + public void prune() { + if (_fedPlanVariants.size() > 1) { + // Find the FedPlan with the minimum cost + FedPlan minCostPlan = _fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + + // Retain only the minimum cost plan + _fedPlanVariants.clear(); + _fedPlanVariants.add(minCostPlan); + } + } } /** diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java index ddddc641d2e..22d7f083c45 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java @@ -3,9 +3,7 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.OptimizerUtils; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; import org.apache.sysds.runtime.instructions.fed.FEDInstruction; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import java.util.HashSet; import java.util.List; @@ -21,48 +19,11 @@ public class FederatedMemoTablePrinter { * @param memoTable The memoization table containing FedPlan variants * @param additionalTotalCost The additional cost to be printed once */ - public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, Set rootHopStatSet, - FederatedMemoTable memoTable, double additionalTotalCost) { + public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, FederatedMemoTable memoTable, + double additionalTotalCost) { System.out.println("Additional Cost: " + additionalTotalCost); - Set visited = new HashSet<>(); + Set visited = new HashSet<>(); printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0); - - for (Hop hop : rootHopStatSet) { - FedPlan plan = memoTable.getFedPlanAfterPrune(hop.getHopID(), FederatedOutput.LOUT); - printNotReferencedFedPlanRecursive(plan, memoTable, visited, 1); - } - } - - /** - * Helper method to recursively print the FedPlan tree. - * - * @param plan The current FedPlan to print - * @param visited Set to keep track of visited FedPlans (prevents cycles) - * @param depth The current depth level for indentation - */ - private static void printNotReferencedFedPlanRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, - Set visited, int depth) { - long hopID = plan.getHopRef().getHopID(); - - if (visited.contains(hopID)) { - return; - } - - visited.add(hopID); - printFedPlan(plan, depth, true); - - // Process child nodes - List> childFedPlanPairs = plan.getChildFedPlans(); - for (int i = 0; i < childFedPlanPairs.size(); i++) { - Pair childFedPlanPair = childFedPlanPairs.get(i); - FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); - if (childVariants == null || childVariants.isEmpty()) - continue; - - for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { - printNotReferencedFedPlanRecursive(childPlan, memoTable, visited, depth + 1); - } - } } /** @@ -73,83 +34,40 @@ private static void printNotReferencedFedPlanRecursive(FederatedMemoTable.FedPla * @param depth The current depth level for indentation */ private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, - Set visited, int depth) { - long hopID = 0; - - if (depth == 0) { - hopID = -1; - } else { - hopID = plan.getHopRef().getHopID(); - } - - if (visited.contains(hopID)) { + Set visited, int depth) { + if (plan == null || visited.contains(plan)) { return; } - visited.add(hopID); - printFedPlan(plan, depth, false); - - // Process child nodes - List> childFedPlanPairs = plan.getChildFedPlans(); - for (int i = 0; i < childFedPlanPairs.size(); i++) { - Pair childFedPlanPair = childFedPlanPairs.get(i); - FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); - if (childVariants == null || childVariants.isEmpty()) - continue; - - for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { - printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1); - } - } - } + visited.add(plan); - private static void printFedPlan(FederatedMemoTable.FedPlan plan, int depth, boolean isNotReferenced) { + Hop hop = plan.getHopRef(); StringBuilder sb = new StringBuilder(); - Hop hop = null; - - if (depth == 0){ - sb.append("(R) ROOT [Root]"); - } else { - hop = plan.getHopRef(); - // Add FedPlan information - sb.append(String.format("(%d) ", hop.getHopID())) - .append(hop.getOpString()) - .append(" ["); - - if (isNotReferenced) { - sb.append("NRef"); - } else{ - sb.append(plan.getFedOutType()); - } - sb.append("]"); - } + + // Add FedPlan information + sb.append(String.format("(%d) ", plan.getHopRef().getHopID())) + .append(plan.getHopRef().getOpString()) + .append(" [") + .append(plan.getFedOutType()) + .append("]"); StringBuilder childs = new StringBuilder(); childs.append(" ("); - boolean childAdded = false; - for (Pair childPair : plan.getChildFedPlans()){ + for( Hop input : hop.getInput()){ childs.append(childAdded?",":""); - childs.append(childPair.getLeft()); + childs.append(input.getHopID()); childAdded = true; } - childs.append(")"); - if( childAdded ) sb.append(childs.toString()); - if (depth == 0){ - sb.append(String.format(" {Total: %.1f}", plan.getCumulativeCost())); - System.out.println(sb); - return; - } - sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f, Weight: %.1f}", - plan.getCumulativeCost(), + sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", + plan.getTotalCost(), plan.getSelfCost(), - plan.getForwardingCost(), - plan.getWeight())); + plan.getNetTransferCost())); // Add matrix characteristics sb.append(" [") @@ -185,5 +103,18 @@ private static void printFedPlan(FederatedMemoTable.FedPlan plan, int depth, boo } System.out.println(sb); + + // Process child nodes + List> childFedPlanPairs = plan.getChildFedPlans(); + for (int i = 0; i < childFedPlanPairs.size(); i++) { + Pair childFedPlanPair = childFedPlanPairs.get(i); + FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); + if (childVariants == null || childVariants.isEmpty()) + continue; + + for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { + printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1); + } + } } } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java index db1583ab2fb..be1cfa7cdf3 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -24,8 +24,7 @@ import java.util.Comparator; import java.util.HashMap; import java.util.Objects; -import java.util.Queue; -import java.util.LinkedList; +import java.util.LinkedHashMap; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -58,16 +57,12 @@ public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) // Return the minimum cost plan for the root node FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable); - memoTable.pruneMemoTable(); // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types - List>> conflictFedPlanList = detectConflictFedPlan(optimalPlan, memoTable); - - // Resolve these conflicts to ensure a consistent federated output type across the plan - FederatedPlanCostEstimator.resolveConflictFedPlan(optimalPlan, memoTable, conflictFedPlanList); + double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); // Optionally print the federated plan tree if requested - if (printTree) memoTable.printFedPlanTree(optimalPlan); + if (printTree) FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost); return optimalPlan; } @@ -120,6 +115,10 @@ private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoT FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds); FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable); } + + // Prune MemoTable for hop. + memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.LOUT); + memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.FOUT); } /** @@ -149,62 +148,87 @@ private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memo } /** - * Detects conflicts in federated plans starting from the root plan. + * Detects and resolves conflicts in federated plans starting from the root plan. * This function performs a breadth-first search (BFS) to traverse the federated plan tree. - * It identifies conflicts where the same plan ID has different federated output types - * and returns a list of such conflicts, each represented by a plan ID and its conflicting parent plans. + * It identifies conflicts where the same plan ID has different federated output types. + * For each conflict, it records the plan ID and its conflicting parent plans. + * The function ensures that each plan ID is associated with a consistent federated output type + * by resolving these conflicts iteratively. + * + * The process involves: + * - Using a map to track conflicts, associating each plan ID with its federated output type + * and a list of parent plans. + * - Storing detected conflicts in a linked map, each entry containing a plan ID and its + * conflicting parent plans. + * - Performing BFS traversal starting from the root plan, checking each child plan for conflicts. + * - If a conflict is detected (i.e., a plan ID has different output types), the conflicting plan + * is removed from the BFS queue and added to the conflict map to prevent duplicate calculations. + * - Resolving conflicts by ensuring a consistent federated output type across the plan. + * - Re-running BFS with resolved conflicts to ensure all inconsistencies are addressed. * * @param rootPlan The root federated plan from which to start the conflict detection. * @param memoTable The memoization table used to retrieve pruned federated plans. - * @return A list of pairs, each containing a plan ID and a list of parent plans that have conflicting federated outputs. + * @return The cumulative additional cost for resolving conflicts. */ - private static List>> detectConflictFedPlan(FedPlan rootPlan, FederatedMemoTable memoTable) { + private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, FederatedMemoTable memoTable) { // Map to track conflicts: maps a plan ID to its federated output type and list of parent plans Map>> conflictCheckMap = new HashMap<>(); - // List to store detected conflicts, each with a plan ID and its conflicting parent plans - List>> conflictFedPlanList = new ArrayList<>(); - - // Queue for BFS traversal starting from the root plan - Queue bfsQueue = new LinkedList<>(); - bfsQueue.add(rootPlan); - - // Perform BFS to detect conflicts in federated plans - while (!bfsQueue.isEmpty()) { - FedPlan currentPlan = bfsQueue.poll(); - - // Iterate over each child plan of the current plan - for (Pair childPlanPair : currentPlan.getChildFedPlans()) { - FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair.getLeft(), childPlanPair.getRight()); - - // Check if the child plan ID is already in the conflict check map - if (conflictCheckMap.containsKey(childPlanPair.getLeft())) { - // Retrieve the existing conflict pair for the child plan - Pair> conflictFedPlanPair = conflictCheckMap.get(childPlanPair.getLeft()); - // Add the current plan to the list of parent plans - conflictFedPlanPair.getRight().add(currentPlan); - - // If the federated output type differs, a conflict is detected - if (conflictFedPlanPair.getLeft() != childPlanPair.getRight()) { - // Add the conflict to the conflict list - conflictFedPlanList.add(new ImmutablePair<>(childPlanPair.getLeft(), conflictFedPlanPair.getRight())); - // Add the child plan to the BFS queue for further exploration - // Todo: Unsure whether to skip or continue traversal when encountering the same Hop ID with different FederatedOutput types - bfsQueue.add(childFedPlan); + + // LinkedMap to store detected conflicts, each with a plan ID and its conflicting parent plans + LinkedHashMap> conflictLinkedMap = new LinkedHashMap<>(); + + // LinkedMap for BFS traversal starting from the root plan (Do not use value (boolean)) + LinkedHashMap bfsLinkedMap = new LinkedHashMap<>(); + bfsLinkedMap.put(rootPlan, true); + + // Array to store cumulative additional cost for resolving conflicts + double[] cumulativeAdditionalCost = new double[]{0.0}; + + while (!bfsLinkedMap.isEmpty()) { + // Perform BFS to detect conflicts in federated plans + while (!bfsLinkedMap.isEmpty()) { + FedPlan currentPlan = bfsLinkedMap.keySet().iterator().next(); + bfsLinkedMap.remove(currentPlan); + + // Iterate over each child plan of the current plan + for (Pair childPlanPair : currentPlan.getChildFedPlans()) { + FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair); + + // Check if the child plan ID is already visited + if (conflictCheckMap.containsKey(childPlanPair.getLeft())) { + // Retrieve the existing conflict pair for the child plan + Pair> conflictChildPlanPair = conflictCheckMap.get(childPlanPair.getLeft()); + // Add the current plan to the list of parent plans + conflictChildPlanPair.getRight().add(currentPlan); + + // If the federated output type differs, a conflict is detected + if (conflictChildPlanPair.getLeft() != childPlanPair.getRight()) { + // If this is the first detection, remove conflictChildFedPlan from the BFS queue and add it to the conflict linked map (queue) + // If the existing FedPlan is not removed from the bfsqueue or both actions are performed, duplicate calculations for the same FedPlan and its children occur + if (!conflictLinkedMap.containsKey(childPlanPair.getLeft())) { + conflictLinkedMap.put(childPlanPair.getLeft(), conflictChildPlanPair.getRight()); + bfsLinkedMap.remove(childFedPlan); + } + } + } else { + // If no conflict exists, create a new entry in the conflict check map + List parentFedPlanList = new ArrayList<>(); + parentFedPlanList.add(currentPlan); + + // Map the child plan ID to its output type and list of parent plans + conflictCheckMap.put(childPlanPair.getLeft(), new ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList)); + // Add the child plan to the BFS queue + bfsLinkedMap.put(childFedPlan, true); } - } else { - // If no conflict exists, create a new entry in the conflict check map - List parentFedPlanList = new ArrayList<>(); - parentFedPlanList.add(currentPlan); - - // Map the child plan ID to its output type and list of parent plans - conflictCheckMap.put(childPlanPair.getLeft(), new ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList)); - // Add the child plan to the BFS queue - bfsQueue.add(childFedPlan); } } + // Resolve these conflicts to ensure a consistent federated output type across the plan + // Re-run BFS with resolved conflicts + bfsLinkedMap = FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap, cumulativeAdditionalCost); + conflictLinkedMap.clear(); } - // Return the list of detected conflicts - return conflictFedPlanList; + // Return the cumulative additional cost for resolving conflicts + return cumulativeAdditionalCost[0]; } } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java index be59bb6fda7..7bc7339563a 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -23,8 +23,11 @@ import org.apache.sysds.hops.cost.ComputeCost; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + +import java.util.LinkedHashMap; import java.util.NoSuchElementException; import java.util.List; +import java.util.Map; /** * Cost estimator for federated execution plans. @@ -64,11 +67,11 @@ public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTa } // Step 2: Process each child plan and add their costs - for (Pair planRefMeta : currentPlan.getChildFedPlans()) { + for (Pair childPlanPair : currentPlan.getChildFedPlans()) { // Find minimum cost child plan considering federation type compatibility // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents // because we're selecting child plans independently for each parent - FedPlan planRef = memoTable.getMinCostFedPlan(planRefMeta.getLeft(), planRefMeta.getRight()); + FedPlan planRef = memoTable.getMinCostFedPlan(childPlanPair); // Add child plan cost (includes network transfer cost if federation types differ) totalCost += planRef.getTotalCost() + planRef.getCondNetTransferCost(currentPlan.getFedOutType()); @@ -82,44 +85,46 @@ public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTa * Resolves conflicts in federated plans where different plans have different FederatedOutput types. * This function traverses the list of conflicting plans in reverse order to ensure that conflicts * are resolved from the bottom-up, allowing for consistent federated output types across the plan. + * It calculates additional costs for each potential resolution and updates the cumulative additional cost. * - * @param currentPlan The current FedPlan being evaluated for conflicts. * @param memoTable The FederatedMemoTable containing all federated plan variants. - * @param conflictFedPlanList A list of pairs, each containing a plan ID and a list of parent plans - * that have conflicting federated outputs. + * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent plans with conflicting federated outputs. + * @param cumulativeAdditionalCost An array to store the cumulative additional cost incurred by resolving conflicts. + * @return A LinkedHashMap of resolved federated plans, marked with a boolean indicating resolution status. */ - public static void resolveConflictFedPlan(FedPlan currentPlan, FederatedMemoTable memoTable, List>> conflictFedPlanList) { + public static LinkedHashMap resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) { + // LinkedHashMap to store resolved federated plans for BFS traversal. + LinkedHashMap resolvedFedPlanLinkedMap = new LinkedHashMap<>(); + // Traverse the conflictFedPlanList in reverse order after BFS to resolve conflicts - for (int i = conflictFedPlanList.size() - 1; i >= 0; i--) { - Pair> conflictFedPlanPair = conflictFedPlanList.get(i); - + for (Map.Entry> conflictFedPlanPair : conflictFedPlanLinkedMap.entrySet()) { + long conflictHopID = conflictFedPlanPair.getKey(); + List conflictParentFedPlans = conflictFedPlanPair.getValue(); + // Retrieve the conflicting federated plans for LOUT and FOUT types - FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictFedPlanPair.getLeft(), FederatedOutput.LOUT); - FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictFedPlanPair.getLeft(), FederatedOutput.FOUT); + FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT); + FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT); + + // Variables to store additional costs for LOUT and FOUT types + double lOutAdditionalCost = 0; + double fOutAdditionalCost = 0; - double lOutCost = 0; - double fOutCost = 0; - // Flags to check if the plan involves network transfer // Network transfer cost is calculated only once, even if it occurs multiple times boolean isLOutNetTransfer = false; boolean isFOutNetTransfer = false; + // Determine the optimal federated output type based on the calculated costs FederatedOutput optimalFedOutType; - + // Iterate over each parent federated plan in the current conflict pair - for (FedPlan conflictParentFedPlan : conflictFedPlanPair.getValue()) { + for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { // Find the calculated FedOutType of the child plan - Pair cacluatedCurrentPlan = conflictParentFedPlan.getChildFedPlans().stream() - .filter(pair -> pair.getLeft().equals(currentPlan.getHopID())) + Pair cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream() + .filter(pair -> pair.getLeft().equals(conflictHopID)) .findFirst() - .orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + currentPlan.getHopID())); - - // Accumulate the total costs for both LOUT and FOUT - // Total cost includes compute and memory access, but not network transfer cost - lOutCost += conflictParentFedPlan.getTotalCost(); - fOutCost += conflictParentFedPlan.getTotalCost(); - + .orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + conflictHopID)); + // CASE 1. Calculated LOUT / Parent LOUT / Current LOUT: Total cost remains unchanged. // CASE 2. Calculated LOUT / Parent FOUT / Current LOUT: Total cost remains unchanged, subtract net cost, add net cost later. // CASE 3. Calculated FOUT / Parent LOUT / Current LOUT: Change total cost, subtract net cost. @@ -130,11 +135,10 @@ public static void resolveConflictFedPlan(FedPlan currentPlan, FederatedMemoTabl // CASE 8. Calculated FOUT / Parent FOUT / Current FOUT: Total cost remains unchanged. // Adjust LOUT, FOUT costs based on the calculated plan's output type - if (cacluatedCurrentPlan.getRight() == FederatedOutput.LOUT) { + if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. - fOutCost -= confilctLOutFedPlan.getTotalCost(); - fOutCost += confilctFOutFedPlan.getTotalCost(); + fOutAdditionalCost += confilctFOutFedPlan.getTotalCost() - confilctLOutFedPlan.getTotalCost(); if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred @@ -144,50 +148,56 @@ public static void resolveConflictFedPlan(FedPlan currentPlan, FederatedMemoTabl // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later isLOutNetTransfer = true; - lOutCost -= confilctLOutFedPlan.getNetTransferCost(); + lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); + // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it - fOutCost -= confilctLOutFedPlan.getNetTransferCost(); + fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); } } else { - lOutCost -= confilctFOutFedPlan.getTotalCost(); - lOutCost += confilctLOutFedPlan.getTotalCost(); + lOutAdditionalCost += confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost(); if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { isLOutNetTransfer = true; } else { isFOutNetTransfer = true; - lOutCost -= confilctLOutFedPlan.getNetTransferCost(); - fOutCost -= confilctLOutFedPlan.getNetTransferCost(); + lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); + fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); } } } // Add network transfer costs if applicable if (isLOutNetTransfer) { - lOutCost += confilctLOutFedPlan.getNetTransferCost(); + lOutAdditionalCost += confilctLOutFedPlan.getNetTransferCost(); } if (isFOutNetTransfer) { - fOutCost += confilctFOutFedPlan.getNetTransferCost(); + fOutAdditionalCost += confilctFOutFedPlan.getNetTransferCost(); } // Determine the optimal federated output type based on the calculated costs - if (lOutCost < fOutCost) { + if (lOutAdditionalCost <= fOutAdditionalCost) { optimalFedOutType = FederatedOutput.LOUT; + cumulativeAdditionalCost[0] += lOutAdditionalCost; + resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true); } else { optimalFedOutType = FederatedOutput.FOUT; + cumulativeAdditionalCost[0] += fOutAdditionalCost; + resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true); } // Update only the optimal federated output type, not the cost itself or recursively - for (FedPlan conflictParentFedPlan : conflictFedPlanPair.getValue()) { + for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { for (Pair childPlanPair : conflictParentFedPlan.getChildFedPlans()) { - if (childPlanPair.getLeft() == currentPlan.getHopID() && childPlanPair.getRight() != optimalFedOutType) { + if (childPlanPair.getLeft() == conflictHopID && childPlanPair.getRight() != optimalFedOutType) { int index = conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair); conflictParentFedPlan.getChildFedPlans().set(index, Pair.of(childPlanPair.getLeft(), optimalFedOutType)); + break; } } } } + return resolvedFedPlanLinkedMap; } /** From 5094822c1fee9befc998a6ac2a4bfe8dce8f1524 Mon Sep 17 00:00:00 2001 From: min-guk Date: Sun, 12 Jan 2025 05:16:09 +0900 Subject: [PATCH 4/9] Add if-else DML test script --- .../component/federated/FederatedPlanCostEnumeratorTest.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java index 1d0740fbc04..20485588d32 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java @@ -51,7 +51,10 @@ public void setUp() {} @Test public void testFederatedPlanCostEnumerator2() { runTest("FederatedPlanCostEnumeratorTest2.dml"); } - + + @Test + public void testFederatedPlanCostEnumerator3() { runTest("FederatedPlanCostEnumeratorTest3.dml"); } + // Todo: Need to write test scripts for the federated version private void runTest( String scriptFilename ) { int index = scriptFilename.lastIndexOf(".dml"); From b771558a9942a69f30395516f725320306451bdc Mon Sep 17 00:00:00 2001 From: min-guk Date: Mon, 10 Feb 2025 21:30:01 +0900 Subject: [PATCH 5/9] Optimal Planner --- .../hops/fedplanner/FederatedMemoTable.java | 320 ++++++++---- .../fedplanner/FederatedMemoTablePrinter.java | 4 +- .../FederatedPlanCostEnumerator.java | 474 ++++++++++++++++-- .../FederatedPlanCostEstimator.java | 224 +++++++-- 4 files changed, 827 insertions(+), 195 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java index c84d697a8e6..196e52b6de1 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -19,18 +19,19 @@ package org.apache.sysds.hops.fedplanner; -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.OptimizerUtils; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.ArrayList; import java.util.Map; -import java.util.HashSet; -import java.util.Set; +import java.util.Arrays; +import java.util.Collections; +import org.apache.sysds.hops.Hop; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; +import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator.ConflictMergeResolveInfo; +import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator.ResolvedType; /** * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. @@ -41,98 +42,196 @@ public class FederatedMemoTable { // Maps Hop ID and fedOutType pairs to their plan variants private final Map, FedPlanVariants> hopMemoTable = new HashMap<>(); - /** - * Adds a new federated plan to the memo table. - * Creates a new variant list if none exists for the given Hop and fedOutType. - * - * @param hop The Hop node - * @param fedOutType The federated output type - * @param planChilds List of child plan references - * @return The newly created FedPlan - */ - public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List> planChilds) { - long hopID = hop.getHopID(); - FedPlanVariants fedPlanVariantList; - - if (contains(hopID, fedOutType)) { - fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); - } else { - fedPlanVariantList = new FedPlanVariants(hop, fedOutType); - hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList); - } - - FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList); - fedPlanVariantList.addFedPlan(newPlan); - - return newPlan; + public void addFedPlanVariants(long hopID, FederatedOutput fedOutType, FedPlanVariants fedPlanVariants) { + hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariants); } - /** - * Retrieves the minimum cost child plan considering the parent's output type. - * The cost is calculated using getParentViewCost to account for potential type mismatches. - */ - public FedPlan getMinCostFedPlan(Pair fedPlanPair) { - FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); - return fedPlanVariantList._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); + public FedPlanVariants getFedPlanVariants(Pair fedPlanPair) { + return hopMemoTable.get(fedPlanPair); } public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType) { return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); } - public FedPlanVariants getFedPlanVariants(Pair fedPlanPair) { - return hopMemoTable.get(fedPlanPair); - } - public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) { - // Todo: Consider whether to verify if pruning has been performed FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); return fedPlanVariantList._fedPlanVariants.get(0); } public FedPlan getFedPlanAfterPrune(Pair fedPlanPair) { - // Todo: Consider whether to verify if pruning has been performed FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); return fedPlanVariantList._fedPlanVariants.get(0); } - /** - * Checks if the memo table contains an entry for a given Hop and fedOutType. - * - * @param hopID The Hop ID. - * @param fedOutType The associated fedOutType. - * @return True if the entry exists, false otherwise. - */ public boolean contains(long hopID, FederatedOutput fedOutType) { return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); } - /** - * Prunes the specified entry in the memo table, retaining only the minimum-cost - * FedPlan for the given Hop ID and federated output type. - * - * @param hopID The ID of the Hop to prune - * @param federatedOutput The federated output type associated with the Hop - */ - public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) { - hopMemoTable.get(new ImmutablePair<>(hopID, federatedOutput)).prune(); - } + public static class ConflictedFedPlanVariants extends FedPlanVariants { + public List conflictInfos; + protected int numConflictCombinations; + // 2^(# of conflicts), 2^(# of childs) + protected double[][] cumulativeCost; + protected int[][] forwardingBitMap; - /** - * Represents common properties and costs associated with a Hop. - * This class holds a reference to the Hop and tracks its execution and network transfer costs. - */ - public static class HopCommon { - protected final Hop hopRef; // Reference to the associated Hop - protected double selfCost; // Current execution cost (compute + memory access) - protected double netTransferCost; // Network transfer cost + // bitset array (java class) >> arbitary length >> - protected HopCommon(Hop hopRef) { - this.hopRef = hopRef; - this.selfCost = 0; - this.netTransferCost = 0; + public ConflictedFedPlanVariants(HopCommon hopCommon, FederatedOutput fedOutType, + List conflictMergeResolveInfos) { + super(hopCommon, fedOutType); + this.conflictInfos = conflictMergeResolveInfos; + this.numConflictCombinations = 1 << this.conflictInfos.size(); + this.cumulativeCost = new double[this.numConflictCombinations][this._fedPlanVariants.size()]; + this.forwardingBitMap = new int[this.numConflictCombinations][this._fedPlanVariants.size()]; + // Initialize isForwardBitMap to 0 + for (int i = 0; i < this.numConflictCombinations; i++) { + Arrays.fill(this.cumulativeCost[i], 0); + Arrays.fill(this.forwardingBitMap[i], 0); + } + } + + // Todo: (최적화) java bitset 사용하여, 다수의 conflict 처리할 수 있도록 해야 함. + // Todo: (구현) 만약 resolve point (converge, first-split & last-merge) child로 내려가면서 recursive하게 prune 해야 함. (이때, parents의 LOUT/FOUT의 Optimal Plan을 동시에 고려해야함) + public void pruneConflictedFedPlans() { + // Step 1: Initialize prunedCost and prunedIsForwardingBitMap with minimal values per combination + double[][] prunedCost = new double[this.numConflictCombinations][1]; + int[][] prunedIsForwardingBitMap = new int[this.numConflictCombinations][1]; + List prunedFedPlanVariants = new ArrayList<>(); + + for (int i = 0; i < this.numConflictCombinations; i++) { + double minCost = Double.MAX_VALUE; + int minIndex = -1; + for (int j = 0; j < _fedPlanVariants.size(); j++) { + if (cumulativeCost[i][j] < minCost) { + minCost = cumulativeCost[i][j]; + minIndex = j; + } + } + prunedCost[i][0] = minCost; + prunedIsForwardingBitMap[i][0] = (minIndex != -1) ? forwardingBitMap[i][minIndex] : 0; + prunedFedPlanVariants.add(_fedPlanVariants.get(minIndex)); + } + + this.cumulativeCost = prunedCost; + this.forwardingBitMap = prunedIsForwardingBitMap; + this._fedPlanVariants = prunedFedPlanVariants; + + // Step 2: Collect resolved conflict bit positions + List resolvedBits = new ArrayList<>(); + for (int i = 0; i < conflictInfos.size(); i++) { + ConflictMergeResolveInfo info = conflictInfos.get(i); + if (info.getResolvedType() == ResolvedType.RESOLVE) { + resolvedBits.add(i); // Assuming index corresponds to bit position + } + } + + int resolvedBitsSize = resolvedBits.size(); + + // CASE 1: if not resolved, return + if (resolvedBitsSize == 0){ + return; + } + + // CASE 2: if all resolved, transform to FedPlanVariants + if (resolvedBits.size() == conflictInfos.size()){ + double minCost = Double.MAX_VALUE; + int minCostIdx = -1; + + for (int i = 0; i < this.numConflictCombinations; i++) { + if (cumulativeCost[i][0] < minCost) { + minCost = cumulativeCost[i][0]; + minCostIdx = i; + } + } + + FedPlan finalFedPlan = this.getFedPlanVariants().get(minCostIdx); + finalFedPlan.setCumulativeCost(minCost); + this._fedPlanVariants.clear(); + this._fedPlanVariants.add(finalFedPlan); + + this.conflictInfos = null; + this.cumulativeCost = null; + this.forwardingBitMap = null; + this.numConflictCombinations = 0; + + return; + } + + // CASE 3: if some resolved, some not, merge them + int mask = 0; + for (int bit : resolvedBits) { + mask |= (1 << bit); + } + mask = ~mask; + + List unresolvedBits = new ArrayList<>(); + for (int bit = 0; bit < conflictInfos.size(); bit++) { + if (!resolvedBits.contains(bit)) { + unresolvedBits.add(bit); + } + } + Collections.sort(unresolvedBits); // Ensure consistent ordering + + // Create newConflictInfos with unresolved conflicts + List newConflictInfos = new ArrayList<>(); + for (int bit : unresolvedBits) { + newConflictInfos.add(conflictInfos.get(bit)); + } + + // Step 4: Group combinations by their base (ignoring resolved bits) + Map> groups = new HashMap<>(); + for (int i = 0; i < this.numConflictCombinations; i++) { + int base = i & mask; + groups.computeIfAbsent(base, k -> new ArrayList<>()).add(i); + } + + // Step 5: Merge groups and create new arrays with reduced size + int newSize = 1 << unresolvedBits.size(); + double[][] newPrunedCost = new double[newSize][1]; + int[][] newPrunedBitMap = new int[newSize][1]; + List newPrunedFedPlanVariants = new ArrayList<>(newSize); + Arrays.fill(newPrunedCost, Double.MAX_VALUE); + + for (Map.Entry> entry : groups.entrySet()) { + int base = entry.getKey(); + List group = entry.getValue(); + + // Find minimal cost and bitmap in the group + double minGroupCost = Double.MAX_VALUE; + int minBitmap = 0; + int minIdx = -1; + + for (int comb : group) { + if (cumulativeCost[comb][0] < minGroupCost) { + minGroupCost = cumulativeCost[comb][0]; + minBitmap = forwardingBitMap[comb][0]; + minIdx = comb; + } + } + + // Compute new index based on unresolved bits + int newIndex = 0; + for (int i = 0; i < unresolvedBits.size(); i++) { + int bitPos = unresolvedBits.get(i); + if ((base & (1 << bitPos)) != 0) { + newIndex |= (1 << i); // Set the i-th bit in newIndex + } + } + + // Update newPruned arrays + if (newIndex < newSize) { + newPrunedCost[newIndex][0] = minGroupCost; + newPrunedBitMap[newIndex][0] = minBitmap; + newPrunedFedPlanVariants.add(newIndex, _fedPlanVariants.get(minIdx)); + } + } + + // Replace the pruned arrays with the merged results and update size + this.conflictInfos = newConflictInfos; + this.cumulativeCost = newPrunedCost; + this.forwardingBitMap = newPrunedBitMap; + this.numConflictCombinations = newSize; // Update to the new reduced size } } @@ -146,21 +245,24 @@ public static class FedPlanVariants { private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) protected List _fedPlanVariants; // List of plan variants - public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) { - this.hopCommon = new HopCommon(hopRef); + public FedPlanVariants(HopCommon hopCommon, FederatedOutput fedOutType) { + this.hopCommon = hopCommon; this.fedOutType = fedOutType; this._fedPlanVariants = new ArrayList<>(); } + public boolean isEmpty() {return _fedPlanVariants.isEmpty();} public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} public List getFedPlanVariants() {return _fedPlanVariants;} - public boolean isEmpty() {return _fedPlanVariants.isEmpty();} + public FederatedOutput getFedOutType() {return fedOutType;} + public double getSelfCost() {return hopCommon.getSelfCost();} + public double getForwardingCost() {return hopCommon.getForwardingCost();} - public void prune() { + public void pruneFedPlans() { if (_fedPlanVariants.size() > 1) { // Find the FedPlan with the minimum cost FedPlan minCostPlan = _fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .min(Comparator.comparingDouble(FedPlan::getCumulativeCost)) .orElse(null); // Retain only the minimum cost plan @@ -174,43 +276,53 @@ public void prune() { * Represents a single federated execution plan with its associated costs and dependencies. * This class contains: * 1. selfCost: Cost of current hop (compute + input/output memory access) - * 2. totalCost: Cumulative cost including this plan and all child plans + * 2. cumulativeCost: Cumulative cost including this plan and all child plans * 3. netTransferCost: Network transfer cost for this plan to parent plan. * * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs. */ public static class FedPlan { - private double totalCost; // Total cost including child plans + private double cumulativeCost; // Total cost including child plans private final FedPlanVariants fedPlanVariants; // Reference to variant list private final List> childFedPlans; // Child plan references - public FedPlan(List> childFedPlans, FedPlanVariants fedPlanVariants) { - this.totalCost = 0; - this.childFedPlans = childFedPlans; + public FedPlan(double cumulativeCost, FedPlanVariants fedPlanVariants, List> childFedPlans) { + this.cumulativeCost = cumulativeCost; this.fedPlanVariants = fedPlanVariants; + this.childFedPlans = childFedPlans; } - public void setTotalCost(double totalCost) {this.totalCost = totalCost;} - public void setSelfCost(double selfCost) {fedPlanVariants.hopCommon.selfCost = selfCost;} - public void setNetTransferCost(double netTransferCost) {fedPlanVariants.hopCommon.netTransferCost = netTransferCost;} - - public Hop getHopRef() {return fedPlanVariants.hopCommon.hopRef;} - public long getHopID() {return fedPlanVariants.hopCommon.hopRef.getHopID();} - public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;} - public double getTotalCost() {return totalCost;} - public double getSelfCost() {return fedPlanVariants.hopCommon.selfCost;} - public double getNetTransferCost() {return fedPlanVariants.hopCommon.netTransferCost;} + public Hop getHopRef() {return fedPlanVariants.hopCommon.getHopRef();} + public long getHopID() {return fedPlanVariants.hopCommon.getHopRef().getHopID();} + public FederatedOutput getFedOutType() {return fedPlanVariants.getFedOutType();} + public double getCumulativeCost() {return cumulativeCost;} + public double getSelfCost() {return fedPlanVariants.hopCommon.getSelfCost();} + public double getForwardingCost() {return fedPlanVariants.hopCommon.getForwardingCost();} public List> getChildFedPlans() {return childFedPlans;} - /** - * Calculates the conditional network transfer cost based on output type compatibility. - * Returns 0 if output types match, otherwise returns the network transfer cost. - * @param parentFedOutType The federated output type of the parent plan. - * @return The conditional network transfer cost. - */ - public double getCondNetTransferCost(FederatedOutput parentFedOutType) { - if (parentFedOutType == getFedOutType()) return 0; - return fedPlanVariants.hopCommon.netTransferCost; + public void setCumulativeCost(double cumulativeCost) {this.cumulativeCost = cumulativeCost;} + } + + /** + * Represents common properties and costs associated with a Hop. + * This class holds a reference to the Hop and tracks its execution and network transfer costs. + */ + public static class HopCommon { + protected final Hop hopRef; + protected double selfCost; + protected double forwardingCost; + + public HopCommon(Hop hopRef) { + this.hopRef = hopRef; + this.selfCost = 0; + this.forwardingCost = 0; } + + public Hop getHopRef() {return hopRef;} + public double getSelfCost() {return selfCost;} + public double getForwardingCost() {return forwardingCost;} + + public void setSelfCost(double selfCost) {this.selfCost = selfCost;} + public void setForwardingCost(double forwardingCost) {this.forwardingCost = forwardingCost;} } } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java index 22d7f083c45..f73165b3c5c 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java @@ -65,9 +65,9 @@ private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, F sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", - plan.getTotalCost(), + plan.getCumulativeCost(), plan.getSelfCost(), - plan.getNetTransferCost())); + plan.getForwardingCost())); // Add matrix characteristics sb.append(" [") diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java index be1cfa7cdf3..11e6b907873 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -27,8 +27,11 @@ import java.util.LinkedHashMap; import org.apache.commons.lang3.tuple.Pair; + import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.ConflictedFedPlanVariants; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; @@ -48,25 +51,292 @@ public class FederatedPlanCostEnumerator { * @param printTree A boolean flag indicating whether to print the federated plan tree. * @return The optimal FedPlan with the minimum cost for the entire DAG. */ - public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) { + public static FedPlan enumerateOptimalFederatedPlanCost(Hop rootHop, boolean printTree) { + Set visited = new HashSet<>(); + Map> conflictMergeResolveMap = new HashMap<>(); + Map> resolveMap = new HashMap<>(); + detectPossibleConflicts(rootHop, visited, conflictMergeResolveMap, resolveMap); + // Create new memo table to store all plan variants FederatedMemoTable memoTable = new FederatedMemoTable(); - // Recursively enumerate all possible plans - enumerateFederatedPlanCost(rootHop, memoTable); + enumerateFederatedPlanCost(rootHop, memoTable, conflictMergeResolveMap); // Return the minimum cost plan for the root node FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable); // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types - double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); + // double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); // Optionally print the federated plan tree if requested - if (printTree) FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost); + // if (printTree) FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost); return optimalPlan; } + public static void detectPossibleConflicts(Hop hop, Set visited, Map> conflictMergeResolveMap, Map> resolveMap) { + for (Hop inputHop : hop.getInput()) { + if (visited.contains(hop.getHopID())) + return; + + visited.add(hop.getHopID()); + + if (inputHop.getParent().size() > 1) + findMergeResolvePaths(inputHop, conflictMergeResolveMap); + + detectPossibleConflicts(inputHop, visited, conflictMergeResolveMap); + } + } + + /** + * Identifies and marks conflicts and merge points in a Hop DAG starting from a conflicted Hop. + * A conflicted Hop is one that has multiple parent nodes, indicating potential execution path conflicts. + * + * The algorithm performs a breadth-first search (BFS) through the DAG to: + * 1. Start from a conflicted hop (one with multiple parents) + * 2. Traverse upward through parent nodes using BFS + * 3. Track merge points where execution paths converge + * 4. Mark nodes as resolved when all required merges are found + * 5. Track the count of merged hops at each merge point + * + * @param conflictedHop The Hop node with multiple parents that initiates the conflict detection + * @param conflictMergeResolveMap Map storing conflict and merge information for each Hop ID + */ + private static void findMergeResolvePaths(Hop conflictedHop, Map> conflictMergeResolveMap, Map resolveMap) { + // Initialize counter for remaining merges needed (parents - 1 since we need n-1 merges for n paths) + long conflictedHopID = conflictedHop.getHopID(); + int leftMergeCount = conflictedHop.getParent().size() - 1; + boolean isConverged = true; + + Set visited = new HashSet<>(); + Queue> BFSqueue = new LinkedList<>(); + + long convergeHopID = -1; + List topResolveHops = new ArrayList<>(); + List topResolveHopIDs = new ArrayList<>(); + + Map splitPointMap = new HashMap<>(); + Set mergeHopIDs = new HashSet<>(); + Set splitHopIDs = new HashSet<>(); + + // 여러 개의 부모 집합을 추가하는 경우 + for (Hop parentHop : conflictedHop.getParent()) { + SplitInfo splitInfo = new SplitInfo(parentHop); + BFSqueue.offer(Pair.of(parentHop, splitInfo)); + splitPointMap.put(parentHop.getHopID(), splitInfo); + } + + // 의문점 1. 모든 hop을 다 거치는가? + // 의문점 2. resolve Point 너머도 진행되지는 않았는가? 진행되었다면 지워야 한다. + + // Start BFS traversal through the DAG + while (!BFSqueue.isEmpty() || leftMergeCount > 0) { + Pair current = BFSqueue.poll(); + Hop currentHop = current.getKey(); + SplitInfo splitInfo = current.getValue(); + int numOfParent = currentHop.getParent().size(); + + if (numOfParent == 0) { + isConverged = false; + leftMergeCount--; + updateConflictResolveType(conflictMergeResolveMap, currentHop.getHopID(), conflictedHopID, false, false, ResolvedType.TOP); + topResolveHopIDs.add(currentHop.getHopID()); + topResolveHops.add(currentHop); + continue; + } + + // For nodes with multiple parents, update the merge count + // Each additional parent represents another path that needs to be merged + boolean isSplited = false; + if (numOfParent > 1){ + isSplited = true; + leftMergeCount += numOfParent - 1; + } + + // Process all parent nodes of the current node + for (Hop parentHop : currentHop.getParent()) { + long parentHopID = parentHop.getHopID(); + + if (isSplited) { + splitHopIDs.add(parentHopID); + } + + // Handle potential merge points (nodes with multiple inputs) + if (parentHop.getInput().size() > 1) { + // If node was previously visited, update merge information + if (visited.contains(parentHopID)) { + leftMergeCount--; + mergeHopIDs.add(parentHopID); + + if (leftMergeCount == 0 && isConverged){ + updateConflictResolveType(conflictMergeResolveMap, parentHopID, conflictedHopID, true, isSplited, ResolvedType.RESOLVE); + convergeHopID = parentHopID; + } else { + updateConflictResolveType(conflictMergeResolveMap, parentHopID, conflictedHopID, true, isSplited, ResolvedType.INNER_PATH); + } + } else { + // First visit to this node - initialize tracking information + visited.add(parentHopID); + BFSqueue.offer(parentHop); + addConflictResolveType(conflictMergeResolveMap, parentHopID, conflictedHopID, false, isSplited, ResolvedType.INNER_PATH); + } + } else { + // Handle nodes with single input + // No need to track visit count as these aren't merge points + BFSqueue.offer(parentHop); + addConflictResolveType(conflictMergeResolveMap, parentHopID, conflictedHopID, false, isSplited, ResolvedType.INNER_PATH); + } + } + } + + ResolveInfo resolveInfo; + + if (isConverged) { + resolveInfo = new ResolveInfo(conflictedHopID, convergeHopID, null, null); + } else { + for (Hop topHop : topResolveHops) { + boolean isfound = false; + + while (!isfound) { + // 공통점 1: 자신의 부모에서 더 이상 merge가 발생하지 않음 + // 공통점 2: 자식이 자식들이 split하였다면, 반드시 merge 되어야 함. + // 차이점 1: last-merge는 자신이 merge하나, first-split은 자신이 merge하지 않음. + // 차이점 2: last-merge는 자식이 split하지 않아도 되나, first-split은 자식이 반드시 split해야 함. + + for (Hop childHop : topHop.getInput()) { + // Todo: 여기부터 하자. + // visited, merge인지, split인지, split되면 merge 되었는지... + // bfs queues는 hop과 hop의 split point들을 가지고 다님. + // merge가 되면 마지막 split point를 지우고, 차례대로 지움. + + if (!visited.contains(childHop.getHopID())) + continue; + + + if (mergeHopIDs.contains(childHop.getHopID()) && childHop.getParent().size() == 1) { + isfound = true; + updateConflictResolveType(conflictMergeResolveMap, childHop.getHopID(), conflictedHopID, true, false, ResolvedType.FIRST_SPLIT_LAST_MERGE); + } + + if (mergeHopIDs.contains(childHop.getHopID()) && childHop.getParent().size() > 1) { + for (Hop childParentHop : childHop.getParent()) { + if (childParentHop == topHop) + continue; + + if (childParentHop is Merged) + + } + } + + if () + + if (childHop.getParent().size() == 1) { + if (mergeHopIDs.contains(childHop.getHopID())) { + if (childHop.getParent().size() == 1) { + isfound = true; + updateConflictResolveType(conflictMergeResolveMap, childHop.getHopID(), conflictedHopID, true, false, ResolvedType.FIRST_SPLIT_LAST_MERGE); + } else{ + + } + + } + + if (splitHopIDs.contains(childHop.getHopID())) { + + } + } + } + } + + + // // childHop이 merge혹은 initial parent일 때까지 내려가야함. + // if (childInfo.isMerged() || initialParentHopIDs.contains(childHop.getHopID())) { + // // 1. single-parent이면, child가 last-merge 혹은 first-split임 + // if (childHop.getParent().size() == 1) { + // isfound = true; + // updateConflictResolveType(conflictMergeResolveMap, childHop.getHopID(), conflictedHopID, true, false, ResolvedType.FIRST_SPLIT_LAST_MERGE); + // } else { + // ResolvedType resolvedType = conflictMergeResolveMap.get(childHop.getHopID()).stream() + // .filter(resolveInfo -> resolveInfo.conflictedHopID == conflictedHopID) + // .findFirst() + // .get() + // .getResolvedType(); + + // if (resolvedType != ResolvedType.INNER_PATH && resolvedType != ResolvedType.OUTER_PATH) { + // isfound = true; + // updateConflictResolveType(conflictMergeResolveMap, childHop.getHopID(), conflictedHopID, true, false, resolvedType); + // } + + // for (Hop parentHop : childHop.getParent()) { + // // childHop의 다른 parent가 merge되었는지 확인해야함. + // // merge한 hop을 기억해야함 + // // split한 hop이면 더해졌을 수도 있으니 그것도 문제임 + // // path에서 split 포인트를 기억하고 있어야 하나? + // // 나중에 모았다가 진행해야 하는 듯. + // // left merge count가 줄어드는 건 맞으니까. + // // 서로 엉킬수도 있나? + // } + // // 2. multi-parent이면, child가 first-split임. + // // 2-1: 다른 parent가 모두 merge하지 않으면, childHop은 last-merge임 + // // 2-2: 다른 parent가 하나라도 merge하면, currentHop이 first-split임. + // } + // // end case decision + // break; + // } else { + // currentHop = childHop; + // updateConflictResolveType(conflictMergeResolveMap, childHop.getHopID(), conflictedHopID, false, false, ResolvedType.OUTER_PATH); + // } + } + resolveInfo = new ResolveInfo(conflictedHopID, convergeHopID, topResolveHopIDs, firstSplitLastMergeHopIDs); + } + resolveMap.put(conflictedHopID, resolveInfo); + } + + public static class SplitInfo { + private Hop hopRef; + private int numOfParents; + private Set mergeParentHopIDs; + + public SplitInfo(Hop hopRef) { + this.hopRef = hopRef; + this.numOfParents = hopRef.getParent().size(); + this.mergeParentHopIDs = new HashSet<>(); + } + } + + private static void updateConflictResolveType(Map> conflictMergeResolveMap, long currentHopID, long conflictedHopID, boolean isMerged, boolean isSplited, ResolvedType resolvedType) { + List mergeInfoList = conflictMergeResolveMap.get(currentHopID); + mergeInfoList.stream() + .filter(info -> info.conflictedHopID == conflictedHopID) + .forEach(info -> { + info.isMerged |= isMerged; + info.isSplited |= isSplited; + info.resolvedType = resolvedType; + }); + } + + private static void addConflictResolveType(Map> conflictMergeResolveMap, + long currentHopID, long conflictedHopID, boolean isMerged, boolean isSplited, ResolvedType resolvedType) { + conflictMergeResolveMap.putIfAbsent(currentHopID, new ArrayList<>()); + conflictMergeResolveMap.get(currentHopID).add(new ConflictMergeResolveInfo(conflictedHopID, isMerged, isSplited, resolvedType)); + } + + public static class ResolveInfo { + private long conflictHopID; + private long convergeHopID; + private List topResolveHopIDs; + private List firstSplitLastMergeHopIDs; + + public ResolveInfo(long conflictHopID, long convergeHopID, List topResolveHopIDs, List firstSplitLastMergeHopIDs) { + this.conflictHopID = conflictHopID; + this.convergeHopID = convergeHopID; + this.topResolveHopIDs = topResolveHopIDs; + this.firstSplitLastMergeHopIDs = firstSplitLastMergeHopIDs; + } + } + + + /** * Recursively enumerates all possible federated execution plans for a Hop DAG. * For each node: @@ -82,43 +352,123 @@ public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) * @param hop ? * @param memoTable ? */ - private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) { - int numInputs = hop.getInput().size(); + private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable, + Map> conflictMergeResolveMap) { // Process all input nodes first if not already in memo table for (Hop inputHop : hop.getInput()) { if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) { - enumerateFederatedPlanCost(inputHop, memoTable); + enumerateFederatedPlanCost(inputHop, memoTable, conflictMergeResolveMap); } } + long hopID = hop.getHopID(); + HopCommon hopCommon = new HopCommon(hop); + FederatedPlanCostEstimator.computeHopCost(hopCommon); + + int numInputs = hop.getInput().size(); + double selfCost = hopCommon.getSelfCost(); + + // Todo: (구현) conflict hop의 initial parent 처리 + // Todo: (구현) resolve point 위에서 처리 (resolve, first-split & last-merge, top-level) + + if (!conflictMergeResolveMap.containsKey(hopID)){ + FedPlanVariants LOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.LOUT); + FedPlanVariants FOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.FOUT); + + // # of child, LOUT/FOUT of child + double[][] childCumulativeCost = new double[numInputs][2]; + // # of child + double[] childForwardingCost = new double[numInputs]; + + FederatedPlanCostEstimator.getChildCosts(hopCommon, memoTable, childCumulativeCost, childForwardingCost); + + for (int i = 0; i < (1 << numInputs); i++) { + List> planChilds = new ArrayList<>(); + double lOutCumulativeCost = selfCost; + double fOutCumulativeCost = selfCost; + + // For each input, determine if it should be FOUT or LOUT based on bit j in i + for (int j = 0; j < numInputs; j++) { + Hop inputHop = hop.getInput().get(j); + final int bit = (i & (1 << j)) != 0 ? 1 : 0; // bit 값 계산 (FOUT/LOUT 결정) + final FederatedOutput childType = (bit == 1) ? FederatedOutput.FOUT : FederatedOutput.LOUT; + planChilds.add(Pair.of(inputHop.getHopID(), childType)); - // Generate all possible input combinations using binary representation - // i represents a specific combination of FOUT/LOUT for inputs - for (int i = 0; i < (1 << numInputs); i++) { - List> planChilds = new ArrayList<>(); - - // For each input, determine if it should be FOUT or LOUT based on bit j in i - for (int j = 0; j < numInputs; j++) { - Hop inputHop = hop.getInput().get(j); - // If bit j is set (1), use FOUT; otherwise use LOUT - FederatedOutput childType = ((i & (1 << j)) != 0) ? - FederatedOutput.FOUT : FederatedOutput.LOUT; - planChilds.add(Pair.of(inputHop.getHopID(), childType)); + lOutCumulativeCost += childCumulativeCost[j][bit]; + fOutCumulativeCost += childCumulativeCost[j][bit]; + // 비트 기반 산술 연산을 사용하여 전달 비용 추가 + fOutCumulativeCost += childForwardingCost[j] * (1 - bit); // bit == 0일 때 활성화 + lOutCumulativeCost += childForwardingCost[j] * bit; // bit == 1일 때 활성화 + } + LOutFedPlanVariants.addFedPlan(new FedPlan(lOutCumulativeCost, LOutFedPlanVariants, planChilds)); + FOutFedPlanVariants.addFedPlan(new FedPlan(fOutCumulativeCost, FOutFedPlanVariants, planChilds)); } + LOutFedPlanVariants.pruneFedPlans(); + FOutFedPlanVariants.pruneFedPlans(); + + memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT, LOutFedPlanVariants); + memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT, FOutFedPlanVariants); + } else { + List conflictMergeResolveInfos = conflictMergeResolveMap.get(hopID); + conflictMergeResolveInfos.sort(Comparator.comparingLong(ConflictMergeResolveInfo::getConflictedHopID)); + + ConflictedFedPlanVariants LOutFedPlanVariants = new ConflictedFedPlanVariants(hopCommon, FederatedOutput.LOUT, conflictMergeResolveInfos); + ConflictedFedPlanVariants FOutFedPlanVariants = new ConflictedFedPlanVariants(hopCommon, FederatedOutput.FOUT, conflictMergeResolveInfos); - // Create and evaluate FOUT variant for current input combination - FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds); - FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable); + int numOfConflictCombinations = 1 << conflictMergeResolveInfos.size(); + double mergeCost = FederatedPlanCostEstimator.computeMergeCost(conflictMergeResolveInfos, memoTable); + selfCost += mergeCost; - // Create and evaluate LOUT variant for current input combination - FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds); - FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable); - } + // 2^(# of conflicts), # of childs, LOUT/FOUT of child + double[][][] childCumulativeCost = new double[numOfConflictCombinations][numInputs][2]; + int[][][] childForwardingBitMap = new int[numOfConflictCombinations][numInputs][2]; + double[] childForwardingCost = new double[numInputs]; // # of childs + + FederatedPlanCostEstimator.getConflictedChildCosts(hopCommon, memoTable, conflictMergeResolveInfos, childCumulativeCost, childForwardingBitMap, childForwardingCost); + + for (int i = 0; i < (1 << numInputs); i++) { + List> planChilds = new ArrayList<>(); + + for (int j = 0; j < numOfConflictCombinations; j++) { + LOutFedPlanVariants.cumulativeCost[j][i] = selfCost; + FOutFedPlanVariants.cumulativeCost[j][i] = selfCost; + } + + for (int j = 0; j < numInputs; j++) { + Hop inputHop = hop.getInput().get(j); + + final int bit = (i & (1 << j)) != 0 ? 1 : 0; // bit 값 계산 (FOUT/LOUT 결정) + final FederatedOutput childType = (bit == 1) ? FederatedOutput.FOUT : FederatedOutput.LOUT; + planChilds.add(Pair.of(inputHop.getHopID(), childType)); + + for (int k = 0; k < numOfConflictCombinations; k++) { + // 비트 기반 인덱스를 사용하여 누적 비용 업데이트 + LOutFedPlanVariants.cumulativeCost[k][i] += childCumulativeCost[k][j][bit]; + FOutFedPlanVariants.cumulativeCost[k][i] += childCumulativeCost[k][j][bit]; + + // 비트 기반 산술 연산을 사용하여 전달 비용 추가 + FOutFedPlanVariants.cumulativeCost[k][i] += childForwardingCost[j] * (1 - bit); // bit == 0일 때 활성화 + LOutFedPlanVariants.cumulativeCost[k][i] += childForwardingCost[j] * bit; // bit == 1일 때 활성화 + + if (mergeCost != 0) { + FederatedPlanCostEstimator.computeForwardingMergeCost(LOutFedPlanVariants.forwardingBitMap[k][i], + childForwardingBitMap[k][j][bit], conflictMergeResolveInfos, memoTable); + } - // Prune MemoTable for hop. - memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.LOUT); - memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.FOUT); + LOutFedPlanVariants.forwardingBitMap[k][i] |= childForwardingBitMap[k][j][bit]; + FOutFedPlanVariants.forwardingBitMap[k][i] |= childForwardingBitMap[k][j][bit]; + } + } + LOutFedPlanVariants.addFedPlan(new FedPlan(0, LOutFedPlanVariants, planChilds)); + FOutFedPlanVariants.addFedPlan(new FedPlan(0, FOutFedPlanVariants, planChilds)); + } + LOutFedPlanVariants.pruneConflictedFedPlans(); + FOutFedPlanVariants.pruneConflictedFedPlans(); + + memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT, LOutFedPlanVariants); + memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT, FOutFedPlanVariants); + } } /** @@ -130,21 +480,14 @@ private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoT * @return ? */ private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) { - FedPlanVariants fOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT); - FedPlanVariants lOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT); - - FedPlan minFOutFedPlan = fOutFedPlanVariants._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - FedPlan minlOutFedPlan = lOutFedPlanVariants._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - - if (Objects.requireNonNull(minFOutFedPlan).getTotalCost() - < Objects.requireNonNull(minlOutFedPlan).getTotalCost()) { - return minFOutFedPlan; + FedPlan lOutFedPlan = memoTable.getFedPlanAfterPrune(HopID, FederatedOutput.LOUT); + FedPlan fOutFedPlan = memoTable.getFedPlanAfterPrune(HopID, FederatedOutput.FOUT); + + if (lOutFedPlan.getCumulativeCost() < fOutFedPlan.getCumulativeCost()){ + return lOutFedPlan; + } else{ + return fOutFedPlan; } - return minlOutFedPlan; } /** @@ -231,4 +574,47 @@ private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, Federate // Return the cumulative additional cost for resolving conflicts return cumulativeAdditionalCost[0]; } + + /** + * Data structure to store conflict and merge information for a specific Hop. + * This class maintains the state of conflict resolution and merge operations + * for a given Hop in the execution plan. + */ + public static class ConflictMergeResolveInfo { + private long conflictedHopID; // ID of the Hop that originated the conflict + private boolean isMerged; + private boolean isSplited; + private ResolvedType resolvedType; + + public ConflictMergeResolveInfo(long conflictedHopID, boolean isMerged, boolean isSplited, ResolvedType resolvedType) { + this.conflictedHopID = conflictedHopID; + this.isMerged = isMerged; + this.isSplited = isSplited; + this.resolvedType = resolvedType; + } + + public long getConflictedHopID() { + return conflictedHopID; + } + + public boolean isMerged() { + return isMerged; + } + + public boolean isSplited() { + return isSplited; + } + + public ResolvedType getResolvedType() { + return resolvedType; + } + } + + public static enum ResolvedType { + INNER_PATH, + OUTER_PATH, + FIRST_SPLIT_LAST_MERGE, // 첫 분기점 또는 마지막 + RESOLVE, // 해결 지점 + TOP // 최상위 지점 + }; } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java index 7bc7339563a..3ae8b37a82c 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -22,8 +22,13 @@ import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.cost.ComputeCost; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; +import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator.ConflictMergeResolveInfo; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.ConflictedFedPlanVariants; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; +import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.NoSuchElementException; import java.util.List; @@ -42,43 +47,138 @@ public class FederatedPlanCostEstimator { // Network bandwidth for data transfers between federated sites (1 Gbps) private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; - /** - * Computes total cost of federated plan by: - * 1. Computing current node cost (if not cached) - * 2. Adding minimum-cost child plans - * 3. Including network transfer costs when needed - * - * @param currentPlan Plan to compute cost for - * @param memoTable Table containing all plan variants - */ - public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) { - double totalCost; - Hop currentHop = currentPlan.getHopRef(); - - // Step 1: Calculate current node costs if not already computed - if (currentPlan.getSelfCost() == 0) { - // Compute cost for current node (computation + memory access) - totalCost = computeCurrentCost(currentHop); - currentPlan.setSelfCost(totalCost); - // Calculate potential network transfer cost if federation type changes - currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate())); - } else { - totalCost = currentPlan.getSelfCost(); - } + public static void getChildCosts(HopCommon hopCommon, FederatedMemoTable memoTable, double[][] childCumulativeCost, double[] childForwardingCost) { + List inputHops = hopCommon.hopRef.getInput(); - // Step 2: Process each child plan and add their costs - for (Pair childPlanPair : currentPlan.getChildFedPlans()) { - // Find minimum cost child plan considering federation type compatibility - // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents - // because we're selecting child plans independently for each parent - FedPlan planRef = memoTable.getMinCostFedPlan(childPlanPair); - - // Add child plan cost (includes network transfer cost if federation types differ) - totalCost += planRef.getTotalCost() + planRef.getCondNetTransferCost(currentPlan.getFedOutType()); + for (int i = 0; i < inputHops.size(); i++) { + long childHopID = inputHops.get(i).getHopID(); + + FedPlan childLOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.LOUT); + FedPlan childFOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.FOUT); + + childCumulativeCost[i][0] = childLOutFedPlan.getCumulativeCost(); + childCumulativeCost[i][1] = childFOutFedPlan.getCumulativeCost(); + childForwardingCost[i] = childLOutFedPlan.getForwardingCost(); + } + } + + public static void getConflictedChildCosts(HopCommon hopCommon, FederatedMemoTable memoTable, List conflictMergeResolveInfos, + double[][][] childCumulativeCost, int[][][] childForwardingBitMap, double[] childForwardingCost) { + List inputHops = hopCommon.hopRef.getInput(); + int numConflictCombinations = 1 << conflictMergeResolveInfos.size(); + + for (int i = 0; i < inputHops.size(); i++) { + long childHopID = inputHops.get(i).getHopID(); + + FedPlanVariants childLOutVariants = memoTable.getFedPlanVariants(childHopID, FederatedOutput.LOUT); + FedPlanVariants childFOutVariants = memoTable.getFedPlanVariants(childHopID, FederatedOutput.FOUT); + + childForwardingCost[i] = childLOutVariants.getForwardingCost(); + + if (childLOutVariants instanceof ConflictedFedPlanVariants) { + FedPlan childLOutFedPlan = childLOutVariants.getFedPlanVariants().get(0); + FedPlan childFOutFedPlan = childFOutVariants.getFedPlanVariants().get(0); + + for (int j = 0; j < numConflictCombinations; j++) { + childCumulativeCost[j][i][0] = childLOutFedPlan.getCumulativeCost(); + childCumulativeCost[j][i][1] = childFOutFedPlan.getCumulativeCost(); + } + } + else { + ConflictedFedPlanVariants conflictedChildLOutVariants = (ConflictedFedPlanVariants) childLOutVariants; + ConflictedFedPlanVariants conflictedChildFOutVariants = (ConflictedFedPlanVariants) childFOutVariants; + + computeConflictedChildCosts(conflictMergeResolveInfos, conflictedChildLOutVariants, childCumulativeCost, childForwardingBitMap, i, 0); + computeConflictedChildCosts(conflictMergeResolveInfos, conflictedChildFOutVariants, childCumulativeCost, childForwardingBitMap, i, 1); + } + } + } + + private static void computeConflictedChildCosts(List conflictInfos, ConflictedFedPlanVariants conflictedChildVariants, + double[][][] childCumulativeCost, int[][][] childForwardingBitMap, int childIdx, int fedOutTypeIdx){ + int i = 0, j = 0; + int pLen = conflictInfos.size(); + int cLen = conflictedChildVariants.conflictInfos.size(); + int numConflictCombinations = 1 << conflictInfos.size(); + + // Step 1: 공통 제약 조건과 비공통 자식 위치 계산 + List common = new ArrayList<>(); + List nonCommonChildPos = new ArrayList<>(); + + while (i < pLen && j < cLen) { + long pHopID = conflictInfos.get(i).getConflictedHopID(); + long cHopID = conflictedChildVariants.conflictInfos.get(j).getConflictedHopID(); + + if (pHopID == cHopID) { + int pBitPos = pLen - 1 - i; + int cBitPos = cLen - 1 - j; + common.add(new CommonConstraint(pHopID, pBitPos, cBitPos)); + i++; + j++; + } else if (pHopID < cHopID) { + i++; + } else { + int cBitPos = cLen - 1 - j; + nonCommonChildPos.add(cBitPos); + j++; + } + } + + int restNumBits = nonCommonChildPos.size(); + for (int parentIdx = 0; parentIdx < numConflictCombinations; parentIdx++) { + // 공통 제약 조건을 기반으로 baseChildIdx 계산 + int baseChildIdx = 0; + for (CommonConstraint cc : common) { + int bit = (parentIdx >> cc.pBitPos) & 1; + baseChildIdx |= (bit << cc.cBitPos); + } + + // 최소 비용을 가진 자식 인덱스 찾기 + double minChildCost = Double.MAX_VALUE; + int minChildIdx = -1; + for (int restValue = 0; restValue < (1 << restNumBits); restValue++) { + int temp = 0; + for (int bitIdx = 0; bitIdx < restNumBits; bitIdx++) { + if (((restValue >> bitIdx) & 1) == 1) { + temp |= (1 << nonCommonChildPos.get(bitIdx)); + } + } + int tempChildIdx = baseChildIdx | temp; + if (conflictedChildVariants.cumulativeCost[tempChildIdx][0] < minChildCost) { + minChildCost = conflictedChildVariants.cumulativeCost[tempChildIdx][0]; + minChildIdx = tempChildIdx; + } + } + + // 자식의 isForwardBitMap을 부모의 비트 위치로 변환 + int childForwardBitMap = conflictedChildVariants.forwardingBitMap[minChildIdx][0]; + int convertedBitmask = 0; + for (CommonConstraint cc : common) { + int childBit = (childForwardBitMap >> cc.cBitPos) & 1; + if (childBit == 1) { + convertedBitmask |= (1 << cc.pBitPos); + } + } + + childCumulativeCost[parentIdx][childIdx][fedOutTypeIdx] = minChildCost; + childForwardingBitMap[parentIdx][childIdx][fedOutTypeIdx] = convertedBitmask; + } + } + + // Todo: (최적화) 추후에 MemoTable retrieve 하지 않게 최적화 가능 + public static double computeForwardingMergeCost(int parentBitmask, int childBitmask, List conflictInfos, FederatedMemoTable memoTable){ + int overlappingBits = parentBitmask & childBitmask; + double overlappingForwardingCost = 0.0; + + int pLen = conflictInfos.size(); + for (int b = 0; b < pLen; b++) { + int bitPos = pLen - 1 - b; + if ((overlappingBits & (1 << bitPos)) != 0) { + overlappingForwardingCost += memoTable.getFedPlanVariants(conflictInfos.get(b).getConflictedHopID(), FederatedOutput.LOUT).getForwardingCost(); + } } - // Step 3: Set final cumulative cost including current node - currentPlan.setTotalCost(totalCost); + return overlappingForwardingCost; } /** @@ -138,7 +238,7 @@ public static LinkedHashMap resolveConflictFedPlan(FederatedMe if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. - fOutAdditionalCost += confilctFOutFedPlan.getTotalCost() - confilctLOutFedPlan.getTotalCost(); + fOutAdditionalCost += confilctFOutFedPlan.getCumulativeCost() - confilctLOutFedPlan.getCumulativeCost(); if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred @@ -148,30 +248,30 @@ public static LinkedHashMap resolveConflictFedPlan(FederatedMe // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later isLOutNetTransfer = true; - lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); + lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it - fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); + fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); } } else { - lOutAdditionalCost += confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost(); + lOutAdditionalCost += confilctLOutFedPlan.getCumulativeCost() - confilctFOutFedPlan.getCumulativeCost(); if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { isLOutNetTransfer = true; } else { isFOutNetTransfer = true; - lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); - fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); + lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); } } } // Add network transfer costs if applicable if (isLOutNetTransfer) { - lOutAdditionalCost += confilctLOutFedPlan.getNetTransferCost(); + lOutAdditionalCost += confilctLOutFedPlan.getForwardingCost(); } if (isFOutNetTransfer) { - fOutAdditionalCost += confilctFOutFedPlan.getNetTransferCost(); + fOutAdditionalCost += confilctFOutFedPlan.getForwardingCost(); } // Determine the optimal federated output type based on the calculated costs @@ -199,14 +299,36 @@ public static LinkedHashMap resolveConflictFedPlan(FederatedMe } return resolvedFedPlanLinkedMap; } - + + // Todo: (구현) forwarding bitmap을 본 뒤, merge cost 일일히 type에 따라 계산해야함. + public static double computeMergeCost(List conflictMergeResolveInfos, FederatedMemoTable memoTable){ + double mergeCost = 0; + + for (ConflictMergeResolveInfo conflictInfo: conflictMergeResolveInfos){ + int numOfMergedHops = conflictInfo.getNumOfMergedHops(); + + if (numOfMergedHops != 0){ + double selfCost = memoTable.getFedPlanVariants(conflictInfo.getConflictedHopID(), FederatedOutput.LOUT).getSelfCost(); + mergeCost += selfCost * numOfMergedHops; + } + } + + return mergeCost; + } + + public static void computeHopCost(HopCommon hopCommon){ + Hop hop = hopCommon.hopRef; + hopCommon.setSelfCost(computeSelfCost(hop)); + hopCommon.setForwardingCost(computeHopForwardingCost(hop.getOutputMemEstimate())); + } + /** * Computes the cost for the current Hop node. * * @param currentHop The Hop node whose cost needs to be computed * @return The total cost for the current node's operation */ - private static double computeCurrentCost(Hop currentHop){ + private static double computeSelfCost(Hop currentHop){ double computeCost = ComputeCost.getHOPComputeCost(currentHop); double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); @@ -234,7 +356,19 @@ private static double computeHopMemoryAccessCost(double memSize) { * @param memSize Size of data to be transferred (in bytes) * @return Time cost for network transfer (in seconds) */ - private static double computeHopNetworkAccessCost(double memSize) { + private static double computeHopForwardingCost(double memSize) { return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; } + + public static class CommonConstraint { + long name; + int pBitPos; + int cBitPos; + + CommonConstraint(long name, int pBitPos, int cBitPos) { + this.name = name; + this.pBitPos = pBitPos; + this.cBitPos = cBitPos; + } + } } From 16a8d00609422454b5eb7d7304380dfb0b450be8 Mon Sep 17 00:00:00 2001 From: min-guk Date: Tue, 21 Jan 2025 01:01:17 +0900 Subject: [PATCH 6/9] Enumerator for an optimal federated plan at the program level --- .../hops/fedplanner/FederatedMemoTable.java | 320 ++++------ .../fedplanner/FederatedMemoTablePrinter.java | 4 +- .../FederatedPlanCostEnumerator.java | 563 ++++-------------- .../FederatedPlanCostEstimator.java | 240 ++------ .../FederatedPlanCostEnumeratorTest.java | 20 +- .../FederatedPlanCostEnumeratorTest5.dml | 2 +- .../FederatedPlanCostEnumeratorTest6.dml | 19 +- 7 files changed, 315 insertions(+), 853 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java index 196e52b6de1..82d05e4f286 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -19,19 +19,15 @@ package org.apache.sysds.hops.fedplanner; +import org.apache.sysds.hops.Hop; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.ArrayList; import java.util.Map; -import java.util.Arrays; -import java.util.Collections; -import org.apache.sysds.hops.Hop; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; -import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator.ConflictMergeResolveInfo; -import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator.ResolvedType; /** * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. @@ -42,196 +38,98 @@ public class FederatedMemoTable { // Maps Hop ID and fedOutType pairs to their plan variants private final Map, FedPlanVariants> hopMemoTable = new HashMap<>(); - public void addFedPlanVariants(long hopID, FederatedOutput fedOutType, FedPlanVariants fedPlanVariants) { - hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariants); + /** + * Adds a new federated plan to the memo table. + * Creates a new variant list if none exists for the given Hop and fedOutType. + * + * @param hop The Hop node + * @param fedOutType The federated output type + * @param planChilds List of child plan references + * @return The newly created FedPlan + */ + public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List> planChilds) { + long hopID = hop.getHopID(); + FedPlanVariants fedPlanVariantList; + + if (contains(hopID, fedOutType)) { + fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + } else { + fedPlanVariantList = new FedPlanVariants(hop, fedOutType); + hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList); + } + + FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList); + fedPlanVariantList.addFedPlan(newPlan); + + return newPlan; } - public FedPlanVariants getFedPlanVariants(Pair fedPlanPair) { - return hopMemoTable.get(fedPlanPair); + /** + * Retrieves the minimum cost child plan considering the parent's output type. + * The cost is calculated using getParentViewCost to account for potential type mismatches. + */ + public FedPlan getMinCostFedPlan(Pair fedPlanPair) { + FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); + return fedPlanVariantList._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); } public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType) { return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); } + public FedPlanVariants getFedPlanVariants(Pair fedPlanPair) { + return hopMemoTable.get(fedPlanPair); + } + public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) { + // Todo: Consider whether to verify if pruning has been performed FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); return fedPlanVariantList._fedPlanVariants.get(0); } public FedPlan getFedPlanAfterPrune(Pair fedPlanPair) { + // Todo: Consider whether to verify if pruning has been performed FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); return fedPlanVariantList._fedPlanVariants.get(0); } + /** + * Checks if the memo table contains an entry for a given Hop and fedOutType. + * + * @param hopID The Hop ID. + * @param fedOutType The associated fedOutType. + * @return True if the entry exists, false otherwise. + */ public boolean contains(long hopID, FederatedOutput fedOutType) { return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); } - public static class ConflictedFedPlanVariants extends FedPlanVariants { - public List conflictInfos; - protected int numConflictCombinations; - // 2^(# of conflicts), 2^(# of childs) - protected double[][] cumulativeCost; - protected int[][] forwardingBitMap; - - // bitset array (java class) >> arbitary length >> - - public ConflictedFedPlanVariants(HopCommon hopCommon, FederatedOutput fedOutType, - List conflictMergeResolveInfos) { - super(hopCommon, fedOutType); - this.conflictInfos = conflictMergeResolveInfos; - this.numConflictCombinations = 1 << this.conflictInfos.size(); - this.cumulativeCost = new double[this.numConflictCombinations][this._fedPlanVariants.size()]; - this.forwardingBitMap = new int[this.numConflictCombinations][this._fedPlanVariants.size()]; - // Initialize isForwardBitMap to 0 - for (int i = 0; i < this.numConflictCombinations; i++) { - Arrays.fill(this.cumulativeCost[i], 0); - Arrays.fill(this.forwardingBitMap[i], 0); - } - } - - // Todo: (최적화) java bitset 사용하여, 다수의 conflict 처리할 수 있도록 해야 함. - // Todo: (구현) 만약 resolve point (converge, first-split & last-merge) child로 내려가면서 recursive하게 prune 해야 함. (이때, parents의 LOUT/FOUT의 Optimal Plan을 동시에 고려해야함) - public void pruneConflictedFedPlans() { - // Step 1: Initialize prunedCost and prunedIsForwardingBitMap with minimal values per combination - double[][] prunedCost = new double[this.numConflictCombinations][1]; - int[][] prunedIsForwardingBitMap = new int[this.numConflictCombinations][1]; - List prunedFedPlanVariants = new ArrayList<>(); - - for (int i = 0; i < this.numConflictCombinations; i++) { - double minCost = Double.MAX_VALUE; - int minIndex = -1; - for (int j = 0; j < _fedPlanVariants.size(); j++) { - if (cumulativeCost[i][j] < minCost) { - minCost = cumulativeCost[i][j]; - minIndex = j; - } - } - prunedCost[i][0] = minCost; - prunedIsForwardingBitMap[i][0] = (minIndex != -1) ? forwardingBitMap[i][minIndex] : 0; - prunedFedPlanVariants.add(_fedPlanVariants.get(minIndex)); - } - - this.cumulativeCost = prunedCost; - this.forwardingBitMap = prunedIsForwardingBitMap; - this._fedPlanVariants = prunedFedPlanVariants; - - // Step 2: Collect resolved conflict bit positions - List resolvedBits = new ArrayList<>(); - for (int i = 0; i < conflictInfos.size(); i++) { - ConflictMergeResolveInfo info = conflictInfos.get(i); - if (info.getResolvedType() == ResolvedType.RESOLVE) { - resolvedBits.add(i); // Assuming index corresponds to bit position - } - } - - int resolvedBitsSize = resolvedBits.size(); - - // CASE 1: if not resolved, return - if (resolvedBitsSize == 0){ - return; - } - - // CASE 2: if all resolved, transform to FedPlanVariants - if (resolvedBits.size() == conflictInfos.size()){ - double minCost = Double.MAX_VALUE; - int minCostIdx = -1; - - for (int i = 0; i < this.numConflictCombinations; i++) { - if (cumulativeCost[i][0] < minCost) { - minCost = cumulativeCost[i][0]; - minCostIdx = i; - } - } - - FedPlan finalFedPlan = this.getFedPlanVariants().get(minCostIdx); - finalFedPlan.setCumulativeCost(minCost); - this._fedPlanVariants.clear(); - this._fedPlanVariants.add(finalFedPlan); - - this.conflictInfos = null; - this.cumulativeCost = null; - this.forwardingBitMap = null; - this.numConflictCombinations = 0; - - return; - } - - // CASE 3: if some resolved, some not, merge them - int mask = 0; - for (int bit : resolvedBits) { - mask |= (1 << bit); - } - mask = ~mask; - - List unresolvedBits = new ArrayList<>(); - for (int bit = 0; bit < conflictInfos.size(); bit++) { - if (!resolvedBits.contains(bit)) { - unresolvedBits.add(bit); - } - } - Collections.sort(unresolvedBits); // Ensure consistent ordering - - // Create newConflictInfos with unresolved conflicts - List newConflictInfos = new ArrayList<>(); - for (int bit : unresolvedBits) { - newConflictInfos.add(conflictInfos.get(bit)); - } + /** + * Prunes the specified entry in the memo table, retaining only the minimum-cost + * FedPlan for the given Hop ID and federated output type. + * + * @param hopID The ID of the Hop to prune + * @param federatedOutput The federated output type associated with the Hop + */ + public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) { + hopMemoTable.get(new ImmutablePair<>(hopID, federatedOutput)).prune(); + } - // Step 4: Group combinations by their base (ignoring resolved bits) - Map> groups = new HashMap<>(); - for (int i = 0; i < this.numConflictCombinations; i++) { - int base = i & mask; - groups.computeIfAbsent(base, k -> new ArrayList<>()).add(i); - } - - // Step 5: Merge groups and create new arrays with reduced size - int newSize = 1 << unresolvedBits.size(); - double[][] newPrunedCost = new double[newSize][1]; - int[][] newPrunedBitMap = new int[newSize][1]; - List newPrunedFedPlanVariants = new ArrayList<>(newSize); - Arrays.fill(newPrunedCost, Double.MAX_VALUE); - - for (Map.Entry> entry : groups.entrySet()) { - int base = entry.getKey(); - List group = entry.getValue(); - - // Find minimal cost and bitmap in the group - double minGroupCost = Double.MAX_VALUE; - int minBitmap = 0; - int minIdx = -1; + /** + * Represents common properties and costs associated with a Hop. + * This class holds a reference to the Hop and tracks its execution and network transfer costs. + */ + public static class HopCommon { + protected final Hop hopRef; // Reference to the associated Hop + protected double selfCost; // Current execution cost (compute + memory access) + protected double forwardingCost; // Network transfer cost - for (int comb : group) { - if (cumulativeCost[comb][0] < minGroupCost) { - minGroupCost = cumulativeCost[comb][0]; - minBitmap = forwardingBitMap[comb][0]; - minIdx = comb; - } - } - - // Compute new index based on unresolved bits - int newIndex = 0; - for (int i = 0; i < unresolvedBits.size(); i++) { - int bitPos = unresolvedBits.get(i); - if ((base & (1 << bitPos)) != 0) { - newIndex |= (1 << i); // Set the i-th bit in newIndex - } - } - - // Update newPruned arrays - if (newIndex < newSize) { - newPrunedCost[newIndex][0] = minGroupCost; - newPrunedBitMap[newIndex][0] = minBitmap; - newPrunedFedPlanVariants.add(newIndex, _fedPlanVariants.get(minIdx)); - } - } - - // Replace the pruned arrays with the merged results and update size - this.conflictInfos = newConflictInfos; - this.cumulativeCost = newPrunedCost; - this.forwardingBitMap = newPrunedBitMap; - this.numConflictCombinations = newSize; // Update to the new reduced size + protected HopCommon(Hop hopRef) { + this.hopRef = hopRef; + this.selfCost = 0; + this.forwardingCost = 0; } } @@ -245,24 +143,21 @@ public static class FedPlanVariants { private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) protected List _fedPlanVariants; // List of plan variants - public FedPlanVariants(HopCommon hopCommon, FederatedOutput fedOutType) { - this.hopCommon = hopCommon; + public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) { + this.hopCommon = new HopCommon(hopRef); this.fedOutType = fedOutType; this._fedPlanVariants = new ArrayList<>(); } - public boolean isEmpty() {return _fedPlanVariants.isEmpty();} public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} public List getFedPlanVariants() {return _fedPlanVariants;} - public FederatedOutput getFedOutType() {return fedOutType;} - public double getSelfCost() {return hopCommon.getSelfCost();} - public double getForwardingCost() {return hopCommon.getForwardingCost();} + public boolean isEmpty() {return _fedPlanVariants.isEmpty();} - public void pruneFedPlans() { + public void prune() { if (_fedPlanVariants.size() > 1) { // Find the FedPlan with the minimum cost FedPlan minCostPlan = _fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getCumulativeCost)) + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) .orElse(null); // Retain only the minimum cost plan @@ -276,53 +171,44 @@ public void pruneFedPlans() { * Represents a single federated execution plan with its associated costs and dependencies. * This class contains: * 1. selfCost: Cost of current hop (compute + input/output memory access) - * 2. cumulativeCost: Cumulative cost including this plan and all child plans - * 3. netTransferCost: Network transfer cost for this plan to parent plan. + * 2. totalCost: Cumulative cost including this plan and all child plans + * 3. forwardingCost: Network transfer cost for this plan to parent plan. * * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs. */ public static class FedPlan { - private double cumulativeCost; // Total cost including child plans + private double totalCost; // Total cost including child plans private final FedPlanVariants fedPlanVariants; // Reference to variant list private final List> childFedPlans; // Child plan references - public FedPlan(double cumulativeCost, FedPlanVariants fedPlanVariants, List> childFedPlans) { - this.cumulativeCost = cumulativeCost; + public FedPlan(List> childFedPlans, FedPlanVariants fedPlanVariants) { + this.totalCost = 0; + this.childFedPlans = childFedPlans; this.fedPlanVariants = fedPlanVariants; - this.childFedPlans = childFedPlans; } - public Hop getHopRef() {return fedPlanVariants.hopCommon.getHopRef();} - public long getHopID() {return fedPlanVariants.hopCommon.getHopRef().getHopID();} - public FederatedOutput getFedOutType() {return fedPlanVariants.getFedOutType();} - public double getCumulativeCost() {return cumulativeCost;} - public double getSelfCost() {return fedPlanVariants.hopCommon.getSelfCost();} - public double getForwardingCost() {return fedPlanVariants.hopCommon.getForwardingCost();} + public void setTotalCost(double totalCost) {this.totalCost = totalCost;} + public void setSelfCost(double selfCost) {fedPlanVariants.hopCommon.selfCost = selfCost;} + public void setForwardingCost(double forwardingCost) {fedPlanVariants.hopCommon.forwardingCost = forwardingCost;} + public void applyIterationWeight(int iteration) {totalCost *= iteration;} + + public Hop getHopRef() {return fedPlanVariants.hopCommon.hopRef;} + public long getHopID() {return fedPlanVariants.hopCommon.hopRef.getHopID();} + public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;} + public double getTotalCost() {return totalCost;} + public double getSelfCost() {return fedPlanVariants.hopCommon.selfCost;} + public double setForwardingCost() {return fedPlanVariants.hopCommon.forwardingCost;} public List> getChildFedPlans() {return childFedPlans;} - public void setCumulativeCost(double cumulativeCost) {this.cumulativeCost = cumulativeCost;} - } - - /** - * Represents common properties and costs associated with a Hop. - * This class holds a reference to the Hop and tracks its execution and network transfer costs. - */ - public static class HopCommon { - protected final Hop hopRef; - protected double selfCost; - protected double forwardingCost; - - public HopCommon(Hop hopRef) { - this.hopRef = hopRef; - this.selfCost = 0; - this.forwardingCost = 0; + /** + * Calculates the conditional network transfer cost based on output type compatibility. + * Returns 0 if output types match, otherwise returns the network transfer cost. + * @param parentFedOutType The federated output type of the parent plan. + * @return The conditional network transfer cost. + */ + public double getCondForwardingCost(FederatedOutput parentFedOutType) { + if (parentFedOutType == getFedOutType()) return 0; + return fedPlanVariants.hopCommon.forwardingCost; } - - public Hop getHopRef() {return hopRef;} - public double getSelfCost() {return selfCost;} - public double getForwardingCost() {return forwardingCost;} - - public void setSelfCost(double selfCost) {this.selfCost = selfCost;} - public void setForwardingCost(double forwardingCost) {this.forwardingCost = forwardingCost;} } } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java index f73165b3c5c..391868efcd7 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java @@ -65,9 +65,9 @@ private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, F sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", - plan.getCumulativeCost(), + plan.getTotalCost(), plan.getSelfCost(), - plan.getForwardingCost())); + plan.setForwardingCost())); // Add matrix characteristics sb.append(" [") diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java index 11e6b907873..f626e27c1bc 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -27,13 +27,20 @@ import java.util.LinkedHashMap; import org.apache.commons.lang3.tuple.Pair; - import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.ConflictedFedPlanVariants; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; +import org.apache.sysds.parser.IfStatement; +import org.apache.sysds.parser.IfStatementBlock; +import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; /** @@ -42,301 +49,109 @@ * to compute their costs. */ public class FederatedPlanCostEnumerator { + public static void enumerateProgram(DMLProgram prog) { + for(StatementBlock sb : prog.getStatementBlocks()) + enumerateStatementBlock(sb); + } + + /** + * Recursively enumerates federated execution plans for a given statement block. + * This method processes each type of statement block (If, For, While, Function, and generic) + * to determine the optimal federated plan. + * + * @param sb The statement block to enumerate. + */ + public static void enumerateStatementBlock(StatementBlock sb) { + // While enumerating the program, recursively determine the optimal FedPlan and MemoTable + // for each statement block and statement. + // 1. How to recursively integrate optimal FedPlans and MemoTables across statements and statement blocks? + // 1) Is it determined using the same dynamic programming approach, or simply by summing the minimal plans? + // 2. Is there a need to share the MemoTable? Are there data/hop dependencies between statements? + // 3. How to predict the number of iterations for For and While loops? + // 1) If from/to/increment are constants: Calculations can be done at compile time. + // 2) If they are variables: Use default values at compile time, adjust at runtime, or predict using ML models. + + if (sb instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement)isb.getStatement(0); + + enumerateFederatedPlanCost(isb.getPredicateHops()); + + for (StatementBlock csb : istmt.getIfBody()) + enumerateStatementBlock(csb); + for (StatementBlock csb : istmt.getElseBody()) + enumerateStatementBlock(csb); + + // Todo: 1. apply iteration weight to csbFedPlans (if: 0.5, else: 0.5) + // Todo: 2. Merge predFedPlans + } else if (sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + + enumerateFederatedPlanCost(fsb.getFromHops()); + enumerateFederatedPlanCost(fsb.getToHops()); + enumerateFederatedPlanCost(fsb.getIncrementHops()); + + for (StatementBlock csb : fstmt.getBody()) + enumerateStatementBlock(csb); + + // Todo: 1. get(predict) # of Iterations + // Todo: 2. apply iteration weight to csbFedPlans + // Todo: 3. Merge csbFedPlans and predFedPlans + } else if (sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + enumerateFederatedPlanCost(wsb.getPredicateHops()); + + ArrayList csbFedPlans = new ArrayList<>(); + for (StatementBlock csb : wstmt.getBody()) + enumerateStatementBlock(csb); + + // Todo: 1. get(predict) # of Iterations + // Todo: 2. apply iteration weight to csbFedPlans + // Todo: 3. Merge csbFedPlans and predFedPlans + } else if (sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + for (StatementBlock csb : fstmt.getBody()) + enumerateStatementBlock(csb); + + // Todo: 1. Merge csbFedPlans + } else { //generic (last-level) + if( sb.getHops() != null ) + for( Hop c : sb.getHops() ) + enumerateFederatedPlanCost(c); + } + } + /** * Entry point for federated plan enumeration. This method creates a memo table * and returns the minimum cost plan for the entire Directed Acyclic Graph (DAG). * It also resolves conflicts where FedPlans have different FederatedOutput types. * * @param rootHop The root Hop node from which to start the plan enumeration. - * @param printTree A boolean flag indicating whether to print the federated plan tree. * @return The optimal FedPlan with the minimum cost for the entire DAG. */ - public static FedPlan enumerateOptimalFederatedPlanCost(Hop rootHop, boolean printTree) { - Set visited = new HashSet<>(); - Map> conflictMergeResolveMap = new HashMap<>(); - Map> resolveMap = new HashMap<>(); - detectPossibleConflicts(rootHop, visited, conflictMergeResolveMap, resolveMap); - + public static FedPlan enumerateFederatedPlanCost(Hop rootHop) { // Create new memo table to store all plan variants FederatedMemoTable memoTable = new FederatedMemoTable(); + // Recursively enumerate all possible plans - enumerateFederatedPlanCost(rootHop, memoTable, conflictMergeResolveMap); + enumerateFederatedPlanCost(rootHop, memoTable); // Return the minimum cost plan for the root node FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable); // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types - // double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); + double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); - // Optionally print the federated plan tree if requested - // if (printTree) FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost); + // Print the federated plan tree if requested + FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost); return optimalPlan; } - public static void detectPossibleConflicts(Hop hop, Set visited, Map> conflictMergeResolveMap, Map> resolveMap) { - for (Hop inputHop : hop.getInput()) { - if (visited.contains(hop.getHopID())) - return; - - visited.add(hop.getHopID()); - - if (inputHop.getParent().size() > 1) - findMergeResolvePaths(inputHop, conflictMergeResolveMap); - - detectPossibleConflicts(inputHop, visited, conflictMergeResolveMap); - } - } - - /** - * Identifies and marks conflicts and merge points in a Hop DAG starting from a conflicted Hop. - * A conflicted Hop is one that has multiple parent nodes, indicating potential execution path conflicts. - * - * The algorithm performs a breadth-first search (BFS) through the DAG to: - * 1. Start from a conflicted hop (one with multiple parents) - * 2. Traverse upward through parent nodes using BFS - * 3. Track merge points where execution paths converge - * 4. Mark nodes as resolved when all required merges are found - * 5. Track the count of merged hops at each merge point - * - * @param conflictedHop The Hop node with multiple parents that initiates the conflict detection - * @param conflictMergeResolveMap Map storing conflict and merge information for each Hop ID - */ - private static void findMergeResolvePaths(Hop conflictedHop, Map> conflictMergeResolveMap, Map resolveMap) { - // Initialize counter for remaining merges needed (parents - 1 since we need n-1 merges for n paths) - long conflictedHopID = conflictedHop.getHopID(); - int leftMergeCount = conflictedHop.getParent().size() - 1; - boolean isConverged = true; - - Set visited = new HashSet<>(); - Queue> BFSqueue = new LinkedList<>(); - - long convergeHopID = -1; - List topResolveHops = new ArrayList<>(); - List topResolveHopIDs = new ArrayList<>(); - - Map splitPointMap = new HashMap<>(); - Set mergeHopIDs = new HashSet<>(); - Set splitHopIDs = new HashSet<>(); - - // 여러 개의 부모 집합을 추가하는 경우 - for (Hop parentHop : conflictedHop.getParent()) { - SplitInfo splitInfo = new SplitInfo(parentHop); - BFSqueue.offer(Pair.of(parentHop, splitInfo)); - splitPointMap.put(parentHop.getHopID(), splitInfo); - } - - // 의문점 1. 모든 hop을 다 거치는가? - // 의문점 2. resolve Point 너머도 진행되지는 않았는가? 진행되었다면 지워야 한다. - - // Start BFS traversal through the DAG - while (!BFSqueue.isEmpty() || leftMergeCount > 0) { - Pair current = BFSqueue.poll(); - Hop currentHop = current.getKey(); - SplitInfo splitInfo = current.getValue(); - int numOfParent = currentHop.getParent().size(); - - if (numOfParent == 0) { - isConverged = false; - leftMergeCount--; - updateConflictResolveType(conflictMergeResolveMap, currentHop.getHopID(), conflictedHopID, false, false, ResolvedType.TOP); - topResolveHopIDs.add(currentHop.getHopID()); - topResolveHops.add(currentHop); - continue; - } - - // For nodes with multiple parents, update the merge count - // Each additional parent represents another path that needs to be merged - boolean isSplited = false; - if (numOfParent > 1){ - isSplited = true; - leftMergeCount += numOfParent - 1; - } - - // Process all parent nodes of the current node - for (Hop parentHop : currentHop.getParent()) { - long parentHopID = parentHop.getHopID(); - - if (isSplited) { - splitHopIDs.add(parentHopID); - } - - // Handle potential merge points (nodes with multiple inputs) - if (parentHop.getInput().size() > 1) { - // If node was previously visited, update merge information - if (visited.contains(parentHopID)) { - leftMergeCount--; - mergeHopIDs.add(parentHopID); - - if (leftMergeCount == 0 && isConverged){ - updateConflictResolveType(conflictMergeResolveMap, parentHopID, conflictedHopID, true, isSplited, ResolvedType.RESOLVE); - convergeHopID = parentHopID; - } else { - updateConflictResolveType(conflictMergeResolveMap, parentHopID, conflictedHopID, true, isSplited, ResolvedType.INNER_PATH); - } - } else { - // First visit to this node - initialize tracking information - visited.add(parentHopID); - BFSqueue.offer(parentHop); - addConflictResolveType(conflictMergeResolveMap, parentHopID, conflictedHopID, false, isSplited, ResolvedType.INNER_PATH); - } - } else { - // Handle nodes with single input - // No need to track visit count as these aren't merge points - BFSqueue.offer(parentHop); - addConflictResolveType(conflictMergeResolveMap, parentHopID, conflictedHopID, false, isSplited, ResolvedType.INNER_PATH); - } - } - } - - ResolveInfo resolveInfo; - - if (isConverged) { - resolveInfo = new ResolveInfo(conflictedHopID, convergeHopID, null, null); - } else { - for (Hop topHop : topResolveHops) { - boolean isfound = false; - - while (!isfound) { - // 공통점 1: 자신의 부모에서 더 이상 merge가 발생하지 않음 - // 공통점 2: 자식이 자식들이 split하였다면, 반드시 merge 되어야 함. - // 차이점 1: last-merge는 자신이 merge하나, first-split은 자신이 merge하지 않음. - // 차이점 2: last-merge는 자식이 split하지 않아도 되나, first-split은 자식이 반드시 split해야 함. - - for (Hop childHop : topHop.getInput()) { - // Todo: 여기부터 하자. - // visited, merge인지, split인지, split되면 merge 되었는지... - // bfs queues는 hop과 hop의 split point들을 가지고 다님. - // merge가 되면 마지막 split point를 지우고, 차례대로 지움. - - if (!visited.contains(childHop.getHopID())) - continue; - - - if (mergeHopIDs.contains(childHop.getHopID()) && childHop.getParent().size() == 1) { - isfound = true; - updateConflictResolveType(conflictMergeResolveMap, childHop.getHopID(), conflictedHopID, true, false, ResolvedType.FIRST_SPLIT_LAST_MERGE); - } - - if (mergeHopIDs.contains(childHop.getHopID()) && childHop.getParent().size() > 1) { - for (Hop childParentHop : childHop.getParent()) { - if (childParentHop == topHop) - continue; - - if (childParentHop is Merged) - - } - } - - if () - - if (childHop.getParent().size() == 1) { - if (mergeHopIDs.contains(childHop.getHopID())) { - if (childHop.getParent().size() == 1) { - isfound = true; - updateConflictResolveType(conflictMergeResolveMap, childHop.getHopID(), conflictedHopID, true, false, ResolvedType.FIRST_SPLIT_LAST_MERGE); - } else{ - - } - - } - - if (splitHopIDs.contains(childHop.getHopID())) { - - } - } - } - } - - - // // childHop이 merge혹은 initial parent일 때까지 내려가야함. - // if (childInfo.isMerged() || initialParentHopIDs.contains(childHop.getHopID())) { - // // 1. single-parent이면, child가 last-merge 혹은 first-split임 - // if (childHop.getParent().size() == 1) { - // isfound = true; - // updateConflictResolveType(conflictMergeResolveMap, childHop.getHopID(), conflictedHopID, true, false, ResolvedType.FIRST_SPLIT_LAST_MERGE); - // } else { - // ResolvedType resolvedType = conflictMergeResolveMap.get(childHop.getHopID()).stream() - // .filter(resolveInfo -> resolveInfo.conflictedHopID == conflictedHopID) - // .findFirst() - // .get() - // .getResolvedType(); - - // if (resolvedType != ResolvedType.INNER_PATH && resolvedType != ResolvedType.OUTER_PATH) { - // isfound = true; - // updateConflictResolveType(conflictMergeResolveMap, childHop.getHopID(), conflictedHopID, true, false, resolvedType); - // } - - // for (Hop parentHop : childHop.getParent()) { - // // childHop의 다른 parent가 merge되었는지 확인해야함. - // // merge한 hop을 기억해야함 - // // split한 hop이면 더해졌을 수도 있으니 그것도 문제임 - // // path에서 split 포인트를 기억하고 있어야 하나? - // // 나중에 모았다가 진행해야 하는 듯. - // // left merge count가 줄어드는 건 맞으니까. - // // 서로 엉킬수도 있나? - // } - // // 2. multi-parent이면, child가 first-split임. - // // 2-1: 다른 parent가 모두 merge하지 않으면, childHop은 last-merge임 - // // 2-2: 다른 parent가 하나라도 merge하면, currentHop이 first-split임. - // } - // // end case decision - // break; - // } else { - // currentHop = childHop; - // updateConflictResolveType(conflictMergeResolveMap, childHop.getHopID(), conflictedHopID, false, false, ResolvedType.OUTER_PATH); - // } - } - resolveInfo = new ResolveInfo(conflictedHopID, convergeHopID, topResolveHopIDs, firstSplitLastMergeHopIDs); - } - resolveMap.put(conflictedHopID, resolveInfo); - } - - public static class SplitInfo { - private Hop hopRef; - private int numOfParents; - private Set mergeParentHopIDs; - - public SplitInfo(Hop hopRef) { - this.hopRef = hopRef; - this.numOfParents = hopRef.getParent().size(); - this.mergeParentHopIDs = new HashSet<>(); - } - } - - private static void updateConflictResolveType(Map> conflictMergeResolveMap, long currentHopID, long conflictedHopID, boolean isMerged, boolean isSplited, ResolvedType resolvedType) { - List mergeInfoList = conflictMergeResolveMap.get(currentHopID); - mergeInfoList.stream() - .filter(info -> info.conflictedHopID == conflictedHopID) - .forEach(info -> { - info.isMerged |= isMerged; - info.isSplited |= isSplited; - info.resolvedType = resolvedType; - }); - } - - private static void addConflictResolveType(Map> conflictMergeResolveMap, - long currentHopID, long conflictedHopID, boolean isMerged, boolean isSplited, ResolvedType resolvedType) { - conflictMergeResolveMap.putIfAbsent(currentHopID, new ArrayList<>()); - conflictMergeResolveMap.get(currentHopID).add(new ConflictMergeResolveInfo(conflictedHopID, isMerged, isSplited, resolvedType)); - } - - public static class ResolveInfo { - private long conflictHopID; - private long convergeHopID; - private List topResolveHopIDs; - private List firstSplitLastMergeHopIDs; - - public ResolveInfo(long conflictHopID, long convergeHopID, List topResolveHopIDs, List firstSplitLastMergeHopIDs) { - this.conflictHopID = conflictHopID; - this.convergeHopID = convergeHopID; - this.topResolveHopIDs = topResolveHopIDs; - this.firstSplitLastMergeHopIDs = firstSplitLastMergeHopIDs; - } - } - - - /** * Recursively enumerates all possible federated execution plans for a Hop DAG. * For each node: @@ -352,123 +167,43 @@ public ResolveInfo(long conflictHopID, long convergeHopID, List topResolve * @param hop ? * @param memoTable ? */ - private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable, - Map> conflictMergeResolveMap) { + private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) { + int numInputs = hop.getInput().size(); // Process all input nodes first if not already in memo table for (Hop inputHop : hop.getInput()) { if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) { - enumerateFederatedPlanCost(inputHop, memoTable, conflictMergeResolveMap); + enumerateFederatedPlanCost(inputHop, memoTable); } } - long hopID = hop.getHopID(); - HopCommon hopCommon = new HopCommon(hop); - FederatedPlanCostEstimator.computeHopCost(hopCommon); - int numInputs = hop.getInput().size(); - double selfCost = hopCommon.getSelfCost(); - - // Todo: (구현) conflict hop의 initial parent 처리 - // Todo: (구현) resolve point 위에서 처리 (resolve, first-split & last-merge, top-level) - - if (!conflictMergeResolveMap.containsKey(hopID)){ - FedPlanVariants LOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.LOUT); - FedPlanVariants FOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.FOUT); - - // # of child, LOUT/FOUT of child - double[][] childCumulativeCost = new double[numInputs][2]; - // # of child - double[] childForwardingCost = new double[numInputs]; - - FederatedPlanCostEstimator.getChildCosts(hopCommon, memoTable, childCumulativeCost, childForwardingCost); - - for (int i = 0; i < (1 << numInputs); i++) { - List> planChilds = new ArrayList<>(); - double lOutCumulativeCost = selfCost; - double fOutCumulativeCost = selfCost; - - // For each input, determine if it should be FOUT or LOUT based on bit j in i - for (int j = 0; j < numInputs; j++) { - Hop inputHop = hop.getInput().get(j); - final int bit = (i & (1 << j)) != 0 ? 1 : 0; // bit 값 계산 (FOUT/LOUT 결정) - final FederatedOutput childType = (bit == 1) ? FederatedOutput.FOUT : FederatedOutput.LOUT; - planChilds.add(Pair.of(inputHop.getHopID(), childType)); - - lOutCumulativeCost += childCumulativeCost[j][bit]; - fOutCumulativeCost += childCumulativeCost[j][bit]; - // 비트 기반 산술 연산을 사용하여 전달 비용 추가 - fOutCumulativeCost += childForwardingCost[j] * (1 - bit); // bit == 0일 때 활성화 - lOutCumulativeCost += childForwardingCost[j] * bit; // bit == 1일 때 활성화 - } - LOutFedPlanVariants.addFedPlan(new FedPlan(lOutCumulativeCost, LOutFedPlanVariants, planChilds)); - FOutFedPlanVariants.addFedPlan(new FedPlan(fOutCumulativeCost, FOutFedPlanVariants, planChilds)); + // Generate all possible input combinations using binary representation + // i represents a specific combination of FOUT/LOUT for inputs + for (int i = 0; i < (1 << numInputs); i++) { + List> planChilds = new ArrayList<>(); + + // For each input, determine if it should be FOUT or LOUT based on bit j in i + for (int j = 0; j < numInputs; j++) { + Hop inputHop = hop.getInput().get(j); + // If bit j is set (1), use FOUT; otherwise use LOUT + FederatedOutput childType = ((i & (1 << j)) != 0) ? + FederatedOutput.FOUT : FederatedOutput.LOUT; + planChilds.add(Pair.of(inputHop.getHopID(), childType)); } - LOutFedPlanVariants.pruneFedPlans(); - FOutFedPlanVariants.pruneFedPlans(); - - memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT, LOutFedPlanVariants); - memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT, FOutFedPlanVariants); - } else { - List conflictMergeResolveInfos = conflictMergeResolveMap.get(hopID); - conflictMergeResolveInfos.sort(Comparator.comparingLong(ConflictMergeResolveInfo::getConflictedHopID)); - - ConflictedFedPlanVariants LOutFedPlanVariants = new ConflictedFedPlanVariants(hopCommon, FederatedOutput.LOUT, conflictMergeResolveInfos); - ConflictedFedPlanVariants FOutFedPlanVariants = new ConflictedFedPlanVariants(hopCommon, FederatedOutput.FOUT, conflictMergeResolveInfos); - int numOfConflictCombinations = 1 << conflictMergeResolveInfos.size(); - double mergeCost = FederatedPlanCostEstimator.computeMergeCost(conflictMergeResolveInfos, memoTable); - selfCost += mergeCost; - - // 2^(# of conflicts), # of childs, LOUT/FOUT of child - double[][][] childCumulativeCost = new double[numOfConflictCombinations][numInputs][2]; - int[][][] childForwardingBitMap = new int[numOfConflictCombinations][numInputs][2]; - double[] childForwardingCost = new double[numInputs]; // # of childs - - FederatedPlanCostEstimator.getConflictedChildCosts(hopCommon, memoTable, conflictMergeResolveInfos, childCumulativeCost, childForwardingBitMap, childForwardingCost); - - for (int i = 0; i < (1 << numInputs); i++) { - List> planChilds = new ArrayList<>(); - - for (int j = 0; j < numOfConflictCombinations; j++) { - LOutFedPlanVariants.cumulativeCost[j][i] = selfCost; - FOutFedPlanVariants.cumulativeCost[j][i] = selfCost; - } - - for (int j = 0; j < numInputs; j++) { - Hop inputHop = hop.getInput().get(j); - - final int bit = (i & (1 << j)) != 0 ? 1 : 0; // bit 값 계산 (FOUT/LOUT 결정) - final FederatedOutput childType = (bit == 1) ? FederatedOutput.FOUT : FederatedOutput.LOUT; - planChilds.add(Pair.of(inputHop.getHopID(), childType)); - - for (int k = 0; k < numOfConflictCombinations; k++) { - // 비트 기반 인덱스를 사용하여 누적 비용 업데이트 - LOutFedPlanVariants.cumulativeCost[k][i] += childCumulativeCost[k][j][bit]; - FOutFedPlanVariants.cumulativeCost[k][i] += childCumulativeCost[k][j][bit]; - - // 비트 기반 산술 연산을 사용하여 전달 비용 추가 - FOutFedPlanVariants.cumulativeCost[k][i] += childForwardingCost[j] * (1 - bit); // bit == 0일 때 활성화 - LOutFedPlanVariants.cumulativeCost[k][i] += childForwardingCost[j] * bit; // bit == 1일 때 활성화 - - if (mergeCost != 0) { - FederatedPlanCostEstimator.computeForwardingMergeCost(LOutFedPlanVariants.forwardingBitMap[k][i], - childForwardingBitMap[k][j][bit], conflictMergeResolveInfos, memoTable); - } + // Create and evaluate FOUT variant for current input combination + FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds); + FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable); - LOutFedPlanVariants.forwardingBitMap[k][i] |= childForwardingBitMap[k][j][bit]; - FOutFedPlanVariants.forwardingBitMap[k][i] |= childForwardingBitMap[k][j][bit]; - } - } - LOutFedPlanVariants.addFedPlan(new FedPlan(0, LOutFedPlanVariants, planChilds)); - FOutFedPlanVariants.addFedPlan(new FedPlan(0, FOutFedPlanVariants, planChilds)); - } - LOutFedPlanVariants.pruneConflictedFedPlans(); - FOutFedPlanVariants.pruneConflictedFedPlans(); - - memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT, LOutFedPlanVariants); - memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT, FOutFedPlanVariants); + // Create and evaluate LOUT variant for current input combination + FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds); + FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable); } + + // Prune MemoTable for hop. + memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.LOUT); + memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.FOUT); } /** @@ -480,14 +215,21 @@ private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoT * @return ? */ private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) { - FedPlan lOutFedPlan = memoTable.getFedPlanAfterPrune(HopID, FederatedOutput.LOUT); - FedPlan fOutFedPlan = memoTable.getFedPlanAfterPrune(HopID, FederatedOutput.FOUT); - - if (lOutFedPlan.getCumulativeCost() < fOutFedPlan.getCumulativeCost()){ - return lOutFedPlan; - } else{ - return fOutFedPlan; + FedPlanVariants fOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT); + FedPlanVariants lOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT); + + FedPlan minFOutFedPlan = fOutFedPlanVariants._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + FedPlan minlOutFedPlan = lOutFedPlanVariants._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + + if (Objects.requireNonNull(minFOutFedPlan).getTotalCost() + < Objects.requireNonNull(minlOutFedPlan).getTotalCost()) { + return minFOutFedPlan; } + return minlOutFedPlan; } /** @@ -574,47 +316,4 @@ private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, Federate // Return the cumulative additional cost for resolving conflicts return cumulativeAdditionalCost[0]; } - - /** - * Data structure to store conflict and merge information for a specific Hop. - * This class maintains the state of conflict resolution and merge operations - * for a given Hop in the execution plan. - */ - public static class ConflictMergeResolveInfo { - private long conflictedHopID; // ID of the Hop that originated the conflict - private boolean isMerged; - private boolean isSplited; - private ResolvedType resolvedType; - - public ConflictMergeResolveInfo(long conflictedHopID, boolean isMerged, boolean isSplited, ResolvedType resolvedType) { - this.conflictedHopID = conflictedHopID; - this.isMerged = isMerged; - this.isSplited = isSplited; - this.resolvedType = resolvedType; - } - - public long getConflictedHopID() { - return conflictedHopID; - } - - public boolean isMerged() { - return isMerged; - } - - public boolean isSplited() { - return isSplited; - } - - public ResolvedType getResolvedType() { - return resolvedType; - } - } - - public static enum ResolvedType { - INNER_PATH, - OUTER_PATH, - FIRST_SPLIT_LAST_MERGE, // 첫 분기점 또는 마지막 - RESOLVE, // 해결 지점 - TOP // 최상위 지점 - }; } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java index 3ae8b37a82c..f48332ac752 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -22,13 +22,8 @@ import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.cost.ComputeCost; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; -import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator.ConflictMergeResolveInfo; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.ConflictedFedPlanVariants; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; -import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.NoSuchElementException; import java.util.List; @@ -47,138 +42,43 @@ public class FederatedPlanCostEstimator { // Network bandwidth for data transfers between federated sites (1 Gbps) private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; - public static void getChildCosts(HopCommon hopCommon, FederatedMemoTable memoTable, double[][] childCumulativeCost, double[] childForwardingCost) { - List inputHops = hopCommon.hopRef.getInput(); - - for (int i = 0; i < inputHops.size(); i++) { - long childHopID = inputHops.get(i).getHopID(); - - FedPlan childLOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.LOUT); - FedPlan childFOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.FOUT); - - childCumulativeCost[i][0] = childLOutFedPlan.getCumulativeCost(); - childCumulativeCost[i][1] = childFOutFedPlan.getCumulativeCost(); - childForwardingCost[i] = childLOutFedPlan.getForwardingCost(); - } - } - - public static void getConflictedChildCosts(HopCommon hopCommon, FederatedMemoTable memoTable, List conflictMergeResolveInfos, - double[][][] childCumulativeCost, int[][][] childForwardingBitMap, double[] childForwardingCost) { - List inputHops = hopCommon.hopRef.getInput(); - int numConflictCombinations = 1 << conflictMergeResolveInfos.size(); - - for (int i = 0; i < inputHops.size(); i++) { - long childHopID = inputHops.get(i).getHopID(); - - FedPlanVariants childLOutVariants = memoTable.getFedPlanVariants(childHopID, FederatedOutput.LOUT); - FedPlanVariants childFOutVariants = memoTable.getFedPlanVariants(childHopID, FederatedOutput.FOUT); - - childForwardingCost[i] = childLOutVariants.getForwardingCost(); - - if (childLOutVariants instanceof ConflictedFedPlanVariants) { - FedPlan childLOutFedPlan = childLOutVariants.getFedPlanVariants().get(0); - FedPlan childFOutFedPlan = childFOutVariants.getFedPlanVariants().get(0); - - for (int j = 0; j < numConflictCombinations; j++) { - childCumulativeCost[j][i][0] = childLOutFedPlan.getCumulativeCost(); - childCumulativeCost[j][i][1] = childFOutFedPlan.getCumulativeCost(); - } - } - else { - ConflictedFedPlanVariants conflictedChildLOutVariants = (ConflictedFedPlanVariants) childLOutVariants; - ConflictedFedPlanVariants conflictedChildFOutVariants = (ConflictedFedPlanVariants) childFOutVariants; - - computeConflictedChildCosts(conflictMergeResolveInfos, conflictedChildLOutVariants, childCumulativeCost, childForwardingBitMap, i, 0); - computeConflictedChildCosts(conflictMergeResolveInfos, conflictedChildFOutVariants, childCumulativeCost, childForwardingBitMap, i, 1); - } - } - } - - private static void computeConflictedChildCosts(List conflictInfos, ConflictedFedPlanVariants conflictedChildVariants, - double[][][] childCumulativeCost, int[][][] childForwardingBitMap, int childIdx, int fedOutTypeIdx){ - int i = 0, j = 0; - int pLen = conflictInfos.size(); - int cLen = conflictedChildVariants.conflictInfos.size(); - int numConflictCombinations = 1 << conflictInfos.size(); - - // Step 1: 공통 제약 조건과 비공통 자식 위치 계산 - List common = new ArrayList<>(); - List nonCommonChildPos = new ArrayList<>(); - - while (i < pLen && j < cLen) { - long pHopID = conflictInfos.get(i).getConflictedHopID(); - long cHopID = conflictedChildVariants.conflictInfos.get(j).getConflictedHopID(); - - if (pHopID == cHopID) { - int pBitPos = pLen - 1 - i; - int cBitPos = cLen - 1 - j; - common.add(new CommonConstraint(pHopID, pBitPos, cBitPos)); - i++; - j++; - } else if (pHopID < cHopID) { - i++; - } else { - int cBitPos = cLen - 1 - j; - nonCommonChildPos.add(cBitPos); - j++; - } - } - - int restNumBits = nonCommonChildPos.size(); - for (int parentIdx = 0; parentIdx < numConflictCombinations; parentIdx++) { - // 공통 제약 조건을 기반으로 baseChildIdx 계산 - int baseChildIdx = 0; - for (CommonConstraint cc : common) { - int bit = (parentIdx >> cc.pBitPos) & 1; - baseChildIdx |= (bit << cc.cBitPos); - } - - // 최소 비용을 가진 자식 인덱스 찾기 - double minChildCost = Double.MAX_VALUE; - int minChildIdx = -1; - for (int restValue = 0; restValue < (1 << restNumBits); restValue++) { - int temp = 0; - for (int bitIdx = 0; bitIdx < restNumBits; bitIdx++) { - if (((restValue >> bitIdx) & 1) == 1) { - temp |= (1 << nonCommonChildPos.get(bitIdx)); - } - } - int tempChildIdx = baseChildIdx | temp; - if (conflictedChildVariants.cumulativeCost[tempChildIdx][0] < minChildCost) { - minChildCost = conflictedChildVariants.cumulativeCost[tempChildIdx][0]; - minChildIdx = tempChildIdx; - } - } - - // 자식의 isForwardBitMap을 부모의 비트 위치로 변환 - int childForwardBitMap = conflictedChildVariants.forwardingBitMap[minChildIdx][0]; - int convertedBitmask = 0; - for (CommonConstraint cc : common) { - int childBit = (childForwardBitMap >> cc.cBitPos) & 1; - if (childBit == 1) { - convertedBitmask |= (1 << cc.pBitPos); - } - } - - childCumulativeCost[parentIdx][childIdx][fedOutTypeIdx] = minChildCost; - childForwardingBitMap[parentIdx][childIdx][fedOutTypeIdx] = convertedBitmask; + /** + * Computes total cost of federated plan by: + * 1. Computing current node cost (if not cached) + * 2. Adding minimum-cost child plans + * 3. Including network transfer costs when needed + * + * @param currentPlan Plan to compute cost for + * @param memoTable Table containing all plan variants + */ + public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) { + double totalCost; + Hop currentHop = currentPlan.getHopRef(); + + // Step 1: Calculate current node costs if not already computed + if (currentPlan.getSelfCost() == 0) { + // Compute cost for current node (computation + memory access) + totalCost = computeCurrentCost(currentHop); + currentPlan.setSelfCost(totalCost); + // Calculate potential network transfer cost if federation type changes + currentPlan.setForwardingCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate())); + } else { + totalCost = currentPlan.getSelfCost(); } - } - - // Todo: (최적화) 추후에 MemoTable retrieve 하지 않게 최적화 가능 - public static double computeForwardingMergeCost(int parentBitmask, int childBitmask, List conflictInfos, FederatedMemoTable memoTable){ - int overlappingBits = parentBitmask & childBitmask; - double overlappingForwardingCost = 0.0; - - int pLen = conflictInfos.size(); - for (int b = 0; b < pLen; b++) { - int bitPos = pLen - 1 - b; - if ((overlappingBits & (1 << bitPos)) != 0) { - overlappingForwardingCost += memoTable.getFedPlanVariants(conflictInfos.get(b).getConflictedHopID(), FederatedOutput.LOUT).getForwardingCost(); - } + + // Step 2: Process each child plan and add their costs + for (Pair childPlanPair : currentPlan.getChildFedPlans()) { + // Find minimum cost child plan considering federation type compatibility + // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents + // because we're selecting child plans independently for each parent + FedPlan planRef = memoTable.getMinCostFedPlan(childPlanPair); + + // Add child plan cost (includes network transfer cost if federation types differ) + totalCost += planRef.getTotalCost() + planRef.getCondForwardingCost(currentPlan.getFedOutType()); } - return overlappingForwardingCost; + // Step 3: Set final cumulative cost including current node + currentPlan.setTotalCost(totalCost); } /** @@ -211,8 +111,8 @@ public static LinkedHashMap resolveConflictFedPlan(FederatedMe // Flags to check if the plan involves network transfer // Network transfer cost is calculated only once, even if it occurs multiple times - boolean isLOutNetTransfer = false; - boolean isFOutNetTransfer = false; + boolean isLOutForwarding = false; + boolean isFOutForwarding = false; // Determine the optimal federated output type based on the calculated costs FederatedOutput optimalFedOutType; @@ -238,40 +138,40 @@ public static LinkedHashMap resolveConflictFedPlan(FederatedMe if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. - fOutAdditionalCost += confilctFOutFedPlan.getCumulativeCost() - confilctLOutFedPlan.getCumulativeCost(); + fOutAdditionalCost += confilctFOutFedPlan.getTotalCost() - confilctLOutFedPlan.getTotalCost(); if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred // (CASE 5) If changing from calculated LOUT to current FOUT, network transfer cost occurs, but calculated later - isFOutNetTransfer = true; + isFOutForwarding = true; } else { // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later - isLOutNetTransfer = true; - lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + isLOutForwarding = true; + lOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost(); // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it - fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + fOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost(); } } else { - lOutAdditionalCost += confilctLOutFedPlan.getCumulativeCost() - confilctFOutFedPlan.getCumulativeCost(); + lOutAdditionalCost += confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost(); if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { - isLOutNetTransfer = true; + isLOutForwarding = true; } else { - isFOutNetTransfer = true; - lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); - fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + isFOutForwarding = true; + lOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost(); + fOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost(); } } } // Add network transfer costs if applicable - if (isLOutNetTransfer) { - lOutAdditionalCost += confilctLOutFedPlan.getForwardingCost(); + if (isLOutForwarding) { + lOutAdditionalCost += confilctLOutFedPlan.setForwardingCost(); } - if (isFOutNetTransfer) { - fOutAdditionalCost += confilctFOutFedPlan.getForwardingCost(); + if (isFOutForwarding) { + fOutAdditionalCost += confilctFOutFedPlan.setForwardingCost(); } // Determine the optimal federated output type based on the calculated costs @@ -299,36 +199,14 @@ public static LinkedHashMap resolveConflictFedPlan(FederatedMe } return resolvedFedPlanLinkedMap; } - - // Todo: (구현) forwarding bitmap을 본 뒤, merge cost 일일히 type에 따라 계산해야함. - public static double computeMergeCost(List conflictMergeResolveInfos, FederatedMemoTable memoTable){ - double mergeCost = 0; - - for (ConflictMergeResolveInfo conflictInfo: conflictMergeResolveInfos){ - int numOfMergedHops = conflictInfo.getNumOfMergedHops(); - - if (numOfMergedHops != 0){ - double selfCost = memoTable.getFedPlanVariants(conflictInfo.getConflictedHopID(), FederatedOutput.LOUT).getSelfCost(); - mergeCost += selfCost * numOfMergedHops; - } - } - - return mergeCost; - } - - public static void computeHopCost(HopCommon hopCommon){ - Hop hop = hopCommon.hopRef; - hopCommon.setSelfCost(computeSelfCost(hop)); - hopCommon.setForwardingCost(computeHopForwardingCost(hop.getOutputMemEstimate())); - } - + /** * Computes the cost for the current Hop node. * * @param currentHop The Hop node whose cost needs to be computed * @return The total cost for the current node's operation */ - private static double computeSelfCost(Hop currentHop){ + private static double computeCurrentCost(Hop currentHop){ double computeCost = ComputeCost.getHOPComputeCost(currentHop); double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); @@ -356,19 +234,7 @@ private static double computeHopMemoryAccessCost(double memSize) { * @param memSize Size of data to be transferred (in bytes) * @return Time cost for network transfer (in seconds) */ - private static double computeHopForwardingCost(double memSize) { + private static double computeHopNetworkAccessCost(double memSize) { return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; } - - public static class CommonConstraint { - long name; - int pBitPos; - int cBitPos; - - CommonConstraint(long name, int pBitPos, int cBitPos) { - this.name = name; - this.pBitPos = pBitPos; - this.cBitPos = cBitPos; - } - } } diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java index 20485588d32..9d69067a987 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java @@ -45,7 +45,7 @@ public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase @Override public void setUp() {} - + @Test public void testFederatedPlanCostEnumerator1() { runTest("FederatedPlanCostEnumeratorTest1.dml"); } @@ -55,6 +55,21 @@ public void setUp() {} @Test public void testFederatedPlanCostEnumerator3() { runTest("FederatedPlanCostEnumeratorTest3.dml"); } + @Test + public void testFederatedPlanCostEnumerator4() { runTest("FederatedPlanCostEnumeratorTest4.dml"); } + + @Test + public void testFederatedPlanCostEnumerator5() { runTest("FederatedPlanCostEnumeratorTest5.dml"); } + + @Test + public void testFederatedPlanCostEnumerator6() { runTest("FederatedPlanCostEnumeratorTest6.dml"); } + + @Test + public void testFederatedPlanCostEnumerator7() { runTest("FederatedPlanCostEnumeratorTest7.dml"); } + + @Test + public void testFederatedPlanCostEnumerator8() { runTest("FederatedPlanCostEnumeratorTest4.dml"); } + // Todo: Need to write test scripts for the federated version private void runTest( String scriptFilename ) { int index = scriptFilename.lastIndexOf(".dml"); @@ -80,8 +95,7 @@ private void runTest( String scriptFilename ) { dmlt.rewriteHopsDAG(prog); dmlt.constructLops(prog); - Hop hops = prog.getStatementBlocks().get(0).getHops().get(0); - FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true); + FederatedPlanCostEnumerator.enumerateProgram(prog); } catch (IOException e) { e.printStackTrace(); diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml index 2721bbcbaf6..19b65223305 100644 --- a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml @@ -19,7 +19,7 @@ # #------------------------------------------------------------- -for( i in 1:100 ) +for( i in 1:10 ) { b = i + 1; print(b); diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml index b95ae1b5bb0..4a0ca5eaa72 100644 --- a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml @@ -19,16 +19,13 @@ # #------------------------------------------------------------- -A = matrix(7, rows=10, cols=10) -b = rand(rows = 1, cols = ncol(A), min = 1, max = 2); +A = matrix(7,10,10); +b = rand(rows = 1, cols = ncol(A), min = 1, max = 2) +d = sum(b) * 8 i = 0 - -while (sum(b) < i) { - i = i + 1 - b = b + i - A = A * A - s = b %*% A - print(mean(s)) +while(sum(b) < d){ + i = i + 1 + b = b + i + s = b %*% A + print(mean(s)) } -c = sqrt(A) -print(sum(c)) \ No newline at end of file From fd9479dcad561b601a3a2b339308eb163d470b7e Mon Sep 17 00:00:00 2001 From: min-guk Date: Tue, 11 Feb 2025 18:04:00 +0900 Subject: [PATCH 7/9] program level fed planer --- .../FederatedPlanCostEnumerator.java | 55 +++++++++++++------ .../FederatedPlanCostEnumeratorTest.java | 9 ++- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java index f626e27c1bc..692522adbde 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -28,6 +28,8 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.DataOp; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; @@ -50,8 +52,11 @@ */ public class FederatedPlanCostEnumerator { public static void enumerateProgram(DMLProgram prog) { + FederatedMemoTable memoTable = new FederatedMemoTable(); + Map transTable = new HashMap<>(); + for(StatementBlock sb : prog.getStatementBlocks()) - enumerateStatementBlock(sb); + enumerateStatementBlock(sb, memoTable, transTable); } /** @@ -61,7 +66,7 @@ public static void enumerateProgram(DMLProgram prog) { * * @param sb The statement block to enumerate. */ - public static void enumerateStatementBlock(StatementBlock sb) { + public static void enumerateStatementBlock(StatementBlock sb, FederatedMemoTable memoTable, Map transTable) { // While enumerating the program, recursively determine the optimal FedPlan and MemoTable // for each statement block and statement. // 1. How to recursively integrate optimal FedPlans and MemoTables across statements and statement blocks? @@ -75,12 +80,12 @@ public static void enumerateStatementBlock(StatementBlock sb) { IfStatementBlock isb = (IfStatementBlock) sb; IfStatement istmt = (IfStatement)isb.getStatement(0); - enumerateFederatedPlanCost(isb.getPredicateHops()); + enumerateHopDAG(isb.getPredicateHops(), memoTable, transTable); for (StatementBlock csb : istmt.getIfBody()) - enumerateStatementBlock(csb); + enumerateStatementBlock(csb, memoTable, transTable); for (StatementBlock csb : istmt.getElseBody()) - enumerateStatementBlock(csb); + enumerateStatementBlock(csb, memoTable, transTable); // Todo: 1. apply iteration weight to csbFedPlans (if: 0.5, else: 0.5) // Todo: 2. Merge predFedPlans @@ -89,12 +94,12 @@ public static void enumerateStatementBlock(StatementBlock sb) { ForStatement fstmt = (ForStatement)fsb.getStatement(0); - enumerateFederatedPlanCost(fsb.getFromHops()); - enumerateFederatedPlanCost(fsb.getToHops()); - enumerateFederatedPlanCost(fsb.getIncrementHops()); + enumerateHopDAG(fsb.getFromHops(), memoTable, transTable); + enumerateHopDAG(fsb.getToHops(), memoTable, transTable); + enumerateHopDAG(fsb.getIncrementHops(), memoTable, transTable); for (StatementBlock csb : fstmt.getBody()) - enumerateStatementBlock(csb); + enumerateStatementBlock(csb, memoTable, transTable); // Todo: 1. get(predict) # of Iterations // Todo: 2. apply iteration weight to csbFedPlans @@ -102,11 +107,11 @@ public static void enumerateStatementBlock(StatementBlock sb) { } else if (sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) sb; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); - enumerateFederatedPlanCost(wsb.getPredicateHops()); + enumerateHopDAG(wsb.getPredicateHops(), memoTable, transTable); ArrayList csbFedPlans = new ArrayList<>(); for (StatementBlock csb : wstmt.getBody()) - enumerateStatementBlock(csb); + enumerateStatementBlock(csb, memoTable, transTable); // Todo: 1. get(predict) # of Iterations // Todo: 2. apply iteration weight to csbFedPlans @@ -115,13 +120,13 @@ public static void enumerateStatementBlock(StatementBlock sb) { FunctionStatementBlock fsb = (FunctionStatementBlock)sb; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); for (StatementBlock csb : fstmt.getBody()) - enumerateStatementBlock(csb); + enumerateStatementBlock(csb, memoTable, transTable); // Todo: 1. Merge csbFedPlans } else { //generic (last-level) if( sb.getHops() != null ) for( Hop c : sb.getHops() ) - enumerateFederatedPlanCost(c); + enumerateHopDAG(c, memoTable, transTable); } } @@ -133,12 +138,11 @@ public static void enumerateStatementBlock(StatementBlock sb) { * @param rootHop The root Hop node from which to start the plan enumeration. * @return The optimal FedPlan with the minimum cost for the entire DAG. */ - public static FedPlan enumerateFederatedPlanCost(Hop rootHop) { + public static FedPlan enumerateHopDAG(Hop rootHop, FederatedMemoTable memoTable, Map transTable) { // Create new memo table to store all plan variants - FederatedMemoTable memoTable = new FederatedMemoTable(); // Recursively enumerate all possible plans - enumerateFederatedPlanCost(rootHop, memoTable); + enumerateHop(rootHop, memoTable, transTable); // Return the minimum cost plan for the root node FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable); @@ -167,14 +171,29 @@ public static FedPlan enumerateFederatedPlanCost(Hop rootHop) { * @param hop ? * @param memoTable ? */ - private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) { + private static void enumerateHop(Hop hop, FederatedMemoTable memoTable, Map transTable) { int numInputs = hop.getInput().size(); // Process all input nodes first if not already in memo table for (Hop inputHop : hop.getInput()) { if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) { - enumerateFederatedPlanCost(inputHop, memoTable); + enumerateHop(inputHop, memoTable, transTable); + } + } + + if (hop instanceof DataOp + && ((DataOp)hop).getOp()== Types.OpOpData.TRANSIENTWRITE + && !(hop.getName().equals("__pred"))){ + transTable.put(hop.getName(), hop.getHopID()); + } + + if (hop instanceof DataOp + && !(hop.getName().equals("__pred"))){ + if (((DataOp)hop).getOp()== Types.OpOpData.TRANSIENTWRITE){ + transTable.put(hop.getName(), hop.getHopID()); + } else if (((DataOp)hop).getOp()== Types.OpOpData.TRANSIENTREAD){ + long rWriteHopID = transTable.get(hop.getName()); } } diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java index 9d69067a987..8fd17998e96 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java @@ -20,9 +20,15 @@ package org.apache.sysds.test.component.federated; import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; +import org.apache.sysds.common.Types; import org.apache.sysds.hops.Hop; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.test.functions.federated.algorithms.FederatedL2SVMTest; import org.junit.Assert; import org.junit.Test; import org.apache.sysds.api.DMLScript; @@ -67,9 +73,6 @@ public void setUp() {} @Test public void testFederatedPlanCostEnumerator7() { runTest("FederatedPlanCostEnumeratorTest7.dml"); } - @Test - public void testFederatedPlanCostEnumerator8() { runTest("FederatedPlanCostEnumeratorTest4.dml"); } - // Todo: Need to write test scripts for the federated version private void runTest( String scriptFilename ) { int index = scriptFilename.lastIndexOf(".dml"); From 879dc4857d75c472d5e81179b0e3d99ad027b58f Mon Sep 17 00:00:00 2001 From: min-guk Date: Tue, 25 Feb 2025 04:15:08 +0900 Subject: [PATCH 8/9] program level fed planer --- graph.py | 247 ++++++++ .../hops/fedplanner/FederatedMemoTable.java | 171 ++---- .../fedplanner/FederatedMemoTablePrinter.java | 133 +++-- .../FederatedPlanCostEnumerator.java | 535 +++++++++++++----- .../FederatedPlanCostEstimator.java | 178 +++--- .../FederatedPlanCostEnumeratorTest.java | 18 +- .../FederatedPlanCostEnumeratorTest5.dml | 2 +- .../FederatedPlanCostEnumeratorTest6.dml | 19 +- 8 files changed, 909 insertions(+), 394 deletions(-) create mode 100644 graph.py diff --git a/graph.py b/graph.py new file mode 100644 index 00000000000..7b0ba6c7a79 --- /dev/null +++ b/graph.py @@ -0,0 +1,247 @@ +import sys +import re +import networkx as nx +import matplotlib.pyplot as plt + +try: + import pygraphviz + from networkx.drawing.nx_agraph import graphviz_layout + HAS_PYGRAPHVIZ = True +except ImportError: + HAS_PYGRAPHVIZ = False + print("[WARNING] pygraphviz not found. Please install via 'pip install pygraphviz'.\n" + "If not installed, we will use an alternative layout (spring_layout).") + + +def parse_line(line: str): + """ + Parse a single line from the trace file to extract: + - Node ID + - Operation (hop name) + - Kind (e.g., FOUT, LOUT, NREF) + - Total cost + - Weight + - Refs (list of IDs that this node depends on) + """ + + # 1) Match a node ID in the form of "(R)" or "()" + match_id = re.match(r'^\((R|\d+)\)', line) + if not match_id: + return None + node_id = match_id.group(1) + + # 2) The remaining string after the node ID + after_id = line[match_id.end():].strip() + + # Extract operation (hop name) before the first "[" + match_label = re.search(r'^(.*?)\s*\[', after_id) + if match_label: + operation = match_label.group(1).strip() + else: + operation = after_id.strip() + + # 3) Extract the kind (content inside the first pair of brackets "[]") + match_bracket = re.search(r'\[([^\]]+)\]', after_id) + if match_bracket: + kind = match_bracket.group(1).strip() + else: + kind = "" + + # 4) Extract total and weight from the content inside curly braces "{}" + total = "" + weight = "" + match_curly = re.search(r'\{([^}]+)\}', line) + if match_curly: + curly_content = match_curly.group(1) + m_total = re.search(r'Total:\s*([\d\.]+)', curly_content) + m_weight = re.search(r'Weight:\s*([\d\.]+)', curly_content) + if m_total: + total = m_total.group(1) + if m_weight: + weight = m_weight.group(1) + + # 5) Extract reference nodes: look for the first parenthesis containing numbers after the hop name + match_refs = re.search(r'\(\s*(\d+(?:,\d+)*)\s*\)', after_id) + if match_refs: + ref_str = match_refs.group(1) + refs = [r.strip() for r in ref_str.split(',') if r.strip().isdigit()] + else: + refs = [] + + return { + 'node_id': node_id, + 'operation': operation, + 'kind': kind, + 'total': total, + 'weight': weight, + 'refs': refs + } + + +def build_dag_from_file(filename: str): + """ + Read a trace file line by line and build a directed acyclic graph (DAG) using NetworkX. + """ + G = nx.DiGraph() + with open(filename, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + + info = parse_line(line) + if not info: + continue + + node_id = info['node_id'] + operation = info['operation'] + kind = info['kind'] + total = info['total'] + weight = info['weight'] + refs = info['refs'] + + # Add node with attributes + G.add_node(node_id, label=operation, kind=kind, total=total, weight=weight) + + # Add edges from references to this node + for r in refs: + if r not in G: + G.add_node(r, label=r, kind="", total="", weight="") + G.add_edge(r, node_id) + return G + + +def main(): + """ + Main function that: + - Reads a filename from command-line arguments + - Builds a DAG from the file + - Draws and displays the DAG using matplotlib + """ + + # Get filename from command-line argument + if len(sys.argv) < 2: + print("[ERROR] No filename provided.\nUsage: python plot_federated_dag.py ") + sys.exit(1) + filename = sys.argv[1] + + print(f"[INFO] Running with filename '{filename}'") + + # Build the DAG + G = build_dag_from_file(filename) + + # Print debug info: nodes and edges + print("Nodes:", G.nodes(data=True)) + print("Edges:", list(G.edges())) + + # Decide on layout + if HAS_PYGRAPHVIZ: + # graphviz_layout with rankdir=BT (bottom to top), etc. + pos = graphviz_layout(G, prog='dot', args='-Grankdir=BT -Gnodesep=0.5 -Granksep=0.8') + else: + # Fallback layout if pygraphviz is not installed + pos = nx.spring_layout(G, seed=42) + + # Dynamically adjust figure size based on number of nodes + node_count = len(G.nodes()) + fig_width = 10 + node_count / 10.0 + fig_height = 6 + node_count / 10.0 + plt.figure(figsize=(fig_width, fig_height), facecolor='white', dpi=300) + ax = plt.gca() + ax.set_facecolor('white') + + # Generate labels for each node in the format: + # node_id: operation_name + # C (W) + labels = { + n: f"{n}: {G.nodes[n].get('label', n)}\n C{G.nodes[n].get('total', '')} (W{G.nodes[n].get('weight', '')})" + for n in G.nodes() + } + + # Function to determine color based on 'kind' + def get_color(n): + k = G.nodes[n].get('kind', '').lower() + if k == 'fout': + return 'tomato' + elif k == 'lout': + return 'dodgerblue' + elif k == 'nref': + return 'mediumpurple' + else: + return 'mediumseagreen' + + # Determine node shapes based on operation name: + # - '^' (triangle) if the label contains "twrite" + # - 's' (square) if the label contains "tread" + # - 'o' (circle) otherwise + triangle_nodes = [n for n in G.nodes() if 'twrite' in G.nodes[n].get('label', '').lower()] + square_nodes = [n for n in G.nodes() if 'tread' in G.nodes[n].get('label', '').lower()] + other_nodes = [ + n for n in G.nodes() + if 'twrite' not in G.nodes[n].get('label', '').lower() and + 'tread' not in G.nodes[n].get('label', '').lower() + ] + + # Colors for each group + triangle_colors = [get_color(n) for n in triangle_nodes] + square_colors = [get_color(n) for n in square_nodes] + other_colors = [get_color(n) for n in other_nodes] + + # Draw nodes group-wise + node_collection_triangle = nx.draw_networkx_nodes( + G, pos, nodelist=triangle_nodes, node_size=800, + node_color=triangle_colors, node_shape='^', ax=ax + ) + node_collection_square = nx.draw_networkx_nodes( + G, pos, nodelist=square_nodes, node_size=800, + node_color=square_colors, node_shape='s', ax=ax + ) + node_collection_other = nx.draw_networkx_nodes( + G, pos, nodelist=other_nodes, node_size=800, + node_color=other_colors, node_shape='o', ax=ax + ) + + # Set z-order for nodes, edges, and labels + node_collection_triangle.set_zorder(1) + node_collection_square.set_zorder(1) + node_collection_other.set_zorder(1) + + edge_collection = nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle='->', ax=ax) + if isinstance(edge_collection, list): + for ec in edge_collection: + ec.set_zorder(2) + else: + edge_collection.set_zorder(2) + + label_dict = nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, ax=ax) + for text in label_dict.values(): + text.set_zorder(3) + + # Set the title + plt.title("Program Level Federated Plan", fontsize=14, fontweight="bold") + + # Provide a small legend on the top-right or top-left + plt.text(1, 1, + "[LABEL]\n hopID: hopName\n C(Total) (W(Weight))", + fontsize=12, ha='right', va='top', transform=ax.transAxes) + + # Example mini-legend for different 'kind' values + plt.scatter(0.05, 0.95, color='dodgerblue', s=200, transform=ax.transAxes) + plt.scatter(0.18, 0.95, color='tomato', s=200, transform=ax.transAxes) + plt.scatter(0.31, 0.95, color='mediumpurple', s=200, transform=ax.transAxes) + + plt.text(0.08, 0.95, "LOUT", fontsize=12, va='center', transform=ax.transAxes) + plt.text(0.21, 0.95, "FOUT", fontsize=12, va='center', transform=ax.transAxes) + plt.text(0.34, 0.95, "NREF", fontsize=12, va='center', transform=ax.transAxes) + + plt.axis("off") + + # Save the plot to a file with the same name as the input file, but with a .png extension + output_filename = f"{filename.rsplit('.', 1)[0]}.png" + plt.savefig(output_filename, format='png', dpi=300, bbox_inches='tight') + + plt.show() + + +if __name__ == '__main__': + main() diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java index 82d05e4f286..b35723b8173 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -19,15 +19,15 @@ package org.apache.sysds.hops.fedplanner; -import org.apache.sysds.hops.Hop; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.ArrayList; import java.util.Map; +import org.apache.sysds.hops.Hop; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; /** * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. @@ -38,45 +38,8 @@ public class FederatedMemoTable { // Maps Hop ID and fedOutType pairs to their plan variants private final Map, FedPlanVariants> hopMemoTable = new HashMap<>(); - /** - * Adds a new federated plan to the memo table. - * Creates a new variant list if none exists for the given Hop and fedOutType. - * - * @param hop The Hop node - * @param fedOutType The federated output type - * @param planChilds List of child plan references - * @return The newly created FedPlan - */ - public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List> planChilds) { - long hopID = hop.getHopID(); - FedPlanVariants fedPlanVariantList; - - if (contains(hopID, fedOutType)) { - fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); - } else { - fedPlanVariantList = new FedPlanVariants(hop, fedOutType); - hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList); - } - - FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList); - fedPlanVariantList.addFedPlan(newPlan); - - return newPlan; - } - - /** - * Retrieves the minimum cost child plan considering the parent's output type. - * The cost is calculated using getParentViewCost to account for potential type mismatches. - */ - public FedPlan getMinCostFedPlan(Pair fedPlanPair) { - FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); - return fedPlanVariantList._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - } - - public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType) { - return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + public void addFedPlanVariants(long hopID, FederatedOutput fedOutType, FedPlanVariants fedPlanVariants) { + hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariants); } public FedPlanVariants getFedPlanVariants(Pair fedPlanPair) { @@ -84,53 +47,47 @@ public FedPlanVariants getFedPlanVariants(Pair fedPlanPai } public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) { - // Todo: Consider whether to verify if pruning has been performed FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); return fedPlanVariantList._fedPlanVariants.get(0); } public FedPlan getFedPlanAfterPrune(Pair fedPlanPair) { - // Todo: Consider whether to verify if pruning has been performed FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); return fedPlanVariantList._fedPlanVariants.get(0); } - /** - * Checks if the memo table contains an entry for a given Hop and fedOutType. - * - * @param hopID The Hop ID. - * @param fedOutType The associated fedOutType. - * @return True if the entry exists, false otherwise. - */ public boolean contains(long hopID, FederatedOutput fedOutType) { return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); } /** - * Prunes the specified entry in the memo table, retaining only the minimum-cost - * FedPlan for the given Hop ID and federated output type. - * - * @param hopID The ID of the Hop to prune - * @param federatedOutput The federated output type associated with the Hop - */ - public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) { - hopMemoTable.get(new ImmutablePair<>(hopID, federatedOutput)).prune(); - } - - /** - * Represents common properties and costs associated with a Hop. - * This class holds a reference to the Hop and tracks its execution and network transfer costs. + * Represents a single federated execution plan with its associated costs and dependencies. + * This class contains: + * 1. selfCost: Cost of the current hop (computation + input/output memory access). + * 2. cumulativeCost: Total cost including this plan's selfCost and all child plans' cumulativeCost. + * 3. forwardingCost: Network transfer cost for this plan to the parent plan. + * + * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs. */ - public static class HopCommon { - protected final Hop hopRef; // Reference to the associated Hop - protected double selfCost; // Current execution cost (compute + memory access) - protected double forwardingCost; // Network transfer cost + public static class FedPlan { + private double cumulativeCost; // Total cost = sum of selfCost + cumulativeCost of child plans + private final FedPlanVariants fedPlanVariants; // Reference to variant list + private final List> childFedPlans; // Child plan references - protected HopCommon(Hop hopRef) { - this.hopRef = hopRef; - this.selfCost = 0; - this.forwardingCost = 0; + public FedPlan(double cumulativeCost, FedPlanVariants fedPlanVariants, List> childFedPlans) { + this.cumulativeCost = cumulativeCost; + this.fedPlanVariants = fedPlanVariants; + this.childFedPlans = childFedPlans; } + + public Hop getHopRef() {return fedPlanVariants.hopCommon.getHopRef();} + public long getHopID() {return fedPlanVariants.hopCommon.getHopRef().getHopID();} + public FederatedOutput getFedOutType() {return fedPlanVariants.getFedOutType();} + public double getCumulativeCost() {return cumulativeCost;} + public double getSelfCost() {return fedPlanVariants.hopCommon.getSelfCost();} + public double getForwardingCost() {return fedPlanVariants.hopCommon.getForwardingCost();} + public double getWeight() {return fedPlanVariants.hopCommon.getWeight();} + public List> getChildFedPlans() {return childFedPlans;} } /** @@ -143,21 +100,22 @@ public static class FedPlanVariants { private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) protected List _fedPlanVariants; // List of plan variants - public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) { - this.hopCommon = new HopCommon(hopRef); + public FedPlanVariants(HopCommon hopCommon, FederatedOutput fedOutType) { + this.hopCommon = hopCommon; this.fedOutType = fedOutType; this._fedPlanVariants = new ArrayList<>(); } + public boolean isEmpty() {return _fedPlanVariants.isEmpty();} public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} public List getFedPlanVariants() {return _fedPlanVariants;} - public boolean isEmpty() {return _fedPlanVariants.isEmpty();} + public FederatedOutput getFedOutType() {return fedOutType;} - public void prune() { + public void pruneFedPlans() { if (_fedPlanVariants.size() > 1) { - // Find the FedPlan with the minimum cost + // Find the FedPlan with the minimum cumulative cost FedPlan minCostPlan = _fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .min(Comparator.comparingDouble(FedPlan::getCumulativeCost)) .orElse(null); // Retain only the minimum cost plan @@ -168,47 +126,28 @@ public void prune() { } /** - * Represents a single federated execution plan with its associated costs and dependencies. - * This class contains: - * 1. selfCost: Cost of current hop (compute + input/output memory access) - * 2. totalCost: Cumulative cost including this plan and all child plans - * 3. forwardingCost: Network transfer cost for this plan to parent plan. - * - * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs. + * Represents common properties and costs associated with a Hop. + * This class holds a reference to the Hop and tracks its execution and network forwarding (transfer) costs. */ - public static class FedPlan { - private double totalCost; // Total cost including child plans - private final FedPlanVariants fedPlanVariants; // Reference to variant list - private final List> childFedPlans; // Child plan references + public static class HopCommon { + protected final Hop hopRef; // Reference to the associated Hop + protected double selfCost; // Cost of the hop's computation and memory access + protected double forwardingCost; // Cost of forwarding the hop's output to its parent + protected double weight; // Weight used to calculate cost based on hop execution frequency - public FedPlan(List> childFedPlans, FedPlanVariants fedPlanVariants) { - this.totalCost = 0; - this.childFedPlans = childFedPlans; - this.fedPlanVariants = fedPlanVariants; + public HopCommon(Hop hopRef, double weight) { + this.hopRef = hopRef; + this.selfCost = 0; + this.forwardingCost = 0; + this.weight = weight; } - public void setTotalCost(double totalCost) {this.totalCost = totalCost;} - public void setSelfCost(double selfCost) {fedPlanVariants.hopCommon.selfCost = selfCost;} - public void setForwardingCost(double forwardingCost) {fedPlanVariants.hopCommon.forwardingCost = forwardingCost;} - public void applyIterationWeight(int iteration) {totalCost *= iteration;} - - public Hop getHopRef() {return fedPlanVariants.hopCommon.hopRef;} - public long getHopID() {return fedPlanVariants.hopCommon.hopRef.getHopID();} - public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;} - public double getTotalCost() {return totalCost;} - public double getSelfCost() {return fedPlanVariants.hopCommon.selfCost;} - public double setForwardingCost() {return fedPlanVariants.hopCommon.forwardingCost;} - public List> getChildFedPlans() {return childFedPlans;} + public Hop getHopRef() {return hopRef;} + public double getSelfCost() {return selfCost;} + public double getForwardingCost() {return forwardingCost;} + public double getWeight() {return weight;} - /** - * Calculates the conditional network transfer cost based on output type compatibility. - * Returns 0 if output types match, otherwise returns the network transfer cost. - * @param parentFedOutType The federated output type of the parent plan. - * @return The conditional network transfer cost. - */ - public double getCondForwardingCost(FederatedOutput parentFedOutType) { - if (parentFedOutType == getFedOutType()) return 0; - return fedPlanVariants.hopCommon.forwardingCost; - } + protected void setSelfCost(double selfCost) {this.selfCost = selfCost;} + protected void setForwardingCost(double forwardingCost) {this.forwardingCost = forwardingCost;} } } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java index 391868efcd7..ddddc641d2e 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java @@ -3,7 +3,9 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; import org.apache.sysds.runtime.instructions.fed.FEDInstruction; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import java.util.HashSet; import java.util.List; @@ -19,11 +21,48 @@ public class FederatedMemoTablePrinter { * @param memoTable The memoization table containing FedPlan variants * @param additionalTotalCost The additional cost to be printed once */ - public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, FederatedMemoTable memoTable, - double additionalTotalCost) { + public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, Set rootHopStatSet, + FederatedMemoTable memoTable, double additionalTotalCost) { System.out.println("Additional Cost: " + additionalTotalCost); - Set visited = new HashSet<>(); + Set visited = new HashSet<>(); printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0); + + for (Hop hop : rootHopStatSet) { + FedPlan plan = memoTable.getFedPlanAfterPrune(hop.getHopID(), FederatedOutput.LOUT); + printNotReferencedFedPlanRecursive(plan, memoTable, visited, 1); + } + } + + /** + * Helper method to recursively print the FedPlan tree. + * + * @param plan The current FedPlan to print + * @param visited Set to keep track of visited FedPlans (prevents cycles) + * @param depth The current depth level for indentation + */ + private static void printNotReferencedFedPlanRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, + Set visited, int depth) { + long hopID = plan.getHopRef().getHopID(); + + if (visited.contains(hopID)) { + return; + } + + visited.add(hopID); + printFedPlan(plan, depth, true); + + // Process child nodes + List> childFedPlanPairs = plan.getChildFedPlans(); + for (int i = 0; i < childFedPlanPairs.size(); i++) { + Pair childFedPlanPair = childFedPlanPairs.get(i); + FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); + if (childVariants == null || childVariants.isEmpty()) + continue; + + for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { + printNotReferencedFedPlanRecursive(childPlan, memoTable, visited, depth + 1); + } + } } /** @@ -34,40 +73,83 @@ public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, Fede * @param depth The current depth level for indentation */ private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, - Set visited, int depth) { - if (plan == null || visited.contains(plan)) { + Set visited, int depth) { + long hopID = 0; + + if (depth == 0) { + hopID = -1; + } else { + hopID = plan.getHopRef().getHopID(); + } + + if (visited.contains(hopID)) { return; } - visited.add(plan); + visited.add(hopID); + printFedPlan(plan, depth, false); - Hop hop = plan.getHopRef(); - StringBuilder sb = new StringBuilder(); + // Process child nodes + List> childFedPlanPairs = plan.getChildFedPlans(); + for (int i = 0; i < childFedPlanPairs.size(); i++) { + Pair childFedPlanPair = childFedPlanPairs.get(i); + FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); + if (childVariants == null || childVariants.isEmpty()) + continue; + + for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { + printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1); + } + } + } - // Add FedPlan information - sb.append(String.format("(%d) ", plan.getHopRef().getHopID())) - .append(plan.getHopRef().getOpString()) - .append(" [") - .append(plan.getFedOutType()) - .append("]"); + private static void printFedPlan(FederatedMemoTable.FedPlan plan, int depth, boolean isNotReferenced) { + StringBuilder sb = new StringBuilder(); + Hop hop = null; + + if (depth == 0){ + sb.append("(R) ROOT [Root]"); + } else { + hop = plan.getHopRef(); + // Add FedPlan information + sb.append(String.format("(%d) ", hop.getHopID())) + .append(hop.getOpString()) + .append(" ["); + + if (isNotReferenced) { + sb.append("NRef"); + } else{ + sb.append(plan.getFedOutType()); + } + sb.append("]"); + } StringBuilder childs = new StringBuilder(); childs.append(" ("); + boolean childAdded = false; - for( Hop input : hop.getInput()){ + for (Pair childPair : plan.getChildFedPlans()){ childs.append(childAdded?",":""); - childs.append(input.getHopID()); + childs.append(childPair.getLeft()); childAdded = true; } + childs.append(")"); + if( childAdded ) sb.append(childs.toString()); + if (depth == 0){ + sb.append(String.format(" {Total: %.1f}", plan.getCumulativeCost())); + System.out.println(sb); + return; + } - sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", - plan.getTotalCost(), + sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f, Weight: %.1f}", + plan.getCumulativeCost(), plan.getSelfCost(), - plan.setForwardingCost())); + plan.getForwardingCost(), + plan.getWeight())); // Add matrix characteristics sb.append(" [") @@ -103,18 +185,5 @@ private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, F } System.out.println(sb); - - // Process child nodes - List> childFedPlanPairs = plan.getChildFedPlans(); - for (int i = 0; i < childFedPlanPairs.size(); i++) { - Pair childFedPlanPair = childFedPlanPairs.get(i); - FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); - if (childVariants == null || childVariants.isEmpty()) - continue; - - for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { - printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1); - } - } } } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java index 692522adbde..f32bc4a76b9 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -21,18 +21,24 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Comparator; import java.util.HashMap; -import java.util.Objects; import java.util.LinkedHashMap; +import java.util.Optional; +import java.util.Set; +import java.util.HashSet; import org.apache.commons.lang3.tuple.Pair; + import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.sysds.common.Types; import org.apache.sysds.hops.DataOp; import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.UnaryOp; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.parser.DMLProgram; import org.apache.sysds.parser.ForStatement; import org.apache.sysds.parser.ForStatementBlock; @@ -44,211 +50,440 @@ import org.apache.sysds.parser.WhileStatement; import org.apache.sysds.parser.WhileStatementBlock; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; +import org.apache.sysds.runtime.util.UtilFunctions; -/** - * Enumerates and evaluates all possible federated execution plans for a given Hop DAG. - * Works with FederatedMemoTable to store plan variants and FederatedPlanCostEstimator - * to compute their costs. - */ public class FederatedPlanCostEnumerator { - public static void enumerateProgram(DMLProgram prog) { + private static final double DEFAULT_LOOP_WEIGHT = 10.0; + private static final double DEFAULT_IF_ELSE_WEIGHT = 0.5; + + /** + * Enumerates the entire DML program to generate federated execution plans. + * It processes each statement block, computes the optimal federated plan, + * detects and resolves conflicts, and optionally prints the plan tree. + * + * @param prog The DML program to enumerate. + * @param isPrint A boolean indicating whether to print the federated plan tree. + */ + public static void enumerateProgram(DMLProgram prog, boolean isPrint) { FederatedMemoTable memoTable = new FederatedMemoTable(); - Map transTable = new HashMap<>(); - for(StatementBlock sb : prog.getStatementBlocks()) - enumerateStatementBlock(sb, memoTable, transTable); + Map> outerTransTable = new HashMap<>(); + Map> formerInnerTransTable = new HashMap<>(); + Set progRootHopSet = new HashSet<>(); // Set of hops for the root dummy node + // TODO: Just for debug, remove later + Set statRootHopSet = new HashSet<>(); // Set of hops that have no parent but are not referenced + + for (StatementBlock sb : prog.getStatementBlocks()) { + Optional.ofNullable(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, 1, false)) + .ifPresent(outerTransTable::putAll); + } + + FedPlan optimalPlan = getMinCostRootFedPlan(progRootHopSet, memoTable); + + // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types + double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); + + // Print the federated plan tree if requested + if (isPrint) { + FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, statRootHopSet, memoTable, additionalTotalCost); + } } + /** - * Recursively enumerates federated execution plans for a given statement block. - * This method processes each type of statement block (If, For, While, Function, and generic) - * to determine the optimal federated plan. - * + * Enumerates the statement block and updates the transient and memoization tables. + * This method processes different types of statement blocks such as If, For, While, and Function blocks. + * It recursively enumerates the Hop DAGs within these blocks and updates the corresponding tables. + * The method also calculates weights recursively for if-else/loops and handles inner and outer block distinctions. + * * @param sb The statement block to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track immutable outer transient writes. + * @param formerInnerTransTable The table to track immutable former inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). + * @param weight The weight associated with the current Hop. + * @param isInnerBlock A boolean indicating if the current block is an inner block. + * @return A map of inner transient writes. */ - public static void enumerateStatementBlock(StatementBlock sb, FederatedMemoTable memoTable, Map transTable) { - // While enumerating the program, recursively determine the optimal FedPlan and MemoTable - // for each statement block and statement. - // 1. How to recursively integrate optimal FedPlans and MemoTables across statements and statement blocks? - // 1) Is it determined using the same dynamic programming approach, or simply by summing the minimal plans? - // 2. Is there a need to share the MemoTable? Are there data/hop dependencies between statements? - // 3. How to predict the number of iterations for For and While loops? - // 1) If from/to/increment are constants: Calculations can be done at compile time. - // 2) If they are variables: Use default values at compile time, adjust at runtime, or predict using ML models. + public static Map> enumerateStatementBlock(StatementBlock sb, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Set progRootHopSet, Set statRootHopSet, double weight, boolean isInnerBlock) { + Map> innerTransTable = new HashMap<>(); if (sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; IfStatement istmt = (IfStatement)isb.getStatement(0); - enumerateHopDAG(isb.getPredicateHops(), memoTable, transTable); + enumerateHopDAG(isb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + + // Treat outerTransTable as immutable in inner blocks + // Write TWrite of sb sequentially in innerTransTable, and update formerInnerTransTable after the sb ends + // In case of if-else, create separate formerInnerTransTables for if and else, merge them after completion, and update formerInnerTransTable + Map> ifFormerInnerTransTable = new HashMap<>(formerInnerTransTable); + Map> elseFormerInnerTransTable = new HashMap<>(formerInnerTransTable); - for (StatementBlock csb : istmt.getIfBody()) - enumerateStatementBlock(csb, memoTable, transTable); - for (StatementBlock csb : istmt.getElseBody()) - enumerateStatementBlock(csb, memoTable, transTable); + for (StatementBlock csb : istmt.getIfBody()){ + ifFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, ifFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); + } + + for (StatementBlock csb : istmt.getElseBody()){ + elseFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, elseFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); + } - // Todo: 1. apply iteration weight to csbFedPlans (if: 0.5, else: 0.5) - // Todo: 2. Merge predFedPlans + // If there are common keys: merge elseValue list into ifValue list + elseFormerInnerTransTable.forEach((key, elseValue) -> { + ifFormerInnerTransTable.merge(key, elseValue, (ifValue, newValue) -> { + ifValue.addAll(newValue); + return ifValue; + }); + }); + // Update innerTransTable + innerTransTable.putAll(ifFormerInnerTransTable); } else if (sb instanceof ForStatementBlock) { //incl parfor ForStatementBlock fsb = (ForStatementBlock) sb; - ForStatement fstmt = (ForStatement)fsb.getStatement(0); - enumerateHopDAG(fsb.getFromHops(), memoTable, transTable); - enumerateHopDAG(fsb.getToHops(), memoTable, transTable); - enumerateHopDAG(fsb.getIncrementHops(), memoTable, transTable); + // Calculate for-loop iteration count if possible + double loopWeight = DEFAULT_LOOP_WEIGHT; + Hop from = fsb.getFromHops().getInput().get(0); + Hop to = fsb.getToHops().getInput().get(0); + Hop incr = (fsb.getIncrementHops() != null) ? + fsb.getIncrementHops().getInput().get(0) : new LiteralOp(1); + + // Calculate for-loop iteration count (weight) if from, to, and incr are literal ops (constant values) + if( from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp ) { + double dfrom = HopRewriteUtils.getDoubleValue((LiteralOp) from); + double dto = HopRewriteUtils.getDoubleValue((LiteralOp) to); + double dincr = HopRewriteUtils.getDoubleValue((LiteralOp) incr); + if( dfrom > dto && dincr == 1 ) + dincr = -1; + loopWeight = UtilFunctions.getSeqLength(dfrom, dto, dincr, false); + } + weight *= loopWeight; - for (StatementBlock csb : fstmt.getBody()) - enumerateStatementBlock(csb, memoTable, transTable); + enumerateHopDAG(fsb.getFromHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateHopDAG(fsb.getToHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateHopDAG(fsb.getIncrementHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - // Todo: 1. get(predict) # of Iterations - // Todo: 2. apply iteration weight to csbFedPlans - // Todo: 3. Merge csbFedPlans and predFedPlans + enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); } else if (sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) sb; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); - enumerateHopDAG(wsb.getPredicateHops(), memoTable, transTable); - - ArrayList csbFedPlans = new ArrayList<>(); - for (StatementBlock csb : wstmt.getBody()) - enumerateStatementBlock(csb, memoTable, transTable); + weight *= DEFAULT_LOOP_WEIGHT; - // Todo: 1. get(predict) # of Iterations - // Todo: 2. apply iteration weight to csbFedPlans - // Todo: 3. Merge csbFedPlans and predFedPlans - } else if (sb instanceof FunctionStatementBlock) { + enumerateHopDAG(wsb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateStatementBlockBody(wstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); + } else if (sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock)sb; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); - for (StatementBlock csb : fstmt.getBody()) - enumerateStatementBlock(csb, memoTable, transTable); - // Todo: 1. Merge csbFedPlans + // TODO: NOT descent multiple types (use hash set for functions using function name) + enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); } else { //generic (last-level) - if( sb.getHops() != null ) - for( Hop c : sb.getHops() ) - enumerateHopDAG(c, memoTable, transTable); + if( sb.getHops() != null ){ + for(Hop c : sb.getHops()) + // In the statement block, if isInner, write hopDAG in innerTransTable, if not, write directly in outerTransTable + enumerateHopDAG(c, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + } } + return innerTransTable; } - + /** - * Entry point for federated plan enumeration. This method creates a memo table - * and returns the minimum cost plan for the entire Directed Acyclic Graph (DAG). - * It also resolves conflicts where FedPlans have different FederatedOutput types. - * - * @param rootHop The root Hop node from which to start the plan enumeration. - * @return The optimal FedPlan with the minimum cost for the entire DAG. + * Enumerates the statement blocks within a body and updates the transient and memoization tables. + * + * @param sbList The list of statement blocks to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track immutable outer transient writes. + * @param formerInnerTransTable The table to track immutable former inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). + * @param weight The weight associated with the current Hop. */ - public static FedPlan enumerateHopDAG(Hop rootHop, FederatedMemoTable memoTable, Map transTable) { - // Create new memo table to store all plan variants + public static void enumerateStatementBlockBody(List sbList, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Map> innerTransTable, Set progRootHopSet, Set statRootHopSet, double weight) { + // The statement blocks within the body reference outerTransTable and formerInnerTransTable as immutable read-only, + // and record TWrite in the innerTransTable of the statement block within the body. + // Update the formerInnerTransTable with the contents of the returned innerTransTable. + for (StatementBlock sb : sbList) + formerInnerTransTable.putAll(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, weight, true)); + + // Then update and return the innerTransTable of the statement block containing the body. + innerTransTable.putAll(formerInnerTransTable); + } + /** + * Enumerates the statement hop DAG within a statement block. + * This method recursively enumerates all possible federated execution plans + * and identifies hops to connect to the root dummy node. + * + * @param rootHop The root Hop of the DAG to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track transient writes. + * @param formerInnerTransTable The table to track immutable inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of root hops for debugging purposes. + * @param weight The weight associated with the current Hop. + * @param isInnerBlock A boolean indicating if the current block is an inner block. + */ + public static void enumerateHopDAG(Hop rootHop, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Map> innerTransTable, Set progRootHopSet, Set statRootHopSet, double weight, boolean isInnerBlock) { // Recursively enumerate all possible plans - enumerateHop(rootHop, memoTable, transTable); - - // Return the minimum cost plan for the root node - FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable); - - // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types - double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); - - // Print the federated plan tree if requested - FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost); - - return optimalPlan; + rewireAndEnumerateFedPlan(rootHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInnerBlock); + + // Identify hops to connect to the root dummy node + + if ((rootHop instanceof DataOp && (rootHop.getName().equals("__pred"))) // TWrite "__pred" + || (rootHop instanceof UnaryOp && ((UnaryOp)rootHop).getOp() == Types.OpOp1.PRINT)){ // u(print) + // Connect TWrite pred and u(print) to the root dummy node + // TODO: Should the last unreferenced TWrite be connected? + progRootHopSet.add(rootHop); + } else { + // TODO: Just for debug, remove later + // For identifying TWrites that are not referenced later + statRootHopSet.add(rootHop); + } } /** - * Recursively enumerates all possible federated execution plans for a Hop DAG. - * For each node: - * 1. First processes all input nodes recursively if not already processed - * 2. Generates all possible combinations of federation types (FOUT/LOUT) for inputs - * 3. Creates and evaluates both FOUT and LOUT variants for current node with each input combination - * - * The enumeration uses a bottom-up approach where: - * - Each input combination is represented by a binary number (i) - * - Bit j in i determines whether input j is FOUT (1) or LOUT (0) - * - Total number of combinations is 2^numInputs - * - * @param hop ? - * @param memoTable ? + * Rewires and enumerates federated execution plans for a given Hop. + * This method processes all input nodes, rewires TWrite and TRead operations, + * and generates federated plan variants for both inner and outer code blocks. + * + * @param hop The Hop for which to rewire and enumerate federated plans. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track transient writes. + * @param formerInnerTransTable The table to track immutable inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param weight The weight associated with the current Hop. + * @param isInner A boolean indicating if the current block is an inner block. */ - private static void enumerateHop(Hop hop, FederatedMemoTable memoTable, Map transTable) { - int numInputs = hop.getInput().size(); - + private static void rewireAndEnumerateFedPlan(Hop hop, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Map> innerTransTable, double weight, boolean isInner) { // Process all input nodes first if not already in memo table for (Hop inputHop : hop.getInput()) { - if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) - && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) { - enumerateHop(inputHop, memoTable, transTable); + long inputHopID = inputHop.getHopID(); + if (!memoTable.contains(inputHopID, FederatedOutput.FOUT) + && !memoTable.contains(inputHopID, FederatedOutput.LOUT)) { + rewireAndEnumerateFedPlan(inputHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInner); } } - if (hop instanceof DataOp - && ((DataOp)hop).getOp()== Types.OpOpData.TRANSIENTWRITE - && !(hop.getName().equals("__pred"))){ - transTable.put(hop.getName(), hop.getHopID()); - } + // Detect and Rewire TWrite and TRead operations + List childHops = hop.getInput(); + if (hop instanceof DataOp && !(hop.getName().equals("__pred"))){ + String hopName = hop.getName(); + + if (isInner){ // If it's an inner code block + if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTWRITE){ + innerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); + } else if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTREAD){ + // Copy existing and add TWrite + childHops = new ArrayList<>(childHops); + List additionalChildHops = null; + + // Read according to priority + if (innerTransTable.containsKey(hopName)){ + additionalChildHops = innerTransTable.get(hopName); + } else if (formerInnerTransTable.containsKey(hopName)){ + additionalChildHops = formerInnerTransTable.get(hopName); + } else if (outerTransTable.containsKey(hopName)){ + additionalChildHops = outerTransTable.get(hopName); + } - if (hop instanceof DataOp - && !(hop.getName().equals("__pred"))){ - if (((DataOp)hop).getOp()== Types.OpOpData.TRANSIENTWRITE){ - transTable.put(hop.getName(), hop.getHopID()); - } else if (((DataOp)hop).getOp()== Types.OpOpData.TRANSIENTREAD){ - long rWriteHopID = transTable.get(hop.getName()); + if (additionalChildHops != null) { + childHops.addAll(additionalChildHops); + } + } + } else { // If it's an outer code block + if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTWRITE){ + // Add directly to outerTransTable + outerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); + } else if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTREAD){ + childHops = new ArrayList<>(childHops); + + // TODO: In the case of for (i in 1:10), there is no hop that writes TWrite for i. + // Read directly from outerTransTable and add + List additionalChildHops = outerTransTable.get(hopName); + if (additionalChildHops != null) { + childHops.addAll(additionalChildHops); + } + } } } - // Generate all possible input combinations using binary representation - // i represents a specific combination of FOUT/LOUT for inputs - for (int i = 0; i < (1 << numInputs); i++) { - List> planChilds = new ArrayList<>(); - - // For each input, determine if it should be FOUT or LOUT based on bit j in i - for (int j = 0; j < numInputs; j++) { - Hop inputHop = hop.getInput().get(j); - // If bit j is set (1), use FOUT; otherwise use LOUT - FederatedOutput childType = ((i & (1 << j)) != 0) ? - FederatedOutput.FOUT : FederatedOutput.LOUT; - planChilds.add(Pair.of(inputHop.getHopID(), childType)); - } - - // Create and evaluate FOUT variant for current input combination - FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds); - FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable); + // Enumerate the federated plan for the current Hop + enumerateFedPlan(hop, memoTable, childHops, weight); + } - // Create and evaluate LOUT variant for current input combination - FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds); - FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable); + /** + * Enumerates federated execution plans for a given Hop. + * This method calculates the self cost and child costs for the Hop, + * generates federated plan variants for both LOUT and FOUT output types, + * and prunes redundant plans before adding them to the memo table. + * + * @param hop The Hop for which to enumerate federated plans. + * @param memoTable The memoization table to store plan variants. + * @param childHops The list of child hops. + * @param weight The weight associated with the current Hop. + */ + private static void enumerateFedPlan(Hop hop, FederatedMemoTable memoTable, List childHops, double weight){ + long hopID = hop.getHopID(); + HopCommon hopCommon = new HopCommon(hop, weight); + double selfCost = FederatedPlanCostEstimator.computeHopCost(hopCommon); + + FedPlanVariants lOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.LOUT); + FedPlanVariants fOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.FOUT); + + int numInputs = childHops.size(); + int numInitInputs = hop.getInput().size(); + + double[][] childCumulativeCost = new double[numInputs][2]; // # of child, LOUT/FOUT of child + double[] childForwardingCost = new double[numInputs]; // # of child + + // The self cost follows its own weight, while the forwarding cost follows the parent's weight. + FederatedPlanCostEstimator.getChildCosts(hopCommon, memoTable, childHops, childCumulativeCost, childForwardingCost); + + if (numInitInputs == numInputs){ + enumerateOnlyInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); + } else { + enumerateTReadInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, numInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); } - // Prune MemoTable for hop. - memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.LOUT); - memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.FOUT); + // Prune the FedPlans to remove redundant plans + lOutFedPlanVariants.pruneFedPlans(); + fOutFedPlanVariants.pruneFedPlans(); + + // Add the FedPlanVariants to the memo table + memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT, lOutFedPlanVariants); + memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT, fOutFedPlanVariants); } /** - * Returns the minimum cost plan for the root Hop, comparing both FOUT and LOUT variants. - * Used to select the final execution plan after enumeration. - * - * @param HopID ? - * @param memoTable ? - * @return ? + * Enumerates federated execution plans for initial child hops only. + * This method generates all possible combinations of federated output types (LOUT and FOUT) + * for the initial child hops and calculates their cumulative costs. + * + * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. + * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. + * @param numInitInputs The number of initial input hops. + * @param childHops The list of child hops. + * @param childCumulativeCost The cumulative costs for each child hop. + * @param childForwardingCost The forwarding costs for each child hop. + * @param selfCost The self cost of the current hop. + */ + private static void enumerateOnlyInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, int numInitInputs, List childHops, + double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ + // Iterate 2^n times, generating two FedPlans (LOUT, FOUT) each time. + for (int i = 0; i < (1 << numInitInputs); i++) { + double[] cumulativeCost = new double[]{selfCost, selfCost}; + List> planChilds = new ArrayList<>(); + // LOUT and FOUT share the same planChilds in each iteration (only forwarding cost differs). + enumerateInitChildFedPlan(numInitInputs, childHops, planChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); + + lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, planChilds)); + fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, planChilds)); + } + } + + /** + * Enumerates federated execution plans for a TRead hop. + * This method calculates the cumulative costs for both LOUT and FOUT federated output types + * by considering the additional child hops, which are TWrite hops. + * It generates all possible combinations of federated output types for the initial child hops + * and adds the pre-calculated costs of the TWrite child hops to these combinations. + * + * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. + * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. + * @param numInitInputs The number of initial input hops. + * @param numInputs The total number of input hops, including additional TWrite hops. + * @param childHops The list of child hops. + * @param childCumulativeCost The cumulative costs for each child hop. + * @param childForwardingCost The forwarding costs for each child hop. + * @param selfCost The self cost of the current hop. */ - private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) { - FedPlanVariants fOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT); - FedPlanVariants lOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT); - - FedPlan minFOutFedPlan = fOutFedPlanVariants._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - FedPlan minlOutFedPlan = lOutFedPlanVariants._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - - if (Objects.requireNonNull(minFOutFedPlan).getTotalCost() - < Objects.requireNonNull(minlOutFedPlan).getTotalCost()) { - return minFOutFedPlan; + private static void enumerateTReadInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, + int numInitInputs, int numInputs, List childHops, + double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ + double lOutTReadCumulativeCost = selfCost; + double fOutTReadCumulativeCost = selfCost; + + List> lOutTReadPlanChilds = new ArrayList<>(); + List> fOutTReadPlanChilds = new ArrayList<>(); + + // Pre-calculate the cost for the additional child hop, which is a TWrite hop, of the TRead hop. + // Constraint: TWrite must have the same FedOutType as TRead. + for (int j = numInitInputs; j < numInputs; j++) { + Hop inputHop = childHops.get(j); + lOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.LOUT)); + fOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.FOUT)); + + lOutTReadCumulativeCost += childCumulativeCost[j][0]; + fOutTReadCumulativeCost += childCumulativeCost[j][1]; + // Skip TWrite -> TRead as they have the same FedOutType. } - return minlOutFedPlan; + + for (int i = 0; i < (1 << numInitInputs); i++) { + double[] cumulativeCost = new double[]{selfCost, selfCost}; + List> lOutPlanChilds = new ArrayList<>(); + enumerateInitChildFedPlan(numInitInputs, childHops, lOutPlanChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); + + // Copy lOutPlanChilds to create fOutPlanChilds and add the pre-calculated cost of the TWrite child hop. + List> fOutPlanChilds = new ArrayList<>(lOutPlanChilds); + + lOutPlanChilds.addAll(lOutTReadPlanChilds); + fOutPlanChilds.addAll(fOutTReadPlanChilds); + + cumulativeCost[0] += lOutTReadCumulativeCost; + cumulativeCost[1] += fOutTReadCumulativeCost; + + lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, lOutPlanChilds)); + fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, fOutPlanChilds)); + } + } + + // Calculates costs for initial child hops, determining FOUT or LOUT based on `i`. + private static void enumerateInitChildFedPlan(int numInitInputs, List childHops, List> planChilds, + double[][] childCumulativeCost, double[] childForwardingCost, double[] cumulativeCost, int i){ + // For each input, determine if it should be FOUT or LOUT based on bit j in i + for (int j = 0; j < numInitInputs; j++) { + Hop inputHop = childHops.get(j); + // Calculate the bit value to decide between FOUT and LOUT for the current input + final int bit = (i & (1 << j)) != 0 ? 1 : 0; // Determine the bit value (decides FOUT/LOUT) + final FederatedOutput childType = (bit == 1) ? FederatedOutput.FOUT : FederatedOutput.LOUT; + planChilds.add(Pair.of(inputHop.getHopID(), childType)); + + // Update the cumulative cost for LOUT, FOUT + cumulativeCost[0] += childCumulativeCost[j][bit] + childForwardingCost[j] * bit; + cumulativeCost[1] += childCumulativeCost[j][bit] + childForwardingCost[j] * (1 - bit); + } + } + + // Creates a dummy root node (fedplan) and selects the FedPlan with the minimum cost to return. + // The dummy root node does not have LOUT or FOUT. + private static FedPlan getMinCostRootFedPlan(Set progRootHopSet, FederatedMemoTable memoTable) { + double cumulativeCost = 0; + List> rootFedPlanChilds = new ArrayList<>(); + + // Iterate over each Hop in the progRootHopSet + for (Hop endHop : progRootHopSet){ + // Retrieve the pruned FedPlan for LOUT and FOUT from the memo table + FedPlan lOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.LOUT); + FedPlan fOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.FOUT); + + // Compare the cumulative costs of LOUT and FOUT FedPlans + if (lOutFedPlan.getCumulativeCost() <= fOutFedPlan.getCumulativeCost()){ + cumulativeCost += lOutFedPlan.getCumulativeCost(); + rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.LOUT)); + } else{ + cumulativeCost += fOutFedPlan.getCumulativeCost(); + rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.FOUT)); + } + } + + return new FedPlan(cumulativeCost, null, rootFedPlanChilds); } /** diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java index f48332ac752..1f2c2802f46 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -19,9 +19,12 @@ package org.apache.sysds.hops.fedplanner; import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.DataOp; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.cost.ComputeCost; import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import java.util.LinkedHashMap; @@ -42,43 +45,102 @@ public class FederatedPlanCostEstimator { // Network bandwidth for data transfers between federated sites (1 Gbps) private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; + // Retrieves the cumulative and forwarding costs of the child hops and stores them in arrays + public static void getChildCosts(HopCommon hopCommon, FederatedMemoTable memoTable, List inputHops, + double[][] childCumulativeCost, double[] childForwardingCost) { + for (int i = 0; i < inputHops.size(); i++) { + long childHopID = inputHops.get(i).getHopID(); + + FedPlan childLOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.LOUT); + FedPlan childFOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.FOUT); + + // The cumulative cost of the child already includes the weight + childCumulativeCost[i][0] = childLOutFedPlan.getCumulativeCost(); + childCumulativeCost[i][1] = childFOutFedPlan.getCumulativeCost(); + + // TODO: Q. Shouldn't the child's forwarding cost follow the parent's weight, regardless of loops or if-else statements? + childForwardingCost[i] = hopCommon.weight * childLOutFedPlan.getForwardingCost(); + } + } + /** - * Computes total cost of federated plan by: - * 1. Computing current node cost (if not cached) - * 2. Adding minimum-cost child plans - * 3. Including network transfer costs when needed + * Computes the cost associated with a given Hop node. + * This method calculates both the self cost and the forwarding cost for the Hop, + * taking into account its type and the number of parent nodes. * - * @param currentPlan Plan to compute cost for - * @param memoTable Table containing all plan variants + * @param hopCommon The HopCommon object containing the Hop and its properties. + * @return The self cost of the Hop. */ - public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) { - double totalCost; - Hop currentHop = currentPlan.getHopRef(); - - // Step 1: Calculate current node costs if not already computed - if (currentPlan.getSelfCost() == 0) { - // Compute cost for current node (computation + memory access) - totalCost = computeCurrentCost(currentHop); - currentPlan.setSelfCost(totalCost); - // Calculate potential network transfer cost if federation type changes - currentPlan.setForwardingCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate())); - } else { - totalCost = currentPlan.getSelfCost(); + public static double computeHopCost(HopCommon hopCommon){ + // TWrite and TRead are meta-data operations, hence selfCost is zero + if (hopCommon.hopRef instanceof DataOp){ + if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTWRITE ){ + hopCommon.setSelfCost(0); + // Since TWrite and TRead have the same FedOutType, forwarding cost is zero + hopCommon.setForwardingCost(0); + return 0; + } else if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTREAD) { + hopCommon.setSelfCost(0); + // TRead may have a different FedOutType from its parent, so calculate forwarding cost + // TODO: Uncertain about the number of TWrites + hopCommon.setForwardingCost(computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate())); + return 0; + } } - - // Step 2: Process each child plan and add their costs - for (Pair childPlanPair : currentPlan.getChildFedPlans()) { - // Find minimum cost child plan considering federation type compatibility - // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents - // because we're selecting child plans independently for each parent - FedPlan planRef = memoTable.getMinCostFedPlan(childPlanPair); - - // Add child plan cost (includes network transfer cost if federation types differ) - totalCost += planRef.getTotalCost() + planRef.getCondForwardingCost(currentPlan.getFedOutType()); + + // In loops, selfCost is repeated, but forwarding may not be + // Therefore, the weight for forwarding follows the parent's weight (TODO: Q. Is the parent also receiving forwarding once?) + double selfCost = hopCommon.weight * computeSelfCost(hopCommon.hopRef); + double forwardingCost = computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate()); + + int numParents = hopCommon.hopRef.getParent().size(); + if (numParents >= 2) { + selfCost /= numParents; + forwardingCost /= numParents; } + + hopCommon.setSelfCost(selfCost); + hopCommon.setForwardingCost(forwardingCost); + + return selfCost; + } + + /** + * Computes the cost for the current Hop node. + * + * @param currentHop The Hop node whose cost needs to be computed + * @return The total cost for the current node's operation + */ + private static double computeSelfCost(Hop currentHop){ + double computeCost = ComputeCost.getHOPComputeCost(currentHop); + double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); + double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); - // Step 3: Set final cumulative cost including current node - currentPlan.setTotalCost(totalCost); + // Compute total cost assuming: + // 1. Computation and input access can be overlapped (hence taking max) + // 2. Output access must wait for both to complete (hence adding) + return Math.max(computeCost, inputAccessCost) + ouputAccessCost; + } + + /** + * Calculates the memory access cost based on data size and memory bandwidth. + * + * @param memSize Size of data to be accessed (in bytes) + * @return Time cost for memory access (in seconds) + */ + private static double computeHopMemoryAccessCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; + } + + /** + * Calculates the network transfer cost based on data size and network bandwidth. + * Used when federation status changes between parent and child plans. + * + * @param memSize Size of data to be transferred (in bytes) + * @return Time cost for network transfer (in seconds) + */ + private static double computeHopForwardingCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; } /** @@ -138,7 +200,7 @@ public static LinkedHashMap resolveConflictFedPlan(FederatedMe if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. - fOutAdditionalCost += confilctFOutFedPlan.getTotalCost() - confilctLOutFedPlan.getTotalCost(); + fOutAdditionalCost += confilctFOutFedPlan.getCumulativeCost() - confilctLOutFedPlan.getCumulativeCost(); if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred @@ -148,30 +210,30 @@ public static LinkedHashMap resolveConflictFedPlan(FederatedMe // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later isLOutForwarding = true; - lOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost(); + lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it - fOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost(); + fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); } } else { - lOutAdditionalCost += confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost(); + lOutAdditionalCost += confilctLOutFedPlan.getCumulativeCost() - confilctFOutFedPlan.getCumulativeCost(); if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { isLOutForwarding = true; } else { isFOutForwarding = true; - lOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost(); - fOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost(); + lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); } } } // Add network transfer costs if applicable if (isLOutForwarding) { - lOutAdditionalCost += confilctLOutFedPlan.setForwardingCost(); + lOutAdditionalCost += confilctLOutFedPlan.getForwardingCost(); } if (isFOutForwarding) { - fOutAdditionalCost += confilctFOutFedPlan.setForwardingCost(); + fOutAdditionalCost += confilctFOutFedPlan.getForwardingCost(); } // Determine the optimal federated output type based on the calculated costs @@ -199,42 +261,4 @@ public static LinkedHashMap resolveConflictFedPlan(FederatedMe } return resolvedFedPlanLinkedMap; } - - /** - * Computes the cost for the current Hop node. - * - * @param currentHop The Hop node whose cost needs to be computed - * @return The total cost for the current node's operation - */ - private static double computeCurrentCost(Hop currentHop){ - double computeCost = ComputeCost.getHOPComputeCost(currentHop); - double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); - double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); - - // Compute total cost assuming: - // 1. Computation and input access can be overlapped (hence taking max) - // 2. Output access must wait for both to complete (hence adding) - return Math.max(computeCost, inputAccessCost) + ouputAccessCost; - } - - /** - * Calculates the memory access cost based on data size and memory bandwidth. - * - * @param memSize Size of data to be accessed (in bytes) - * @return Time cost for memory access (in seconds) - */ - private static double computeHopMemoryAccessCost(double memSize) { - return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; - } - - /** - * Calculates the network transfer cost based on data size and network bandwidth. - * Used when federation status changes between parent and child plans. - * - * @param memSize Size of data to be transferred (in bytes) - * @return Time cost for network transfer (in seconds) - */ - private static double computeHopNetworkAccessCost(double memSize) { - return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; - } } diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java index 8fd17998e96..3edfbc581ad 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java @@ -20,15 +20,7 @@ package org.apache.sysds.test.component.federated; import java.io.IOException; -import java.util.Arrays; -import java.util.Collection; import java.util.HashMap; - -import org.apache.sysds.common.Types; -import org.apache.sysds.hops.Hop; -import org.apache.sysds.runtime.meta.MatrixCharacteristics; -import org.apache.sysds.test.TestUtils; -import org.apache.sysds.test.functions.federated.algorithms.FederatedL2SVMTest; import org.junit.Assert; import org.junit.Test; import org.apache.sysds.api.DMLScript; @@ -42,7 +34,6 @@ import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; - public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase { private static final String TEST_DIR = "functions/federated/privacy/"; @@ -73,6 +64,12 @@ public void setUp() {} @Test public void testFederatedPlanCostEnumerator7() { runTest("FederatedPlanCostEnumeratorTest7.dml"); } + @Test + public void testFederatedPlanCostEnumerator8() { runTest("FederatedPlanCostEnumeratorTest8.dml"); } + + @Test + public void testFederatedPlanCostEnumerator9() { runTest("FederatedPlanCostEnumeratorTest9.dml"); } + // Todo: Need to write test scripts for the federated version private void runTest( String scriptFilename ) { int index = scriptFilename.lastIndexOf(".dml"); @@ -97,8 +94,9 @@ private void runTest( String scriptFilename ) { dmlt.constructHops(prog); dmlt.rewriteHopsDAG(prog); dmlt.constructLops(prog); + dmlt.rewriteLopDAG(prog); - FederatedPlanCostEnumerator.enumerateProgram(prog); + FederatedPlanCostEnumerator.enumerateProgram(prog, true); } catch (IOException e) { e.printStackTrace(); diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml index 19b65223305..2721bbcbaf6 100644 --- a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml @@ -19,7 +19,7 @@ # #------------------------------------------------------------- -for( i in 1:10 ) +for( i in 1:100 ) { b = i + 1; print(b); diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml index 4a0ca5eaa72..b95ae1b5bb0 100644 --- a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml @@ -19,13 +19,16 @@ # #------------------------------------------------------------- -A = matrix(7,10,10); -b = rand(rows = 1, cols = ncol(A), min = 1, max = 2) -d = sum(b) * 8 +A = matrix(7, rows=10, cols=10) +b = rand(rows = 1, cols = ncol(A), min = 1, max = 2); i = 0 -while(sum(b) < d){ - i = i + 1 - b = b + i - s = b %*% A - print(mean(s)) + +while (sum(b) < i) { + i = i + 1 + b = b + i + A = A * A + s = b %*% A + print(mean(s)) } +c = sqrt(A) +print(sum(c)) \ No newline at end of file From e48d7af92bfe6a4a7b79c9f1317cf2cd5c8cbcb6 Mon Sep 17 00:00:00 2001 From: min-guk Date: Wed, 26 Feb 2025 07:47:46 +0900 Subject: [PATCH 9/9] program level fed planer --- .../FederatedPlanCostEnumerator.java | 1076 +++++++++-------- .../FederatedPlanCostEstimator.java | 490 ++++---- .../FederatedPlanCostEnumeratorTest.java | 172 +-- .../federated/FederatedPlanVisualizer.py | 247 ++++ .../FederatedPlanCostEnumeratorTest10.dml | 33 + 5 files changed, 1163 insertions(+), 855 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/federated/FederatedPlanVisualizer.py create mode 100644 src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest10.dml diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java index f32bc4a76b9..f3e8cc286db 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -17,557 +17,581 @@ * under the License. */ -package org.apache.sysds.hops.fedplanner; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Optional; -import java.util.Set; -import java.util.HashSet; - -import org.apache.commons.lang3.tuple.Pair; - -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.sysds.common.Types; -import org.apache.sysds.hops.DataOp; -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.LiteralOp; -import org.apache.sysds.hops.UnaryOp; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; -import org.apache.sysds.hops.rewrite.HopRewriteUtils; -import org.apache.sysds.parser.DMLProgram; -import org.apache.sysds.parser.ForStatement; -import org.apache.sysds.parser.ForStatementBlock; -import org.apache.sysds.parser.FunctionStatement; -import org.apache.sysds.parser.FunctionStatementBlock; -import org.apache.sysds.parser.IfStatement; -import org.apache.sysds.parser.IfStatementBlock; -import org.apache.sysds.parser.StatementBlock; -import org.apache.sysds.parser.WhileStatement; -import org.apache.sysds.parser.WhileStatementBlock; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; -import org.apache.sysds.runtime.util.UtilFunctions; - -public class FederatedPlanCostEnumerator { - private static final double DEFAULT_LOOP_WEIGHT = 10.0; - private static final double DEFAULT_IF_ELSE_WEIGHT = 0.5; - - /** - * Enumerates the entire DML program to generate federated execution plans. - * It processes each statement block, computes the optimal federated plan, - * detects and resolves conflicts, and optionally prints the plan tree. - * - * @param prog The DML program to enumerate. - * @param isPrint A boolean indicating whether to print the federated plan tree. - */ - public static void enumerateProgram(DMLProgram prog, boolean isPrint) { - FederatedMemoTable memoTable = new FederatedMemoTable(); - - Map> outerTransTable = new HashMap<>(); - Map> formerInnerTransTable = new HashMap<>(); - Set progRootHopSet = new HashSet<>(); // Set of hops for the root dummy node - // TODO: Just for debug, remove later - Set statRootHopSet = new HashSet<>(); // Set of hops that have no parent but are not referenced - - for (StatementBlock sb : prog.getStatementBlocks()) { - Optional.ofNullable(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, 1, false)) - .ifPresent(outerTransTable::putAll); - } - - FedPlan optimalPlan = getMinCostRootFedPlan(progRootHopSet, memoTable); - - // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types - double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); - - // Print the federated plan tree if requested - if (isPrint) { - FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, statRootHopSet, memoTable, additionalTotalCost); - } - } - - - /** - * Enumerates the statement block and updates the transient and memoization tables. - * This method processes different types of statement blocks such as If, For, While, and Function blocks. - * It recursively enumerates the Hop DAGs within these blocks and updates the corresponding tables. - * The method also calculates weights recursively for if-else/loops and handles inner and outer block distinctions. - * - * @param sb The statement block to enumerate. - * @param memoTable The memoization table to store plan variants. - * @param outerTransTable The table to track immutable outer transient writes. - * @param formerInnerTransTable The table to track immutable former inner transient writes. - * @param progRootHopSet The set of hops to connect to the root dummy node. - * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). - * @param weight The weight associated with the current Hop. - * @param isInnerBlock A boolean indicating if the current block is an inner block. - * @return A map of inner transient writes. - */ - public static Map> enumerateStatementBlock(StatementBlock sb, FederatedMemoTable memoTable, Map> outerTransTable, - Map> formerInnerTransTable, Set progRootHopSet, Set statRootHopSet, double weight, boolean isInnerBlock) { - Map> innerTransTable = new HashMap<>(); - - if (sb instanceof IfStatementBlock) { - IfStatementBlock isb = (IfStatementBlock) sb; - IfStatement istmt = (IfStatement)isb.getStatement(0); - - enumerateHopDAG(isb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - - // Treat outerTransTable as immutable in inner blocks - // Write TWrite of sb sequentially in innerTransTable, and update formerInnerTransTable after the sb ends - // In case of if-else, create separate formerInnerTransTables for if and else, merge them after completion, and update formerInnerTransTable - Map> ifFormerInnerTransTable = new HashMap<>(formerInnerTransTable); - Map> elseFormerInnerTransTable = new HashMap<>(formerInnerTransTable); - - for (StatementBlock csb : istmt.getIfBody()){ - ifFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, ifFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); - } - - for (StatementBlock csb : istmt.getElseBody()){ - elseFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, elseFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); - } - - // If there are common keys: merge elseValue list into ifValue list - elseFormerInnerTransTable.forEach((key, elseValue) -> { - ifFormerInnerTransTable.merge(key, elseValue, (ifValue, newValue) -> { - ifValue.addAll(newValue); - return ifValue; - }); - }); - // Update innerTransTable - innerTransTable.putAll(ifFormerInnerTransTable); - } else if (sb instanceof ForStatementBlock) { //incl parfor - ForStatementBlock fsb = (ForStatementBlock) sb; - ForStatement fstmt = (ForStatement)fsb.getStatement(0); - - // Calculate for-loop iteration count if possible - double loopWeight = DEFAULT_LOOP_WEIGHT; - Hop from = fsb.getFromHops().getInput().get(0); - Hop to = fsb.getToHops().getInput().get(0); - Hop incr = (fsb.getIncrementHops() != null) ? - fsb.getIncrementHops().getInput().get(0) : new LiteralOp(1); - - // Calculate for-loop iteration count (weight) if from, to, and incr are literal ops (constant values) - if( from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp ) { - double dfrom = HopRewriteUtils.getDoubleValue((LiteralOp) from); - double dto = HopRewriteUtils.getDoubleValue((LiteralOp) to); - double dincr = HopRewriteUtils.getDoubleValue((LiteralOp) incr); - if( dfrom > dto && dincr == 1 ) - dincr = -1; - loopWeight = UtilFunctions.getSeqLength(dfrom, dto, dincr, false); - } - weight *= loopWeight; - - enumerateHopDAG(fsb.getFromHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - enumerateHopDAG(fsb.getToHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - enumerateHopDAG(fsb.getIncrementHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - - enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); - } else if (sb instanceof WhileStatementBlock) { - WhileStatementBlock wsb = (WhileStatementBlock) sb; - WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); - weight *= DEFAULT_LOOP_WEIGHT; - - enumerateHopDAG(wsb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - enumerateStatementBlockBody(wstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); - } else if (sb instanceof FunctionStatementBlock) { - FunctionStatementBlock fsb = (FunctionStatementBlock)sb; - FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); - - // TODO: NOT descent multiple types (use hash set for functions using function name) - enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); - } else { //generic (last-level) - if( sb.getHops() != null ){ - for(Hop c : sb.getHops()) - // In the statement block, if isInner, write hopDAG in innerTransTable, if not, write directly in outerTransTable - enumerateHopDAG(c, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); - } - } - return innerTransTable; - } - - /** - * Enumerates the statement blocks within a body and updates the transient and memoization tables. - * - * @param sbList The list of statement blocks to enumerate. - * @param memoTable The memoization table to store plan variants. - * @param outerTransTable The table to track immutable outer transient writes. - * @param formerInnerTransTable The table to track immutable former inner transient writes. - * @param innerTransTable The table to track inner transient writes. - * @param progRootHopSet The set of hops to connect to the root dummy node. - * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). - * @param weight The weight associated with the current Hop. - */ - public static void enumerateStatementBlockBody(List sbList, FederatedMemoTable memoTable, Map> outerTransTable, - Map> formerInnerTransTable, Map> innerTransTable, Set progRootHopSet, Set statRootHopSet, double weight) { - // The statement blocks within the body reference outerTransTable and formerInnerTransTable as immutable read-only, - // and record TWrite in the innerTransTable of the statement block within the body. - // Update the formerInnerTransTable with the contents of the returned innerTransTable. - for (StatementBlock sb : sbList) - formerInnerTransTable.putAll(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, weight, true)); - - // Then update and return the innerTransTable of the statement block containing the body. - innerTransTable.putAll(formerInnerTransTable); - } - - /** - * Enumerates the statement hop DAG within a statement block. - * This method recursively enumerates all possible federated execution plans - * and identifies hops to connect to the root dummy node. - * - * @param rootHop The root Hop of the DAG to enumerate. - * @param memoTable The memoization table to store plan variants. - * @param outerTransTable The table to track transient writes. - * @param formerInnerTransTable The table to track immutable inner transient writes. - * @param innerTransTable The table to track inner transient writes. - * @param progRootHopSet The set of hops to connect to the root dummy node. - * @param statRootHopSet The set of root hops for debugging purposes. - * @param weight The weight associated with the current Hop. - * @param isInnerBlock A boolean indicating if the current block is an inner block. - */ - public static void enumerateHopDAG(Hop rootHop, FederatedMemoTable memoTable, Map> outerTransTable, - Map> formerInnerTransTable, Map> innerTransTable, Set progRootHopSet, Set statRootHopSet, double weight, boolean isInnerBlock) { - // Recursively enumerate all possible plans - rewireAndEnumerateFedPlan(rootHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInnerBlock); - - // Identify hops to connect to the root dummy node - - if ((rootHop instanceof DataOp && (rootHop.getName().equals("__pred"))) // TWrite "__pred" - || (rootHop instanceof UnaryOp && ((UnaryOp)rootHop).getOp() == Types.OpOp1.PRINT)){ // u(print) - // Connect TWrite pred and u(print) to the root dummy node - // TODO: Should the last unreferenced TWrite be connected? - progRootHopSet.add(rootHop); - } else { - // TODO: Just for debug, remove later - // For identifying TWrites that are not referenced later - statRootHopSet.add(rootHop); - } - } - - /** - * Rewires and enumerates federated execution plans for a given Hop. - * This method processes all input nodes, rewires TWrite and TRead operations, - * and generates federated plan variants for both inner and outer code blocks. - * - * @param hop The Hop for which to rewire and enumerate federated plans. - * @param memoTable The memoization table to store plan variants. - * @param outerTransTable The table to track transient writes. - * @param formerInnerTransTable The table to track immutable inner transient writes. - * @param innerTransTable The table to track inner transient writes. - * @param weight The weight associated with the current Hop. - * @param isInner A boolean indicating if the current block is an inner block. - */ - private static void rewireAndEnumerateFedPlan(Hop hop, FederatedMemoTable memoTable, Map> outerTransTable, - Map> formerInnerTransTable, Map> innerTransTable, double weight, boolean isInner) { + package org.apache.sysds.hops.fedplanner; + import java.util.ArrayList; + import java.util.List; + import java.util.Map; + import java.util.HashMap; + import java.util.LinkedHashMap; + import java.util.Optional; + import java.util.Set; + import java.util.HashSet; + + import org.apache.commons.lang3.tuple.Pair; + + import org.apache.commons.lang3.tuple.ImmutablePair; + import org.apache.sysds.common.Types; + import org.apache.sysds.hops.DataOp; + import org.apache.sysds.hops.Hop; + import org.apache.sysds.hops.LiteralOp; + import org.apache.sysds.hops.UnaryOp; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; + import org.apache.sysds.hops.rewrite.HopRewriteUtils; + import org.apache.sysds.parser.DMLProgram; + import org.apache.sysds.parser.ForStatement; + import org.apache.sysds.parser.ForStatementBlock; + import org.apache.sysds.parser.FunctionStatement; + import org.apache.sysds.parser.FunctionStatementBlock; + import org.apache.sysds.parser.IfStatement; + import org.apache.sysds.parser.IfStatementBlock; + import org.apache.sysds.parser.StatementBlock; + import org.apache.sysds.parser.WhileStatement; + import org.apache.sysds.parser.WhileStatementBlock; + import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + import org.apache.sysds.runtime.util.UtilFunctions; + + public class FederatedPlanCostEnumerator { + private static final double DEFAULT_LOOP_WEIGHT = 10.0; + private static final double DEFAULT_IF_ELSE_WEIGHT = 0.5; + + /** + * Enumerates the entire DML program to generate federated execution plans. + * It processes each statement block, computes the optimal federated plan, + * detects and resolves conflicts, and optionally prints the plan tree. + * + * @param prog The DML program to enumerate. + * @param isPrint A boolean indicating whether to print the federated plan tree. + */ + public static void enumerateProgram(DMLProgram prog, boolean isPrint) { + FederatedMemoTable memoTable = new FederatedMemoTable(); + + Map> outerTransTable = new HashMap<>(); + Map> formerInnerTransTable = new HashMap<>(); + Set progRootHopSet = new HashSet<>(); // Set of hops for the root dummy node + // TODO: Just for debug, remove later + Set statRootHopSet = new HashSet<>(); // Set of hops that have no parent but are not referenced + + for (StatementBlock sb : prog.getStatementBlocks()) { + Optional.ofNullable(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, 1, false)) + .ifPresent(outerTransTable::putAll); + } + + FedPlan optimalPlan = getMinCostRootFedPlan(progRootHopSet, memoTable); + + // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types + double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); + + // Print the federated plan tree if requested + if (isPrint) { + FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, statRootHopSet, memoTable, additionalTotalCost); + } + } + + + /** + * Enumerates the statement block and updates the transient and memoization tables. + * This method processes different types of statement blocks such as If, For, While, and Function blocks. + * It recursively enumerates the Hop DAGs within these blocks and updates the corresponding tables. + * The method also calculates weights recursively for if-else/loops and handles inner and outer block distinctions. + * + * @param sb The statement block to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track immutable outer transient writes. + * @param formerInnerTransTable The table to track immutable former inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). + * @param weight The weight associated with the current Hop. + * @param isInnerBlock A boolean indicating if the current block is an inner block. + * @return A map of inner transient writes. + */ + public static Map> enumerateStatementBlock(StatementBlock sb, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Set progRootHopSet, Set statRootHopSet, double weight, boolean isInnerBlock) { + Map> innerTransTable = new HashMap<>(); + + if (sb instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement)isb.getStatement(0); + + enumerateHopDAG(isb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + + // Treat outerTransTable as immutable in inner blocks + // Write TWrite of sb sequentially in innerTransTable, and update formerInnerTransTable after the sb ends + // In case of if-else, create separate formerInnerTransTables for if and else, merge them after completion, and update formerInnerTransTable + Map> ifFormerInnerTransTable = new HashMap<>(formerInnerTransTable); + Map> elseFormerInnerTransTable = new HashMap<>(formerInnerTransTable); + + for (StatementBlock csb : istmt.getIfBody()){ + ifFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, ifFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); + } + + for (StatementBlock csb : istmt.getElseBody()){ + elseFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, elseFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); + } + + // If there are common keys: merge elseValue list into ifValue list + elseFormerInnerTransTable.forEach((key, elseValue) -> { + ifFormerInnerTransTable.merge(key, elseValue, (ifValue, newValue) -> { + ifValue.addAll(newValue); + return ifValue; + }); + }); + // Update innerTransTable + innerTransTable.putAll(ifFormerInnerTransTable); + } + else if (sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + + // Calculate for-loop iteration count if possible + double loopWeight = DEFAULT_LOOP_WEIGHT; + Hop from = fsb.getFromHops().getInput().get(0); + Hop to = fsb.getToHops().getInput().get(0); + Hop incr = (fsb.getIncrementHops() != null) ? + fsb.getIncrementHops().getInput().get(0) : new LiteralOp(1); + + // Calculate for-loop iteration count (weight) if from, to, and incr are literal ops (constant values) + if( from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp ) { + double dfrom = HopRewriteUtils.getDoubleValue((LiteralOp) from); + double dto = HopRewriteUtils.getDoubleValue((LiteralOp) to); + double dincr = HopRewriteUtils.getDoubleValue((LiteralOp) incr); + if( dfrom > dto && dincr == 1 ) + dincr = -1; + loopWeight = UtilFunctions.getSeqLength(dfrom, dto, dincr, false); + } + weight *= loopWeight; + + enumerateHopDAG(fsb.getFromHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateHopDAG(fsb.getToHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateHopDAG(fsb.getIncrementHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + + enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); + } + else if (sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + weight *= DEFAULT_LOOP_WEIGHT; + + enumerateHopDAG(wsb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateStatementBlockBody(wstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); + } + else if (sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + + // TODO: Do not descend for visited functions (use a hash set for functions using their names) + enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); + } + else { //generic (last-level) + if( sb.getHops() != null ){ + for(Hop c : sb.getHops()) + // In the statement block, if isInner, write hopDAG in innerTransTable, if not, write directly in outerTransTable + enumerateHopDAG(c, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + } + } + return innerTransTable; + } + + /** + * Enumerates the statement blocks within a body and updates the transient and memoization tables. + * + * @param sbList The list of statement blocks to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track immutable outer transient writes. + * @param formerInnerTransTable The table to track immutable former inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). + * @param weight The weight associated with the current Hop. + */ + public static void enumerateStatementBlockBody(List sbList, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Map> innerTransTable, Set progRootHopSet, Set statRootHopSet, double weight) { + // The statement blocks within the body reference outerTransTable and formerInnerTransTable as immutable read-only, + // and record TWrite in the innerTransTable of the statement block within the body. + // Update the formerInnerTransTable with the contents of the returned innerTransTable. + for (StatementBlock sb : sbList) + formerInnerTransTable.putAll(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, weight, true)); + + // Then update and return the innerTransTable of the statement block containing the body. + innerTransTable.putAll(formerInnerTransTable); + } + + /** + * Enumerates the statement hop DAG within a statement block. + * This method recursively enumerates all possible federated execution plans + * and identifies hops to connect to the root dummy node. + * + * @param rootHop The root Hop of the DAG to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track transient writes. + * @param formerInnerTransTable The table to track immutable inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of root hops for debugging purposes. + * @param weight The weight associated with the current Hop. + * @param isInnerBlock A boolean indicating if the current block is an inner block. + */ + public static void enumerateHopDAG(Hop rootHop, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Map> innerTransTable, Set progRootHopSet, Set statRootHopSet, double weight, boolean isInnerBlock) { + // Recursively enumerate all possible plans + rewireAndEnumerateFedPlan(rootHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInnerBlock); + + // Identify hops to connect to the root dummy node + + if ((rootHop instanceof DataOp && (rootHop.getName().equals("__pred"))) // TWrite "__pred" + || (rootHop instanceof UnaryOp && ((UnaryOp)rootHop).getOp() == Types.OpOp1.PRINT)){ // u(print) + // Connect TWrite pred and u(print) to the root dummy node + // TODO: Should we check all statement-level root hops to see if they are not referenced? + progRootHopSet.add(rootHop); + } else { + // TODO: Just for debug, remove later + // For identifying TWrites that are not referenced later + statRootHopSet.add(rootHop); + } + } + + /** + * Rewires and enumerates federated execution plans for a given Hop. + * This method processes all input nodes, rewires TWrite and TRead operations, + * and generates federated plan variants for both inner and outer code blocks. + * + * @param hop The Hop for which to rewire and enumerate federated plans. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track transient writes. + * @param formerInnerTransTable The table to track immutable inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param weight The weight associated with the current Hop. + * @param isInner A boolean indicating if the current block is an inner block. + */ + private static void rewireAndEnumerateFedPlan(Hop hop, FederatedMemoTable memoTable, Map> outerTransTable, + Map> formerInnerTransTable, Map> innerTransTable, + double weight, boolean isInner) { // Process all input nodes first if not already in memo table for (Hop inputHop : hop.getInput()) { long inputHopID = inputHop.getHopID(); if (!memoTable.contains(inputHopID, FederatedOutput.FOUT) - && !memoTable.contains(inputHopID, FederatedOutput.LOUT)) { - rewireAndEnumerateFedPlan(inputHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInner); + && !memoTable.contains(inputHopID, FederatedOutput.LOUT)) { + rewireAndEnumerateFedPlan(inputHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInner); } } - // Detect and Rewire TWrite and TRead operations - List childHops = hop.getInput(); - if (hop instanceof DataOp && !(hop.getName().equals("__pred"))){ - String hopName = hop.getName(); - - if (isInner){ // If it's an inner code block - if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTWRITE){ - innerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); - } else if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTREAD){ - // Copy existing and add TWrite - childHops = new ArrayList<>(childHops); - List additionalChildHops = null; - - // Read according to priority - if (innerTransTable.containsKey(hopName)){ - additionalChildHops = innerTransTable.get(hopName); - } else if (formerInnerTransTable.containsKey(hopName)){ - additionalChildHops = formerInnerTransTable.get(hopName); - } else if (outerTransTable.containsKey(hopName)){ - additionalChildHops = outerTransTable.get(hopName); - } - - if (additionalChildHops != null) { - childHops.addAll(additionalChildHops); - } - } - } else { // If it's an outer code block - if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTWRITE){ - // Add directly to outerTransTable - outerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); - } else if (((DataOp)hop).getOp() == Types.OpOpData.TRANSIENTREAD){ - childHops = new ArrayList<>(childHops); - - // TODO: In the case of for (i in 1:10), there is no hop that writes TWrite for i. - // Read directly from outerTransTable and add - List additionalChildHops = outerTransTable.get(hopName); - if (additionalChildHops != null) { - childHops.addAll(additionalChildHops); - } - } - } - } + // Determine modified child hops based on DataOp type and transient operations + List childHops = rewireTransReadWrite(hop, outerTransTable, formerInnerTransTable, innerTransTable, isInner); // Enumerate the federated plan for the current Hop enumerateFedPlan(hop, memoTable, childHops, weight); } - /** - * Enumerates federated execution plans for a given Hop. - * This method calculates the self cost and child costs for the Hop, - * generates federated plan variants for both LOUT and FOUT output types, - * and prunes redundant plans before adding them to the memo table. - * - * @param hop The Hop for which to enumerate federated plans. - * @param memoTable The memoization table to store plan variants. - * @param childHops The list of child hops. - * @param weight The weight associated with the current Hop. - */ - private static void enumerateFedPlan(Hop hop, FederatedMemoTable memoTable, List childHops, double weight){ - long hopID = hop.getHopID(); - HopCommon hopCommon = new HopCommon(hop, weight); - double selfCost = FederatedPlanCostEstimator.computeHopCost(hopCommon); - - FedPlanVariants lOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.LOUT); - FedPlanVariants fOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.FOUT); - - int numInputs = childHops.size(); - int numInitInputs = hop.getInput().size(); - - double[][] childCumulativeCost = new double[numInputs][2]; // # of child, LOUT/FOUT of child - double[] childForwardingCost = new double[numInputs]; // # of child - - // The self cost follows its own weight, while the forwarding cost follows the parent's weight. - FederatedPlanCostEstimator.getChildCosts(hopCommon, memoTable, childHops, childCumulativeCost, childForwardingCost); - - if (numInitInputs == numInputs){ - enumerateOnlyInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); - } else { - enumerateTReadInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, numInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); + private static List rewireTransReadWrite(Hop hop, Map> outerTransTable, + Map> formerInnerTransTable, + Map> innerTransTable, boolean isInner) { + List childHops = hop.getInput(); + + if (!(hop instanceof DataOp) || hop.getName().equals("__pred")) { + return childHops; // Early exit for non-DataOp or __pred } - // Prune the FedPlans to remove redundant plans - lOutFedPlanVariants.pruneFedPlans(); - fOutFedPlanVariants.pruneFedPlans(); + DataOp dataOp = (DataOp) hop; + Types.OpOpData opType = dataOp.getOp(); + String hopName = dataOp.getName(); - // Add the FedPlanVariants to the memo table - memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT, lOutFedPlanVariants); - memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT, fOutFedPlanVariants); - } - - /** - * Enumerates federated execution plans for initial child hops only. - * This method generates all possible combinations of federated output types (LOUT and FOUT) - * for the initial child hops and calculates their cumulative costs. - * - * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. - * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. - * @param numInitInputs The number of initial input hops. - * @param childHops The list of child hops. - * @param childCumulativeCost The cumulative costs for each child hop. - * @param childForwardingCost The forwarding costs for each child hop. - * @param selfCost The self cost of the current hop. - */ - private static void enumerateOnlyInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, int numInitInputs, List childHops, - double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ - // Iterate 2^n times, generating two FedPlans (LOUT, FOUT) each time. - for (int i = 0; i < (1 << numInitInputs); i++) { - double[] cumulativeCost = new double[]{selfCost, selfCost}; - List> planChilds = new ArrayList<>(); - // LOUT and FOUT share the same planChilds in each iteration (only forwarding cost differs). - enumerateInitChildFedPlan(numInitInputs, childHops, planChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); - - lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, planChilds)); - fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, planChilds)); + if (isInner && opType == Types.OpOpData.TRANSIENTWRITE) { + innerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); } - } - - /** - * Enumerates federated execution plans for a TRead hop. - * This method calculates the cumulative costs for both LOUT and FOUT federated output types - * by considering the additional child hops, which are TWrite hops. - * It generates all possible combinations of federated output types for the initial child hops - * and adds the pre-calculated costs of the TWrite child hops to these combinations. - * - * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. - * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. - * @param numInitInputs The number of initial input hops. - * @param numInputs The total number of input hops, including additional TWrite hops. - * @param childHops The list of child hops. - * @param childCumulativeCost The cumulative costs for each child hop. - * @param childForwardingCost The forwarding costs for each child hop. - * @param selfCost The self cost of the current hop. - */ - private static void enumerateTReadInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, - int numInitInputs, int numInputs, List childHops, - double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ - double lOutTReadCumulativeCost = selfCost; - double fOutTReadCumulativeCost = selfCost; - - List> lOutTReadPlanChilds = new ArrayList<>(); - List> fOutTReadPlanChilds = new ArrayList<>(); - - // Pre-calculate the cost for the additional child hop, which is a TWrite hop, of the TRead hop. - // Constraint: TWrite must have the same FedOutType as TRead. - for (int j = numInitInputs; j < numInputs; j++) { - Hop inputHop = childHops.get(j); - lOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.LOUT)); - fOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.FOUT)); - - lOutTReadCumulativeCost += childCumulativeCost[j][0]; - fOutTReadCumulativeCost += childCumulativeCost[j][1]; - // Skip TWrite -> TRead as they have the same FedOutType. + else if (isInner && opType == Types.OpOpData.TRANSIENTREAD) { + childHops = rewireInnerTransRead(childHops, hopName, + innerTransTable, formerInnerTransTable, outerTransTable); } - - for (int i = 0; i < (1 << numInitInputs); i++) { - double[] cumulativeCost = new double[]{selfCost, selfCost}; - List> lOutPlanChilds = new ArrayList<>(); - enumerateInitChildFedPlan(numInitInputs, childHops, lOutPlanChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); - - // Copy lOutPlanChilds to create fOutPlanChilds and add the pre-calculated cost of the TWrite child hop. - List> fOutPlanChilds = new ArrayList<>(lOutPlanChilds); - - lOutPlanChilds.addAll(lOutTReadPlanChilds); - fOutPlanChilds.addAll(fOutTReadPlanChilds); - - cumulativeCost[0] += lOutTReadCumulativeCost; - cumulativeCost[1] += fOutTReadCumulativeCost; - - lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, lOutPlanChilds)); - fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, fOutPlanChilds)); + else if (!isInner && opType == Types.OpOpData.TRANSIENTWRITE) { + outerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); } - } - - // Calculates costs for initial child hops, determining FOUT or LOUT based on `i`. - private static void enumerateInitChildFedPlan(int numInitInputs, List childHops, List> planChilds, - double[][] childCumulativeCost, double[] childForwardingCost, double[] cumulativeCost, int i){ - // For each input, determine if it should be FOUT or LOUT based on bit j in i - for (int j = 0; j < numInitInputs; j++) { - Hop inputHop = childHops.get(j); - // Calculate the bit value to decide between FOUT and LOUT for the current input - final int bit = (i & (1 << j)) != 0 ? 1 : 0; // Determine the bit value (decides FOUT/LOUT) - final FederatedOutput childType = (bit == 1) ? FederatedOutput.FOUT : FederatedOutput.LOUT; - planChilds.add(Pair.of(inputHop.getHopID(), childType)); - - // Update the cumulative cost for LOUT, FOUT - cumulativeCost[0] += childCumulativeCost[j][bit] + childForwardingCost[j] * bit; - cumulativeCost[1] += childCumulativeCost[j][bit] + childForwardingCost[j] * (1 - bit); + else if (!isInner && opType == Types.OpOpData.TRANSIENTREAD) { + childHops = rewireOuterTransRead(childHops, hopName, outerTransTable); } - } - // Creates a dummy root node (fedplan) and selects the FedPlan with the minimum cost to return. - // The dummy root node does not have LOUT or FOUT. - private static FedPlan getMinCostRootFedPlan(Set progRootHopSet, FederatedMemoTable memoTable) { - double cumulativeCost = 0; - List> rootFedPlanChilds = new ArrayList<>(); + return childHops; + } - // Iterate over each Hop in the progRootHopSet - for (Hop endHop : progRootHopSet){ - // Retrieve the pruned FedPlan for LOUT and FOUT from the memo table - FedPlan lOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.LOUT); - FedPlan fOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.FOUT); + private static List rewireInnerTransRead(List childHops, String hopName, Map> innerTransTable, + Map> formerInnerTransTable, Map> outerTransTable) { + List newChildHops = new ArrayList<>(childHops); - // Compare the cumulative costs of LOUT and FOUT FedPlans - if (lOutFedPlan.getCumulativeCost() <= fOutFedPlan.getCumulativeCost()){ - cumulativeCost += lOutFedPlan.getCumulativeCost(); - rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.LOUT)); - } else{ - cumulativeCost += fOutFedPlan.getCumulativeCost(); - rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.FOUT)); - } + // Read according to priority: inner -> formerInner -> outer + List additionalChildHops = innerTransTable.get(hopName); + if (additionalChildHops == null) { + additionalChildHops = formerInnerTransTable.get(hopName); + } + if (additionalChildHops == null) { + additionalChildHops = outerTransTable.get(hopName); } - return new FedPlan(cumulativeCost, null, rootFedPlanChilds); + if (additionalChildHops != null) { + newChildHops.addAll(additionalChildHops); + } + return newChildHops; } - /** - * Detects and resolves conflicts in federated plans starting from the root plan. - * This function performs a breadth-first search (BFS) to traverse the federated plan tree. - * It identifies conflicts where the same plan ID has different federated output types. - * For each conflict, it records the plan ID and its conflicting parent plans. - * The function ensures that each plan ID is associated with a consistent federated output type - * by resolving these conflicts iteratively. - * - * The process involves: - * - Using a map to track conflicts, associating each plan ID with its federated output type - * and a list of parent plans. - * - Storing detected conflicts in a linked map, each entry containing a plan ID and its - * conflicting parent plans. - * - Performing BFS traversal starting from the root plan, checking each child plan for conflicts. - * - If a conflict is detected (i.e., a plan ID has different output types), the conflicting plan - * is removed from the BFS queue and added to the conflict map to prevent duplicate calculations. - * - Resolving conflicts by ensuring a consistent federated output type across the plan. - * - Re-running BFS with resolved conflicts to ensure all inconsistencies are addressed. - * - * @param rootPlan The root federated plan from which to start the conflict detection. - * @param memoTable The memoization table used to retrieve pruned federated plans. - * @return The cumulative additional cost for resolving conflicts. - */ - private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, FederatedMemoTable memoTable) { - // Map to track conflicts: maps a plan ID to its federated output type and list of parent plans - Map>> conflictCheckMap = new HashMap<>(); - - // LinkedMap to store detected conflicts, each with a plan ID and its conflicting parent plans - LinkedHashMap> conflictLinkedMap = new LinkedHashMap<>(); - - // LinkedMap for BFS traversal starting from the root plan (Do not use value (boolean)) - LinkedHashMap bfsLinkedMap = new LinkedHashMap<>(); - bfsLinkedMap.put(rootPlan, true); - - // Array to store cumulative additional cost for resolving conflicts - double[] cumulativeAdditionalCost = new double[]{0.0}; - - while (!bfsLinkedMap.isEmpty()) { - // Perform BFS to detect conflicts in federated plans - while (!bfsLinkedMap.isEmpty()) { - FedPlan currentPlan = bfsLinkedMap.keySet().iterator().next(); - bfsLinkedMap.remove(currentPlan); - - // Iterate over each child plan of the current plan - for (Pair childPlanPair : currentPlan.getChildFedPlans()) { - FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair); - - // Check if the child plan ID is already visited - if (conflictCheckMap.containsKey(childPlanPair.getLeft())) { - // Retrieve the existing conflict pair for the child plan - Pair> conflictChildPlanPair = conflictCheckMap.get(childPlanPair.getLeft()); - // Add the current plan to the list of parent plans - conflictChildPlanPair.getRight().add(currentPlan); - - // If the federated output type differs, a conflict is detected - if (conflictChildPlanPair.getLeft() != childPlanPair.getRight()) { - // If this is the first detection, remove conflictChildFedPlan from the BFS queue and add it to the conflict linked map (queue) - // If the existing FedPlan is not removed from the bfsqueue or both actions are performed, duplicate calculations for the same FedPlan and its children occur - if (!conflictLinkedMap.containsKey(childPlanPair.getLeft())) { - conflictLinkedMap.put(childPlanPair.getLeft(), conflictChildPlanPair.getRight()); - bfsLinkedMap.remove(childFedPlan); - } - } - } else { - // If no conflict exists, create a new entry in the conflict check map - List parentFedPlanList = new ArrayList<>(); - parentFedPlanList.add(currentPlan); - - // Map the child plan ID to its output type and list of parent plans - conflictCheckMap.put(childPlanPair.getLeft(), new ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList)); - // Add the child plan to the BFS queue - bfsLinkedMap.put(childFedPlan, true); - } - } - } - // Resolve these conflicts to ensure a consistent federated output type across the plan - // Re-run BFS with resolved conflicts - bfsLinkedMap = FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap, cumulativeAdditionalCost); - conflictLinkedMap.clear(); + private static List rewireOuterTransRead(List childHops, String hopName, Map> outerTransTable) { + List newChildHops = new ArrayList<>(childHops); + List additionalChildHops = outerTransTable.get(hopName); + if (additionalChildHops != null) { + newChildHops.addAll(additionalChildHops); } - - // Return the cumulative additional cost for resolving conflicts - return cumulativeAdditionalCost[0]; + return newChildHops; } -} + + /** + * Enumerates federated execution plans for a given Hop. + * This method calculates the self cost and child costs for the Hop, + * generates federated plan variants for both LOUT and FOUT output types, + * and prunes redundant plans before adding them to the memo table. + * + * @param hop The Hop for which to enumerate federated plans. + * @param memoTable The memoization table to store plan variants. + * @param childHops The list of child hops. + * @param weight The weight associated with the current Hop. + */ + private static void enumerateFedPlan(Hop hop, FederatedMemoTable memoTable, List childHops, double weight){ + long hopID = hop.getHopID(); + HopCommon hopCommon = new HopCommon(hop, weight); + double selfCost = FederatedPlanCostEstimator.computeHopCost(hopCommon); + + FedPlanVariants lOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.LOUT); + FedPlanVariants fOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.FOUT); + + int numInputs = childHops.size(); + int numInitInputs = hop.getInput().size(); + + double[][] childCumulativeCost = new double[numInputs][2]; // # of child, LOUT/FOUT of child + double[] childForwardingCost = new double[numInputs]; // # of child + + // The self cost follows its own weight, while the forwarding cost follows the parent's weight. + FederatedPlanCostEstimator.getChildCosts(hopCommon, memoTable, childHops, childCumulativeCost, childForwardingCost); + + if (numInitInputs == numInputs){ + enumerateOnlyInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); + } else { + enumerateTReadInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, numInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); + } + + // Prune the FedPlans to remove redundant plans + lOutFedPlanVariants.pruneFedPlans(); + fOutFedPlanVariants.pruneFedPlans(); + + // Add the FedPlanVariants to the memo table + memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT, lOutFedPlanVariants); + memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT, fOutFedPlanVariants); + } + + /** + * Enumerates federated execution plans for initial child hops only. + * This method generates all possible combinations of federated output types (LOUT and FOUT) + * for the initial child hops and calculates their cumulative costs. + * + * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. + * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. + * @param numInitInputs The number of initial input hops. + * @param childHops The list of child hops. + * @param childCumulativeCost The cumulative costs for each child hop. + * @param childForwardingCost The forwarding costs for each child hop. + * @param selfCost The self cost of the current hop. + */ + private static void enumerateOnlyInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, int numInitInputs, List childHops, + double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ + // Iterate 2^n times, generating two FedPlans (LOUT, FOUT) each time. + for (int i = 0; i < (1 << numInitInputs); i++) { + double[] cumulativeCost = new double[]{selfCost, selfCost}; + List> planChilds = new ArrayList<>(); + // LOUT and FOUT share the same planChilds in each iteration (only forwarding cost differs). + enumerateInitChildFedPlan(numInitInputs, childHops, planChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); + + lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, planChilds)); + fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, planChilds)); + } + } + + /** + * Enumerates federated execution plans for a TRead hop. + * This method calculates the cumulative costs for both LOUT and FOUT federated output types + * by considering the additional child hops, which are TWrite hops. + * It generates all possible combinations of federated output types for the initial child hops + * and adds the pre-calculated costs of the TWrite child hops to these combinations. + * + * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. + * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. + * @param numInitInputs The number of initial input hops. + * @param numInputs The total number of input hops, including additional TWrite hops. + * @param childHops The list of child hops. + * @param childCumulativeCost The cumulative costs for each child hop. + * @param childForwardingCost The forwarding costs for each child hop. + * @param selfCost The self cost of the current hop. + */ + private static void enumerateTReadInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, + int numInitInputs, int numInputs, List childHops, + double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ + double lOutTReadCumulativeCost = selfCost; + double fOutTReadCumulativeCost = selfCost; + + List> lOutTReadPlanChilds = new ArrayList<>(); + List> fOutTReadPlanChilds = new ArrayList<>(); + + // Pre-calculate the cost for the additional child hop, which is a TWrite hop, of the TRead hop. + // Constraint: TWrite must have the same FedOutType as TRead. + for (int j = numInitInputs; j < numInputs; j++) { + Hop inputHop = childHops.get(j); + lOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.LOUT)); + fOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.FOUT)); + + lOutTReadCumulativeCost += childCumulativeCost[j][0]; + fOutTReadCumulativeCost += childCumulativeCost[j][1]; + // Skip TWrite -> TRead as they have the same FedOutType. + } + + for (int i = 0; i < (1 << numInitInputs); i++) { + double[] cumulativeCost = new double[]{selfCost, selfCost}; + List> lOutPlanChilds = new ArrayList<>(); + enumerateInitChildFedPlan(numInitInputs, childHops, lOutPlanChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); + + // Copy lOutPlanChilds to create fOutPlanChilds and add the pre-calculated cost of the TWrite child hop. + List> fOutPlanChilds = new ArrayList<>(lOutPlanChilds); + + lOutPlanChilds.addAll(lOutTReadPlanChilds); + fOutPlanChilds.addAll(fOutTReadPlanChilds); + + cumulativeCost[0] += lOutTReadCumulativeCost; + cumulativeCost[1] += fOutTReadCumulativeCost; + + lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, lOutPlanChilds)); + fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, fOutPlanChilds)); + } + } + + // Calculates costs for initial child hops, determining FOUT or LOUT based on `i`. + private static void enumerateInitChildFedPlan(int numInitInputs, List childHops, List> planChilds, + double[][] childCumulativeCost, double[] childForwardingCost, double[] cumulativeCost, int i){ + // For each input, determine if it should be FOUT or LOUT based on bit j in i + for (int j = 0; j < numInitInputs; j++) { + Hop inputHop = childHops.get(j); + // Calculate the bit value to decide between FOUT and LOUT for the current input + final int bit = (i & (1 << j)) != 0 ? 1 : 0; // Determine the bit value (decides FOUT/LOUT) + final FederatedOutput childType = (bit == 1) ? FederatedOutput.FOUT : FederatedOutput.LOUT; + planChilds.add(Pair.of(inputHop.getHopID(), childType)); + + // Update the cumulative cost for LOUT, FOUT + cumulativeCost[0] += childCumulativeCost[j][bit] + childForwardingCost[j] * bit; + cumulativeCost[1] += childCumulativeCost[j][bit] + childForwardingCost[j] * (1 - bit); + } + } + + // Creates a dummy root node (fedplan) and selects the FedPlan with the minimum cost to return. + // The dummy root node does not have LOUT or FOUT. + private static FedPlan getMinCostRootFedPlan(Set progRootHopSet, FederatedMemoTable memoTable) { + double cumulativeCost = 0; + List> rootFedPlanChilds = new ArrayList<>(); + + // Iterate over each Hop in the progRootHopSet + for (Hop endHop : progRootHopSet){ + // Retrieve the pruned FedPlan for LOUT and FOUT from the memo table + FedPlan lOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.LOUT); + FedPlan fOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.FOUT); + + // Compare the cumulative costs of LOUT and FOUT FedPlans + if (lOutFedPlan.getCumulativeCost() <= fOutFedPlan.getCumulativeCost()){ + cumulativeCost += lOutFedPlan.getCumulativeCost(); + rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.LOUT)); + } else{ + cumulativeCost += fOutFedPlan.getCumulativeCost(); + rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.FOUT)); + } + } + + return new FedPlan(cumulativeCost, null, rootFedPlanChilds); + } + + /** + * Detects and resolves conflicts in federated plans starting from the root plan. + * This function performs a breadth-first search (BFS) to traverse the federated plan tree. + * It identifies conflicts where the same plan ID has different federated output types. + * For each conflict, it records the plan ID and its conflicting parent plans. + * The function ensures that each plan ID is associated with a consistent federated output type + * by resolving these conflicts iteratively. + * + * The process involves: + * - Using a map to track conflicts, associating each plan ID with its federated output type + * and a list of parent plans. + * - Storing detected conflicts in a linked map, each entry containing a plan ID and its + * conflicting parent plans. + * - Performing BFS traversal starting from the root plan, checking each child plan for conflicts. + * - If a conflict is detected (i.e., a plan ID has different output types), the conflicting plan + * is removed from the BFS queue and added to the conflict map to prevent duplicate calculations. + * - Resolving conflicts by ensuring a consistent federated output type across the plan. + * - Re-running BFS with resolved conflicts to ensure all inconsistencies are addressed. + * + * @param rootPlan The root federated plan from which to start the conflict detection. + * @param memoTable The memoization table used to retrieve pruned federated plans. + * @return The cumulative additional cost for resolving conflicts. + */ + private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, FederatedMemoTable memoTable) { + // Map to track conflicts: maps a plan ID to its federated output type and list of parent plans + Map>> conflictCheckMap = new HashMap<>(); + + // LinkedMap to store detected conflicts, each with a plan ID and its conflicting parent plans + LinkedHashMap> conflictLinkedMap = new LinkedHashMap<>(); + + // LinkedMap for BFS traversal starting from the root plan (Do not use value (boolean)) + LinkedHashMap bfsLinkedMap = new LinkedHashMap<>(); + bfsLinkedMap.put(rootPlan, true); + + // Array to store cumulative additional cost for resolving conflicts + double[] cumulativeAdditionalCost = new double[]{0.0}; + + while (!bfsLinkedMap.isEmpty()) { + // Perform BFS to detect conflicts in federated plans + while (!bfsLinkedMap.isEmpty()) { + FedPlan currentPlan = bfsLinkedMap.keySet().iterator().next(); + bfsLinkedMap.remove(currentPlan); + + // Iterate over each child plan of the current plan + for (Pair childPlanPair : currentPlan.getChildFedPlans()) { + FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair); + + // Check if the child plan ID is already visited + if (conflictCheckMap.containsKey(childPlanPair.getLeft())) { + // Retrieve the existing conflict pair for the child plan + Pair> conflictChildPlanPair = conflictCheckMap.get(childPlanPair.getLeft()); + // Add the current plan to the list of parent plans + conflictChildPlanPair.getRight().add(currentPlan); + + // If the federated output type differs, a conflict is detected + if (conflictChildPlanPair.getLeft() != childPlanPair.getRight()) { + // If this is the first detection, remove conflictChildFedPlan from the BFS queue and add it to the conflict linked map (queue) + // If the existing FedPlan is not removed from the bfsqueue or both actions are performed, duplicate calculations for the same FedPlan and its children occur + if (!conflictLinkedMap.containsKey(childPlanPair.getLeft())) { + conflictLinkedMap.put(childPlanPair.getLeft(), conflictChildPlanPair.getRight()); + bfsLinkedMap.remove(childFedPlan); + } + } + } else { + // If no conflict exists, create a new entry in the conflict check map + List parentFedPlanList = new ArrayList<>(); + parentFedPlanList.add(currentPlan); + + // Map the child plan ID to its output type and list of parent plans + conflictCheckMap.put(childPlanPair.getLeft(), new ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList)); + // Add the child plan to the BFS queue + bfsLinkedMap.put(childFedPlan, true); + } + } + } + // Resolve these conflicts to ensure a consistent federated output type across the plan + // Re-run BFS with resolved conflicts + bfsLinkedMap = FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap, cumulativeAdditionalCost); + conflictLinkedMap.clear(); + } + + // Return the cumulative additional cost for resolving conflicts + return cumulativeAdditionalCost[0]; + } + } + \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java index 1f2c2802f46..9ff405ab283 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -17,248 +17,248 @@ * under the License. */ -package org.apache.sysds.hops.fedplanner; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.sysds.common.Types; -import org.apache.sysds.hops.DataOp; -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.cost.ComputeCost; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; - -import java.util.LinkedHashMap; -import java.util.NoSuchElementException; -import java.util.List; -import java.util.Map; - -/** - * Cost estimator for federated execution plans. - * Calculates computation, memory access, and network transfer costs for federated operations. - * Works in conjunction with FederatedMemoTable to evaluate different execution plan variants. - */ -public class FederatedPlanCostEstimator { - // Default value is used as a reasonable estimate since we only need - // to compare relative costs between different federated plans - // Memory bandwidth for local computations (25 GB/s) - private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0; - // Network bandwidth for data transfers between federated sites (1 Gbps) - private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; - - // Retrieves the cumulative and forwarding costs of the child hops and stores them in arrays - public static void getChildCosts(HopCommon hopCommon, FederatedMemoTable memoTable, List inputHops, - double[][] childCumulativeCost, double[] childForwardingCost) { - for (int i = 0; i < inputHops.size(); i++) { - long childHopID = inputHops.get(i).getHopID(); - - FedPlan childLOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.LOUT); - FedPlan childFOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.FOUT); - - // The cumulative cost of the child already includes the weight - childCumulativeCost[i][0] = childLOutFedPlan.getCumulativeCost(); - childCumulativeCost[i][1] = childFOutFedPlan.getCumulativeCost(); - - // TODO: Q. Shouldn't the child's forwarding cost follow the parent's weight, regardless of loops or if-else statements? - childForwardingCost[i] = hopCommon.weight * childLOutFedPlan.getForwardingCost(); - } - } - - /** - * Computes the cost associated with a given Hop node. - * This method calculates both the self cost and the forwarding cost for the Hop, - * taking into account its type and the number of parent nodes. - * - * @param hopCommon The HopCommon object containing the Hop and its properties. - * @return The self cost of the Hop. - */ - public static double computeHopCost(HopCommon hopCommon){ - // TWrite and TRead are meta-data operations, hence selfCost is zero - if (hopCommon.hopRef instanceof DataOp){ - if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTWRITE ){ - hopCommon.setSelfCost(0); - // Since TWrite and TRead have the same FedOutType, forwarding cost is zero - hopCommon.setForwardingCost(0); - return 0; - } else if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTREAD) { - hopCommon.setSelfCost(0); - // TRead may have a different FedOutType from its parent, so calculate forwarding cost - // TODO: Uncertain about the number of TWrites - hopCommon.setForwardingCost(computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate())); - return 0; - } - } - - // In loops, selfCost is repeated, but forwarding may not be - // Therefore, the weight for forwarding follows the parent's weight (TODO: Q. Is the parent also receiving forwarding once?) - double selfCost = hopCommon.weight * computeSelfCost(hopCommon.hopRef); - double forwardingCost = computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate()); - - int numParents = hopCommon.hopRef.getParent().size(); - if (numParents >= 2) { - selfCost /= numParents; - forwardingCost /= numParents; - } - - hopCommon.setSelfCost(selfCost); - hopCommon.setForwardingCost(forwardingCost); - - return selfCost; - } - - /** - * Computes the cost for the current Hop node. - * - * @param currentHop The Hop node whose cost needs to be computed - * @return The total cost for the current node's operation - */ - private static double computeSelfCost(Hop currentHop){ - double computeCost = ComputeCost.getHOPComputeCost(currentHop); - double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); - double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); - - // Compute total cost assuming: - // 1. Computation and input access can be overlapped (hence taking max) - // 2. Output access must wait for both to complete (hence adding) - return Math.max(computeCost, inputAccessCost) + ouputAccessCost; - } - - /** - * Calculates the memory access cost based on data size and memory bandwidth. - * - * @param memSize Size of data to be accessed (in bytes) - * @return Time cost for memory access (in seconds) - */ - private static double computeHopMemoryAccessCost(double memSize) { - return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; - } - - /** - * Calculates the network transfer cost based on data size and network bandwidth. - * Used when federation status changes between parent and child plans. - * - * @param memSize Size of data to be transferred (in bytes) - * @return Time cost for network transfer (in seconds) - */ - private static double computeHopForwardingCost(double memSize) { - return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; - } - - /** - * Resolves conflicts in federated plans where different plans have different FederatedOutput types. - * This function traverses the list of conflicting plans in reverse order to ensure that conflicts - * are resolved from the bottom-up, allowing for consistent federated output types across the plan. - * It calculates additional costs for each potential resolution and updates the cumulative additional cost. - * - * @param memoTable The FederatedMemoTable containing all federated plan variants. - * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent plans with conflicting federated outputs. - * @param cumulativeAdditionalCost An array to store the cumulative additional cost incurred by resolving conflicts. - * @return A LinkedHashMap of resolved federated plans, marked with a boolean indicating resolution status. - */ - public static LinkedHashMap resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) { - // LinkedHashMap to store resolved federated plans for BFS traversal. - LinkedHashMap resolvedFedPlanLinkedMap = new LinkedHashMap<>(); - - // Traverse the conflictFedPlanList in reverse order after BFS to resolve conflicts - for (Map.Entry> conflictFedPlanPair : conflictFedPlanLinkedMap.entrySet()) { - long conflictHopID = conflictFedPlanPair.getKey(); - List conflictParentFedPlans = conflictFedPlanPair.getValue(); - - // Retrieve the conflicting federated plans for LOUT and FOUT types - FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT); - FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT); - - // Variables to store additional costs for LOUT and FOUT types - double lOutAdditionalCost = 0; - double fOutAdditionalCost = 0; - - // Flags to check if the plan involves network transfer - // Network transfer cost is calculated only once, even if it occurs multiple times - boolean isLOutForwarding = false; - boolean isFOutForwarding = false; - - // Determine the optimal federated output type based on the calculated costs - FederatedOutput optimalFedOutType; - - // Iterate over each parent federated plan in the current conflict pair - for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { - // Find the calculated FedOutType of the child plan - Pair cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream() - .filter(pair -> pair.getLeft().equals(conflictHopID)) - .findFirst() - .orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + conflictHopID)); - - // CASE 1. Calculated LOUT / Parent LOUT / Current LOUT: Total cost remains unchanged. - // CASE 2. Calculated LOUT / Parent FOUT / Current LOUT: Total cost remains unchanged, subtract net cost, add net cost later. - // CASE 3. Calculated FOUT / Parent LOUT / Current LOUT: Change total cost, subtract net cost. - // CASE 4. Calculated FOUT / Parent FOUT / Current LOUT: Change total cost, add net cost later. - // CASE 5. Calculated LOUT / Parent LOUT / Current FOUT: Change total cost, add net cost later. - // CASE 6. Calculated LOUT / Parent FOUT / Current FOUT: Change total cost, subtract net cost. - // CASE 7. Calculated FOUT / Parent LOUT / Current FOUT: Total cost remains unchanged, subtract net cost, add net cost later. - // CASE 8. Calculated FOUT / Parent FOUT / Current FOUT: Total cost remains unchanged. - - // Adjust LOUT, FOUT costs based on the calculated plan's output type - if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { - // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost - // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. - fOutAdditionalCost += confilctFOutFedPlan.getCumulativeCost() - confilctLOutFedPlan.getCumulativeCost(); - - if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { - // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred - // (CASE 5) If changing from calculated LOUT to current FOUT, network transfer cost occurs, but calculated later - isFOutForwarding = true; - } else { - // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred - // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later - isLOutForwarding = true; - lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); - - // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it - fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); - } - } else { - lOutAdditionalCost += confilctLOutFedPlan.getCumulativeCost() - confilctFOutFedPlan.getCumulativeCost(); - - if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { - isLOutForwarding = true; - } else { - isFOutForwarding = true; - lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); - fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); - } - } - } - - // Add network transfer costs if applicable - if (isLOutForwarding) { - lOutAdditionalCost += confilctLOutFedPlan.getForwardingCost(); - } - if (isFOutForwarding) { - fOutAdditionalCost += confilctFOutFedPlan.getForwardingCost(); - } - - // Determine the optimal federated output type based on the calculated costs - if (lOutAdditionalCost <= fOutAdditionalCost) { - optimalFedOutType = FederatedOutput.LOUT; - cumulativeAdditionalCost[0] += lOutAdditionalCost; - resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true); - } else { - optimalFedOutType = FederatedOutput.FOUT; - cumulativeAdditionalCost[0] += fOutAdditionalCost; - resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true); - } - - // Update only the optimal federated output type, not the cost itself or recursively - for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { - for (Pair childPlanPair : conflictParentFedPlan.getChildFedPlans()) { - if (childPlanPair.getLeft() == conflictHopID && childPlanPair.getRight() != optimalFedOutType) { - int index = conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair); - conflictParentFedPlan.getChildFedPlans().set(index, - Pair.of(childPlanPair.getLeft(), optimalFedOutType)); - break; - } - } - } - } - return resolvedFedPlanLinkedMap; - } -} + package org.apache.sysds.hops.fedplanner; + import org.apache.commons.lang3.tuple.Pair; + import org.apache.sysds.common.Types; + import org.apache.sysds.hops.DataOp; + import org.apache.sysds.hops.Hop; + import org.apache.sysds.hops.cost.ComputeCost; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; + import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + + import java.util.LinkedHashMap; + import java.util.NoSuchElementException; + import java.util.List; + import java.util.Map; + + /** + * Cost estimator for federated execution plans. + * Calculates computation, memory access, and network transfer costs for federated operations. + * Works in conjunction with FederatedMemoTable to evaluate different execution plan variants. + */ + public class FederatedPlanCostEstimator { + // Default value is used as a reasonable estimate since we only need + // to compare relative costs between different federated plans + // Memory bandwidth for local computations (25 GB/s) + private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0; + // Network bandwidth for data transfers between federated sites (1 Gbps) + private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; + + // Retrieves the cumulative and forwarding costs of the child hops and stores them in arrays + public static void getChildCosts(HopCommon hopCommon, FederatedMemoTable memoTable, List inputHops, + double[][] childCumulativeCost, double[] childForwardingCost) { + for (int i = 0; i < inputHops.size(); i++) { + long childHopID = inputHops.get(i).getHopID(); + + FedPlan childLOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.LOUT); + FedPlan childFOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.FOUT); + + // The cumulative cost of the child already includes the weight + childCumulativeCost[i][0] = childLOutFedPlan.getCumulativeCost(); + childCumulativeCost[i][1] = childFOutFedPlan.getCumulativeCost(); + + // TODO: Q. Shouldn't the child's forwarding cost follow the parent's weight, regardless of loops or if-else statements? + childForwardingCost[i] = hopCommon.weight * childLOutFedPlan.getForwardingCost(); + } + } + + /** + * Computes the cost associated with a given Hop node. + * This method calculates both the self cost and the forwarding cost for the Hop, + * taking into account its type and the number of parent nodes. + * + * @param hopCommon The HopCommon object containing the Hop and its properties. + * @return The self cost of the Hop. + */ + public static double computeHopCost(HopCommon hopCommon){ + // TWrite and TRead are meta-data operations, hence selfCost is zero + if (hopCommon.hopRef instanceof DataOp){ + if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTWRITE ){ + hopCommon.setSelfCost(0); + // Since TWrite and TRead have the same FedOutType, forwarding cost is zero + hopCommon.setForwardingCost(0); + return 0; + } else if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTREAD) { + hopCommon.setSelfCost(0); + // TRead may have a different FedOutType from its parent, so calculate forwarding cost + hopCommon.setForwardingCost(computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate())); + return 0; + } + } + + // In loops, selfCost is repeated, but forwarding may not be + // Therefore, the weight for forwarding follows the parent's weight (TODO: Q. Is the parent also receiving forwarding once?) + double selfCost = hopCommon.weight * computeSelfCost(hopCommon.hopRef); + double forwardingCost = computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate()); + + int numParents = hopCommon.hopRef.getParent().size(); + if (numParents >= 2) { + selfCost /= numParents; + forwardingCost /= numParents; + } + + hopCommon.setSelfCost(selfCost); + hopCommon.setForwardingCost(forwardingCost); + + return selfCost; + } + + /** + * Computes the cost for the current Hop node. + * + * @param currentHop The Hop node whose cost needs to be computed + * @return The total cost for the current node's operation + */ + private static double computeSelfCost(Hop currentHop){ + double computeCost = ComputeCost.getHOPComputeCost(currentHop); + double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); + double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); + + // Compute total cost assuming: + // 1. Computation and input access can be overlapped (hence taking max) + // 2. Output access must wait for both to complete (hence adding) + return Math.max(computeCost, inputAccessCost) + ouputAccessCost; + } + + /** + * Calculates the memory access cost based on data size and memory bandwidth. + * + * @param memSize Size of data to be accessed (in bytes) + * @return Time cost for memory access (in seconds) + */ + private static double computeHopMemoryAccessCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; + } + + /** + * Calculates the network transfer cost based on data size and network bandwidth. + * Used when federation status changes between parent and child plans. + * + * @param memSize Size of data to be transferred (in bytes) + * @return Time cost for network transfer (in seconds) + */ + private static double computeHopForwardingCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; + } + + /** + * Resolves conflicts in federated plans where different plans have different FederatedOutput types. + * This function traverses the list of conflicting plans in reverse order to ensure that conflicts + * are resolved from the bottom-up, allowing for consistent federated output types across the plan. + * It calculates additional costs for each potential resolution and updates the cumulative additional cost. + * + * @param memoTable The FederatedMemoTable containing all federated plan variants. + * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent plans with conflicting federated outputs. + * @param cumulativeAdditionalCost An array to store the cumulative additional cost incurred by resolving conflicts. + * @return A LinkedHashMap of resolved federated plans, marked with a boolean indicating resolution status. + */ + public static LinkedHashMap resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) { + // LinkedHashMap to store resolved federated plans for BFS traversal. + LinkedHashMap resolvedFedPlanLinkedMap = new LinkedHashMap<>(); + + // Traverse the conflictFedPlanList in reverse order after BFS to resolve conflicts + for (Map.Entry> conflictFedPlanPair : conflictFedPlanLinkedMap.entrySet()) { + long conflictHopID = conflictFedPlanPair.getKey(); + List conflictParentFedPlans = conflictFedPlanPair.getValue(); + + // Retrieve the conflicting federated plans for LOUT and FOUT types + FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT); + FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT); + + // Variables to store additional costs for LOUT and FOUT types + double lOutAdditionalCost = 0; + double fOutAdditionalCost = 0; + + // Flags to check if the plan involves network transfer + // Network transfer cost is calculated only once, even if it occurs multiple times + boolean isLOutForwarding = false; + boolean isFOutForwarding = false; + + // Determine the optimal federated output type based on the calculated costs + FederatedOutput optimalFedOutType; + + // Iterate over each parent federated plan in the current conflict pair + for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { + // Find the calculated FedOutType of the child plan + Pair cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream() + .filter(pair -> pair.getLeft().equals(conflictHopID)) + .findFirst() + .orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + conflictHopID)); + + // CASE 1. Calculated LOUT / Parent LOUT / Current LOUT: Total cost remains unchanged. + // CASE 2. Calculated LOUT / Parent FOUT / Current LOUT: Total cost remains unchanged, subtract net cost, add net cost later. + // CASE 3. Calculated FOUT / Parent LOUT / Current LOUT: Change total cost, subtract net cost. + // CASE 4. Calculated FOUT / Parent FOUT / Current LOUT: Change total cost, add net cost later. + // CASE 5. Calculated LOUT / Parent LOUT / Current FOUT: Change total cost, add net cost later. + // CASE 6. Calculated LOUT / Parent FOUT / Current FOUT: Change total cost, subtract net cost. + // CASE 7. Calculated FOUT / Parent LOUT / Current FOUT: Total cost remains unchanged, subtract net cost, add net cost later. + // CASE 8. Calculated FOUT / Parent FOUT / Current FOUT: Total cost remains unchanged. + + // Adjust LOUT, FOUT costs based on the calculated plan's output type + if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { + // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost + // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. + fOutAdditionalCost += confilctFOutFedPlan.getCumulativeCost() - confilctLOutFedPlan.getCumulativeCost(); + + if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { + // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred + // (CASE 5) If changing from calculated LOUT to current FOUT, network transfer cost occurs, but calculated later + isFOutForwarding = true; + } else { + // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred + // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later + isLOutForwarding = true; + lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + + // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it + fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + } + } else { + lOutAdditionalCost += confilctLOutFedPlan.getCumulativeCost() - confilctFOutFedPlan.getCumulativeCost(); + + if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { + isLOutForwarding = true; + } else { + isFOutForwarding = true; + lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + } + } + } + + // Add network transfer costs if applicable + if (isLOutForwarding) { + lOutAdditionalCost += confilctLOutFedPlan.getForwardingCost(); + } + if (isFOutForwarding) { + fOutAdditionalCost += confilctFOutFedPlan.getForwardingCost(); + } + + // Determine the optimal federated output type based on the calculated costs + if (lOutAdditionalCost <= fOutAdditionalCost) { + optimalFedOutType = FederatedOutput.LOUT; + cumulativeAdditionalCost[0] += lOutAdditionalCost; + resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true); + } else { + optimalFedOutType = FederatedOutput.FOUT; + cumulativeAdditionalCost[0] += fOutAdditionalCost; + resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true); + } + + // Update only the optimal federated output type, not the cost itself or recursively + for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { + for (Pair childPlanPair : conflictParentFedPlan.getChildFedPlans()) { + if (childPlanPair.getLeft() == conflictHopID && childPlanPair.getRight() != optimalFedOutType) { + int index = conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair); + conflictParentFedPlan.getChildFedPlans().set(index, + Pair.of(childPlanPair.getLeft(), optimalFedOutType)); + break; + } + } + } + } + return resolvedFedPlanLinkedMap; + } + } + \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java index 3edfbc581ad..0bc7d9f84f5 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java @@ -17,90 +17,94 @@ * under the License. */ -package org.apache.sysds.test.component.federated; + package org.apache.sysds.test.component.federated; -import java.io.IOException; -import java.util.HashMap; -import org.junit.Assert; -import org.junit.Test; -import org.apache.sysds.api.DMLScript; -import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.conf.DMLConfig; -import org.apache.sysds.parser.DMLProgram; -import org.apache.sysds.parser.DMLTranslator; -import org.apache.sysds.parser.ParserFactory; -import org.apache.sysds.parser.ParserWrapper; -import org.apache.sysds.test.AutomatedTestBase; -import org.apache.sysds.test.TestConfiguration; -import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; + import java.io.IOException; + import java.util.HashMap; + import org.junit.Assert; + import org.junit.Test; + import org.apache.sysds.api.DMLScript; + import org.apache.sysds.conf.ConfigurationManager; + import org.apache.sysds.conf.DMLConfig; + import org.apache.sysds.parser.DMLProgram; + import org.apache.sysds.parser.DMLTranslator; + import org.apache.sysds.parser.ParserFactory; + import org.apache.sysds.parser.ParserWrapper; + import org.apache.sysds.test.AutomatedTestBase; + import org.apache.sysds.test.TestConfiguration; + import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; + + public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase + { + private static final String TEST_DIR = "functions/federated/privacy/"; + private static final String HOME = SCRIPT_DIR + TEST_DIR; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; + + @Override + public void setUp() {} + + @Test + public void testFederatedPlanCostEnumerator1() { runTest("FederatedPlanCostEnumeratorTest1.dml"); } + + @Test + public void testFederatedPlanCostEnumerator2() { runTest("FederatedPlanCostEnumeratorTest2.dml"); } + + @Test + public void testFederatedPlanCostEnumerator3() { runTest("FederatedPlanCostEnumeratorTest3.dml"); } + + @Test + public void testFederatedPlanCostEnumerator4() { runTest("FederatedPlanCostEnumeratorTest4.dml"); } + + @Test + public void testFederatedPlanCostEnumerator5() { runTest("FederatedPlanCostEnumeratorTest5.dml"); } + + @Test + public void testFederatedPlanCostEnumerator6() { runTest("FederatedPlanCostEnumeratorTest6.dml"); } + + @Test + public void testFederatedPlanCostEnumerator7() { runTest("FederatedPlanCostEnumeratorTest7.dml"); } + + @Test + public void testFederatedPlanCostEnumerator8() { runTest("FederatedPlanCostEnumeratorTest8.dml"); } + + @Test + public void testFederatedPlanCostEnumerator9() { runTest("FederatedPlanCostEnumeratorTest9.dml"); } -public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase -{ - private static final String TEST_DIR = "functions/federated/privacy/"; - private static final String HOME = SCRIPT_DIR + TEST_DIR; - private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; - - @Override - public void setUp() {} + @Test + public void testFederatedPlanCostEnumerator10() { runTest("FederatedPlanCostEnumeratorTest10.dml"); } - @Test - public void testFederatedPlanCostEnumerator1() { runTest("FederatedPlanCostEnumeratorTest1.dml"); } - - @Test - public void testFederatedPlanCostEnumerator2() { runTest("FederatedPlanCostEnumeratorTest2.dml"); } - - @Test - public void testFederatedPlanCostEnumerator3() { runTest("FederatedPlanCostEnumeratorTest3.dml"); } - - @Test - public void testFederatedPlanCostEnumerator4() { runTest("FederatedPlanCostEnumeratorTest4.dml"); } - - @Test - public void testFederatedPlanCostEnumerator5() { runTest("FederatedPlanCostEnumeratorTest5.dml"); } - - @Test - public void testFederatedPlanCostEnumerator6() { runTest("FederatedPlanCostEnumeratorTest6.dml"); } - - @Test - public void testFederatedPlanCostEnumerator7() { runTest("FederatedPlanCostEnumeratorTest7.dml"); } - - @Test - public void testFederatedPlanCostEnumerator8() { runTest("FederatedPlanCostEnumeratorTest8.dml"); } - - @Test - public void testFederatedPlanCostEnumerator9() { runTest("FederatedPlanCostEnumeratorTest9.dml"); } - - // Todo: Need to write test scripts for the federated version - private void runTest( String scriptFilename ) { - int index = scriptFilename.lastIndexOf(".dml"); - String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); - TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); - addTestConfiguration(testName, testConfig); - loadTestConfiguration(testConfig); - - try { - DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); - ConfigurationManager.setLocalConfig(conf); - - //read script - String dmlScriptString = DMLScript.readDMLScript(true, HOME + scriptFilename); - - //parsing and dependency analysis - ParserWrapper parser = ParserFactory.createParser(); - DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); - DMLTranslator dmlt = new DMLTranslator(prog); - dmlt.liveVariableAnalysis(prog); - dmlt.validateParseTree(prog); - dmlt.constructHops(prog); - dmlt.rewriteHopsDAG(prog); - dmlt.constructLops(prog); - dmlt.rewriteLopDAG(prog); - - FederatedPlanCostEnumerator.enumerateProgram(prog, true); - } - catch (IOException e) { - e.printStackTrace(); - Assert.fail(); - } - } -} + // Todo: Need to write test scripts for the federated version + private void runTest( String scriptFilename ) { + int index = scriptFilename.lastIndexOf(".dml"); + String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); + TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); + addTestConfiguration(testName, testConfig); + loadTestConfiguration(testConfig); + + try { + DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); + ConfigurationManager.setLocalConfig(conf); + + //read script + String dmlScriptString = DMLScript.readDMLScript(true, HOME + scriptFilename); + + //parsing and dependency analysis + ParserWrapper parser = ParserFactory.createParser(); + DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + dmlt.rewriteHopsDAG(prog); + dmlt.constructLops(prog); + dmlt.rewriteLopDAG(prog); + + FederatedPlanCostEnumerator.enumerateProgram(prog, true); + } + catch (IOException e) { + e.printStackTrace(); + Assert.fail(); + } + } + } + \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanVisualizer.py b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanVisualizer.py new file mode 100644 index 00000000000..7b0ba6c7a79 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanVisualizer.py @@ -0,0 +1,247 @@ +import sys +import re +import networkx as nx +import matplotlib.pyplot as plt + +try: + import pygraphviz + from networkx.drawing.nx_agraph import graphviz_layout + HAS_PYGRAPHVIZ = True +except ImportError: + HAS_PYGRAPHVIZ = False + print("[WARNING] pygraphviz not found. Please install via 'pip install pygraphviz'.\n" + "If not installed, we will use an alternative layout (spring_layout).") + + +def parse_line(line: str): + """ + Parse a single line from the trace file to extract: + - Node ID + - Operation (hop name) + - Kind (e.g., FOUT, LOUT, NREF) + - Total cost + - Weight + - Refs (list of IDs that this node depends on) + """ + + # 1) Match a node ID in the form of "(R)" or "()" + match_id = re.match(r'^\((R|\d+)\)', line) + if not match_id: + return None + node_id = match_id.group(1) + + # 2) The remaining string after the node ID + after_id = line[match_id.end():].strip() + + # Extract operation (hop name) before the first "[" + match_label = re.search(r'^(.*?)\s*\[', after_id) + if match_label: + operation = match_label.group(1).strip() + else: + operation = after_id.strip() + + # 3) Extract the kind (content inside the first pair of brackets "[]") + match_bracket = re.search(r'\[([^\]]+)\]', after_id) + if match_bracket: + kind = match_bracket.group(1).strip() + else: + kind = "" + + # 4) Extract total and weight from the content inside curly braces "{}" + total = "" + weight = "" + match_curly = re.search(r'\{([^}]+)\}', line) + if match_curly: + curly_content = match_curly.group(1) + m_total = re.search(r'Total:\s*([\d\.]+)', curly_content) + m_weight = re.search(r'Weight:\s*([\d\.]+)', curly_content) + if m_total: + total = m_total.group(1) + if m_weight: + weight = m_weight.group(1) + + # 5) Extract reference nodes: look for the first parenthesis containing numbers after the hop name + match_refs = re.search(r'\(\s*(\d+(?:,\d+)*)\s*\)', after_id) + if match_refs: + ref_str = match_refs.group(1) + refs = [r.strip() for r in ref_str.split(',') if r.strip().isdigit()] + else: + refs = [] + + return { + 'node_id': node_id, + 'operation': operation, + 'kind': kind, + 'total': total, + 'weight': weight, + 'refs': refs + } + + +def build_dag_from_file(filename: str): + """ + Read a trace file line by line and build a directed acyclic graph (DAG) using NetworkX. + """ + G = nx.DiGraph() + with open(filename, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + + info = parse_line(line) + if not info: + continue + + node_id = info['node_id'] + operation = info['operation'] + kind = info['kind'] + total = info['total'] + weight = info['weight'] + refs = info['refs'] + + # Add node with attributes + G.add_node(node_id, label=operation, kind=kind, total=total, weight=weight) + + # Add edges from references to this node + for r in refs: + if r not in G: + G.add_node(r, label=r, kind="", total="", weight="") + G.add_edge(r, node_id) + return G + + +def main(): + """ + Main function that: + - Reads a filename from command-line arguments + - Builds a DAG from the file + - Draws and displays the DAG using matplotlib + """ + + # Get filename from command-line argument + if len(sys.argv) < 2: + print("[ERROR] No filename provided.\nUsage: python plot_federated_dag.py ") + sys.exit(1) + filename = sys.argv[1] + + print(f"[INFO] Running with filename '{filename}'") + + # Build the DAG + G = build_dag_from_file(filename) + + # Print debug info: nodes and edges + print("Nodes:", G.nodes(data=True)) + print("Edges:", list(G.edges())) + + # Decide on layout + if HAS_PYGRAPHVIZ: + # graphviz_layout with rankdir=BT (bottom to top), etc. + pos = graphviz_layout(G, prog='dot', args='-Grankdir=BT -Gnodesep=0.5 -Granksep=0.8') + else: + # Fallback layout if pygraphviz is not installed + pos = nx.spring_layout(G, seed=42) + + # Dynamically adjust figure size based on number of nodes + node_count = len(G.nodes()) + fig_width = 10 + node_count / 10.0 + fig_height = 6 + node_count / 10.0 + plt.figure(figsize=(fig_width, fig_height), facecolor='white', dpi=300) + ax = plt.gca() + ax.set_facecolor('white') + + # Generate labels for each node in the format: + # node_id: operation_name + # C (W) + labels = { + n: f"{n}: {G.nodes[n].get('label', n)}\n C{G.nodes[n].get('total', '')} (W{G.nodes[n].get('weight', '')})" + for n in G.nodes() + } + + # Function to determine color based on 'kind' + def get_color(n): + k = G.nodes[n].get('kind', '').lower() + if k == 'fout': + return 'tomato' + elif k == 'lout': + return 'dodgerblue' + elif k == 'nref': + return 'mediumpurple' + else: + return 'mediumseagreen' + + # Determine node shapes based on operation name: + # - '^' (triangle) if the label contains "twrite" + # - 's' (square) if the label contains "tread" + # - 'o' (circle) otherwise + triangle_nodes = [n for n in G.nodes() if 'twrite' in G.nodes[n].get('label', '').lower()] + square_nodes = [n for n in G.nodes() if 'tread' in G.nodes[n].get('label', '').lower()] + other_nodes = [ + n for n in G.nodes() + if 'twrite' not in G.nodes[n].get('label', '').lower() and + 'tread' not in G.nodes[n].get('label', '').lower() + ] + + # Colors for each group + triangle_colors = [get_color(n) for n in triangle_nodes] + square_colors = [get_color(n) for n in square_nodes] + other_colors = [get_color(n) for n in other_nodes] + + # Draw nodes group-wise + node_collection_triangle = nx.draw_networkx_nodes( + G, pos, nodelist=triangle_nodes, node_size=800, + node_color=triangle_colors, node_shape='^', ax=ax + ) + node_collection_square = nx.draw_networkx_nodes( + G, pos, nodelist=square_nodes, node_size=800, + node_color=square_colors, node_shape='s', ax=ax + ) + node_collection_other = nx.draw_networkx_nodes( + G, pos, nodelist=other_nodes, node_size=800, + node_color=other_colors, node_shape='o', ax=ax + ) + + # Set z-order for nodes, edges, and labels + node_collection_triangle.set_zorder(1) + node_collection_square.set_zorder(1) + node_collection_other.set_zorder(1) + + edge_collection = nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle='->', ax=ax) + if isinstance(edge_collection, list): + for ec in edge_collection: + ec.set_zorder(2) + else: + edge_collection.set_zorder(2) + + label_dict = nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, ax=ax) + for text in label_dict.values(): + text.set_zorder(3) + + # Set the title + plt.title("Program Level Federated Plan", fontsize=14, fontweight="bold") + + # Provide a small legend on the top-right or top-left + plt.text(1, 1, + "[LABEL]\n hopID: hopName\n C(Total) (W(Weight))", + fontsize=12, ha='right', va='top', transform=ax.transAxes) + + # Example mini-legend for different 'kind' values + plt.scatter(0.05, 0.95, color='dodgerblue', s=200, transform=ax.transAxes) + plt.scatter(0.18, 0.95, color='tomato', s=200, transform=ax.transAxes) + plt.scatter(0.31, 0.95, color='mediumpurple', s=200, transform=ax.transAxes) + + plt.text(0.08, 0.95, "LOUT", fontsize=12, va='center', transform=ax.transAxes) + plt.text(0.21, 0.95, "FOUT", fontsize=12, va='center', transform=ax.transAxes) + plt.text(0.34, 0.95, "NREF", fontsize=12, va='center', transform=ax.transAxes) + + plt.axis("off") + + # Save the plot to a file with the same name as the input file, but with a .png extension + output_filename = f"{filename.rsplit('.', 1)[0]}.png" + plt.savefig(output_filename, format='png', dpi=300, bbox_inches='tight') + + plt.show() + + +if __name__ == '__main__': + main() diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest10.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest10.dml new file mode 100644 index 00000000000..276de7bde91 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest10.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Recursive function: Calculate factorial +factorialUser = function(int n) return (int result) { + if (n <= 1) { + result = 1; # base case + } else { + result = n * factorialUser(n - 1); # recursive call + } +} + +number = 5; +fact_result = factorialUser(number); +print("Factorial of " + number + ": " + fact_result); \ No newline at end of file