Skip to content

Commit f9c5a15

Browse files
Update core.py
run_makevideo optional outpath and basename. When is called from the api the user can specify the folder and the file name.
1 parent 28d4097 commit f9c5a15

File tree

1 file changed

+20
-95
lines changed

1 file changed

+20
-95
lines changed

src/core.py

Lines changed: 20 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,6 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
209209
inpaint_imgs.append(inputimages[count])
210210
inpaint_depths.append(img_output)
211211

212-
if inp[go.RUN_MAKEVIDEO_API]:
213-
inpaint_imgs.append(inputimages[count])
214-
inpaint_depths.append(img_output)
215-
216212
# applying background masks after depth
217213
if inp[go.GEN_REMBG]:
218214
print('applying background masks')
@@ -342,15 +338,6 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
342338
except Exception as e:
343339
print(f'{str(e)}, some issue with generating inpainted mesh')
344340

345-
if inp[go.RUN_MAKEVIDEO_API]:
346-
try:
347-
video = run_makevideo_api(device, inpaint_imgs, inpaint_depths, inputnames, outpath,
348-
inp[go.RUN_MAKEVIDEO_API],
349-
1, "mp4")
350-
yield 0, 'video', None
351-
except Exception as e:
352-
print(f'{str(e)}, some issue with generating video')
353-
354341
backbone.reload_sd_model()
355342
print("All done.\n")
356343

@@ -591,12 +578,10 @@ def run_3dphoto_videos(mesh_fi, basename, outpath, num_frames, fps, crop_border,
591578
fnExt=vid_format)
592579
return fn_saved
593580

594-
595-
# called from gen vid tab button
596-
def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa):
581+
def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa, outpath=None, basename=None):
597582
if len(fn_mesh) == 0 or not os.path.exists(fn_mesh):
598583
raise Exception("Could not open mesh.")
599-
584+
600585
vid_ssaa = int(vid_ssaa)
601586

602587
# traj type
@@ -621,20 +606,24 @@ def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_bord
621606
raise Exception("Crop Border requires 4 elements.")
622607
crop_border = [float(borders[0]), float(borders[1]), float(borders[2]), float(borders[3])]
623608

624-
# output path and filename mess ..
625-
basename = Path(fn_mesh).stem
626-
outpath = backbone.get_outpath()
627-
# unique filename
628-
basecount = backbone.get_next_sequence_number(outpath, basename)
629-
if basecount > 0: basecount = basecount - 1
630-
fullfn = None
631-
for i in range(500):
632-
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
633-
fullfn = os.path.join(outpath, f"{fn}_." + vid_format)
634-
if not os.path.exists(fullfn):
635-
break
636-
basename = Path(fullfn).stem
637-
basename = basename[:-1]
609+
if not outpath:
610+
outpath = backbone.get_outpath()
611+
612+
if not basename:
613+
# output path and filename mess ..
614+
basename = Path(fn_mesh).stem
615+
616+
# unique filename
617+
basecount = backbone.get_next_sequence_number(outpath, basename)
618+
if basecount > 0: basecount = basecount - 1
619+
fullfn = None
620+
for i in range(500):
621+
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
622+
fullfn = os.path.join(outpath, f"{fn}_." + vid_format)
623+
if not os.path.exists(fullfn):
624+
break
625+
basename = Path(fullfn).stem
626+
basename = basename[:-1]
638627

639628
print("Loading mesh ..")
640629

@@ -643,70 +632,6 @@ def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_bord
643632

644633
return fn_saved[-1], fn_saved[-1], ''
645634

646-
def run_makevideo_api(device, inpaint_imgs, inpaint_depths, inputnames, outpath, gen_video_params, vid_ssaa, vid_format):
647-
648-
required_params = ["vid_numframes", "vid_fps", "vid_traj", "vid_shift", "vid_border", "dolly", "vid_format", "vid_ssaa", "output_filename"]
649-
650-
missing_params = [param for param in required_params if param not in gen_video_params]
651-
652-
if missing_params:
653-
raise ValueError(f"Missing required parameter(s): {', '.join(missing_params)}")
654-
655-
mesh_fi_filename = gen_video_params.get('mesh_fi_filename', None)
656-
657-
if mesh_fi_filename and os.path.exists(mesh_fi_filename):
658-
659-
mesh_fi = mesh_fi_filename
660-
print("Loaded existing mesh from: ", mesh_fi)
661-
else:
662-
#If there is no mesh file generate it.
663-
mesh_fi = run_3dphoto(device, inpaint_imgs, inpaint_depths, inputnames, outpath, None, 1, "mp4")
664-
print("Created mesh in: ", mesh_fi)
665-
666-
vid_numframes = gen_video_params["vid_numframes"]
667-
vid_fps = gen_video_params["vid_fps"]
668-
vid_traj = gen_video_params["vid_traj"]
669-
vid_shift = gen_video_params["vid_shift"]
670-
vid_border = gen_video_params["vid_border"]
671-
dolly = gen_video_params["dolly"]
672-
vid_format = gen_video_params["vid_format"]
673-
vid_ssaa = int(gen_video_params["vid_ssaa"])
674-
675-
output_filename = gen_video_params["output_filename"]
676-
output_path = os.path.dirname(output_filename)
677-
basename, extension = os.path.splitext(os.path.basename(output_filename))
678-
679-
# Comparing video_format with the extension
680-
if vid_format != extension[1:]:
681-
raise ValueError(f"Video format '{vid_format}' does not match with the extension '{extension}'.")
682-
683-
# traj type
684-
if vid_traj == 0:
685-
vid_traj = ['straight-line']
686-
elif vid_traj == 1:
687-
vid_traj = ['double-straight-line']
688-
elif vid_traj == 2:
689-
vid_traj = ['circle']
690-
691-
num_fps = int(vid_fps)
692-
num_frames = int(vid_numframes)
693-
shifts = vid_shift.split(',')
694-
if len(shifts) != 3:
695-
raise Exception("Translate requires 3 elements.")
696-
x_shift_range = [float(shifts[0])]
697-
y_shift_range = [float(shifts[1])]
698-
z_shift_range = [float(shifts[2])]
699-
700-
borders = vid_border.split(',')
701-
if len(borders) != 4:
702-
raise Exception("Crop Border requires 4 elements.")
703-
crop_border = [float(borders[0]), float(borders[1]), float(borders[2]), float(borders[3])]
704-
705-
fn_saved = run_3dphoto_videos(mesh_fi, basename, output_path, num_frames, num_fps, crop_border, vid_traj, x_shift_range,
706-
y_shift_range, z_shift_range, [''], dolly, vid_format, vid_ssaa)
707-
708-
return fn_saved[-1]
709-
710635
def unload_models():
711636
model_holder.unload_models()
712637

0 commit comments

Comments
 (0)