diff --git a/auto3dseg/tasks/hecktor22/hecktor_crop_neck_region.py b/auto3dseg/tasks/hecktor22/hecktor_crop_neck_region.py index b0405805bb..3adf6e1d07 100644 --- a/auto3dseg/tasks/hecktor22/hecktor_crop_neck_region.py +++ b/auto3dseg/tasks/hecktor22/hecktor_crop_neck_region.py @@ -50,7 +50,9 @@ def __call__(self, data): box_start, box_end = self.extract_roi(im_pet=im_pet, box_size=box_size) - if "label" in d and "label" in self.keys: + use_label = "label" in d and "label" in self.keys and (d["image"].shape[1:] == d["label"].shape[1:]) + + if use_label: # if label mask is available, let's check if the cropped region includes all foreground before_sum = d["label"].sum().item() after_sum = ( @@ -83,7 +85,8 @@ def __call__(self, data): d[self.end_coord_key] = box_end for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key, extra_info={"box_start": box_start, "box_end": box_end}) + if key == "label" and not use_label: + continue d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m) return d