Skip to content

Commit 4400fee

Browse files
authored
Merge pull request #317 from pxl-th/pxl-th/gpuarrays
Adapt to GPUArraysCore@0.2
2 parents 83fb3d7 + dcc7780 commit 4400fee

File tree

5 files changed

+35
-20
lines changed

5 files changed

+35
-20
lines changed

.github/workflows/CI.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ jobs:
1515
fail-fast: false
1616
matrix:
1717
version:
18-
- '1.6'
18+
- 'min'
19+
- 'lts'
1920
- '1'
20-
- 'nightly'
21+
- 'pre'
2122
os:
2223
- ubuntu-latest
2324
arch:

Project.toml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
name = "StructArrays"
22
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
3-
version = "0.6.19"
3+
version = "0.7.0"
44

55
[deps]
6-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
76
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
87
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
98
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
@@ -16,26 +15,28 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1615
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1716
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1817
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
18+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1919
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2020
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2121

2222
[extensions]
2323
StructArraysAdaptExt = "Adapt"
24-
StructArraysGPUArraysCoreExt = "GPUArraysCore"
24+
StructArraysGPUArraysCoreExt = ["GPUArraysCore", "KernelAbstractions"]
2525
StructArraysLinearAlgebraExt = "LinearAlgebra"
2626
StructArraysSparseArraysExt = "SparseArrays"
2727
StructArraysStaticArraysExt = "StaticArrays"
2828

2929
[compat]
30-
Adapt = "3.4, 4"
30+
Adapt = "4"
3131
Aqua = "0.8"
3232
ConstructionBase = "1"
3333
DataAPI = "1"
3434
Documenter = "1"
35-
GPUArraysCore = "0.1.2, 0.2"
35+
GPUArraysCore = "0.2"
3636
InfiniteArrays = "0.13"
37-
JLArrays = "0.1"
37+
JLArrays = "0.2"
3838
LinearAlgebra = "1"
39+
KernelAbstractions = "0.9"
3940
OffsetArrays = "1"
4041
PooledArrays = "1"
4142
SparseArrays = "1"
@@ -44,7 +45,7 @@ Tables = "1"
4445
Test = "1"
4546
TypedTables = "1"
4647
WeakRefStrings = "1"
47-
julia = "1.6"
48+
julia = "1.10"
4849

4950
[extras]
5051
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -54,6 +55,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
5455
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
5556
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
5657
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
58+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
5759
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
5860
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
5961
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -63,4 +65,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
6365
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
6466

6567
[targets]
66-
test = ["Adapt", "Aqua", "Documenter", "GPUArraysCore", "InfiniteArrays", "JLArrays", "LinearAlgebra", "OffsetArrays", "PooledArrays", "SparseArrays", "StaticArrays", "Test", "TypedTables", "WeakRefStrings"]
68+
test = ["Adapt", "Aqua", "Documenter", "GPUArraysCore", "InfiniteArrays", "JLArrays", "LinearAlgebra", "KernelAbstractions", "OffsetArrays", "PooledArrays", "SparseArrays", "StaticArrays", "Test", "TypedTables", "WeakRefStrings"]

ext/StructArraysAdaptExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module StructArraysAdaptExt
22
# Use Adapt allows for automatic conversion of CPU to GPU StructArrays
33
using Adapt, StructArrays
4-
Adapt.adapt_structure(to, s::StructArray) = replace_storage(adapt(to), s)
4+
5+
function Adapt.adapt_structure(to, s::StructArray)
6+
replace_storage(adapt(to), s)
7+
end
58
end

ext/StructArraysGPUArraysCoreExt.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,23 @@ using StructArrays: map_params, array_types
66
using Base: tail
77

88
import GPUArraysCore
9+
import KernelAbstractions as KA
10+
11+
function KA.get_backend(x::T) where {T<:StructArray}
12+
components = StructArrays.components(x)
13+
array_components = filter(
14+
fn -> getfield(components, fn) isa AbstractArray,
15+
fieldnames(typeof(components)))
16+
backends = map(
17+
fn -> KA.get_backend(getfield(components, fn)),
18+
array_components)
919

10-
# for GPU broadcast
11-
import GPUArraysCore
12-
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
13-
backends = map_params(GPUArraysCore.backend, array_types(T))
1420
backend, others = backends[1], tail(backends)
1521
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
1622
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
1723
return backend
1824
end
25+
1926
StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true
2027

2128
end # module

test/runtests.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@ using StructArrays
22
using StructArrays: staticschema, iscompatible, _promote_typejoin, append!!
33
using OffsetArrays: OffsetArray, OffsetVector, OffsetMatrix
44
using StaticArrays
5-
import Tables, PooledArrays, WeakRefStrings
65
using TypedTables: Table
76
using DataAPI: refarray, refvalue
87
using Adapt: adapt, Adapt
98
using JLArrays
10-
using GPUArraysCore: backend
119
using LinearAlgebra
1210
using Test
1311
using SparseArrays
1412
using InfiniteArrays
13+
1514
import Aqua
15+
import KernelAbstractions as KA
16+
import Tables, PooledArrays, WeakRefStrings
1617

1718
using Documenter: doctest
1819
if Base.VERSION == v"1.6" && Int === Int64
@@ -1192,6 +1193,7 @@ end
11921193
struct ArrayConverter end
11931194

11941195
Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
1196+
Adapt.adapt_structure(::ArrayConverter, xs::UnitRange) = convert(Array, xs)
11951197

11961198
@testset "adapt" begin
11971199
s = StructArray(a = 1:10, b = StructArray(c = 1:10, d = 1:10))
@@ -1372,11 +1374,11 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
13721374
a = StructArray(randn(ComplexF32, 10, 10))
13731375
sa = jl(a)
13741376
@test sa isa StructArray
1375-
@test @inferred(backend(sa)) === backend(sa.re) === backend(sa.im) === backend(jl(a.re))
1377+
@test @inferred(KA.get_backend(sa)) === KA.get_backend(sa.re) === KA.get_backend(sa.im) === KA.get_backend(jl(a.re))
13761378
@test collect(@inferred(bcabs(sa))) == bcabs(a)
1377-
@test backend(bcabs(sa)) === backend(sa)
1379+
@test KA.get_backend(bcabs(sa)) === KA.get_backend(sa)
13781380
@test @inferred(bcmul2(sa)) isa StructArray
1379-
@test backend(bcmul2(sa)) === backend(sa)
1381+
@test KA.get_backend(bcmul2(sa)) === KA.get_backend(sa)
13801382
@test (sa .+= 1) === sa
13811383
end
13821384

0 commit comments

Comments
 (0)