forked from FasterDecoding/BitDelta
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlowbit_lowrank.py
More file actions
23 lines (15 loc) · 803 Bytes
/
lowbit_lowrank.py
File metadata and controls
23 lines (15 loc) · 803 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os
import torch
import torch.nn.functional as F
from bitdelta.diff import compress_diff, save_diff, save_full_model
from bitdelta.misc import find_corr_stddev
from bitdelta.utils import get_model, parse_args, get_tokenizer
from tqdm import tqdm
args = parse_args()
tokenizer = get_tokenizer(args.base_model)
with torch.no_grad():
base_model = get_model(args.base_model, args.base_model_device, args.base_model_memory_map)
finetuned_model = get_model(args.finetuned_model, args.finetuned_model_device, args.finetuned_model_memory_map)
finetuned_compressed_model = get_model(args.finetuned_model, args.finetuned_compressed_model_device, args.finetuned_compressed_model_memory_map)
print(f"compressing diff...")
compress_diff(base_model, finetuned_model, finetuned_compressed_model)