@@ -43,6 +43,7 @@ def __init__(self):
43
43
# Extra stuff
44
44
self .resize_mode = None
45
45
self .normalization = None
46
+ self .tiling_mode = False
46
47
47
48
48
49
def update_settings (self , ** kvargs ):
@@ -51,18 +52,23 @@ def update_settings(self, **kvargs):
51
52
setattr (self , k , v )
52
53
53
54
54
- def ensure_models (self , model_type , device : torch .device , boost : bool ):
55
+ def ensure_models (self , model_type , device : torch .device , boost : bool , tiling_mode : bool = False ):
55
56
# TODO: could make it more granular
56
57
if model_type == - 1 or model_type is None :
57
58
self .unload_models ()
58
59
return
59
60
# Certain optimisations are irreversible and not device-agnostic, thus changing device requires reloading
60
- if model_type != self .depth_model_type or boost != (self .pix2pix_model is not None ) or device != self .device :
61
+ if (
62
+ model_type != self .depth_model_type or
63
+ boost != (self .pix2pix_model is not None ) or
64
+ device != self .device or
65
+ tiling_mode != self .tiling_mode
66
+ ):
61
67
self .unload_models ()
62
- self .load_models (model_type , device , boost )
68
+ self .load_models (model_type , device , boost , tiling_mode )
63
69
self .reload ()
64
70
65
- def load_models (self , model_type , device : torch .device , boost : bool ):
71
+ def load_models (self , model_type , device : torch .device , boost : bool , tiling_mode : bool = False ):
66
72
"""Ensure that the depth model is loaded"""
67
73
68
74
# TODO: we need to at least try to find models downloaded by other plugins (e.g. controlnet)
@@ -205,7 +211,6 @@ def load_models(self, model_type, device: torch.device, boost: bool):
205
211
model .enable_xformers_memory_efficient_attention ()
206
212
except :
207
213
pass # run without xformers
208
-
209
214
elif model_type == 11 : # depth_anything
210
215
from depth_anything .dpt import DPT_DINOv2
211
216
# This will download the model... to some place
@@ -223,6 +228,17 @@ def load_models(self, model_type, device: torch.device, boost: bool):
223
228
224
229
model .load_state_dict (torch .load (model_path ))
225
230
231
+ if tiling_mode :
232
+ def flatten (el ):
233
+ flattened = [flatten (children ) for children in el .children ()]
234
+ res = [el ]
235
+ for c in flattened :
236
+ res += c
237
+ return res
238
+ layers = flatten (model ) # Hijacking the model
239
+ for layer in [layer for layer in layers if type (layer ) == torch .nn .Conv2d or type (layer ) == torch .nn .Conv1d ]:
240
+ layer .padding_mode = 'circular'
241
+
226
242
if model_type in range (0 , 10 ):
227
243
model .eval () # prepare for evaluation
228
244
# optimize
@@ -238,6 +254,7 @@ def load_models(self, model_type, device: torch.device, boost: bool):
238
254
self .depth_model_type = model_type
239
255
self .resize_mode = resize_mode
240
256
self .normalization = normalization
257
+ self .tiling_mode = tiling_mode
241
258
242
259
self .device = device
243
260
0 commit comments