From ed086e77985672f224be37a84b13fe10b0e77ef3 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 12 Dec 2024 23:26:00 -0800 Subject: [PATCH 01/11] feat: add preallocation utilities for expression --- src/Expression.jl | 8 +++++++ src/ParametricExpression.jl | 29 ++++++++++++++++++++++-- src/base.jl | 45 +++++++++++++++++++++++++++++-------- 3 files changed, 71 insertions(+), 11 deletions(-) diff --git a/src/Expression.jl b/src/Expression.jl index be927269..a5d512a0 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -317,6 +317,14 @@ function extract_gradient( return extract_gradient(gradient.tree, get_tree(ex)) end +function preallocate_expression(prototype::Expression, n::Union{Nothing,Integer}=nothing) + return (; tree=preallocate_expression(DE.get_contents(prototype), n)) +end +function DE.copy_node!(dest::NamedTuple, src::Expression) + tree = DE.copy_node!(dest.tree, DE.get_contents(src)) + return DE.with_contents(src, tree) +end + """ string_tree( ex::AbstractExpression, diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 60f5fb41..b2be8440 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -4,7 +4,8 @@ using DispatchDoctor: @stable, @unstable using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum -using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce +using ..NodeModule: + AbstractExpressionNode, Node, tree_mapreduce, with_contents, with_metadata using ..ExpressionModule: AbstractExpression, Metadata using ..ChainRulesModule: NodeTangent @@ -17,7 +18,9 @@ import ..NodeModule: leaf_convert, leaf_hash, leaf_equal, - branch_copy! + branch_copy!, + copy_node!, + preallocate_expression import ..NodeUtilsModule: count_constant_nodes, index_constant_nodes, @@ -444,6 +447,28 @@ end return node_type(; val=ex) end end +function preallocate_expression( + prototype::ParametricExpression, n::Union{Nothing,Integer}=nothing +) + return (; + tree=preallocate_expression(get_contents(prototype), n), + parameters=similar(get_metadata(prototype).parameters), + ) +end +function copy_node!(dest::NamedTuple, src::ParametricExpression) + new_tree = copy_node!(dest.tree, get_contents(src)) + metadata = get_metadata(src) + new_parameters = dest.parameters + new_parameters .= metadata.parameters + new_metadata = Metadata((; + operators=metadata.operators, + variable_names=metadata.variable_names, + parameters=new_parameters, + parameter_names=metadata.parameter_names, + )) + # TODO: Better interface for this^ + return with_metadata(with_contents(src, new_tree), new_metadata) +end ############################################################################### end diff --git a/src/base.jl b/src/base.jl index 8a656ab3..28d54b93 100644 --- a/src/base.jl +++ b/src/base.jl @@ -488,23 +488,25 @@ end # In-place versions """ - copy_node!(dest::AbstractArray{N}, src::N; break_sharing::Val{BS}=Val(false)) where {BS,N<:AbstractExpressionNode} + copy_node!(dest::AbstractArray{N}, src::N) where {BS,N<:AbstractExpressionNode} Copy a node, recursively copying all children nodes, in-place to an array of pre-allocated nodes. This should result in no extra allocations. """ function copy_node!( - dest::AbstractArray{N}, - src::N; - break_sharing::Val{BS}=Val(false), - ref::Base.RefValue{<:Integer}=Ref(0), -) where {BS,N<:AbstractExpressionNode} - ref.x = 0 + dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing +) where {N<:AbstractExpressionNode} + _ref = if ref === nothing + Ref(0) + else + ref.x = 0 + ref + end return tree_mapreduce( - leaf -> leaf_copy!(@inbounds(dest[ref.x += 1]), leaf), + leaf -> leaf_copy!(@inbounds(dest[_ref.x += 1]), leaf), identity, ((p, c::Vararg{Any,M}) where {M}) -> - branch_copy!(@inbounds(dest[ref.x += 1]), p, c...), + branch_copy!(@inbounds(dest[_ref.x += 1]), p, c...), src, N; break_sharing=Val(BS), @@ -533,6 +535,31 @@ function branch_copy!( return dest end +""" + preallocate_expression(prototype::AbstractExpressionNode, n=nothing) + +Preallocate an array of empty nodes matching the type of `prototype`. If `n` is provided, use that length, otherwise use `length(prototype)`. + +A given return value of this will be passed to `copy_node!` as the first argument, +so it should be compatible. +""" +function preallocate_expression( + prototype::N, n::Union{Nothing,Integer}=nothing +) where {T,N<:AbstractExpressionNode{T}} + num_nodes = @something(n, length(prototype)) + return N[with_type_parameters(N, T)() for _ in 1:num_nodes] +end + +function copy_node!(::Nothing, src::AbstractExpression) + return copy(src) +end +function preallocate_expression(::AbstractExpression, ::Union{Nothing,Integer}=nothing) + return nothing +end +# We don't require users to overload this, as it's not part of the required interface. +# Also, there's no way to generally do this from the required interface, so for backwards +# compatibility, we just return nothing. + """ copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false)) From 1f97a0b08ab354d30437d77dd021a166071b38b4 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 13 Dec 2024 11:07:47 -0800 Subject: [PATCH 02/11] feat: rename to `allocate_container` and `copy_into!` --- src/DynamicExpressions.jl | 3 +- src/Expression.jl | 25 +++++++++---- src/Interfaces.jl | 20 +++++----- src/NodePreallocation.jl | 68 +++++++++++++++++++++++++++++++++ src/ParametricExpression.jl | 15 ++++---- src/base.jl | 75 ------------------------------------- test/test_copy_inplace.jl | 12 +++--- 7 files changed, 110 insertions(+), 108 deletions(-) create mode 100644 src/NodePreallocation.jl diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 8e32d899..6c0ba5f8 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -9,6 +9,7 @@ using DispatchDoctor: @stable, @unstable include("OperatorEnum.jl") include("Node.jl") include("NodeUtils.jl") + include("NodePreallocation.jl") include("Strings.jl") include("Evaluate.jl") include("EvaluateDerivative.jl") @@ -41,11 +42,11 @@ import .ValueInterfaceModule: GraphNode, Node, copy_node, - copy_node!, set_node!, tree_mapreduce, filter_map, filter_map! +import .NodePreallocationModule: allocate_container, copy_into! import .NodeModule: constructorof, with_type_parameters, diff --git a/src/Expression.jl b/src/Expression.jl index a5d512a0..32cd2bce 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -317,14 +317,6 @@ function extract_gradient( return extract_gradient(gradient.tree, get_tree(ex)) end -function preallocate_expression(prototype::Expression, n::Union{Nothing,Integer}=nothing) - return (; tree=preallocate_expression(DE.get_contents(prototype), n)) -end -function DE.copy_node!(dest::NamedTuple, src::Expression) - tree = DE.copy_node!(dest.tree, DE.get_contents(src)) - return DE.with_contents(src, tree) -end - """ string_tree( ex::AbstractExpression, @@ -510,4 +502,21 @@ function (ex::AbstractExpression)( return get_tree(ex)(X, get_operators(ex, operators); kws...) end +# We don't require users to overload this, as it's not part of the required interface. +# Also, there's no way to generally do this from the required interface, so for backwards +# compatibility, we just return nothing. +function copy_into!(::Nothing, src::AbstractExpression) + return copy(src) +end +function allocate_container(::AbstractExpression, ::Union{Nothing,Integer}=nothing) + return nothing +end +function allocate_container(prototype::Expression, n::Union{Nothing,Integer}=nothing) + return (; tree=allocate_container(get_contents(prototype), n)) +end +function copy_into!(dest::NamedTuple, src::Expression) + tree = copy_into!(dest.tree, get_contents(src)) + return with_contents(src, tree) +end + end diff --git a/src/Interfaces.jl b/src/Interfaces.jl index b950ec97..52031031 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -13,12 +13,12 @@ using ..NodeModule: default_allocator, with_type_parameters, leaf_copy, - leaf_copy!, + leaf_copy_into!, leaf_convert, leaf_hash, leaf_equal, branch_copy, - branch_copy!, + branch_copy_into!, branch_convert, branch_hash, branch_equal, @@ -264,10 +264,10 @@ function _check_leaf_copy(tree::AbstractExpressionNode) tree.degree != 0 && return true return leaf_copy(tree) isa typeof(tree) end -function _check_leaf_copy!(tree::AbstractExpressionNode{T}) where {T} +function _check_leaf_copy_into!(tree::AbstractExpressionNode{T}) where {T} tree.degree != 0 && return true new_leaf = constructorof(typeof(tree))(; val=zero(T)) - ret = leaf_copy!(new_leaf, tree) + ret = leaf_copy_into!(new_leaf, tree) return new_leaf == tree && ret === new_leaf end function _check_leaf_convert(tree::AbstractExpressionNode) @@ -292,16 +292,16 @@ function _check_branch_copy(tree::AbstractExpressionNode) return branch_copy(tree, tree.l, tree.r) isa typeof(tree) end end -function _check_branch_copy!(tree::AbstractExpressionNode{T}) where {T} +function _check_branch_copy_into!(tree::AbstractExpressionNode{T}) where {T} if tree.degree == 0 return true end new_branch = constructorof(typeof(tree))(; val=zero(T)) if tree.degree == 1 - ret = branch_copy!(new_branch, tree, copy(tree.l)) + ret = branch_copy_into!(new_branch, tree, copy(tree.l)) return new_branch == tree && ret === new_branch else - ret = branch_copy!(new_branch, tree, copy(tree.l), copy(tree.r)) + ret = branch_copy_into!(new_branch, tree, copy(tree.l), copy(tree.r)) return new_branch == tree && ret === new_branch end end @@ -373,12 +373,12 @@ ni_components = ( ), optional = ( leaf_copy = "copies a leaf node" => _check_leaf_copy, - leaf_copy! = "copies a leaf node in-place" => _check_leaf_copy!, + leaf_copy_into! = "copies a leaf node in-place" => _check_leaf_copy_into!, leaf_convert = "converts a leaf node" => _check_leaf_convert, leaf_hash = "computes the hash of a leaf node" => _check_leaf_hash, leaf_equal = "checks equality of two leaf nodes" => _check_leaf_equal, branch_copy = "copies a branch node" => _check_branch_copy, - branch_copy! = "copies a branch node in-place" => _check_branch_copy!, + branch_copy_into! = "copies a branch node in-place" => _check_branch_copy_into!, branch_convert = "converts a branch node" => _check_branch_convert, branch_hash = "computes the hash of a branch node" => _check_branch_hash, branch_equal = "checks equality of two branch nodes" => _check_branch_equal, @@ -419,7 +419,7 @@ ni_description = ( [Arguments()] ) @implements( - NodeInterface{all_ni_methods_except((:leaf_copy!, :branch_copy!))}, + NodeInterface{all_ni_methods_except((:leaf_copy_into!, :branch_copy_into!))}, GraphNode, [Arguments()] ) diff --git a/src/NodePreallocation.jl b/src/NodePreallocation.jl new file mode 100644 index 00000000..6fb1d012 --- /dev/null +++ b/src/NodePreallocation.jl @@ -0,0 +1,68 @@ +module NodePreallocationModule + +using ..NodeModule: + AbstractExpressionNode, + with_type_parameters, + tree_mapreduce, + leaf_copy, + branch_copy, + set_node! + +""" + copy_into!(dest::AbstractArray{N}, src::N) where {BS,N<:AbstractExpressionNode} + +Copy a node, recursively copying all children nodes, in-place to an +array of pre-allocated nodes. This should result in no extra allocations. +""" +function copy_into!( + dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing +) where {N<:AbstractExpressionNode} + _ref = if ref === nothing + Ref(0) + else + ref.x = 0 + ref + end + return tree_mapreduce( + leaf -> leaf_copy_into!(@inbounds(dest[_ref.x += 1]), leaf), + identity, + ((p, c::Vararg{Any,M}) where {M}) -> + branch_copy_into!(@inbounds(dest[_ref.x += 1]), p, c...), + src, + N; + break_sharing=Val(BS), + ) +end +function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode} + set_node!(dest, src) + return dest +end +function branch_copy_into!( + dest::N, src::N, children::Vararg{N,M} +) where {N<:AbstractExpressionNode,M} + dest.degree = M + dest.op = src.op + dest.l = children[1] + if M == 2 + dest.r = children[2] + end + return dest +end + +""" + allocate_container(prototype::AbstractExpressionNode, n=nothing) + +Preallocate an array of `n` empty nodes matching the type of `prototype`. +If `n` is not provided, it will be computed from `length(prototype)`. + +A given return value of this will be passed to `copy_into!` as the first argument, +so it should be compatible. +""" +function allocate_container( + prototype::N, n::Union{Nothing,Integer}=nothing +) where {T,N<:AbstractExpressionNode{T}} + num_nodes = @something(n, length(prototype)) + return N[with_type_parameters(N, T)() for _ in 1:num_nodes] +end + +end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index b2be8440..3ed506cd 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -14,13 +14,12 @@ import ..NodeModule: with_type_parameters, preserve_sharing, leaf_copy, - leaf_copy!, leaf_convert, leaf_hash, leaf_equal, - branch_copy!, - copy_node!, - preallocate_expression + set_node!, + copy_into!, + allocate_container import ..NodeUtilsModule: count_constant_nodes, index_constant_nodes, @@ -447,16 +446,16 @@ end return node_type(; val=ex) end end -function preallocate_expression( +function allocate_container( prototype::ParametricExpression, n::Union{Nothing,Integer}=nothing ) return (; - tree=preallocate_expression(get_contents(prototype), n), + tree=allocate_container(get_contents(prototype), n), parameters=similar(get_metadata(prototype).parameters), ) end -function copy_node!(dest::NamedTuple, src::ParametricExpression) - new_tree = copy_node!(dest.tree, get_contents(src)) +function copy_into!(dest::NamedTuple, src::ParametricExpression) + new_tree = copy_into!(dest.tree, get_contents(src)) metadata = get_metadata(src) new_parameters = dest.parameters new_parameters .= metadata.parameters diff --git a/src/base.jl b/src/base.jl index 28d54b93..3d29a404 100644 --- a/src/base.jl +++ b/src/base.jl @@ -485,81 +485,6 @@ function branch_copy(t::N, children::Vararg{Any,M}) where {T,N<:AbstractExpressi return constructorof(N)(T; op=t.op, children) end -# In-place versions - -""" - copy_node!(dest::AbstractArray{N}, src::N) where {BS,N<:AbstractExpressionNode} - -Copy a node, recursively copying all children nodes, in-place to an -array of pre-allocated nodes. This should result in no extra allocations. -""" -function copy_node!( - dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing -) where {N<:AbstractExpressionNode} - _ref = if ref === nothing - Ref(0) - else - ref.x = 0 - ref - end - return tree_mapreduce( - leaf -> leaf_copy!(@inbounds(dest[_ref.x += 1]), leaf), - identity, - ((p, c::Vararg{Any,M}) where {M}) -> - branch_copy!(@inbounds(dest[_ref.x += 1]), p, c...), - src, - N; - break_sharing=Val(BS), - ) -end -function leaf_copy!(dest::N, src::N) where {T,N<:AbstractExpressionNode{T}} - dest.degree = 0 - if src.constant - dest.constant = true - dest.val = src.val - else - dest.constant = false - dest.feature = src.feature - end - return dest -end -function branch_copy!( - dest::N, src::N, children::Vararg{N,M} -) where {T,N<:AbstractExpressionNode{T},M} - dest.degree = M - dest.op = src.op - dest.l = children[1] - if M == 2 - dest.r = children[2] - end - return dest -end - -""" - preallocate_expression(prototype::AbstractExpressionNode, n=nothing) - -Preallocate an array of empty nodes matching the type of `prototype`. If `n` is provided, use that length, otherwise use `length(prototype)`. - -A given return value of this will be passed to `copy_node!` as the first argument, -so it should be compatible. -""" -function preallocate_expression( - prototype::N, n::Union{Nothing,Integer}=nothing -) where {T,N<:AbstractExpressionNode{T}} - num_nodes = @something(n, length(prototype)) - return N[with_type_parameters(N, T)() for _ in 1:num_nodes] -end - -function copy_node!(::Nothing, src::AbstractExpression) - return copy(src) -end -function preallocate_expression(::AbstractExpression, ::Union{Nothing,Integer}=nothing) - return nothing -end -# We don't require users to overload this, as it's not part of the required interface. -# Also, there's no way to generally do this from the required interface, so for backwards -# compatibility, we just return nothing. - """ copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false)) diff --git a/test/test_copy_inplace.jl b/test/test_copy_inplace.jl index 4b337aec..9e2b9514 100644 --- a/test/test_copy_inplace.jl +++ b/test/test_copy_inplace.jl @@ -1,6 +1,6 @@ -@testitem "copy_node! - random trees" begin +@testitem "copy_into! - random trees" begin using DynamicExpressions - using DynamicExpressions: copy_node! + using DynamicExpressions: copy_into! include("tree_gen_utils.jl") operators = OperatorEnum(; binary_operators=[+, *, /], unary_operators=[sin, cos]) @@ -15,7 +15,7 @@ orig_nodes = dest_array[(n_nodes + 1):end] # Save reference to unused nodes ref = Ref(0) - result = copy_node!(dest_array, tree; ref) + result = copy_into!(dest_array, tree; ref) @test ref[] == n_nodes # Increment once per node @@ -35,9 +35,9 @@ end end -@testitem "copy_node! - leaf nodes" begin +@testitem "copy_into! - leaf nodes" begin using DynamicExpressions - using DynamicExpressions: copy_node! + using DynamicExpressions: copy_into! leaf_constant = Node{Float64}(; val=1.0) leaf_feature = Node{Float64}(; feature=1) @@ -45,7 +45,7 @@ end for leaf in [leaf_constant, leaf_feature] dest_array = [Node{Float64}() for _ in 1:1] ref = Ref(0) - result = copy_node!(dest_array, leaf; ref=ref) + result = copy_into!(dest_array, leaf; ref=ref) @test ref[] == 1 @test result == leaf @test result === dest_array[1] From facdaaebc64e9db587d2754030210eadb5afee93 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 13 Dec 2024 11:11:15 -0800 Subject: [PATCH 03/11] feat: avoid creating dummy nodes --- src/Node.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index 40667f94..ddceaaa4 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -321,23 +321,12 @@ function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where { return GraphNode{promote_type(T1, T2)} end -# TODO: Verify using this helps with garbage collection -create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N() - """ set_node!(tree::AbstractExpressionNode{T}, new_tree::AbstractExpressionNode{T}) where {T} Set every field of `tree` equal to the corresponding field of `new_tree`. """ function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNode) - # First, ensure we free some memory: - if new_tree.degree < 2 && tree.degree == 2 - tree.r = create_dummy_node(typeof(tree)) - end - if new_tree.degree < 1 && tree.degree >= 1 - tree.l = create_dummy_node(typeof(tree)) - end - tree.degree = new_tree.degree if new_tree.degree == 0 tree.constant = new_tree.constant From 34b150bd367d820d3144225e6bab0145d314dc5d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 13 Dec 2024 11:11:55 -0800 Subject: [PATCH 04/11] fix: add missing `set_node!` for parametric expressions --- src/ParametricExpression.jl | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 3ed506cd..a2ef538a 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -126,21 +126,29 @@ function leaf_copy(t::ParametricNode{T}) where {T} return n end end -function leaf_copy!(dest::N, src::N) where {T,N<:ParametricNode{T}} - dest.degree = 0 - if src.constant - dest.constant = true - dest.val = src.val - elseif !src.is_parameter - dest.constant = false - dest.is_parameter = false - dest.feature = src.feature +function set_node!(tree::ParametricNode, new_tree::ParametricNode) + tree.degree = new_tree.degree + if new_tree.degree == 0 + if new_tree.constant + tree.constant = true + tree.val = new_tree.val + elseif !new_tree.is_parameter + tree.constant = false + tree.is_parameter = false + tree.feature = new_tree.feature + else + tree.constant = false + tree.is_parameter = true + tree.parameter = new_tree.parameter + end else - dest.constant = false - dest.is_parameter = true - dest.parameter = src.parameter + tree.op = new_tree.op + tree.l = new_tree.l + if new_tree.degree == 2 + tree.r = new_tree.r + end end - return dest + return nothing end function leaf_convert(::Type{N}, t::ParametricNode) where {T,N<:ParametricNode{T}} if t.constant From 95f2bdbce4080ba263832b6b326cd70cd14a6ab3 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 13 Dec 2024 11:12:42 -0800 Subject: [PATCH 05/11] feat: add `copy_into!` for GraphNode --- src/Interfaces.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index 52031031..cb819ac2 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -419,7 +419,7 @@ ni_description = ( [Arguments()] ) @implements( - NodeInterface{all_ni_methods_except((:leaf_copy_into!, :branch_copy_into!))}, + NodeInterface{all_ni_methods_except(())}, GraphNode, [Arguments()] ) From b5b40a751ac416732b65eda9a8b8d2eca26fcb5e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 13 Dec 2024 11:25:29 -0800 Subject: [PATCH 06/11] fix: various issues in preallocation interface --- src/Interfaces.jl | 16 ++++++++++++++-- src/NodePreallocation.jl | 5 ++--- src/ParametricExpression.jl | 10 ++++------ test/test_parametric_expression.jl | 2 +- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index cb819ac2..271f5ad4 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -13,12 +13,10 @@ using ..NodeModule: default_allocator, with_type_parameters, leaf_copy, - leaf_copy_into!, leaf_convert, leaf_hash, leaf_equal, branch_copy, - branch_copy_into!, branch_convert, branch_hash, branch_equal, @@ -38,6 +36,8 @@ using ..NodeUtilsModule: has_constants, get_scalar_constants, set_scalar_constants! +using ..NodePreallocationModule: + copy_into!, leaf_copy_into!, branch_copy_into!, allocate_container using ..StringsModule: string_tree using ..EvaluateModule: eval_tree_array using ..EvaluateDerivativeModule: eval_grad_tree_array @@ -96,6 +96,11 @@ function _check_with_metadata(ex::AbstractExpression) end ## optional +function _check_copy_into!(ex::AbstractExpression) + container = allocate_container(ex) + prealloc_ex = copy_into!(container, ex) + return container !== nothing && prealloc_ex == ex && prealloc_ex !== container +end function _check_count_nodes(ex::AbstractExpression) return count_nodes(ex) isa Int64 end @@ -156,6 +161,7 @@ ei_components = ( with_metadata = "returns the expression with different metadata" => _check_with_metadata, ), optional = ( + copy_into! = "copies an expression into a preallocated container" => _check_copy_into!, count_nodes = "counts the number of nodes in the expression tree" => _check_count_nodes, count_constant_nodes = "counts the number of constant nodes in the expression tree" => _check_count_constant_nodes, count_depth = "calculates the depth of the expression tree" => _check_count_depth, @@ -260,6 +266,11 @@ function _check_tree_mapreduce(tree::AbstractExpressionNode) end ## optional +function _check_copy_into!(tree::AbstractExpressionNode) + container = allocate_container(tree) + prealloc_tree = copy_into!(container, tree) + return container !== nothing && prealloc_tree == tree && prealloc_tree !== container +end function _check_leaf_copy(tree::AbstractExpressionNode) tree.degree != 0 && return true return leaf_copy(tree) isa typeof(tree) @@ -372,6 +383,7 @@ ni_components = ( tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce ), optional = ( + copy_into! = "copies a node into a preallocated container" => _check_copy_into!, leaf_copy = "copies a leaf node" => _check_leaf_copy, leaf_copy_into! = "copies a leaf node in-place" => _check_leaf_copy_into!, leaf_convert = "converts a leaf node" => _check_leaf_convert, diff --git a/src/NodePreallocation.jl b/src/NodePreallocation.jl index 6fb1d012..19bc6790 100644 --- a/src/NodePreallocation.jl +++ b/src/NodePreallocation.jl @@ -9,7 +9,7 @@ using ..NodeModule: set_node! """ - copy_into!(dest::AbstractArray{N}, src::N) where {BS,N<:AbstractExpressionNode} + copy_into!(dest::AbstractArray{N}, src::N) where {N<:AbstractExpressionNode} Copy a node, recursively copying all children nodes, in-place to an array of pre-allocated nodes. This should result in no extra allocations. @@ -29,8 +29,7 @@ function copy_into!( ((p, c::Vararg{Any,M}) where {M}) -> branch_copy_into!(@inbounds(dest[_ref.x += 1]), p, c...), src, - N; - break_sharing=Val(BS), + N, ) end function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode} diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index a2ef538a..16d27254 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -4,9 +4,8 @@ using DispatchDoctor: @stable, @unstable using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum -using ..NodeModule: - AbstractExpressionNode, Node, tree_mapreduce, with_contents, with_metadata -using ..ExpressionModule: AbstractExpression, Metadata +using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce +using ..ExpressionModule: AbstractExpression, Metadata, with_contents, with_metadata using ..ChainRulesModule: NodeTangent import ..NodeModule: @@ -17,9 +16,8 @@ import ..NodeModule: leaf_convert, leaf_hash, leaf_equal, - set_node!, - copy_into!, - allocate_container + set_node! +import ..NodePreallocationModule: copy_into!, allocate_container import ..NodeUtilsModule: count_constant_nodes, index_constant_nodes, diff --git a/test/test_parametric_expression.jl b/test/test_parametric_expression.jl index a3aceed3..e222765f 100644 --- a/test/test_parametric_expression.jl +++ b/test/test_parametric_expression.jl @@ -26,7 +26,7 @@ end using Interfaces: test ex = @parse_expression( - x + y + p1 * p2, + x + y + p1 * p2 + 1.5, binary_operators = [+, -, *, /], variable_names = ["x", "y"], node_type = ParametricNode, From eecc9da652fe965a4169b4696894cd4829ed022f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 13 Dec 2024 11:44:00 -0800 Subject: [PATCH 07/11] feat: add preallocation for abstract structured expression --- src/Expression.jl | 1 + src/StructuredExpression.jl | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/Expression.jl b/src/Expression.jl index 32cd2bce..1a47fef9 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -19,6 +19,7 @@ import ..NodeUtilsModule: count_scalar_constants, get_scalar_constants, set_scalar_constants! +import ..NodePreallocationModule: copy_into!, allocate_container import ..EvaluateModule: eval_tree_array, differentiable_eval_tree_array import ..EvaluateDerivativeModule: eval_grad_tree_array import ..EvaluationHelpersModule: _grad_evaluator diff --git a/src/StructuredExpression.jl b/src/StructuredExpression.jl index 28ed6cd6..963da0e4 100644 --- a/src/StructuredExpression.jl +++ b/src/StructuredExpression.jl @@ -6,12 +6,14 @@ using ..ExpressionModule: AbstractExpression, Metadata, node_type using ..ChainRulesModule: NodeTangent import ..NodeModule: constructorof +import ..NodePreallocationModule: copy_into!, allocate_container import ..ExpressionModule: get_contents, get_metadata, get_tree, get_operators, get_variable_names, + with_contents, Metadata, _copy, _data, @@ -164,4 +166,16 @@ function set_scalar_constants!(e::AbstractStructuredExpression, constants, refs) return e end +function allocate_container( + e::AbstractStructuredExpression, n::Union{Nothing,Integer}=nothing +) + ts = get_contents(e) + return (; trees=NamedTuple{keys(ts)}(map(t -> allocate_container(t, n), values(ts)))) +end +function copy_into!(dest::NamedTuple, src::AbstractStructuredExpression) + ts = get_contents(src) + new_contents = NamedTuple{keys(ts)}(map(copy_into!, values(dest.trees), values(ts))) + return with_contents(src, new_contents) +end + end From d11adbf4672a2e42cde796468f99973c25dcb625 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 13 Dec 2024 11:46:04 -0800 Subject: [PATCH 08/11] test: fix interface check --- src/Interfaces.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index 271f5ad4..c2d44b59 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -99,7 +99,7 @@ end function _check_copy_into!(ex::AbstractExpression) container = allocate_container(ex) prealloc_ex = copy_into!(container, ex) - return container !== nothing && prealloc_ex == ex && prealloc_ex !== container + return container !== nothing && prealloc_ex == ex && prealloc_ex !== ex end function _check_count_nodes(ex::AbstractExpression) return count_nodes(ex) isa Int64 From 09ec6ec51fc3aec3ed8ddd5d4595a858ad54f535 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 13 Dec 2024 12:02:22 -0800 Subject: [PATCH 09/11] test: `copy_into!` on expressions --- test/test_copy_inplace.jl | 42 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/test_copy_inplace.jl b/test/test_copy_inplace.jl index 9e2b9514..839d3d55 100644 --- a/test/test_copy_inplace.jl +++ b/test/test_copy_inplace.jl @@ -51,3 +51,45 @@ end @test result === dest_array[1] end end + +@testitem "copy_into! with expressions" begin + using DynamicExpressions + using DynamicExpressions: + copy_into!, allocate_container, get_operators, get_variable_names + + operators = OperatorEnum(; binary_operators=[+, *], unary_operators=[sin]) + variable_names = ["x", "y"] + + # Test regular Expression + ex = @parse_expression( + sin(x + 2.0 * y), operators = operators, variable_names = variable_names + ) + container = allocate_container(ex) + result = copy_into!(container, ex) + + @test result == ex + @test result !== ex + @test get_tree(result) !== get_tree(ex) + @test get_operators(result, nothing) === get_operators(ex, nothing) + @test get_variable_names(result, nothing) === get_variable_names(ex, nothing) + + # Test ParametricExpression + parameters = [1.0 2.0; 3.0 4.0] + pex = @parse_expression( + sin(x + p1 * y + p2), + operators = operators, + variable_names = variable_names, + expression_type = ParametricExpression, + extra_metadata = (; parameters=parameters, parameter_names=["p1", "p2"]) + ) + container = allocate_container(pex) + result = copy_into!(container, pex) + + @test result == pex + @test result !== pex + @test get_tree(result) !== get_tree(pex) + @test get_operators(result, nothing) === get_operators(pex, nothing) + @test get_variable_names(result, nothing) === get_variable_names(pex, nothing) + @test result.metadata.parameters !== pex.metadata.parameters + @test result.metadata.parameters == pex.metadata.parameters +end From 64035a5cd470d32aa2fa22d594f144fa3c38f343 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 13 Dec 2024 12:09:54 -0800 Subject: [PATCH 10/11] docs: improve preallocation docstrings --- src/NodePreallocation.jl | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/NodePreallocation.jl b/src/NodePreallocation.jl index 19bc6790..73fd60e8 100644 --- a/src/NodePreallocation.jl +++ b/src/NodePreallocation.jl @@ -8,11 +8,27 @@ using ..NodeModule: branch_copy, set_node! +""" + allocate_container(prototype::AbstractExpressionNode, n=nothing) + +Preallocate an array of `n` empty nodes matching the type of `prototype`. +If `n` is not provided, it will be computed from `length(prototype)`. + +A given return value of this will be passed to `copy_into!` as the first argument, +so it should be compatible. +""" +function allocate_container( + prototype::N, n::Union{Nothing,Integer}=nothing +) where {T,N<:AbstractExpressionNode{T}} + num_nodes = @something(n, length(prototype)) + return N[with_type_parameters(N, T)() for _ in 1:num_nodes] +end + """ copy_into!(dest::AbstractArray{N}, src::N) where {N<:AbstractExpressionNode} -Copy a node, recursively copying all children nodes, in-place to an -array of pre-allocated nodes. This should result in no extra allocations. +Copy a node, recursively copying all children nodes, in-place to a preallocated container. +This should result in no extra allocations. """ function copy_into!( dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing @@ -48,20 +64,5 @@ function branch_copy_into!( return dest end -""" - allocate_container(prototype::AbstractExpressionNode, n=nothing) - -Preallocate an array of `n` empty nodes matching the type of `prototype`. -If `n` is not provided, it will be computed from `length(prototype)`. - -A given return value of this will be passed to `copy_into!` as the first argument, -so it should be compatible. -""" -function allocate_container( - prototype::N, n::Union{Nothing,Integer}=nothing -) where {T,N<:AbstractExpressionNode{T}} - num_nodes = @something(n, length(prototype)) - return N[with_type_parameters(N, T)() for _ in 1:num_nodes] -end end From 16c5ef08b12d27c47fdb964e8783b8539aae6fb5 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 13 Dec 2024 12:17:06 -0800 Subject: [PATCH 11/11] ci: exclude coverage of isbits functions --- src/Expression.jl | 2 ++ src/NodePreallocation.jl | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Expression.jl b/src/Expression.jl index 1a47fef9..68fd8818 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -506,12 +506,14 @@ end # We don't require users to overload this, as it's not part of the required interface. # Also, there's no way to generally do this from the required interface, so for backwards # compatibility, we just return nothing. +# COV_EXCL_START function copy_into!(::Nothing, src::AbstractExpression) return copy(src) end function allocate_container(::AbstractExpression, ::Union{Nothing,Integer}=nothing) return nothing end +# COV_EXCL_STOP function allocate_container(prototype::Expression, n::Union{Nothing,Integer}=nothing) return (; tree=allocate_container(get_contents(prototype), n)) end diff --git a/src/NodePreallocation.jl b/src/NodePreallocation.jl index 73fd60e8..ccce372d 100644 --- a/src/NodePreallocation.jl +++ b/src/NodePreallocation.jl @@ -48,10 +48,12 @@ function copy_into!( N, ) end +# COV_EXCL_START function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode} set_node!(dest, src) return dest end +# COV_EXCL_STOP function branch_copy_into!( dest::N, src::N, children::Vararg{N,M} ) where {N<:AbstractExpressionNode,M} @@ -64,5 +66,4 @@ function branch_copy_into!( return dest end - end