Skip to content

Commit 1feb07f

Browse files
authored
Merge branch 'main' into fix-torchrl-deps
2 parents e3c7f14 + f5c28eb commit 1feb07f

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

.ci/docker/requirements.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ pytorch-lightning
3030
torchx
3131
torchrl==0.5.0
3232
tensordict==0.5.0
33-
ax-platform
33+
# TODO: use stable 0.5 when released
34+
-e git+https://github.com/pytorch/rl.git#egg=torchrl
35+
-e git+https://github.com/pytorch/tensordict.git#egg=tensordict
36+
ax-platform>==0.4.0
3437
nbformat>==5.9.2
3538
datasets
3639
transformers
@@ -68,4 +71,4 @@ pygame==2.1.2
6871
pycocotools
6972
semilearn==0.3.2
7073
torchao==0.0.3
71-
segment_anything==1.0
74+
segment_anything==1.0

intermediate_source/ax_multiobjective_nas_tutorial.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,21 +232,21 @@ def trainer(
232232
# we get the logic to read and parse the TensorBoard logs for free.
233233
#
234234

235-
from ax.metrics.tensorboard import TensorboardCurveMetric
235+
from ax.metrics.tensorboard import TensorboardMetric
236+
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
236237

237-
238-
class MyTensorboardMetric(TensorboardCurveMetric):
238+
class MyTensorboardMetric(TensorboardMetric):
239239

240240
# NOTE: We need to tell the new TensorBoard metric how to get the id /
241241
# file handle for the TensorBoard logs from a trial. In this case
242242
# our convention is to just save a separate file per trial in
243243
# the prespecified log dir.
244-
@classmethod
245-
def get_ids_from_trials(cls, trials):
246-
return {
247-
trial.index: Path(log_dir).joinpath(str(trial.index)).as_posix()
248-
for trial in trials
249-
}
244+
def _get_event_multiplexer_for_trial(self, trial):
245+
mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)
246+
mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None)
247+
mul.Reload()
248+
249+
return mul
250250

251251
# This indicates whether the metric is queryable while the trial is
252252
# still running. We don't use this in the current tutorial, but Ax
@@ -266,12 +266,12 @@ def is_available_while_running(cls):
266266

267267
val_acc = MyTensorboardMetric(
268268
name="val_acc",
269-
curve_name="val_acc",
269+
tag="val_acc",
270270
lower_is_better=False,
271271
)
272272
model_num_params = MyTensorboardMetric(
273273
name="num_params",
274-
curve_name="num_params",
274+
tag="num_params",
275275
lower_is_better=True,
276276
)
277277

prototype_source/pt2e_quantizer.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Prerequisites:
88

99
Required:
1010

11-
- `Torchdynamo concepts in PyTorch <https://pytorch.org/docs/stable/dynamo/index.html>`__
11+
- `Torchdynamo concepts in PyTorch <https://pytorch.org/docs/stable/torch.compiler_dynamo_overview.html>`__
1212

1313
- `Quantization concepts in PyTorch <https://pytorch.org/docs/master/quantization.html#quantization-api-summary>`__
1414

0 commit comments

Comments
 (0)