Skip to content

Commit d70e557

Browse files
committed
Env vec speeds look reasonable in testing
1 parent 3d67cdb commit d70e557

6 files changed

Lines changed: 68 additions & 34 deletions

File tree

profile_kernels.cu

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,9 @@ EnvSpeedArgs* create_envspeedargs(int total_agents, int num_buffers, int num_thr
17551755
fprintf(stderr, "Failed to create environments\n");
17561756
return nullptr;
17571757
}
1758+
for (int i = 0; i < num_buffers; i++) {
1759+
cudaStreamCreateWithFlags(&vec->streams[i], cudaStreamNonBlocking);
1760+
}
17581761

17591762
int num_envs = vec->size;
17601763
printf("Created %d envs (%s) for %d total_agents\n", num_envs, TOSTRING(ENV_NAME), total_agents);
@@ -1794,25 +1797,23 @@ float profile_env_rollout(EnvSpeedArgs* args, const char* name) {
17941797
cudaEventCreate(&stop);
17951798

17961799
// Warmup
1797-
auto start_time = std::chrono::steady_clock::now();
1798-
for (int i = 0; i < 10; ++i) {
1800+
float start_time = get_time_sec();
1801+
for (int i = 0; i < 100; ++i) {
17991802
run_env_rollout(args);
18001803
cudaDeviceSynchronize();
1801-
auto now = std::chrono::steady_clock::now();
1802-
float elapsed = std::chrono::duration<float>(now - start_time).count();
1804+
auto elapsed = get_time_sec() - start_time;
18031805
if (elapsed > TIMEOUT_SEC) break;
18041806
}
18051807

1806-
start_time = std::chrono::steady_clock::now();
1808+
start_time = get_time_sec();
18071809
cudaProfilerStart();
18081810
if (name) nvtxRangePushA(name);
18091811
cudaEventRecord(start);
18101812
float completed = 0;
18111813
for (int i = 0; i < 1000; ++i) {
18121814
run_env_rollout(args);
18131815
completed += 1;
1814-
auto now = std::chrono::steady_clock::now();
1815-
float elapsed = std::chrono::duration<float>(now - start_time).count();
1816+
float elapsed = get_time_sec() - start_time;
18161817
if (elapsed > TIMEOUT_SEC) break;
18171818
}
18181819
cudaDeviceSynchronize();
@@ -1883,10 +1884,12 @@ int main(int argc, char** argv) {
18831884
int buffers = BUF;
18841885
int threads = 16;
18851886
int horizon = T;
1887+
int total_agents = BR * buffers;
18861888
for (int i = 2; i < argc - 1; i++) {
18871889
if (strcmp(argv[i], "--buffers") == 0) buffers = atoi(argv[++i]);
18881890
else if (strcmp(argv[i], "--threads") == 0) threads = atoi(argv[++i]);
18891891
else if (strcmp(argv[i], "--horizon") == 0) horizon = atoi(argv[++i]);
1892+
else if (strcmp(argv[i], "--total-agents") == 0) total_agents = atoi(argv[++i]);
18901893
}
18911894

18921895
warmup_gpu();
@@ -1914,7 +1917,7 @@ int main(int argc, char** argv) {
19141917

19151918
#ifdef USE_STATIC_ENV
19161919
if (strcmp(profile, "envspeed") == 0 || strcmp(profile, "all") == 0) {
1917-
profile_envspeed(buffers * BR, buffers, threads, horizon);
1920+
profile_envspeed(total_agents, buffers, threads, horizon);
19181921
}
19191922
#endif
19201923

pufferlib/config/ocean/benchmark.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ policy_name = Policy
55
rnn_name = Recurrent
66

77
[env]
8-
bandwidth = 1
8+
bandwidth = 512
99
compute = 0
1010

1111
[vec]

pufferlib/extensions/env_binding.c

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ static void* static_omp_threadmanager(void* arg) {
8989

9090
Env* envs = (Env*)vec->envs;
9191

92+
printf("Num workers: %d\n", num_workers);
9293
while (true) {
9394
while (atomic_load(&buffer_states[buf]) != OMP_RUNNING) {
9495
if (atomic_load(&threading->shutdown)) {
@@ -107,9 +108,15 @@ static void* static_omp_threadmanager(void* arg) {
107108
cudaMemcpyDeviceToHost, stream);
108109
cudaStreamSynchronize(stream);
109110

110-
#pragma omp parallel for schedule(static) num_threads(num_workers)
111-
for (int i = env_start; i < env_start + env_count; i++) {
112-
c_step(&envs[i]);
111+
if (num_workers > 1) {
112+
#pragma omp parallel for schedule(static) num_threads(num_workers)
113+
for (int i = env_start; i < env_start + env_count; i++) {
114+
c_step(&envs[i]);
115+
}
116+
} else {
117+
for (int i = env_start; i < env_start + env_count; i++) {
118+
c_step(&envs[i]);
119+
}
113120
}
114121

115122
cudaMemcpyAsync(

pufferlib/ocean/benchmark/benchmark.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void c_step(Benchmark* env) {
2626
result = sinf(result + 0.1f);
2727
}
2828

29-
memset(env->observations, result, env->bandwidth);
29+
//memset(env->observations, result, env->bandwidth);
3030
}
3131

3232
void c_render(Benchmark* env) { }

pufferlib/ocean/benchmark/binding.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "benchmark.h"
2-
#define OBS_SIZE 1 // TODO: Current API forces you to edit this per obs size
2+
#define OBS_SIZE 512 // TODO: Current API forces you to edit this per obs size
33
#define NUM_ATNS 1
44
#define ACT_SIZES {2}
55
#define OBS_TYPE UNSIGNED_CHAR

pufferlib/ocean/breakout/breakout.h

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,16 @@ void compute_observations(Breakout* env) {
159159
env->observations[7] = env->score / 864.0f;
160160
env->observations[8] = env->num_balls / 5.0f;
161161
env->observations[9] = env->paddle_width / (2.0f * HALF_PADDLE_WIDTH);
162-
for (int i = 0; i < env->num_bricks; i++) {
163-
env->observations[10 + i] = env->brick_states[i];
164-
}
162+
memcpy(env->observations + 10, env->brick_states, sizeof(float) * env->num_bricks);
165163
}
166164

167165
// Collision of a stationary vertical line segment (xw,yw) to (xw,yw+hw)
168166
// with a moving line segment (x+vx*t,y+vy*t) to (x+vx*t,y+vy*t+h).
169167
static inline bool calc_vline_collision(float xw, float yw, float hw, float x,
170168
float y, float vx, float vy, float h, CollisionInfo* col) {
171169
float t_new = (xw - x) / vx;
172-
float topmost = fmin(yw + hw, y + h + vy * t_new);
173-
float botmost = fmax(yw, y + vy * t_new);
170+
float topmost = fminf(yw + hw, y + h + vy * t_new);
171+
float botmost = fmaxf(yw, y + vy * t_new);
174172
float overlap_new = topmost - botmost;
175173

176174
// Collision finds the smallest time of collision with the greatest overlap
@@ -248,26 +246,52 @@ static inline void calc_brick_collision(Breakout* env, int idx,
248246
}
249247
}
250248
static inline int column_index(Breakout* env, float x) {
251-
return (int)(floorf(x / env->brick_width));
249+
return (int)(x / env->brick_width);
252250
}
253251
static inline int row_index(Breakout* env, float y) {
254-
return (int)(floorf((y - Y_OFFSET) / env->brick_height));
252+
return (int)((y - Y_OFFSET) / env->brick_height);
255253
}
256254

257255
void calc_all_brick_collisions(Breakout* env, CollisionInfo* collision_info) {
258-
int column_from = column_index(env, fminf(env->ball_x + env->ball_vx, env->ball_x));
259-
column_from = fmaxf(column_from, 0);
260-
int column_to = column_index(env, fmaxf(env->ball_x + env->ball_width + env->ball_vx, env->ball_x + env->ball_width));
261-
column_to = fminf(column_to, env->brick_cols - 1);
262-
int row_from = row_index(env, fminf(env->ball_y + env->ball_vy, env->ball_y));
263-
row_from = fmaxf(row_from, 0);
264-
int row_to = row_index(env, fmaxf(env->ball_y + env->ball_height + env->ball_vy, env->ball_y + env->ball_height));
265-
row_to = fminf(row_to, env->brick_rows - 1);
256+
float ball_x = env->ball_x;
257+
float ball_x_dst = ball_x + env->ball_vx;
258+
float ball_y = env->ball_y;
259+
float ball_y_dst = ball_y + env->ball_vy;
260+
float ball_width = env->ball_width;
261+
float ball_height = env->ball_height;
262+
263+
int row_from = row_index(env, ball_y < ball_y_dst ? ball_y : ball_y_dst);
264+
if (row_from < 0) {
265+
row_from = 0;
266+
}
267+
268+
if (row_from > env->brick_rows) {
269+
return;
270+
}
271+
272+
int column_from = column_index(env, ball_x < ball_x_dst ? ball_x : ball_x_dst);
273+
if (column_from < 0) {
274+
column_from = 0;
275+
}
276+
277+
float ball_x_end = ball_x + ball_width;
278+
float ball_x_dst_end = ball_x_dst + ball_width;
279+
int column_to = column_index(env, ball_x_dst_end > ball_x_end ? ball_x_dst_end : ball_x_end);
280+
if (column_to >= env->brick_cols) {
281+
column_to = env->brick_cols - 1;
282+
}
283+
284+
float ball_y_end = ball_y + ball_height;
285+
float ball_y_dst_end = ball_y_dst + ball_height;
286+
int row_to = row_index(env, ball_y_dst_end > ball_y_end ? ball_y_dst_end : ball_y_end);
287+
if (row_to >= env->brick_rows) {
288+
row_to = env->brick_rows - 1;
289+
}
266290

267291
for (int row = row_from; row <= row_to; row++) {
268292
for (int column = column_from; column <= column_to; column++) {
269293
int brick_index = row * env->brick_cols + column;
270-
if (env->brick_states[brick_index] == 0.0)
294+
if (env->brick_states[brick_index] == 0.0f)
271295
calc_brick_collision(env, brick_index, collision_info);
272296
}
273297
}
@@ -297,8 +321,8 @@ bool calc_paddle_ball_collisions(Breakout* env, CollisionInfo* collision_info) {
297321
float relative_intersection = (
298322
(env->ball_x + env->ball_width / 2) - env->paddle_x) / env->paddle_width;
299323
float angle = -base_angle + relative_intersection * 2 * base_angle;
300-
env->ball_vx = sin(angle) * env->ball_speed * TICK_RATE;
301-
env->ball_vy = -cos(angle) * env->ball_speed * TICK_RATE;
324+
env->ball_vx = sinf(angle) * env->ball_speed * TICK_RATE;
325+
env->ball_vy = -cosf(angle) * env->ball_speed * TICK_RATE;
302326
env->hits += 1;
303327
if (env->hits % 4 == 0 && env->ball_speed < env->max_ball_speed) {
304328
env->ball_speed += 64;
@@ -430,8 +454,8 @@ void step_frame(Breakout* env, float action) {
430454
env->balls_fired = 1;
431455
float direction = M_PI / 3.25f;
432456

433-
env->ball_vy = cos(direction) * env->ball_speed * TICK_RATE;
434-
env->ball_vx = sin(direction) * env->ball_speed * TICK_RATE;
457+
env->ball_vy = cosf(direction) * env->ball_speed * TICK_RATE;
458+
env->ball_vx = sinf(direction) * env->ball_speed * TICK_RATE;
435459
if (rand() % 2 == 0) {
436460
env->ball_vx = -env->ball_vx;
437461
}

0 commit comments

Comments
 (0)