Skip to content

Add support of GPU Broadcast with StructArray LHS? #150

Closed
@N3N5

Description

@N3N5

Add something like

import GPUArrays: BroadcastGPUArray,backend,@cartesianidx,launch_heuristic,launch_configuration,gpu_call
const GPUStore = Tuple{Vararg{BroadcastGPUArray}}
const NamedGPUStore = NamedTuple{Name,<:GPUStore} where {Name}
const StructGPUArray = StructArray{T, N, <:Union{GPUStore,NamedGPUStore}} where {T,N}
## backend for StructArray
backend(A::StructGPUArray) = backend(fieldarrays(A))
backend(t::GPUStore) = backend(typeof(t))
backend(nt::NamedGPUStore) = backend(typeof(nt).parameters[2])
backend(::Type{T}) where {T<:GPUStore} = begin
    I = mapreduce(backend,===,tuple(T.parameters...))
    I || throw("device error")
    backend(T.parameters[1])
end

## copy from GPUArrays
@inline function Base.copyto!(dest::StructGPUArray, bc::Broadcast.Broadcasted{Nothing})
    axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
    isempty(dest) && return dest
    bc′ = Broadcast.preprocess(dest, bc)

    # grid-stride kernel
    function broadcast_kernel(ctx, dest, bc′, nelem)
        for i in 1:nelem
            I = @cartesianidx(dest, i)
            @inbounds dest[I] = bc′[I]
        end
        return
    end
    heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1)
    config = launch_configuration(backend(dest), heuristic, length(dest), typemax(Int))
    gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread;
             threads=config.threads, blocks=config.blocks)

    return dest
end

Then we can write code like:

CUDA.allowscalar(false)
a_re = CuArray{Float32}(undef,4000,4000)
a_im = CuArray{Float32}(undef,4000,4000)
a = StructArray{ComplexF32}(tuple(a_re,a_im))
b =  CUDA.randn(size(a)...)
a .= cis.(b)
isapprox(replace_storage(Array,a),cis.(collect(b)))  # true

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions