diff --git a/detection/README.md b/detection/README.md index d4e7da35b5..80661e3873 100644 --- a/detection/README.md +++ b/detection/README.md @@ -88,6 +88,8 @@ python3 luna16_training.py \ -e ./config/environment_luna16_fold${i}.json \ -c ./config/config_train_luna16_16g.json ``` +If you are tuning hyper-parameters, please also add `--verbose` flag. +Details about matched anchors during training will be printed out. For each fold, 95% of the training data is used for training, while the rest 5% is used for validation and model selection. The training and validation curves for 300 epochs of 10 folds are shown below. The upper row shows the training losses for box regression and classification. The bottom row shows the validation mAP and mAR for IoU ranging from 0.1 to 0.5. diff --git a/detection/generate_transforms.py b/detection/generate_transforms.py index 32c5ffd22e..b255532742 100644 --- a/detection/generate_transforms.py +++ b/detection/generate_transforms.py @@ -37,6 +37,7 @@ RandRotateBox90d, RandZoomBoxd, ConvertBoxModed, + StandardizeEmptyBoxd, ) @@ -70,7 +71,6 @@ def generate_detection_train_transform( Return: training transform for detection """ - amp = True if amp: compute_dtype = torch.float16 else: @@ -82,6 +82,7 @@ def generate_detection_train_transform( EnsureChannelFirstd(keys=[image_key]), EnsureTyped(keys=[image_key, box_key], dtype=torch.float32), EnsureTyped(keys=[label_key], dtype=torch.long), + StandardizeEmptyBoxd(box_keys=[box_key], box_ref_image_keys=image_key), Orientationd(keys=[image_key], axcodes="RAS"), intensity_transform, EnsureTyped(keys=[image_key], dtype=torch.float16), @@ -216,7 +217,6 @@ def generate_detection_val_transform( Return: validation transform for detection """ - amp = True if amp: compute_dtype = torch.float16 else: @@ -228,6 +228,7 @@ def generate_detection_val_transform( EnsureChannelFirstd(keys=[image_key]), EnsureTyped(keys=[image_key, box_key], dtype=torch.float32), EnsureTyped(keys=[label_key], dtype=torch.long), + StandardizeEmptyBoxd(box_keys=[box_key], box_ref_image_keys=image_key), Orientationd(keys=[image_key], axcodes="RAS"), intensity_transform, ConvertBoxToStandardModed(box_keys=[box_key], mode=gt_box_mode), @@ -272,7 +273,6 @@ def generate_detection_inference_transform( Return: validation transform for detection """ - amp = True if amp: compute_dtype = torch.float16 else: diff --git a/detection/luna16_training.py b/detection/luna16_training.py index bb11e5ceb4..9897697dac 100644 --- a/detection/luna16_training.py +++ b/detection/luna16_training.py @@ -57,6 +57,13 @@ def main(): default="./config/config_train.json", help="config json file that stores hyper-parameters", ) + parser.add_argument( + "-v", + "--verbose", + default=False, + action="store_true", + help="whether to print verbose detail during training, recommand True when you are not sure about hyper-parameters", + ) args = parser.parse_args() set_determinism(seed=0) @@ -188,7 +195,7 @@ def main(): ) # 3) build detector - detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=False).to(device) + detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=args.verbose).to(device) # set training components detector.set_atss_matcher(num_candidates=4, center_in_gt=False)