diff --git a/main.py b/main.py index a05cb2d..67f8c5d 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,7 @@ import argparse import os import pathlib - +import time import src.misc @@ -29,6 +29,7 @@ def maybe_chdir(): parser = argparse.ArgumentParser() parser.add_argument("--share", help="Create public link", action='store_true') parser.add_argument("--listen", help="Create public link", action='store_true') + parser.add_argument("--api", help="Start-up api", action='store_true') parser.add_argument("--no_chdir", help="Do not try to use the root of stable-diffusion-webui", action='store_true') args = parser.parse_args() @@ -37,4 +38,15 @@ def maybe_chdir(): maybe_chdir() server_name = "0.0.0.0" if args.listen else None import src.common_ui - src.common_ui.on_ui_tabs().launch(share=args.share, server_name=server_name) + + ui_block = src.common_ui.on_ui_tabs() + + if not args.api: + ui_block.launch(share=args.share, server_name=server_name) + else: + app, _, _ = ui_block.launch(share=args.share, server_name=server_name, prevent_thread_lock=True) + print(f"starting depth api") + from src.api.api_standalone import init_api + init_api(ui_block, app) + while True: + time.sleep(0.1) \ No newline at end of file diff --git a/scripts/depthmap.py b/scripts/depthmap.py index 7dbdbb2..461538f 100644 --- a/scripts/depthmap.py +++ b/scripts/depthmap.py @@ -96,3 +96,14 @@ def add_option(name, default_value, description, name_prefix='depthmap_script'): from modules import script_callbacks script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_tabs(lambda: [(common_ui.on_ui_tabs(), "Depth", "depthmap_interface")]) + +# API script +from src.api import api_routes + +try: + import modules.script_callbacks as script_callbacks + if backbone.get_cmd_opt('api', False): + script_callbacks.on_app_started(api_routes.depth_api) +except: + print('DepthMap API could not start') + diff --git a/scripts/depthmap_api.py b/scripts/depthmap_api.py deleted file mode 100644 index f731dea..0000000 --- a/scripts/depthmap_api.py +++ /dev/null @@ -1,81 +0,0 @@ -# Non-public API. Don't host publicly - SECURITY RISKS! -# (will only be on with --api starting option) -# Currently no API stability guarantees are provided - API may break on any new commit. - -import numpy as np -from fastapi import FastAPI, Body -from fastapi.exceptions import HTTPException -from PIL import Image - -import gradio as gr - -from modules.api.models import List, Dict -from modules.api import api - -from src.core import core_generation_funnel -from src.misc import SCRIPT_VERSION -from src import backbone -from src.common_constants import GenerationOptions as go - - -def encode_to_base64(image): - if type(image) is str: - return image - elif type(image) is Image.Image: - return api.encode_pil_to_base64(image) - elif type(image) is np.ndarray: - return encode_np_to_base64(image) - else: - return "" - - -def encode_np_to_base64(image): - pil = Image.fromarray(image) - return api.encode_pil_to_base64(pil) - - -def to_base64_PIL(encoding: str): - return Image.fromarray(np.array(api.decode_base64_to_image(encoding)).astype('uint8')) - - -def depth_api(_: gr.Blocks, app: FastAPI): - @app.get("/depth/version") - async def version(): - return {"version": SCRIPT_VERSION} - - @app.get("/depth/get_options") - async def get_options(): - return {"options": sorted([x.name.lower() for x in go])} - - # TODO: some potential inputs not supported (like custom depthmaps) - @app.post("/depth/generate") - async def process( - depth_input_images: List[str] = Body([], title='Input Images'), - options: Dict[str, object] = Body("options", title='Generation options'), - ): - # TODO: restrict mesh options - - if len(depth_input_images) == 0: - raise HTTPException(status_code=422, detail="No images supplied") - print(f"Processing {str(len(depth_input_images))} images trough the API") - - pil_images = [] - for input_image in depth_input_images: - pil_images.append(to_base64_PIL(input_image)) - outpath = backbone.get_outpath() - gen_obj = core_generation_funnel(outpath, pil_images, None, None, options) - - results_based = [] - for count, type, result in gen_obj: - if not isinstance(result, Image.Image): - continue - results_based += [encode_to_base64(result)] - return {"images": results_based, "info": "Success"} - - -try: - import modules.script_callbacks as script_callbacks - if backbone.get_cmd_opt('api', False): - script_callbacks.on_app_started(depth_api) -except: - print('DepthMap API could not start') diff --git a/src/api/api_constants.py b/src/api/api_constants.py new file mode 100644 index 0000000..ff33784 --- /dev/null +++ b/src/api/api_constants.py @@ -0,0 +1,32 @@ + +api_options = { + #'outputs': ["depth"], # list of outputs to send in response. examples ["depth", "normalmap", 'heatmap', "normal", 'background_removed'] etc + #'conversions': "", #TODO implement. it's a good idea to give some options serverside for because often that's challenging in js/clientside + 'save':"" #TODO implement. To save on local machine. Can be very helpful for debugging. +} + +# TODO: These are intended to be temporary +api_defaults={ + "BOOST": False, + "NET_SIZE_MATCH": True +} + +#These are enforced after user inputs +api_forced={ + "GEN_SIMPLE_MESH": False, + "GEN_INPAINTED_MESH": False +} + +#model diction TODO find a way to remove without forcing people do know indexes of models +models_to_index = { + 'res101':0, + 'dpt_beit_large_512 (midas 3.1)':1, + 'dpt_beit_large_384 (midas 3.1)':2, + 'dpt_large_384 (midas 3.0)':3, + 'dpt_hybrid_384 (midas 3.0)':4, + 'midas_v21':5, + 'midas_v21_small':6, + 'zoedepth_n (indoor)':7, + 'zoedepth_k (outdoor)':8, + 'zoedepth_nk':9 +} \ No newline at end of file diff --git a/src/api/api_core.py b/src/api/api_core.py new file mode 100644 index 0000000..06faccb --- /dev/null +++ b/src/api/api_core.py @@ -0,0 +1,81 @@ +import numpy as np +from PIL import PngImagePlugin, Image +import base64 +from io import BytesIO +from fastapi.exceptions import HTTPException + +from src.core import core_generation_funnel +from src import backbone +from src.api.api_constants import api_defaults, api_forced, models_to_index + +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid encoded image") from e + +# TODO check that internally we always use png. +def encode_pil_to_base64(image, image_type='png'): + with BytesIO() as output_bytes: + + if image_type == 'png': + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None)) + + else: + raise HTTPException(status_code=500, detail="Invalid image format") + + bytes_data = output_bytes.getvalue() + + return base64.b64encode(bytes_data) + +def encode_to_base64(image): + if type(image) is str: + return image + elif type(image) is Image.Image: + return encode_pil_to_base64(image) + elif type(image) is np.ndarray: + return encode_np_to_base64(image) + else: + return "" + +def encode_np_to_base64(image): + pil = Image.fromarray(image) + return encode_pil_to_base64(pil) + +def to_base64_PIL(encoding: str): + return Image.fromarray(np.array(decode_base64_to_image(encoding)).astype('uint8')) + + +def api_gen(input_images, client_options): + + default_options = api_defaults.copy() + + #TODO try-catch type errors here + for key, value in client_options.items(): + if key == "model_type": + default_options[key.upper()] = models_to_index[value] + continue + default_options[key.upper()] = value + + for key, value in api_forced.items(): + default_options[key] = value + + print(f"Processing {str(len(input_images))} images through the API") + + print(default_options) + + pil_images = [] + for input_image in input_images: + pil_images.append(to_base64_PIL(input_image)) + outpath = backbone.get_outpath() + gen_obj = core_generation_funnel(outpath, pil_images, None, None, default_options) + return gen_obj \ No newline at end of file diff --git a/src/api/api_routes.py b/src/api/api_routes.py new file mode 100644 index 0000000..069f59e --- /dev/null +++ b/src/api/api_routes.py @@ -0,0 +1,49 @@ +# Non-public API. Don't host publicly - SECURITY RISKS! +# (will only be on with --api starting option) +# Currently no API stability guarantees are provided - API may break on any new commit. + +from fastapi import FastAPI, Body +from fastapi.exceptions import HTTPException + +from typing import Dict, List +from PIL import Image + +from src.common_constants import GenerationOptions as go +from src.misc import SCRIPT_VERSION +from src.api.api_constants import api_options, models_to_index +from src.api.api_core import api_gen, encode_to_base64 + +# _ parameter is needed for auto1111 extensions (_ is type gr.Blocks) +def depth_api(_, app: FastAPI): + @app.get("/depth/version") + async def version(): + return {"version": SCRIPT_VERSION} + + @app.get("/depth/get_options") + async def get_options(): + return { + "gen_options": [x.name.lower() for x in go], + "api_options": list(api_options.keys()), + "model_names": list(models_to_index.keys()) + } + + @app.post("/depth/generate") + async def process( + input_images: List[str] = Body([], title='Input Images'), + generate_options: Dict[str, object] = Body({}, title='Generation options', options= [x.name.lower() for x in go]), + api_options: Dict[str, object] = Body({}, title='Api options', options= api_options) + ): + + if len(input_images)==0: + raise HTTPException(status_code=422, detail="No images supplied") + + gen_obj = api_gen(input_images, generate_options) + + results_based = [] + for count, type, result in gen_obj: + if not isinstance(result, Image.Image): + continue + results_based += [encode_to_base64(result)] + return {"images": results_based, "info": "Success"} + + \ No newline at end of file diff --git a/src/api/api_standalone.py b/src/api/api_standalone.py new file mode 100644 index 0000000..3421fc3 --- /dev/null +++ b/src/api/api_standalone.py @@ -0,0 +1,16 @@ +from fastapi import FastAPI +import uvicorn +from src.api.api_routes import depth_api + +# without gradio +def init_api_no_webui(): + app = FastAPI() + print("setting up api endpoints") + depth_api( '', app) + print("api running") + uvicorn.run('src.api.api_standalone:depth_api', port=7860, host="127.0.0.1") + +def init_api(block, app): + print("setting up api endpoints") + depth_api( block, app) + print("api running") \ No newline at end of file diff --git a/src/backbone.py b/src/backbone.py index f82c0d5..0f90719 100644 --- a/src/backbone.py +++ b/src/backbone.py @@ -99,7 +99,8 @@ def torch_gc(): launched_at = int(datetime.now().timestamp()) backbone_current_seq_number = 0 - def get_next_sequence_number(outpath=None, basename=None): + # Make sure to preserve the function signature when calling! + def get_next_sequence_number(outpath, basename): global backbone_current_seq_number backbone_current_seq_number += 1 return int(f"{launched_at}{backbone_current_seq_number:04}") diff --git a/src/video_mode.py b/src/video_mode.py index 4b40cc0..69d593d 100644 --- a/src/video_mode.py +++ b/src/video_mode.py @@ -153,7 +153,7 @@ def gen_video(video, outpath, inp, custom_depthmap=None, colorvids_bitrate=None, imgs = [x[2] for x in img_results if x[1] == gen] basename = f'{gen}_video' - frames_to_video(fps, imgs, outpath, f"depthmap-{backbone.get_next_sequence_number()}-{basename}", + frames_to_video(fps, imgs, outpath, f"depthmap-{backbone.get_next_sequence_number(outpath, basename)}-{basename}", colorvids_bitrate) print('All done. Video(s) saved!') return '