Skip to content

Commit b99776e

Browse files
committed
Try to resolve style conflict.
Update runtests.jl Update runtests.jl
1 parent ea83c8a commit b99776e

File tree

2 files changed

+81
-21
lines changed

2 files changed

+81
-21
lines changed

src/structarray.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
487487
end
488488

489489
# broadcast
490-
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
490+
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
491491

492492
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
493493

@@ -497,6 +497,15 @@ function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
497497
return StructArrayStyle{T, N}()
498498
end
499499

500+
# StructArrayStyle is a wrapped style.
501+
# Here we try our best to resolve style conflict.
502+
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
503+
N′ = M === Any || N === Any ? Any : max(M, N)
504+
S′ = Broadcast.result_style(S(), b)
505+
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
506+
end
507+
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()
508+
500509
@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
501510
combine_style_types(BroadcastStyle(A), args...)
502511
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
@@ -507,9 +516,19 @@ Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).para
507516

508517
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
509518

510-
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:DefaultArrayStyle, N, ElType}
511-
ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType}
512-
return similar(ContainerType, axes(bc))
519+
# Here we use `similar` defined for `S` to build the dest Array.
520+
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
521+
bc′ = convert(Broadcasted{S}, bc)
522+
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
523+
end
524+
525+
# Unwrapper to recover the behaviour defined by parent style.
526+
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
527+
return copyto!(dest, convert(Broadcasted{S}, bc))
528+
end
529+
530+
@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
531+
return Broadcast.materialize!(S(), dest, bc)
513532
end
514533

515534
# for aliasing analysis during broadcast

test/runtests.jl

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,17 +1073,26 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
10731073
@test t.b.d isa Array
10741074
end
10751075

1076-
struct MyArray{T,N} <: AbstractArray{T,N}
1077-
A::Array{T,N}
1076+
for S in (1, 2, 3)
1077+
MyArray = Symbol(:MyArray, S)
1078+
@eval begin
1079+
struct $MyArray{T,N} <: AbstractArray{T,N}
1080+
A::Array{T,N}
1081+
end
1082+
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz))
1083+
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear()
1084+
Base.getindex(A::$MyArray, i::Int) = A.A[i]
1085+
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
1086+
Base.size(A::$MyArray) = Base.size(A.A)
1087+
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
1088+
end
10781089
end
1079-
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
1080-
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
1081-
Base.getindex(A::MyArray, i::Int) = A.A[i]
1082-
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
1083-
Base.size(A::MyArray) = Base.size(A.A)
1084-
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
1085-
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
1086-
MyArray{ElType}(undef, size(bc))
1090+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
1091+
MyArray1{ElType}(undef, size(bc))
1092+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
1093+
MyArray2{ElType}(undef, size(bc))
1094+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
1095+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S
10871096

10881097
@testset "broadcast" begin
10891098
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
@@ -1101,19 +1110,34 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
11011110
# used inside of broadcast but we also test it here explicitly
11021111
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})
11031112

1104-
s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
1105-
@test_throws MethodError s .+ s
1113+
# Make sure we can handle style with similar defined
1114+
# And we can handle most conflict
1115+
# s1 and s2 has similar defined, but s3 not
1116+
# s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle)
1117+
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
1118+
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
1119+
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
1120+
s4 = StructArray{ComplexF64}((rand(2), rand(2)))
1121+
1122+
function _test_similar(a, b, c)
1123+
try
1124+
d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im))
1125+
@test typeof(a .+ b .- c) == typeof(d)
1126+
catch
1127+
@test_throws MethodError a .+ b .- c
1128+
end
1129+
end
1130+
for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4)
1131+
_test_similar(s, s′, s″)
1132+
end
11061133

11071134
# test for dimensionality track
1135+
s = s1
11081136
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
11091137
@test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
11101138
@test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
11111139
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}
1112-
1113-
a = StructArray([1;2+im])
1114-
b = StructArray([1;;2+im])
1115-
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
1116-
@test a .+ Any[1] isa StructArray
1140+
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
11171141

11181142
# issue #185
11191143
A = StructArray(randn(ComplexF64, 3, 3))
@@ -1125,6 +1149,23 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
11251149
# issue #189
11261150
v = StructArray([(a="s1",), (a="s2",)])
11271151
@test @inferred(broadcast(el -> el.a, v)) == ["s1", "s2"]
1152+
1153+
@testset "ambiguity check" begin
1154+
function _test(a, b, c)
1155+
if a isa StructArray || b isa StructArray || c isa StructArray
1156+
d = @inferred a .+ b .- c
1157+
@test d == collect(a) .+ collect(b) .- collect(c)
1158+
@test d isa StructArray
1159+
end
1160+
end
1161+
testset = Any[StructArray([1;2+im]),
1162+
1:2,
1163+
(1,2),
1164+
]
1165+
for aa in testset, bb in testset, cc in testset
1166+
_test(aa, bb, cc)
1167+
end
1168+
end
11281169
end
11291170

11301171
@testset "map" begin

0 commit comments

Comments
 (0)