Skip to content

Commit bc1cd7a

Browse files
authored
[FEATURE]: katz centrality
[FEATURE]: katz centrality
2 parents 313089c + 87e8534 commit bc1cd7a

5 files changed

Lines changed: 236 additions & 1 deletion

File tree

cpp_easygraph/cpp_easygraph.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ PYBIND11_MODULE(cpp_easygraph, m) {
7777

7878
m.def("cpp_closeness_centrality", &closeness_centrality, py::arg("G"), py::arg("weight") = "weight", py::arg("cutoff") = py::none(), py::arg("sources") = py::none());
7979
m.def("cpp_betweenness_centrality", &betweenness_centrality, py::arg("G"), py::arg("weight") = "weight", py::arg("cutoff") = py::none(),py::arg("sources") = py::none(), py::arg("normalized") = py::bool_(true), py::arg("endpoints") = py::bool_(false));
80+
m.def("cpp_katz_centrality", &cpp_katz_centrality, py::arg("G"), py::arg("alpha") = 0.1, py::arg("beta") = 1.0, py::arg("max_iter") = 1000, py::arg("tol") = 1e-6, py::arg("normalized") = true);
8081
m.def("cpp_k_core", &core_decomposition, py::arg("G"));
8182
m.def("cpp_density", &density, py::arg("G"));
8283
m.def("cpp_constraint", &constraint, py::arg("G"), py::arg("nodes") = py::none(), py::arg("weight") = py::none(), py::arg("n_workers") = py::none());

cpp_easygraph/functions/centrality/centrality.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,12 @@
44

55
py::object closeness_centrality(py::object G, py::object weight, py::object cutoff, py::object sources);
66
py::object betweenness_centrality(py::object G, py::object weight, py::object cutoff, py::object sources,
7-
py::object normalized, py::object endpoints);
7+
py::object normalized, py::object endpoints);
8+
py::object cpp_katz_centrality(
9+
py::object G,
10+
py::object py_alpha,
11+
py::object py_beta,
12+
py::object py_max_iter,
13+
py::object py_tol,
14+
py::object py_normalized
15+
);
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include <cmath>
2+
#include <vector>
3+
#include <pybind11/pybind11.h>
4+
#include <pybind11/stl.h>
5+
#include "centrality.h"
6+
#include "../../classes/graph.h"
7+
8+
namespace py = pybind11;
9+
10+
py::object cpp_katz_centrality(
11+
py::object G,
12+
py::object py_alpha,
13+
py::object py_beta,
14+
py::object py_max_iter,
15+
py::object py_tol,
16+
py::object py_normalized
17+
) {
18+
try {
19+
Graph& graph = G.cast<Graph&>();
20+
auto csr = graph.gen_CSR();
21+
int n = csr->nodes.size();
22+
23+
if (n == 0) {
24+
return py::dict();
25+
}
26+
27+
// Initialize vectors
28+
std::vector<double> x0(n, 1.0);
29+
std::vector<double> x1(n);
30+
std::vector<double>* x_prev = &x0;
31+
std::vector<double>* x_next = &x1;
32+
33+
// Process beta parameter
34+
std::vector<double> b(n);
35+
if (py::isinstance<py::float_>(py_beta) || py::isinstance<py::int_>(py_beta)) {
36+
double beta_val = py_beta.cast<double>();
37+
for (int i = 0; i < n; i++) {
38+
b[i] = beta_val;
39+
}
40+
} else if (py::isinstance<py::dict>(py_beta)) {
41+
py::dict beta_dict = py_beta.cast<py::dict>();
42+
for (int i = 0; i < n; i++) {
43+
node_t internal_id = csr->nodes[i];
44+
py::object node_obj = graph.id_to_node[py::cast(internal_id)];
45+
if (beta_dict.contains(node_obj)) {
46+
b[i] = beta_dict[node_obj].cast<double>();
47+
} else {
48+
b[i] = 1.0;
49+
}
50+
}
51+
} else {
52+
throw py::type_error("beta must be a float or a dict");
53+
}
54+
55+
// Extract parameters
56+
double alpha = py_alpha.cast<double>();
57+
int max_iter = py_max_iter.cast<int>();
58+
double tol = py_tol.cast<double>();
59+
bool normalized = py_normalized.cast<bool>();
60+
61+
// Iterative updates
62+
int iter = 0;
63+
for (; iter < max_iter; iter++) {
64+
for (int i = 0; i < n; i++) {
65+
double sum = 0.0;
66+
int start = csr->V[i];
67+
int end = csr->V[i + 1];
68+
for (int jj = start; jj < end; jj++) {
69+
int j = csr->E[jj];
70+
sum += (*x_prev)[j];
71+
}
72+
(*x_next)[i] = alpha * sum + b[i];
73+
}
74+
75+
// Check convergence
76+
double change = 0.0;
77+
for (int i = 0; i < n; i++) {
78+
change += std::abs((*x_next)[i] - (*x_prev)[i]);
79+
}
80+
81+
if (change < tol) {
82+
break;
83+
}
84+
85+
std::swap(x_prev, x_next);
86+
}
87+
88+
// Handle convergence failure
89+
if (iter == max_iter) {
90+
throw std::runtime_error("Katz centrality failed to converge in " + std::to_string(max_iter) + " iterations");
91+
}
92+
93+
// Normalization
94+
std::vector<double>& x_final = *x_next;
95+
if (normalized) {
96+
double norm = 0.0;
97+
for (double val : x_final) {
98+
norm += val * val;
99+
}
100+
norm = std::sqrt(norm);
101+
if (norm > 0) {
102+
for (int i = 0; i < n; i++) {
103+
x_final[i] /= norm;
104+
}
105+
}
106+
}
107+
108+
// Prepare results
109+
py::dict result;
110+
for (int i = 0; i < n; i++) {
111+
node_t internal_id = csr->nodes[i];
112+
py::object node_obj = graph.id_to_node[py::cast(internal_id)];
113+
result[node_obj] = x_final[i];
114+
}
115+
116+
return result;
117+
} catch (const std::exception& e) {
118+
throw std::runtime_error(e.what());
119+
}
120+
}

easygraph/functions/centrality/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from .flowbetweenness import *
66
from .laplacian import *
77
from .pagerank import *
8+
from .katz_centrality import *
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from easygraph.utils import *
2+
import numpy as np
3+
from easygraph.utils.decorators import *
4+
5+
__all__ = ["katz_centrality"]
6+
7+
@not_implemented_for("multigraph")
8+
@hybrid("cpp_katz_centrality")
9+
def katz_centrality(G, alpha=0.1, beta=1.0, max_iter=1000, tol=1e-6, normalized=True):
10+
r"""
11+
Compute the Katz centrality for nodes in a graph.
12+
13+
Katz centrality computes the influence of a node based on the total number
14+
of walks between nodes, attenuated by a factor of their length. It is
15+
defined as the solution to the linear system:
16+
17+
.. math::
18+
19+
x = \alpha A x + \beta
20+
21+
where:
22+
- \( A \) is the adjacency matrix of the graph,
23+
- \( \alpha \) is a scalar attenuation factor,
24+
- \( \beta \) is the bias vector (typically all ones),
25+
- and \( x \) is the resulting centrality vector.
26+
27+
The algorithm runs an iterative fixed-point method until convergence.
28+
29+
Parameters
30+
----------
31+
G : easygraph.Graph
32+
An EasyGraph graph instance. Must be simple (non-multigraph).
33+
34+
alpha : float, optional (default=0.1)
35+
Attenuation factor, must be smaller than the reciprocal of the largest
36+
eigenvalue of the adjacency matrix to ensure convergence.
37+
38+
beta : float or dict, optional (default=1.0)
39+
Bias term. Can be a constant scalar applied to all nodes, or a dictionary
40+
mapping node IDs to values.
41+
42+
max_iter : int, optional (default=1000)
43+
Maximum number of iterations before the algorithm terminates.
44+
45+
tol : float, optional (default=1e-6)
46+
Convergence tolerance. Iteration stops when the L1 norm of the difference
47+
between successive iterations is below this threshold.
48+
49+
normalized : bool, optional (default=True)
50+
If True, the result vector will be normalized to unit norm (L2).
51+
52+
Returns
53+
-------
54+
dict
55+
A dictionary mapping node IDs to Katz centrality scores.
56+
57+
Raises
58+
------
59+
RuntimeError
60+
If the algorithm fails to converge within `max_iter` iterations.
61+
62+
Examples
63+
--------
64+
>>> import easygraph as eg
65+
>>> from easygraph import katz_centrality
66+
>>> G = eg.Graph()
67+
>>> G.add_edges_from([(0, 1), (1, 2), (2, 3)])
68+
>>> katz_centrality(G, alpha=0.05)
69+
{0: 0.370..., 1: 0.447..., 2: 0.447..., 3: 0.370...}
70+
"""
71+
# Create node ordering
72+
nodes = list(G.nodes)
73+
n = len(nodes)
74+
node_to_index = {node: i for i, node in enumerate(nodes)}
75+
index_to_node = {i: node for i, node in enumerate(nodes)}
76+
77+
# Build adjacency matrix
78+
A = np.zeros((n, n), dtype=np.float64)
79+
for u in G.nodes:
80+
for v in G.adj[u]:
81+
A[node_to_index[u], node_to_index[v]] = 1.0
82+
83+
# Initialize x and beta
84+
x = np.ones(n, dtype=np.float64)
85+
if isinstance(beta, dict):
86+
b = np.array([beta.get(index_to_node[i], 1.0) for i in range(n)])
87+
else:
88+
b = np.ones(n, dtype=np.float64) * beta
89+
90+
# Iterative update using vectorized ops
91+
for _ in range(max_iter):
92+
x_new = alpha * A @ x + b
93+
if np.linalg.norm(x_new - x, ord=1) < tol:
94+
break
95+
x = x_new
96+
else:
97+
raise RuntimeError(f"Katz centrality failed to converge in {max_iter} iterations")
98+
99+
if normalized:
100+
norm = np.linalg.norm(x)
101+
if norm > 0:
102+
x /= norm
103+
104+
result = {index_to_node[i]: float(x[i]) for i in range(n)}
105+
return result

0 commit comments

Comments
 (0)