Skip to content

Commit 9721b1e

Browse files
committed
Also trial AMD support
1 parent 09d5421 commit 9721b1e

6 files changed

Lines changed: 53 additions & 20 deletions

File tree

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@ StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
99
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1010

1111
[weakdeps]
12+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1213
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1314
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1415

1516
[extensions]
17+
StridedAMDGPUExt = "AMDGPU"
1618
StridedGPUArraysExt = "GPUArrays"
1719
StridedCUDAExt = "CUDA"
1820

1921
[compat]
22+
AMDGPU = "2"
2023
Aqua = "0.8"
2124
CUDA = "5"
2225
GPUArrays = "11.4.1"
@@ -28,11 +31,12 @@ TupleTools = "1.6"
2831
julia = "1.6"
2932

3033
[extras]
34+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3135
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3236
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3337
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
3438
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3539
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3640

3741
[targets]
38-
test = ["Test", "Random", "Aqua", "CUDA", "GPUArrays"]
42+
test = ["Test", "Random", "Aqua", "AMDGPU", "CUDA", "GPUArrays"]

ext/StridedAMDGPUExt.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
module StridedAMDGPUExt
2+
3+
using Strided, StridedViews, AMDGPU
4+
using AMDGPU: Adapt
5+
using AMDGPU: GPUArrays
6+
7+
const ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}
8+
9+
function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS, TAS, FS}) where {TD, ND, TAD <: ROCArray{TD}, FD <: ALL_FS, TS, NS, TAS <: ROCArray{TS}, FS <: ALL_FS}
10+
bc_style = Base.Broadcast.BroadcastStyle(TAS)
11+
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
12+
GPUArrays._copyto!(dst, bc)
13+
return dst
14+
end
15+
16+
end

ext/StridedCUDAExt.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
module StridedCUDAExt
22

3-
using Strided, CUDA
3+
using Strided, StridedViews, CUDA
44
using CUDA: Adapt, KernelAdaptor
55
using CUDA: GPUArrays
66

77
const ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}
88

9-
function Adapt.adapt_storage(to::KernelAdaptor, xs::StridedView{T,N,TA,F}) where {T,N,TA<:CuArray{T},F <: ALL_FS}
10-
return StridedView(Adapt.adapt(to, parent(xs)), xs.size, xs.strides, xs.offset, xs.op)
11-
end
12-
139
function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS, TAS, FS}) where {TD, ND, TAD <: CuArray{TD}, FD <: ALL_FS, TS, NS, TAS <: CuArray{TS}, FS <: ALL_FS}
14-
bc_style = Base.Broadcast.BroadcastStyle(TAS)
10+
bc_style = Base.Broadcast.BroadcastStyle(TAS)
1511
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
1612
GPUArrays._copyto!(dst, bc)
1713
return dst

test/amd.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
2+
@testset "Copy with ROCStridedView: $T, $f1, $f2" for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint)
3+
for m1 in (0, 16, 32), m2 in (0, 16, 32)
4+
A1 = AMDGPU.randn(T, (m1, m2))
5+
A2 = similar(A1)
6+
A1c = copy(A1)
7+
A2c = copy(A2)
8+
B1 = f1(StridedView(A1c))
9+
B2 = f2(StridedView(A2c))
10+
axes(f1(A1)) == axes(f2(A2)) || continue
11+
@test collect(ROCMatrix(copy!(f2(A2), f1(A1)))) == Adapt.adapt(Vector{T}, copy!(B2, B1))
12+
end
13+
end
14+
end

test/cuda.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
2-
m1 = 32
3-
m2 = 16
42
@testset "Copy with CuStridedView: $T, $f1, $f2" for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint)
5-
A1 = CUDA.randn(T, (m1, m2))
6-
A2 = similar(A1)
7-
A1c = copy(A1)
8-
A2c = copy(A2)
9-
B1 = f1(StridedView(A1c))
10-
B2 = f2(StridedView(A2c))
11-
axes(f1(A1)) == axes(f2(A2)) || continue
12-
@test collect(CuMatrix(copy!(f2(A2), f1(A1)))) == Adapt.adapt(Vector{T}, copy!(B2, B1))
3+
for m1 in (0, 16, 32), m2 in (0, 16, 32)
4+
A1 = CUDA.randn(T, (m1, m2))
5+
A2 = similar(A1)
6+
A1c = copy(A1)
7+
A2c = copy(A2)
8+
B1 = f1(StridedView(A1c))
9+
B2 = f2(StridedView(A2c))
10+
axes(f1(A1)) == axes(f2(A2)) || continue
11+
@test collect(CuMatrix(copy!(f2(A2), f1(A1)))) == CUDA.Adapt.adapt(Vector{T}, copy!(B2, B1))
12+
end
1313
end
1414
end

test/runtests.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ using Random
44
using Strided
55
using Strided: StridedView
66
using Aqua
7-
using CUDA, GPUArrays
8-
using CUDA: Adapt
7+
using AMDGPU, CUDA, GPUArrays
98

109
Random.seed!(1234)
1110

@@ -29,9 +28,13 @@ if !is_buildkite
2928
include("blasmultests.jl")
3029
Strided.disable_threaded_mul()
3130

32-
Aqua.test_all(Strided; piracies=false)
31+
Aqua.test_all(Strided; piracies = false)
3332
end
3433

3534
if CUDA.functional()
3635
include("cuda.jl")
3736
end
37+
38+
if AMDGPU.functional()
39+
include("amd.jl")
40+
end

0 commit comments

Comments
 (0)