Skip to content

Commit 9287a0c

Browse files
authored
[Release/2.1][JIT] Fix typed enum handling in 3.11 (#109807)
In Python-3.11+ typed enums (such as `enum.IntEnum`) retain `__new__`,`__str__` and so on method of the base class via `__init__subclass__()` method (see https://docs.python.org/3/whatsnew/3.11.html#enum ), i.e. following code ```python import sys import inspect from enum import Enum class IntColor(int, Enum): RED = 1 GREEN = 2 class Color(Enum): RED = 1 GREEN = 2 def get_methods(cls): def predicate(m): if not inspect.isfunction(m) and not inspect.ismethod(m): return False return m.__name__ in cls.__dict__ return inspect.getmembers(cls, predicate=predicate) if __name__ == "__main__": print(sys.version) print(f"IntColor methods {get_methods(IntColor)}") print(f"Color methods {get_methods(Color)}") ``` Returns empty list for both cases for older Python, but on Python-3.11+ it returns list contains of enum constructors and others: ```shell % conda run -n py310 python bar.py 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:41:52) [Clang 15.0.7 ] IntColor methods [] Color methods [] % conda run -n py311 python bar.py 3.11.0 | packaged by conda-forge | (main, Oct 25 2022, 06:21:25) [Clang 14.0.4 ] IntColor methods [('__format__', <function Enum.__format__ at 0x105006ac0>), ('__new__', <function Enum.__new__ at 0x105006660>), ('__repr__', <function Enum.__repr__ at 0x1050068e0>)] Color methods [] ``` This change allows typed enums to be scriptable on 3.11, by explicitly marking several `enum.Enum` method to be dropped by jit script and adds test that typed enums are jit-scriptable. Fixes #108933 Cherry-pick of #109717 into release/2.1 branch. Approved by: https://github.com/atalman, https://github.com/davidberard98 (cherry picked from commit 55685d5)
1 parent c464075 commit 9287a0c

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

test/jit/test_enum.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,13 @@ class Color(Enum):
362362
GREEN = 2
363363

364364
torch.jit.script(Color)
365+
366+
# Regression test for https://github.com/pytorch/pytorch/issues/108933
367+
def test_typed_enum(self):
368+
class Color(int, Enum):
369+
RED = 1
370+
GREEN = 2
371+
372+
@torch.jit.script
373+
def is_red(x: Color) -> bool:
374+
return x == Color.RED

torch/_jit_internal.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def createResolutionCallbackForClassMethods(cls):
450450
# Skip built-ins, as they do not have global scope nor type hints
451451
# Needed to support `enum.Enum` derived classes in Python-3.11
452452
# That adds `_new_member_` property which is an alias to `__new__`
453-
fns = [fn for fn in fns if not inspect.isbuiltin(fn)]
453+
fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
454454
captures = {}
455455

456456
for fn in fns:
@@ -1491,3 +1491,13 @@ def _extract_tensors(obj):
14911491
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
14921492
extractor.dump(obj)
14931493
return tensors
1494+
1495+
1496+
# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
1497+
# that were previously dropped. To preserve the behavior, explicitly drop them there
1498+
1499+
if sys.version_info > (3, 10):
1500+
_drop(enum.Enum.__new__)
1501+
_drop(enum.Enum.__format__)
1502+
_drop(enum.Enum.__repr__)
1503+
_drop(enum.Enum.__str__)

0 commit comments

Comments
 (0)