Skip to content

Commit abc8bb0

Browse files
committed
Add benchmark env
1 parent 0837e76 commit abc8bb0

3 files changed

Lines changed: 18 additions & 2 deletions

File tree

profile_kernels.cu

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,13 +1857,16 @@ void profile_envspeed(int total_agents, int num_buffers, int num_threads, int ho
18571857
#endif // USE_STATIC_ENV
18581858

18591859
void print_usage(const char* prog) {
1860-
printf("Usage: %s <profile>\n", prog);
1860+
printf("Usage: %s <profile> [options]\n", prog);
18611861
printf(" kernels - Individual kernel profiling (no nsys needed)\n");
18621862
#ifdef USE_TORCH
18631863
printf(" forwardcall - Inference forward pass\n");
18641864
#endif
18651865
#ifdef USE_STATIC_ENV
18661866
printf(" envspeed - Environment step throughput (static linked)\n");
1867+
printf(" --buffers N - Number of buffers (default: %d)\n", BUF);
1868+
printf(" --threads N - Number of threads (default: 16)\n");
1869+
printf(" --horizon N - Horizon length (default: %d)\n", T);
18671870
#endif
18681871
printf(" all - Run all profiles\n");
18691872
}
@@ -1875,6 +1878,17 @@ int main(int argc, char** argv) {
18751878
}
18761879

18771880
const char* profile = argv[1];
1881+
1882+
// Parse optional CLI args for envspeed
1883+
int buffers = BUF;
1884+
int threads = 16;
1885+
int horizon = T;
1886+
for (int i = 2; i < argc - 1; i++) {
1887+
if (strcmp(argv[i], "--buffers") == 0) buffers = atoi(argv[++i]);
1888+
else if (strcmp(argv[i], "--threads") == 0) threads = atoi(argv[++i]);
1889+
else if (strcmp(argv[i], "--horizon") == 0) horizon = atoi(argv[++i]);
1890+
}
1891+
18781892
warmup_gpu();
18791893

18801894
// Using typical breakout settings: INPUT_SIZE=96, H=128, A=4
@@ -1900,7 +1914,7 @@ int main(int argc, char** argv) {
19001914

19011915
#ifdef USE_STATIC_ENV
19021916
if (strcmp(profile, "envspeed") == 0 || strcmp(profile, "all") == 0) {
1903-
profile_envspeed(BUF*BR, BUF, 16, T);
1917+
profile_envspeed(buffers * BR, buffers, threads, horizon);
19041918
}
19051919
#endif
19061920

pufferlib/extensions/env_binding.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <omp.h>
55
#include <stdatomic.h>
66
#include <pthread.h>
7+
#include <stdbool.h>
78

89
#include "env_binding.h"
910
#include "binding.h"

pufferlib/ocean/benchmark/benchmark.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ typedef struct {
1313
double* actions;
1414
float* rewards;
1515
float* terminals;
16+
int num_agents;
1617
int bandwidth;
1718
int compute;
1819
} Benchmark;

0 commit comments

Comments
 (0)