Skip to content

Commit 03afa5b

Browse files
tomicaprettoricardoV94
authored andcommitted
Remove SquareDiagonal Op and replace it with a square_diagonal function
1 parent 2b15ce1 commit 03afa5b

2 files changed

Lines changed: 30 additions & 80 deletions

File tree

pytensor/sparse/basic.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,53 +1445,13 @@ def infer_shape(self, fgraph, nodes, shapes):
14451445
diag = Diag()
14461446

14471447

1448-
class SquareDiagonal(Op):
1449-
"""Produce a square sparse (csc) matrix with a diagonal given by a dense vector.
1450-
1451-
Notes
1452-
-----
1453-
The grad implemented is regular, i.e. not structured.
1454-
1455-
"""
1456-
1457-
__props__ = ()
1458-
1459-
def make_node(self, diag):
1460-
"""
1461-
1462-
Parameters
1463-
----------
1464-
x
1465-
Dense vector for the diagonal.
1466-
1467-
"""
1468-
diag = ptb.as_tensor_variable(diag)
1469-
if diag.type.ndim != 1:
1470-
raise TypeError("data argument must be a vector", diag.type)
1471-
1472-
return Apply(self, [diag], [SparseTensorType(dtype=diag.dtype, format="csc")()])
1473-
1474-
def perform(self, node, inputs, outputs):
1475-
(z,) = outputs
1476-
diag = inputs[0]
1477-
1478-
N = len(diag)
1479-
data = diag[:N]
1480-
indices = list(range(N))
1481-
indptr = list(range(N + 1))
1482-
tup = (data, indices, indptr)
1483-
1484-
z[0] = scipy.sparse.csc_matrix(tup, copy=True)
1485-
1486-
def grad(self, inputs, gout):
1487-
(gz,) = gout
1488-
return [diag(gz)]
1489-
1490-
def infer_shape(self, fgraph, nodes, shapes):
1491-
return [(shapes[0][0], shapes[0][0])]
1492-
1493-
1494-
square_diagonal = SquareDiagonal()
1448+
def square_diagonal(diag):
1449+
"""Produce a square sparse (csc) matrix with a diagonal given by a dense vector."""
1450+
n = diag.shape[0]
1451+
data = ptb.as_tensor_variable(diag)
1452+
indices = ptb.arange(n, dtype=np.int32)
1453+
indptr = ptb.arange(n + 1, dtype=np.int32)
1454+
return CSC(data, indices, indptr, ptb.as_tensor((n, n)))
14951455

14961456

14971457
class EnsureSortedIndices(Op):

tests/sparse/test_basic.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
Remove0,
3030
SparseFromDense,
3131
SparseTensorType,
32-
SquareDiagonal,
3332
Transpose,
3433
VStack,
3534
_is_sparse,
@@ -966,42 +965,33 @@ def test_grad(self):
966965
verify_grad_sparse(self.op, data, structured=False)
967966

968967

969-
class TestSquareDiagonal(utt.InferShapeTester):
970-
def setup_method(self):
971-
super().setup_method()
972-
self.op_class = SquareDiagonal
973-
self.op = square_diagonal
974-
968+
class TestSquareDiagonal:
975969
def test_op(self):
976-
for format in sparse.sparse_formats:
977-
for size in range(5, 9):
978-
variable = [vector()]
979-
data = [np.random.random(size).astype(config.floatX)]
980-
981-
f = pytensor.function(variable, self.op(*variable))
982-
tested = f(*data).toarray()
983-
984-
expected = np.diag(*data)
985-
utt.assert_allclose(expected, tested)
986-
assert tested.dtype == expected.dtype
987-
assert tested.shape == expected.shape
988-
989-
def test_infer_shape(self):
990-
for format in sparse.sparse_formats:
991-
for size in range(5, 9):
992-
variable = [vector()]
993-
data = [np.random.random(size).astype(config.floatX)]
994-
995-
self._compile_and_check(
996-
variable, [self.op(*variable)], data, self.op_class
997-
)
970+
x = vector(dtype=config.floatX)
971+
y = square_diagonal(x)
972+
f = pytensor.function([x], y)
973+
974+
size = 11
975+
values = np.random.random(size).astype(config.floatX)
976+
tested = f(values)
977+
978+
assert tested.format == "csc"
979+
utt.assert_allclose(tested.toarray(), np.diag(values))
980+
assert tuple(tested.shape) == (values.size, values.size)
981+
np.testing.assert_array_equal(
982+
tested.indices, np.arange(values.size, dtype="int32")
983+
)
984+
np.testing.assert_array_equal(
985+
tested.indptr, np.arange(values.size + 1, dtype="int32")
986+
)
998987

999988
def test_grad(self):
1000-
for format in sparse.sparse_formats:
1001-
for size in range(5, 9):
1002-
data = [np.random.random(size).astype(config.floatX)]
989+
values = [np.random.random(13).astype(config.floatX)]
990+
verify_grad_sparse(square_diagonal, values, structured=False)
1003991

1004-
verify_grad_sparse(self.op, data, structured=False)
992+
def test_rejects_non_vector_input(self):
993+
with pytest.raises(TypeError, match="data argument must be a vector"):
994+
square_diagonal(matrix(dtype=config.floatX))
1005995

1006996

1007997
class TestEnsureSortedIndices(utt.InferShapeTester):

0 commit comments

Comments
 (0)