@@ -74,8 +74,6 @@ def load_models(self, model_type, device: torch.device, boost: bool):
74
74
model_dir = "./models/midas"
75
75
if model_type == 0 :
76
76
model_dir = "./models/leres"
77
- if model_type == 10 :
78
- "./models/marigold"
79
77
# create paths to model if not present
80
78
os .makedirs (model_dir , exist_ok = True )
81
79
os .makedirs ('./models/pix2pix' , exist_ok = True )
@@ -197,9 +195,9 @@ def load_models(self, model_type, device: torch.device, boost: bool):
197
195
model = build_model (conf )
198
196
199
197
elif model_type == 10 : # Marigold v1
200
- # TODO: pass more parameters
201
- model_path = f" { model_dir } /marigold_v1/"
202
- from repositories . Marigold .src .model .marigold_pipeline import MarigoldPipeline
198
+ model_path = "Bingxin/Marigold"
199
+ print ( model_path )
200
+ from Marigold .src .model .marigold_pipeline import MarigoldPipeline
203
201
model = MarigoldPipeline .from_pretrained (model_path )
204
202
205
203
model .eval () # prepare for evaluation
@@ -301,11 +299,11 @@ def get_raw_prediction(self, input, net_width, net_height):
301
299
self .resize_mode , self .normalization , self .no_half ,
302
300
self .precision == "autocast" )
303
301
elif self .depth_model_type == 10 :
304
- raw_prediction = estimatemarigold (img , self .depth_model , net_width , net_height , self . device )
302
+ raw_prediction = estimatemarigold (img , self .depth_model , net_width , net_height )
305
303
else :
306
304
raw_prediction = estimateboost (img , self .depth_model , self .depth_model_type , self .pix2pix_model ,
307
305
self .boost_whole_size_threshold )
308
- raw_prediction_invert = self .depth_model_type in [0 , 7 , 8 , 9 ]
306
+ raw_prediction_invert = self .depth_model_type in [0 , 7 , 8 , 9 , 10 ]
309
307
return raw_prediction , raw_prediction_invert
310
308
311
309
@@ -405,11 +403,11 @@ def estimatemidas(img, model, w, h, resize_mode, normalization, no_half, precisi
405
403
return prediction
406
404
407
405
408
- def estimatemarigold (image , model , w , h , device ):
409
- from repositories . Marigold .src .model .marigold_pipeline import MarigoldPipeline
410
- from repositories . Marigold .src .util .ensemble import ensemble_depths
411
- from repositories . Marigold .src .util .image_util import chw2hwc , colorize_depth_maps , resize_max_res
412
- from repositories . Marigold .src .util .seed_all import seed_all
406
+ def estimatemarigold (image , model , w , h ):
407
+ from Marigold .src .model .marigold_pipeline import MarigoldPipeline
408
+ from Marigold .src .util .ensemble import ensemble_depths
409
+ from Marigold .src .util .image_util import chw2hwc , colorize_depth_maps , resize_max_res
410
+ from Marigold .src .util .seed_all import seed_all
413
411
414
412
n_repeat = 10
415
413
denoise_steps = 10
@@ -418,13 +416,18 @@ def estimatemarigold(image, model, w, h, device):
418
416
tol = 1e-3
419
417
reduction_method = "median"
420
418
merging_max_res = None
419
+ resize_to_max_res = None
421
420
422
421
# From Marigold repository run.py
423
422
with torch .no_grad ():
424
- rgb = np .transpose (image , (2 , 0 , 1 )) # [H, W, rgb] -> [rgb, H, W]
425
- rgb_norm = rgb / 255.0
423
+ if resize_to_max_res is not None :
424
+ image = (image * 255 ).astype (np .uint8 )
425
+ image = np .asarray (resize_max_res (
426
+ Image .fromarray (image ), max_edge_resolution = resize_to_max_res
427
+ )) / 255.0
428
+ rgb_norm = np .transpose (image , (2 , 0 , 1 )) # [H, W, rgb] -> [rgb, H, W]
426
429
rgb_norm = torch .from_numpy (rgb_norm ).unsqueeze (0 ).float ()
427
- rgb_norm = rgb_norm .to (device )
430
+ rgb_norm = rgb_norm .to (depthmap_device )
428
431
429
432
model .unet .eval ()
430
433
depth_pred_ls = []
@@ -445,7 +448,7 @@ def estimatemarigold(image, model, w, h, device):
445
448
tol = tol ,
446
449
reduction = reduction_method ,
447
450
max_res = merging_max_res ,
448
- device = device ,
451
+ device = depthmap_device ,
449
452
)
450
453
else :
451
454
depth_pred = depth_preds
@@ -942,6 +945,8 @@ def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel
942
945
def singleestimate (img , msize , model , net_type ):
943
946
if net_type == 0 :
944
947
return estimateleres (img , model , msize , msize )
948
+ elif net_type == 10 :
949
+ return estimatemarigold (img , model , msize , msize )
945
950
elif net_type >= 7 :
946
951
# np to PIL
947
952
return estimatezoedepth (Image .fromarray (np .uint8 (img * 255 )).convert ('RGB' ), model , msize , msize )
0 commit comments