Skip to content

Commit c4025a3

Browse files
authored
Update Project.toml (#45)
* Update Project.toml Bump version of `Strided` and required version of dependency `StridedViews` * Make the copy to AbstractArray more general * Add test for new GPUArrays copy
1 parent c82c506 commit c4025a3

6 files changed

Lines changed: 18 additions & 9 deletions

File tree

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Strided"
22
uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
3-
version = "2.3.2"
3+
version = "2.3.3"
44
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Maarten Van Damme <maartenvd1994@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
55

66
[deps]
@@ -28,7 +28,7 @@ JLArrays = "0.3.1"
2828
GPUArrays = "11.4.1"
2929
LinearAlgebra = "1.6"
3030
Random = "1.6"
31-
StridedViews = "0.4.5"
31+
StridedViews = "0.4.6"
3232
Test = "1.6"
3333
TupleTools = "1.6"
3434
julia = "1.6"

ext/StridedGPUArraysExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,11 @@ function Base.Broadcast.BroadcastStyle(gpu_sv::StridedView{T, N, TA}) where {T,
1212
return typeof(raw_style)(Val(N)) # sets the dimensionality correctly
1313
end
1414

15+
function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TS <: Number, NS, TAS <: AbstractGPUArray{TS}, FS <: ALL_FS}
16+
bc_style = Base.Broadcast.BroadcastStyle(TAS)
17+
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
18+
GPUArrays._copyto!(dst, bc)
19+
return dst
20+
end
21+
1522
end

ext/StridedJLArraysExt.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,4 @@ function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS,
1313
return dst
1414
end
1515

16-
function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TS <: Number, NS, TAS <: JLArray{TS}, FS <: ALL_FS}
17-
bc_style = Base.Broadcast.BroadcastStyle(TAS)
18-
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
19-
GPUArrays._copyto!(dst, bc)
20-
return dst
21-
end
22-
2316
end

test/amd.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
77
A1 = ROCMatrix(randn(T, (m1, m2)))
88
end
99
A2 = similar(A1)
10+
zA1 = ROCMatrix(f1(zeros(T, (m1, m2))))
11+
zA2 = ROCMatrix(f2(zeros(T, (m1, m2))))
1012
A1c = copy(A1)
1113
A2c = copy(A2)
1214
B1 = f1(StridedView(A1c))
1315
B2 = f2(StridedView(A2c))
1416
axes(f1(A1)) == axes(f2(A2)) || continue
1517
@test collect(ROCMatrix(copy!(f2(A2), f1(A1)))) == AMDGPU.Adapt.adapt(Vector{T}, copy!(B2, B1))
18+
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
1619
end
1720
end
1821
end

test/cuda.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
33
for m1 in (0, 16, 32), m2 in (0, 16, 32)
44
A1 = CUDA.randn(T, (m1, m2))
55
A2 = similar(A1)
6+
zA1 = CuMatrix(f1(zeros(T, (m1, m2))))
7+
zA2 = CuMatrix(f2(zeros(T, (m1, m2))))
68
A1c = copy(A1)
79
A2c = copy(A2)
810
B1 = f1(StridedView(A1c))
911
B2 = f2(StridedView(A2c))
1012
axes(f1(A1)) == axes(f2(A2)) || continue
1113
@test collect(CuMatrix(copy!(f2(A2), f1(A1)))) == CUDA.Adapt.adapt(Vector{T}, copy!(B2, B1))
14+
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
1215
end
1316
end
1417
end

test/jlarrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
33
for m1 in (0, 16, 32), m2 in (0, 16, 32)
44
A1 = JLArray(randn(T, (m1, m2)))
55
A2 = similar(A1)
6+
zA1 = JLArray(f1(zeros(T, (m1, m2))))
7+
zA2 = JLArray(f2(zeros(T, (m1, m2))))
68
A1c = copy(A1)
79
A2c = copy(A2)
810
B1 = f1(StridedView(A1c))
911
B2 = f2(StridedView(A2c))
1012
axes(f1(A1)) == axes(f2(A2)) || continue
1113
@test collect(Matrix(copy!(f2(A2), f1(A1)))) == JLArrays.Adapt.adapt(Vector{T}, copy!(B2, B1))
14+
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
1215
end
1316
end
1417
end

0 commit comments

Comments
 (0)