@@ -66,8 +66,10 @@ inline void bind_infer_engine(py::module &m) {
6666 }
6767 return state_dict_tp_all;
6868 })
69- .def (" forward" , [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward (input); }, " Run inference on all ranks with arbitrary arguments" )
70- .def (" reset_cache" , [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache (cfg ? cfg.get () : nullptr ); }, py::arg (" cache_config" ) = py::none ())
69+ .def (
70+ " forward" , [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward (input); }, " Run inference on all ranks with arbitrary arguments" )
71+ .def (
72+ " reset_cache" , [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache (cfg ? cfg.get () : nullptr ); }, py::arg (" cache_config" ) = py::none ())
7173 .def (" get_cache_config" , [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
7274 auto cfg = self.get_cache_config ();
7375 return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy ()) : nullptr ; })
@@ -81,7 +83,7 @@ inline void bind_infer_engine(py::module &m) {
8183 std::shared_ptr<infinilm::cache::CacheConfig> cache_cfg,
8284 bool enable_graph_compiling,
8385 const std::string &attention_backend,
84- const std::string & kv_cache_dtype) {
86+ std::optional<infinicore::DataType> kv_cache_dtype) {
8587 return std::make_shared<InferEngine>(
8688 model_path,
8789 dist,
@@ -97,7 +99,7 @@ inline void bind_infer_engine(py::module &m) {
9799 py::arg (" cache_config" ) = py::none (),
98100 py::arg (" enable_graph_compiling" ) = false ,
99101 py::arg (" attention_backend" ) = " default" ,
100- py::arg (" kv_cache_dtype" ) = " " )
102+ py::arg (" kv_cache_dtype" ) = py::none () )
101103 .def (" load_param" , &InferEngine::load_param,
102104 py::arg (" name" ), py::arg (" param" ),
103105 " Load a parameter tensor into all workers (each worker picks its shard)" )
@@ -112,8 +114,10 @@ inline void bind_infer_engine(py::module &m) {
112114 }
113115 return state_dict_tp_all;
114116 })
115- .def (" forward" , [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward (input); }, " Run inference on all ranks with arbitrary arguments" )
116- .def (" reset_cache" , [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache (cfg ? cfg.get () : nullptr ); }, py::arg (" cache_config" ) = py::none ())
117+ .def (
118+ " forward" , [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward (input); }, " Run inference on all ranks with arbitrary arguments" )
119+ .def (
120+ " reset_cache" , [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache (cfg ? cfg.get () : nullptr ); }, py::arg (" cache_config" ) = py::none ())
117121 .def (" get_cache_config" , [](const InferEngine &self) {
118122 auto cfg = self.get_cache_config ();
119123 return std::shared_ptr<cache::CacheConfig>(std::move (cfg->unique_copy ())); })
0 commit comments