Skip to content

Commit 204ea5b

Browse files
committed
Tiling mode
This is very cool, but it still has some seams for some reason ;( Drawbacks: Clutters UI Closes #444
1 parent b1bd3cb commit 204ea5b

File tree

4 files changed

+37
-7
lines changed

4 files changed

+37
-7
lines changed

src/common_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, default_value=None, *args):
2121
NET_SIZE_MATCH = False
2222
NET_WIDTH = 448
2323
NET_HEIGHT = 448
24+
TILING_MODE = False
2425

2526
DO_OUTPUT_DEPTH = True
2627
OUTPUT_DEPTH_INVERT = False

src/common_ui.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def main_ui_panel(is_depth_tab):
4848
with gr.Row(visible=False) as options_depend_on_match_size:
4949
inp += go.NET_WIDTH, gr.Slider(minimum=64, maximum=2048, step=64, label='Net width')
5050
inp += go.NET_HEIGHT, gr.Slider(minimum=64, maximum=2048, step=64, label='Net height')
51+
with gr.Row():
52+
inp += go.TILING_MODE, gr.Checkbox(
53+
label='Tiling mode', info='Reduces seams that appear if the depthmap is tiled into a grid'
54+
)
5155

5256
with gr.Box() as cur_option_root:
5357
inp -= 'depthmap_gen_row_2', cur_option_root
@@ -75,7 +79,7 @@ def main_ui_panel(is_depth_tab):
7579

7680
with gr.Box():
7781
with gr.Row():
78-
inp += go.GEN_STEREO, gr.Checkbox(label="Generate stereoscopic image(s)")
82+
inp += go.GEN_STEREO, gr.Checkbox(label="Generate stereoscopic (3D) image(s)")
7983
with gr.Column(visible=False) as stereo_options:
8084
with gr.Row():
8185
inp += go.STEREO_MODES, gr.CheckboxGroup(
@@ -178,6 +182,13 @@ def update_default_net_size(model_type):
178182
outputs=[inp[go.NET_SIZE_MATCH], options_depend_on_match_size]
179183
)
180184
inp.add_rule(options_depend_on_match_size, 'visible-if-not', go.NET_SIZE_MATCH)
185+
inp[go.TILING_MODE].change( # Go boost! Wroom!..
186+
fn=lambda a: (
187+
inp[go.BOOST].update(value=False), inp[go.NET_SIZE_MATCH].update(value=True)
188+
) if a else (inp[go.BOOST].update(), inp[go.NET_SIZE_MATCH].update()),
189+
inputs=[inp[go.TILING_MODE]],
190+
outputs=[inp[go.BOOST], inp[go.NET_SIZE_MATCH]]
191+
)
181192

182193
inp.add_rule(options_depend_on_output_depth_1, 'visible-if', go.DO_OUTPUT_DEPTH)
183194
inp.add_rule(go.OUTPUT_DEPTH_INVERT, 'visible-if', go.DO_OUTPUT_DEPTH)

src/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
121121
try:
122122
if not inputdepthmaps_complete:
123123
print("Loading model(s) ..")
124-
model_holder.ensure_models(inp[go.MODEL_TYPE], device, inp[go.BOOST])
124+
model_holder.ensure_models(inp[go.MODEL_TYPE], device, inp[go.BOOST], inp[go.TILING_MODE])
125125
print("Computing output(s) ..")
126126
# iterate over input images
127127
for count in trange(0, len(inputimages)):
@@ -170,6 +170,7 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
170170
# override net size (size may be different for different images)
171171
if inp[go.NET_SIZE_MATCH]:
172172
# Round up to a multiple of 32 to avoid potential issues
173+
# TODO: buggs for Depth Anything
173174
net_width = (inputimages[count].width + 31) // 32 * 32
174175
net_height = (inputimages[count].height + 31) // 32 * 32
175176
else:

src/depthmap_generation.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self):
4343
# Extra stuff
4444
self.resize_mode = None
4545
self.normalization = None
46+
self.tiling_mode = False
4647

4748

4849
def update_settings(self, **kvargs):
@@ -51,18 +52,23 @@ def update_settings(self, **kvargs):
5152
setattr(self, k, v)
5253

5354

54-
def ensure_models(self, model_type, device: torch.device, boost: bool):
55+
def ensure_models(self, model_type, device: torch.device, boost: bool, tiling_mode: bool = False):
5556
# TODO: could make it more granular
5657
if model_type == -1 or model_type is None:
5758
self.unload_models()
5859
return
5960
# Certain optimisations are irreversible and not device-agnostic, thus changing device requires reloading
60-
if model_type != self.depth_model_type or boost != (self.pix2pix_model is not None) or device != self.device:
61+
if (
62+
model_type != self.depth_model_type or
63+
boost != (self.pix2pix_model is not None) or
64+
device != self.device or
65+
tiling_mode != self.tiling_mode
66+
):
6167
self.unload_models()
62-
self.load_models(model_type, device, boost)
68+
self.load_models(model_type, device, boost, tiling_mode)
6369
self.reload()
6470

65-
def load_models(self, model_type, device: torch.device, boost: bool):
71+
def load_models(self, model_type, device: torch.device, boost: bool, tiling_mode: bool = False):
6672
"""Ensure that the depth model is loaded"""
6773

6874
# TODO: we need to at least try to find models downloaded by other plugins (e.g. controlnet)
@@ -205,7 +211,6 @@ def load_models(self, model_type, device: torch.device, boost: bool):
205211
model.enable_xformers_memory_efficient_attention()
206212
except:
207213
pass # run without xformers
208-
209214
elif model_type == 11: # depth_anything
210215
from depth_anything.dpt import DPT_DINOv2
211216
# This will download the model... to some place
@@ -223,6 +228,17 @@ def load_models(self, model_type, device: torch.device, boost: bool):
223228

224229
model.load_state_dict(torch.load(model_path))
225230

231+
if tiling_mode:
232+
def flatten(el):
233+
flattened = [flatten(children) for children in el.children()]
234+
res = [el]
235+
for c in flattened:
236+
res += c
237+
return res
238+
layers = flatten(model) # Hijacking the model
239+
for layer in [layer for layer in layers if type(layer) == torch.nn.Conv2d or type(layer) == torch.nn.Conv1d]:
240+
layer.padding_mode = 'circular'
241+
226242
if model_type in range(0, 10):
227243
model.eval() # prepare for evaluation
228244
# optimize
@@ -238,6 +254,7 @@ def load_models(self, model_type, device: torch.device, boost: bool):
238254
self.depth_model_type = model_type
239255
self.resize_mode = resize_mode
240256
self.normalization = normalization
257+
self.tiling_mode = tiling_mode
241258

242259
self.device = device
243260

0 commit comments

Comments
 (0)