From 4b6497c7049cac06eaf21d206d4b7694fe730a8a Mon Sep 17 00:00:00 2001 From: TheMemoryDealer <32904619+TheMemoryDealer@users.noreply.github.com> Date: Thu, 1 Jun 2023 12:16:44 +0100 Subject: [PATCH 1/2] Updates #836 as suggested in https://github.com/pytorch/pytorch/issues/16885#issuecomment-551779897 --- beginner_source/former_torchies/parallelism_tutorial.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/beginner_source/former_torchies/parallelism_tutorial.py b/beginner_source/former_torchies/parallelism_tutorial.py index 18c14c43167..aad0328b48f 100644 --- a/beginner_source/former_torchies/parallelism_tutorial.py +++ b/beginner_source/former_torchies/parallelism_tutorial.py @@ -51,9 +51,12 @@ def forward(self, x): # clashes in their names. For those who still want to access the attributes, # a workaround is to use a subclass of ``DataParallel`` as below. -class MyDataParallel(nn.DataParallel): +class MyDataParallel(DataParallel): def __getattr__(self, name): - return getattr(self.module, name) + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) ######################################################################## # **Primitives on which DataParallel is implemented upon:** From 9a04cebb45299f3669481c8a47196303d0738eea Mon Sep 17 00:00:00 2001 From: TheMemoryDealer <32904619+TheMemoryDealer@users.noreply.github.com> Date: Thu, 1 Jun 2023 13:24:33 +0100 Subject: [PATCH 2/2] Update parallelism_tutorial.py added nn.DataParallel --- beginner_source/former_torchies/parallelism_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/former_torchies/parallelism_tutorial.py b/beginner_source/former_torchies/parallelism_tutorial.py index aad0328b48f..a11d844e1bd 100644 --- a/beginner_source/former_torchies/parallelism_tutorial.py +++ b/beginner_source/former_torchies/parallelism_tutorial.py @@ -51,7 +51,7 @@ def forward(self, x): # clashes in their names. For those who still want to access the attributes, # a workaround is to use a subclass of ``DataParallel`` as below. -class MyDataParallel(DataParallel): +class MyDataParallel(nn.DataParallel): def __getattr__(self, name): try: return super().__getattr__(name)