diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index a616c592..152513c7 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -150,6 +150,6 @@ export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_pus vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype, recursive_unitless_bottom_eltype, recursive_unitless_eltype -export ArrayPartition, NamedArrayPartition +export ArrayPartition, NamedArrayPartition, AbstractNamedArrayPartition end # module diff --git a/src/array_partition.jl b/src/array_partition.jl index 39e28ffe..82b93fa0 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -1,3 +1,5 @@ +abstract type AbstractArrayPartition{T} <: AbstractVector{T} end + """ ```julia ArrayPartition(x::AbstractArray...) @@ -23,7 +25,7 @@ A = ArrayPartition(y, z) we would have `A.x[1]==y` and `A.x[2]==z`. Broadcasting like `f.(A)` is efficient. """ -struct ArrayPartition{T, S <: Tuple} <: AbstractVector{T} +struct ArrayPartition{T, S <: Tuple} <: AbstractArrayPartition{T} x::S end diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index 99c83512..3661e58e 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -1,3 +1,21 @@ +""" + AbstractNamedArrayPartition{T, A, NT} + +An abstract type above that of `NamedArrayPartition` that can be used to subtype a +new and seperately named 'NamedArrayPartition'-like structure. This can be done +by defining your new type as: + +```julia +struct foo{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractNamedArrayPartition{T, A, NT} + array_partition::A + names_to_indices::NT +end +``` + +where `foo` is your custom name and then all funcitonalities of NamedArrayPartitions will be inherited. +""" +abstract type AbstractNamedArrayPartition{T, A, NT} <: AbstractArrayPartition{T} end + """ NamedArrayPartition(; kwargs...) NamedArrayPartition(x::NamedTuple) @@ -6,137 +24,155 @@ Similar to an `ArrayPartition` but the individual arrays can be accessed via the constructor-specified names. However, unlike `ArrayPartition`, each individual array must have the same element type. """ -struct NamedArrayPartition{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractVector{T} +struct NamedArrayPartition{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractNamedArrayPartition{T, A, NT} array_partition::A names_to_indices::NT end -NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs)) -function NamedArrayPartition(x::NamedTuple) +(::Type{T})(; kwargs...) where {T<:AbstractNamedArrayPartition} = T(NamedTuple(kwargs)) +function (::Type{T})(x::NamedTuple) where {T<:AbstractNamedArrayPartition} names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x))) # enforce homogeneity of eltypes @assert all(eltype.(values(x)) .== eltype(first(x))) - T = eltype(first(x)) + R = eltype(first(x)) S = typeof(values(x)) - return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices) + return T(ArrayPartition{R, S}(values(x)), names_to_indices) +end + +function named_partition_constructor(X::T) where {T<:AbstractNamedArrayPartition} + getfield(parentmodule(T), nameof(T)) end # Note: overloading `getproperty` means we cannot access `NamedArrayPartition` # fields except through `getfield` and accessor functions. -ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition) +ArrayPartition(x::AbstractNamedArrayPartition) = getfield(x, :array_partition) -function Base.similar(A::NamedArrayPartition) - NamedArrayPartition( +# With new type structure this function does the same as Base.similar(x::AbstractNamedArrayPartition{T, S, NT}) where {T, S, NT} +#= function Base.similar(A::T) where {T<:AbstractNamedArrayPartition} + Tconstr = named_partition_constructor(A) + Tconstr( similar(getfield(A, :array_partition)), getfield(A, :names_to_indices)) -end +end =# # return ArrayPartition when possible, otherwise next best thing of the correct size -function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} - NamedArrayPartition( +function Base.similar(A::T, dims::NTuple{N, Int}) where {T<:AbstractNamedArrayPartition, N} + Tconstr = named_partition_constructor(A) + Tconstr( similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices)) end # similar array partition of common type -@inline function Base.similar(A::NamedArrayPartition, ::Type{T}) where {T} - NamedArrayPartition( +@inline function Base.similar(A::S, ::Type{T}) where {S<:AbstractNamedArrayPartition, T} + Tconstr = named_partition_constructor(A) + Tconstr( similar(getfield(A, :array_partition), T), getfield(A, :names_to_indices)) end # return ArrayPartition when possible, otherwise next best thing of the correct size -function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N} - NamedArrayPartition( +function Base.similar(A::S, ::Type{T}, dims::NTuple{N, Int}) where {T, N, S<:AbstractNamedArrayPartition} + Tconstr = named_partition_constructor(A) + Tconstr( similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices)) end # similar array partition with different types function Base.similar( - A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S} - NamedArrayPartition( + A::U, ::Type{T}, ::Type{S}, R::DataType...) where {T, S, U<:AbstractNamedArrayPartition} + Tconstr = named_partition_constructor(A) + Tconstr( similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices)) end -Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x)) +Base.Array(x::AbstractNamedArrayPartition) = Array(ArrayPartition(x)) -function Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} - NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices)) +function Base.zero(x::R) where {R <: AbstractNamedArrayPartition} + R(zero(ArrayPartition(x)), getfield(x, :names_to_indices)) end -Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors +Base.zero(A::AbstractNamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors -Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices)) -function Base.getproperty(x::NamedArrayPartition, s::Symbol) +Base.propertynames(x::AbstractNamedArrayPartition) = propertynames(getfield(x, :names_to_indices)) +function Base.getproperty(x::AbstractNamedArrayPartition, s::Symbol) getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s)) end # this enables x.s = some_array. -@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v) +@inline function Base.setproperty!(x::AbstractNamedArrayPartition, s::Symbol, v) index = getproperty(getfield(x, :names_to_indices), s) ArrayPartition(x).x[index] .= v end # print out NamedArrayPartition as a NamedTuple -Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:") -function Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) +Base.summary(x::AbstractNamedArrayPartition) = string(typeof(x), " with arrays:") +function Base.show(io::IO, m::MIME"text/plain", x::AbstractNamedArrayPartition) show( io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x))) end -Base.size(x::NamedArrayPartition) = size(ArrayPartition(x)) -Base.length(x::NamedArrayPartition) = length(ArrayPartition(x)) -Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...) +Base.size(x::AbstractNamedArrayPartition) = size(ArrayPartition(x)) +Base.length(x::AbstractNamedArrayPartition) = length(ArrayPartition(x)) +Base.getindex(x::AbstractNamedArrayPartition, args...) = getindex(ArrayPartition(x), args...) -Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...) -function Base.map(f, x::NamedArrayPartition) - NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices)) +Base.setindex!(x::AbstractNamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...) +function Base.map(f, x::T) where {T<:AbstractNamedArrayPartition} + Tconstr = named_partition_constructor(x) + Tconstr(map(f, ArrayPartition(x)), getfield(x, :names_to_indices)) end -Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x)) -# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x)) +Base.mapreduce(f, op, x::AbstractNamedArrayPartition) = mapreduce(f, op, ArrayPartition(x)) +# Base.filter(f, x::AbstractNamedArrayPartition) = filter(f, ArrayPartition(x)) -function Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} - NamedArrayPartition{T, S, NT}( - similar(ArrayPartition(x)), getfield(x, :names_to_indices)) -end +function Base.similar(x::AbstractNamedArrayPartition{T, A, NT}) where {T, A, NT} + # Safely extract the concrete type parameters + Tconstr = named_partition_constructor(x) + return Tconstr{T, A, NT}( + similar(getfield(x, :array_partition)), + getfield(x, :names_to_indices) + ) +end # broadcasting -function Base.BroadcastStyle(::Type{<:NamedArrayPartition}) - Broadcast.ArrayStyle{NamedArrayPartition}() +function Base.BroadcastStyle(::Type{T}) where{T<:AbstractNamedArrayPartition} + Broadcast.ArrayStyle{T}() end -function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, - ::Type{ElType}) where {ElType} +function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{T}}, + ::Type{ElType}) where {ElType, T<:AbstractNamedArrayPartition} x = find_NamedArrayPartition(bc) - return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) + Tconstr = named_partition_constructor(x) + return Tconstr(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) end # when broadcasting with ArrayPartition + another array type, the output is the other array tupe function Base.BroadcastStyle( - ::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) + ::Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) Broadcast.DefaultArrayStyle{1}() end # hook into ArrayPartition broadcasting routines -@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x)) -@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) = Broadcast.Broadcasted( +@inline RecursiveArrayTools.npartitions(x::AbstractNamedArrayPartition) = npartitions(ArrayPartition(x)) +@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}}, i) = Broadcast.Broadcasted( bc.f, RecursiveArrayTools.unpack_args(i, bc.args)) -@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i) +@inline RecursiveArrayTools.unpack(x::AbstractNamedArrayPartition, i) = unpack(ArrayPartition(x), i) -function Base.copy(A::NamedArrayPartition{T, S, NT}) where {T, S, NT} - NamedArrayPartition{T, S, NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices)) +function Base.copy(A::AbstractNamedArrayPartition{T, S, NT}) where {T, S, NT} + Tconstr = named_partition_constructor(A) + Tconstr{T, S, NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices)) end -@inline NamedArrayPartition(f::F, N, names_to_indices) where {F <: Function} = NamedArrayPartition( +@inline (::Type{T})(f::F, N, names_to_indices) where {F <: Function, T<:AbstractNamedArrayPartition} = T( ArrayPartition(ntuple(f, Val(N))), names_to_indices) -@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) +@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{T}}) where {T<:AbstractNamedArrayPartition} N = npartitions(bc) @inline function f(i) copy(unpack(bc, i)) end x = find_NamedArrayPartition(bc) - NamedArrayPartition(f, N, getfield(x, :names_to_indices)) + Tconstr = named_partition_constructor(x) + Tconstr(f, N, getfield(x, :names_to_indices)) end -@inline function Base.copyto!(dest::NamedArrayPartition, - bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) +@inline function Base.copyto!(dest::AbstractNamedArrayPartition, + bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}}) N = npartitions(dest, bc) @inline function f(i) copyto!(ArrayPartition(dest).x[i], unpack(bc, i)) @@ -146,7 +182,7 @@ end end #Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq -function ArrayInterface.zeromatrix(A::NamedArrayPartition) +function ArrayInterface.zeromatrix(A::AbstractNamedArrayPartition) B = ArrayPartition(A) x = reduce(vcat,vec.(B.x)) x .* x' .* false @@ -159,5 +195,5 @@ function find_NamedArrayPartition(args::Tuple) end find_NamedArrayPartition(x) = x find_NamedArrayPartition(::Tuple{}) = nothing -find_NamedArrayPartition(x::NamedArrayPartition, rest) = x +find_NamedArrayPartition(x::AbstractNamedArrayPartition, rest) = x find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest) diff --git a/test/named_array_partition_tests.jl b/test/named_array_partition_tests.jl index e6747969..cdb5a403 100644 --- a/test/named_array_partition_tests.jl +++ b/test/named_array_partition_tests.jl @@ -1,4 +1,4 @@ -using RecursiveArrayTools, Test +using RecursiveArrayTools, Test, ArrayInterface @testset "NamedArrayPartition tests" begin x = NamedArrayPartition(a = ones(10), b = rand(20))