Skip to content

Commit 1272f99

Browse files
committed
[update] implementation for larger matrices to use Vec<f32> instead of a Vec of Vecs
1 parent f7498e2 commit 1272f99

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

crates/lambda-rs/src/math/matrix.rs

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -437,21 +437,22 @@ where
437437
return Ok(determinant_gaussian_stack::<4>(data));
438438
}
439439

440-
// Convert to a mutable dense matrix so we can perform in-place elimination.
441-
let mut working = vec![vec![0.0_f32; rows]; rows];
442-
for (i, row) in working.iter_mut().enumerate() {
443-
for (j, value) in row.iter_mut().enumerate() {
444-
*value = self.at(i, j);
440+
// Use a contiguous row-major buffer for larger matrices to keep the
441+
// fallback cache-friendlier than `Vec<Vec<f32>>`.
442+
let mut working = vec![0.0_f32; rows * rows];
443+
for i in 0..rows {
444+
for j in 0..rows {
445+
working[i * rows + j] = self.at(i, j);
445446
}
446447
}
447448

448449
let mut sign = 1.0_f32;
449450
for pivot in 0..rows {
450451
// Partial pivoting improves numerical stability and avoids division by 0.
451452
let mut pivot_row = pivot;
452-
let mut pivot_abs = working[pivot][pivot].abs();
453-
for (candidate, row) in working.iter().enumerate().skip(pivot + 1) {
454-
let candidate_abs = row[pivot].abs();
453+
let mut pivot_abs = working[pivot * rows + pivot].abs();
454+
for candidate in (pivot + 1)..rows {
455+
let candidate_abs = working[candidate * rows + pivot].abs();
455456
if candidate_abs > pivot_abs {
456457
pivot_abs = candidate_abs;
457458
pivot_row = candidate;
@@ -463,24 +464,28 @@ where
463464
}
464465

465466
if pivot_row != pivot {
466-
working.swap(pivot, pivot_row);
467+
for column in 0..rows {
468+
working.swap(pivot * rows + column, pivot_row * rows + column);
469+
}
467470
sign = -sign;
468471
}
469472

470-
let pivot_value = working[pivot][pivot];
471-
let pivot_tail: Vec<f32> = working[pivot][pivot..].to_vec();
472-
for row in working.iter_mut().skip(pivot + 1) {
473-
let factor = row[pivot] / pivot_value;
473+
let pivot_value = working[pivot * rows + pivot];
474+
for row in (pivot + 1)..rows {
475+
let factor = working[row * rows + pivot] / pivot_value;
474476
if factor == 0.0 {
475477
continue;
476478
}
477-
for (dst, src) in row[pivot..].iter_mut().zip(pivot_tail.iter()) {
478-
*dst -= factor * *src;
479+
for column in pivot..rows {
480+
let row_idx = row * rows + column;
481+
let pivot_idx = pivot * rows + column;
482+
working[row_idx] -= factor * working[pivot_idx];
479483
}
480484
}
481485
}
482486

483-
let diagonal_product = (0..rows).map(|i| working[i][i]).product::<f32>();
487+
let diagonal_product =
488+
(0..rows).map(|i| working[i * rows + i]).product::<f32>();
484489
return Ok(sign * diagonal_product);
485490
}
486491

0 commit comments

Comments
 (0)