Skip to content

Commit bb9ccb7

Browse files
authored
Add transform to handle training data with empty ground truth boxes (#1256)
Fixes #1197 . ### Description Add transform to handle training data with empty ground truth boxes Add verbose choice for training script ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [x] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` ---------
1 parent 2ecaef9 commit bb9ccb7

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

detection/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ python3 luna16_training.py \
8888
-e ./config/environment_luna16_fold${i}.json \
8989
-c ./config/config_train_luna16_16g.json
9090
```
91+
If you are tuning hyper-parameters, please also add `--verbose` flag.
92+
Details about matched anchors during training will be printed out.
9193

9294
For each fold, 95% of the training data is used for training, while the rest 5% is used for validation and model selection.
9395
The training and validation curves for 300 epochs of 10 folds are shown below. The upper row shows the training losses for box regression and classification. The bottom row shows the validation mAP and mAR for IoU ranging from 0.1 to 0.5.

detection/generate_transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
RandRotateBox90d,
3838
RandZoomBoxd,
3939
ConvertBoxModed,
40+
StandardizeEmptyBoxd,
4041
)
4142

4243

@@ -70,7 +71,6 @@ def generate_detection_train_transform(
7071
Return:
7172
training transform for detection
7273
"""
73-
amp = True
7474
if amp:
7575
compute_dtype = torch.float16
7676
else:
@@ -82,6 +82,7 @@ def generate_detection_train_transform(
8282
EnsureChannelFirstd(keys=[image_key]),
8383
EnsureTyped(keys=[image_key, box_key], dtype=torch.float32),
8484
EnsureTyped(keys=[label_key], dtype=torch.long),
85+
StandardizeEmptyBoxd(box_keys=[box_key], box_ref_image_keys=image_key),
8586
Orientationd(keys=[image_key], axcodes="RAS"),
8687
intensity_transform,
8788
EnsureTyped(keys=[image_key], dtype=torch.float16),
@@ -216,7 +217,6 @@ def generate_detection_val_transform(
216217
Return:
217218
validation transform for detection
218219
"""
219-
amp = True
220220
if amp:
221221
compute_dtype = torch.float16
222222
else:
@@ -228,6 +228,7 @@ def generate_detection_val_transform(
228228
EnsureChannelFirstd(keys=[image_key]),
229229
EnsureTyped(keys=[image_key, box_key], dtype=torch.float32),
230230
EnsureTyped(keys=[label_key], dtype=torch.long),
231+
StandardizeEmptyBoxd(box_keys=[box_key], box_ref_image_keys=image_key),
231232
Orientationd(keys=[image_key], axcodes="RAS"),
232233
intensity_transform,
233234
ConvertBoxToStandardModed(box_keys=[box_key], mode=gt_box_mode),
@@ -272,7 +273,6 @@ def generate_detection_inference_transform(
272273
Return:
273274
validation transform for detection
274275
"""
275-
amp = True
276276
if amp:
277277
compute_dtype = torch.float16
278278
else:

detection/luna16_training.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def main():
5757
default="./config/config_train.json",
5858
help="config json file that stores hyper-parameters",
5959
)
60+
parser.add_argument(
61+
"-v",
62+
"--verbose",
63+
default=False,
64+
action="store_true",
65+
help="whether to print verbose detail during training, recommand True when you are not sure about hyper-parameters",
66+
)
6067
args = parser.parse_args()
6168

6269
set_determinism(seed=0)
@@ -188,7 +195,7 @@ def main():
188195
)
189196

190197
# 3) build detector
191-
detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=False).to(device)
198+
detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=args.verbose).to(device)
192199

193200
# set training components
194201
detector.set_atss_matcher(num_candidates=4, center_in_gt=False)

0 commit comments

Comments
 (0)