Skip to content

Commit 128fe5b

Browse files
authored
support for marigold (#385)
support for marigold
1 parent af64de8 commit 128fe5b

File tree

7 files changed

+49
-20
lines changed

7 files changed

+49
-20
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
## Changelog
2+
### 0.4.5
3+
* Support for [Marigold](https://marigoldmonodepth.github.io). [PR #385](https://github.com/thygate/stable-diffusion-webui-depthmap-script/pull/385).
24
### 0.4.4
35
* Compatibility with stable-diffusion-webui 1.6.0
46
### 0.4.3 video processing tab

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,16 @@ ZoeDepth :
198198
copyright = {arXiv.org perpetual, non-exclusive license}
199199
}
200200
```
201+
202+
Marigold - Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation:
203+
204+
```
205+
@misc{ke2023repurposing,
206+
title={Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation},
207+
author={Bingxin Ke and Anton Obukhov and Shengyu Huang and Nando Metzger and Rodrigo Caye Daudt and Konrad Schindler},
208+
year={2023},
209+
eprint={2312.02145},
210+
archivePrefix={arXiv},
211+
primaryClass={cs.CV}
212+
}
213+
```

install.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def ensure(module_name, min_version=None):
3838
launch.run_pip('install "moviepy==1.0.2"', "moviepy requirement for depthmap script")
3939
ensure('transforms3d', '0.4.1')
4040

41+
ensure('transformers', '4.32.1')
42+
ensure('xformers', '0.0.21')
43+
ensure('accelerate', '0.22.0')
44+
ensure('diffusers', '0.20.1')
45+
4146
ensure('imageio') # 2.4.1
4247
try: # Dirty hack to not reinstall every time
4348
importlib_metadata.version('imageio-ffmpeg')
@@ -53,4 +58,4 @@ def ensure(module_name, min_version=None):
5358
if platform.system() == 'Darwin':
5459
ensure('pyqt6')
5560

56-
launch.git_clone("https://github.com/prs-eth/Marigold", "repositories/Marigold", "Marigold", "cc78ff3")
61+
launch.git_clone("https://github.com/prs-eth/Marigold", "Marigold", "Marigold", "cc78ff3")

pix2pix/models/pix2pix4depth_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def set_input_train(self, input):
9494
self.real_A = torch.cat((self.outer, self.inner), 1)
9595

9696
def set_input(self, outer, inner):
97-
inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0)
98-
outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0)
97+
inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0).float()
98+
outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0).float()
9999

100100
inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner))
101101
outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer))

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,9 @@ transforms3d>=0.4.1
1616
imageio>=2.4.1,<3.0
1717
imageio-ffmpeg
1818
networkx>=2.5
19+
transformers>=4.32.1 # For Marigold
20+
xformers==0.0.21 # For Marigold
21+
accelerate>=0.22.0 # For Marigold
22+
diffusers>=0.20.1 # For Marigold
1923
pyqt5; sys_platform == 'windows'
2024
pyqt6; sys_platform != 'windows'

src/depthmap_generation.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ def load_models(self, model_type, device: torch.device, boost: bool):
7474
model_dir = "./models/midas"
7575
if model_type == 0:
7676
model_dir = "./models/leres"
77-
if model_type == 10:
78-
"./models/marigold"
7977
# create paths to model if not present
8078
os.makedirs(model_dir, exist_ok=True)
8179
os.makedirs('./models/pix2pix', exist_ok=True)
@@ -197,9 +195,9 @@ def load_models(self, model_type, device: torch.device, boost: bool):
197195
model = build_model(conf)
198196

199197
elif model_type == 10: # Marigold v1
200-
# TODO: pass more parameters
201-
model_path = f"{model_dir}/marigold_v1/"
202-
from repositories.Marigold.src.model.marigold_pipeline import MarigoldPipeline
198+
model_path = "Bingxin/Marigold"
199+
print(model_path)
200+
from Marigold.src.model.marigold_pipeline import MarigoldPipeline
203201
model = MarigoldPipeline.from_pretrained(model_path)
204202

205203
model.eval() # prepare for evaluation
@@ -301,11 +299,11 @@ def get_raw_prediction(self, input, net_width, net_height):
301299
self.resize_mode, self.normalization, self.no_half,
302300
self.precision == "autocast")
303301
elif self.depth_model_type == 10:
304-
raw_prediction = estimatemarigold(img, self.depth_model, net_width, net_height, self.device)
302+
raw_prediction = estimatemarigold(img, self.depth_model, net_width, net_height)
305303
else:
306304
raw_prediction = estimateboost(img, self.depth_model, self.depth_model_type, self.pix2pix_model,
307305
self.boost_whole_size_threshold)
308-
raw_prediction_invert = self.depth_model_type in [0, 7, 8, 9]
306+
raw_prediction_invert = self.depth_model_type in [0, 7, 8, 9, 10]
309307
return raw_prediction, raw_prediction_invert
310308

311309

@@ -405,11 +403,11 @@ def estimatemidas(img, model, w, h, resize_mode, normalization, no_half, precisi
405403
return prediction
406404

407405

408-
def estimatemarigold(image, model, w, h, device):
409-
from repositories.Marigold.src.model.marigold_pipeline import MarigoldPipeline
410-
from repositories.Marigold.src.util.ensemble import ensemble_depths
411-
from repositories.Marigold.src.util.image_util import chw2hwc, colorize_depth_maps, resize_max_res
412-
from repositories.Marigold.src.util.seed_all import seed_all
406+
def estimatemarigold(image, model, w, h):
407+
from Marigold.src.model.marigold_pipeline import MarigoldPipeline
408+
from Marigold.src.util.ensemble import ensemble_depths
409+
from Marigold.src.util.image_util import chw2hwc, colorize_depth_maps, resize_max_res
410+
from Marigold.src.util.seed_all import seed_all
413411

414412
n_repeat = 10
415413
denoise_steps = 10
@@ -418,13 +416,18 @@ def estimatemarigold(image, model, w, h, device):
418416
tol = 1e-3
419417
reduction_method = "median"
420418
merging_max_res = None
419+
resize_to_max_res = None
421420

422421
# From Marigold repository run.py
423422
with torch.no_grad():
424-
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
425-
rgb_norm = rgb / 255.0
423+
if resize_to_max_res is not None:
424+
image = (image * 255).astype(np.uint8)
425+
image = np.asarray(resize_max_res(
426+
Image.fromarray(image), max_edge_resolution=resize_to_max_res
427+
)) / 255.0
428+
rgb_norm = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
426429
rgb_norm = torch.from_numpy(rgb_norm).unsqueeze(0).float()
427-
rgb_norm = rgb_norm.to(device)
430+
rgb_norm = rgb_norm.to(depthmap_device)
428431

429432
model.unet.eval()
430433
depth_pred_ls = []
@@ -445,7 +448,7 @@ def estimatemarigold(image, model, w, h, device):
445448
tol=tol,
446449
reduction=reduction_method,
447450
max_res=merging_max_res,
448-
device=device,
451+
device=depthmap_device,
449452
)
450453
else:
451454
depth_pred = depth_preds
@@ -942,6 +945,8 @@ def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel
942945
def singleestimate(img, msize, model, net_type):
943946
if net_type == 0:
944947
return estimateleres(img, model, msize, msize)
948+
elif net_type == 10:
949+
return estimatemarigold(img, model, msize, msize)
945950
elif net_type >= 7:
946951
# np to PIL
947952
return estimatezoedepth(Image.fromarray(np.uint8(img * 255)).convert('RGB'), model, msize, msize)

src/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def get_commit_hash():
1515

1616
REPOSITORY_NAME = "stable-diffusion-webui-depthmap-script"
1717
SCRIPT_NAME = "DepthMap"
18-
SCRIPT_VERSION = "v0.4.4"
18+
SCRIPT_VERSION = "v0.4.5"
1919
SCRIPT_FULL_NAME = f"{SCRIPT_NAME} {SCRIPT_VERSION} ({get_commit_hash()})"
2020

2121

0 commit comments

Comments
 (0)