@@ -40,9 +40,8 @@ def handles_memory_backpressure(self):
4040 return True
4141
4242 def device_name (self , device_index = None ):
43- if device_index is None :
44- return 'hpu'
45- return 'hpu:{}' .format (device_index )
43+ # ignoring device_index.
44+ return 'hpu'
4645
4746 def device (self , device_index = None ):
4847 return torch .device (self .device_name (device_index ))
@@ -73,13 +72,13 @@ def get_rng_state(self, device_index=None):
7372 return self .hpu .random .get_rng_state ()
7473
7574 def manual_seed (self , seed ):
76- self .hpu .random .manual_seed (seed )
75+ return self .hpu .random .manual_seed (seed )
7776
7877 def manual_seed_all (self , seed ):
7978 self .hpu .random .manual_seed_all (seed )
8079
81- def initial_seed (self , seed ):
82- self .hpu .random .initial_seed (seed )
80+ def initial_seed (self ):
81+ return self .hpu .random .initial_seed ()
8382
8483 def default_generator (self , device_index ):
8584 return self .hpu .random .default_generators [device_index ]
@@ -288,6 +287,17 @@ def get_op_builder(self, class_name):
288287 else :
289288 return self .class_dict ['NotImplementedBuilder' ] if 'NotImplementedBuilder' in self .class_dict else None
290289
290+ def get_compile_backend (self ):
291+ return "hpu_backend"
292+
293+ #shall be removed once moving to torch.compile
294+ def wrap_in_hpu_graph (self , module ):
295+ if self .hpu .is_lazy ():
296+ module = self .hpu .wrap_in_hpu_graph (module )
297+ else :
298+ print ("Warning: hpu graphs in eager mode is not supported, ignoring" )
299+ return module
300+
291301 def build_extension (self ):
292302 from torch .utils .cpp_extension import BuildExtension
293303 return BuildExtension
0 commit comments