Skip to content

Commit 2094a78

Browse files
Type-generic zeromatrix for arraypartition
One part of SciML/OrdinaryDiffEq.jl#2703
1 parent ea0e4f0 commit 2094a78

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/array_partition.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,10 @@ end
430430

431431
## Linear Algebra
432432

433-
ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(Vector(A))
433+
function ArrayInterface.zeromatrix(A::ArrayPartition)
434+
x = reduce(vcat,vec.(A.x))
435+
x .* x' .* false
436+
end
434437

435438
function __get_subtypes_in_module(
436439
mod, supertype; include_supertype = true, all = false, except = [])

test/gpu/arraypartition_gpu.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools, CUDA, Test
1+
using RecursiveArrayTools, ArrayInterface, CUDA, Test
22
CUDA.allowscalar(false)
33

44
# Test indexing with colon
@@ -21,3 +21,6 @@ fill!(pA, false)
2121
a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu))
2222
b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu))
2323
@. a + b
24+
25+
@test ArrayInterface.zeromatrix(ArrayPartition((CUDA.zeros(2),CUDA.zeros(2)))) isa CuMatrix
26+
@test size(ArrayInterface.zeromatrix(ArrayPartition((CUDA.zeros(2),CUDA.zeros(2))))) == (4,4)

0 commit comments

Comments
 (0)