@@ -232,21 +232,21 @@ def trainer(
232
232
# we get the logic to read and parse the TensorBoard logs for free.
233
233
#
234
234
235
- from ax .metrics .tensorboard import TensorboardCurveMetric
235
+ from ax .metrics .tensorboard import TensorboardMetric
236
+ from tensorboard .backend .event_processing import plugin_event_multiplexer as event_multiplexer
236
237
237
-
238
- class MyTensorboardMetric (TensorboardCurveMetric ):
238
+ class MyTensorboardMetric (TensorboardMetric ):
239
239
240
240
# NOTE: We need to tell the new TensorBoard metric how to get the id /
241
241
# file handle for the TensorBoard logs from a trial. In this case
242
242
# our convention is to just save a separate file per trial in
243
243
# the prespecified log dir.
244
- @ classmethod
245
- def get_ids_from_trials ( cls , trials ):
246
- return {
247
- trial . index : Path ( log_dir ). joinpath ( str ( trial . index )). as_posix ()
248
- for trial in trials
249
- }
244
+ def _get_event_multiplexer_for_trial ( self , trial ):
245
+ mul = event_multiplexer . EventMultiplexer ( max_reload_threads = 20 )
246
+ mul . AddRunsFromDirectory ( Path ( log_dir ). joinpath ( str ( trial . index )). as_posix (), None )
247
+ mul . Reload ()
248
+
249
+ return mul
250
250
251
251
# This indicates whether the metric is queryable while the trial is
252
252
# still running. We don't use this in the current tutorial, but Ax
@@ -266,12 +266,12 @@ def is_available_while_running(cls):
266
266
267
267
val_acc = MyTensorboardMetric (
268
268
name = "val_acc" ,
269
- curve_name = "val_acc" ,
269
+ tag = "val_acc" ,
270
270
lower_is_better = False ,
271
271
)
272
272
model_num_params = MyTensorboardMetric (
273
273
name = "num_params" ,
274
- curve_name = "num_params" ,
274
+ tag = "num_params" ,
275
275
lower_is_better = True ,
276
276
)
277
277
0 commit comments