Skip to content

Commit e6c8a35

Browse files
committed
feat: Add Splay Tree implementation
1 parent 456d644 commit e6c8a35

1 file changed

Lines changed: 191 additions & 0 deletions

File tree

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# 文件名:data_structures/binary_tree/splay_tree.py
2+
from __future__ import annotations
3+
4+
class Node:
5+
"""Splay树节点类"""
6+
def __init__(self, key: int):
7+
self.key = key
8+
self.left: Node | None = None
9+
self.right: Node | None = None
10+
11+
class SplayTree:
12+
"""
13+
伸展树(Splay Tree)实现
14+
特性:每次访问节点后,将该节点旋转到根位置,优化局部性访问性能
15+
"""
16+
def __init__(self):
17+
self.root: Node | None = None
18+
19+
def _right_rotate(self, x: Node) -> Node:
20+
"""右旋操作(zig)"""
21+
y = x.left
22+
x.left = y.right
23+
y.right = x
24+
return y
25+
26+
def _left_rotate(self, x: Node) -> Node:
27+
"""左旋操作(zig)"""
28+
y = x.right
29+
x.right = y.left
30+
y.left = x
31+
return y
32+
33+
def _splay(self, root: Node | None, key: int) -> Node | None:
34+
"""
35+
核心伸展操作:将包含key的节点旋转到根
36+
包含zig、zig-zig、zig-zag三种模式
37+
"""
38+
if root is None or root.key == key:
39+
return root
40+
41+
# 目标节点在左子树
42+
if key < root.key:
43+
if root.left is None:
44+
return root
45+
# Zig-Zig模式:左-左
46+
if key < root.left.key:
47+
root.left.left = self._splay(root.left.left, key)
48+
root = self._right_rotate(root)
49+
# Zig-Zag模式:左-右
50+
elif key > root.left.key:
51+
root.left.right = self._splay(root.left.right, key)
52+
if root.left.right:
53+
root.left = self._left_rotate(root.left)
54+
return root.left if root.left is None else self._right_rotate(root)
55+
56+
# 目标节点在右子树
57+
else:
58+
if root.right is None:
59+
return root
60+
# Zig-Zig模式:右-右
61+
if key > root.right.key:
62+
root.right.right = self._splay(root.right.right, key)
63+
root = self._left_rotate(root)
64+
# Zig-Zag模式:右-左
65+
elif key < root.right.key:
66+
root.right.left = self._splay(root.right.left, key)
67+
if root.right.left:
68+
root.right = self._right_rotate(root.right)
69+
return root.right if root.right is None else self._left_rotate(root)
70+
71+
def search(self, key: int) -> bool:
72+
"""搜索指定key,存在返回True,不存在返回False"""
73+
self.root = self._splay(self.root, key)
74+
return self.root is not None and self.root.key == key
75+
76+
def insert(self, key: int) -> None:
77+
"""插入新节点"""
78+
if self.root is None:
79+
self.root = Node(key)
80+
return
81+
82+
self.root = self._splay(self.root, key)
83+
if self.root.key == key:
84+
return # 已存在,无需插入
85+
86+
new_node = Node(key)
87+
if key < self.root.key:
88+
new_node.right = self.root
89+
new_node.left = self.root.left
90+
self.root.left = None
91+
else:
92+
new_node.left = self.root
93+
new_node.right = self.root.right
94+
self.root.right = None
95+
self.root = new_node
96+
97+
def delete(self, key: int) -> None:
98+
"""删除指定key的节点"""
99+
if self.root is None:
100+
return
101+
102+
self.root = self._splay(self.root, key)
103+
if self.root.key != key:
104+
return # 节点不存在
105+
106+
# 左右子树合并
107+
if self.root.left is None:
108+
self.root = self.root.right
109+
else:
110+
temp = self.root.right
111+
self.root = self.root.left
112+
self.root = self._splay(self.root, key)
113+
self.root.right = temp
114+
115+
def find_min(self) -> int | None:
116+
"""查找树中最小值"""
117+
if self.root is None:
118+
return None
119+
current = self.root
120+
while current.left:
121+
current = current.left
122+
self.root = self._splay(self.root, current.key)
123+
return current.key
124+
125+
def find_max(self) -> int | None:
126+
"""查找树中最大值"""
127+
if self.root is None:
128+
return None
129+
current = self.root
130+
while current.right:
131+
current = current.right
132+
self.root = self._splay(self.root, current.key)
133+
return current.key
134+
135+
def inorder_traversal(self, root: Node | None, result: list[int]) -> None:
136+
"""中序遍历(有序输出)"""
137+
if root:
138+
self.inorder_traversal(root.left, result)
139+
result.append(root.key)
140+
self.inorder_traversal(root.right, result)
141+
142+
def get_size(self, root: Node | None) -> int:
143+
"""获取树的节点总数"""
144+
if root is None:
145+
return 0
146+
return 1 + self.get_size(root.left) + self.get_size(root.right)
147+
148+
def get_height(self, root: Node | None) -> int:
149+
"""获取树的高度"""
150+
if root is None:
151+
return -1
152+
return 1 + max(self.get_height(root.left), self.get_height(root.right))
153+
154+
def print_tree(self, root: Node | None, indent: str = "", last: bool = True) -> None:
155+
"""可视化打印树结构"""
156+
if root:
157+
print(indent + ("└── " if last else "├── ") + str(root.key))
158+
indent += " " if last else "│ "
159+
self.print_tree(root.left, indent, False)
160+
self.print_tree(root.right, indent, True)
161+
162+
# 测试用例
163+
if __name__ == "__main__":
164+
tree = SplayTree()
165+
# 插入测试
166+
for key in [10, 20, 30, 40, 50, 25]:
167+
tree.insert(key)
168+
print("中序遍历:", end=" ")
169+
traversal_result = []
170+
tree.inorder_traversal(tree.root, traversal_result)
171+
print(traversal_result)
172+
173+
# 搜索测试(访问后节点会被伸展到根)
174+
print("\n搜索25:", tree.search(25))
175+
print("搜索后根节点:", tree.root.key)
176+
177+
# 删除测试
178+
tree.delete(30)
179+
print("\n删除30后中序遍历:", end=" ")
180+
traversal_result = []
181+
tree.inorder_traversal(tree.root, traversal_result)
182+
print(traversal_result)
183+
184+
# 辅助方法测试
185+
print("\n树大小:", tree.get_size(tree.root))
186+
print("树高度:", tree.get_height(tree.root))
187+
print("最小值:", tree.find_min())
188+
print("最大值:", tree.find_max())
189+
190+
print("\n树结构:")
191+
tree.print_tree(tree.root)

0 commit comments

Comments
 (0)