Skip to content

Commit 937c40d

Browse files
authored
robot_activity_model内の配列をなるべく減らすために,dataclassを使うように変更 (#504)
1 parent 95b39cc commit 937c40d

1 file changed

Lines changed: 145 additions & 92 deletions

File tree

consai_game/consai_game/world_model/robot_activity_model.py

Lines changed: 145 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from consai_msgs.msg import MotionCommand
2727

28-
from dataclasses import dataclass
28+
from dataclasses import dataclass, field
2929
from enum import Enum, auto
3030

3131

@@ -45,7 +45,96 @@ class OurRobotsArrived:
4545
arrived: bool = False
4646

4747

48+
@dataclass
49+
class RobotInfo:
50+
"""単一のロボット情報を保持するデータクラス."""
51+
52+
# ロボットID
53+
robot_id: int = 0
54+
55+
# 目標位置までの距離
56+
desired_distance: float = float("inf")
57+
# ボールまでの距離
58+
ball_distance: float = float("inf")
59+
# プレースメント位置までの距離
60+
placement_distance: float = float("inf")
61+
62+
# 目標位置に到着しているかのフラグ
63+
arrived: bool = False
64+
65+
66+
@dataclass
67+
class RobotsInfo:
68+
"""自ロボットの情報を保持するデータクラス."""
69+
70+
robots: dict[int, RobotInfo] = field(default_factory=dict)
71+
72+
def clear(self):
73+
"""全ロボット情報を初期化して空にするメソッド."""
74+
self.robots.clear()
75+
76+
def visible_ids(self) -> list[int]:
77+
"""可視ロボットのIDリストを返すメソッド."""
78+
return list(self.robots.keys())
79+
80+
def arrived_ids(self) -> list[int]:
81+
"""目標位置に到達したロボットのIDリストを返すメソッド."""
82+
return [r.robot_id for r in self.robots.values() if r.arrived]
83+
84+
def all_arrived(self) -> bool:
85+
"""全ロボットが目標位置に到達しているかを返すメソッド."""
86+
return all(r.arrived for r in self.robots.values())
87+
88+
def get(self, robot_id: int) -> RobotInfo:
89+
"""指定したロボットIDのRobotInfoを返す。存在しない場合はKeyErrorメソッド."""
90+
return self.robots[robot_id]
91+
92+
def __getitem__(self, robot_id: int) -> RobotInfo:
93+
"""辞書のようにロボットIDでRobotInfoへアクセスできるようにするメソッド."""
94+
return self.robots[robot_id]
95+
96+
def __setitem__(self, robot_id: int, value: RobotInfo):
97+
"""辞書のようにロボットIDでRobotInfoを設定できるようにするメソッド."""
98+
self.robots[robot_id] = value
99+
100+
def __contains__(self, robot_id: int) -> bool:
101+
"""ロボットIDが含まれているか判定するメソッド."""
102+
return robot_id in self.robots
103+
104+
def __len__(self):
105+
"""可視ロボット数を返すメソッド."""
106+
return len(self.robots)
107+
108+
def keys(self):
109+
"""可視ロボットのID一覧を返すメソッド."""
110+
return self.robots.keys()
111+
112+
def values(self):
113+
"""可視ロボットのRobotInfo一覧を返すメソッド."""
114+
return self.robots.values()
115+
116+
def items(self):
117+
"""可視ロボットの(ID, RobotInfo)タプル一覧を返すメソッド."""
118+
return self.robots.items()
119+
120+
121+
@dataclass
122+
class OurRobotsInfo:
123+
"""自ロボットの情報を保持するデータクラス."""
124+
125+
our_robots: RobotsInfo = field(default_factory=RobotsInfo)
126+
127+
128+
@dataclass
129+
class TheirRobotsInfo:
130+
"""敵ロボットの情報を保持するデータクラス."""
131+
132+
their_robots: RobotsInfo = field(default_factory=RobotsInfo)
133+
134+
48135
class ProhibitedKickRobotSearchState(Enum):
136+
"""キック禁止ロボット探索状態を表す列挙型."""
137+
49138
BEFORE_SEARCH = 0
50139
SHOULD_FIRST_SEARCH = auto()
51140
SHOULD_SECOND_SEARCH = auto()
@@ -60,18 +149,12 @@ class RobotActivityModel:
60149
INVALID_ROBOT_ID = -1
61150

62151
def __init__(self):
63-
"""ロボットの可視状態と順序リストの初期化関数."""
64-
self.ordered_our_visible_robots: list[int] = []
65-
self.ordered_their_visible_robots: list[int] = []
66-
self.our_visible_robots: list[int] = []
67-
self.their_visible_robots: list[int] = []
68-
self.our_robots_by_ball_distance: list[int] = []
69-
self.their_robots_by_ball_distance: list[int] = []
70-
self.our_robots_by_placement_distance: list[int] = []
152+
"""RobotActivityModelの初期化."""
153+
self.our_visible_robots = RobotsInfo()
154+
self.their_visible_robots = RobotsInfo()
71155
self.our_ball_receive_score: list[ReceiveScore] = []
72-
self.our_robots_arrived_list: list[OurRobotsArrived] = []
73-
self.our_prohibited_kick_robot_id: int = self.INVALID_ROBOT_ID # 直近のキック禁止ロボットID
74-
self.prohibited_kick_robot_candidate_id: int = self.INVALID_ROBOT_ID # キック禁止ロボット候補ID
156+
self.our_prohibited_kick_robot_id: int = self.INVALID_ROBOT_ID
157+
self.prohibited_kick_robot_candidate_id: int = self.INVALID_ROBOT_ID
75158
self.prohibited_kick_robot_search_state = ProhibitedKickRobotSearchState.BEFORE_SEARCH
76159
self.number_of_their_robots_in_our_area: int = 0
77160

@@ -83,67 +166,76 @@ def update(
83166
game_config: GameConfigModel,
84167
referee: RefereeModel,
85168
):
86-
"""ロボットの可視状態を更新し, 順序づけされたIDリストを更新する関数."""
87-
self.our_visible_robots = [robot.robot_id for robot in robots.our_robots.values() if robot.is_visible]
88-
self.their_visible_robots = [robot.robot_id for robot in robots.their_robots.values() if robot.is_visible]
169+
"""ロボットの可視状態を更新し, 距離や状態をセットする関数."""
89170

90-
self.ordered_our_visible_robots = self.ordered_merge(
91-
self.ordered_our_visible_robots,
92-
self.our_visible_robots,
93-
)
94-
self.ordered_their_visible_robots = self.ordered_merge(
95-
self.ordered_their_visible_robots,
96-
self.their_visible_robots,
97-
)
171+
# 可視ロボットをRobotInfoで管理
172+
self.our_visible_robots.clear()
173+
for robot in robots.our_robots.values():
174+
if robot.is_visible:
175+
info = RobotInfo(
176+
robot_id=robot.robot_id,
177+
ball_distance=tools.get_distance(robot.pos, ball.pos),
178+
placement_distance=tools.get_distance(robot.pos, referee.placement_pos),
179+
)
180+
self.our_visible_robots[robot.robot_id] = info
98181

99-
# ボールに近い順にリストを作る
182+
self.their_visible_robots.clear()
183+
for robot in robots.their_robots.values():
184+
if robot.is_visible:
185+
info = RobotInfo(
186+
robot_id=robot.robot_id,
187+
ball_distance=tools.get_distance(robot.pos, ball.pos),
188+
placement_distance=tools.get_distance(robot.pos, referee.placement_pos),
189+
)
190+
self.their_visible_robots[robot.robot_id] = info
191+
192+
# ボールに近い順
100193
self.our_robots_by_ball_distance = [
101-
robot_id
102-
for robot_id, _ in self.robot_ball_distances(
103-
robots.our_visible_robots,
104-
ball,
105-
)
194+
r.robot_id for r in sorted(self.our_visible_robots.values(), key=lambda x: x.ball_distance)
106195
]
107196
self.their_robots_by_ball_distance = [
108-
robot_id
109-
for robot_id, _ in self.robot_ball_distances(
110-
robots.their_visible_robots,
111-
ball,
112-
)
197+
r.robot_id for r in sorted(self.their_visible_robots.values(), key=lambda x: x.ball_distance)
113198
]
114-
115-
# プレースメント位置に近い順にリストを作る
199+
# プレースメント位置に近い順
116200
self.our_robots_by_placement_distance = [
117-
robot_id
118-
for robot_id, _ in self.robot_placement_distances(
119-
robots.our_visible_robots,
120-
referee,
121-
)
201+
r.robot_id for r in sorted(self.our_visible_robots.values(), key=lambda x: x.placement_distance)
122202
]
123203

124-
# ボールを受け取れるスコアを計算する
204+
# ボール受け取りスコア
125205
self.our_ball_receive_score = self.calc_ball_receive_score_list(
126-
robots=robots.our_visible_robots,
206+
robots={rid: robots.our_robots[rid] for rid in self.our_visible_robots.visible_ids()},
127207
ball=ball,
128208
ball_activity=ball_activity,
129209
game_config=game_config,
130210
)
131211

132-
# 自陣にいる相手ロボットの台数を計算する
133-
self.number_of_their_robots_in_our_area = self.count_their_robots(robots.their_visible_robots)
212+
# 自陣にいる相手ロボット数
213+
self.number_of_their_robots_in_our_area = sum(
214+
1 for r in self.their_visible_robots.values() if robots.their_robots[r.robot_id].pos.x < 0
215+
)
134216

135-
# ダブルタッチ防止のために、キック禁止ロボット情報を更新する
136217
self.update_prohibited_kick_robot(ball_activity, referee)
137218

138-
def ordered_merge(self, prev_list: list[int], new_list: list[int]) -> list[int]:
139-
"""過去の順序を保ちながら, 新しいリストでマージする関数."""
140-
# new_listに存在するものを残す
141-
output_list = [r for r in prev_list if r in new_list]
219+
def update_our_robots_arrived(self, robots: dict[int, Robot], commands: list[MotionCommand]):
220+
"""各ロボットが目標位置に到達したかをRobotInfoにセット"""
221+
for command in commands:
222+
if command.robot_id not in robots:
223+
continue
224+
robot = robots[command.robot_id]
225+
dist = tools.get_distance(robot.pos, command.desired_pose)
226+
if command.robot_id in self.our_visible_robots:
227+
self.our_visible_robots[command.robot_id].arrived = dist < self.DIST_ROBOT_TO_DESIRED_THRESHOLD
142228

143-
# 新しい要素を追加する
144-
output_list.extend([r for r in new_list if r not in output_list])
229+
@property
230+
def our_robots_arrived(self) -> bool:
231+
"""すべての自ロボットが目標位置に到達したか"""
232+
return self.our_visible_robots.all_arrived()
145233

146-
return output_list
234+
def our_robot_arrived(self, robot_id: int) -> bool:
235+
"""指定したロボットが目標位置に到達したか"""
236+
if robot_id in self.our_visible_robots:
237+
return self.our_visible_robots[robot_id].arrived
238+
return False
147239

148240
def count_their_robots(self, robots: dict[int, Robot]) -> int:
149241
"""自陣にいる相手ロボットの数を返す関数."""
@@ -231,33 +323,6 @@ def calc_intercept_time(self, robot: Robot, ball: BallModel, game_config: GameCo
231323
return intercept_time
232324
return float("inf")
233325

234-
def update_our_robots_arrived(self, our_visible_robots: dict[int, Robot], commands: list[MotionCommand]) -> bool:
235-
"""各ロボットが目標位置に到達したか判定する関数."""
236-
237-
# 初期化
238-
self.our_robots_arrived_list = []
239-
# エラー処理
240-
if len(our_visible_robots) == 0 or len(commands) == 0:
241-
return
242-
243-
# 更新
244-
for command in commands:
245-
if command.robot_id not in our_visible_robots.keys():
246-
continue
247-
248-
robot = our_visible_robots[command.robot_id]
249-
robot_pos = robot.pos
250-
desired_pose = command.desired_pose
251-
# ロボットと目標位置の距離を計算
252-
dist_robot_to_desired = tools.get_distance(robot_pos, desired_pose)
253-
# 目標位置に到達したか判定結果をリストに追加
254-
self.our_robots_arrived_list.append(
255-
OurRobotsArrived(
256-
robot_id=robot.robot_id,
257-
arrived=dist_robot_to_desired < self.DIST_ROBOT_TO_DESIRED_THRESHOLD,
258-
)
259-
)
260-
261326
def update_prohibited_kick_robot(self, ball_activity: BallActivityModel, referee: RefereeModel):
262327
"""レフェリー信号を見て、ダブルタッチをしてはいけないロボットを更新する関数."""
263328
# ストップゲームで初期化する
@@ -300,15 +365,3 @@ def update_prohibited_kick_robot(self, ball_activity: BallActivityModel, referee
300365
# 違うロボットがボールを蹴ったら探索終了
301366
self.prohibited_kick_robot_search_state = ProhibitedKickRobotSearchState.SEARCH_COMPLETED
302367
self.our_prohibited_kick_robot_id = self.INVALID_ROBOT_ID
303-
304-
@property
305-
def our_robots_arrived(self) -> bool:
306-
"""すべての自ロボットが目標位置に到達したかを返す関数."""
307-
return all([robot.arrived for robot in self.our_robots_arrived_list])
308-
309-
def our_robot_arrived(self, robot_id: int) -> bool:
310-
"""指定したロボットが目標位置に到達したかを返す関数."""
311-
for robot in self.our_robots_arrived_list:
312-
if robot.robot_id == robot_id:
313-
return robot.arrived
314-
return False

0 commit comments

Comments
 (0)