Skip to content

Commit 875b923

Browse files
committed
Add derivative rules in ChainRulesCore
1 parent 79e91be commit 875b923

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ authors = ["Chris Rackauckas and Julia Computing"]
44
version = "2.1.0"
55

66
[deps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
89
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1112
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1213

1314
[compat]
15+
ChainRulesCore = "1"
1416
DiffEqBase = "6"
1517
IfElse = "0.1"
1618
ModelingToolkit = "8.50"

src/Blocks/sources.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DiffEqBase
2+
import ChainRulesCore
23

34
# Define and register smooth functions
45
# These are "smooth" aka differentiable and avoid Gibbs effect
@@ -575,6 +576,9 @@ function Symbolics.derivative(::typeof(get_sampled_data), args::NTuple{2, Any},
575576
memory = @inbounds args[2]
576577
first_order_backwards_difference(t, memory)
577578
end
579+
function ChainRulesCore.frule((_, ẋ, _), ::typeof(get_sampled_data), t, memory)
580+
first_order_backwards_difference(t, memory) *
581+
end
578582

579583
"""
580584
SampledData(; name, buffer)
@@ -649,6 +653,14 @@ function Symbolics.derivative(::typeof(set_sampled_data!), args::NTuple{4, Any},
649653
first_order_backwards_difference(t, x, Δt, memory)
650654
end
651655
Symbolics.derivative(::typeof(set_sampled_data!), args::NTuple{4, Any}, ::Val{3}) = 1 #set_sampled_data returns x, therefore d/dx (x) = 1
656+
function ChainRulesCore.frule((_, _, ṫ, ẋ, _),
657+
::typeof(set_sampled_data!),
658+
memory,
659+
t,
660+
x,
661+
Δt)
662+
first_order_backwards_difference(t, x, Δt, memory) *+
663+
end
652664

653665
function first_order_backwards_difference(t, x, Δt, memory)
654666
x1 = set_sampled_data!(memory, t, x, Δt)

src/Hydraulic/IsothermalCompressible/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import ChainRulesCore
2+
13
regPow(x, a, delta = 0.01) = x * (x * x + delta * delta)^((a - 1) / 2);
24
regRoot(x, delta = 0.01) = regPow(x, 0.5, delta)
35

@@ -116,6 +118,7 @@ end
116118
@register_symbolic friction_factor(dm, area, d_h, viscosity, shape_factor)
117119
Symbolics.derivative(::typeof(friction_factor), args, ::Val{1}) = 0
118120
Symbolics.derivative(::typeof(friction_factor), args, ::Val{4}) = 0
121+
ChainRulesCore.frule(_, ::typeof(friction_factor), _...) = 0
119122

120123
function transition(x1, x2, y1, y2, x)
121124
u = (x - x1) / (x2 - x1)

0 commit comments

Comments
 (0)