Skip to content

Commit c8f233c

Browse files
gh-132805: annotationlib: Fix handling of non-constant values in FORWARDREF (#132812)
Co-authored-by: David C Ellis <ducksual@gmail.com>
1 parent 7cb86c5 commit c8f233c

File tree

3 files changed

+250
-43
lines changed

3 files changed

+250
-43
lines changed

Lib/annotationlib.py

Lines changed: 132 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Format(enum.IntEnum):
3838
"__weakref__",
3939
"__arg__",
4040
"__globals__",
41+
"__extra_names__",
4142
"__code__",
4243
"__ast_node__",
4344
"__cell__",
@@ -82,6 +83,7 @@ def __init__(
8283
# is created through __class__ assignment on a _Stringifier object.
8384
self.__globals__ = None
8485
self.__cell__ = None
86+
self.__extra_names__ = None
8587
# These are initially None but serve as a cache and may be set to a non-None
8688
# value later.
8789
self.__code__ = None
@@ -151,6 +153,8 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
151153
if not self.__forward_is_class__ or param_name not in globals:
152154
globals[param_name] = param
153155
locals.pop(param_name, None)
156+
if self.__extra_names__:
157+
locals = {**locals, **self.__extra_names__}
154158

155159
arg = self.__forward_arg__
156160
if arg.isidentifier() and not keyword.iskeyword(arg):
@@ -231,6 +235,10 @@ def __eq__(self, other):
231235
and self.__forward_is_class__ == other.__forward_is_class__
232236
and self.__cell__ == other.__cell__
233237
and self.__owner__ == other.__owner__
238+
and (
239+
(tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None) ==
240+
(tuple(sorted(other.__extra_names__.items())) if other.__extra_names__ else None)
241+
)
234242
)
235243

236244
def __hash__(self):
@@ -241,6 +249,7 @@ def __hash__(self):
241249
self.__forward_is_class__,
242250
self.__cell__,
243251
self.__owner__,
252+
tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None,
244253
))
245254

246255
def __or__(self, other):
@@ -274,6 +283,7 @@ def __init__(
274283
cell=None,
275284
*,
276285
stringifier_dict,
286+
extra_names=None,
277287
):
278288
# Either an AST node or a simple str (for the common case where a ForwardRef
279289
# represent a single name).
@@ -285,49 +295,91 @@ def __init__(
285295
self.__code__ = None
286296
self.__ast_node__ = node
287297
self.__globals__ = globals
298+
self.__extra_names__ = extra_names
288299
self.__cell__ = cell
289300
self.__owner__ = owner
290301
self.__stringifier_dict__ = stringifier_dict
291302

292303
def __convert_to_ast(self, other):
293304
if isinstance(other, _Stringifier):
294305
if isinstance(other.__ast_node__, str):
295-
return ast.Name(id=other.__ast_node__)
296-
return other.__ast_node__
297-
elif isinstance(other, slice):
306+
return ast.Name(id=other.__ast_node__), other.__extra_names__
307+
return other.__ast_node__, other.__extra_names__
308+
elif (
309+
# In STRING format we don't bother with the create_unique_name() dance;
310+
# it's better to emit the repr() of the object instead of an opaque name.
311+
self.__stringifier_dict__.format == Format.STRING
312+
or other is None
313+
or type(other) in (str, int, float, bool, complex)
314+
):
315+
return ast.Constant(value=other), None
316+
elif type(other) is dict:
317+
extra_names = {}
318+
keys = []
319+
values = []
320+
for key, value in other.items():
321+
new_key, new_extra_names = self.__convert_to_ast(key)
322+
if new_extra_names is not None:
323+
extra_names.update(new_extra_names)
324+
keys.append(new_key)
325+
new_value, new_extra_names = self.__convert_to_ast(value)
326+
if new_extra_names is not None:
327+
extra_names.update(new_extra_names)
328+
values.append(new_value)
329+
return ast.Dict(keys, values), extra_names
330+
elif type(other) in (list, tuple, set):
331+
extra_names = {}
332+
elts = []
333+
for elt in other:
334+
new_elt, new_extra_names = self.__convert_to_ast(elt)
335+
if new_extra_names is not None:
336+
extra_names.update(new_extra_names)
337+
elts.append(new_elt)
338+
ast_class = {list: ast.List, tuple: ast.Tuple, set: ast.Set}[type(other)]
339+
return ast_class(elts), extra_names
340+
else:
341+
name = self.__stringifier_dict__.create_unique_name()
342+
return ast.Name(id=name), {name: other}
343+
344+
def __convert_to_ast_getitem(self, other):
345+
if isinstance(other, slice):
346+
extra_names = {}
347+
348+
def conv(obj):
349+
if obj is None:
350+
return None
351+
new_obj, new_extra_names = self.__convert_to_ast(obj)
352+
if new_extra_names is not None:
353+
extra_names.update(new_extra_names)
354+
return new_obj
355+
298356
return ast.Slice(
299-
lower=(
300-
self.__convert_to_ast(other.start)
301-
if other.start is not None
302-
else None
303-
),
304-
upper=(
305-
self.__convert_to_ast(other.stop)
306-
if other.stop is not None
307-
else None
308-
),
309-
step=(
310-
self.__convert_to_ast(other.step)
311-
if other.step is not None
312-
else None
313-
),
314-
)
357+
lower=conv(other.start),
358+
upper=conv(other.stop),
359+
step=conv(other.step),
360+
), extra_names
315361
else:
316-
return ast.Constant(value=other)
362+
return self.__convert_to_ast(other)
317363

318364
def __get_ast(self):
319365
node = self.__ast_node__
320366
if isinstance(node, str):
321367
return ast.Name(id=node)
322368
return node
323369

324-
def __make_new(self, node):
370+
def __make_new(self, node, extra_names=None):
371+
new_extra_names = {}
372+
if self.__extra_names__ is not None:
373+
new_extra_names.update(self.__extra_names__)
374+
if extra_names is not None:
375+
new_extra_names.update(extra_names)
325376
stringifier = _Stringifier(
326377
node,
327378
self.__globals__,
328379
self.__owner__,
329380
self.__forward_is_class__,
330381
stringifier_dict=self.__stringifier_dict__,
382+
extra_names=new_extra_names or None,
331383
)
332384
self.__stringifier_dict__.stringifiers.append(stringifier)
333385
return stringifier
@@ -343,27 +395,37 @@ def __getitem__(self, other):
343395
if self.__ast_node__ == "__classdict__":
344396
raise KeyError
345397
if isinstance(other, tuple):
346-
elts = [self.__convert_to_ast(elt) for elt in other]
398+
extra_names = {}
399+
elts = []
400+
for elt in other:
401+
new_elt, new_extra_names = self.__convert_to_ast_getitem(elt)
402+
if new_extra_names is not None:
403+
extra_names.update(new_extra_names)
404+
elts.append(new_elt)
347405
other = ast.Tuple(elts)
348406
else:
349-
other = self.__convert_to_ast(other)
407+
other, extra_names = self.__convert_to_ast_getitem(other)
350408
assert isinstance(other, ast.AST), repr(other)
351-
return self.__make_new(ast.Subscript(self.__get_ast(), other))
409+
return self.__make_new(ast.Subscript(self.__get_ast(), other), extra_names)
352410

353411
def __getattr__(self, attr):
354412
return self.__make_new(ast.Attribute(self.__get_ast(), attr))
355413

356414
def __call__(self, *args, **kwargs):
357-
return self.__make_new(
358-
ast.Call(
359-
self.__get_ast(),
360-
[self.__convert_to_ast(arg) for arg in args],
361-
[
362-
ast.keyword(key, self.__convert_to_ast(value))
363-
for key, value in kwargs.items()
364-
],
365-
)
366-
)
415+
extra_names = {}
416+
ast_args = []
417+
for arg in args:
418+
new_arg, new_extra_names = self.__convert_to_ast(arg)
419+
if new_extra_names is not None:
420+
extra_names.update(new_extra_names)
421+
ast_args.append(new_arg)
422+
ast_kwargs = []
423+
for key, value in kwargs.items():
424+
new_value, new_extra_names = self.__convert_to_ast(value)
425+
if new_extra_names is not None:
426+
extra_names.update(new_extra_names)
427+
ast_kwargs.append(ast.keyword(key, new_value))
428+
return self.__make_new(ast.Call(self.__get_ast(), ast_args, ast_kwargs), extra_names)
367429

368430
def __iter__(self):
369431
yield self.__make_new(ast.Starred(self.__get_ast()))
@@ -378,8 +440,9 @@ def __format__(self, format_spec):
378440

379441
def _make_binop(op: ast.AST):
380442
def binop(self, other):
443+
rhs, extra_names = self.__convert_to_ast(other)
381444
return self.__make_new(
382-
ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
445+
ast.BinOp(self.__get_ast(), op, rhs), extra_names
383446
)
384447

385448
return binop
@@ -402,8 +465,9 @@ def binop(self, other):
402465

403466
def _make_rbinop(op: ast.AST):
404467
def rbinop(self, other):
468+
new_other, extra_names = self.__convert_to_ast(other)
405469
return self.__make_new(
406-
ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
470+
ast.BinOp(new_other, op, self.__get_ast()), extra_names
407471
)
408472

409473
return rbinop
@@ -426,12 +490,14 @@ def rbinop(self, other):
426490

427491
def _make_compare(op):
428492
def compare(self, other):
493+
rhs, extra_names = self.__convert_to_ast(other)
429494
return self.__make_new(
430495
ast.Compare(
431496
left=self.__get_ast(),
432497
ops=[op],
433-
comparators=[self.__convert_to_ast(other)],
434-
)
498+
comparators=[rhs],
499+
),
500+
extra_names,
435501
)
436502

437503
return compare
@@ -459,13 +525,15 @@ def unary_op(self):
459525

460526

461527
class _StringifierDict(dict):
462-
def __init__(self, namespace, globals=None, owner=None, is_class=False):
528+
def __init__(self, namespace, *, globals=None, owner=None, is_class=False, format):
463529
super().__init__(namespace)
464530
self.namespace = namespace
465531
self.globals = globals
466532
self.owner = owner
467533
self.is_class = is_class
468534
self.stringifiers = []
535+
self.next_id = 1
536+
self.format = format
469537

470538
def __missing__(self, key):
471539
fwdref = _Stringifier(
@@ -478,6 +546,11 @@ def __missing__(self, key):
478546
self.stringifiers.append(fwdref)
479547
return fwdref
480548

549+
def create_unique_name(self):
550+
name = f"__annotationlib_name_{self.next_id}__"
551+
self.next_id += 1
552+
return name
553+
481554

482555
def call_evaluate_function(evaluate, format, *, owner=None):
483556
"""Call an evaluate function. Evaluate functions are normally generated for
@@ -521,7 +594,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
521594
# possibly constants if the annotate function uses them directly). We then
522595
# convert each of those into a string to get an approximation of the
523596
# original source.
524-
globals = _StringifierDict({})
597+
globals = _StringifierDict({}, format=format)
525598
if annotate.__closure__:
526599
freevars = annotate.__code__.co_freevars
527600
new_closure = []
@@ -544,9 +617,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
544617
)
545618
annos = func(Format.VALUE_WITH_FAKE_GLOBALS)
546619
if _is_evaluate:
547-
return annos if isinstance(annos, str) else repr(annos)
620+
return _stringify_single(annos)
548621
return {
549-
key: val if isinstance(val, str) else repr(val)
622+
key: _stringify_single(val)
550623
for key, val in annos.items()
551624
}
552625
elif format == Format.FORWARDREF:
@@ -569,7 +642,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
569642
# that returns a bool and an defined set of attributes.
570643
namespace = {**annotate.__builtins__, **annotate.__globals__}
571644
is_class = isinstance(owner, type)
572-
globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class)
645+
globals = _StringifierDict(
646+
namespace,
647+
globals=annotate.__globals__,
648+
owner=owner,
649+
is_class=is_class,
650+
format=format,
651+
)
573652
if annotate.__closure__:
574653
freevars = annotate.__code__.co_freevars
575654
new_closure = []
@@ -619,6 +698,16 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
619698
raise ValueError(f"Invalid format: {format!r}")
620699

621700

701+
def _stringify_single(anno):
702+
if anno is ...:
703+
return "..."
704+
# We have to handle str specially to support PEP 563 stringified annotations.
705+
elif isinstance(anno, str):
706+
return anno
707+
else:
708+
return repr(anno)
709+
710+
622711
def get_annotate_from_class_namespace(obj):
623712
"""Retrieve the annotate function from a class namespace dictionary.
624713

0 commit comments

Comments
 (0)