Skip to content

Commit 37a54f3

Browse files
Refactor the V1LabelledKwargDataflowGraph format to use a single-level map.
1 parent c802cee commit 37a54f3

2 files changed

Lines changed: 15 additions & 20 deletions

File tree

lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.dtg.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ type = "std::unordered_map<::FlexFlow::nonnegative_int, NodeLabel>"
3636

3737
[[fields]]
3838
name = "output_labels"
39-
type = "std::unordered_map<::FlexFlow::nonnegative_int, std::unordered_map<SlotName, OutputLabel>>"
39+
type = "std::unordered_map<::FlexFlow::V1GraphOutput<SlotName>, OutputLabel>"
4040

4141
[[fields]]
4242
name = "graph"

lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
#include "utils/containers/map_keys.h"
88
#include "utils/containers/map_values.h"
99
#include "utils/containers/transform.h"
10-
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h"
10+
#include "utils/containers/unordered_map_from_pairs.h"
11+
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h"
1112
#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h"
1213
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h"
1314
#include "utils/graph/node/algorithms.h"
@@ -28,16 +29,13 @@ std::pair<V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName>,
2829
std::unordered_map<nonnegative_int, NodeLabel> node_labels = map_values(
2930
nodes.as_unordered_map(), [&](Node const &n) { return g.at(n); });
3031

31-
std::unordered_map<nonnegative_int, std::unordered_map<SlotName, OutputLabel>>
32-
output_labels = map_values(
33-
nodes.as_unordered_map(),
34-
[&](Node const &n) -> std::unordered_map<SlotName, OutputLabel> {
35-
return map_values(
36-
get_outgoing_kwarg_dataflow_outputs_for_node(g, n),
37-
[&](KwargDataflowOutput<SlotName> const &o) {
38-
return g.at(o);
39-
});
40-
});
32+
std::unordered_map<V1GraphOutput<SlotName>, OutputLabel> output_labels =
33+
unordered_map_from_pairs(transform(
34+
get_all_kwarg_dataflow_outputs(g),
35+
[&](KwargDataflowOutput<SlotName> const &o) {
36+
return std::pair{V1GraphOutput{nodes.at_r(o.node), o.slot_name},
37+
g.at(o)};
38+
}));
4139

4240
return {
4341
V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName>{
@@ -63,14 +61,11 @@ std::pair<LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName>,
6361
std::unordered_map<Node, NodeLabel> node_labels = map_keys(
6462
v1.node_labels, [&](nonnegative_int n) { return node_map.at(n); });
6563

66-
std::unordered_map<KwargDataflowOutput<SlotName>, OutputLabel> value_labels;
67-
for (auto const &[nonneg_node, slot_map] : v1.output_labels) {
68-
Node node = node_map.at(nonneg_node);
69-
for (auto const &[slot_name, label] : slot_map) {
70-
value_labels.emplace(KwargDataflowOutput<SlotName>{node, slot_name},
71-
label);
72-
}
73-
}
64+
std::unordered_map<KwargDataflowOutput<SlotName>, OutputLabel> value_labels =
65+
map_keys(v1.output_labels, [&](V1GraphOutput<SlotName> const &o) {
66+
Node n = Node{o.node.size_t_from_nonnegative_int()};
67+
return KwargDataflowOutput<SlotName>{n, o.slot_name};
68+
});
7469

7570
return std::pair{kwarg_dataflow_graph_view_with_labelling(
7671
graph_view, node_labels, value_labels),

0 commit comments

Comments
 (0)