Skip to content

Commit 88aa86f

Browse files
committed
Post-refactor fixes
1 parent 72517c4 commit 88aa86f

File tree

6 files changed

+48
-32
lines changed

6 files changed

+48
-32
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ video by [@graemeniedermayer](https://github.com/graemeniedermayer), more exampl
2121
images generated by [@semjon00](https://github.com/semjon00) from CC0 photos, more examples [here](https://github.com/thygate/stable-diffusion-webui-depthmap-script/pull/56#issuecomment-1367596463).
2222

2323
## Changelog
24-
* v0.3.13
25-
* Large code refactor
26-
* Improved interface
27-
* Slightly changed the behaviour of various options
24+
* v0.4.0 large code refactor
25+
* UI improvements
26+
* slightly changed the behaviour of various options
27+
* extension may partially work even if some of the dependencies are unmet
2828
* v0.3.12
2929
* Fixed stereo image generation
3030
* Other bugfixes

scripts/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
9898
stereo_separation = inp["stereo_separation"]
9999

100100
# TODO: ideally, run_depthmap should not save meshes - that makes the function not pure
101-
print(f"\n{SCRIPT_NAME} {SCRIPT_VERSION} ({get_commit_hash()})")
101+
print(f"{SCRIPT_NAME} {SCRIPT_VERSION} ({get_commit_hash()})")
102102

103103
unload_sd_model()
104104

@@ -230,7 +230,7 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
230230

231231
if show_heat:
232232
from dzoedepth.utils.misc import colorize
233-
heatmap = colorize(img_output, cmap='inferno')
233+
heatmap = Image.fromarray(colorize(img_output, cmap='inferno'))
234234
generated_images[count]['heatmap'] = heatmap
235235

236236
if gen_stereo:
@@ -325,7 +325,7 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
325325
print(f'{str(e)}, some issue with generating inpainted mesh')
326326

327327
reload_sd_model()
328-
print("All done.")
328+
print("All done.\n")
329329
return generated_images, mesh_fi, meshsimple_fi
330330

331331

scripts/depthmap_generation.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from PIL import Image
44
from torchvision.transforms import Compose, transforms
55

6+
# TODO: depthmap_generation should not depend on WebUI
67
from modules import shared, devices
78
from modules.shared import opts, cmd_opts
89

@@ -29,7 +30,6 @@
2930
from pix2pix.options.test_options import TestOptions
3031
from pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
3132

32-
3333
# zoedepth
3434
from dzoedepth.models.builder import build_model
3535
from dzoedepth.utils.config import get_config
@@ -59,9 +59,6 @@ def ensure_models(self, model_type, device: torch.device, boost: bool):
5959

6060
def load_models(self, model_type, device: torch.device, boost: bool):
6161
"""Ensure that the depth model is loaded"""
62-
# TODO: supply correct values for zoedepth
63-
net_width = 512
64-
net_height = 512
6562

6663
# model path and name
6764
model_dir = "./models/midas"
@@ -171,22 +168,21 @@ def load_models(self, model_type, device: torch.device, boost: bool):
171168
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
172169
)
173170

171+
# When loading, zoedepth models will report the default net size.
172+
# It will be overridden by the generation settings.
174173
elif model_type == 7: # zoedepth_n
175174
print("zoedepth_n\n")
176175
conf = get_config("zoedepth", "infer")
177-
conf.img_size = [net_width, net_height]
178176
model = build_model(conf)
179177

180178
elif model_type == 8: # zoedepth_k
181179
print("zoedepth_k\n")
182180
conf = get_config("zoedepth", "infer", config_version="kitti")
183-
conf.img_size = [net_width, net_height]
184181
model = build_model(conf)
185182

186183
elif model_type == 9: # zoedepth_nk
187184
print("zoedepth_nk\n")
188185
conf = get_config("zoedepth_nk", "infer")
189-
conf.img_size = [net_width, net_height]
190186
model = build_model(conf)
191187

192188
model.eval() # prepare for evaluation
@@ -221,15 +217,20 @@ def load_models(self, model_type, device: torch.device, boost: bool):
221217

222218
devices.torch_gc()
223219

224-
def get_default_net_size(self, model_type):
220+
@staticmethod
221+
def get_default_net_size(model_type):
225222
# TODO: fill in, use in the GUI
226223
sizes = {
224+
0: [448, 448],
227225
1: [512, 512],
228226
2: [384, 384],
229227
3: [384, 384],
230228
4: [384, 384],
231229
5: [384, 384],
232230
6: [256, 256],
231+
7: [384, 512],
232+
8: [384, 768],
233+
9: [384, 512]
233234
}
234235
if model_type in sizes:
235236
return sizes[model_type]
@@ -254,8 +255,9 @@ def unload_models(self):
254255
self.device = None
255256

256257
def get_raw_prediction(self, input, net_width, net_height):
257-
"""Get prediction from the model currently loaded by the class.
258+
"""Get prediction from the model currently loaded by the ModelHolder object.
258259
If boost is enabled, net_width and net_height will be ignored."""
260+
# TODO: supply net size for zoedepth
259261
global device
260262
device = self.device
261263
# input image
@@ -264,17 +266,14 @@ def get_raw_prediction(self, input, net_width, net_height):
264266
if self.pix2pix_model is None:
265267
if self.depth_model_type == 0:
266268
raw_prediction = estimateleres(img, self.depth_model, net_width, net_height)
267-
raw_prediction_invert = True
268269
elif self.depth_model_type in [7, 8, 9]:
269270
raw_prediction = estimatezoedepth(input, self.depth_model, net_width, net_height)
270-
raw_prediction_invert = True
271271
else:
272272
raw_prediction = estimatemidas(img, self.depth_model, net_width, net_height,
273273
self.resize_mode, self.normalization)
274-
raw_prediction_invert = False
275274
else:
276275
raw_prediction = estimateboost(img, self.depth_model, self.depth_model_type, self.pix2pix_model)
277-
raw_prediction_invert = False
276+
raw_prediction_invert = self.depth_model_type in [0, 7, 8, 9]
278277
return raw_prediction, raw_prediction_invert
279278

280279

scripts/interface_webui.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from modules.shared import opts
99
from modules.ui import plaintext_to_html
1010
from pathlib import Path
11+
from PIL import Image
1112

1213
from scripts.gradio_args_transport import GradioComponentBundle
1314
from scripts.main import *
1415
from scripts.core import core_generation_funnel, unload_models, run_makevideo
15-
from PIL import Image
16+
from scripts.depthmap_generation import ModelHolder
1617

1718

1819
# Ugly workaround to fix gradio tempfile issue
@@ -102,19 +103,19 @@ def main_ui_panel(is_depth_tab):
102103
with gr.Group():
103104
with gr.Row():
104105
inp += "gen_mesh", gr.Checkbox(
105-
label="Generate simple 3D mesh. "
106-
"(Fast, accurate only with ZoeDepth models and no boost, no custom maps)",
107-
value=False, visible=True)
106+
label="Generate simple 3D mesh", value=False, visible=True)
108107
with gr.Row(visible=False) as mesh_options_row_0:
108+
gr.Label(value="Generates fast, accurate only with ZoeDepth models and no boost, no custom maps")
109109
inp += "mesh_occlude", gr.Checkbox(label="Remove occluded edges", value=True, visible=True)
110110
inp += "mesh_spherical", gr.Checkbox(label="Equirectangular projection", value=False, visible=True)
111111

112112
if is_depth_tab:
113113
with gr.Group():
114114
with gr.Row():
115115
inp += "inpaint", gr.Checkbox(
116-
label="Generate 3D inpainted mesh. (Sloooow, required for generating videos)", value=False)
116+
label="Generate 3D inpainted mesh", value=False)
117117
with gr.Group(visible=False) as inpaint_options_row_0:
118+
gr.Label("Generation is sloooow, required for generating videos")
118119
inp += "inpaint_vids", gr.Checkbox(
119120
label="Generate 4 demo videos with 3D inpainted mesh.", value=False)
120121
gr.HTML("More options for generating video can be found in the Generate video tab")
@@ -139,6 +140,15 @@ def main_ui_panel(is_depth_tab):
139140

140141
inp += "gen_normal", gr.Checkbox(label="Generate Normalmap (hidden! api only)", value=False, visible=False)
141142

143+
def update_delault_net_size(model_type):
144+
w, h = ModelHolder.get_default_net_size(model_type)
145+
return inp['net_width'].update(value=w), inp['net_height'].update(value=h)
146+
inp['model_type'].change(
147+
fn=update_delault_net_size,
148+
inputs=inp['model_type'],
149+
outputs=[inp['net_width'], inp['net_height']]
150+
)
151+
142152
inp['boost'].change(
143153
fn=lambda a, b: (options_depend_on_boost.update(visible=not a),
144154
options_depend_on_match_size.update(visible=not a and not b)),
@@ -309,6 +319,7 @@ def on_ui_tabs():
309319
inp += gr.Image(label="Source", source="upload", interactive=True, type="pil",
310320
elem_id="depthmap_input_image")
311321
with gr.Group(visible=False) as custom_depthmap_row_0:
322+
# TODO: depthmap generation settings should disappear when using this
312323
inp += gr.File(label="Custom DepthMap", file_count="single", interactive=True,
313324
type="file", elem_id='custom_depthmap_img')
314325
inp += gr.Checkbox(elem_id="custom_depthmap", label="Use custom DepthMap", value=False)
@@ -471,13 +482,12 @@ def run_generate(*inputs):
471482
inputnames.append(None)
472483
if custom_depthmap:
473484
if custom_depthmap_img is None:
474-
return [], None, None, "Custom depthmap is not specified. " \
475-
"Please either supply it or disable this option.", ""
476-
inputdepthmaps.append(custom_depthmap_img)
485+
return [], None, None,\
486+
"Custom depthmap is not specified. Please either supply it or disable this option.", ""
487+
inputdepthmaps.append(Image.open(os.path.abspath(custom_depthmap_img.name)))
477488
else:
478489
inputdepthmaps.append(None)
479490
if depthmap_mode == '1': # Batch Process
480-
# convert files to pillow images
481491
for img in image_batch:
482492
image = Image.open(os.path.abspath(img.name))
483493
inputimages.append(image)

scripts/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
SCRIPT_NAME = "DepthMap"
7-
SCRIPT_VERSION = "v0.3.13"
7+
SCRIPT_VERSION = "v0.4.0"
88

99
commit_hash = None # TODO: understand why it would spam to stderr if changed to ... = get_commit_hash()
1010
def get_commit_hash():

scripts/stereoimage_generation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from numba import njit, prange
1+
try:
2+
from numba import njit, prange
3+
except Exception as e:
4+
print(f"WARINING! Numba failed to import! Stereoimage generation will be much slower! ({str(e)})")
5+
from builtins import range as prange
6+
def njit(parallel=False):
7+
def Inner(func): return lambda *args, **kwargs: func(*args, **kwargs)
8+
return Inner
29
import numpy as np
310
from PIL import Image
411

@@ -73,7 +80,7 @@ def apply_stereo_divergence(original_image, depth, divergence, separation, fill_
7380
)
7481

7582

76-
@njit
83+
@njit(parallel=False)
7784
def apply_stereo_divergence_naive(
7885
original_image, normalized_depth, divergence_px: float, separation_px: float, fill_technique):
7986
h, w, c = original_image.shape

0 commit comments

Comments
 (0)