-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathinference.py
More file actions
63 lines (58 loc) · 2.47 KB
/
inference.py
File metadata and controls
63 lines (58 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import timm
assert timm.__version__ == "0.3.2" # version check for timm
from ops.argparser import argparser_infer
from ops.file_format_convert import convert_to_pkl
def main(args):
import socket
hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname)
print("local ip: ",local_ip)
#format processing, convert different formats to .pkl format for further processing
output_dir = os.path.abspath(args.output)
os.makedirs(output_dir,exist_ok=True)
input_file = os.path.abspath(args.input)
config_resolution = args.resolution
input_pkl=convert_to_pkl(input_file, output_dir,config_resolution)
#for reproducibility analysis, we need to smooth the matrix to generate embeddings.
if args.task==1:
from ops.smooth_matrix import smooth_pkl
smooth_pkl_file = os.path.join(output_dir,"input_smoothed.pkl")
input_pkl = smooth_pkl(input_pkl,smooth_pkl_file)
print("Reproducibility analysis smoothed input matrix saved to ",input_pkl)
from inference.main_worker import main_worker
main_worker(args, input_pkl)
if __name__ == '__main__':
print("HiCFoundation inference started!")
parser = argparser_infer()
args = parser.parse_args()
if args.gpu is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
#print mode based on --task
if args.task==1:
print("Reproducibility analysis")
elif args.task==2:
print("Loop calling")
elif args.task==3:
print("Resolution enhancement")
elif args.task==4:
print("Epigenomic assay prediction")
elif args.task==5:
print("scHi-C enhancement")
elif args.task==6:
print("Hi-C embedding generation")
embed_depth = args.embed_depth
if embed_depth>8:
print("Error: embed_depth is larger than 8, that is beyond decoder depth. Please set embed_depth<=8")
print("0 indicates the encoder output, k indicates the k-th decoder layer's output")
exit(1)
else:
print("Unknown task specified ",args.task)
print("Please specify the task using --task with 1,2,3,4,5,6")
exit(1)
#check the specied input size, must be a multiple of args.patch_size
if args.input_row_size%args.patch_size!=0 or args.input_col_size%args.patch_size!=0:
print("args configuration error: input_row_size and input_col_size must be a multiple of patch_size")
exit(1)
#output the args in a beautiful format
main(args)