Skip to content

Commit f382bbb

Browse files
test: update AD tests, mark some as broken
1 parent 714d823 commit f382bbb

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

test/extensions/ad.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,18 @@ ps = [p => zeros(3, 3),
2222
q => 1.0]
2323
tspan = (0.0, 10.0)
2424
@mtkcompile sys = System(eqs, t)
25-
prob = ODEProblem(sys, u0, tspan, ps)
25+
prob = ODEProblem(sys, [u0; ps], tspan)
2626
sol = solve(prob, Tsit5())
2727

2828
mtkparams = parameter_values(prob)
2929
new_p = rand(14)
30-
gs = gradient(new_p) do new_p
31-
new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
32-
new_prob = remake(prob, p = new_params)
33-
new_sol = solve(new_prob, Tsit5())
34-
sum(new_sol)
30+
@test_broken begin
31+
gs = gradient(new_p) do new_p
32+
new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
33+
new_prob = remake(prob, p = new_params)
34+
new_sol = solve(new_prob, Tsit5())
35+
sum(new_sol)
36+
end
3537
end
3638

3739
@testset "Issue#2997" begin
@@ -50,7 +52,7 @@ end
5052
sys = mtkcompile(sys)
5153

5254
function x_at_0(θ)
53-
prob = ODEProblem(sys, [sys.x => 1.0], (0.0, 1.0), [sys.ργ0 => θ[1], sys.h => θ[2]])
55+
prob = ODEProblem(sys, [sys.x => 1.0, sys.ργ0 => θ[1], sys.h => θ[2]], (0.0, 1.0))
5456
return prob.u0[1]
5557
end
5658

@@ -61,7 +63,7 @@ end
6163
@named sys = System(
6264
Equation[], t, [], [a, b, c, d, e, f, g, h],
6365
continuous_events = [ModelingToolkit.SymbolicContinuousCallback(
64-
[a ~ 0] => [c ~ 0], discrete_parameters = c)])
66+
[a ~ 0] => [c ~ 0], discrete_parameters = c, iv = t)])
6567
sys = complete(sys)
6668

6769
ivs = Dict(c => 3a, b => ones(3), a => 1.0, d => 4, e => [5.0, 6.0, 7.0],
@@ -116,7 +118,7 @@ fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals)
116118
sys = mtkcompile(sys)
117119

118120
# Find initial throw velocity that reaches exactly 10 m after 1 s
119-
dprob0 = ODEProblem(sys, [D(y) => NaN], (0.0, 1.0), []; guesses = [y => 0.0])
121+
dprob0 = ODEProblem(sys, [D(y) => NaN], (0.0, 1.0); guesses = [y => 0.0])
120122
function f(ics, _)
121123
dprob = remake(dprob0, u0 = Dict(D(y) => ics[1]))
122124
dsol = solve(dprob, Tsit5())

0 commit comments

Comments
 (0)