From 188f0a3d4349d88d3e631dc68a4fed51c16eaf38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20G=C3=B6tz?= Date: Mon, 29 Jan 2024 22:44:14 +0100 Subject: [PATCH] Fix batch size calculation in dist_tuto Batch size must be an int, not a float. This change fixes it, basically doing the same as in https://github.com/seba-1511/dist_tuto.pth/blob/a552567061a9985cdcfe72ecb9b47e4630d6a7fe/train_dist.py#L85. --- intermediate_source/dist_tuto.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/dist_tuto.rst b/intermediate_source/dist_tuto.rst index 9a0ceb7a4a8..ece84930d2c 100644 --- a/intermediate_source/dist_tuto.rst +++ b/intermediate_source/dist_tuto.rst @@ -327,7 +327,7 @@ the following few lines: transforms.Normalize((0.1307,), (0.3081,)) ])) size = dist.get_world_size() - bsz = 128 / float(size) + bsz = 128 // size partition_sizes = [1.0 / size for _ in range(size)] partition = DataPartitioner(dataset, partition_sizes) partition = partition.use(dist.get_rank())