From 03d565bedf55f995e53e97e5ef14240fcb68ecdd Mon Sep 17 00:00:00 2001 From: piever Date: Fri, 14 Oct 2022 15:50:38 +0200 Subject: [PATCH 1/5] refactor finding consistent value --- src/structarray.jl | 14 ++++++-------- src/utils.jl | 12 ++++++++++++ test/runtests.jl | 20 ++++++++++++++------ 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index 3fa2029e..38b7d949 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -14,12 +14,10 @@ struct StructArray{T, N, C<:Tup, I} <: AbstractArray{T, N} components::C function StructArray{T, N, C}(c) where {T, N, C<:Tup} - isempty(c) && error("only eltypes with fields are supported") - ax = axes(first(c)) - length(ax) == N || error("wrong number of dimensions") - map(tail(c)) do ci - axes(ci) == ax || error("all field arrays must have same shape") - end + isempty(c) && throw(ArgumentError("only eltypes with fields are supported")) + ax = findconsistentvalue(axes, c) + isnothing(ax) && throw(ArgumentError("all component arrays must have the same shape")) + length(ax) == N || throw(ArgumentError("wrong number of dimensions")) new{T, N, C, index_type(c)}(c) end end @@ -369,8 +367,8 @@ end end function Base.parentindices(s::StructArray) - res = parentindices(component(s, 1)) - all(c -> parentindices(c) == res, components(s)) || throw(ArgumentError("inconsistent parentindices of components")) + res = findconsistentvalue(parentindices, components(s)) + isnothing(res) && throw(ArgumentError("inconsistent parentindices of components")) return res end diff --git a/src/utils.jl b/src/utils.jl index 5f4e9936..f1ade86a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -196,3 +196,15 @@ By default, this calls `convert(T, x)`; however, you can specialize it for other maybe_convert_elt(::Type{T}, vals) where T = convert(T, vals) maybe_convert_elt(::Type{T}, vals::Tuple) where T = T <: Tuple ? convert(T, vals) : vals # assignment of fields by position maybe_convert_elt(::Type{T}, vals::NamedTuple) where T = T<:NamedTuple ? convert(T, vals) : vals # assignment of fields by name + +""" + findconsistentvalue(f, componenents::Union{Tuple, NamedTuple}) + +Compute the unique value that `f` takes on each `component ∈ componenents`. +If not all values are equal, return `nothing`. Otherwise, return the unique value. +""" +function findconsistentvalue(f::F, (col, cols...)::Tup) where F + val = f(col) + isconsistent = mapfoldl(isequal(val)∘f, &, cols; init=true) + return ifelse(isconsistent, val, nothing) +end diff --git a/test/runtests.jl b/test/runtests.jl index 2c11b34f..4693ca1b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -105,6 +105,14 @@ end @test StructArrays.strip_params(Tuple{Int}) == Tuple @test StructArrays.astuple(NamedTuple{(:a,), Tuple{Float64}}) == Tuple{Float64} @test StructArrays.strip_params(NamedTuple{(:a,), Tuple{Float64}}) == NamedTuple{(:a,)} + + cols = (a=rand(2), b=rand(2), c=rand(2)) + @test StructArrays.findconsistentvalue(length, cols) == 2 + @test StructArrays.findconsistentvalue(length, Tuple(cols)) == 2 + + cols = (a=rand(2), b=rand(2), c=rand(3)) + @test isnothing(StructArrays.findconsistentvalue(length, cols)) + @test isnothing(StructArrays.findconsistentvalue(length, Tuple(cols))) end @testset "indexstyle" begin @@ -439,8 +447,8 @@ end @test isequal(t.a, [1, missing]) @test eltype(t) <: NamedTuple{(:a,)} - @test_throws ErrorException StructArray([nothing]) - @test_throws ErrorException StructArray([1, 2, 3]) + @test_throws ArgumentError StructArray([nothing]) + @test_throws ArgumentError StructArray([1, 2, 3]) end @testset "tuple case" begin @@ -460,10 +468,10 @@ end @test getproperty(t, 1) == [2] @test getproperty(t, 2) == [3.0] - @test_throws ErrorException StructArray(([1, 2], [3])) + @test_throws ArgumentError StructArray(([1, 2], [3])) - @test_throws ErrorException StructArray{Tuple{}}(()) - @test_throws ErrorException StructArray{Tuple{}, 1, Tuple{}}(()) + @test_throws ArgumentError StructArray{Tuple{}}(()) + @test_throws ArgumentError StructArray{Tuple{}, 1, Tuple{}}(()) end @testset "constructor from slices" begin @@ -503,7 +511,7 @@ end @test t1 == StructArray((a=[1.2], b=["test"])) @test t2 == StructArray{Pair{Float64, String}}(([1.2], ["test"])) - @test_throws ErrorException StructArray(a=[1, 2], b=[3]) + @test_throws ArgumentError StructArray(a=[1, 2], b=[3]) end @testset "complex" begin From 4b3e857c271f3bf552454ce8c9fde97db70812ee Mon Sep 17 00:00:00 2001 From: piever Date: Fri, 14 Oct 2022 15:59:28 +0200 Subject: [PATCH 2/5] add internal docs --- docs/src/reference.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/reference.md b/docs/src/reference.md index bf7a456f..d0195e94 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -57,4 +57,6 @@ StructArrays.map_params StructArrays.buildfromschema StructArrays.bypass_constructor StructArrays.iscompatible +StructArrays.maybe_convert_elt +StructArrays.findconsistentvalue ``` \ No newline at end of file From 24bd1bb4f08c1488dcbee3b06f97539233621ab8 Mon Sep 17 00:00:00 2001 From: piever Date: Fri, 14 Oct 2022 16:02:19 +0200 Subject: [PATCH 3/5] add doc compat --- docs/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index 78f2a7bf..b741cfe1 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,8 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] +Documenter = "0.27" PooledArrays = "1" From 1d18186adaccf48414fce1310ecde08f25952740 Mon Sep 17 00:00:00 2001 From: piever Date: Fri, 14 Oct 2022 16:04:16 +0200 Subject: [PATCH 4/5] remove outdated docstring --- src/structarray.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index 38b7d949..2f97d046 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -117,9 +117,6 @@ Construct a `StructArray` from slices of `A` along `dims`. The `unwrap` keyword argument is a function that determines whether to recursively convert fields of type `FT` to `StructArray`s. -!!! compat "Julia 1.1" - This function requires at least Julia 1.1. - ```julia-repl julia> X = [1.0 2.0; 3.0 4.0] 2×2 Array{Float64,2}: From 985e883e8f1c14bfe3d122392181ef7c61500ae9 Mon Sep 17 00:00:00 2001 From: piever Date: Fri, 14 Oct 2022 16:25:05 +0200 Subject: [PATCH 5/5] fix inferrability on older julia --- src/structarray.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index 2f97d046..d4bf529f 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -16,7 +16,7 @@ struct StructArray{T, N, C<:Tup, I} <: AbstractArray{T, N} function StructArray{T, N, C}(c) where {T, N, C<:Tup} isempty(c) && throw(ArgumentError("only eltypes with fields are supported")) ax = findconsistentvalue(axes, c) - isnothing(ax) && throw(ArgumentError("all component arrays must have the same shape")) + (ax === nothing) && throw(ArgumentError("all component arrays must have the same shape")) length(ax) == N || throw(ArgumentError("wrong number of dimensions")) new{T, N, C, index_type(c)}(c) end @@ -365,7 +365,7 @@ end function Base.parentindices(s::StructArray) res = findconsistentvalue(parentindices, components(s)) - isnothing(res) && throw(ArgumentError("inconsistent parentindices of components")) + (res === nothing) && throw(ArgumentError("inconsistent parentindices of components")) return res end