Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dpnp/scipy/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@

"""

from ._decomp_lu import lu_factor, lu_solve
from ._decomp_lu import lu, lu_factor, lu_solve

__all__ = [
"lu",
"lu_factor",
"lu_solve",
]
143 changes: 143 additions & 0 deletions dpnp/scipy/linalg/_decomp_lu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,154 @@
)

from ._utils import (
dpnp_lu,
dpnp_lu_factor,
dpnp_lu_solve,
)


def lu(
a, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
):
"""
Compute LU decomposition of a matrix with partial pivoting.

The decomposition satisfies::

A = P @ L @ U

where `P` is a permutation matrix, `L` is lower triangular with unit
diagonal elements, and `U` is upper triangular. If `permute_l` is set to
``True`` then `L` is returned already permuted and hence satisfying
``A = L @ U``.

For full documentation refer to :obj:`scipy.linalg.lu`.

Parameters
----------
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
Input array to decompose.
permute_l : bool, optional
Perform the multiplication ``P @ L`` (Default: do not permute).

Default: ``False``.
overwrite_a : {None, bool}, optional
Whether to overwrite data in `a` (may increase performance).

Default: ``False``.
check_finite : {None, bool}, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Default: ``True``.
p_indices : bool, optional
If ``True`` the permutation information is returned as row indices
instead of a permutation matrix.

Default: ``False``.

Returns
-------
**(If ``permute_l`` is ``False``)**

p : (..., M, M) dpnp.ndarray or (..., M) dpnp.ndarray
If `p_indices` is ``False`` (default), the permutation matrix.
The permutation matrix always has a real dtype (``float32`` or
``float64``) even when `a` is complex, since it only contains
0s and 1s.
If `p_indices` is ``True``, a 1-D (or batched) array of row
permutation indices such that ``A = L[p] @ U``.
l : (..., M, K) dpnp.ndarray
Lower triangular or trapezoidal matrix with unit diagonal.
``K = min(M, N)``.
u : (..., K, N) dpnp.ndarray
Upper triangular or trapezoidal matrix.

**(If ``permute_l`` is ``True``)**

pl : (..., M, K) dpnp.ndarray
Permuted ``L`` matrix: ``pl = P @ L``.
``K = min(M, N)``.
u : (..., K, N) dpnp.ndarray
Upper triangular or trapezoidal matrix.

Notes
-----
Permutation matrices are costly since they are nothing but row reorder of
``L`` and hence indices are strongly recommended to be used instead if the
permutation is required. The relation in the 2D case then becomes simply
``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
to avoid complicated indexing tricks.

In the 2D case, if one has the indices however, for some reason, the
permutation matrix is still needed then it can be constructed by
``dpnp.eye(M)[P, :]``.

Warning
-------
This function synchronizes in order to validate array elements
when ``check_finite=True``, and also synchronizes to compute the
permutation from LAPACK pivot indices.

See Also
--------
:obj:`dpnp.scipy.linalg.lu_factor` : LU factorize a matrix
(compact representation).
:obj:`dpnp.scipy.linalg.lu_solve` : Solve an equation system using
the LU factorization of a matrix.

Examples
--------
>>> import dpnp as np
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8],
... [7, 5, 6, 6], [5, 4, 4, 8]])
>>> p, l, u = np.scipy.linalg.lu(A)
>>> np.allclose(A, p @ l @ u)
array(True)

Retrieve the permutation as row indices with ``p_indices=True``:

>>> p, l, u = np.scipy.linalg.lu(A, p_indices=True)
>>> p
array([1, 3, 0, 2])
>>> np.allclose(A, l[p] @ u)
array(True)

Return the permuted ``L`` directly with ``permute_l=True``:

>>> pl, u = np.scipy.linalg.lu(A, permute_l=True)
>>> np.allclose(A, pl @ u)
array(True)

Non-square matrices are supported:

>>> B = np.array([[1, 2, 3], [4, 5, 6]])
>>> p, l, u = np.scipy.linalg.lu(B)
>>> np.allclose(B, p @ l @ u)
array(True)

Batched input:

>>> C = np.random.randn(3, 2, 4, 4)
>>> p, l, u = np.scipy.linalg.lu(C)
>>> np.allclose(C, p @ l @ u)
array(True)

"""

dpnp.check_supported_arrays_type(a)
assert_stacked_2d(a)

return dpnp_lu(
a,
overwrite_a=overwrite_a,
check_finite=check_finite,
p_indices=p_indices,
permute_l=permute_l,
)


def lu_factor(a, overwrite_a=False, check_finite=True):
"""
Compute the pivoted LU decomposition of `a` matrix.
Expand Down
Loading