Skip to content

Commit 002d38c

Browse files
authored
Update ax_multiobjective_nas_tutorial.py to use TensorboardMetric
Ax's TensorboardCurveMetric is going away soon and is replaced by TensorboardMetric
1 parent c3882db commit 002d38c

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

intermediate_source/ax_multiobjective_nas_tutorial.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,21 +232,21 @@ def trainer(
232232
# we get the logic to read and parse the TensorBoard logs for free.
233233
#
234234

235-
from ax.metrics.tensorboard import TensorboardCurveMetric
235+
from ax.metrics.tensorboard import TensorboardMetric
236+
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
236237

237-
238-
class MyTensorboardMetric(TensorboardCurveMetric):
238+
class MyTensorboardMetric(TensorboardMetric):
239239

240240
# NOTE: We need to tell the new TensorBoard metric how to get the id /
241241
# file handle for the TensorBoard logs from a trial. In this case
242242
# our convention is to just save a separate file per trial in
243243
# the prespecified log dir.
244-
@classmethod
245-
def get_ids_from_trials(cls, trials):
246-
return {
247-
trial.index: Path(log_dir).joinpath(str(trial.index)).as_posix()
248-
for trial in trials
249-
}
244+
def _get_event_multiplexer_for_trial(self, trial):
245+
mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)
246+
mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None)
247+
mul.Reload()
248+
249+
return mul
250250

251251
# This indicates whether the metric is queryable while the trial is
252252
# still running. We don't use this in the current tutorial, but Ax
@@ -266,12 +266,12 @@ def is_available_while_running(cls):
266266

267267
val_acc = MyTensorboardMetric(
268268
name="val_acc",
269-
curve_name="val_acc",
269+
tag="val_acc",
270270
lower_is_better=False,
271271
)
272272
model_num_params = MyTensorboardMetric(
273273
name="num_params",
274-
curve_name="num_params",
274+
tag="num_params",
275275
lower_is_better=True,
276276
)
277277

0 commit comments

Comments
 (0)