Skip to content

Commit 9ca8264

Browse files
authored
Forward complex matrix multiplication to components (#294)
* Forward complex matrix multiplication to components * mul for vectors * Convert tabs to spaces * Remove extra end
1 parent 58aba83 commit 9ca8264

File tree

5 files changed

+50
-3
lines changed

5 files changed

+50
-3
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
88
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
99
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1112
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1213
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1314

1415
[weakdeps]
1516
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1617
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
18+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1719
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1820
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1921

@@ -22,6 +24,7 @@ StructArraysAdaptExt = "Adapt"
2224
StructArraysGPUArraysCoreExt = "GPUArraysCore"
2325
StructArraysSparseArraysExt = "SparseArrays"
2426
StructArraysStaticArraysExt = "StaticArrays"
27+
StructArraysLinearAlgebraExt = "LinearAlgebra"
2528

2629
[compat]
2730
Adapt = "3.4, 4"

ext/StructArraysLinearAlgebraExt.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module StructArraysLinearAlgebraExt
2+
3+
using StructArrays
4+
using LinearAlgebra
5+
import LinearAlgebra: mul!
6+
7+
const StructMatrixC{T, A<:AbstractMatrix{T}} = StructArrays.StructMatrix{Complex{T}, @NamedTuple{re::A, im::A}}
8+
const StructVectorC{T, A<:AbstractVector{T}} = StructArrays.StructVector{Complex{T}, @NamedTuple{re::A, im::A}}
9+
10+
function _mul!(C, A, B, alpha, beta)
11+
mul!(C.re, A.re, B.re, alpha, beta)
12+
mul!(C.re, A.im, B.im, -alpha, oneunit(beta))
13+
mul!(C.im, A.re, B.im, alpha, beta)
14+
mul!(C.im, A.im, B.re, alpha, oneunit(beta))
15+
C
16+
end
17+
18+
function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number)
19+
_mul!(C, A, B, alpha, beta)
20+
end
21+
function mul!(C::StructVectorC, A::StructMatrixC, B::StructVectorC, alpha::Number, beta::Number)
22+
_mul!(C, A, B, alpha, beta)
23+
end
24+
25+
end

src/StructArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ end
3030
include("../ext/StructArraysGPUArraysCoreExt.jl")
3131
include("../ext/StructArraysSparseArraysExt.jl")
3232
include("../ext/StructArraysStaticArraysExt.jl")
33+
include("../ext/StructArraysLinearAlgebraExt.jl")
3334
end
3435

3536
end # module

src/structarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ _structarray(args::Tuple, ::Tuple) = _structarray(args, nothing)
106106
_structarray(args::NTuple{N, Any}, names::NTuple{N, Symbol}) where {N} = StructArray(NamedTuple{names}(args))
107107

108108
const StructVector{T, C<:Tup, I} = StructArray{T, 1, C, I}
109+
const StructMatrix{T, C<:Tup, I} = StructArray{T, 2, C, I}
109110
StructVector{T}(args...; kwargs...) where {T} = StructArray{T}(args...; kwargs...)
110111
StructVector(args...; kwargs...) = StructArray(args...; kwargs...)
111112

test/runtests.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ end
12061206
# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s.
12071207
# 1. `MyArray1` and `MyArray1` have `similar` defined.
12081208
# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`.
1209-
# 2. `MyArray3` has no `similar` defined.
1209+
# 2. `MyArray3` has no `similar` defined.
12101210
# We use it to simulate `BroadcastStyle` overloading `Base.copy`.
12111211
# 3. Their resolved style could be summaryized as (`-` means conflict)
12121212
# | MyArray1 | MyArray2 | MyArray3 | Array
@@ -1302,7 +1302,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
13021302
f(s) = s .+= 1
13031303
f(s)
13041304
@test (@allocated f(s)) == 0
1305-
1305+
13061306
# issue #185
13071307
A = StructArray(randn(ComplexF64, 3, 3))
13081308
B = randn(ComplexF64, 3, 3)
@@ -1321,7 +1321,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
13211321

13221322
@testset "ambiguity check" begin
13231323
test_set = Any[StructArray([1;2+im]),
1324-
1:2,
1324+
1:2,
13251325
(1,2),
13261326
StructArray(@SArray [1;1+2im]),
13271327
(@SArray [1 2]),
@@ -1550,6 +1550,23 @@ end
15501550
@test Base.IteratorSize(S) == Base.IsInfinite()
15511551
end
15521552

1553+
@testset "LinearAlgebra" begin
1554+
@testset "matrix * matrix" begin
1555+
A = StructArray{ComplexF64}((rand(10,10), rand(10,10)))
1556+
B = StructArray{ComplexF64}((rand(size(A)...), rand(size(A)...)))
1557+
MA, MB = Matrix(A), Matrix(B)
1558+
@test A * B MA * MB
1559+
@test mul!(ones(ComplexF64,size(A)), A, B, 2.0, 3.0) 2 * A * B .+ 3
1560+
end
1561+
@testset "matrix * vector" begin
1562+
A = StructArray{ComplexF64}((rand(10,10), rand(10,10)))
1563+
v = StructArray{ComplexF64}((rand(size(A,2)), rand(size(A,2))))
1564+
MA, Mv = Matrix(A), Vector(v)
1565+
@test A * v MA * Mv
1566+
@test mul!(ones(ComplexF64,size(v)), A, v, 2.0, 3.0) 2 * A * v .+ 3
1567+
end
1568+
end
1569+
15531570
@testset "project quality" begin
15541571
Aqua.test_all(StructArrays, ambiguities=(; broken=true))
15551572
end

0 commit comments

Comments
 (0)