diff --git a/intermediate_source/dist_tuto.rst b/intermediate_source/dist_tuto.rst index 2e1bba33e23..35f6341395f 100644 --- a/intermediate_source/dist_tuto.rst +++ b/intermediate_source/dist_tuto.rst @@ -327,7 +327,7 @@ the following few lines: transforms.Normalize((0.1307,), (0.3081,)) ])) size = dist.get_world_size() - bsz = 128 / float(size) + bsz = 128 // size partition_sizes = [1.0 / size for _ in range(size)] partition = DataPartitioner(dataset, partition_sizes) partition = partition.use(dist.get_rank())