diff --git a/scripts/depthmap_api.py b/scripts/depthmap_api.py index 9c2d59d..5767fd0 100644 --- a/scripts/depthmap_api.py +++ b/scripts/depthmap_api.py @@ -1,7 +1,8 @@ -# Non-public API. Don't host publicly - SECURITY RISKS! -# (will only be on with --api starting option) -# Currently no API stability guarantees are provided - API may break on any new commit. +# DO NOT HOST PUBLICLY - SECURITY RISKS! +# (the API will only be on with --api starting option) +# Currently no API stability guarantees are provided - API may break on any new commit (but hopefully won't). +import os import numpy as np from fastapi import FastAPI, Body from fastapi.exceptions import HTTPException @@ -12,7 +13,7 @@ from typing import Dict, List from modules.api import api -from src.core import core_generation_funnel +from src.core import core_generation_funnel, run_makevideo from src.misc import SCRIPT_VERSION from src import backbone from src.common_constants import GenerationOptions as go @@ -70,12 +71,110 @@ async def process( if not isinstance(result, Image.Image): continue results_based += [encode_to_base64(result)] + return {"images": results_based, "info": "Success"} + @app.post("/depth/generate/video") + async def process_video( + depth_input_images: List[str] = Body([], title='Input Images'), + options: Dict[str, object] = Body("options", title='Generation options'), + ): + if len(depth_input_images) == 0: + raise HTTPException(status_code=422, detail="No images supplied") + print(f"Processing {str(len(depth_input_images))} images trough the API") + + available_models = { + 'res101': 0, + 'dpt_beit_large_512': 1, #midas 3.1 + 'dpt_beit_large_384': 2, #midas 3.1 + 'dpt_large_384': 3, #midas 3.0 + 'dpt_hybrid_384': 4, #midas 3.0 + 'midas_v21': 5, + 'midas_v21_small': 6, + 'zoedepth_n': 7, #indoor + 'zoedepth_k': 8, #outdoor + 'zoedepth_nk': 9, + } + + model_type = options["model_type"] + + model_id = None + if isinstance(model_type, str): + # Check if the string is in the available_models dictionary + if model_type in available_models: + model_id = available_models[model_type] + else: + available_strings = list(available_models.keys()) + raise HTTPException(status_code=400, detail={'error': 'Invalid model string', 'available_models': available_strings}) + elif isinstance(model_type, int): + model_id = model_type + else: + raise HTTPException(status_code=400, detail={'error': 'Invalid model parameter type'}) + + options["model_type"] = model_id + + video_parameters = options["video_parameters"] + + required_params = ["vid_numframes", "vid_fps", "vid_traj", "vid_shift", "vid_border", "dolly", "vid_format", "vid_ssaa", "output_filename"] + + missing_params = [param for param in required_params if param not in video_parameters] + + if missing_params: + raise HTTPException(status_code=400, detail={'error': f"Missing required parameter(s): {', '.join(missing_params)}"}) + + vid_numframes = video_parameters["vid_numframes"] + vid_fps = video_parameters["vid_fps"] + vid_traj = video_parameters["vid_traj"] + vid_shift = video_parameters["vid_shift"] + vid_border = video_parameters["vid_border"] + dolly = video_parameters["dolly"] + vid_format = video_parameters["vid_format"] + vid_ssaa = int(video_parameters["vid_ssaa"]) + + output_filename = video_parameters["output_filename"] + output_path = os.path.dirname(output_filename) + basename, extension = os.path.splitext(os.path.basename(output_filename)) + + # Comparing video_format with the extension + if vid_format != extension[1:]: + raise HTTPException(status_code=400, detail={'error': f"Video format '{vid_format}' does not match with the extension '{extension}'."}) + + pil_images = [] + for input_image in depth_input_images: + pil_images.append(to_base64_PIL(input_image)) + outpath = backbone.get_outpath() + + mesh_fi_filename = video_parameters.get('mesh_fi_filename', None) + + if mesh_fi_filename and os.path.exists(mesh_fi_filename): + mesh_fi = mesh_fi_filename + print("Loaded existing mesh from: ", mesh_fi) + else: + # If there is no mesh file generate it. + options["GEN_INPAINTED_MESH"] = True + + gen_obj = core_generation_funnel(outpath, pil_images, None, None, options) + + mesh_fi = None + for count, type, result in gen_obj: + if type == 'inpainted_mesh': + mesh_fi = result + break + + if mesh_fi: + print("Created mesh in: ", mesh_fi) + else: + raise HTTPException(status_code=400, detail={'error': "The mesh has not been created"}) + + run_makevideo(mesh_fi, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa, output_path, basename) + + return {"info": "Success"} + try: import modules.script_callbacks as script_callbacks if backbone.get_cmd_opt('api', False): script_callbacks.on_app_started(depth_api) + print("Started the depthmap API. DO NOT HOST PUBLICLY - SECURITY RISKS!") except: print('DepthMap API could not start') diff --git a/src/core.py b/src/core.py index ee6a5dc..12621bd 100644 --- a/src/core.py +++ b/src/core.py @@ -578,9 +578,8 @@ def run_3dphoto_videos(mesh_fi, basename, outpath, num_frames, fps, crop_border, fnExt=vid_format) return fn_saved - -# called from gen vid tab button -def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa): +def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa, + outpath=None, basename=None): if len(fn_mesh) == 0 or not os.path.exists(fn_mesh): raise Exception("Could not open mesh.") @@ -608,20 +607,24 @@ def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_bord raise Exception("Crop Border requires 4 elements.") crop_border = [float(borders[0]), float(borders[1]), float(borders[2]), float(borders[3])] - # output path and filename mess .. - basename = Path(fn_mesh).stem - outpath = backbone.get_outpath() - # unique filename - basecount = backbone.get_next_sequence_number(outpath, basename) - if basecount > 0: basecount = basecount - 1 - fullfn = None - for i in range(500): - fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" - fullfn = os.path.join(outpath, f"{fn}_." + vid_format) - if not os.path.exists(fullfn): - break - basename = Path(fullfn).stem - basename = basename[:-1] + if not outpath: + outpath = backbone.get_outpath() + + if not basename: + # output path and filename mess .. + basename = Path(fn_mesh).stem + + # unique filename + basecount = backbone.get_next_sequence_number(outpath, basename) + if basecount > 0: basecount = basecount - 1 + fullfn = None + for i in range(500): + fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" + fullfn = os.path.join(outpath, f"{fn}_." + vid_format) + if not os.path.exists(fullfn): + break + basename = Path(fullfn).stem + basename = basename[:-1] print("Loading mesh ..") @@ -630,7 +633,6 @@ def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_bord return fn_saved[-1], fn_saved[-1], '' - def unload_models(): model_holder.unload_models()