forked from leeruibin/RORem
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_RORem_4S.py
More file actions
145 lines (135 loc) · 4.55 KB
/
inference_RORem_4S.py
File metadata and controls
145 lines (135 loc) · 4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from diffusers import AutoPipelineForInpainting
import torch
import os
from diffusers import UNet2DConditionModel
import argparse
from myutils.img_util import dilate_mask
from diffusers import LCMScheduler
from diffusers.utils import load_image
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model",
type=str,
default="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--RORem_unet",
type=str,
default=None,
required=True,
help="Path to pretrain RORem Unet",
)
parser.add_argument(
"--RORem_LoRA",
type=str,
default=None,
required=True,
help="Path to pretrain RORem Unet",
)
parser.add_argument(
"--image_path",
type=str,
default=None,
help="Path to the input image.",
)
parser.add_argument(
"--mask_path",
type=str,
default=None,
help="Path to the mask image.",
)
parser.add_argument(
"--save_path",
type=str,
default=None,
help="Path to save the removal result.",
)
parser.add_argument(
"--inference_steps",
type=int,
default=50,
)
parser.add_argument(
"--resolution",
default=512,
type=int
)
parser.add_argument(
"--dilate_size",
default=20,
type=int,
help="dilate the mask"
)
parser.add_argument(
"--use_CFG",
type=lambda x: x.lower() == 'true',
default=True,
help="whether to enable CFG, can reduce the artifacts in the mask region in our final test"
)
args = parser.parse_args()
return args
def main(args):
if args.pretrained_model is None:
pretrain_path = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
else:
pretrain_path = args.pretrained_model
# Step 1: load pretrained SDXL-inpainting model
pipe_edit = AutoPipelineForInpainting.from_pretrained(
pretrain_path,
torch_dtype=torch.float16,
)
# Step 2: load RORem Unet
unet = UNet2DConditionModel.from_pretrained(args.RORem_unet).to("cuda",dtype=torch.float16)
print(f"Finish loading unet from {args.RORem_unet}!!")
pipe_edit.unet = unet
# Step 3: load LoRA
pipe_edit.scheduler = LCMScheduler.from_config(pipe_edit.scheduler.config)
pipe_edit.load_lora_weights(args.RORem_LoRA)
pipe_edit.fuse_lora()
pipe_edit.to("cuda")
height = width = args.resolution
image_name = args.image_path.split("/")[-1]
if args.save_path is None:
save_folder = "removal_result"
os.makedirs(save_folder,exist_ok=True)
args.save_path = f"{save_folder}/{image_name}"
else:
save_folder = os.path.dirname(args.save_path)
os.makedirs(save_folder,exist_ok=True)
input_image = load_image(args.input_path).resize((args.resolution,args.resolution))
input_mask = load_image(args.mask_path).resize((args.resolution,args.resolution))
if args.dilate_size != 0:
mask_image = dilate_mask(mask_image,args.dilate_size)
if not args.use_CFG:
prompts = ""
Removal_result = pipe_edit(
prompt=prompts,
height=height,
width=width,
image=input_image,
mask_image=input_mask,
guidance_scale=1.,
num_inference_steps=4,
strength=0.99, # make sure to use `strength` below 1.0
).images[0]
else:
# we also find by adding these prompt, the model can work even better
prompts = "4K, high quality, masterpiece, Highly detailed, Sharp focus, Professional, photorealistic, realistic"
negative_prompts = "low quality, worst, bad proportions, blurry, extra finger, Deformed, disfigured, unclear background"
Removal_result = pipe_edit(
prompt=prompts,
negative_prompt=negative_prompts,
height=height,
width=width,
image=input_image,
mask_image=input_mask,
guidance_scale=1.,
num_inference_steps=4, # steps between 15 and 30 also work well
strength=0.99, # make sure to use `strength` below 1.0
).images[0]
Removal_result.save(save_folder)
if __name__ == "__main__":
args = parse_args()
main(args)