Skip to content

Commit 1a26dc3

Browse files
Merge pull request #434 from SciML/ChrisRackauckas-patch-1
Recurse adapt
2 parents 8f89f21 + 63be2a1 commit 1a26dc3

File tree

4 files changed

+21
-6
lines changed

4 files changed

+21
-6
lines changed

src/vector_of_array.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,11 @@ end
143143
Base.convert(::Type{AbstractArray}, VA::AbstractVectorOfArray) = stack(VA.u)
144144

145145
function Adapt.adapt_structure(to, VA::AbstractVectorOfArray)
146-
Adapt.adapt(to, Array(VA))
146+
VectorOfArray(Adapt.adapt.(to, VA.u))
147+
end
148+
149+
function Adapt.adapt_structure(to, VA::AbstractDiffEqArray)
150+
DiffEqArray(Adapt.adapt.(to, VA.u), Adapt.adapt(to, VA.t))
147151
end
148152

149153
function VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N}

test/gpu.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

test/gpu/arraypartition_gpu.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ RecursiveArrayTools.recursivefill!(pA, true)
1717
# Test that regular filling is done using GPU kernels and not scalar indexing
1818
fill!(pA, false)
1919
@test all(pA .== false)
20+
21+
a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu))
22+
b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu))
23+
@. a + b

test/gpu/vectorofarray_gpu.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools, CUDA, Test, Zygote
1+
using RecursiveArrayTools, CUDA, Test, Zygote, Adapt
22
CUDA.allowscalar(false)
33

44
# Test indexing with colon
@@ -37,3 +37,14 @@ va_cu = convert(AbstractArray, va)
3737

3838
@test va_cu isa CuArray
3939
@test size(va_cu) == size(x)
40+
41+
a = VectorOfArray([ones(2) for i in 1:3])
42+
_a = Adapt.adapt(CuArray,a)
43+
@test _a isa VectorOfArray
44+
@test _a.u isa Vector{<:CuArray}
45+
46+
b = DiffEqArray([ones(2) for i in 1:3],ones(2))
47+
_b = Adapt.adapt(CuArray,b)
48+
@test _b isa DiffEqArray
49+
@test _b.u isa Vector{<:CuArray}
50+
@test _b.t isa CuArray

0 commit comments

Comments
 (0)