Skip to content

Commit 3d14829

Browse files
committed
Remove code duplication, allow more settings
1 parent c9d799c commit 3d14829

File tree

1 file changed

+16
-87
lines changed

1 file changed

+16
-87
lines changed

scripts/depthmap_api.py

Lines changed: 16 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010
import gradio as gr
1111

12-
from modules.api.models import *
12+
from modules.api.models import List, Dict
1313
from modules.api import api
14-
from modules.shared import opts
1514

1615
from src.core import core_generation_funnel
17-
from src.common_ui import main_ui_panel
1816
from src.misc import SCRIPT_VERSION
1917
from src import backbone
18+
from src.common_constants import GenerationOptions as go
19+
2020

2121
def encode_to_base64(image):
2222
if type(image) is str:
@@ -28,20 +28,15 @@ def encode_to_base64(image):
2828
else:
2929
return ""
3030

31+
3132
def encode_np_to_base64(image):
3233
pil = Image.fromarray(image)
3334
return api.encode_pil_to_base64(pil)
3435

36+
3537
def to_base64_PIL(encoding: str):
3638
return Image.fromarray(np.array(api.decode_base64_to_image(encoding)).astype('uint8'))
3739

38-
#TODO: is this slow?
39-
def get_defaults():
40-
default_gradio = main_ui_panel(True).internal
41-
defaults = {}
42-
for key, value in default_gradio.items():
43-
defaults[key]= value.value
44-
return defaults
4540

4641
def depth_api(_: gr.Blocks, app: FastAPI):
4742
@app.get("/depth/version")
@@ -50,97 +45,31 @@ async def version():
5045

5146
@app.get("/depth/get_options")
5247
async def get_options():
53-
default_input = get_defaults()
54-
return {"settings": sorted(list(default_input.internal.keys()))}
55-
56-
#This will be the stable basic api
57-
@app.post("/depth/process")
48+
return {"options": sorted([x.name.lower() for x in go])}
49+
50+
# TODO: some potential inputs not supported (like custom depthmaps)
51+
@app.post("/depth/generate")
5852
async def process(
5953
depth_input_images: List[str] = Body([], title='Input Images'),
60-
compute_device:str = Body("GPU", title='CPU or GPU', options="'GPU', 'CPU'"),
61-
model_type:str = Body('zoedepth_n (indoor)', title='depth model', options="'res101', 'dpt_beit_large_512 (midas 3.1)', 'dpt_beit_large_384 (midas 3.1)', 'dpt_large_384 (midas 3.0)', 'dpt_hybrid_384 (midas 3.0)', 'midas_v21', 'midas_v21_small', 'zoedepth_n (indoor)', 'zoedepth_k (outdoor)', 'zoedepth_nk'"),
62-
net_width:int = Body(512, title="net width"),
63-
net_height:int = Body(512, title="net height"),
64-
net_size_match:bool = Body(True, title="match original image size"),
65-
boost:bool = Body(False, title="use boost algorithm"),
66-
output_depth_invert:bool = Body(False, title="invert depthmap")
54+
options: Dict[str, object] = Body("options", title='Generation options'),
6755
):
68-
default_inputs = get_defaults()
69-
override = {
70-
# TODO: These indexing aren't soo nice
71-
'compute_device': compute_device,
72-
'model_type': ['res101', 'dpt_beit_large_512 (midas 3.1)',
73-
'dpt_beit_large_384 (midas 3.1)', 'dpt_large_384 (midas 3.0)',
74-
'dpt_hybrid_384 (midas 3.0)',
75-
'midas_v21', 'midas_v21_small',
76-
'zoedepth_n (indoor)', 'zoedepth_k (outdoor)', 'zoedepth_nk'].index(model_type),
77-
'net_width': net_width,
78-
'net_height': net_height,
79-
'net_size_match': net_size_match,
80-
'boost': boost,
81-
'output_depth_invert': output_depth_invert,
82-
}
83-
84-
for key, value in override.items():
85-
default_inputs[key] = value
86-
8756
if len(depth_input_images) == 0:
88-
raise HTTPException(
89-
status_code=422, detail="No image selected")
57+
raise HTTPException(status_code=422, detail="No image supplied")
9058

91-
print(f"Processing {str(len(depth_input_images))} images with the depth module.")
59+
print(f"Processing {str(len(depth_input_images))} images trough the API.")
9260

9361
PIL_images = []
9462
for input_image in depth_input_images:
9563
PIL_images.append(to_base64_PIL(input_image))
9664

97-
outpath = opts.outdir_samples or opts.outdir_extras_samples
98-
img_gen = core_generation_funnel(outpath, PIL_images, None, None, default_inputs)[0]
65+
outpath = backbone.get_outpath()
66+
results, _, _ = core_generation_funnel(outpath, PIL_images, None, None, options)
9967

100-
# This just keeps depth image throws everything else away
101-
results = [img['depth'] for img in img_gen]
68+
# TODO: Fix: this just keeps depth image throws everything else away
69+
results = [img['depth'] for img in results]
10270
results64 = list(map(encode_to_base64, results))
10371

10472
return {"images": results64, "info": "Success"}
105-
106-
#This will be direct process for overriding the default settings
107-
@app.post("/depth/raw_process")
108-
async def raw_process(
109-
depth_input_images: List[str] = Body([], title='Input Images'),
110-
override: dict = Body({}, title="a dictionary containing exact internal keys to depthmap")
111-
):
112-
113-
default_inputs = get_defaults()
114-
for key, value in override.items():
115-
default_inputs[key] = value
116-
117-
if len(depth_input_images) == 0:
118-
raise HTTPException(
119-
status_code=422, detail="No image selected")
120-
121-
print(f"Processing {str(len(depth_input_images))} images with the depth module.")
122-
123-
PIL_images = []
124-
for input_image in depth_input_images:
125-
PIL_images.append(to_base64_PIL(input_image))
126-
127-
outpath = opts.outdir_samples or opts.outdir_extras_samples
128-
img_gen = core_generation_funnel(outpath, PIL_images, None, None, default_inputs)[0]
129-
130-
# This just keeps depth image throws everything else away
131-
results = [img['depth'] for img in img_gen]
132-
results64 = list(map(encode_to_base64, results))
133-
return {"images": results64, "info": "Success"}
134-
135-
# TODO: add functionality
136-
# most different output formats (.obj, etc) should have different apis because otherwise network bloat might become a thing
137-
138-
@app.post("/depth/extras_process")
139-
async def extras_process(
140-
depth_input_images: List[str] = Body([], title='Input Images')
141-
):
142-
143-
return {"images": depth_input_images, "info": "Success"}
14473

14574
try:
14675
import modules.script_callbacks as script_callbacks

0 commit comments

Comments
 (0)