1+ # 文件名:data_structures/binary_tree/splay_tree.py
2+ from __future__ import annotations
3+
4+ class Node :
5+ """Splay树节点类"""
6+ def __init__ (self , key : int ):
7+ self .key = key
8+ self .left : Node | None = None
9+ self .right : Node | None = None
10+
11+ class SplayTree :
12+ """
13+ 伸展树(Splay Tree)实现
14+ 特性:每次访问节点后,将该节点旋转到根位置,优化局部性访问性能
15+ """
16+ def __init__ (self ):
17+ self .root : Node | None = None
18+
19+ def _right_rotate (self , x : Node ) -> Node :
20+ """右旋操作(zig)"""
21+ y = x .left
22+ x .left = y .right
23+ y .right = x
24+ return y
25+
26+ def _left_rotate (self , x : Node ) -> Node :
27+ """左旋操作(zig)"""
28+ y = x .right
29+ x .right = y .left
30+ y .left = x
31+ return y
32+
33+ def _splay (self , root : Node | None , key : int ) -> Node | None :
34+ """
35+ 核心伸展操作:将包含key的节点旋转到根
36+ 包含zig、zig-zig、zig-zag三种模式
37+ """
38+ if root is None or root .key == key :
39+ return root
40+
41+ # 目标节点在左子树
42+ if key < root .key :
43+ if root .left is None :
44+ return root
45+ # Zig-Zig模式:左-左
46+ if key < root .left .key :
47+ root .left .left = self ._splay (root .left .left , key )
48+ root = self ._right_rotate (root )
49+ # Zig-Zag模式:左-右
50+ elif key > root .left .key :
51+ root .left .right = self ._splay (root .left .right , key )
52+ if root .left .right :
53+ root .left = self ._left_rotate (root .left )
54+ return root .left if root .left is None else self ._right_rotate (root )
55+
56+ # 目标节点在右子树
57+ else :
58+ if root .right is None :
59+ return root
60+ # Zig-Zig模式:右-右
61+ if key > root .right .key :
62+ root .right .right = self ._splay (root .right .right , key )
63+ root = self ._left_rotate (root )
64+ # Zig-Zag模式:右-左
65+ elif key < root .right .key :
66+ root .right .left = self ._splay (root .right .left , key )
67+ if root .right .left :
68+ root .right = self ._right_rotate (root .right )
69+ return root .right if root .right is None else self ._left_rotate (root )
70+
71+ def search (self , key : int ) -> bool :
72+ """搜索指定key,存在返回True,不存在返回False"""
73+ self .root = self ._splay (self .root , key )
74+ return self .root is not None and self .root .key == key
75+
76+ def insert (self , key : int ) -> None :
77+ """插入新节点"""
78+ if self .root is None :
79+ self .root = Node (key )
80+ return
81+
82+ self .root = self ._splay (self .root , key )
83+ if self .root .key == key :
84+ return # 已存在,无需插入
85+
86+ new_node = Node (key )
87+ if key < self .root .key :
88+ new_node .right = self .root
89+ new_node .left = self .root .left
90+ self .root .left = None
91+ else :
92+ new_node .left = self .root
93+ new_node .right = self .root .right
94+ self .root .right = None
95+ self .root = new_node
96+
97+ def delete (self , key : int ) -> None :
98+ """删除指定key的节点"""
99+ if self .root is None :
100+ return
101+
102+ self .root = self ._splay (self .root , key )
103+ if self .root .key != key :
104+ return # 节点不存在
105+
106+ # 左右子树合并
107+ if self .root .left is None :
108+ self .root = self .root .right
109+ else :
110+ temp = self .root .right
111+ self .root = self .root .left
112+ self .root = self ._splay (self .root , key )
113+ self .root .right = temp
114+
115+ def find_min (self ) -> int | None :
116+ """查找树中最小值"""
117+ if self .root is None :
118+ return None
119+ current = self .root
120+ while current .left :
121+ current = current .left
122+ self .root = self ._splay (self .root , current .key )
123+ return current .key
124+
125+ def find_max (self ) -> int | None :
126+ """查找树中最大值"""
127+ if self .root is None :
128+ return None
129+ current = self .root
130+ while current .right :
131+ current = current .right
132+ self .root = self ._splay (self .root , current .key )
133+ return current .key
134+
135+ def inorder_traversal (self , root : Node | None , result : list [int ]) -> None :
136+ """中序遍历(有序输出)"""
137+ if root :
138+ self .inorder_traversal (root .left , result )
139+ result .append (root .key )
140+ self .inorder_traversal (root .right , result )
141+
142+ def get_size (self , root : Node | None ) -> int :
143+ """获取树的节点总数"""
144+ if root is None :
145+ return 0
146+ return 1 + self .get_size (root .left ) + self .get_size (root .right )
147+
148+ def get_height (self , root : Node | None ) -> int :
149+ """获取树的高度"""
150+ if root is None :
151+ return - 1
152+ return 1 + max (self .get_height (root .left ), self .get_height (root .right ))
153+
154+ def print_tree (self , root : Node | None , indent : str = "" , last : bool = True ) -> None :
155+ """可视化打印树结构"""
156+ if root :
157+ print (indent + ("└── " if last else "├── " ) + str (root .key ))
158+ indent += " " if last else "│ "
159+ self .print_tree (root .left , indent , False )
160+ self .print_tree (root .right , indent , True )
161+
162+ # 测试用例
163+ if __name__ == "__main__" :
164+ tree = SplayTree ()
165+ # 插入测试
166+ for key in [10 , 20 , 30 , 40 , 50 , 25 ]:
167+ tree .insert (key )
168+ print ("中序遍历:" , end = " " )
169+ traversal_result = []
170+ tree .inorder_traversal (tree .root , traversal_result )
171+ print (traversal_result )
172+
173+ # 搜索测试(访问后节点会被伸展到根)
174+ print ("\n 搜索25:" , tree .search (25 ))
175+ print ("搜索后根节点:" , tree .root .key )
176+
177+ # 删除测试
178+ tree .delete (30 )
179+ print ("\n 删除30后中序遍历:" , end = " " )
180+ traversal_result = []
181+ tree .inorder_traversal (tree .root , traversal_result )
182+ print (traversal_result )
183+
184+ # 辅助方法测试
185+ print ("\n 树大小:" , tree .get_size (tree .root ))
186+ print ("树高度:" , tree .get_height (tree .root ))
187+ print ("最小值:" , tree .find_min ())
188+ print ("最大值:" , tree .find_max ())
189+
190+ print ("\n 树结构:" )
191+ tree .print_tree (tree .root )
0 commit comments