|
26 | 26 |
|
27 | 27 | import unittest |
28 | 28 | import numpy as np |
| 29 | +from itertools import product |
29 | 30 | import modmesh as mm |
30 | 31 |
|
31 | 32 |
|
@@ -75,16 +76,47 @@ def test_identity_matrix(self): |
75 | 76 | # 3x3 matrix |
76 | 77 | a_data = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], |
77 | 78 | [7.0, 8.0, 9.0]], dtype=self.dtype) |
78 | | - identity_data = np.eye(3, dtype=self.dtype) |
79 | 79 |
|
80 | 80 | a = self.SimpleArray(array=a_data) |
81 | | - identity = self.SimpleArray(array=identity_data) |
| 81 | + identity = self.SimpleArray.eye(3) |
82 | 82 |
|
83 | 83 | result = a.matmul(identity) |
84 | 84 |
|
85 | 85 | self.assertEqual(list(result.shape), [3, 3]) |
86 | 86 | np.testing.assert_array_almost_equal(result.ndarray, a_data) |
87 | 87 |
|
| 88 | + def test_eye_method(self): |
| 89 | + """Test eye method creates correct identity matrices""" |
| 90 | + # Test cases: different sizes |
| 91 | + test_sizes = [1, 2, 3, 4, 5, 10] |
| 92 | + |
| 93 | + for size in test_sizes: |
| 94 | + with self.subTest(size=size): |
| 95 | + # Create identity matrix using our eye method |
| 96 | + identity = self.SimpleArray.eye(size) |
| 97 | + |
| 98 | + # Create expected identity matrix using NumPy |
| 99 | + expected = np.eye(size, dtype=self.dtype) |
| 100 | + |
| 101 | + # Check shape |
| 102 | + self.assertEqual(list(identity.shape), [size, size]) |
| 103 | + |
| 104 | + # Check array values |
| 105 | + np.testing.assert_array_almost_equal(identity.ndarray, |
| 106 | + expected) |
| 107 | + |
| 108 | + # Verify diagonal and off-diagonal elements explicitly |
| 109 | + # using product |
| 110 | + for i, j in product(range(size), repeat=2): |
| 111 | + if i == j: |
| 112 | + self.assertEqual(identity[i, j], 1.0, |
| 113 | + f"Diagonal element ({i},{j}) " |
| 114 | + f"should be 1.0") |
| 115 | + else: |
| 116 | + self.assertEqual(identity[i, j], 0.0, |
| 117 | + f"Off-diagonal element ({i},{j}) " |
| 118 | + f"should be 0.0") |
| 119 | + |
88 | 120 | def test_zero_matrix(self): |
89 | 121 | """Test multiplication with zero matrix""" |
90 | 122 | a_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) |
|
0 commit comments