Skip to content

Commit 8c83220

Browse files
N5N3piever
andcommitted
Resolve the review comments
1. Update Project.toml. 2. test `backend`'s inferability. Co-Authored-By: Pietro Vertechi <6333339+piever@users.noreply.github.com>
1 parent e711ebe commit 8c83220

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Adapt = "1, 2, 3"
1414
DataAPI = "1"
1515
StaticArraysCore = "1.3"
1616
StaticArrays = "1.5.6"
17-
GPUArraysCore = "~0.1.2"
17+
GPUArraysCore = "0.1.2"
1818
Tables = "1"
1919
julia = "1.6"
2020

src/StructArrays.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x
3232
# for GPU broadcast
3333
import GPUArraysCore
3434
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
35-
backs = map(GPUArraysCore.backend, fieldtypes(array_types(T)))
36-
all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!")
37-
return backs[1]
35+
backends = map_params(GPUArraysCore.backend, array_types(T))
36+
backend, others = backends[1], tail(backends)
37+
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
38+
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
39+
return backend
3840
end
3941

4042
end # module

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,8 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
12321232
bcmul2(a) = 2 .* a
12331233
a = StructArray(randn(ComplexF32, 10, 10))
12341234
sa = jl(a)
1235+
backend = StructArrays.GPUArraysCore.backend
1236+
@test @inferred(backend(sa)) === backend(sa.re) === backend(sa.im)
12351237
@test collect(@inferred(bcabs(sa))) == bcabs(a)
12361238
@test @inferred(bcmul2(sa)) isa StructArray
12371239
@test (sa .+= 1) isa StructArray

0 commit comments

Comments
 (0)