Skip to content

refactor finding consistent value #252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
Documenter = "0.27"
PooledArrays = "1"
2 changes: 2 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,6 @@ StructArrays.map_params
StructArrays.buildfromschema
StructArrays.bypass_constructor
StructArrays.iscompatible
StructArrays.maybe_convert_elt
StructArrays.findconsistentvalue
```
17 changes: 6 additions & 11 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ struct StructArray{T, N, C<:Tup, I} <: AbstractArray{T, N}
components::C

function StructArray{T, N, C}(c) where {T, N, C<:Tup}
isempty(c) && error("only eltypes with fields are supported")
ax = axes(first(c))
length(ax) == N || error("wrong number of dimensions")
map(tail(c)) do ci
axes(ci) == ax || error("all field arrays must have same shape")
end
isempty(c) && throw(ArgumentError("only eltypes with fields are supported"))
ax = findconsistentvalue(axes, c)
(ax === nothing) && throw(ArgumentError("all component arrays must have the same shape"))
length(ax) == N || throw(ArgumentError("wrong number of dimensions"))
new{T, N, C, index_type(c)}(c)
end
end
Expand Down Expand Up @@ -119,9 +117,6 @@ Construct a `StructArray` from slices of `A` along `dims`.
The `unwrap` keyword argument is a function that determines whether to
recursively convert fields of type `FT` to `StructArray`s.

!!! compat "Julia 1.1"
This function requires at least Julia 1.1.

```julia-repl
julia> X = [1.0 2.0; 3.0 4.0]
2×2 Array{Float64,2}:
Expand Down Expand Up @@ -369,8 +364,8 @@ end
end

function Base.parentindices(s::StructArray)
res = parentindices(component(s, 1))
all(c -> parentindices(c) == res, components(s)) || throw(ArgumentError("inconsistent parentindices of components"))
res = findconsistentvalue(parentindices, components(s))
(res === nothing) && throw(ArgumentError("inconsistent parentindices of components"))
return res
end

Expand Down
12 changes: 12 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,15 @@ By default, this calls `convert(T, x)`; however, you can specialize it for other
maybe_convert_elt(::Type{T}, vals) where T = convert(T, vals)
maybe_convert_elt(::Type{T}, vals::Tuple) where T = T <: Tuple ? convert(T, vals) : vals # assignment of fields by position
maybe_convert_elt(::Type{T}, vals::NamedTuple) where T = T<:NamedTuple ? convert(T, vals) : vals # assignment of fields by name

"""
findconsistentvalue(f, componenents::Union{Tuple, NamedTuple})

Compute the unique value that `f` takes on each `component ∈ componenents`.
If not all values are equal, return `nothing`. Otherwise, return the unique value.
"""
function findconsistentvalue(f::F, (col, cols...)::Tup) where F
val = f(col)
isconsistent = mapfoldl(isequal(val)∘f, &, cols; init=true)
return ifelse(isconsistent, val, nothing)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val might be nothing? I think we can follow Base._all_match_first style here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val might be nothing?

Not really, we only use this to compute ranges of indices. If we happen to need that use case, we can always return Some(val), but I'd rather avoid the extra wrapping / unwrapping unless it's needed.

end
20 changes: 14 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ end
@test StructArrays.strip_params(Tuple{Int}) == Tuple
@test StructArrays.astuple(NamedTuple{(:a,), Tuple{Float64}}) == Tuple{Float64}
@test StructArrays.strip_params(NamedTuple{(:a,), Tuple{Float64}}) == NamedTuple{(:a,)}

cols = (a=rand(2), b=rand(2), c=rand(2))
@test StructArrays.findconsistentvalue(length, cols) == 2
@test StructArrays.findconsistentvalue(length, Tuple(cols)) == 2

cols = (a=rand(2), b=rand(2), c=rand(3))
@test isnothing(StructArrays.findconsistentvalue(length, cols))
@test isnothing(StructArrays.findconsistentvalue(length, Tuple(cols)))
end

@testset "indexstyle" begin
Expand Down Expand Up @@ -439,8 +447,8 @@ end
@test isequal(t.a, [1, missing])
@test eltype(t) <: NamedTuple{(:a,)}

@test_throws ErrorException StructArray([nothing])
@test_throws ErrorException StructArray([1, 2, 3])
@test_throws ArgumentError StructArray([nothing])
@test_throws ArgumentError StructArray([1, 2, 3])
end

@testset "tuple case" begin
Expand All @@ -460,10 +468,10 @@ end
@test getproperty(t, 1) == [2]
@test getproperty(t, 2) == [3.0]

@test_throws ErrorException StructArray(([1, 2], [3]))
@test_throws ArgumentError StructArray(([1, 2], [3]))

@test_throws ErrorException StructArray{Tuple{}}(())
@test_throws ErrorException StructArray{Tuple{}, 1, Tuple{}}(())
@test_throws ArgumentError StructArray{Tuple{}}(())
@test_throws ArgumentError StructArray{Tuple{}, 1, Tuple{}}(())
end

@testset "constructor from slices" begin
Expand Down Expand Up @@ -503,7 +511,7 @@ end
@test t1 == StructArray((a=[1.2], b=["test"]))
@test t2 == StructArray{Pair{Float64, String}}(([1.2], ["test"]))

@test_throws ErrorException StructArray(a=[1, 2], b=[3])
@test_throws ArgumentError StructArray(a=[1, 2], b=[3])
end

@testset "complex" begin
Expand Down