Skip to content

Add Transformer API for in-place manipulation of AST #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions fluent.syntax/fluent/syntax/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ class Visitor(object):
The boolean value of the returned value determines if the visitor
descends into the children of the given AST node.
'''
def visit(self, value):
if isinstance(value, BaseNode):
self.visit_node(value)
if isinstance(value, list):
for node in value:
self.visit(node)

def visit_node(self, node):
def visit(self, node):
if isinstance(node, list):
for child in node:
self.visit(child)
return
if not isinstance(node, BaseNode):
return
nodename = type(node).__name__
visit = getattr(self, 'visit_{}'.format(nodename), self.generic_visit)
should_descend = visit(node)
Expand All @@ -33,6 +32,41 @@ def generic_visit(self, node):
return True


class Transformer(Visitor):
'''In-place AST Transformer pattern.

Subclass this to create an in-place modified variant
of the given AST.
If you need to keep the original AST around, pass
a `node.clone()` to the transformer.
'''
def visit(self, node):
if not isinstance(node, BaseNode):
return node

nodename = type(node).__name__
visit = getattr(self, 'visit_{}'.format(nodename), self.generic_visit)
return visit(node)

def generic_visit(self, node):
for propname, propvalue in vars(node).items():
if isinstance(propvalue, list):
new_vals = []
for child in propvalue:
new_val = self.visit(child)
if new_val is not None:
new_vals.append(new_val)
# in-place manipulation
propvalue[:] = new_vals
elif isinstance(propvalue, BaseNode):
new_val = self.visit(propvalue)
if new_val is None:
delattr(node, propname)
else:
setattr(node, propname, new_val)
return node


def to_json(value, fn=None):
if isinstance(value, BaseNode):
return value.to_json(fn)
Expand Down Expand Up @@ -79,7 +113,9 @@ class BaseNode(object):
"""

def traverse(self, fun):
"""Postorder-traverse this node and apply `fun` to all child nodes.
"""DEPRECATED. Please use Visitor or Transformer.

Postorder-traverse this node and apply `fun` to all child nodes.

Traverse this node depth-first applying `fun` to subnodes and leaves.
Children are processed before parents (postorder traversal).
Expand All @@ -103,6 +139,23 @@ def visit(value):

return fun(node)

def clone(self):
"""Create a deep clone of the current node."""
def visit(value):
"""Clone node and its descendants."""
if isinstance(value, BaseNode):
return value.clone()
if isinstance(value, list):
return [visit(child) for child in value]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that at some point we add a tuple check here: cf564c3. I think the goal was to support Annotations' args field which is a tuple.

We should probably fix it in

class ParseError(Exception):
def __init__(self, code, *args):
self.code = code
self.args = args
self.message = get_error_message(code, args)
, but perhaps it's safer to added the check here first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(value, tuple):
return tuple(visit(child) for child in value)
return value

# Use all attributes found on the node as kwargs to the constructor.
return self.__class__(
**{name: visit(value) for name, value in vars(self).items()}
)

def equals(self, other, ignored_fields=['span']):
"""Compare two nodes.

Expand Down
6 changes: 6 additions & 0 deletions fluent.syntax/tests/syntax/test_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_same_simple_message(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))

def test_same_selector_message(self):
message1 = self.parse_ftl_entry("""\
Expand All @@ -41,6 +42,7 @@ def test_same_selector_message(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))

def test_same_complex_placeable_message(self):
message1 = self.parse_ftl_entry("""\
Expand All @@ -49,6 +51,7 @@ def test_same_complex_placeable_message(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))

def test_same_message_with_attribute(self):
message1 = self.parse_ftl_entry("""\
Expand All @@ -58,6 +61,7 @@ def test_same_message_with_attribute(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))

def test_same_message_with_attributes(self):
message1 = self.parse_ftl_entry("""\
Expand All @@ -68,6 +72,7 @@ def test_same_message_with_attributes(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))

def test_same_junk(self):
message1 = self.parse_ftl_entry("""\
Expand All @@ -76,6 +81,7 @@ def test_same_junk(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))


class TestOrderEquals(unittest.TestCase):
Expand Down
81 changes: 80 additions & 1 deletion fluent.syntax/tests/syntax/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ def test_resource(self):
)


class TestTransformer(unittest.TestCase):
def test(self):
resource = FluentParser().parse(dedent_ftl('''\
one = Message
two = Messages
three = Has a
.an = Message string in the Attribute
'''))
prior_res_id = id(resource)
prior_msg_id = id(resource.body[1].value)
backup = resource.clone()
transformed = ReplaceTransformer('Message', 'Term').visit(resource)
self.assertEqual(prior_res_id, id(transformed))
self.assertEqual(
prior_msg_id,
id(transformed.body[1].value)
)
self.assertFalse(transformed.equals(backup))
self.assertEqual(
transformed.body[1].value.elements[0].value,
'Terms'
)


class WordCounter(object):
def __init__(self):
self.word_count = 0
Expand All @@ -70,6 +94,34 @@ def visit_TextElement(self, node):
return False


class ReplaceText(object):
def __init__(self, before, after):
self.before = before
self.after = after

def __call__(self, node):
"""Perform find and replace on text values only"""
if type(node) == ast.TextElement:
node.value = node.value.replace(self.before, self.after)
return node


class ReplaceTransformer(ast.Transformer):
def __init__(self, before, after):
self.before = before
self.after = after

def generic_visit(self, node):
if isinstance(node, (ast.Span, ast.Annotation)):
return node
return super(ReplaceTransformer, self).generic_visit(node)

def visit_TextElement(self, node):
"""Perform find and replace on text values only"""
node.value = node.value.replace(self.before, self.after)
return node


class TestPerf(unittest.TestCase):
def setUp(self):
parser = FluentParser()
Expand All @@ -89,6 +141,27 @@ def test_visitor(self):
counter.visit(self.resource)
self.assertEqual(counter.word_count, 277)

def test_edit_traverse(self):
edited = self.resource.traverse(ReplaceText('Tab', 'Reiter'))
self.assertEqual(
edited.body[4].attributes[0].value.elements[0].value,
'New Reiter'
)

def test_edit_transform(self):
edited = ReplaceTransformer('Tab', 'Reiter').visit(self.resource)
self.assertEqual(
edited.body[4].attributes[0].value.elements[0].value,
'New Reiter'
)

def test_edit_cloned(self):
edited = ReplaceTransformer('Tab', 'Reiter').visit(self.resource.clone())
self.assertEqual(
edited.body[4].attributes[0].value.elements[0].value,
'New Reiter'
)


def gather_stats(method, repeat=10, number=50):
t = timeit.Timer(
Expand All @@ -107,7 +180,13 @@ def gather_stats(method, repeat=10, number=50):


if __name__=='__main__':
for m in ('traverse', 'visitor'):
for m in (
'traverse',
'visitor',
'edit_traverse',
'edit_transform',
'edit_cloned',
):
results = gather_stats(m)
try:
import statistics
Expand Down