diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index de8fa91a..5286f1cf 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -33,8 +33,12 @@ end # return ArrayPartition when possible, otherwise next best thing of the correct size function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} - NamedArrayPartition( - similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices)) + similar_array_partition = similar(getfield(A, :array_partition), dims) + if similar_array_partition isa ArrayPartition + NamedArrayPartition(similar_array_partition, getfield(A, :names_to_indices)) + else + similar_array_partition + end end # similar array partition of common type @@ -45,8 +49,12 @@ end # return ArrayPartition when possible, otherwise next best thing of the correct size function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N} - NamedArrayPartition( - similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices)) + similar_array_partition = similar(getfield(A, :array_partition), T, dims) + if similar_array_partition isa ArrayPartition + NamedArrayPartition(similar_array_partition, getfield(A, :names_to_indices)) + else + similar_array_partition + end end # similar array partition with different types diff --git a/test/named_array_partition_tests.jl b/test/named_array_partition_tests.jl index d5647bad..9a765cd0 100644 --- a/test/named_array_partition_tests.jl +++ b/test/named_array_partition_tests.jl @@ -6,6 +6,8 @@ using RecursiveArrayTools, Test @test typeof(x .^ 2) <: NamedArrayPartition @test typeof(similar(x)) <: NamedArrayPartition @test typeof(similar(x, Int)) <: NamedArrayPartition + @test typeof(similar(x, (5, 5))) <: Matrix + @test typeof(similar(x, Int, (5, 5))) <: Matrix @test x.a ≈ ones(10) @test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence @test all(x .== x[1:end])