Skip to content

Commit 8a3abcc

Browse files
refractor + standalone api
1 parent 3fddffb commit 8a3abcc

File tree

4 files changed

+92
-54
lines changed

4 files changed

+92
-54
lines changed

scripts/depthmap_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
# Currently no API stability guarantees are provided - API may break on any new commit.
44

55
from src import backbone
6-
from src.api import api_extension
6+
from api import api_routes
77

88
try:
99
import modules.script_callbacks as script_callbacks
1010
if backbone.get_cmd_opt('api', False):
11-
script_callbacks.on_app_started(api_extension.depth_api)
11+
script_callbacks.on_app_started(api_routes.depth_api)
1212
except:
1313
print('DepthMap API could not start')

src/api/api_core.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import numpy as np
2+
from PIL import PngImagePlugin, Image
3+
import base64
4+
from io import BytesIO
5+
from fastapi.exceptions import HTTPException
6+
7+
import gradio as gr
8+
9+
10+
from src.core import core_generation_funnel, CoreGenerationFunnelInp
11+
from src import backbone
12+
from src.api.api_constants import api_defaults, api_forced, models_to_index
13+
14+
# moedified from modules/api/api.py auto1111
15+
def decode_base64_to_image(encoding):
16+
if encoding.startswith("data:image/"):
17+
encoding = encoding.split(";")[1].split(",")[1]
18+
try:
19+
image = Image.open(BytesIO(base64.b64decode(encoding)))
20+
return image
21+
except Exception as e:
22+
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
23+
24+
# modified from modules/api/api.py auto1111. TODO check that internally we always use png. Removed webp and jpeg
25+
def encode_pil_to_base64(image, image_type='png'):
26+
with BytesIO() as output_bytes:
27+
28+
if image_type == 'png':
29+
use_metadata = False
30+
metadata = PngImagePlugin.PngInfo()
31+
for key, value in image.info.items():
32+
if isinstance(key, str) and isinstance(value, str):
33+
metadata.add_text(key, value)
34+
use_metadata = True
35+
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None))
36+
37+
else:
38+
raise HTTPException(status_code=500, detail="Invalid image format")
39+
40+
bytes_data = output_bytes.getvalue()
41+
42+
return base64.b64encode(bytes_data)
43+
44+
def encode_to_base64(image):
45+
if type(image) is str:
46+
return image
47+
elif type(image) is Image.Image:
48+
return encode_pil_to_base64(image)
49+
elif type(image) is np.ndarray:
50+
return encode_np_to_base64(image)
51+
else:
52+
return ""
53+
54+
def encode_np_to_base64(image):
55+
pil = Image.fromarray(image)
56+
return encode_pil_to_base64(pil)
57+
58+
def to_base64_PIL(encoding: str):
59+
return Image.fromarray(np.array(decode_base64_to_image(encoding)).astype('uint8'))
60+
61+
62+
def api_gen(input_images, client_options):
63+
64+
default_options = CoreGenerationFunnelInp(api_defaults).values
65+
66+
#TODO try-catch type errors here
67+
for key, value in client_options.items():
68+
if key == "model_type":
69+
default_options[key] = models_to_index(value)
70+
continue
71+
default_options[key] = value
72+
73+
for key, value in api_forced.items():
74+
default_options[key.lower()] = value
75+
76+
print(f"Processing {str(len(input_images))} images through the API")
77+
78+
print(default_options)
79+
80+
pil_images = []
81+
for input_image in input_images:
82+
pil_images.append(to_base64_PIL(input_image))
83+
outpath = backbone.get_outpath()
84+
gen_obj = core_generation_funnel(outpath, pil_images, None, None, default_options)
85+
return gen_obj

src/api/api_extension.py renamed to src/api/api_routes.py

Lines changed: 5 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,66 +2,18 @@
22
# (will only be on with --api starting option)
33
# Currently no API stability guarantees are provided - API may break on any new commit.
44

5-
import numpy as np
65
from fastapi import FastAPI, Body
76
from fastapi.exceptions import HTTPException
8-
from PIL import Image
97
from itertools import tee
10-
import json
118

129
import gradio as gr
1310

14-
from modules.api.models import List, Dict
15-
from modules.api import api
11+
from typing import Dict, List
1612

1713
from src.common_constants import GenerationOptions as go
18-
from src.core import core_generation_funnel, CoreGenerationFunnelInp
19-
from src import backbone
2014
from src.misc import SCRIPT_VERSION
21-
from src.api.api_constants import api_defaults, api_forced, api_options, models_to_index
22-
23-
def encode_to_base64(image):
24-
if type(image) is str:
25-
return image
26-
elif type(image) is Image.Image:
27-
return api.encode_pil_to_base64(image)
28-
elif type(image) is np.ndarray:
29-
return encode_np_to_base64(image)
30-
else:
31-
return ""
32-
33-
def encode_np_to_base64(image):
34-
pil = Image.fromarray(image)
35-
return api.encode_pil_to_base64(pil)
36-
37-
def to_base64_PIL(encoding: str):
38-
return Image.fromarray(np.array(api.decode_base64_to_image(encoding)).astype('uint8'))
39-
40-
41-
def api_gen(input_images, client_options):
42-
43-
default_options = CoreGenerationFunnelInp(api_defaults).values
44-
45-
#TODO try-catch type errors here
46-
for key, value in client_options.items():
47-
if key == "model_type":
48-
default_options[key] = models_to_index(value)
49-
continue
50-
default_options[key] = value
51-
52-
for key, value in api_forced.items():
53-
default_options[key.lower()] = value
54-
55-
print(f"Processing {str(len(input_images))} images through the API")
56-
57-
print(default_options)
58-
59-
pil_images = []
60-
for input_image in input_images:
61-
pil_images.append(to_base64_PIL(input_image))
62-
outpath = backbone.get_outpath()
63-
gen_obj = core_generation_funnel(outpath, pil_images, None, None, default_options)
64-
return gen_obj
15+
from src.api.api_constants import api_options, models_to_index
16+
from api.api_core import api_gen, encode_to_base64
6517

6618
def depth_api(_: gr.Blocks, app: FastAPI):
6719
@app.get("/depth/version")
@@ -72,7 +24,8 @@ async def version():
7224
async def get_options():
7325
return {
7426
"gen_options": [x.name.lower() for x in go],
75-
"api_options": api_options
27+
"api_options": api_options,
28+
"model_names": models_to_index.keys()
7629
}
7730

7831
@app.post("/depth/generate")

src/api/api_standalone.py

Whitespace-only changes.

0 commit comments

Comments
 (0)