-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathtorch_save_tensor.py
More file actions
90 lines (76 loc) · 2.99 KB
/
torch_save_tensor.py
File metadata and controls
90 lines (76 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import time
import argparse
import torch
import os
from torch_save_utils import PINNED_BUFFER_MB, load_io_ops
from torch_save_utils import test_save, test_ds_mock_save, test_ds_py_save, test_ds_aio_fast_save, test_ds_gds_fast_save
import deepspeed
from deepspeed.accelerator import get_accelerator
import os
def run(args):
device = get_accelerator().current_device_name() if args.gpu else 'cpu'
buffer = torch.randint(high=128,
size=(args.mb_size * (1024**2), ),
dtype=torch.uint8,
device=device)
fn_dict = {
'test_save': test_save,
'test_ds_mock_save': test_ds_mock_save,
'test_ds_py_save': test_ds_py_save,
'test_ds_gds_fast_save': test_ds_gds_fast_save,
'test_ds_aio_fast_save': test_ds_aio_fast_save,
}
for tag, fn in fn_dict.items():
if tag == 'test_ds_gds_fast_save' and not args.gpu:
continue
file = os.path.join(args.folder, f'{tag}_{args.mb_size}MB.pt')
print(f'checkpoint file = {file}')
st = time.time()
write_sec = fn(file, buffer, args)
gb_per_sec = args.mb_size / (1024.0 * write_sec)
gb_size = os.path.getsize(file) / (1024**3)
print(
f'{tag} -- {gb_size:5.2f} GB, {write_sec:5.2f} secs, {gb_per_sec:5.2f} GB/s'
)
print(f'*********************************************')
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--folder',
default=None,
type=str,
required=True,
help='Folder to use for I/O.')
parser.add_argument('--mb_size',
type=int,
default=None,
required=True,
help='Size of tensor to save in MB.')
parser.add_argument('--zipfile',
action='store_true',
help='Use torch zipfile save format')
parser.add_argument('--gpu', action='store_true', help='Use gpu tensors.')
parser.add_argument('--io_buffer_mb',
type=int,
default=PINNED_BUFFER_MB,
help='Size of pinned i/o buffer in MB.')
parser.add_argument('--no-statistics',
action='store_true',
help='Suppress low-level performance statistics.')
parser.add_argument('--single_io_buffer',
action='store_true',
help='Disable double buffering of i/o buffer.')
args = parser.parse_args()
print(f'args = {args}')
return args
def main():
print(
f'Performance test of torch.save() integration of fast tensor checkpointing.'
)
args = parse_arguments()
if not os.path.exists(args.folder):
print(f'Invalid folder: {args.folder}')
quit()
load_io_ops(args)
run(args)
if __name__ == "__main__":
main()