Skip to content

Commit 569e338

Browse files
authored
feat: Reduce nodes over heterogeneous graphs (#634)
* feat: Reduce nodes over heterogeneous graphs * test: Add test for reduce
1 parent 6fdcc51 commit 569e338

2 files changed

Lines changed: 22 additions & 0 deletions

File tree

GNNlib/src/utils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ function reduce_nodes(aggr, indicator::AbstractVector, x)
2727
return NNlib.scatter(aggr, x, indicator)
2828
end
2929

30+
"""
31+
reduce_nodes(aggr, node_type, g, x)
32+
33+
Return the graph-wise aggregation of the node features `x` on type `node_type`
34+
given a heterogeneous graph `g`. The aggregation operator `aggr` can be `+`,
35+
`mean`, `max`, or `min`.
36+
"""
37+
function reduce_nodes(aggr, node_type, g::GNNHeteroGraph, x)
38+
return NNlib.scatter(aggr, x[node_type], graph_indicator(g, node_type))
39+
end
40+
3041
"""
3142
reduce_edges(aggr, g, e)
3243

GNNlib/test/utils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@
1919
@test r2 == r
2020
end
2121

22+
@testset "reduce_nodes" begin
23+
g = rand_bipartite_heterograph((5, 10), 20)
24+
x = (
25+
A = [Float32(i) for j = 1:1, i = 1:g.num_nodes[:A]],
26+
B = [Float32(0) for j = 1:2, _ = 1:g.num_nodes[:B]],
27+
)
28+
expected = sum(i for i = 1:g.num_nodes[:A])
29+
result = reduce_nodes(+, :A, g, x)
30+
@test result == [expected;;]
31+
end
32+
2233
@testset "reduce_edges" begin
2334
r = reduce_edges(mean, g, e)
2435
@test size(r) == (De, g.num_graphs)

0 commit comments

Comments
 (0)