From 002d38c7d4747d312447cadf9b8bc2bbfa89b61c Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 30 Jul 2024 17:28:40 -0400 Subject: [PATCH 1/2] Update ax_multiobjective_nas_tutorial.py to use TensorboardMetric Ax's TensorboardCurveMetric is going away soon and is replaced by TensorboardMetric --- .../ax_multiobjective_nas_tutorial.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/intermediate_source/ax_multiobjective_nas_tutorial.py b/intermediate_source/ax_multiobjective_nas_tutorial.py index 79b096b9e64..0f1ae21a556 100644 --- a/intermediate_source/ax_multiobjective_nas_tutorial.py +++ b/intermediate_source/ax_multiobjective_nas_tutorial.py @@ -232,21 +232,21 @@ def trainer( # we get the logic to read and parse the TensorBoard logs for free. # -from ax.metrics.tensorboard import TensorboardCurveMetric +from ax.metrics.tensorboard import TensorboardMetric +from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer - -class MyTensorboardMetric(TensorboardCurveMetric): +class MyTensorboardMetric(TensorboardMetric): # NOTE: We need to tell the new TensorBoard metric how to get the id / # file handle for the TensorBoard logs from a trial. In this case # our convention is to just save a separate file per trial in # the prespecified log dir. - @classmethod - def get_ids_from_trials(cls, trials): - return { - trial.index: Path(log_dir).joinpath(str(trial.index)).as_posix() - for trial in trials - } + def _get_event_multiplexer_for_trial(self, trial): + mul = event_multiplexer.EventMultiplexer(max_reload_threads=20) + mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None) + mul.Reload() + + return mul # This indicates whether the metric is queryable while the trial is # still running. We don't use this in the current tutorial, but Ax @@ -266,12 +266,12 @@ def is_available_while_running(cls): val_acc = MyTensorboardMetric( name="val_acc", - curve_name="val_acc", + tag="val_acc", lower_is_better=False, ) model_num_params = MyTensorboardMetric( name="num_params", - curve_name="num_params", + tag="num_params", lower_is_better=True, ) From 925b1b181711072bab3a4410a706497e904baa6d Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 30 Jul 2024 20:27:14 -0400 Subject: [PATCH 2/2] Update requirements.txt Ensure Ax is a version with TensorboardMetric --- .ci/docker/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index a355966683a..99a4118cdb8 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -30,7 +30,7 @@ pytorch-lightning torchx torchrl==0.3.0 tensordict==0.3.0 -ax-platform +ax-platform>==0.4.0 nbformat>==5.9.2 datasets transformers