Skip to content

Commit 97fc15b

Browse files
authored
Merge pull request #116 from SymbolicML/expression-qol-improvements
Some quality of life API updates
2 parents dde9291 + e63317c commit 97fc15b

File tree

5 files changed

+88
-31
lines changed

5 files changed

+88
-31
lines changed

src/Expression.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,24 @@ import ..SimplifyModule: combine_operators, simplify_tree!
3131
struct Metadata{NT<:NamedTuple}
3232
_data::NT
3333
end
34-
_data(x::Metadata) = getfield(x, :_data)
34+
unpack_metadata(x::Metadata) = getfield(x, :_data)
3535

36-
Base.propertynames(x::Metadata) = propertynames(_data(x))
37-
@unstable @inline Base.getproperty(x::Metadata, f::Symbol) = getproperty(_data(x), f)
38-
Base.show(io::IO, x::Metadata) = print(io, "Metadata(", _data(x), ")")
36+
Base.propertynames(x::Metadata) = propertynames(unpack_metadata(x))
37+
@unstable @inline function Base.getproperty(x::Metadata, f::Symbol)
38+
return getproperty(unpack_metadata(x), f)
39+
end
40+
Base.show(io::IO, x::Metadata) = print(io, "Metadata(", unpack_metadata(x), ")")
3941
@inline _copy(x) = copy(x)
4042
@inline _copy(x::NamedTuple) = copy_named_tuple(x)
41-
@inline _copy(x::Nothing) = nothing
43+
@inline _copy(::Nothing) = nothing
4244
@inline function copy_named_tuple(nt::NamedTuple)
4345
return NamedTuple{keys(nt)}(map(_copy, values(nt)))
4446
end
4547
@inline function Base.copy(metadata::Metadata)
46-
return Metadata(_copy(_data(metadata)))
48+
return Metadata(_copy(unpack_metadata(metadata)))
4749
end
48-
@inline Base.:(==)(x::Metadata, y::Metadata) = _data(x) == _data(y)
49-
@inline Base.hash(x::Metadata, h::UInt) = hash(_data(x), h)
50+
@inline Base.:(==)(x::Metadata, y::Metadata) = unpack_metadata(x) == unpack_metadata(y)
51+
@inline Base.hash(x::Metadata, h::UInt) = hash(unpack_metadata(x), h)
5052

5153
"""
5254
AbstractExpression{T,N}
@@ -216,7 +218,9 @@ end
216218
Create a new expression based on `ex` but with a different `metadata`.
217219
"""
218220
function with_metadata(ex::AbstractExpression; metadata...)
219-
return with_metadata(ex, Metadata((; metadata...)))
221+
return with_metadata(
222+
ex, Metadata((; unpack_metadata(get_metadata(ex))..., metadata...))
223+
)
220224
end
221225
function with_metadata(ex::AbstractExpression, metadata::Metadata)
222226
return constructorof(typeof(ex))(get_contents(ex), metadata)
@@ -246,7 +250,13 @@ end
246250
function get_variable_names(
247251
ex::Expression, variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing
248252
)
249-
return variable_names === nothing ? ex.metadata.variable_names : variable_names
253+
return if variable_names !== nothing
254+
variable_names
255+
elseif hasproperty(ex.metadata, :variable_names)
256+
ex.metadata.variable_names
257+
else
258+
nothing
259+
end
250260
end
251261
function get_tree(ex::Expression)
252262
return ex.tree

src/ParametricExpression.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk
55

66
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
77
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
8-
using ..ExpressionModule: AbstractExpression, Metadata, with_contents, with_metadata
8+
using ..ExpressionModule:
9+
AbstractExpression, Metadata, with_contents, with_metadata, unpack_metadata
910
using ..ChainRulesModule: NodeTangent
1011

1112
import ..NodeModule:
@@ -63,7 +64,6 @@ mutable struct ParametricNode{T} <: AbstractExpressionNode{T}
6364
return n
6465
end
6566
end
66-
@inline _data(x::Metadata) = getfield(x, :_data)
6767

6868
"""
6969
ParametricExpression{T,N<:ParametricNode{T},D<:NamedTuple} <: AbstractExpression{T,N}
@@ -79,15 +79,17 @@ struct ParametricExpression{
7979
metadata::Metadata{D}
8080

8181
function ParametricExpression(tree::ParametricNode, metadata::Metadata)
82-
return new{eltype(tree),typeof(tree),typeof(_data(metadata))}(tree, metadata)
82+
return new{eltype(tree),typeof(tree),typeof(unpack_metadata(metadata))}(
83+
tree, metadata
84+
)
8385
end
8486
end
8587
function ParametricExpression(
8688
tree::ParametricNode{T1};
8789
operators::Union{AbstractOperatorEnum,Nothing},
88-
variable_names,
90+
variable_names=nothing,
8991
parameters::AbstractMatrix{T2},
90-
parameter_names,
92+
parameter_names=nothing,
9193
) where {T1,T2}
9294
if !isnothing(parameter_names)
9395
@assert size(parameters, 1) == length(parameter_names)
@@ -200,18 +202,16 @@ function get_variable_names(
200202
ex::ParametricExpression,
201203
variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing,
202204
)
203-
return variable_names === nothing ? ex.metadata.variable_names : variable_names
205+
return if variable_names !== nothing
206+
variable_names
207+
elseif hasproperty(ex.metadata, :variable_names)
208+
ex.metadata.variable_names
209+
else
210+
nothing
211+
end
204212
end
205-
@inline _copy_with_nothing(x) = copy(x)
206-
@inline _copy_with_nothing(::Nothing) = nothing
207213
function Base.copy(ex::ParametricExpression; break_sharing::Val=Val(false))
208-
return ParametricExpression(
209-
copy(ex.tree; break_sharing=break_sharing);
210-
operators=_copy_with_nothing(ex.metadata.operators),
211-
variable_names=_copy_with_nothing(ex.metadata.variable_names),
212-
parameters=_copy_with_nothing(ex.metadata.parameters),
213-
parameter_names=_copy_with_nothing(ex.metadata.parameter_names),
214-
)
214+
return ParametricExpression(copy(ex.tree; break_sharing), copy(ex.metadata))
215215
end
216216
###############################################################################
217217

src/Strings.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ function string_tree(
139139
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
140140
f_variable::F1=string_variable,
141141
f_constant::F2=string_constant,
142-
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
142+
variable_names=nothing,
143143
pretty::Union{Bool,Nothing}=nothing, # Not used, but can be used by other types
144144
# Deprecated
145145
raw::Union{Bool,Nothing}=nothing,
@@ -190,7 +190,7 @@ for io in ((), (:(io::IO),))
190190
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
191191
f_variable::F1=string_variable,
192192
f_constant::F2=string_constant,
193-
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
193+
variable_names=nothing,
194194
pretty::Union{Bool,Nothing}=nothing, # Not used, but can be used by other types
195195
# Deprecated
196196
raw::Union{Bool,Nothing}=nothing,

src/StructuredExpression.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import ..ExpressionModule:
1616
with_contents,
1717
Metadata,
1818
_copy,
19-
_data,
19+
unpack_metadata,
2020
default_node_type,
2121
node_type,
2222
get_scalar_constants,
@@ -114,7 +114,7 @@ constructorof(::Type{<:StructuredExpression}) = StructuredExpression
114114
function Base.copy(e::AbstractStructuredExpression)
115115
ts = get_contents(e)
116116
meta = get_metadata(e)
117-
meta_inner = _data(meta)
117+
meta_inner = unpack_metadata(meta)
118118
copy_ts = NamedTuple{keys(ts)}(map(copy, values(ts)))
119119
keys_except_structure = filter(!=(:structure), keys(meta_inner))
120120
copy_metadata = (;
@@ -143,7 +143,13 @@ function get_variable_names(
143143
e::AbstractStructuredExpression,
144144
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
145145
)
146-
return variable_names === nothing ? get_metadata(e).variable_names : variable_names
146+
return if variable_names !== nothing
147+
variable_names
148+
elseif hasproperty(get_metadata(e), :variable_names)
149+
get_metadata(e).variable_names
150+
else
151+
nothing
152+
end
147153
end
148154
function get_scalar_constants(e::AbstractStructuredExpression)
149155
# Get constants for each inner expression

test/test_expressions.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ end
269269

270270
@testitem "Miscellaneous expression calls" begin
271271
using DynamicExpressions
272-
using DynamicExpressions: get_tree, get_operators
272+
using DynamicExpressions: get_tree, get_operators, default_node_type
273273

274274
ex = @parse_expression(x1 + 1.5, binary_operators = [+], variable_names = ["x1"])
275275
@test DynamicExpressions.ExpressionModule.node_type(ex) <: Node
@@ -278,6 +278,24 @@ end
278278

279279
tree = get_tree(ex)
280280
@test_throws ArgumentError get_operators(tree, nothing)
281+
282+
# We can also define expressions without variable names, and it should work
283+
operators = OperatorEnum(; binary_operators=[+])
284+
for E in (Expression, ParametricExpression)
285+
N = default_node_type(E)
286+
kws = (; operators)
287+
if E === ParametricExpression
288+
kws = (; kws..., parameters=Matrix{Float64}(undef, 0, 0))
289+
end
290+
x1, x2 = (E(N(Float64; feature=i); kws...) for i in 1:2)
291+
x1000 = E(N(Float64; feature=1000); kws...)
292+
@test string(x1 + x2 + x1000) == "(x1 + x2) + x1000"
293+
# And also with structured expressions
294+
x1 = StructuredExpression(
295+
(; x1, x2, x1000); operators, structure=nt -> nt.x1 + nt.x2 + nt.x1000
296+
)
297+
@test string(x1) == "(x1 + x2) + x1000"
298+
end
281299
end
282300

283301
@testitem "Expression Literate examples" begin
@@ -413,3 +431,26 @@ end
413431

414432
#literate_end
415433
end
434+
435+
@testitem "Expression with_metadata partial updates" begin
436+
using DynamicExpressions
437+
using DynamicExpressions: get_operators, get_metadata, with_metadata, get_variable_names
438+
439+
# Create an expression with initial metadata
440+
ex = @parse_expression(
441+
x1 + 1.5,
442+
operators = OperatorEnum(; binary_operators=[+, *]),
443+
variable_names = ["x1"]
444+
)
445+
446+
# Update only the variable_names, keeping the original operators
447+
new_ex = with_metadata(ex; variable_names=["y1"])
448+
@test get_variable_names(new_ex, nothing) == ["y1"]
449+
@test get_operators(new_ex, nothing) == get_operators(ex, nothing)
450+
451+
# Update only the operators, keeping the original variable_names
452+
new_operators = OperatorEnum(; binary_operators=[+])
453+
new_ex2 = with_metadata(ex; operators=new_operators)
454+
@test get_variable_names(new_ex2, nothing) == ["x1"]
455+
@test get_operators(new_ex2, nothing) == new_operators
456+
end

0 commit comments

Comments
 (0)