11#include < fmt/format.h>
2+ #include < cstring>
23
34#include < forward/sampler.hpp>
45#include < forward/tokenizer.hpp>
910using namespace llama ;
1011using namespace tensor ;
1112
12- int main (int argc, char * argv[]) {
13- const auto * path = " ./tests/model" ;
14- if (argc > 1 ) {
15- path = argv[1 ];
16- }
17-
13+ template <Device D>
14+ void run_inference (const char * path, tokenizer::Tokenizer& tok) {
1815 size_t max_tokens = 128 ;
1916 size_t kv_cache_size = max_tokens;
2017
21- tokenizer::Tokenizer tok (" ./tests/model/tokenizer.json" );
22-
23- sampler::GreedySampler<bfloat16, CPU> sampler{sampler::GreedyConfig{}, tok};
24-
25- Model<bfloat16, CPU> mod (" ./tests/model/config.json" , max_tokens, kv_cache_size);
18+ sampler::GreedySampler<bfloat16, D> sampler{sampler::GreedyConfig{}, tok};
2619
27- // loader::inspect_safetensors("./tests/model/model.safetensors" );
20+ Model<bfloat16, D> mod ( fmt::format ( " {}/config.json " , path), max_tokens, kv_cache_size );
2821
2922 fmt::println (" Loading weights..." );
30- Loader<bfloat16, CPU > loader{" ./tests/ model/model .safetensors" };
23+ Loader<bfloat16, D > loader{fmt::format ( " {}/ model.safetensors" , path) };
3124 mod.load_weights (loader);
3225
3326 fmt::println (" Weights loaded! Performing inference..." );
@@ -36,16 +29,44 @@ int main(int argc, char* argv[]) {
3629
3730 fmt::println (" Prompt: {}" , prompt);
3831
39- auto gen_and_tok_s = sampler.generate (mod, prompt, 12 );
32+ auto [out, stats] = sampler.generate (mod, prompt, 12 );
4033
41- auto out = std::get<0 >(gen_and_tok_s);
42- auto tok_s = std::get<1 >(gen_and_tok_s);
34+ auto colored_out = fmt::format (fmt::fg (fmt::color::aqua), " {}" , out);
4335
44- out = fmt::format ( fmt::fg (fmt::color::aqua), " {}" , out );
36+ fmt::println ( " {}{} " , prompt, colored_out );
4537
46- fmt::println (" {}{}" , prompt, out);
38+ fmt::println (" " );
39+ fmt::println (" TTFT: {:.2f} ms" , stats.ttft_ms );
40+ fmt::println (" Avg ITL: {:.2f} ms" , stats.avg_itl_ms );
41+ fmt::println (" Tokens / sec: {:.2f}" , stats.tokens_per_sec );
42+ }
4743
48- fmt::println (" Tokens / sec: {}" , tok_s);
44+ int main (int argc, char * argv[]) {
45+ const char * path = " ./tests/model" ;
46+ bool use_cuda = false ;
47+
48+ for (int i = 1 ; i < argc; ++i) {
49+ if (std::strcmp (argv[i], " --cuda" ) == 0 ) {
50+ use_cuda = true ;
51+ } else {
52+ path = argv[i];
53+ }
54+ }
55+
56+ tokenizer::Tokenizer tok (fmt::format (" {}/tokenizer.json" , path));
57+
58+ if (use_cuda) {
59+ #ifdef BACKEND_CUDA
60+ fmt::println (" Using CUDA backend" );
61+ run_inference<CUDA>(path, tok);
62+ #else
63+ fmt::println (" Error: CUDA backend not available. Rebuild with CUDA support." );
64+ return 1 ;
65+ #endif
66+ } else {
67+ fmt::println (" Using CPU backend" );
68+ run_inference<CPU>(path, tok);
69+ }
4970
5071 return 0 ;
5172}
0 commit comments