@@ -34,47 +34,86 @@ def upload_to_drive(local_path, folder_id):
3434 f = DRIVE .files ().create (body = meta , media_body = media , fields = "id" ).execute ()
3535 return f ["id" ]
3636
37+ def run_cmd (cmd , cwd = "/app" ):
38+ result = subprocess .run (cmd , capture_output = True , text = True , cwd = cwd )
39+ if result .returncode != 0 :
40+ raise RuntimeError (result .stderr [- 3000 :])
41+ return result
42+
3743def handler (job ):
3844 inp = job ["input" ]
39- source_file_id = inp .get ("source_file_id" )
40- hint_file_id = inp .get ("hint_file_id" )
45+
46+ source_file_id = inp .get ("source_file_id" ) # Drive file ID of green screen clip/frame
47+ hint_file_id = inp .get ("hint_file_id" ) # Optional: provide your own alpha hint
4148 output_folder_id = inp .get ("output_folder_id" , DRIVE_FOLDER_ID )
49+ hint_method = inp .get ("hint_method" , "gvm" ) # "gvm", "videomama", or "manual"
4250 device = inp .get ("device" , "cuda" )
4351 despill = inp .get ("despill_strength" , 0 )
44- src_ext = inp .get ("source_ext" , "png" )
45- hint_ext = inp .get ("hint_ext" , "png" )
52+ src_ext = inp .get ("source_ext" , "mp4" )
4653
47- if not source_file_id or not hint_file_id :
48- return {"error" : "source_file_id and hint_file_id are required" }
54+ if not source_file_id :
55+ return {"error" : "source_file_id is required" }
56+ if hint_method == "manual" and not hint_file_id :
57+ return {"error" : "hint_file_id is required when hint_method is manual" }
4958
5059 with tempfile .TemporaryDirectory () as tmpdir :
5160 clips_dir = Path (tmpdir ) / "clips"
5261 shot_dir = clips_dir / "shot"
5362 out_dir = Path (tmpdir ) / "output"
54- (shot_dir / "Input" ).mkdir (parents = True )
55- (shot_dir / "AlphaHint" ).mkdir (parents = True )
63+ input_dir = shot_dir / "Input"
64+ hint_dir = shot_dir / "AlphaHint"
65+ input_dir .mkdir (parents = True )
66+ hint_dir .mkdir (parents = True )
5667 out_dir .mkdir ()
5768
58- runpod .serverless .progress_update (job , "Downloading from Drive..." )
59- download_from_drive (source_file_id , shot_dir / "Input" / f"frame.{ src_ext } " )
60- download_from_drive (hint_file_id , shot_dir / "AlphaHint" / f"frame.{ hint_ext } " )
69+ # Download source footage from Drive
70+ runpod .serverless .progress_update (job , "Downloading source from Drive..." )
71+ src_path = input_dir / f"source.{ src_ext } "
72+ download_from_drive (source_file_id , src_path )
6173
62- runpod .serverless .progress_update (job , "Running CorridorKey inference..." )
63- result = subprocess .run (
64- [
65- "/app/.venv/bin/python" , "corridorkey_cli.py" ,
66- "--action" , "run_inference" ,
74+ # Generate or download alpha hint
75+ if hint_method == "manual" :
76+ runpod .serverless .progress_update (job , "Downloading alpha hint from Drive..." )
77+ download_from_drive (hint_file_id , hint_dir / f"source.png" )
78+
79+ elif hint_method == "gvm" :
80+ runpod .serverless .progress_update (job , "Generating alpha hint with GVM..." )
81+ run_cmd ([
82+ "/app/.venv/bin/python" , "clip_manager.py" ,
83+ "--action" , "generate_hints" ,
6784 "--clips_dir" , str (clips_dir ),
68- "--output_dir " , str ( out_dir ) ,
85+ "--hint_method " , "gvm" ,
6986 "--device" , device ,
70- "--despill_strength" , str (despill ),
71- ],
72- capture_output = True , text = True , cwd = "/app" ,
73- )
87+ ])
7488
75- if result .returncode != 0 :
76- return {"error" : "Inference failed" , "stderr" : result .stderr [- 3000 :]}
89+ elif hint_method == "videomama" :
90+ runpod .serverless .progress_update (job , "Generating alpha hint with VideoMaMa..." )
91+ # VideoMaMa requires a rough mask hint — user must supply it
92+ if not hint_file_id :
93+ return {"error" : "VideoMaMa requires a hint_file_id (rough mask) in the VideoMamaMaskHint folder" }
94+ mask_dir = shot_dir / "VideoMamaMaskHint"
95+ mask_dir .mkdir (parents = True )
96+ download_from_drive (hint_file_id , mask_dir / f"source.png" )
97+ run_cmd ([
98+ "/app/.venv/bin/python" , "clip_manager.py" ,
99+ "--action" , "generate_hints" ,
100+ "--clips_dir" , str (clips_dir ),
101+ "--hint_method" , "videomama" ,
102+ "--device" , device ,
103+ ])
104+
105+ # Run CorridorKey inference
106+ runpod .serverless .progress_update (job , "Running CorridorKey inference..." )
107+ run_cmd ([
108+ "/app/.venv/bin/python" , "corridorkey_cli.py" ,
109+ "--action" , "run_inference" ,
110+ "--clips_dir" , str (clips_dir ),
111+ "--output_dir" , str (out_dir ),
112+ "--device" , device ,
113+ "--despill_strength" , str (despill ),
114+ ])
77115
116+ # Upload outputs to Drive
78117 runpod .serverless .progress_update (job , "Uploading results to Drive..." )
79118 uploaded = {}
80119 for folder in ["Processed" , "Matte" , "FG" , "Comp" ]:
0 commit comments