diff --git a/lua/nvim-tree/lib.lua b/lua/nvim-tree/lib.lua index 66ab0ed5188..9f0a2c45748 100644 --- a/lua/nvim-tree/lib.lua +++ b/lua/nvim-tree/lib.lua @@ -74,11 +74,21 @@ end -- If node is grouped, return the last node in the group. Otherwise, return the given node. function M.get_last_group_node(node) - local next = node - while next.group_next do - next = next.group_next + local next_node = node + while next_node.group_next do + next_node = next_node.group_next end - return next + return next_node +end + +function M.get_all_nodes_in_group(node) + local next_node = utils.get_parent_of_group(node) + local nodes = {} + while next_node do + table.insert(nodes, next_node) + next_node = next_node.group_next + end + return nodes end function M.expand_or_collapse(node) @@ -90,8 +100,10 @@ function M.expand_or_collapse(node) core.get_explorer():expand(node) end - node = M.get_last_group_node(node) - node.open = not node.open + local open = not M.get_last_group_node(node).open + for _, n in ipairs(M.get_all_nodes_in_group(node)) do + n.open = open + end renderer.draw() end