diff --git a/dpnp/scipy/linalg/__init__.py b/dpnp/scipy/linalg/__init__.py index 3afc08a6fdb..81eadd890fa 100644 --- a/dpnp/scipy/linalg/__init__.py +++ b/dpnp/scipy/linalg/__init__.py @@ -35,9 +35,10 @@ """ -from ._decomp_lu import lu_factor, lu_solve +from ._decomp_lu import lu, lu_factor, lu_solve __all__ = [ + "lu", "lu_factor", "lu_solve", ] diff --git a/dpnp/scipy/linalg/_decomp_lu.py b/dpnp/scipy/linalg/_decomp_lu.py index 292d7fffe4b..c890bb8cc09 100644 --- a/dpnp/scipy/linalg/_decomp_lu.py +++ b/dpnp/scipy/linalg/_decomp_lu.py @@ -46,11 +46,154 @@ ) from ._utils import ( + dpnp_lu, dpnp_lu_factor, dpnp_lu_solve, ) +def lu( + a, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False +): + """ + Compute LU decomposition of a matrix with partial pivoting. + + The decomposition satisfies:: + + A = P @ L @ U + + where `P` is a permutation matrix, `L` is lower triangular with unit + diagonal elements, and `U` is upper triangular. If `permute_l` is set to + ``True`` then `L` is returned already permuted and hence satisfying + ``A = L @ U``. + + For full documentation refer to :obj:`scipy.linalg.lu`. + + Parameters + ---------- + a : (..., M, N) {dpnp.ndarray, usm_ndarray} + Input array to decompose. + permute_l : bool, optional + Perform the multiplication ``P @ L`` (Default: do not permute). + + Default: ``False``. + overwrite_a : {None, bool}, optional + Whether to overwrite data in `a` (may increase performance). + + Default: ``False``. + check_finite : {None, bool}, optional + Whether to check that the input matrix contains only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Default: ``True``. + p_indices : bool, optional + If ``True`` the permutation information is returned as row indices + instead of a permutation matrix. + + Default: ``False``. + + Returns + ------- + **(If ``permute_l`` is ``False``)** + + p : (..., M, M) dpnp.ndarray or (..., M) dpnp.ndarray + If `p_indices` is ``False`` (default), the permutation matrix. + The permutation matrix always has a real dtype (``float32`` or + ``float64``) even when `a` is complex, since it only contains + 0s and 1s. + If `p_indices` is ``True``, a 1-D (or batched) array of row + permutation indices such that ``A = L[p] @ U``. + l : (..., M, K) dpnp.ndarray + Lower triangular or trapezoidal matrix with unit diagonal. + ``K = min(M, N)``. + u : (..., K, N) dpnp.ndarray + Upper triangular or trapezoidal matrix. + + **(If ``permute_l`` is ``True``)** + + pl : (..., M, K) dpnp.ndarray + Permuted ``L`` matrix: ``pl = P @ L``. + ``K = min(M, N)``. + u : (..., K, N) dpnp.ndarray + Upper triangular or trapezoidal matrix. + + Notes + ----- + Permutation matrices are costly since they are nothing but row reorder of + ``L`` and hence indices are strongly recommended to be used instead if the + permutation is required. The relation in the 2D case then becomes simply + ``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l` + to avoid complicated indexing tricks. + + In the 2D case, if one has the indices however, for some reason, the + permutation matrix is still needed then it can be constructed by + ``dpnp.eye(M)[P, :]``. + + Warning + ------- + This function synchronizes in order to validate array elements + when ``check_finite=True``, and also synchronizes to compute the + permutation from LAPACK pivot indices. + + See Also + -------- + :obj:`dpnp.scipy.linalg.lu_factor` : LU factorize a matrix + (compact representation). + :obj:`dpnp.scipy.linalg.lu_solve` : Solve an equation system using + the LU factorization of a matrix. + + Examples + -------- + >>> import dpnp as np + >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], + ... [7, 5, 6, 6], [5, 4, 4, 8]]) + >>> p, l, u = np.scipy.linalg.lu(A) + >>> np.allclose(A, p @ l @ u) + array(True) + + Retrieve the permutation as row indices with ``p_indices=True``: + + >>> p, l, u = np.scipy.linalg.lu(A, p_indices=True) + >>> p + array([1, 3, 0, 2]) + >>> np.allclose(A, l[p] @ u) + array(True) + + Return the permuted ``L`` directly with ``permute_l=True``: + + >>> pl, u = np.scipy.linalg.lu(A, permute_l=True) + >>> np.allclose(A, pl @ u) + array(True) + + Non-square matrices are supported: + + >>> B = np.array([[1, 2, 3], [4, 5, 6]]) + >>> p, l, u = np.scipy.linalg.lu(B) + >>> np.allclose(B, p @ l @ u) + array(True) + + Batched input: + + >>> C = np.random.randn(3, 2, 4, 4) + >>> p, l, u = np.scipy.linalg.lu(C) + >>> np.allclose(C, p @ l @ u) + array(True) + + """ + + dpnp.check_supported_arrays_type(a) + assert_stacked_2d(a) + + return dpnp_lu( + a, + overwrite_a=overwrite_a, + check_finite=check_finite, + p_indices=p_indices, + permute_l=permute_l, + ) + + def lu_factor(a, overwrite_a=False, check_finite=True): """ Compute the pivoted LU decomposition of `a` matrix. diff --git a/dpnp/scipy/linalg/_utils.py b/dpnp/scipy/linalg/_utils.py index f00db6fdfb9..1a62e8da721 100644 --- a/dpnp/scipy/linalg/_utils.py +++ b/dpnp/scipy/linalg/_utils.py @@ -83,6 +83,66 @@ def _align_lu_solve_broadcast(lu, b): return lu, b +def _apply_permutation_to_rows(mat, perm_indices): + """ + Apply a permutation to the rows (axis=-2) of a matrix. + + Returns ``out`` such that + ``out[..., i, :] = mat[..., perm_indices[..., i], :]``. + + For 2-D inputs this is equivalent to ``mat[perm_indices]`` (a single + device gather). For batched inputs :func:`dpnp.take_along_axis` is + used so the operation stays entirely on the device. + + Parameters + ---------- + mat : dpnp.ndarray, shape (..., M, N) + Matrix whose rows are to be permuted. + perm_indices : dpnp.ndarray, shape (..., M) + Permutation indices (dtype int64). + + Returns + ------- + out : dpnp.ndarray, shape (..., M, N) + Row-permuted matrix. + """ + + if perm_indices.ndim == 1: + # 2-D case: simple fancy indexing, single kernel launch. + return mat[perm_indices] + + # Batched case: ensure *mat* has the same batch dimensions as + # *perm_indices*. This is needed, for example, when permuting + # a shared identity matrix across a batch. + target_shape = perm_indices.shape[:-1] + mat.shape[-2:] + if mat.shape != target_shape: + mat = dpnp.broadcast_to(mat, target_shape) + + # Expand (..., M) → (..., M, 1), then broadcast to the full shape + # of *mat* so take_along_axis can gather along axis -2. + idx = dpnp.expand_dims(perm_indices, axis=-1) + idx = dpnp.broadcast_to(idx, target_shape).copy() + return dpnp.take_along_axis(mat, idx, axis=-2) + + +def _get_real_dtype(res_type): + """ + Return the real floating-point counterpart of *res_type*. + + SciPy uses the real dtype for the permutation matrix ``P`` even when + the input is complex (``P`` only contains 0s and 1s). + + ``float32`` and ``complex64`` → ``float32``; + ``float64`` and ``complex128`` → ``float64``. + """ + + if dpnp.issubdtype(res_type, dpnp.complexfloating): + return ( + dpnp.float32 if dpnp.dtype(res_type).itemsize <= 8 else dpnp.float64 + ) + return res_type + + def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals """SciPy-compatible LU factorization for batched inputs.""" @@ -338,6 +398,71 @@ def _map_trans_to_mkl(trans): raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") +def _pivots_to_permutation(piv, m): + """ + Convert 0-based LAPACK pivot indices (sequential row swaps) + to a permutation array. + + The returned permutation ``perm`` satisfies ``A[perm] = L @ U`` + (i.e. the forward row-permutation produced by LAPACK). + + The computation is performed entirely on the device. A host-side + Python loop of ``K = min(M, N)`` iterations drives the sequential + swap logic, but each iteration only launches device kernels + (:func:`dpnp.take_along_axis` for gather, + :func:`dpnp.put_along_axis` for scatter); **no data is transferred + between host and device**. + + .. note:: + + A future custom SYCL kernel could fuse all ``K`` swap steps + into a single launch to eliminate per-step kernel overhead. + + Parameters + ---------- + piv : dpnp.ndarray, shape (..., K) + 0-based pivot indices as returned by :obj:`dpnp_lu_factor`. + m : int + Number of rows of the original matrix. + + Returns + ------- + perm : dpnp.ndarray, shape (..., M), dtype int64 + Permutation indices. + """ + + batch_shape = piv.shape[:-1] + k = piv.shape[-1] + + # Initialise the identity permutation on the device. + perm = dpnp.broadcast_to( + dpnp.arange( + m, + dtype=dpnp.int64, + usm_type=piv.usm_type, + sycl_queue=piv.sycl_queue, + ), + (*batch_shape, m), + ).copy() + + # Apply sequential row swaps entirely on the device. + # Each iteration launches a small number of device kernels (gather + + # slice-assign + scatter) but never transfers data to the host. + for i in range(k): + # Pivot target for step *i*: shape (..., 1) + j = piv[..., i : i + 1] + + # Gather the two values to be swapped. + val_i = perm[..., i : i + 1].copy() # slice (free) + val_j = dpnp.take_along_axis(perm, j, axis=-1) # gather + + # Perform the swap. + perm[..., i : i + 1] = val_j # slice assign + dpnp.put_along_axis(perm, j, val_i, axis=-1) # scatter + + return perm + + def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): """ dpnp_lu_factor(a, overwrite_a=False, check_finite=True) @@ -432,6 +557,167 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): return (a_h, ipiv_h) +def dpnp_lu( + a, + overwrite_a=False, + check_finite=True, + p_indices=False, + permute_l=False, +): + """ + dpnp_lu(a, overwrite_a=False, check_finite=True, p_indices=False, + permute_l=False) + + Compute pivoted LU decomposition and return separate P, L, U matrices + (SciPy-compatible behavior). + + This function mimics the behavior of `scipy.linalg.lu` including + support for `permute_l`, `p_indices`, `overwrite_a`, and `check_finite`. + + """ + + a_sycl_queue = a.sycl_queue + a_usm_type = a.usm_type + m, n = a.shape[-2:] + k = min(m, n) + batch_shape = a.shape[:-2] + + res_type = _common_type(a) + + # The permutation matrix P uses a real dtype (SciPy convention): + # P only contains 0s and 1s, so complex storage would be wasteful. + real_type = _get_real_dtype(res_type) + + # ---- Fast path: scalar (1x1) matrices ---- + # For 1x1 input, P = I, L = I, U = A. This avoids invoking LAPACK + # entirely (matches SciPy's scalar fast path). + if m == 1 and n == 1: + if check_finite: + if not dpnp.isfinite(a).all(): + raise ValueError("array must not contain infs or NaNs") + + _L = dpnp.ones( + a.shape, + dtype=res_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + _U = dpnp.array(a, dtype=res_type) + + if permute_l: + return _L.copy(), _U + + if p_indices: + _p = dpnp.zeros( + (*batch_shape, 1), + dtype=dpnp.int64, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return _p, _L, _U + + _P = dpnp.ones( + a.shape, + dtype=real_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return _P, _L, _U + + # ---- Fast path: empty arrays ---- + if a.size == 0: + _L = dpnp.empty( + (*batch_shape, m, k), + dtype=res_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + _U = dpnp.empty( + (*batch_shape, k, n), + dtype=res_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + + if permute_l: + return _L, _U + + if p_indices: + _p = dpnp.empty( + (*batch_shape, m), + dtype=dpnp.int64, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return _p, _L, _U + + _P = dpnp.empty( + (*batch_shape, m, m), + dtype=real_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return _P, _L, _U + + # ---- General case: LAPACK factorization ---- + lu_compact, piv = dpnp_lu_factor( + a, overwrite_a=overwrite_a, check_finite=check_finite + ) + + # ---- Extract L: lower-triangular with unit diagonal ---- + # L has shape (..., M, K). + _L = dpnp.tril(lu_compact[..., :, :k], k=-1) + _L += dpnp.eye( + m, + k, + dtype=lu_compact.dtype, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + + # ---- Extract U: upper-triangular ---- + # U has shape (..., K, N). + _U = dpnp.triu(lu_compact[..., :k, :]) + + # ---- Convert pivot indices → row permutation ---- + # ``perm`` (forward): A[perm] = L @ U. + # This is the only step that requires a host transfer because the + # sequential swap semantics of LAPACK pivots cannot be parallelised. + # Only the small pivot array (min(M, N) elements per slice) is + # transferred; all subsequent work stays on the device. + perm = _pivots_to_permutation(piv, m) + + # ``inv_perm`` (inverse): A = L[inv_perm] @ U. + # This is SciPy's ``p_indices`` convention. + # ``dpnp.argsort`` is an efficient on-device O(M log M) operation + # that avoids a second host round-trip. + inv_perm = dpnp.argsort(perm, axis=-1).astype(dpnp.int64) + + if permute_l: + # Return (PL, U) where PL = P @ L = L[inv_perm]. + # A = PL @ U directly. + _PL = _apply_permutation_to_rows(_L, inv_perm) + return _PL, _U + + if p_indices: + # SciPy convention: A = L[inv_perm] @ U. + return inv_perm, _L, _U + + # ---- Build full permutation matrix P = I[inv_perm] ---- + # P has shape (..., M, M) with real dtype (SciPy convention). + # The gather from an identity matrix is efficient on device: + # each output row selects one row of the identity (one hot encoding). + _I = dpnp.eye( + m, + dtype=real_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + _P = _apply_permutation_to_rows(_I, inv_perm) + + return _P, _L, _U + + def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): """ dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 31d99d71ce4..39ddff42697 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2605,6 +2605,463 @@ def test_invalid_shapes(self, a_shape, b_shape): dpnp.scipy.linalg.lu_solve((lu, piv), b, check_finite=False) +class TestLu: + @staticmethod + def _make_nonsingular_np(shape, dtype, order): + A = generate_random_numpy_array(shape, dtype, order) + m, n = shape + k = min(m, n) + for i in range(k): + off = numpy.sum(numpy.abs(A[i, :n])) - numpy.abs(A[i, i]) + A[i, i] = A.dtype.type(off + 1.0) + return A + + @pytest.mark.parametrize( + "shape", + [(1, 1), (2, 2), (3, 3), (1, 5), (5, 1), (2, 5), (5, 2)], + ) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_lu_default(self, shape, order, dtype): + a_np = self._make_nonsingular_np(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + P, L, U = dpnp.scipy.linalg.lu(a_dp) + + m, n = shape + k = min(m, n) + assert P.shape == (m, m) + assert L.shape == (m, k) + assert U.shape == (k, n) + + A_cast = a_dp.astype(L.dtype, copy=False) + A_rec = P @ L @ U + assert dpnp.allclose(A_rec, A_cast, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize( + "shape", + [(1, 1), (2, 2), (3, 3), (1, 5), (5, 1), (2, 5), (5, 2)], + ) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_lu_permute_l(self, shape, order, dtype): + a_np = self._make_nonsingular_np(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + PL, U = dpnp.scipy.linalg.lu(a_dp, permute_l=True) + + m, n = shape + k = min(m, n) + assert PL.shape == (m, k) + assert U.shape == (k, n) + + A_cast = a_dp.astype(PL.dtype, copy=False) + A_rec = PL @ U + assert dpnp.allclose(A_rec, A_cast, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize( + "shape", + [(1, 1), (2, 2), (3, 3), (1, 5), (5, 1), (2, 5), (5, 2)], + ) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_lu_p_indices(self, shape, order, dtype): + a_np = self._make_nonsingular_np(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + p, L, U = dpnp.scipy.linalg.lu(a_dp, p_indices=True) + + m, n = shape + k = min(m, n) + assert p.shape == (m,) + assert L.shape == (m, k) + assert U.shape == (k, n) + assert dpnp.issubdtype(p.dtype, dpnp.integer) + + p_np = dpnp.asnumpy(p) + L_np = dpnp.asnumpy(L) + U_np = dpnp.asnumpy(U) + A_rec = L_np[p_np] @ U_np + A_cast = a_dp.astype(L.dtype, copy=False) + assert_allclose(A_rec, dpnp.asnumpy(A_cast), rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize( + "in_dtype, expected_p_dtype", + [ + (dpnp.float32, dpnp.float32), + (dpnp.float64, dpnp.float64), + (dpnp.complex64, dpnp.float32), + (dpnp.complex128, dpnp.float64), + ], + ) + def test_p_matrix_dtype(self, in_dtype, expected_p_dtype): + if in_dtype in (dpnp.float64, dpnp.complex128): + if not has_support_aspect64(): + pytest.skip("fp64 not supported on this device") + + a_np = self._make_nonsingular_np((4, 4), in_dtype, "F") + a_dp = dpnp.array(a_np, order="F") + P, L, U = dpnp.scipy.linalg.lu(a_dp) + + assert P.dtype == expected_p_dtype + assert dpnp.issubdtype(P.dtype, dpnp.floating) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_p_indices_dtype(self, dtype): + a_np = self._make_nonsingular_np((4, 4), dtype, "F") + a_dp = dpnp.array(a_np, order="F") + p, _, _ = dpnp.scipy.linalg.lu(a_dp, p_indices=True) + assert dpnp.issubdtype(p.dtype, dpnp.integer) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_l_structure(self, dtype): + a_np = self._make_nonsingular_np((5, 5), dtype, "F") + a_dp = dpnp.array(a_np, order="F") + _, L, _ = dpnp.scipy.linalg.lu(a_dp) + L_np = dpnp.asnumpy(L) + + # unit diagonal + diag_abs = numpy.abs(numpy.diag(L_np)) + assert_allclose(diag_abs, numpy.ones(5, dtype=diag_abs.dtype)) + # lower triangular + assert_allclose(numpy.triu(L_np, 1), numpy.zeros_like(L_np)) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_u_upper_triangular(self, dtype): + a_np = self._make_nonsingular_np((5, 5), dtype, "F") + a_dp = dpnp.array(a_np, order="F") + _, _, U = dpnp.scipy.linalg.lu(a_dp) + U_np = dpnp.asnumpy(U) + assert_allclose(numpy.tril(U_np, -1), numpy.zeros_like(U_np)) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_p_is_permutation(self, dtype): + a_np = self._make_nonsingular_np((5, 5), dtype, "F") + a_dp = dpnp.array(a_np, order="F") + P, _, _ = dpnp.scipy.linalg.lu(a_dp) + P_np = dpnp.asnumpy(P) + + assert_allclose(P_np.sum(axis=0), numpy.ones(5, dtype=P_np.dtype)) + assert_allclose(P_np.sum(axis=1), numpy.ones(5, dtype=P_np.dtype)) + assert_allclose( + P_np.T @ P_np, numpy.eye(5, dtype=P_np.dtype), atol=1e-15 + ) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_modes_consistency(self, dtype): + a_np = self._make_nonsingular_np((5, 5), dtype, "F") + a_dp = dpnp.array(a_np, order="F") + + P, L, U = dpnp.scipy.linalg.lu(a_dp) + PL, U2 = dpnp.scipy.linalg.lu(a_dp, permute_l=True) + p, L3, U3 = dpnp.scipy.linalg.lu(a_dp, p_indices=True) + + A_cast = a_dp.astype(L.dtype, copy=False) + A1 = P @ L @ U + A2 = PL @ U2 + p_np = dpnp.asnumpy(p) + A3_np = dpnp.asnumpy(L3)[p_np] @ dpnp.asnumpy(U3) + + assert dpnp.allclose(A1, A_cast, rtol=1e-6, atol=1e-6) + assert dpnp.allclose(A2, A_cast, rtol=1e-6, atol=1e-6) + assert_allclose(A3_np, dpnp.asnumpy(A_cast), rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_p_times_l_equals_pl(self, dtype): + a_np = self._make_nonsingular_np((5, 5), dtype, "F") + a_dp = dpnp.array(a_np, order="F") + P, L, _ = dpnp.scipy.linalg.lu(a_dp) + PL, _ = dpnp.scipy.linalg.lu(a_dp, permute_l=True) + assert dpnp.allclose(P @ L, PL, rtol=1e-12, atol=1e-12) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_p_indices_to_matrix(self, dtype): + a_np = self._make_nonsingular_np((5, 5), dtype, "F") + a_dp = dpnp.array(a_np, order="F") + P, _, _ = dpnp.scipy.linalg.lu(a_dp) + p, _, _ = dpnp.scipy.linalg.lu(a_dp, p_indices=True) + P_from_idx = dpnp.eye(5, dtype=P.dtype)[p] + assert dpnp.allclose(P_from_idx, P, rtol=1e-15, atol=1e-15) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_overwrite_a_false(self, dtype): + a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F") + a_dp_orig = a_dp.copy() + dpnp.scipy.linalg.lu(a_dp, overwrite_a=False) + assert dpnp.allclose(a_dp, a_dp_orig) + + @pytest.mark.parametrize("shape", [(0, 0), (0, 2), (2, 0)]) + def test_empty_inputs(self, shape): + a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + P, L, U = dpnp.scipy.linalg.lu(a_dp) + m, n = shape + k = min(m, n) + assert P.shape == (m, m) + assert L.shape == (m, k) + assert U.shape == (k, n) + + @pytest.mark.parametrize("shape", [(0, 0), (0, 2), (2, 0)]) + def test_empty_permute_l(self, shape): + a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + PL, U = dpnp.scipy.linalg.lu(a_dp, permute_l=True) + m, n = shape + k = min(m, n) + assert PL.shape == (m, k) + assert U.shape == (k, n) + + @pytest.mark.parametrize("shape", [(0, 0), (0, 2), (2, 0)]) + def test_empty_p_indices(self, shape): + a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + p, L, U = dpnp.scipy.linalg.lu(a_dp, p_indices=True) + m, n = shape + k = min(m, n) + assert p.shape == (m,) + assert L.shape == (m, k) + assert U.shape == (k, n) + + @pytest.mark.parametrize( + "sl", + [ + (slice(None, None, 2), slice(None, None, 2)), + (slice(None, None, -1), slice(None, None, -1)), + ], + ) + def test_strided(self, sl): + base = self._make_nonsingular_np((7, 7), dpnp.default_float_type(), "F") + a_np = base[sl] + a_dp = dpnp.array(a_np) + + P, L, U = dpnp.scipy.linalg.lu(a_dp) + A_rec = P @ L @ U + assert dpnp.allclose(A_rec, a_dp, rtol=1e-6, atol=1e-6) + + def test_singular_matrix(self): + a_np = numpy.array([[1.0, 2.0], [2.0, 4.0]]) + a_dp = dpnp.array(a_np) + P, L, U = dpnp.scipy.linalg.lu(a_dp) + A_rec = dpnp.asnumpy(P @ L @ U) + assert_allclose(A_rec, a_np, atol=1e-12) + + def test_identity_matrix(self): + n = 4 + I_dp = dpnp.eye(n, dtype=dpnp.default_float_type()) + P, L, U = dpnp.scipy.linalg.lu(I_dp) + I_np = numpy.eye(n) + assert_allclose(dpnp.asnumpy(P), I_np, atol=1e-15) + assert_allclose(dpnp.asnumpy(L), I_np, atol=1e-15) + assert_allclose(dpnp.asnumpy(U), I_np, atol=1e-15) + + def test_1d_input_raises(self): + a_dp = dpnp.array([1.0, 2.0, 3.0]) + with pytest.raises(ValueError): + dpnp.scipy.linalg.lu(a_dp) + + @pytest.mark.parametrize("bad", [numpy.inf, -numpy.inf, numpy.nan]) + def test_check_finite_raises(self, bad): + a_dp = dpnp.array([[1.0, 2.0], [3.0, bad]], order="F") + assert_raises(ValueError, dpnp.scipy.linalg.lu, a_dp, check_finite=True) + + def test_check_finite_disabled(self): + a_dp = dpnp.array([[1.0, numpy.nan], [3.0, 4.0]]) + result = dpnp.scipy.linalg.lu(a_dp, check_finite=False) + assert len(result) == 3 + + +class TestLuBatched: + @staticmethod + def _make_nonsingular_nd_np(shape, dtype, order): + A = generate_random_numpy_array(shape, dtype, order) + m, n = shape[-2], shape[-1] + k = min(m, n) + A3 = A.reshape((-1, m, n)) + for B in A3: + for i in range(k): + off = numpy.sum(numpy.abs(B[i, :n])) - numpy.abs(B[i, i]) + B[i, i] = A.dtype.type(off + 1.0) + A = A3.reshape(shape) + A = numpy.array(A, order=order) + return A + + @staticmethod + def _reconstruct_p_indices(p, L, U): + """Reconstruct A from (p, L, U) for batched p_indices mode.""" + idx = dpnp.expand_dims(p, axis=-1) + idx = dpnp.broadcast_to(idx, L.shape).copy() + PL = dpnp.take_along_axis(L, idx, axis=-2) + return PL @ U + + @pytest.mark.parametrize( + "shape", + [(2, 2, 2), (3, 4, 4), (2, 3, 5, 2), (4, 1, 3)], + ids=["(2,2,2)", "(3,4,4)", "(2,3,5,2)", "(4,1,3)"], + ) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_lu_default_batched(self, shape, order, dtype): + a_np = self._make_nonsingular_nd_np(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + P, L, U = dpnp.scipy.linalg.lu(a_dp) + + m, n = shape[-2], shape[-1] + k = min(m, n) + assert P.shape == (*shape[:-2], m, m) + assert L.shape == (*shape[:-2], m, k) + assert U.shape == (*shape[:-2], k, n) + + A_cast = a_dp.astype(L.dtype, copy=False) + A_rec = P @ L @ U + assert dpnp.allclose(A_rec, A_cast, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize( + "shape", + [(2, 2, 2), (3, 4, 4), (2, 3, 5, 2), (4, 1, 3)], + ids=["(2,2,2)", "(3,4,4)", "(2,3,5,2)", "(4,1,3)"], + ) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_lu_permute_l_batched(self, shape, order, dtype): + a_np = self._make_nonsingular_nd_np(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + PL, U = dpnp.scipy.linalg.lu(a_dp, permute_l=True) + + m, n = shape[-2], shape[-1] + k = min(m, n) + assert PL.shape == (*shape[:-2], m, k) + assert U.shape == (*shape[:-2], k, n) + + A_cast = a_dp.astype(PL.dtype, copy=False) + A_rec = PL @ U + assert dpnp.allclose(A_rec, A_cast, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize( + "shape", + [(2, 2, 2), (3, 4, 4), (2, 3, 5, 2), (4, 1, 3)], + ids=["(2,2,2)", "(3,4,4)", "(2,3,5,2)", "(4,1,3)"], + ) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_lu_p_indices_batched(self, shape, order, dtype): + a_np = self._make_nonsingular_nd_np(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + p, L, U = dpnp.scipy.linalg.lu(a_dp, p_indices=True) + + m, n = shape[-2], shape[-1] + k = min(m, n) + assert p.shape == (*shape[:-2], m) + assert L.shape == (*shape[:-2], m, k) + assert U.shape == (*shape[:-2], k, n) + assert dpnp.issubdtype(p.dtype, dpnp.integer) + + A_cast = a_dp.astype(L.dtype, copy=False) + A_rec = self._reconstruct_p_indices(p, L, U) + assert dpnp.allclose(A_rec, A_cast, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_overwrite_a(self, dtype, order): + a_np = self._make_nonsingular_nd_np((3, 2, 2), dtype, order) + a_dp = dpnp.array(a_np, order=order) + a_dp_orig = a_dp.copy() + + dpnp.scipy.linalg.lu(a_dp, overwrite_a=False) + assert dpnp.allclose(a_dp, a_dp_orig) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_modes_consistency_batched(self, dtype): + a_np = self._make_nonsingular_nd_np((3, 4, 4), dtype, "F") + a_dp = dpnp.array(a_np, order="F") + A_cast = a_dp.astype( + ( + dpnp.complex128 + if dpnp.issubdtype(dtype, dpnp.complexfloating) + else dpnp.float64 + ), + copy=False, + ) + + P, L, U = dpnp.scipy.linalg.lu(a_dp) + PL, U2 = dpnp.scipy.linalg.lu(a_dp, permute_l=True) + p, L3, U3 = dpnp.scipy.linalg.lu(a_dp, p_indices=True) + + A1 = P @ L @ U + A2 = PL @ U2 + A3 = self._reconstruct_p_indices(p, L3, U3) + + A_cast2 = a_dp.astype(L.dtype, copy=False) + assert dpnp.allclose(A1, A_cast2, rtol=1e-6, atol=1e-6) + assert dpnp.allclose(A2, A_cast2, rtol=1e-6, atol=1e-6) + assert dpnp.allclose(A3, A_cast2, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize( + "shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)] + ) + def test_empty_inputs(self, shape): + a = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + + P, L, U = dpnp.scipy.linalg.lu(a) + m, n = shape[-2:] + k = min(m, n) + assert P.shape == (*shape[:-2], m, m) + assert L.shape == (*shape[:-2], m, k) + assert U.shape == (*shape[:-2], k, n) + + @pytest.mark.parametrize( + "shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)] + ) + def test_empty_permute_l(self, shape): + a = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + + PL, U = dpnp.scipy.linalg.lu(a, permute_l=True) + m, n = shape[-2:] + k = min(m, n) + assert PL.shape == (*shape[:-2], m, k) + assert U.shape == (*shape[:-2], k, n) + + @pytest.mark.parametrize( + "shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)] + ) + def test_empty_p_indices(self, shape): + a = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + + p, L, U = dpnp.scipy.linalg.lu(a, p_indices=True) + m, n = shape[-2:] + k = min(m, n) + assert p.shape == (*shape[:-2], m) + assert L.shape == (*shape[:-2], m, k) + assert U.shape == (*shape[:-2], k, n) + + def test_strided(self): + a_np = self._make_nonsingular_nd_np( + (5, 3, 3), dpnp.default_float_type(), "F" + ) + a_dp = dpnp.array(a_np, order="F") + a_stride = a_dp[::2] + + P, L, U = dpnp.scipy.linalg.lu(a_stride) + for i in range(a_stride.shape[0]): + A_rec = dpnp.asnumpy(P[i] @ L[i] @ U[i]) + A_orig = dpnp.asnumpy(a_stride[i].astype(L.dtype, copy=False)) + assert_allclose(A_rec, A_orig, rtol=1e-6, atol=1e-6) + + def test_singular_matrix(self): + a = dpnp.zeros((3, 2, 2), dtype=dpnp.default_float_type()) + a[0] = dpnp.array([[1.0, 2.0], [2.0, 4.0]]) + a[1] = dpnp.eye(2) + a[2] = dpnp.array([[1.0, 1.0], [1.0, 1.0]]) + + P, L, U = dpnp.scipy.linalg.lu(a) + A_rec = P @ L @ U + assert dpnp.allclose(A_rec, a, rtol=1e-6, atol=1e-6) + + def test_check_finite_raises(self): + a = dpnp.ones((2, 3, 3), dtype=dpnp.default_float_type(), order="F") + a[1, 0, 0] = dpnp.nan + assert_raises(ValueError, dpnp.scipy.linalg.lu, a, check_finite=True) + + class TestMatrixPower: @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index d1853579036..b0155658a4d 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -1666,6 +1666,14 @@ def test_lu_factor(self, data, device): param_queue = param.sycl_queue assert_sycl_queue_equal(param_queue, a.sycl_queue) + def test_lu(self, data, device): + a = dpnp.array(data, device=device) + result = dpnp.scipy.linalg.lu(a) + + for param in result: + param_queue = param.sycl_queue + assert_sycl_queue_equal(param_queue, a.sycl_queue) + @pytest.mark.parametrize( "a_data, b_data", [ diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index 4fc0f2b958f..59a6b05f688 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1531,6 +1531,14 @@ def test_lstsq(self, m, n, nrhs, usm_type, usm_type_other): "data", [[[1.0, 2.0], [3.0, 5.0]], [[]], [[[1.0, 2.0], [3.0, 5.0]]], [[[]]]], ) + def test_lu(self, data, usm_type): + a = dpnp.array(data, usm_type=usm_type) + result = dpnp.scipy.linalg.lu(a) + + assert a.usm_type == usm_type + for param in result: + assert param.usm_type == a.usm_type + def test_lu_factor(self, data, usm_type): a = dpnp.array(data, usm_type=usm_type) result = dpnp.scipy.linalg.lu_factor(a)