Skip to content

Commit 987ee87

Browse files
committed
Half mode for Depth Anything and ZoeDepth k and nk
Reduces VRAM usage and may provide speedups. Differences in the result should be imperceptible. Breaks exact reproducibility, but it was never a priority.
1 parent 204ea5b commit 987ee87

File tree

5 files changed

+13
-5
lines changed

5 files changed

+13
-5
lines changed

dzoedepth/models/depth_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, out
137137
with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True.
138138
output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy".
139139
"""
140-
x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device)
140+
# dtype IS ADDED, NOT PRESENT IN THE MAINLINE
141+
x = transforms.ToTensor()(pil_img).unsqueeze(0).to(device=self.device, dtype=next(self.parameters()).dtype)
141142
out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs)
142143
if output_type == "numpy":
143144
return out_tensor.squeeze().cpu().numpy()

dzoedepth/models/layers/patch_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def forward(self, x):
8686
# change to S,N,E format required by transformer
8787
embeddings = embeddings.permute(2, 0, 1)
8888
S, N, E = embeddings.shape
89-
embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device)
89+
# dtype IS ADDED, NOT PRESENT IN THE MAINLINE
90+
embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device).to(dtype=embeddings.dtype)
9091
x = self.transformer_encoder(embeddings) # .shape = S, N, E
9192
return x

scripts/depthmap_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ async def process_video(
8383
raise HTTPException(status_code=422, detail="No images supplied")
8484
print(f"Processing {str(len(depth_input_images))} images trough the API")
8585

86+
# You can use either these strings, or integers
8687
available_models = {
8788
'res101': 0,
8889
'dpt_beit_large_512': 1, #midas 3.1
@@ -94,6 +95,8 @@ async def process_video(
9495
'zoedepth_n': 7, #indoor
9596
'zoedepth_k': 8, #outdoor
9697
'zoedepth_nk': 9,
98+
'marigold_v1': 10,
99+
'depth_anything': 11
97100
}
98101

99102
model_type = options["model_type"]

src/common_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, default_value=None, *args):
3939
STEREO_DIVERGENCE = 2.5
4040
STEREO_SEPARATION = 0.0
4141
STEREO_FILL_ALGO = "polylines_sharp"
42-
STEREO_OFFSET_EXPONENT = 2.0
42+
STEREO_OFFSET_EXPONENT = 1.0
4343
STEREO_BALANCE = 0.0
4444

4545
GEN_NORMALMAP = False

src/depthmap_generation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,9 @@ def flatten(el):
246246
if model_type in [0, 1, 2, 3, 4, 5, 6]:
247247
model = model.to(memory_format=torch.channels_last) # TODO: weird
248248
if not self.no_half:
249-
if model_type in [1, 2, 3, 4, 5, 6] and not boost: # TODO: zoedepth, Marigold and depth_anything, too?
249+
# Marigold can be done
250+
# TODO: Fix for zoedepth_n - it completely trips and generates black images
251+
if model_type in [1, 2, 3, 4, 5, 6, 8, 9, 11] and not boost:
250252
model = model.half()
251253
model.to(device) # to correct device
252254

@@ -484,7 +486,8 @@ def estimatedepthanything(image, model, w, h):
484486
)
485487

486488
timage = transform({"image": image})["image"]
487-
timage = torch.from_numpy(timage).unsqueeze(0).to(next(model.parameters()).device)
489+
timage = torch.from_numpy(timage).unsqueeze(0).to(device=next(model.parameters()).device,
490+
dtype=next(model.parameters()).dtype)
488491

489492
with torch.no_grad():
490493
depth = model(timage)

0 commit comments

Comments
 (0)