Skip to content

Commit e4a47bf

Browse files
authored
Update ax_multiobjective_nas_tutorial.py to use TensorboardMetric (#2985)
Ax's TensorboardCurveMetric is going away soon and is replaced by TensorboardMetric.
1 parent a66464b commit e4a47bf

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

.ci/docker/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ torchx
3131
# TODO: use stable 0.5 when released
3232
-e git+https://github.com/pytorch/rl.git#egg=torchrl
3333
-e git+https://github.com/pytorch/tensordict.git#egg=tensordict
34-
ax-platform
34+
ax-platform>==0.4.0
3535
nbformat>==5.9.2
3636
datasets
3737
transformers
@@ -69,4 +69,4 @@ pygame==2.1.2
6969
pycocotools
7070
semilearn==0.3.2
7171
torchao==0.0.3
72-
segment_anything==1.0
72+
segment_anything==1.0

intermediate_source/ax_multiobjective_nas_tutorial.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,21 +232,21 @@ def trainer(
232232
# we get the logic to read and parse the TensorBoard logs for free.
233233
#
234234

235-
from ax.metrics.tensorboard import TensorboardCurveMetric
235+
from ax.metrics.tensorboard import TensorboardMetric
236+
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
236237

237-
238-
class MyTensorboardMetric(TensorboardCurveMetric):
238+
class MyTensorboardMetric(TensorboardMetric):
239239

240240
# NOTE: We need to tell the new TensorBoard metric how to get the id /
241241
# file handle for the TensorBoard logs from a trial. In this case
242242
# our convention is to just save a separate file per trial in
243243
# the prespecified log dir.
244-
@classmethod
245-
def get_ids_from_trials(cls, trials):
246-
return {
247-
trial.index: Path(log_dir).joinpath(str(trial.index)).as_posix()
248-
for trial in trials
249-
}
244+
def _get_event_multiplexer_for_trial(self, trial):
245+
mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)
246+
mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None)
247+
mul.Reload()
248+
249+
return mul
250250

251251
# This indicates whether the metric is queryable while the trial is
252252
# still running. We don't use this in the current tutorial, but Ax
@@ -266,12 +266,12 @@ def is_available_while_running(cls):
266266

267267
val_acc = MyTensorboardMetric(
268268
name="val_acc",
269-
curve_name="val_acc",
269+
tag="val_acc",
270270
lower_is_better=False,
271271
)
272272
model_num_params = MyTensorboardMetric(
273273
name="num_params",
274-
curve_name="num_params",
274+
tag="num_params",
275275
lower_is_better=True,
276276
)
277277

0 commit comments

Comments
 (0)