@@ -74,6 +74,8 @@ def load_models(self, model_type, device: torch.device, boost: bool):
74
74
model_dir = "./models/midas"
75
75
if model_type == 0 :
76
76
model_dir = "./models/leres"
77
+ if model_type == 10 :
78
+ "./models/marigold"
77
79
# create paths to model if not present
78
80
os .makedirs (model_dir , exist_ok = True )
79
81
os .makedirs ('./models/pix2pix' , exist_ok = True )
@@ -194,6 +196,12 @@ def load_models(self, model_type, device: torch.device, boost: bool):
194
196
conf = get_config ("zoedepth_nk" , "infer" )
195
197
model = build_model (conf )
196
198
199
+ elif model_type == 10 : # Marigold v1
200
+ # TODO: pass more parameters
201
+ model_path = f"{ model_dir } /marigold_v1/"
202
+ from repositories .Marigold .src .model .marigold_pipeline import MarigoldPipeline
203
+ model = MarigoldPipeline .from_pretrained (model_path )
204
+
197
205
model .eval () # prepare for evaluation
198
206
# optimize
199
207
if device == torch .device ("cuda" ) and model_type in [0 , 1 , 2 , 3 , 4 , 5 , 6 ]:
@@ -288,10 +296,12 @@ def get_raw_prediction(self, input, net_width, net_height):
288
296
raw_prediction = estimateleres (img , self .depth_model , net_width , net_height )
289
297
elif self .depth_model_type in [7 , 8 , 9 ]:
290
298
raw_prediction = estimatezoedepth (input , self .depth_model , net_width , net_height )
291
- else :
299
+ elif self . depth_model_type in [ 1 , 2 , 3 , 4 , 5 , 6 ] :
292
300
raw_prediction = estimatemidas (img , self .depth_model , net_width , net_height ,
293
301
self .resize_mode , self .normalization , self .no_half ,
294
302
self .precision == "autocast" )
303
+ elif self .depth_model_type == 10 :
304
+ raw_prediction = estimatemarigold (img , self .depth_model , net_width , net_height , self .device )
295
305
else :
296
306
raw_prediction = estimateboost (img , self .depth_model , self .depth_model_type , self .pix2pix_model ,
297
307
self .boost_whole_size_threshold )
@@ -395,6 +405,52 @@ def estimatemidas(img, model, w, h, resize_mode, normalization, no_half, precisi
395
405
return prediction
396
406
397
407
408
+ def estimatemarigold (image , model , w , h , device ):
409
+ from repositories .Marigold .src .model .marigold_pipeline import MarigoldPipeline
410
+ from repositories .Marigold .src .util .ensemble import ensemble_depths
411
+ from repositories .Marigold .src .util .image_util import chw2hwc , colorize_depth_maps , resize_max_res
412
+ from repositories .Marigold .src .util .seed_all import seed_all
413
+
414
+ n_repeat = 10
415
+ denoise_steps = 10
416
+ regularizer_strength = 0.02
417
+ max_iter = 5
418
+ tol = 1e-3
419
+ reduction_method = "median"
420
+ merging_max_res = None
421
+
422
+ # From Marigold repository run.py
423
+ with torch .no_grad ():
424
+ rgb = np .transpose (image , (2 , 0 , 1 )) # [H, W, rgb] -> [rgb, H, W]
425
+ rgb_norm = rgb / 255.0
426
+ rgb_norm = torch .from_numpy (rgb_norm ).unsqueeze (0 ).float ()
427
+ rgb_norm = rgb_norm .to (device )
428
+
429
+ model .unet .eval ()
430
+ depth_pred_ls = []
431
+ for i_rep in range (n_repeat ):
432
+ depth_pred_raw = model .forward (
433
+ rgb_norm , num_inference_steps = denoise_steps , init_depth_latent = None
434
+ )
435
+ # clip prediction
436
+ depth_pred_raw = torch .clip (depth_pred_raw , - 1.0 , 1.0 )
437
+ depth_pred_ls .append (depth_pred_raw .detach ().cpu ().numpy ().copy ())
438
+
439
+ depth_preds = np .concatenate (depth_pred_ls , axis = 0 ).squeeze ()
440
+ if n_repeat > 1 :
441
+ depth_pred , pred_uncert = ensemble_depths (
442
+ depth_preds ,
443
+ regularizer_strength = regularizer_strength ,
444
+ max_iter = max_iter ,
445
+ tol = tol ,
446
+ reduction = reduction_method ,
447
+ max_res = merging_max_res ,
448
+ device = device ,
449
+ )
450
+ else :
451
+ depth_pred = depth_preds
452
+ return depth_pred
453
+
398
454
class ImageandPatchs :
399
455
def __init__ (self , root_dir , name , patchsinfo , rgb_image , scale = 1 ):
400
456
self .root_dir = root_dir
0 commit comments