diff --git a/src/array_partition.jl b/src/array_partition.jl index b0325fe0..39e28ffe 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -166,7 +166,7 @@ Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x Base.map(f, A::ArrayPartition) = ArrayPartition(map(x -> map(f, x), A.x)) function Base.mapreduce(f, op, A::ArrayPartition{T}; kwargs...) where {T} - mapreduce(f, op, (i for i in A); kwargs...) + mapreduce(x->mapreduce(f, op, x; kwargs...), op, (i for i in A.x); kwargs...) end Base.filter(f, A::ArrayPartition) = ArrayPartition(map(x -> filter(f, x), A.x)) Base.any(f, A::ArrayPartition) = any((any(f, x) for x in A.x)) @@ -430,7 +430,10 @@ end ## Linear Algebra -ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(Vector(A)) +function ArrayInterface.zeromatrix(A::ArrayPartition) + x = reduce(vcat,vec.(A.x)) + x .* x' .* false +end function __get_subtypes_in_module( mod, supertype; include_supertype = true, all = false, except = []) diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index de8fa91a..99c83512 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -145,6 +145,13 @@ end return dest end +#Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq +function ArrayInterface.zeromatrix(A::NamedArrayPartition) + B = ArrayPartition(A) + x = reduce(vcat,vec.(B.x)) + x .* x' .* false +end + # `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments. find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args) function find_NamedArrayPartition(args::Tuple) diff --git a/test/gpu/arraypartition_gpu.jl b/test/gpu/arraypartition_gpu.jl index 08fdb69a..2c457e79 100644 --- a/test/gpu/arraypartition_gpu.jl +++ b/test/gpu/arraypartition_gpu.jl @@ -1,4 +1,4 @@ -using RecursiveArrayTools, CUDA, Test +using RecursiveArrayTools, ArrayInterface, CUDA, Test CUDA.allowscalar(false) # Test indexing with colon @@ -21,3 +21,8 @@ fill!(pA, false) a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu)) b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu)) @. a + b + +x = ArrayPartition((CUDA.zeros(2),CUDA.zeros(2))) +@test ArrayInterface.zeromatrix(x) isa CuMatrix +@test size(ArrayInterface.zeromatrix(x)) == (4,4) +@test maximum(abs, x) == 0f0 diff --git a/test/named_array_partition_tests.jl b/test/named_array_partition_tests.jl index d5647bad..e6747969 100644 --- a/test/named_array_partition_tests.jl +++ b/test/named_array_partition_tests.jl @@ -9,10 +9,13 @@ using RecursiveArrayTools, Test @test x.a ≈ ones(10) @test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence @test all(x .== x[1:end]) + @test ArrayInterface.zeromatrix(x) isa Matrix + @test size(ArrayInterface.zeromatrix(x)) == (30,30) y = copy(x) @test zero(x, (10, 20)) == zero(x) # test that ignoring dims works @test typeof(zero(x)) <: NamedArrayPartition @test (y .*= 2).a[1] ≈ 2 # test in-place bcast + @test length(Array(x)) == 30 @test typeof(Array(x)) <: Array