Skip to content

Commit 13a676a

Browse files
Merge pull request #1113 from SciML/compathelper/new_version/2024-09-17-22-09-03-606-02207253346
CompatHelper: bump compat for DiffEqFlux to 4 for package gpu, (keep existing compat)
2 parents c4de0b9 + 2255638 commit 13a676a

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

test/gpu/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
77
[compat]
88
CUDA = "3.12, 4, 5"
99
DiffEqCallbacks = "2.24, 3"
10-
DiffEqFlux = "3"
10+
DiffEqFlux = "3, 4"
1111
LuxCUDA = "0.3.1"

test/gpu/mixed_gpu_cpu_adjoint.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using SciMLSensitivity, OrdinaryDiffEq
22
using Lux, LuxCUDA, Test, Zygote, Random, LinearAlgebra, ComponentArrays
33

4+
const gdev = gpu_device()
5+
46
CUDA.allowscalar(false)
57

68
H = CuArray(rand(Float32, 2, 2))
@@ -42,10 +44,10 @@ grad = Zygote.gradient(cost, p)[1]
4244
rng = MersenneTwister(1234)
4345
m = 32
4446
n = 16
45-
Z = randn(rng, Float32, (n, m)) |> gpu
47+
Z = randn(rng, Float32, (n, m)) |> gdev
4648
𝒯 = 2.0f0
4749
Δτ = 1.0f-1
48-
ca_init = [zeros(1); ones(m)] |> gpu
50+
ca_init = [zeros(1); ones(m)] |> gdev
4951

5052
function f(ca, Z, t)
5153
a = ca[2:end]
@@ -54,7 +56,7 @@ function f(ca, Z, t)
5456
Ka_unit = Z' * w_unit
5557
z_unit = dot(abs.(Ka_unit), a_unit)
5658
aKa_over_z = a .* Ka_unit / z_unit
57-
[sum(aKa_over_z) / m; -abs.(aKa_over_z)] |> gpu
59+
[sum(aKa_over_z) / m; -abs.(aKa_over_z)] |> gdev
5860
end
5961

6062
function c(Z)

0 commit comments

Comments
 (0)