diff --git a/intermediate_source/memory_format_tutorial.py b/intermediate_source/memory_format_tutorial.py index 244e23ac204..014f1668504 100644 --- a/intermediate_source/memory_format_tutorial.py +++ b/intermediate_source/memory_format_tutorial.py @@ -261,14 +261,17 @@ def check_cl(*args, **kwargs): return result return check_cl +old_attrs = dict() def attribute(m): + old_attrs[m] = dict() for i in dir(m): e = getattr(m, i) exclude_functions = ['is_cuda', 'has_names', 'numel', 'stride', 'Tensor', 'is_contiguous', '__class__'] if i not in exclude_functions and not i.startswith('_') and '__call__' in dir(e): try: + old_attrs[m][i] = e setattr(m, i, check_wrapper(e)) except Exception as e: print(i) @@ -286,6 +289,13 @@ def attribute(m): # guide https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators. # +###################################################################### +# Code below is to recover the attributes of torch. + +for (m, attrs) in old_attrs.items(): + for (k,v) in attrs.items(): + setattr(m, k, v) + ###################################################################### # Work to do # ----------