@@ -280,8 +280,18 @@ def get_args():
280280
281281 logging .basicConfig (level = logging .INFO , format = '%(levelname)s: %(message)s' )
282282 args = get_args ()
283- device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
284- #logging.info(f'Using device {device}')
283+ # Check if CUDA is available
284+ if torch .cuda .is_available ():
285+ logging .info ("CUDA is available. Using CUDA..." )
286+ device = torch .device ("cuda:0" )
287+ elif torch .backends .mps .is_available (): # Check if MPS is available (for macOS)
288+ logging .info ("MPS is available. Using MPS..." )
289+ device = torch .device ("mps" )
290+ else :
291+ logging .info ("Neither CUDA nor MPS is available. Using CPU..." )
292+ device = torch .device ("cpu" )
293+
294+ logging .info (f'Using device { device } ' )
285295
286296 img_size = Define_image_size (args .uniform , args .dataset )
287297 dataset_name = args .dataset
@@ -347,79 +357,79 @@ def get_args():
347357
348358
349359 for i in range (1 ):
350- net_G_1 .load_state_dict (torch .load ( checkpoint_saved_1 + 'CP_best_F1_all.pth' ))
351- net_G_A_1 .load_state_dict (torch .load ( checkpoint_saved_1 + 'CP_best_F1_A.pth' ))
352- net_G_V_1 .load_state_dict (torch .load (checkpoint_saved_1 + 'CP_best_F1_V.pth' ))
360+ net_G_1 .load_state_dict (torch .load ( checkpoint_saved_1 + 'CP_best_F1_all.pth' , map_location = device ))
361+ net_G_A_1 .load_state_dict (torch .load ( checkpoint_saved_1 + 'CP_best_F1_A.pth' , map_location = device ))
362+ net_G_V_1 .load_state_dict (torch .load (checkpoint_saved_1 + 'CP_best_F1_V.pth' , map_location = device ))
353363 net_G_1 .eval ()
354364 net_G_A_1 .eval ()
355365 net_G_V_1 .eval ()
356366 net_G_1 .to (device = device )
357367 net_G_A_1 .to (device = device )
358368 net_G_V_1 .to (device = device )
359369
360- net_G_2 .load_state_dict (torch .load ( checkpoint_saved_2 + 'CP_best_F1_all.pth' ))
361- net_G_A_2 .load_state_dict (torch .load ( checkpoint_saved_2 + 'CP_best_F1_A.pth' ))
362- net_G_V_2 .load_state_dict (torch .load (checkpoint_saved_2 + 'CP_best_F1_V.pth' ))
370+ net_G_2 .load_state_dict (torch .load ( checkpoint_saved_2 + 'CP_best_F1_all.pth' , map_location = device ))
371+ net_G_A_2 .load_state_dict (torch .load ( checkpoint_saved_2 + 'CP_best_F1_A.pth' , map_location = device ))
372+ net_G_V_2 .load_state_dict (torch .load (checkpoint_saved_2 + 'CP_best_F1_V.pth' , map_location = device ))
363373 net_G_2 .eval ()
364374 net_G_A_2 .eval ()
365375 net_G_V_2 .eval ()
366376 net_G_2 .to (device = device )
367377 net_G_A_2 .to (device = device )
368378 net_G_V_2 .to (device = device )
369379
370- net_G_3 .load_state_dict (torch .load ( checkpoint_saved_3 + 'CP_best_F1_all.pth' ))
371- net_G_A_3 .load_state_dict (torch .load ( checkpoint_saved_3 + 'CP_best_F1_A.pth' ))
372- net_G_V_3 .load_state_dict (torch .load (checkpoint_saved_3 + 'CP_best_F1_V.pth' ))
380+ net_G_3 .load_state_dict (torch .load ( checkpoint_saved_3 + 'CP_best_F1_all.pth' , map_location = device ))
381+ net_G_A_3 .load_state_dict (torch .load ( checkpoint_saved_3 + 'CP_best_F1_A.pth' , map_location = device ))
382+ net_G_V_3 .load_state_dict (torch .load (checkpoint_saved_3 + 'CP_best_F1_V.pth' , map_location = device ))
373383 net_G_3 .eval ()
374384 net_G_A_3 .eval ()
375385 net_G_V_3 .eval ()
376386 net_G_3 .to (device = device )
377387 net_G_A_3 .to (device = device )
378388 net_G_V_3 .to (device = device )
379389
380- net_G_4 .load_state_dict (torch .load ( checkpoint_saved_4 + 'CP_best_F1_all.pth' ))
381- net_G_A_4 .load_state_dict (torch .load ( checkpoint_saved_4 + 'CP_best_F1_A.pth' ))
382- net_G_V_4 .load_state_dict (torch .load (checkpoint_saved_4 + 'CP_best_F1_V.pth' ))
390+ net_G_4 .load_state_dict (torch .load ( checkpoint_saved_4 + 'CP_best_F1_all.pth' , map_location = device ))
391+ net_G_A_4 .load_state_dict (torch .load ( checkpoint_saved_4 + 'CP_best_F1_A.pth' , map_location = device ))
392+ net_G_V_4 .load_state_dict (torch .load (checkpoint_saved_4 + 'CP_best_F1_V.pth' , map_location = device ))
383393 net_G_4 .eval ()
384394 net_G_A_4 .eval ()
385395 net_G_V_4 .eval ()
386396 net_G_4 .to (device = device )
387397 net_G_A_4 .to (device = device )
388398 net_G_V_4 .to (device = device )
389399
390- net_G_5 .load_state_dict (torch .load ( checkpoint_saved_5 + 'CP_best_F1_all.pth' ))
391- net_G_A_5 .load_state_dict (torch .load ( checkpoint_saved_5 + 'CP_best_F1_A.pth' ))
392- net_G_V_5 .load_state_dict (torch .load (checkpoint_saved_5 + 'CP_best_F1_V.pth' ))
400+ net_G_5 .load_state_dict (torch .load ( checkpoint_saved_5 + 'CP_best_F1_all.pth' , map_location = device ))
401+ net_G_A_5 .load_state_dict (torch .load ( checkpoint_saved_5 + 'CP_best_F1_A.pth' , map_location = device ))
402+ net_G_V_5 .load_state_dict (torch .load (checkpoint_saved_5 + 'CP_best_F1_V.pth' , map_location = device ))
393403 net_G_5 .eval ()
394404 net_G_A_5 .eval ()
395405 net_G_V_5 .eval ()
396406 net_G_5 .to (device = device )
397407 net_G_A_5 .to (device = device )
398408 net_G_V_5 .to (device = device )
399409
400- net_G_6 .load_state_dict (torch .load ( checkpoint_saved_6 + 'CP_best_F1_all.pth' ))
401- net_G_A_6 .load_state_dict (torch .load ( checkpoint_saved_6 + 'CP_best_F1_A.pth' ))
402- net_G_V_6 .load_state_dict (torch .load (checkpoint_saved_6 + 'CP_best_F1_V.pth' ))
410+ net_G_6 .load_state_dict (torch .load ( checkpoint_saved_6 + 'CP_best_F1_all.pth' , map_location = device ))
411+ net_G_A_6 .load_state_dict (torch .load ( checkpoint_saved_6 + 'CP_best_F1_A.pth' , map_location = device ))
412+ net_G_V_6 .load_state_dict (torch .load (checkpoint_saved_6 + 'CP_best_F1_V.pth' , map_location = device ))
403413 net_G_6 .eval ()
404414 net_G_A_6 .eval ()
405415 net_G_V_6 .eval ()
406416 net_G_6 .to (device = device )
407417 net_G_A_6 .to (device = device )
408418 net_G_V_6 .to (device = device )
409419
410- net_G_7 .load_state_dict (torch .load ( checkpoint_saved_7 + 'CP_best_F1_all.pth' ))
411- net_G_A_7 .load_state_dict (torch .load ( checkpoint_saved_7 + 'CP_best_F1_A.pth' ))
412- net_G_V_7 .load_state_dict (torch .load (checkpoint_saved_7 + 'CP_best_F1_V.pth' ))
420+ net_G_7 .load_state_dict (torch .load ( checkpoint_saved_7 + 'CP_best_F1_all.pth' , map_location = device ))
421+ net_G_A_7 .load_state_dict (torch .load ( checkpoint_saved_7 + 'CP_best_F1_A.pth' , map_location = device ))
422+ net_G_V_7 .load_state_dict (torch .load (checkpoint_saved_7 + 'CP_best_F1_V.pth' , map_location = device ))
413423 net_G_7 .eval ()
414424 net_G_A_7 .eval ()
415425 net_G_V_7 .eval ()
416426 net_G_7 .to (device = device )
417427 net_G_A_7 .to (device = device )
418428 net_G_V_7 .to (device = device )
419429
420- net_G_8 .load_state_dict (torch .load ( checkpoint_saved_8 + 'CP_best_F1_all.pth' ))
421- net_G_A_8 .load_state_dict (torch .load ( checkpoint_saved_8 + 'CP_best_F1_A.pth' ))
422- net_G_V_8 .load_state_dict (torch .load (checkpoint_saved_8 + 'CP_best_F1_V.pth' ))
430+ net_G_8 .load_state_dict (torch .load ( checkpoint_saved_8 + 'CP_best_F1_all.pth' , map_location = device ))
431+ net_G_A_8 .load_state_dict (torch .load ( checkpoint_saved_8 + 'CP_best_F1_A.pth' , map_location = device ))
432+ net_G_V_8 .load_state_dict (torch .load (checkpoint_saved_8 + 'CP_best_F1_V.pth' , map_location = device ))
423433 net_G_8 .eval ()
424434 net_G_A_8 .eval ()
425435 net_G_V_8 .eval ()
0 commit comments