Skip to content

Commit c61e49c

Browse files
guyang3532brianjo
andauthored
Recover the attributes of torch in memory_format_tutorial (#1112)
Co-authored-by: Brian Johnson <brianjo@fb.com>
1 parent fa7cff7 commit c61e49c

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

intermediate_source/memory_format_tutorial.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,14 +261,17 @@ def check_cl(*args, **kwargs):
261261
return result
262262
return check_cl
263263

264+
old_attrs = dict()
264265

265266
def attribute(m):
267+
old_attrs[m] = dict()
266268
for i in dir(m):
267269
e = getattr(m, i)
268270
exclude_functions = ['is_cuda', 'has_names', 'numel',
269271
'stride', 'Tensor', 'is_contiguous', '__class__']
270272
if i not in exclude_functions and not i.startswith('_') and '__call__' in dir(e):
271273
try:
274+
old_attrs[m][i] = e
272275
setattr(m, i, check_wrapper(e))
273276
except Exception as e:
274277
print(i)
@@ -286,6 +289,13 @@ def attribute(m):
286289
# guide https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators.
287290
#
288291

292+
######################################################################
293+
# Code below is to recover the attributes of torch.
294+
295+
for (m, attrs) in old_attrs.items():
296+
for (k,v) in attrs.items():
297+
setattr(m, k, v)
298+
289299
######################################################################
290300
# Work to do
291301
# ----------

0 commit comments

Comments
 (0)