Skip to content

Commit af64de8

Browse files
committed
[WIP] marigold
1 parent 4887a9c commit af64de8

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

install.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,5 @@ def ensure(module_name, min_version=None):
5252

5353
if platform.system() == 'Darwin':
5454
ensure('pyqt6')
55+
56+
launch.git_clone("https://github.com/prs-eth/Marigold", "repositories/Marigold", "Marigold", "cc78ff3")

src/common_ui.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def main_ui_panel(is_depth_tab):
3737
'dpt_beit_large_384 (midas 3.1)', 'dpt_large_384 (midas 3.0)',
3838
'dpt_hybrid_384 (midas 3.0)',
3939
'midas_v21', 'midas_v21_small',
40-
'zoedepth_n (indoor)', 'zoedepth_k (outdoor)', 'zoedepth_nk'],
40+
'zoedepth_n (indoor)', 'zoedepth_k (outdoor)', 'zoedepth_nk',
41+
'Marigold v1'],
4142
type="index")
4243
with gr.Box() as cur_option_root:
4344
inp -= 'depthmap_gen_row_1', cur_option_root

src/depthmap_generation.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def load_models(self, model_type, device: torch.device, boost: bool):
7474
model_dir = "./models/midas"
7575
if model_type == 0:
7676
model_dir = "./models/leres"
77+
if model_type == 10:
78+
"./models/marigold"
7779
# create paths to model if not present
7880
os.makedirs(model_dir, exist_ok=True)
7981
os.makedirs('./models/pix2pix', exist_ok=True)
@@ -194,6 +196,12 @@ def load_models(self, model_type, device: torch.device, boost: bool):
194196
conf = get_config("zoedepth_nk", "infer")
195197
model = build_model(conf)
196198

199+
elif model_type == 10: # Marigold v1
200+
# TODO: pass more parameters
201+
model_path = f"{model_dir}/marigold_v1/"
202+
from repositories.Marigold.src.model.marigold_pipeline import MarigoldPipeline
203+
model = MarigoldPipeline.from_pretrained(model_path)
204+
197205
model.eval() # prepare for evaluation
198206
# optimize
199207
if device == torch.device("cuda") and model_type in [0, 1, 2, 3, 4, 5, 6]:
@@ -288,10 +296,12 @@ def get_raw_prediction(self, input, net_width, net_height):
288296
raw_prediction = estimateleres(img, self.depth_model, net_width, net_height)
289297
elif self.depth_model_type in [7, 8, 9]:
290298
raw_prediction = estimatezoedepth(input, self.depth_model, net_width, net_height)
291-
else:
299+
elif self.depth_model_type in [1, 2, 3, 4, 5, 6]:
292300
raw_prediction = estimatemidas(img, self.depth_model, net_width, net_height,
293301
self.resize_mode, self.normalization, self.no_half,
294302
self.precision == "autocast")
303+
elif self.depth_model_type == 10:
304+
raw_prediction = estimatemarigold(img, self.depth_model, net_width, net_height, self.device)
295305
else:
296306
raw_prediction = estimateboost(img, self.depth_model, self.depth_model_type, self.pix2pix_model,
297307
self.boost_whole_size_threshold)
@@ -395,6 +405,52 @@ def estimatemidas(img, model, w, h, resize_mode, normalization, no_half, precisi
395405
return prediction
396406

397407

408+
def estimatemarigold(image, model, w, h, device):
409+
from repositories.Marigold.src.model.marigold_pipeline import MarigoldPipeline
410+
from repositories.Marigold.src.util.ensemble import ensemble_depths
411+
from repositories.Marigold.src.util.image_util import chw2hwc, colorize_depth_maps, resize_max_res
412+
from repositories.Marigold.src.util.seed_all import seed_all
413+
414+
n_repeat = 10
415+
denoise_steps = 10
416+
regularizer_strength = 0.02
417+
max_iter = 5
418+
tol = 1e-3
419+
reduction_method = "median"
420+
merging_max_res = None
421+
422+
# From Marigold repository run.py
423+
with torch.no_grad():
424+
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
425+
rgb_norm = rgb / 255.0
426+
rgb_norm = torch.from_numpy(rgb_norm).unsqueeze(0).float()
427+
rgb_norm = rgb_norm.to(device)
428+
429+
model.unet.eval()
430+
depth_pred_ls = []
431+
for i_rep in range(n_repeat):
432+
depth_pred_raw = model.forward(
433+
rgb_norm, num_inference_steps=denoise_steps, init_depth_latent=None
434+
)
435+
# clip prediction
436+
depth_pred_raw = torch.clip(depth_pred_raw, -1.0, 1.0)
437+
depth_pred_ls.append(depth_pred_raw.detach().cpu().numpy().copy())
438+
439+
depth_preds = np.concatenate(depth_pred_ls, axis=0).squeeze()
440+
if n_repeat > 1:
441+
depth_pred, pred_uncert = ensemble_depths(
442+
depth_preds,
443+
regularizer_strength=regularizer_strength,
444+
max_iter=max_iter,
445+
tol=tol,
446+
reduction=reduction_method,
447+
max_res=merging_max_res,
448+
device=device,
449+
)
450+
else:
451+
depth_pred = depth_preds
452+
return depth_pred
453+
398454
class ImageandPatchs:
399455
def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1):
400456
self.root_dir = root_dir

0 commit comments

Comments
 (0)