Skip to content

Commit ea57950

Browse files
authored
Update handler.py
1 parent d3213a0 commit ea57950

File tree

1 file changed

+62
-23
lines changed

1 file changed

+62
-23
lines changed

handler.py

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,47 +34,86 @@ def upload_to_drive(local_path, folder_id):
3434
f = DRIVE.files().create(body=meta, media_body=media, fields="id").execute()
3535
return f["id"]
3636

37+
def run_cmd(cmd, cwd="/app"):
38+
result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd)
39+
if result.returncode != 0:
40+
raise RuntimeError(result.stderr[-3000:])
41+
return result
42+
3743
def handler(job):
3844
inp = job["input"]
39-
source_file_id = inp.get("source_file_id")
40-
hint_file_id = inp.get("hint_file_id")
45+
46+
source_file_id = inp.get("source_file_id") # Drive file ID of green screen clip/frame
47+
hint_file_id = inp.get("hint_file_id") # Optional: provide your own alpha hint
4148
output_folder_id = inp.get("output_folder_id", DRIVE_FOLDER_ID)
49+
hint_method = inp.get("hint_method", "gvm") # "gvm", "videomama", or "manual"
4250
device = inp.get("device", "cuda")
4351
despill = inp.get("despill_strength", 0)
44-
src_ext = inp.get("source_ext", "png")
45-
hint_ext = inp.get("hint_ext", "png")
52+
src_ext = inp.get("source_ext", "mp4")
4653

47-
if not source_file_id or not hint_file_id:
48-
return {"error": "source_file_id and hint_file_id are required"}
54+
if not source_file_id:
55+
return {"error": "source_file_id is required"}
56+
if hint_method == "manual" and not hint_file_id:
57+
return {"error": "hint_file_id is required when hint_method is manual"}
4958

5059
with tempfile.TemporaryDirectory() as tmpdir:
5160
clips_dir = Path(tmpdir) / "clips"
5261
shot_dir = clips_dir / "shot"
5362
out_dir = Path(tmpdir) / "output"
54-
(shot_dir / "Input").mkdir(parents=True)
55-
(shot_dir / "AlphaHint").mkdir(parents=True)
63+
input_dir = shot_dir / "Input"
64+
hint_dir = shot_dir / "AlphaHint"
65+
input_dir.mkdir(parents=True)
66+
hint_dir.mkdir(parents=True)
5667
out_dir.mkdir()
5768

58-
runpod.serverless.progress_update(job, "Downloading from Drive...")
59-
download_from_drive(source_file_id, shot_dir / "Input" / f"frame.{src_ext}")
60-
download_from_drive(hint_file_id, shot_dir / "AlphaHint" / f"frame.{hint_ext}")
69+
# Download source footage from Drive
70+
runpod.serverless.progress_update(job, "Downloading source from Drive...")
71+
src_path = input_dir / f"source.{src_ext}"
72+
download_from_drive(source_file_id, src_path)
6173

62-
runpod.serverless.progress_update(job, "Running CorridorKey inference...")
63-
result = subprocess.run(
64-
[
65-
"/app/.venv/bin/python", "corridorkey_cli.py",
66-
"--action", "run_inference",
74+
# Generate or download alpha hint
75+
if hint_method == "manual":
76+
runpod.serverless.progress_update(job, "Downloading alpha hint from Drive...")
77+
download_from_drive(hint_file_id, hint_dir / f"source.png")
78+
79+
elif hint_method == "gvm":
80+
runpod.serverless.progress_update(job, "Generating alpha hint with GVM...")
81+
run_cmd([
82+
"/app/.venv/bin/python", "clip_manager.py",
83+
"--action", "generate_hints",
6784
"--clips_dir", str(clips_dir),
68-
"--output_dir", str(out_dir),
85+
"--hint_method", "gvm",
6986
"--device", device,
70-
"--despill_strength", str(despill),
71-
],
72-
capture_output=True, text=True, cwd="/app",
73-
)
87+
])
7488

75-
if result.returncode != 0:
76-
return {"error": "Inference failed", "stderr": result.stderr[-3000:]}
89+
elif hint_method == "videomama":
90+
runpod.serverless.progress_update(job, "Generating alpha hint with VideoMaMa...")
91+
# VideoMaMa requires a rough mask hint — user must supply it
92+
if not hint_file_id:
93+
return {"error": "VideoMaMa requires a hint_file_id (rough mask) in the VideoMamaMaskHint folder"}
94+
mask_dir = shot_dir / "VideoMamaMaskHint"
95+
mask_dir.mkdir(parents=True)
96+
download_from_drive(hint_file_id, mask_dir / f"source.png")
97+
run_cmd([
98+
"/app/.venv/bin/python", "clip_manager.py",
99+
"--action", "generate_hints",
100+
"--clips_dir", str(clips_dir),
101+
"--hint_method", "videomama",
102+
"--device", device,
103+
])
104+
105+
# Run CorridorKey inference
106+
runpod.serverless.progress_update(job, "Running CorridorKey inference...")
107+
run_cmd([
108+
"/app/.venv/bin/python", "corridorkey_cli.py",
109+
"--action", "run_inference",
110+
"--clips_dir", str(clips_dir),
111+
"--output_dir", str(out_dir),
112+
"--device", device,
113+
"--despill_strength", str(despill),
114+
])
77115

116+
# Upload outputs to Drive
78117
runpod.serverless.progress_update(job, "Uploading results to Drive...")
79118
uploaded = {}
80119
for folder in ["Processed", "Matte", "FG", "Comp"]:

0 commit comments

Comments
 (0)