diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 408d5bf5..3fa52087 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -143,7 +143,11 @@ end Base.convert(::Type{AbstractArray}, VA::AbstractVectorOfArray) = stack(VA.u) function Adapt.adapt_structure(to, VA::AbstractVectorOfArray) - Adapt.adapt(to, Array(VA)) + VectorOfArray(Adapt.adapt.(to, VA.u)) +end + +function Adapt.adapt_structure(to, VA::AbstractDiffEqArray) + DiffEqArray(Adapt.adapt.(to, VA.u), Adapt.adapt(to, VA.t)) end function VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} diff --git a/test/gpu.jl b/test/gpu.jl deleted file mode 100644 index 2ed5211c..00000000 --- a/test/gpu.jl +++ /dev/null @@ -1,4 +0,0 @@ -using RecursiveArrayTools, CuArrays -a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu)) -b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu)) -@. a + b diff --git a/test/gpu/arraypartition_gpu.jl b/test/gpu/arraypartition_gpu.jl index 80f9a8d6..08fdb69a 100644 --- a/test/gpu/arraypartition_gpu.jl +++ b/test/gpu/arraypartition_gpu.jl @@ -17,3 +17,7 @@ RecursiveArrayTools.recursivefill!(pA, true) # Test that regular filling is done using GPU kernels and not scalar indexing fill!(pA, false) @test all(pA .== false) + +a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu)) +b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu)) +@. a + b diff --git a/test/gpu/vectorofarray_gpu.jl b/test/gpu/vectorofarray_gpu.jl index c97beca2..7422a6c1 100644 --- a/test/gpu/vectorofarray_gpu.jl +++ b/test/gpu/vectorofarray_gpu.jl @@ -1,4 +1,4 @@ -using RecursiveArrayTools, CUDA, Test, Zygote +using RecursiveArrayTools, CUDA, Test, Zygote, Adapt CUDA.allowscalar(false) # Test indexing with colon @@ -37,3 +37,14 @@ va_cu = convert(AbstractArray, va) @test va_cu isa CuArray @test size(va_cu) == size(x) + +a = VectorOfArray([ones(2) for i in 1:3]) +_a = Adapt.adapt(CuArray,a) +@test _a isa VectorOfArray +@test _a.u isa Vector{<:CuArray} + +b = DiffEqArray([ones(2) for i in 1:3],ones(2)) +_b = Adapt.adapt(CuArray,b) +@test _b isa DiffEqArray +@test _b.u isa Vector{<:CuArray} +@test _b.t isa CuArray