From 2c5fcce333d31fe2f952121df6d5f2c900d88ac2 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Mon, 7 Apr 2025 08:27:38 +0200 Subject: [PATCH 1/2] Fix NamedArrayPartition similar when size is passed --- src/named_array_partition.jl | 8 ++++++-- test/named_array_partition_tests.jl | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index de8fa91a..6bd61a01 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -33,8 +33,12 @@ end # return ArrayPartition when possible, otherwise next best thing of the correct size function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} - NamedArrayPartition( - similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices)) + similar_array_partition = similar(getfield(A, :array_partition), dims) + if similar_array_partition isa ArrayPartition + NamedArrayPartition(similar_array_partition, getfield(A, :names_to_indices)) + else + similar_array_partition + end end # similar array partition of common type diff --git a/test/named_array_partition_tests.jl b/test/named_array_partition_tests.jl index d5647bad..2d3ea80a 100644 --- a/test/named_array_partition_tests.jl +++ b/test/named_array_partition_tests.jl @@ -6,6 +6,7 @@ using RecursiveArrayTools, Test @test typeof(x .^ 2) <: NamedArrayPartition @test typeof(similar(x)) <: NamedArrayPartition @test typeof(similar(x, Int)) <: NamedArrayPartition + @test typeof(similar(x, (5, 5))) <: Matrix @test x.a ≈ ones(10) @test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence @test all(x .== x[1:end]) From 5c3489c4c9a816e1e3ab5089de5eb6bc60d582c8 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Mon, 7 Apr 2025 08:55:32 +0200 Subject: [PATCH 2/2] Also fix when dims + T are passed --- src/named_array_partition.jl | 8 ++++++-- test/named_array_partition_tests.jl | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index 6bd61a01..5286f1cf 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -49,8 +49,12 @@ end # return ArrayPartition when possible, otherwise next best thing of the correct size function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N} - NamedArrayPartition( - similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices)) + similar_array_partition = similar(getfield(A, :array_partition), T, dims) + if similar_array_partition isa ArrayPartition + NamedArrayPartition(similar_array_partition, getfield(A, :names_to_indices)) + else + similar_array_partition + end end # similar array partition with different types diff --git a/test/named_array_partition_tests.jl b/test/named_array_partition_tests.jl index 2d3ea80a..9a765cd0 100644 --- a/test/named_array_partition_tests.jl +++ b/test/named_array_partition_tests.jl @@ -7,6 +7,7 @@ using RecursiveArrayTools, Test @test typeof(similar(x)) <: NamedArrayPartition @test typeof(similar(x, Int)) <: NamedArrayPartition @test typeof(similar(x, (5, 5))) <: Matrix + @test typeof(similar(x, Int, (5, 5))) <: Matrix @test x.a ≈ ones(10) @test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence @test all(x .== x[1:end])