forked from JuliaParallel/DistributedArrays.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbroadcast.jl
More file actions
152 lines (131 loc) · 6.09 KB
/
broadcast.jl
File metadata and controls
152 lines (131 loc) · 6.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
###
# Distributed broadcast implementation
##
# We define a custom ArrayStyle here since we need to keep track of
# the fact that it is Distributed and what kind of underlying broadcast behaviour
# we will encounter.
struct DArrayStyle{Style <: Union{Nothing,BroadcastStyle}} <: Broadcast.AbstractArrayStyle{Any} end
DArrayStyle(::S) where {S} = DArrayStyle{S}()
DArrayStyle(::S, ::Val{N}) where {S,N} = DArrayStyle(S(Val(N)))
DArrayStyle(::Val{N}) where N = DArrayStyle{Broadcast.DefaultArrayStyle{N}}()
Broadcast.BroadcastStyle(::Type{<:DArray{<:Any, N, A}}) where {N, A} = DArrayStyle(BroadcastStyle(A), Val(N))
# promotion rules
# TODO: test this
function Broadcast.BroadcastStyle(::DArrayStyle{AStyle}, ::DArrayStyle{BStyle}) where {AStyle, BStyle}
DArrayStyle(BroadcastStyle(AStyle, BStyle))
end
function Broadcast.broadcasted(::DArrayStyle{Style}, f, args...) where Style
inner = Broadcast.broadcasted(Style(), f, args...)
if inner isa Broadcasted
return Broadcasted{DArrayStyle{Style}}(inner.f, inner.args, inner.axes)
else # eagerly evaluated
return inner
end
end
# # deal with one layer deep lazy arrays
# BroadcastStyle(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where T <: DArray = BroadcastStyle(T)
# BroadcastStyle(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where T <: DArray = BroadcastStyle(T)
# BroadcastStyle(::Type{<:SubArray{<:Any,<:Any,<:T}}) where T <: DArray = BroadcastStyle(T)
# # This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
# # and we could define our methods in terms of Union{DArray, WrappedArray{<:Any, <:DArray}}
# const DDestArray = Union{DArray,
# LinearAlgebra.Transpose{<:Any,<:DArray},
# LinearAlgebra.Adjoint{<:Any,<:DArray},
# SubArray{<:Any, <:Any, <:DArray}}
const DDestArray = DArray
# This method is responsible for selection the output type of broadcast
function Base.similar(bc::Broadcasted{<:DArrayStyle{Style}}, ::Type{ElType}) where {Style, ElType}
DArray(map(length, axes(bc))) do I
# create fake Broadcasted for underlying ArrayStyle
bc′ = Broadcasted{Style}(identity, (), map(length, I))
similar(bc′, ElType)
end
end
##
# Ref https://docs.julialang.org/en/v1/manual/interfaces/#extending-in-place-broadcast-2
#
# We purposefully only specialise `copyto!`,
# Broadcast implementation that defers to the underlying BroadcastStyle. We can't
# assume that `getindex` is fast, furthermore we can't assume that the distribution of
# DArray across workers is equal or that the underlying array type is consistent.
#
# Implementation:
# - first distribute all arguments
# - Q: How do decide on the cuts
# - then localise arguments on each node
##
@inline function Base.copyto!(dest::DDestArray, bc::Broadcasted{Nothing})
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
# Distribute Broadcasted
# This will turn local AbstractArrays into DArrays
dbc = bcdistribute(bc)
@sync for p in procs(dest)
@async remotecall_wait(p) do
# get the indices for the localpart
lpidx = localpartindex(dest)
@assert lpidx != 0
# create a local version of the broadcast, by constructing views
# Note: creates copies of the argument
lbc = bclocal(dbc, dest.indices[lpidx])
copyto!(localpart(dest), lbc)
end
end
return dest
end
# Test
# a = Array
# a .= DArray(x,y)
@inline function Base.copy(bc::Broadcasted{<:DArrayStyle})
dbc = bcdistribute(bc)
# TODO: teach DArray about axes since this is wrong for OffsetArrays
DArray(map(length, axes(bc))) do I
lbc = bclocal(dbc, I)
copy(lbc)
end
end
# _bcview creates takes the shapes of a view and the shape of a broadcasted argument,
# and produces the view over that argument that constitutes part of the broadcast
# it is in a sense the inverse of _bcs in Base.Broadcast
_bcview(::Tuple{}, ::Tuple{}) = ()
_bcview(::Tuple{}, view::Tuple) = ()
_bcview(shape::Tuple, ::Tuple{}) = (shape[1], _bcview(tail(shape), ())...)
function _bcview(shape::Tuple, view::Tuple)
return (_bcview1(shape[1], view[1]), _bcview(tail(shape), tail(view))...)
end
# _bcview1 handles the logic for a single dimension
function _bcview1(a, b)
if a == 1 || a == 1:1
return 1:1
elseif first(a) <= first(b) <= last(a) &&
first(a) <= last(b) <= last(b)
return b
else
throw(DimensionMismatch("broadcast view could not be constructed"))
end
end
# Distribute broadcast
# TODO: How to decide on cuts
@inline bcdistribute(bc::Broadcasted{Style}) where Style<:Union{Nothing,BroadcastStyle} = Broadcasted{DArrayStyle{Style}}(bc.f, bcdistribute_args(bc.args), bc.axes)
@inline bcdistribute(bc::Broadcasted{Style}) where Style<:DArrayStyle = Broadcasted{Style}(bc.f, bcdistribute_args(bc.args), bc.axes)
# ask BroadcastStyle to decide if argument is in need of being distributed
bcdistribute(x::T) where T = _bcdistribute(BroadcastStyle(T), x)
_bcdistribute(::DArrayStyle, x) = x
# Don't bother distributing singletons
_bcdistribute(::Broadcast.AbstractArrayStyle{0}, x) = x
_bcdistribute(::Broadcast.AbstractArrayStyle, x) = distribute(x)
_bcdistribute(::Any, x) = x
@inline bcdistribute_args(args::Tuple) = (bcdistribute(args[1]), bcdistribute_args(tail(args))...)
bcdistribute_args(args::Tuple{Any}) = (bcdistribute(args[1]),)
bcdistribute_args(args::Tuple{}) = ()
# dropping axes here since recomputing is easier
@inline bclocal(bc::Broadcasted{DArrayStyle{Style}}, idxs) where Style<:Union{Nothing,BroadcastStyle} = Broadcasted{Style}(bc.f, bclocal_args(_bcview(axes(bc), idxs), bc.args))
# bclocal will do a view of the data and the copy it over
# except when the data already is local
function bclocal(x::DArray{T, N, AT}, idxs) where {T, N, AT}
bcidxs = _bcview(axes(x), idxs)
makelocal(x, bcidxs...)
end
bclocal(x, idxs) = x
@inline bclocal_args(idxs, args::Tuple) = (bclocal(args[1], idxs), bclocal_args(idxs, tail(args))...)
bclocal_args(idxs, args::Tuple{Any}) = (bclocal(args[1], idxs),)
bclocal_args(idxs, args::Tuple{}) = ()