Skip to content

Commit 1ad3ea4

Browse files
Merge pull request #431 from SouthEndMusic/master
Use `similar` methods from `ArrayPartion` for `NamedArrayPartition`
2 parents 8765aae + 6a66fc8 commit 1ad3ea4

File tree

6 files changed

+74
-27
lines changed

6 files changed

+74
-27
lines changed

src/named_array_partition.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,36 @@ end
2626
# fields except through `getfield` and accessor functions.
2727
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)
2828

29+
function Base.similar(A::NamedArrayPartition)
30+
NamedArrayPartition(
31+
similar(getfield(A, :array_partition)), getfield(A, :names_to_indices))
32+
end
33+
34+
# return ArrayPartition when possible, otherwise next best thing of the correct size
35+
function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N}
36+
NamedArrayPartition(
37+
similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices))
38+
end
39+
40+
# similar array partition of common type
41+
@inline function Base.similar(A::NamedArrayPartition, ::Type{T}) where {T}
42+
NamedArrayPartition(
43+
similar(getfield(A, :array_partition), T), getfield(A, :names_to_indices))
44+
end
45+
46+
# return ArrayPartition when possible, otherwise next best thing of the correct size
47+
function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N}
48+
NamedArrayPartition(
49+
similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices))
50+
end
51+
52+
# similar array partition with different types
53+
function Base.similar(
54+
A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}
55+
NamedArrayPartition(
56+
similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices))
57+
end
58+
2959
Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))
3060

3161
function Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN}

src/utils.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
unrolled_foreach!(f, t::Tuple) = (f(t[1]); unrolled_foreach!(f, Base.tail(t)))
22
unrolled_foreach!(f, ::Tuple{}) = nothing
33

4-
54
"""
65
```julia
76
recursivecopy(a::Union{AbstractArray{T, N}, AbstractVectorOfArray{T, N}})
@@ -131,7 +130,6 @@ function recursivefill!(bs::AbstractVectorOfArray{T, N},
131130
end
132131
end
133132

134-
135133
for type in [AbstractArray, AbstractVectorOfArray]
136134
@eval function recursivefill!(b::$type{T, N}, a::T2) where {T <: Enum, T2 <: Enum, N}
137135
fill!(b, a)

src/vector_of_array.jl

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,11 @@ Base.parent(vec::VectorOfArray) = vec.u
177177
#### 2-argument
178178

179179
# first element representative
180-
function DiffEqArray(vec::AbstractVector, ts::AbstractVector; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing)
180+
function DiffEqArray(vec::AbstractVector, ts::AbstractVector; discretes = nothing,
181+
variables = nothing, parameters = nothing, independent_variables = nothing)
181182
sys = SymbolCache(something(variables, []),
182-
something(parameters, []),
183-
something(independent_variables, []))
183+
something(parameters, []),
184+
something(independent_variables, []))
184185
_size = size(vec[1])
185186
T = eltype(vec[1])
186187
return DiffEqArray{
@@ -199,10 +200,12 @@ function DiffEqArray(vec::AbstractVector, ts::AbstractVector; discretes = nothin
199200
end
200201

201202
# T and N from type
202-
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}}
203+
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector;
204+
discretes = nothing, variables = nothing, parameters = nothing,
205+
independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}}
203206
sys = SymbolCache(something(variables, []),
204-
something(parameters, []),
205-
something(independent_variables, []))
207+
something(parameters, []),
208+
something(independent_variables, []))
206209
return DiffEqArray{
207210
eltype(eltype(vec)),
208211
N + 1,
@@ -221,7 +224,8 @@ end
221224
#### 3-argument
222225

223226
# NTuple, T from type
224-
function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}; discretes = nothing) where {T, N}
227+
function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector,
228+
::NTuple{N, Int}; discretes = nothing) where {T, N}
225229
DiffEqArray{
226230
eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, typeof(discretes)}(
227231
vec,
@@ -232,19 +236,23 @@ function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int
232236
end
233237

234238
# NTuple parameter
235-
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2}
236-
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
239+
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int};
240+
discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2}
241+
DiffEqArray{
242+
eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(
243+
vec,
237244
ts,
238245
p,
239246
nothing,
240247
discretes)
241248
end
242249

243250
# first element representative
244-
function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing)
251+
function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p; discretes = nothing,
252+
variables = nothing, parameters = nothing, independent_variables = nothing)
245253
sys = SymbolCache(something(variables, []),
246-
something(parameters, []),
247-
something(independent_variables, []))
254+
something(parameters, []),
255+
something(independent_variables, []))
248256
_size = size(vec[1])
249257
T = eltype(vec[1])
250258
return DiffEqArray{
@@ -263,11 +271,14 @@ function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p; discretes = not
263271
end
264272

265273
# T and N from type
266-
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}}
274+
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p;
275+
discretes = nothing, variables = nothing, parameters = nothing,
276+
independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}}
267277
sys = SymbolCache(something(variables, []),
268-
something(parameters, []),
269-
something(independent_variables, []))
270-
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
278+
something(parameters, []),
279+
something(independent_variables, []))
280+
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts),
281+
typeof(p), typeof(sys), typeof(discretes)}(vec,
271282
ts,
272283
p,
273284
sys,
@@ -277,7 +288,8 @@ end
277288
#### 4-argument
278289

279290
# NTuple, T from type
280-
function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p; discretes = nothing) where {T, N}
291+
function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector,
292+
::NTuple{N, Int}, p; discretes = nothing) where {T, N}
281293
DiffEqArray{
282294
eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(
283295
vec,
@@ -288,8 +300,10 @@ function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int
288300
end
289301

290302
# NTuple parameter
291-
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}, sys; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2}
292-
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
303+
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}, sys;
304+
discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2}
305+
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts),
306+
typeof(p), typeof(sys), typeof(discretes)}(vec,
293307
ts,
294308
p,
295309
sys,
@@ -316,8 +330,10 @@ function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p, sys; discretes
316330
end
317331

318332
# T and N from type
319-
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p, sys; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
320-
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
333+
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p, sys;
334+
discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
335+
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts),
336+
typeof(p), typeof(sys), typeof(discretes)}(vec,
321337
ts,
322338
p,
323339
sys,
@@ -327,7 +343,8 @@ end
327343
#### 5-argument
328344

329345
# NTuple, T from type
330-
function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p, sys; discretes = nothing) where {T, N}
346+
function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector,
347+
::NTuple{N, Int}, p, sys; discretes = nothing) where {T, N}
331348
DiffEqArray{
332349
eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(
333350
vec,
@@ -942,7 +959,7 @@ end
942959
VectorOfArray(rewrap(parent, u))
943960
end
944961

945-
rewrap(::Array,u) = u
962+
rewrap(::Array, u) = u
946963
rewrap(parent, u) = convert(typeof(parent), u)
947964

948965
for (type, N_expr) in [

test/downstream/symbol_indexing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ sol_new = DiffEqArray(sol.u[1:10],
2727
@test all(isequal.(all_variable_symbols(sol), all_variable_symbols(sol_new)))
2828
@test all(isequal.(all_variable_symbols(sol), [x, RHS]))
2929
@test all(isequal.(all_symbols(sol), all_symbols(sol_new)))
30-
@test all([any(isequal(sym), all_symbols(sol)) for sym in [x, RHS, τ, t, Initial(x), Initial(RHS)]])
30+
@test all([any(isequal(sym), all_symbols(sol))
31+
for sym in [x, RHS, τ, t, Initial(x), Initial(RHS)]])
3132
@test sol[solvedvariables] == sol[[x]]
3233
@test sol_new[solvedvariables] == sol_new[[x]]
3334
@test sol[allvariables] == sol[[x, RHS]]

test/gpu/arraypartition_gpu.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using RecursiveArrayTools, CUDA, Test
22
CUDA.allowscalar(false)
33

4-
54
# Test indexing with colon
65
a = (CUDA.zeros(5), CUDA.zeros(5))
76
pA = ArrayPartition(a)

test/named_array_partition_tests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using RecursiveArrayTools, Test
44
x = NamedArrayPartition(a = ones(10), b = rand(20))
55
@test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition
66
@test typeof(x .^ 2) <: NamedArrayPartition
7+
@test typeof(similar(x)) <: NamedArrayPartition
8+
@test typeof(similar(x, Int)) <: NamedArrayPartition
79
@test x.a ones(10)
810
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
911
@test all(x .== x[1:end])

0 commit comments

Comments
 (0)