Skip to content

refactoring + standalone api #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import os
import pathlib

import time
import src.misc


Expand All @@ -29,6 +29,7 @@ def maybe_chdir():
parser = argparse.ArgumentParser()
parser.add_argument("--share", help="Create public link", action='store_true')
parser.add_argument("--listen", help="Create public link", action='store_true')
parser.add_argument("--api", help="Start-up api", action='store_true')
parser.add_argument("--no_chdir", help="Do not try to use the root of stable-diffusion-webui", action='store_true')
args = parser.parse_args()

Expand All @@ -37,4 +38,15 @@ def maybe_chdir():
maybe_chdir()
server_name = "0.0.0.0" if args.listen else None
import src.common_ui
src.common_ui.on_ui_tabs().launch(share=args.share, server_name=server_name)

ui_block = src.common_ui.on_ui_tabs()

if not args.api:
ui_block.launch(share=args.share, server_name=server_name)
else:
app, _, _ = ui_block.launch(share=args.share, server_name=server_name, prevent_thread_lock=True)
print(f"starting depth api")
from src.api.api_standalone import init_api
init_api(ui_block, app)
while True:
time.sleep(0.1)
11 changes: 11 additions & 0 deletions scripts/depthmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,14 @@ def add_option(name, default_value, description, name_prefix='depthmap_script'):
from modules import script_callbacks
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(lambda: [(common_ui.on_ui_tabs(), "Depth", "depthmap_interface")])

# API script
from src.api import api_routes

try:
import modules.script_callbacks as script_callbacks
if backbone.get_cmd_opt('api', False):
script_callbacks.on_app_started(api_routes.depth_api)
except:
print('DepthMap API could not start')

81 changes: 0 additions & 81 deletions scripts/depthmap_api.py

This file was deleted.

32 changes: 32 additions & 0 deletions src/api/api_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

api_options = {
#'outputs': ["depth"], # list of outputs to send in response. examples ["depth", "normalmap", 'heatmap', "normal", 'background_removed'] etc
#'conversions': "", #TODO implement. it's a good idea to give some options serverside for because often that's challenging in js/clientside
'save':"" #TODO implement. To save on local machine. Can be very helpful for debugging.
}

# TODO: These are intended to be temporary
api_defaults={
"BOOST": False,
"NET_SIZE_MATCH": True
}

#These are enforced after user inputs
api_forced={
"GEN_SIMPLE_MESH": False,
"GEN_INPAINTED_MESH": False
}

#model diction TODO find a way to remove without forcing people do know indexes of models
models_to_index = {
'res101':0,
'dpt_beit_large_512 (midas 3.1)':1,
'dpt_beit_large_384 (midas 3.1)':2,
'dpt_large_384 (midas 3.0)':3,
'dpt_hybrid_384 (midas 3.0)':4,
'midas_v21':5,
'midas_v21_small':6,
'zoedepth_n (indoor)':7,
'zoedepth_k (outdoor)':8,
'zoedepth_nk':9
}
81 changes: 81 additions & 0 deletions src/api/api_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import numpy as np
from PIL import PngImagePlugin, Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException

from src.core import core_generation_funnel
from src import backbone
from src.api.api_constants import api_defaults, api_forced, models_to_index

def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e

# TODO check that internally we always use png.
def encode_pil_to_base64(image, image_type='png'):
with BytesIO() as output_bytes:

if image_type == 'png':
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None))

else:
raise HTTPException(status_code=500, detail="Invalid image format")

bytes_data = output_bytes.getvalue()

return base64.b64encode(bytes_data)

def encode_to_base64(image):
if type(image) is str:
return image
elif type(image) is Image.Image:
return encode_pil_to_base64(image)
elif type(image) is np.ndarray:
return encode_np_to_base64(image)
else:
return ""

def encode_np_to_base64(image):
pil = Image.fromarray(image)
return encode_pil_to_base64(pil)

def to_base64_PIL(encoding: str):
return Image.fromarray(np.array(decode_base64_to_image(encoding)).astype('uint8'))


def api_gen(input_images, client_options):

default_options = api_defaults.copy()

#TODO try-catch type errors here
for key, value in client_options.items():
if key == "model_type":
default_options[key.upper()] = models_to_index[value]
continue
default_options[key.upper()] = value

for key, value in api_forced.items():
default_options[key] = value

print(f"Processing {str(len(input_images))} images through the API")

print(default_options)

pil_images = []
for input_image in input_images:
pil_images.append(to_base64_PIL(input_image))
outpath = backbone.get_outpath()
gen_obj = core_generation_funnel(outpath, pil_images, None, None, default_options)
return gen_obj
49 changes: 49 additions & 0 deletions src/api/api_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Non-public API. Don't host publicly - SECURITY RISKS!
# (will only be on with --api starting option)
# Currently no API stability guarantees are provided - API may break on any new commit.

from fastapi import FastAPI, Body
from fastapi.exceptions import HTTPException

from typing import Dict, List
from PIL import Image

from src.common_constants import GenerationOptions as go
from src.misc import SCRIPT_VERSION
from src.api.api_constants import api_options, models_to_index
from src.api.api_core import api_gen, encode_to_base64

# _ parameter is needed for auto1111 extensions (_ is type gr.Blocks)
def depth_api(_, app: FastAPI):
@app.get("/depth/version")
async def version():
return {"version": SCRIPT_VERSION}

@app.get("/depth/get_options")
async def get_options():
return {
"gen_options": [x.name.lower() for x in go],
"api_options": list(api_options.keys()),
"model_names": list(models_to_index.keys())
}

@app.post("/depth/generate")
async def process(
input_images: List[str] = Body([], title='Input Images'),
generate_options: Dict[str, object] = Body({}, title='Generation options', options= [x.name.lower() for x in go]),
api_options: Dict[str, object] = Body({}, title='Api options', options= api_options)
):

if len(input_images)==0:
raise HTTPException(status_code=422, detail="No images supplied")

gen_obj = api_gen(input_images, generate_options)

results_based = []
for count, type, result in gen_obj:
if not isinstance(result, Image.Image):
continue
results_based += [encode_to_base64(result)]
return {"images": results_based, "info": "Success"}


16 changes: 16 additions & 0 deletions src/api/api_standalone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from fastapi import FastAPI
import uvicorn
from src.api.api_routes import depth_api

# without gradio
def init_api_no_webui():
app = FastAPI()
print("setting up api endpoints")
depth_api( '', app)
print("api running")
uvicorn.run('src.api.api_standalone:depth_api', port=7860, host="127.0.0.1")

def init_api(block, app):
print("setting up api endpoints")
depth_api( block, app)
print("api running")
3 changes: 2 additions & 1 deletion src/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def torch_gc():
launched_at = int(datetime.now().timestamp())
backbone_current_seq_number = 0

def get_next_sequence_number(outpath=None, basename=None):
# Make sure to preserve the function signature when calling!
def get_next_sequence_number(outpath, basename):
global backbone_current_seq_number
backbone_current_seq_number += 1
return int(f"{launched_at}{backbone_current_seq_number:04}")
Expand Down
2 changes: 1 addition & 1 deletion src/video_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def gen_video(video, outpath, inp, custom_depthmap=None, colorvids_bitrate=None,

imgs = [x[2] for x in img_results if x[1] == gen]
basename = f'{gen}_video'
frames_to_video(fps, imgs, outpath, f"depthmap-{backbone.get_next_sequence_number()}-{basename}",
frames_to_video(fps, imgs, outpath, f"depthmap-{backbone.get_next_sequence_number(outpath, basename)}-{basename}",
colorvids_bitrate)
print('All done. Video(s) saved!')
return '<h3>Videos generated</h3>' if len(gens) > 1 else '<h3>Video generated</h3>' if len(gens) == 1 \
Expand Down