@@ -1857,13 +1857,16 @@ void profile_envspeed(int total_agents, int num_buffers, int num_threads, int ho
18571857#endif // USE_STATIC_ENV
18581858
18591859void print_usage (const char * prog) {
1860- printf (" Usage: %s <profile>\n " , prog);
1860+ printf (" Usage: %s <profile> [options] \n " , prog);
18611861 printf (" kernels - Individual kernel profiling (no nsys needed)\n " );
18621862#ifdef USE_TORCH
18631863 printf (" forwardcall - Inference forward pass\n " );
18641864#endif
18651865#ifdef USE_STATIC_ENV
18661866 printf (" envspeed - Environment step throughput (static linked)\n " );
1867+ printf (" --buffers N - Number of buffers (default: %d)\n " , BUF);
1868+ printf (" --threads N - Number of threads (default: 16)\n " );
1869+ printf (" --horizon N - Horizon length (default: %d)\n " , T);
18671870#endif
18681871 printf (" all - Run all profiles\n " );
18691872}
@@ -1875,6 +1878,17 @@ int main(int argc, char** argv) {
18751878 }
18761879
18771880 const char * profile = argv[1 ];
1881+
1882+ // Parse optional CLI args for envspeed
1883+ int buffers = BUF;
1884+ int threads = 16 ;
1885+ int horizon = T;
1886+ for (int i = 2 ; i < argc - 1 ; i++) {
1887+ if (strcmp (argv[i], " --buffers" ) == 0 ) buffers = atoi (argv[++i]);
1888+ else if (strcmp (argv[i], " --threads" ) == 0 ) threads = atoi (argv[++i]);
1889+ else if (strcmp (argv[i], " --horizon" ) == 0 ) horizon = atoi (argv[++i]);
1890+ }
1891+
18781892 warmup_gpu ();
18791893
18801894 // Using typical breakout settings: INPUT_SIZE=96, H=128, A=4
@@ -1900,7 +1914,7 @@ int main(int argc, char** argv) {
19001914
19011915#ifdef USE_STATIC_ENV
19021916 if (strcmp (profile, " envspeed" ) == 0 || strcmp (profile, " all" ) == 0 ) {
1903- profile_envspeed (BUF* BR, BUF, 16 , T );
1917+ profile_envspeed (buffers * BR, buffers, threads, horizon );
19041918 }
19051919#endif
19061920
0 commit comments