Skip to content

Commit 4f431c9

Browse files
committed
Try to resolve style conflict
1 parent eec1738 commit 4f431c9

File tree

4 files changed

+114
-24
lines changed

4 files changed

+114
-24
lines changed

src/StructArrays.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,16 @@ end
2929
import Adapt
3030
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)
3131

32+
# for GPU broadcast
33+
if isdefined(Adapt, :backend)
34+
function Adapt.backend(x::StructArray)
35+
cs = components(x)
36+
back = Adapt.backend(cs[1])
37+
for i in 2:length(cs)
38+
back === Adapt.backend(cs[i]) || error("backend mismatch!")
39+
end
40+
back
41+
end
42+
end
43+
3244
end # module

src/staticarrays_support.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import StaticArrays: StaticArray, FieldArray, tuple_prod
1+
import StaticArrays: StaticArray, FieldArray, tuple_prod, StaticArrayStyle
22

33
"""
44
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
@@ -26,4 +26,9 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
2626
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T)
2727
end
2828
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
29-
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
29+
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
30+
31+
function Base.copy(bc::Broadcasted{StructArrayStyle{StaticArrayStyle{N},N}}) where {N}
32+
B = convert(Broadcasted{StructArrayStyle{Broadcast.DefaultArrayStyle{N},N}}, bc)
33+
copy(B)
34+
end

src/structarray.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
443443
end
444444

445445
# broadcast
446-
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
446+
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
447447

448448
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
449449

@@ -453,6 +453,22 @@ function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
453453
return StructArrayStyle{T, N}()
454454
end
455455

456+
_dimmax(a::Integer, b::Integer) = max(a, b)
457+
_dimmax(::Type{Any}, ::Integer) = Any
458+
_dimmax(::Integer ,::Type{Any}) = Any
459+
_dimmax(::Type{Any} ,::Type{Any}) = Any
460+
461+
# StructArrayStyle is a wrapped style.
462+
# Here we try our best to resolve style conflict.
463+
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
464+
S′ = Broadcast.result_style(S(), b)
465+
if S′ isa StructArrayStyle # avoid double wrap
466+
return typeof(S′)(Val{_dimmax(N, M)}())
467+
end
468+
StructArrayStyle{typeof(S′), _dimmax(N, M)}()
469+
end
470+
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()
471+
456472
@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
457473
combine_style_types(BroadcastStyle(A), args...)
458474
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
@@ -463,9 +479,20 @@ Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).para
463479

464480
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
465481

466-
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:DefaultArrayStyle, N, ElType}
467-
ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType}
468-
return similar(ContainerType, axes(bc))
482+
# Here we use `similar` defined for `S` to build the dest Array.
483+
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
484+
bc′ = convert(Broadcasted{S}, bc)
485+
isnonemptystructtype(ElType) || return similar(bc′, ElType)
486+
return buildfromschema(T -> similar(bc′, T), ElType)
487+
end
488+
489+
# Unwrapper to recover the behaviour defined by parent style.
490+
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
491+
return copyto!(dest, convert(Broadcasted{S}, bc))
492+
end
493+
494+
@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
495+
return Broadcast.materialize!(S(), dest, bc)
469496
end
470497

471498
# for aliasing analysis during broadcast

test/runtests.jl

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -899,17 +899,26 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
899899
@test t.b.d isa Array
900900
end
901901

902-
struct MyArray{T,N} <: AbstractArray{T,N}
903-
A::Array{T,N}
904-
end
905-
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
906-
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
907-
Base.getindex(A::MyArray, i::Int) = A.A[i]
908-
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
909-
Base.size(A::MyArray) = Base.size(A.A)
910-
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
911-
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
912-
MyArray{ElType}(undef, size(bc))
902+
for S in (1, 2, 3)
903+
MyArray = Symbol(:MyArray, S)
904+
@eval begin
905+
struct $MyArray{T,N} <: AbstractArray{T,N}
906+
A::Array{T,N}
907+
end
908+
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz))
909+
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear()
910+
Base.getindex(A::$MyArray, i::Int) = A.A[i]
911+
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
912+
Base.size(A::$MyArray) = Base.size(A.A)
913+
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
914+
end
915+
end
916+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
917+
MyArray1{ElType}(undef, size(bc))
918+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
919+
MyArray2{ElType}(undef, size(bc))
920+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
921+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S
913922

914923
@testset "broadcast" begin
915924
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
@@ -927,19 +936,34 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
927936
# used inside of broadcast but we also test it here explicitly
928937
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})
929938

930-
s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
931-
@test_throws MethodError s .+ s
939+
# Make sure we can handle style with similar defined
940+
# And we can handle most conflict
941+
# s1 and s2 has similar defined, but s3 not
942+
# s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle)
943+
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
944+
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
945+
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
946+
s4 = StructArray{ComplexF64}((rand(2), rand(2)))
947+
948+
function _test_similar(a, b, c)
949+
try
950+
d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im))
951+
@test typeof(a .+ b .- c) == typeof(d)
952+
catch
953+
@test_throws MethodError a .+ b .- c
954+
end
955+
end
956+
for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4)
957+
_test_similar(s, s′, s″)
958+
end
932959

933960
# test for dimensionality track
961+
s = s1
934962
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
935963
@test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
936964
@test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
937965
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}
938-
939-
a = StructArray([1;2+im])
940-
b = StructArray([1;;2+im])
941-
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
942-
@test a .+ Any[1] isa StructArray
966+
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
943967

944968
# issue #185
945969
A = StructArray(randn(ComplexF64, 3, 3))
@@ -951,6 +975,28 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
951975
# issue #189
952976
v = StructArray([(a="s1",), (a="s2",)])
953977
@test @inferred(broadcast(el -> el.a, v)) == ["s1", "s2"]
978+
979+
# ambiguity check (can we do this better?)
980+
function _test(a, b, c)
981+
if a isa StructArray || b isa StructArray || c isa StructArray
982+
d = @inferred a .+ b .- c
983+
@test d == collect(a) .+ collect(b) .- collect(c)
984+
@test d isa StructArray
985+
end
986+
end
987+
testset = (StructArray([1;2+im]),
988+
StructArray([1 2+im]),
989+
1:2,
990+
(1,2),
991+
(@SArray [1 2]),
992+
StructArray(@SArray [1 1+2im]))
993+
for aa in testset, bb in testset, cc in testset
994+
_test(aa, bb, cc)
995+
end
996+
997+
a = @SArray randn(3,3);
998+
b = StructArray{ComplexF64}((a,a))
999+
@test a[:,1] .+ b isa StructArray && (a[:,1] .+ b).re isa SizedMatrix
9541000
end
9551001

9561002
@testset "staticarrays" begin

0 commit comments

Comments
 (0)