Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,36 @@ function Dagger.move(from_proc::CuArrayDeviceProc, to_proc::CuArrayDeviceProc, x
end
else
# Different node, use DtoH, serialization, HtoD
return CuArray(remotecall_fetch(from_proc.owner, x) do x
Array(unwrap(x))
end)
host_copy = remotecall_fetch(from_proc.owner, from_proc, x) do from_proc, x
return with_context(from_proc) do
Array(unwrap(x))
end
end
return with_context(to_proc) do
return CuArray(host_copy)
end
end
end

function Dagger.move(from_proc::CuArrayDeviceProc, to_proc::CuArrayDeviceProc, x::CuArray)
if from_proc == to_proc
with_context(CUDA.synchronize, from_proc)
return x
elseif Dagger.root_worker_id(from_proc) == Dagger.root_worker_id(to_proc)
with_context(CUDA.synchronize, from_proc)
return with_context(to_proc) do
to_arr = similar(x)
copyto!(to_arr, x)
CUDA.synchronize()
return to_arr
end
else
host_copy = with_context(from_proc) do
return Array(x)
end
return with_context(to_proc) do
return CuArray(host_copy)
end
end
end

Expand Down
28 changes: 25 additions & 3 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,31 @@ function Dagger.move(from_proc::MtlArrayDeviceProc, to_proc::MtlArrayDeviceProc,
# FIXME: elseif Dagger.root_worker_id(from_proc) == Dagger.root_worker_id(to_proc)
else
# Different node, use DtoH, serialization, HtoD
return MtlArray(remotecall_fetch(from_proc.owner, x) do x
Array(unwrap(x))
end)
host_copy = remotecall_fetch(from_proc.owner, from_proc, x) do from_proc, x
return with_context(from_proc) do
Array(unwrap(x))
end
end
return with_context(to_proc) do
return MtlArray(host_copy)
end
end
end

function Dagger.move(from_proc::MtlArrayDeviceProc, to_proc::MtlArrayDeviceProc, x::MtlArray)
if from_proc == to_proc
# Same process and GPU, no change
with_context(Metal.synchronize, from_proc)
return x
# FIXME: elseif Dagger.root_worker_id(from_proc) == Dagger.root_worker_id(to_proc)
else
# Different node, use DtoH, serialization, HtoD
host_copy = with_context(from_proc) do
return Array(x)
end
return with_context(to_proc) do
return MtlArray(host_copy)
end
end
end

Expand Down
35 changes: 32 additions & 3 deletions ext/OpenCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,38 @@ function Dagger.move(from_proc::CLArrayDeviceProc, to_proc::CLArrayDeviceProc, x
end
else
# Different node, use DtoH, serialization, HtoD
return CLArray(remotecall_fetch(from_proc.owner, x) do x
Array(unwrap(x))
end)
host_copy = remotecall_fetch(from_proc.owner, from_proc, x) do from_proc, x
return with_context(from_proc) do
Array(unwrap(x))
end
end
return with_context(to_proc) do
return CLArray(host_copy)
end
end
end
function Dagger.move(from_proc::CLArrayDeviceProc, to_proc::CLArrayDeviceProc, x::CLArray) where T<:CLArray
if from_proc == to_proc
# Same process and GPU, no change
_sync_with_context(from_proc)
return x
elseif Dagger.root_worker_id(from_proc) == Dagger.root_worker_id(to_proc)
# Same process but different GPUs, use DtoD copy
_sync_with_context(from_proc)
return with_context(to_proc) do
to_arr = similar(x)
copyto!(to_arr, x)
cl.finish(cl.queue())
return to_arr
end
else
# Different node, use DtoH, serialization, HtoD
host_copy = with_context(from_proc) do
return Array(x)
end
return with_context(to_proc) do
return CLArray(host_copy)
end
end
end

Expand Down
36 changes: 33 additions & 3 deletions ext/ROCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,39 @@ function Dagger.move(from_proc::ROCArrayDeviceProc, to_proc::ROCArrayDeviceProc,
end
else
# Different node, use DtoH, serialization, HtoD
return ROCArray(remotecall_fetch(from_proc.owner, x) do x
Array(unwrap(x))
end)
host_copy = remotecall_fetch(from_proc.owner, from_proc, x) do from_proc, x
return with_context(from_proc) do
Array(unwrap(x))
end
end
return with_context(to_proc) do
return ROCArray(host_copy)
end
end
end

function Dagger.move(from_proc::ROCArrayDeviceProc, to_proc::ROCArrayDeviceProc, x::ROCArray)
if from_proc == to_proc
with_context(AMDGPU.synchronize, from_proc)
return x
elseif Dagger.root_worker_id(from_proc) == Dagger.root_worker_id(to_proc)
dev = AMDGPU.device(x)
with_context(AMDGPU.synchronize, dev.device_id)
return with_context(to_proc) do
to_arr = similar(x)
copyto!(to_arr, x)
AMDGPU.synchronize()
to_arr
end
else
host_copy = remotecall_fetch(from_proc.owner, from_proc, x) do from_proc, x
return with_context(from_proc) do
Array(unwrap(x))
end
end
return with_context(to_proc) do
return ROCArray(host_copy)
end
end
end

Expand Down
4 changes: 3 additions & 1 deletion test/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ end

if gpu != :all
local A, B
# ROCArray(rand(...)) not AMDGPU.rand: rocRAND RNG finalizers can error at
# process exit ("handle not managed by cache") when context/cache order differs.
AMDGPU.device!(AMDGPU.devices()[gpu]) do
A = AMDGPU.rand(128)
A = ROCArray(rand(Float32, 128))
B = AMDGPU.zeros(128)
end
Dagger.with_options(;scope=local_scope) do
Expand Down
Loading