Skip to content

Commit ba91771

Browse files
committed
Add gram-schmidt implementation
1 parent 678dedb commit ba91771

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

linear_algebra/gram_schmidt.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import math
2+
3+
4+
def gram_schmidt(input_vectors: list[list[float]]) -> list[list[float]]:
5+
"""
6+
Implements the Gram-Schmidt process to orthonormalize a set of vectors.
7+
Reference: https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
8+
9+
Case 1: Standard 2D Orthonormalization
10+
>>> v1 = [[3.0, 1.0], [2.0, 2.0]]
11+
>>> result1 = gram_schmidt(v1)
12+
>>> [[round(x, 3) for x in v] for v in result1]
13+
[[0.949, 0.316], [-0.316, 0.949]]
14+
15+
Case 2: 3D Vectors (The example from your error log)
16+
>>> v2 = [[1.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]]
17+
>>> result2 = gram_schmidt(v2)
18+
>>> [[round(x, 3) for x in v] for v in result2]
19+
[[0.707, 0.707, 0.0], [0.408, -0.408, 0.816], [-0.577, 0.577, 0.577]]
20+
21+
Case 3: Vectors that are already orthonormal (should remain unchanged)
22+
>>> v3 = [[1.0, 0.0], [0.0, 1.0]]
23+
>>> gram_schmidt(v3)
24+
[[1.0, 0.0], [0.0, 1.0]]
25+
"""
26+
27+
orthonormal_basis: list[list[float]] = []
28+
29+
for v in input_vectors:
30+
# Create a copy of the current vector to work on (equivalent to v_dash)
31+
v_orthogonal = list(v)
32+
33+
for u in orthonormal_basis:
34+
# Manual Dot Product: (v_dash * u)
35+
dot_product = sum(vi * ui for vi, ui in zip(v_orthogonal, u))
36+
37+
# Manual Vector Subtraction & Scalar Mult: v_dash = v_dash - u * dot_product
38+
v_orthogonal = [vi - (dot_product * ui) for vi, ui in zip(v_orthogonal, u)]
39+
40+
# Manual Norm Calculation: (v_dash * v_dash) ** 0.5
41+
norm = math.sqrt(sum(xi**2 for xi in v_orthogonal))
42+
43+
if norm < 1e-15:
44+
raise ValueError("The vectors are not linearly independent.")
45+
46+
# Manual Scalar Multiplication: u_new = v_dash * (1/norm)
47+
u_new = [xi / norm for xi in v_orthogonal]
48+
orthonormal_basis.append(u_new)
49+
50+
return orthonormal_basis
51+
52+
53+
if __name__ == "__main__":
54+
import doctest
55+
56+
doctest.testmod()

0 commit comments

Comments
 (0)