Skip to content

Commit eec1738

Browse files
authored
more broadcast fixes (#213)
* more broadcast fixes * minor refactor
1 parent d662f13 commit eec1738

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

src/structarray.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
449449

450450
# Here we define the dimension tracking behavior of StructArrayStyle
451451
function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
452-
T = S <: AbstractArrayStyle{M} ? typeof(S(Val(N))) : S
452+
T = S <: AbstractArrayStyle{M} ? typeof(S(Val{N}())) : S
453453
return StructArrayStyle{T, N}()
454454
end
455455

@@ -463,8 +463,10 @@ Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).para
463463

464464
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
465465

466-
Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:DefaultArrayStyle, N, ElType} =
467-
isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc))
466+
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:DefaultArrayStyle, N, ElType}
467+
ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType}
468+
return similar(ContainerType, axes(bc))
469+
end
468470

469471
# for aliasing analysis during broadcast
470472
Base.dataids(u::StructArray) = mapreduce(Base.dataids, (a, b) -> (a..., b...), values(components(u)), init=())

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ hasfields(::Type{<:NamedTuple{names}}) where {names} = true
172172
hasfields(::Type{T}) where {T} = !isabstracttype(T)
173173
hasfields(::Union) = false
174174

175+
isnonemptystructtype(::Type{T}) where {T} = isstructtype(T) && fieldcount(T) != 0
176+
175177
"""
176178
StructArrays.bypass_constructor(T, args)
177179

test/runtests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -945,7 +945,12 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
945945
A = StructArray(randn(ComplexF64, 3, 3))
946946
B = randn(ComplexF64, 3, 3)
947947
c = StructArray(randn(ComplexF64, 3))
948-
@test (A .= B .* c) === A
948+
A .= B .* c
949+
@test @inferred(B .* c) == A == B .* collect(c)
950+
951+
# issue #189
952+
v = StructArray([(a="s1",), (a="s2",)])
953+
@test @inferred(broadcast(el -> el.a, v)) == ["s1", "s2"]
949954
end
950955

951956
@testset "staticarrays" begin

0 commit comments

Comments
 (0)