-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSimpleAgent.py
More file actions
83 lines (75 loc) · 2.64 KB
/
SimpleAgent.py
File metadata and controls
83 lines (75 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
A Deterministic Agent that tries to put the paddle below the ball
"""
import gym
import argparse
import numpy as np
def paddleCol(obs):
# Returns the center column of the paddle
# The first few pixel columns are just the left side of the screen, so skip them
paddleCol = 10;
while paddleCol < 150:
pixel = obs[190][paddleCol][0]
if pixel != 0:
return paddleCol+7
paddleCol+=1
return 94
def ballCol(obs):
# Returns the column that the ball is in
# The lowest bricks with same color as the ball to start are on row 62, so check in the spaces from row 63 to 187 for the ball
# if the ball cannot be found in those rows, just recenter the paddle to try to catch it later
ballRow = 63;
while ballRow < 188:
ballCol = 0
while ballCol < 160:
pixel = obs[ballRow][ballCol][0]
if pixel == 200:
#print(ballRow)
return ballCol+1
ballCol+=1
ballRow+=1
return -1
def determineAction(paddleCol, ballCol):
if ballCol == -1:
return 1
difference = paddleCol - ballCol
if difference < -7:
return 2
if difference > 7:
return 3
return 1
def main():
np.set_printoptions(threshold=np.inf)
parser = argparse.ArgumentParser()
parser.add_argument('--env-name', type=str, default='BreakoutDeterministic-v4')
args = parser.parse_args()
# Get the environment and extract the number of actions.
env = gym.make(args.env_name)
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n
obs = env.reset()
best_rew = float('-inf')
rew = 0
# Start the game by pressing fire
obs, stepRew, done, info = env.step(1)
while True:
# The screen is a 210x160 array of pixels with three values for Atari's RGB
# The paddle is on rows 188-191 inclusive (initial center is 189.5)
# The paddle starts on columns 86-101 inclusive, (initial center is 93.5)
# and move the paddle based on its location relative to the ball by checking row 190 for where it is
paddleCenter = paddleCol(obs)
ballCenter = ballCol(obs)
#print(paddleCenter, ballCenter)
obs, stepRew, done, info = env.step(determineAction(paddleCenter, ballCenter))
rew += stepRew
if rew > best_rew:
print("new best reward {} => {}".format(best_rew, rew))
best_rew = rew
env.render()
if done:
rew = 0
obs = env.reset()
env.close()
if __name__ == "__main__":
main()