Skip to content

Commit bca777d

Browse files
committed
improve: early stopping setup
1 parent fe8118a commit bca777d

File tree

1 file changed

+17
-7
lines changed
  • template/{{cookiecutter.repository_name}}/{{cookiecutter.package_name}}

1 file changed

+17
-7
lines changed

template/{{cookiecutter.repository_name}}/{{cookiecutter.package_name}}/train.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,21 +87,31 @@ def evaluate_batch(examples):
8787
)
8888

8989
tensorboard_logger = torch.utils.tensorboard.SummaryWriter()
90-
early_stopping = workflow.EarlyStopping(...)
90+
gradient_metrics = metrics.gradient_metrics(tensorboard_logger)
91+
early_stopping = workflow.EarlyStopping(
92+
tensorboard_logger,
93+
lambda summaries: summaries['early_stopping']['accuracy'],
94+
)
9195

92-
for epoch in tqdm(range(config['max_epochs'])):
93-
for examples in tqdm(gradient_data_loader):
96+
for epoch in range(config['max_epochs']):
97+
for examples in workflow.progress(
98+
gradient_data_loader, gradient_metrics[['loss', 'accuracy']]
99+
):
94100
output = train_batch(examples)
95-
metrics.gradient_metrics(output, tensorboard_logger)
101+
gradient_metrics.update_(output)
102+
gradient_metrics.log()
96103
# optional: schedule learning rate
97104

98105
for name, data_loader in evaluate_data_loaders:
106+
107+
metrics = metrics.evaluate_metrics(name, tensorboard_logger)
99108
for examples in tqdm(data_loader):
100109
output = evaluate_batch(examples)
101-
# TODO: metrics need state?
102-
metrics.evaluate_metrics(output, tensorboard_logger)
110+
metrics.update_(output)
111+
112+
metrics.log()
103113

104-
improved, out_of_patience = early_stopping.score(output)
114+
improved, out_of_patience = early_stopping.score_(output)
105115
if improved:
106116
torch.save(train_state, 'model_checkpoint.pt')
107117
elif out_of_patience(output):

0 commit comments

Comments
 (0)