@@ -87,21 +87,31 @@ def evaluate_batch(examples):
87
87
)
88
88
89
89
tensorboard_logger = torch .utils .tensorboard .SummaryWriter ()
90
- early_stopping = workflow .EarlyStopping (...)
90
+ gradient_metrics = metrics .gradient_metrics (tensorboard_logger )
91
+ early_stopping = workflow .EarlyStopping (
92
+ tensorboard_logger ,
93
+ lambda summaries : summaries ['early_stopping' ]['accuracy' ],
94
+ )
91
95
92
- for epoch in tqdm (range (config ['max_epochs' ])):
93
- for examples in tqdm (gradient_data_loader ):
96
+ for epoch in range (config ['max_epochs' ]):
97
+ for examples in workflow .progress (
98
+ gradient_data_loader , gradient_metrics [['loss' , 'accuracy' ]]
99
+ ):
94
100
output = train_batch (examples )
95
- metrics .gradient_metrics (output , tensorboard_logger )
101
+ gradient_metrics .update_ (output )
102
+ gradient_metrics .log ()
96
103
# optional: schedule learning rate
97
104
98
105
for name , data_loader in evaluate_data_loaders :
106
+
107
+ metrics = metrics .evaluate_metrics (name , tensorboard_logger )
99
108
for examples in tqdm (data_loader ):
100
109
output = evaluate_batch (examples )
101
- # TODO: metrics need state?
102
- metrics .evaluate_metrics (output , tensorboard_logger )
110
+ metrics .update_ (output )
111
+
112
+ metrics .log ()
103
113
104
- improved , out_of_patience = early_stopping .score (output )
114
+ improved , out_of_patience = early_stopping .score_ (output )
105
115
if improved :
106
116
torch .save (train_state , 'model_checkpoint.pt' )
107
117
elif out_of_patience (output ):
0 commit comments