Skip to content

Commit 27f8a96

Browse files
gdallelostella
andauthored
Switch from AbstractDifferentiation to DifferentiationInterface (#93)
Following our discussion per email, this PR proposes a switch from AbstractDifferentiation.jl to DifferentiationInterface.jl, which is becoming the new standard in the ecosystem. - [x] Modify `Project.toml` files and imports - [x] Replace `SomethingBackend()` with `AutoSomething()` - [x] Replace `value_and_gradient_closure` with `value_and_gradient` (unclear how performance is affected) - [x] Update documentation and README - [ ] Add [preparation mechanism](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/operators/#Preparation): available on another branch but not sure we want it because if the function contains value-dependent control flow, preparation is not appropriate --------- Co-authored-by: Lorenzo Stella <lorenzostella@gmail.com>
1 parent b3e667e commit 27f8a96

31 files changed

+126
-138
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
name = "ProximalAlgorithms"
22
uuid = "140ffc9f-1907-541a-a177-7475e0a401e9"
3-
version = "0.6.0"
3+
version = "0.7.0"
44

55
[deps]
6-
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
6+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
7+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
910
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
1011

1112
[compat]
12-
AbstractDifferentiation = "0.6"
13+
ADTypes = "1.5.3"
14+
DifferentiationInterface = "0.5.8"
1315
LinearAlgebra = "1.2"
1416
Printf = "1.2"
1517
ProximalCore = "0.1"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Implemented algorithms include:
1919
Check out [this section](https://juliafirstorder.github.io/ProximalAlgorithms.jl/stable/guide/implemented_algorithms/) for an overview of the available algorithms.
2020

2121
Algorithms rely on:
22-
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation (but you can easily bring your own gradients)
22+
- [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) for automatic differentiation (but you can easily bring your own gradients)
2323
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc, to handle non-differentiable terms (see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) for an extensive collection of functions).
2424

2525
## Documentation

benchmark/benchmarks.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ using FileIO
88

99
const SUITE = BenchmarkGroup()
1010

11-
function ProximalAlgorithms.value_and_gradient_closure(
11+
function ProximalAlgorithms.value_and_gradient(
1212
f::ProximalOperators.LeastSquaresDirect,
1313
x,
1414
)
1515
res = f.A * x - f.b
16-
norm(res)^2 / 2, () -> f.A' * res
16+
norm(res)^2 / 2, f.A' * res
1717
end
1818

1919
struct SquaredDistance{Tb}
@@ -22,9 +22,9 @@ end
2222

2323
(f::SquaredDistance)(x) = norm(x - f.b)^2 / 2
2424

25-
function ProximalAlgorithms.value_and_gradient_closure(f::SquaredDistance, x)
25+
function ProximalAlgorithms.value_and_gradient(f::SquaredDistance, x)
2626
diff = x - f.b
27-
norm(diff)^2 / 2, () -> diff
27+
norm(diff)^2 / 2, diff
2828
end
2929

3030
for (benchmark_name, file_name) in [

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[deps]
2-
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
2+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
55
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"

docs/src/examples/sparse_linear_regression.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ end
5353
mean_squared_error(label, output) = mean((output .- label) .^ 2) / 2
5454

5555
using Zygote
56-
using AbstractDifferentiation: ZygoteBackend
56+
using DifferentiationInterface: AutoZygote
5757
using ProximalAlgorithms
5858

5959
training_loss = ProximalAlgorithms.AutoDifferentiable(
6060
wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input)),
61-
ZygoteBackend(),
61+
AutoZygote(),
6262
)
6363

6464
# As regularization we will use the L1 norm, implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl):

docs/src/guide/custom_objectives.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@
1212
#
1313
# Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref).
1414
#
15-
# To compute gradients, algorithms use [`value_and_gradient_closure`](@ref):
16-
# this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation
15+
# To compute gradients, algorithms use [`value_and_gradient`](@ref):
16+
# this relies on [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl), for automatic differentiation
1717
# with any of its supported backends, when functions are wrapped in [`AutoDifferentiable`](@ref),
1818
# as the examples below show.
1919
#
2020
# If however you would like to provide your own gradient implementation (e.g. for efficiency reasons),
21-
# you can simply implement a method for [`value_and_gradient_closure`](@ref) on your own function type.
21+
# you can simply implement a method for [`value_and_gradient`](@ref) on your own function type.
2222
#
2323
# ```@docs
2424
# ProximalCore.prox
2525
# ProximalCore.prox!
26-
# ProximalAlgorithms.value_and_gradient_closure
26+
# ProximalAlgorithms.value_and_gradient
2727
# ProximalAlgorithms.AutoDifferentiable
2828
# ```
2929
#
@@ -32,12 +32,12 @@
3232
# Let's try to minimize the celebrated Rosenbrock function, but constrained to the unit norm ball. The cost function is
3333

3434
using Zygote
35-
using AbstractDifferentiation: ZygoteBackend
35+
using DifferentiationInterface: AutoZygote
3636
using ProximalAlgorithms
3737

3838
rosenbrock2D = ProximalAlgorithms.AutoDifferentiable(
3939
x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2,
40-
ZygoteBackend(),
40+
AutoZygote(),
4141
)
4242

4343
# To enforce the constraint, we define the indicator of the unit ball, together with its proximal mapping:
@@ -105,16 +105,17 @@ end
105105

106106
Counting(f::T) where {T} = Counting{T}(f, 0, 0, 0)
107107

108-
# Now we only need to intercept any call to [`value_and_gradient_closure`](@ref) and [`prox!`](@ref) and increase counters there:
108+
function (f::Counting)(x)
109+
f.eval_count += 1
110+
return f.f(x)
111+
end
109112

110-
function ProximalAlgorithms.value_and_gradient_closure(f::Counting, x)
113+
# Now we only need to intercept any call to [`value_and_gradient`](@ref) and [`prox!`](@ref) and increase counters there:
114+
115+
function ProximalAlgorithms.value_and_gradient(f::Counting, x)
111116
f.eval_count += 1
112-
fx, pb = ProximalAlgorithms.value_and_gradient_closure(f.f, x)
113-
function counting_pullback()
114-
f.gradient_count += 1
115-
return pb()
116-
end
117-
return fx, counting_pullback
117+
f.gradient_count += 1
118+
return ProximalAlgorithms.value_and_gradient(f.f, x)
118119
end
119120

120121
function ProximalCore.prox!(y, f::Counting, x, gamma)

docs/src/guide/getting_started.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# The literature on proximal operators and algorithms is vast: for an overview, one can refer to [Parikh2014](@cite), [Beck2017](@cite).
2121
#
2222
# To evaluate these first-order primitives, in ProximalAlgorithms:
23-
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl) and all of its backends).
23+
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) and all of its backends).
2424
# * ``\operatorname{prox}_{f_i}`` relies on the intereface of [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15).
2525
# Both of the above can be implemented for custom function types, as [documented here](@ref custom_terms).
2626
#
@@ -52,13 +52,13 @@
5252

5353
using LinearAlgebra
5454
using Zygote
55-
using AbstractDifferentiation: ZygoteBackend
55+
using DifferentiationInterface: AutoZygote
5656
using ProximalOperators
5757
using ProximalAlgorithms
5858

5959
quadratic_cost = ProximalAlgorithms.AutoDifferentiable(
6060
x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x),
61-
ZygoteBackend(),
61+
AutoZygote(),
6262
)
6363
box_indicator = ProximalOperators.IndBox(0, 1)
6464

@@ -72,10 +72,9 @@ ffb = ProximalAlgorithms.FastForwardBackward(maxit = 1000, tol = 1e-5, verbose =
7272
solution, iterations = ffb(x0 = ones(2), f = quadratic_cost, g = box_indicator)
7373

7474
# We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards:
75-
# for this, we just evaluate the closure `cl` returned as second output of [`value_and_gradient_closure`](@ref).
75+
# for this, we just evaluate the second output of [`value_and_gradient`](@ref).
7676

77-
v, cl = ProximalAlgorithms.value_and_gradient_closure(quadratic_cost, solution)
78-
-cl()
77+
last(ProximalAlgorithms.value_and_gradient(quadratic_cost, solution))
7978

8079
# Or by plotting the solution against the cost function and constraint:
8180

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Implemented algorithms include:
1414
Check out [this section](@ref problems_algorithms) for an overview of the available algorithms.
1515

1616
Algorithms rely on:
17-
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation (but you can easily bring your own gradients),
17+
- [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) for automatic differentiation (but you can easily bring your own gradients),
1818
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc, to handle non-differentiable terms (see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) for an extensive collection of functions).
1919

2020
!!! note

src/ProximalAlgorithms.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ProximalAlgorithms
22

3-
using AbstractDifferentiation
3+
using ADTypes: ADTypes
4+
using DifferentiationInterface: DifferentiationInterface
45
using ProximalCore
56
using ProximalCore: prox, prox!
67

@@ -12,33 +13,30 @@ const Maybe{T} = Union{T,Nothing}
1213
1314
Callable struct wrapping function `f` to be auto-differentiated using `backend`.
1415
15-
When called, it evaluates the same as `f`, while [`value_and_gradient_closure`](@ref)
16+
When called, it evaluates the same as `f`, while its gradient
1617
is implemented using `backend` for automatic differentiation.
17-
The backend can be any from [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl).
18+
The backend can be any of those supported by [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl).
1819
"""
19-
struct AutoDifferentiable{F,B}
20+
struct AutoDifferentiable{F,B<:ADTypes.AbstractADType}
2021
f::F
2122
backend::B
2223
end
2324

2425
(f::AutoDifferentiable)(x) = f.f(x)
2526

2627
"""
27-
value_and_gradient_closure(f, x)
28+
value_and_gradient(f, x)
2829
29-
Return a tuple containing the value of `f` at `x`, and a closure `cl`.
30-
31-
Function `cl`, once called, yields the gradient of `f` at `x`.
30+
Return a tuple containing the value of `f` at `x` and the gradient of `f` at `x`.
3231
"""
33-
value_and_gradient_closure
32+
value_and_gradient
3433

35-
function value_and_gradient_closure(f::AutoDifferentiable, x)
36-
fx, pb = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x)
37-
return fx, () -> pb(one(fx))[1]
34+
function value_and_gradient(f::AutoDifferentiable, x)
35+
return DifferentiationInterface.value_and_gradient(f.f, f.backend, x)
3836
end
3937

40-
function value_and_gradient_closure(f::ProximalCore.Zero, x)
41-
f(x), () -> zero(x)
38+
function value_and_gradient(f::ProximalCore.Zero, x)
39+
return f(x), zero(x)
4240
end
4341

4442
# various utilities

src/algorithms/davis_yin.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ end
5656
function Base.iterate(iter::DavisYinIteration)
5757
z = copy(iter.x0)
5858
xg, = prox(iter.g, z, iter.gamma)
59-
f_xg, cl = value_and_gradient_closure(iter.f, xg)
60-
grad_f_xg = cl()
59+
f_xg, grad_f_xg = value_and_gradient(iter.f, xg)
6160
z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg
6261
xh, = prox(iter.h, z_half, iter.gamma)
6362
res = xh - xg
@@ -68,8 +67,8 @@ end
6867

6968
function Base.iterate(iter::DavisYinIteration, state::DavisYinState)
7069
prox!(state.xg, iter.g, state.z, iter.gamma)
71-
f_xg, cl = value_and_gradient_closure(iter.f, state.xg)
72-
state.grad_f_xg .= cl()
70+
f_xg, grad_f_xg = value_and_gradient(iter.f, state.xg)
71+
state.grad_f_xg .= grad_f_xg
7372
state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg
7473
prox!(state.xh, iter.h, state.z_half, iter.gamma)
7574
state.res .= state.xh .- state.xg

src/algorithms/fast_forward_backward.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ end
7272

7373
function Base.iterate(iter::FastForwardBackwardIteration)
7474
x = copy(iter.x0)
75-
f_x, cl = value_and_gradient_closure(iter.f, x)
76-
grad_f_x = cl()
75+
f_x, grad_f_x = value_and_gradient(iter.f, x)
7776
gamma =
7877
iter.gamma === nothing ?
7978
1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma
@@ -136,8 +135,8 @@ function Base.iterate(
136135
state.x .= state.z .+ beta .* (state.z .- state.z_prev)
137136
state.z_prev, state.z = state.z, state.z_prev
138137

139-
state.f_x, cl = value_and_gradient_closure(iter.f, state.x)
140-
state.grad_f_x .= cl()
138+
state.f_x, grad_f_x = value_and_gradient(iter.f, state.x)
139+
state.grad_f_x .= grad_f_x
141140
state.y .= state.x .- state.gamma .* state.grad_f_x
142141
state.g_z = prox!(state.z, iter.g, state.y, state.gamma)
143142
state.res .= state.x .- state.z

src/algorithms/forward_backward.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ end
6464

6565
function Base.iterate(iter::ForwardBackwardIteration)
6666
x = copy(iter.x0)
67-
f_x, cl = value_and_gradient_closure(iter.f, x)
68-
grad_f_x = cl()
67+
f_x, grad_f_x = value_and_gradient(iter.f, x)
6968
gamma =
7069
iter.gamma === nothing ?
7170
1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma
@@ -111,8 +110,8 @@ function Base.iterate(
111110
state.grad_f_x, state.grad_f_z = state.grad_f_z, state.grad_f_x
112111
else
113112
state.x, state.z = state.z, state.x
114-
state.f_x, cl = value_and_gradient_closure(iter.f, state.x)
115-
state.grad_f_x .= cl()
113+
state.f_x, grad_f_x = value_and_gradient(iter.f, state.x)
114+
state.grad_f_x .= grad_f_x
116115
end
117116

118117
state.y .= state.x .- state.gamma .* state.grad_f_x

src/algorithms/li_lin.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ end
6262

6363
function Base.iterate(iter::LiLinIteration{R}) where {R}
6464
y = copy(iter.x0)
65-
f_y, cl = value_and_gradient_closure(iter.f, y)
66-
grad_f_y = cl()
65+
f_y, grad_f_y = value_and_gradient(iter.f, y)
6766

6867
# TODO: initialize gamma if not provided
6968
# TODO: authors suggest Barzilai-Borwein rule?
@@ -110,8 +109,7 @@ function Base.iterate(iter::LiLinIteration{R}, state::LiLinState{R,Tx}) where {R
110109
else
111110
# TODO: re-use available space in state?
112111
# TODO: backtrack gamma at x
113-
f_x, cl = value_and_gradient_closure(iter.f, x)
114-
grad_f_x = cl()
112+
f_x, grad_f_x = value_and_gradient(iter.f, x)
115113
x_forward = state.x - state.gamma .* grad_f_x
116114
v, g_v = prox(iter.g, x_forward, state.gamma)
117115
Fv = iter.f(v) + g_v
@@ -130,8 +128,8 @@ function Base.iterate(iter::LiLinIteration{R}, state::LiLinState{R,Tx}) where {R
130128
Fx = Fv
131129
end
132130

133-
state.f_y, cl = value_and_gradient_closure(iter.f, state.y)
134-
state.grad_f_y .= cl()
131+
state.f_y, grad_f_y = value_and_gradient(iter.f, state.y)
132+
state.grad_f_y .= grad_f_y
135133
state.y_forward .= state.y .- state.gamma .* state.grad_f_y
136134
state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma)
137135

src/algorithms/panoc.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ f_model(iter::PANOCIteration, state::PANOCState) =
8787
function Base.iterate(iter::PANOCIteration{R}) where {R}
8888
x = copy(iter.x0)
8989
Ax = iter.A * x
90-
f_Ax, cl = value_and_gradient_closure(iter.f, Ax)
91-
grad_f_Ax = cl()
90+
f_Ax, grad_f_Ax = value_and_gradient(iter.f, Ax)
9291
gamma =
9392
iter.gamma === nothing ?
9493
iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) :
@@ -182,8 +181,8 @@ function Base.iterate(iter::PANOCIteration{R,Tx,Tf}, state::PANOCState) where {R
182181

183182
state.x_d .= state.x .+ state.d
184183
state.Ax_d .= state.Ax .+ state.Ad
185-
state.f_Ax_d, cl = value_and_gradient_closure(iter.f, state.Ax_d)
186-
state.grad_f_Ax_d .= cl()
184+
state.f_Ax_d, grad_f_Ax_d = value_and_gradient(iter.f, state.Ax_d)
185+
state.grad_f_Ax_d .= grad_f_Ax_d
187186
mul!(state.At_grad_f_Ax_d, adjoint(iter.A), state.grad_f_Ax_d)
188187

189188
copyto!(state.x, state.x_d)
@@ -220,8 +219,8 @@ function Base.iterate(iter::PANOCIteration{R,Tx,Tf}, state::PANOCState) where {R
220219
# along a line using interpolation and linear combinations
221220
# this allows saving operations
222221
if isinf(f_Az)
223-
f_Az, cl = value_and_gradient_closure(iter.f, state.Az)
224-
state.grad_f_Az .= cl()
222+
f_Az, grad_f_Az = value_and_gradient(iter.f, state.Az)
223+
state.grad_f_Az .= grad_f_Az
225224
end
226225
if isinf(c)
227226
mul!(state.At_grad_f_Az, iter.A', state.grad_f_Az)
@@ -239,8 +238,8 @@ function Base.iterate(iter::PANOCIteration{R,Tx,Tf}, state::PANOCState) where {R
239238
else
240239
# otherwise, in the general case where f is only smooth, we compute
241240
# one gradient and matvec per backtracking step
242-
state.f_Ax, cl = value_and_gradient_closure(iter.f, state.Ax)
243-
state.grad_f_Ax .= cl()
241+
state.f_Ax, grad_f_Ax = value_and_gradient(iter.f, state.Ax)
242+
state.grad_f_Ax .= grad_f_Ax
244243
mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax)
245244
end
246245

0 commit comments

Comments
 (0)