|
1 |
| -import StaticArraysCore: StaticArray, FieldArray, tuple_prod |
| 1 | +using StaticArraysCore: StaticArray, FieldArray, tuple_prod |
2 | 2 |
|
3 | 3 | """
|
4 | 4 | StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
|
@@ -27,3 +27,50 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
|
27 | 27 | end
|
28 | 28 | StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
|
29 | 29 | StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
|
| 30 | + |
| 31 | +# Broadcast overload |
| 32 | +using StaticArraysCore: StaticArrayStyle, similar_type |
| 33 | +StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N} |
| 34 | +function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} |
| 35 | + bc′ = Broadcast.instantiate(replace_structarray(bc)) |
| 36 | + return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′) |
| 37 | +end |
| 38 | +# This looks costy, but compiler should be able to optimize them away |
| 39 | +Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc)) |
| 40 | + |
| 41 | +to_staticstyle(@nospecialize(x::Type)) = x |
| 42 | +to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N} |
| 43 | +function replace_structarray(bc::Broadcasted{Style}) where {Style} |
| 44 | + args = replace_structarray_args(bc.args) |
| 45 | + return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing) |
| 46 | +end |
| 47 | +function replace_structarray(A::StructArray) |
| 48 | + f = createinstance(eltype(A)) |
| 49 | + args = Tuple(components(A)) |
| 50 | + return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing) |
| 51 | +end |
| 52 | +replace_structarray(@nospecialize(A)) = A |
| 53 | + |
| 54 | +replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(Base.tail(args))...) |
| 55 | +replace_structarray_args(::Tuple{}) = () |
| 56 | + |
| 57 | +# StaticArrayStyle has no similar defined. |
| 58 | +# Overload `Base.copy` instead. |
| 59 | +@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} |
| 60 | + sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc)) |
| 61 | + ET = eltype(sa) |
| 62 | + isnonemptystructtype(ET) || return sa |
| 63 | + elements = Tuple(sa) |
| 64 | + arrs = ntuple(Val(fieldcount(ET))) do i |
| 65 | + similar_type(sa, fieldtype(ET, i))(_getfields(elements, i)) |
| 66 | + end |
| 67 | + return StructArray{ET}(arrs) |
| 68 | +end |
| 69 | + |
| 70 | +@inline function _getfields(x::Tuple, i::Int) |
| 71 | + if @generated |
| 72 | + return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...) |
| 73 | + else |
| 74 | + return map(Base.Fix2(getfield, i), x) |
| 75 | + end |
| 76 | +end |
0 commit comments