Skip to content

Commit f16af83

Browse files
committed
refactor: use generic getters and setters
1 parent 0514c90 commit f16af83

File tree

5 files changed

+154
-130
lines changed

5 files changed

+154
-130
lines changed

ext/DynamicExpressionsLoopVectorizationExt.jl

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module DynamicExpressionsLoopVectorizationExt
22

33
using LoopVectorization: @turbo
44
using DynamicExpressions: AbstractExpressionNode
5+
using DynamicExpressions.NodeModule: get_child
56
using DynamicExpressions.UtilsModule: ResultOk
67
using DynamicExpressions.EvaluateModule:
78
@return_on_nonfinite_val, EvalOptions, get_array, get_feature_array, get_filled_array
@@ -39,30 +40,31 @@ function deg1_l2_ll0_lr0_eval(
3940
op_l::F2,
4041
eval_options::EvalOptions{true},
4142
) where {T<:Number,F,F2}
42-
if tree.l.l.constant && tree.l.r.constant
43-
val_ll = tree.l.l.val
44-
val_lr = tree.l.r.val
43+
if get_child(get_child(tree, 1), 1).constant &&
44+
get_child(get_child(tree, 1), 2).constant
45+
val_ll = get_child(get_child(tree, 1), 1).val
46+
val_lr = get_child(get_child(tree, 1), 2).val
4547
@return_on_nonfinite_val(eval_options, val_ll, cX)
4648
@return_on_nonfinite_val(eval_options, val_lr, cX)
4749
x_l = op_l(val_ll, val_lr)::T
4850
@return_on_nonfinite_val(eval_options, x_l, cX)
4951
x = op(x_l)::T
5052
@return_on_nonfinite_val(eval_options, x, cX)
5153
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
52-
elseif tree.l.l.constant
53-
val_ll = tree.l.l.val
54+
elseif get_child(get_child(tree, 1), 1).constant
55+
val_ll = get_child(get_child(tree, 1), 1).val
5456
@return_on_nonfinite_val(eval_options, val_ll, cX)
55-
feature_lr = tree.l.r.feature
57+
feature_lr = get_child(get_child(tree, 1), 2).feature
5658
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
5759
@turbo for j in axes(cX, 2)
5860
x_l = op_l(val_ll, cX[feature_lr, j])
5961
x = op(x_l)
6062
cumulator[j] = x
6163
end
6264
return ResultOk(cumulator, true)
63-
elseif tree.l.r.constant
64-
feature_ll = tree.l.l.feature
65-
val_lr = tree.l.r.val
65+
elseif get_child(get_child(tree, 1), 2).constant
66+
feature_ll = get_child(get_child(tree, 1), 1).feature
67+
val_lr = get_child(get_child(tree, 1), 2).val
6668
@return_on_nonfinite_val(eval_options, val_lr, cX)
6769
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
6870
@turbo for j in axes(cX, 2)
@@ -72,8 +74,8 @@ function deg1_l2_ll0_lr0_eval(
7274
end
7375
return ResultOk(cumulator, true)
7476
else
75-
feature_ll = tree.l.l.feature
76-
feature_lr = tree.l.r.feature
77+
feature_ll = get_child(get_child(tree, 1), 1).feature
78+
feature_lr = get_child(get_child(tree, 1), 2).feature
7779
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
7880
@turbo for j in axes(cX, 2)
7981
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])
@@ -91,16 +93,16 @@ function deg1_l1_ll0_eval(
9193
op_l::F2,
9294
eval_options::EvalOptions{true},
9395
) where {T<:Number,F,F2}
94-
if tree.l.l.constant
95-
val_ll = tree.l.l.val
96+
if get_child(get_child(tree, 1), 1).constant
97+
val_ll = get_child(get_child(tree, 1), 1).val
9698
@return_on_nonfinite_val(eval_options, val_ll, cX)
9799
x_l = op_l(val_ll)::T
98100
@return_on_nonfinite_val(eval_options, x_l, cX)
99101
x = op(x_l)::T
100102
@return_on_nonfinite_val(eval_options, x, cX)
101103
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
102104
else
103-
feature_ll = tree.l.l.feature
105+
feature_ll = get_child(get_child(tree, 1), 1).feature
104106
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
105107
@turbo for j in axes(cX, 2)
106108
x_l = op_l(cX[feature_ll, j])
@@ -117,28 +119,28 @@ function deg2_l0_r0_eval(
117119
op::F,
118120
eval_options::EvalOptions{true},
119121
) where {T<:Number,F}
120-
if tree.l.constant && tree.r.constant
121-
val_l = tree.l.val
122+
if get_child(tree, 1).constant && get_child(tree, 2).constant
123+
val_l = get_child(tree, 1).val
122124
@return_on_nonfinite_val(eval_options, val_l, cX)
123-
val_r = tree.r.val
125+
val_r = get_child(tree, 2).val
124126
@return_on_nonfinite_val(eval_options, val_r, cX)
125127
x = op(val_l, val_r)::T
126128
@return_on_nonfinite_val(eval_options, x, cX)
127129
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
128-
elseif tree.l.constant
130+
elseif get_child(tree, 1).constant
129131
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
130-
val_l = tree.l.val
132+
val_l = get_child(tree, 1).val
131133
@return_on_nonfinite_val(eval_options, val_l, cX)
132-
feature_r = tree.r.feature
134+
feature_r = get_child(tree, 2).feature
133135
@turbo for j in axes(cX, 2)
134136
x = op(val_l, cX[feature_r, j])
135137
cumulator[j] = x
136138
end
137139
return ResultOk(cumulator, true)
138-
elseif tree.r.constant
140+
elseif get_child(tree, 2).constant
139141
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
140-
feature_l = tree.l.feature
141-
val_r = tree.r.val
142+
feature_l = get_child(tree, 1).feature
143+
val_r = get_child(tree, 2).val
142144
@return_on_nonfinite_val(eval_options, val_r, cX)
143145
@turbo for j in axes(cX, 2)
144146
x = op(cX[feature_l, j], val_r)
@@ -147,8 +149,8 @@ function deg2_l0_r0_eval(
147149
return ResultOk(cumulator, true)
148150
else
149151
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
150-
feature_l = tree.l.feature
151-
feature_r = tree.r.feature
152+
feature_l = get_child(tree, 1).feature
153+
feature_r = get_child(tree, 2).feature
152154
@turbo for j in axes(cX, 2)
153155
x = op(cX[feature_l, j], cX[feature_r, j])
154156
cumulator[j] = x
@@ -165,16 +167,16 @@ function deg2_l0_eval(
165167
op::F,
166168
eval_options::EvalOptions{true},
167169
) where {T<:Number,F}
168-
if tree.l.constant
169-
val = tree.l.val
170+
if get_child(tree, 1).constant
171+
val = get_child(tree, 1).val
170172
@return_on_nonfinite_val(eval_options, val, cX)
171173
@turbo for j in eachindex(cumulator)
172174
x = op(val, cumulator[j])
173175
cumulator[j] = x
174176
end
175177
return ResultOk(cumulator, true)
176178
else
177-
feature = tree.l.feature
179+
feature = get_child(tree, 1).feature
178180
@turbo for j in eachindex(cumulator)
179181
x = op(cX[feature, j], cumulator[j])
180182
cumulator[j] = x
@@ -190,16 +192,16 @@ function deg2_r0_eval(
190192
op::F,
191193
eval_options::EvalOptions{true},
192194
) where {T<:Number,F}
193-
if tree.r.constant
194-
val = tree.r.val
195+
if get_child(tree, 2).constant
196+
val = get_child(tree, 2).val
195197
@return_on_nonfinite_val(eval_options, val, cX)
196198
@turbo for j in eachindex(cumulator)
197199
x = op(cumulator[j], val)
198200
cumulator[j] = x
199201
end
200202
return ResultOk(cumulator, true)
201203
else
202-
feature = tree.r.feature
204+
feature = get_child(tree, 2).feature
203205
@turbo for j in eachindex(cumulator)
204206
x = op(cumulator[j], cX[feature, j])
205207
cumulator[j] = x

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DynamicExpressionsSymbolicUtilsExt
33
using DynamicExpressions:
44
AbstractExpression, get_tree, get_operators, get_variable_names, default_node_type
55
using DynamicExpressions.NodeModule:
6-
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
6+
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE, get_child
77
using DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
88
using DynamicExpressions.UtilsModule: deprecate_varmap
99

@@ -39,7 +39,7 @@ end
3939
subs_bad(x) = is_valid(x) ? x : Inf
4040

4141
function parse_tree_to_eqs(
42-
tree::AbstractExpressionNode{T},
42+
tree::AbstractExpressionNode{T,2},
4343
operators::AbstractOperatorEnum,
4444
index_functions::Bool=false,
4545
) where {T}
@@ -50,7 +50,8 @@ function parse_tree_to_eqs(
5050
end
5151
# Collect the next children
5252
# TODO: Type instability!
53-
children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,)
53+
children =
54+
tree.degree == 2 ? (get_child(tree, 1), get_child(tree, 2)) : (get_child(tree, 1),)
5455
# Get the operation
5556
op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op]
5657
# Create an N tuple of Numbers for each argument
@@ -219,13 +220,13 @@ will generate a symbolic equation in SymbolicUtils.jl format.
219220
(CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84).
220221
"""
221222
function node_to_symbolic(
222-
tree::AbstractExpressionNode,
223+
tree::AbstractExpressionNode{T,2},
223224
operators::AbstractOperatorEnum;
224225
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
225226
index_functions::Bool=false,
226227
# Deprecated:
227228
varMap=nothing,
228-
)
229+
) where {T}
229230
variable_names = deprecate_varmap(variable_names, varMap, :node_to_symbolic)
230231
expr = subs_bad(parse_tree_to_eqs(tree, operators, index_functions))
231232
# Check for NaN and Inf

src/DynamicExpressions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ import .ValueInterfaceModule:
4141
AbstractExpressionNode,
4242
GraphNode,
4343
Node,
44+
get_child,
45+
set_child!,
46+
get_children,
47+
set_children!,
4448
copy_node,
4549
set_node!,
4650
tree_mapreduce,

0 commit comments

Comments
 (0)