Skip to content

Commit 3fa9185

Browse files
committed
Post-refactor fixes vol. 2
* Reload model before generation, if it is offloaded to CPU * Load model if boost got selected * Do not try to offload pix2pix * Net dimensions are multiple of 32 regardless of match size * Change the default net size to default net size of the default model * Fixed script mode * UI fixes
1 parent cee5576 commit 3fa9185

File tree

3 files changed

+39
-20
lines changed

3 files changed

+39
-20
lines changed

scripts/depthmap.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def main_ui_panel(is_depth_tab):
4949
with gr.Group(visible=False) as options_depend_on_boost:
5050
inp += 'match_size', gr.Checkbox(label="Match net size to input size", value=False)
5151
with gr.Row(visible=False) as options_depend_on_match_size:
52-
inp += 'net_width', gr.Slider(minimum=64, maximum=2048, step=64, label='Net width', value=512)
53-
inp += 'net_height', gr.Slider(minimum=64, maximum=2048, step=64, label='Net height', value=512)
52+
inp += 'net_width', gr.Slider(minimum=64, maximum=2048, step=64, label='Net width', value=448)
53+
inp += 'net_height', gr.Slider(minimum=64, maximum=2048, step=64, label='Net height', value=448)
5454

5555
with gr.Group():
5656
with gr.Row():
@@ -104,18 +104,20 @@ def main_ui_panel(is_depth_tab):
104104
with gr.Row():
105105
inp += "gen_mesh", gr.Checkbox(
106106
label="Generate simple 3D mesh", value=False, visible=True)
107-
with gr.Row(visible=False) as mesh_options_row_0:
108-
gr.Label(value="Generates fast, accurate only with ZoeDepth models and no boost, no custom maps")
109-
inp += "mesh_occlude", gr.Checkbox(label="Remove occluded edges", value=True, visible=True)
110-
inp += "mesh_spherical", gr.Checkbox(label="Equirectangular projection", value=False, visible=True)
107+
with gr.Group(visible=False) as mesh_options:
108+
with gr.Row():
109+
gr.HTML(value="Generates fast, accurate only with ZoeDepth models and no boost, no custom maps")
110+
with gr.Row():
111+
inp += "mesh_occlude", gr.Checkbox(label="Remove occluded edges", value=True, visible=True)
112+
inp += "mesh_spherical", gr.Checkbox(label="Equirectangular projection", value=False, visible=True)
111113

112114
if is_depth_tab:
113115
with gr.Group():
114116
with gr.Row():
115117
inp += "inpaint", gr.Checkbox(
116118
label="Generate 3D inpainted mesh", value=False)
117119
with gr.Group(visible=False) as inpaint_options_row_0:
118-
gr.Label("Generation is sloooow, required for generating videos")
120+
gr.HTML("Generation is sloooow, required for generating videos")
119121
inp += "inpaint_vids", gr.Checkbox(
120122
label="Generate 4 demo videos with 3D inpainted mesh.", value=False)
121123
gr.HTML("More options for generating video can be found in the Generate video tab")
@@ -199,9 +201,9 @@ def stereo_options_visibility(v):
199201
)
200202

201203
inp['gen_mesh'].change(
202-
fn=lambda v: mesh_options_row_0.update(visible=v),
204+
fn=lambda v: mesh_options.update(visible=v),
203205
inputs=[inp['gen_mesh']],
204-
outputs=[mesh_options_row_0]
206+
outputs=[mesh_options]
205207
)
206208

207209
def inpaint_options_visibility(v):

src/core.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def reload_sd_model():
6262
def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp):
6363
if len(inputimages) == 0 or inputimages[0] is None:
6464
return [], '', ''
65-
if len(inputdepthmaps) == 0:
65+
if inputdepthmaps is None or len(inputdepthmaps) == 0:
6666
inputdepthmaps: list[Image] = [None for _ in range(len(inputimages))]
6767
inputdepthmaps_complete = all([x is not None for x in inputdepthmaps])
6868

@@ -78,8 +78,8 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
7878
gen_mesh = inp["gen_mesh"]
7979
gen_normal = inp["gen_normal"] if "gen_normal" in inp else False
8080
gen_stereo = inp["gen_stereo"]
81-
inpaint = inp["inpaint"]
82-
inpaint_vids = inp["inpaint_vids"]
81+
inpaint = inp["inpaint"] if "inpaint" in inp else False
82+
inpaint_vids = inp["inpaint_vids"] if "inpaint_vids" in inp else False
8383
invert_depth = inp["invert_depth"]
8484
match_size = inp["match_size"]
8585
mesh_occlude = inp["mesh_occlude"]
@@ -165,7 +165,9 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
165165
else:
166166
# override net size (size may be different for different images)
167167
if match_size:
168-
net_width, net_height = inputimages[count].width, inputimages[count].height
168+
# Round up to a multiple of 32 to avoid potential issues
169+
net_width = (inputimages[count].width + 31) // 32 * 32
170+
net_height = (inputimages[count].height + 31) // 32 * 32
169171
raw_prediction, raw_prediction_invert = \
170172
model_holder.get_raw_prediction(inputimages[count], net_width, net_height)
171173

@@ -304,14 +306,14 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
304306
else:
305307
raise e
306308
finally:
307-
if not (hasattr(opts, 'depthmap_script_keepmodels') and opts.depthmap_script_keepmodels):
309+
if hasattr(opts, 'depthmap_script_keepmodels') and opts.depthmap_script_keepmodels:
310+
model_holder.offload() # Swap to CPU memory
311+
else:
308312
if 'model' in locals():
309313
del model
310314
if 'pix2pixmodel' in locals():
311315
del pix2pix_model
312316
model_holder.unload_models()
313-
else:
314-
model_holder.swap_to_cpu_memory()
315317

316318
gc.collect()
317319
devices.torch_gc()

src/depthmap_generation.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self):
4242
self.pix2pix_model = None
4343
self.depth_model_type = None
4444
self.device = None # Target device, the model may be swapped from VRAM into RAM.
45+
self.offloaded = False # True means current device is not the target device
4546

4647
# Extra stuff
4748
self.resize_mode = None
@@ -53,9 +54,10 @@ def ensure_models(self, model_type, device: torch.device, boost: bool):
5354
self.unload_models()
5455
return
5556
# Certain optimisations are irreversible and not device-agnostic, thus changing device requires reloading
56-
if model_type != self.depth_model_type or boost != self.pix2pix_model is not None or device != self.device:
57+
if model_type != self.depth_model_type or boost != (self.pix2pix_model is not None) or device != self.device:
5758
self.unload_models()
5859
self.load_models(model_type, device, boost)
60+
self.reload()
5961

6062
def load_models(self, model_type, device: torch.device, boost: bool):
6163
"""Ensure that the depth model is loaded"""
@@ -236,11 +238,24 @@ def get_default_net_size(model_type):
236238
return sizes[model_type]
237239
return [512, 512]
238240

239-
def swap_to_cpu_memory(self):
241+
def offload(self):
242+
"""Move to RAM to conserve VRAM"""
243+
if self.device != torch.device('cpu') and not self.offloaded:
244+
self.move_models_to(torch.device('cpu'))
245+
self.offloaded = True
246+
247+
def reload(self):
248+
"""Undoes offload"""
249+
if self.offloaded:
250+
self.move_models_to(self.device)
251+
self.offloaded = True
252+
253+
def move_models_to(self, device):
240254
if self.depth_model is not None:
241-
self.depth_model.to(torch.device('cpu'))
255+
self.depth_model.to(device)
242256
if self.pix2pix_model is not None:
243-
self.pix2pix_model.to(torch.device('cpu'))
257+
pass
258+
# TODO: pix2pix offloading not implemented
244259

245260
def unload_models(self):
246261
if self.depth_model is not None or self.pix2pix_model is not None:

0 commit comments

Comments
 (0)