Skip to content

Add transform to handle training data with empty ground truth boxes #1256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 25, 2023
Merged
2 changes: 2 additions & 0 deletions detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ python3 luna16_training.py \
-e ./config/environment_luna16_fold${i}.json \
-c ./config/config_train_luna16_16g.json
```
If you are tuning hyper-parameters, please also add `--verbose` flag.
Details about matched anchors during training will be printed out.

For each fold, 95% of the training data is used for training, while the rest 5% is used for validation and model selection.
The training and validation curves for 300 epochs of 10 folds are shown below. The upper row shows the training losses for box regression and classification. The bottom row shows the validation mAP and mAR for IoU ranging from 0.1 to 0.5.
Expand Down
6 changes: 3 additions & 3 deletions detection/generate_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
RandRotateBox90d,
RandZoomBoxd,
ConvertBoxModed,
StandardizeEmptyBoxd,
)


Expand Down Expand Up @@ -70,7 +71,6 @@ def generate_detection_train_transform(
Return:
training transform for detection
"""
amp = True
if amp:
compute_dtype = torch.float16
else:
Expand All @@ -82,6 +82,7 @@ def generate_detection_train_transform(
EnsureChannelFirstd(keys=[image_key]),
EnsureTyped(keys=[image_key, box_key], dtype=torch.float32),
EnsureTyped(keys=[label_key], dtype=torch.long),
StandardizeEmptyBoxd(box_keys=[box_key], box_ref_image_keys=image_key),
Orientationd(keys=[image_key], axcodes="RAS"),
intensity_transform,
EnsureTyped(keys=[image_key], dtype=torch.float16),
Expand Down Expand Up @@ -216,7 +217,6 @@ def generate_detection_val_transform(
Return:
validation transform for detection
"""
amp = True
if amp:
compute_dtype = torch.float16
else:
Expand All @@ -228,6 +228,7 @@ def generate_detection_val_transform(
EnsureChannelFirstd(keys=[image_key]),
EnsureTyped(keys=[image_key, box_key], dtype=torch.float32),
EnsureTyped(keys=[label_key], dtype=torch.long),
StandardizeEmptyBoxd(box_keys=[box_key], box_ref_image_keys=image_key),
Orientationd(keys=[image_key], axcodes="RAS"),
intensity_transform,
ConvertBoxToStandardModed(box_keys=[box_key], mode=gt_box_mode),
Expand Down Expand Up @@ -272,7 +273,6 @@ def generate_detection_inference_transform(
Return:
validation transform for detection
"""
amp = True
if amp:
compute_dtype = torch.float16
else:
Expand Down
9 changes: 8 additions & 1 deletion detection/luna16_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def main():
default="./config/config_train.json",
help="config json file that stores hyper-parameters",
)
parser.add_argument(
"-v",
"--verbose",
default=False,
action="store_true",
help="whether to print verbose detail during training, recommand True when you are not sure about hyper-parameters",
)
args = parser.parse_args()

set_determinism(seed=0)
Expand Down Expand Up @@ -188,7 +195,7 @@ def main():
)

# 3) build detector
detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=False).to(device)
detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=args.verbose).to(device)

# set training components
detector.set_atss_matcher(num_candidates=4, center_in_gt=False)
Expand Down