|
| 1 | +""" |
| 2 | + module Concatenate |
| 3 | +
|
| 4 | +Alternative implementation for `Base.cat` through `Concatenate.cat(!)`. |
| 5 | +
|
| 6 | +This is mostly a copy of the Base implementation, with the main difference being |
| 7 | +that the destination is chosen based on all inputs instead of just the first. |
| 8 | +
|
| 9 | +Additionally, we have an intermediate representation in terms of a Concatenated object, |
| 10 | +reminiscent of how Broadcast works. |
| 11 | +
|
| 12 | +The various entry points for specializing behavior are: |
| 13 | +
|
| 14 | +* Destination selection can be achieved through: |
| 15 | +
|
| 16 | +```julia |
| 17 | +Base.similar(concat::Concatenated{Style}, ::Type{T}, axes) where {Style} |
| 18 | +``` |
| 19 | +
|
| 20 | +* Custom implementations: |
| 21 | +
|
| 22 | +```julia |
| 23 | +Base.copy(concat::Concatenated{Style}) # custom implementation of cat |
| 24 | +Base.copyto!(dest, concat::Concatenated{Style}) # custom implementation of cat! based on style |
| 25 | +Base.copyto!(dest, concat::Concatenated{Nothing}) # custom implementation of cat! based on typeof(dest) |
| 26 | +``` |
| 27 | +""" |
| 28 | +module Concatenate |
| 29 | + |
| 30 | +export concatenate |
| 31 | +VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public Concatenated, cat, cat!, concatenated")) |
| 32 | + |
| 33 | +using Base: promote_eltypeof |
| 34 | +import Base.Broadcast as BC |
| 35 | +using ..FunctionImplementations: zero! |
| 36 | + |
| 37 | +unval(::Val{x}) where {x} = x |
| 38 | + |
| 39 | +function _Concatenated end |
| 40 | + |
| 41 | +""" |
| 42 | + Concatenated{Style, Dims, Args <: Tuple} |
| 43 | +
|
| 44 | +Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide |
| 45 | +hooks to customize the implementation. |
| 46 | +""" |
| 47 | +struct Concatenated{Style, Dims, Args <: Tuple} |
| 48 | + style::Style |
| 49 | + dims::Val{Dims} |
| 50 | + args::Args |
| 51 | + global @inline function _Concatenated( |
| 52 | + style::Style, dims::Val{Dims}, args::Args |
| 53 | + ) where {Style, Dims, Args <: Tuple} |
| 54 | + return new{Style, Dims, Args}(style, dims, args) |
| 55 | + end |
| 56 | +end |
| 57 | + |
| 58 | +function Concatenated( |
| 59 | + style::Union{BC.AbstractArrayStyle, Nothing}, dims::Val, args::Tuple |
| 60 | + ) |
| 61 | + return _Concatenated(style, dims, args) |
| 62 | +end |
| 63 | +function Concatenated(dims::Val, args::Tuple) |
| 64 | + return Concatenated(cat_style(dims, args...), dims, args) |
| 65 | +end |
| 66 | +function Concatenated{Style}( |
| 67 | + dims::Val, args::Tuple |
| 68 | + ) where {Style <: Union{BC.AbstractArrayStyle, Nothing}} |
| 69 | + return Concatenated(Style(), dims, args) |
| 70 | +end |
| 71 | + |
| 72 | +dims(::Concatenated{<:Any, D}) where {D} = D |
| 73 | +style(concat::Concatenated) = getfield(concat, :style) |
| 74 | + |
| 75 | +concatenated(dims, args...) = concatenated(Val(dims), args...) |
| 76 | +concatenated(dims::Val, args...) = Concatenated(dims, args) |
| 77 | + |
| 78 | +function Base.convert( |
| 79 | + ::Type{Concatenated{NewStyle}}, concat::Concatenated{<:Any, Dims, Args} |
| 80 | + ) where {NewStyle, Dims, Args} |
| 81 | + return Concatenated{NewStyle}( |
| 82 | + concat.dims, concat.args |
| 83 | + )::Concatenated{NewStyle, Dims, Args} |
| 84 | +end |
| 85 | + |
| 86 | +# allocating the destination container |
| 87 | +# ------------------------------------ |
| 88 | +Base.similar(concat::Concatenated) = similar(concat, eltype(concat)) |
| 89 | +Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat)) |
| 90 | +function Base.similar(concat::Concatenated, ax) |
| 91 | + return similar(concat, eltype(concat), ax) |
| 92 | +end |
| 93 | + |
| 94 | +function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} |
| 95 | + # Convert to a broadcasted to leverage its similar implementation. |
| 96 | + bc = BC.Broadcasted(style(concat), identity, concat.args, ax) |
| 97 | + return similar(bc, T) |
| 98 | +end |
| 99 | + |
| 100 | +function cat_axis( |
| 101 | + a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... |
| 102 | + ) |
| 103 | + return cat_axis(cat_axis(a1, a2), a_rest...) |
| 104 | +end |
| 105 | +function cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) |
| 106 | + first(a1) == first(a2) == 1 || throw(ArgumentError("Concatenated axes must start at 1")) |
| 107 | + return Base.OneTo(length(a1) + length(a2)) |
| 108 | +end |
| 109 | + |
| 110 | +function cat_ndims(dims, as::AbstractArray...) |
| 111 | + return max(maximum(dims), maximum(ndims, as)) |
| 112 | +end |
| 113 | +function cat_ndims(dims::Val, as::AbstractArray...) |
| 114 | + return cat_ndims(unval(dims), as...) |
| 115 | +end |
| 116 | + |
| 117 | +function cat_axes(dims, a::AbstractArray, as::AbstractArray...) |
| 118 | + return ntuple(cat_ndims(dims, a, as...)) do dim |
| 119 | + return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim) |
| 120 | + end |
| 121 | +end |
| 122 | +function cat_axes(dims::Val, as::AbstractArray...) |
| 123 | + return cat_axes(unval(dims), as...) |
| 124 | +end |
| 125 | + |
| 126 | +function cat_style(dims, as::AbstractArray...) |
| 127 | + N = cat_ndims(dims, as...) |
| 128 | + return typeof(BC.combine_styles(as...))(Val(N)) |
| 129 | +end |
| 130 | + |
| 131 | +Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) |
| 132 | +Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...) |
| 133 | +Base.size(concat::Concatenated) = length.(axes(concat)) |
| 134 | +Base.ndims(concat::Concatenated) = cat_ndims(dims(concat), concat.args...) |
| 135 | + |
| 136 | +# Main logic |
| 137 | +# ---------- |
| 138 | +""" |
| 139 | + concatenate(dims, args...) |
| 140 | +
|
| 141 | +Concatenate the supplied `args` along dimensions `dims`. |
| 142 | +
|
| 143 | +See also [`cat`](@ref) and [`cat!`](@ref). |
| 144 | +""" |
| 145 | +concatenate(dims, args...) = Base.materialize(concatenated(dims, args...)) |
| 146 | + |
| 147 | +""" |
| 148 | + Concatenate.cat(args...; dims) |
| 149 | +
|
| 150 | +Concatenate the supplied `args` along dimensions `dims`. |
| 151 | +
|
| 152 | +See also [`concatenate`](@ref) and [`cat!`](@ref). |
| 153 | +""" |
| 154 | +cat(args...; dims) = concatenate(dims, args...) |
| 155 | +Base.materialize(concat::Concatenated) = copy(concat) |
| 156 | + |
| 157 | +""" |
| 158 | + Concatenate.cat!(dest, args...; dims) |
| 159 | +
|
| 160 | +Concatenate the supplied `args` along dimensions `dims`, placing the result into `dest`. |
| 161 | +""" |
| 162 | +function cat!(dest, args...; dims) |
| 163 | + Base.materialize!(dest, concatenated(dims, args...)) |
| 164 | + return dest |
| 165 | +end |
| 166 | +Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat) |
| 167 | + |
| 168 | +Base.copy(concat::Concatenated) = copyto!(similar(concat), concat) |
| 169 | + |
| 170 | +# The following is largely copied from the Base implementation of `Base.cat`, see: |
| 171 | +# https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887 |
| 172 | +_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) |
| 173 | +_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) |
| 174 | + |
| 175 | +cat_size(A) = (1,) |
| 176 | +cat_size(A::AbstractArray) = size(A) |
| 177 | +cat_size(A, d) = 1 |
| 178 | +cat_size(A::AbstractArray, d) = size(A, d) |
| 179 | + |
| 180 | +cat_indices(A, d) = Base.OneTo(1) |
| 181 | +cat_indices(A::AbstractArray, d) = axes(A, d) |
| 182 | + |
| 183 | +function __cat!(A, shape, catdims, X...) |
| 184 | + return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...) |
| 185 | +end |
| 186 | +function __cat_offset!(A, shape, catdims, offsets, x, X...) |
| 187 | + # splitting the "work" on x from X... may reduce latency (fewer costly specializations) |
| 188 | + newoffsets = __cat_offset1!(A, shape, catdims, offsets, x) |
| 189 | + return __cat_offset!(A, shape, catdims, newoffsets, X...) |
| 190 | +end |
| 191 | +__cat_offset!(A, shape, catdims, offsets) = A |
| 192 | +function __cat_offset1!(A, shape, catdims, offsets, x) |
| 193 | + inds = ntuple(length(offsets)) do i |
| 194 | + (i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i] |
| 195 | + end |
| 196 | + _copy_or_fill!(A, inds, x) |
| 197 | + newoffsets = ntuple(length(offsets)) do i |
| 198 | + (i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] |
| 199 | + end |
| 200 | + return newoffsets |
| 201 | +end |
| 202 | + |
| 203 | +dims2cat(dims::Val) = dims2cat(unval(dims)) |
| 204 | +function dims2cat(dims) |
| 205 | + if any(≤(0), dims) |
| 206 | + throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) |
| 207 | + end |
| 208 | + return ntuple(in(dims), maximum(dims)) |
| 209 | +end |
| 210 | + |
| 211 | +# default falls back to replacing style with Nothing |
| 212 | +# this permits specializing on typeof(dest) without ambiguities |
| 213 | +# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. |
| 214 | +@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated) |
| 215 | + return copyto!(dest, convert(Concatenated{Nothing}, concat)) |
| 216 | +end |
| 217 | + |
| 218 | +function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) |
| 219 | + catdims = dims2cat(dims(concat)) |
| 220 | + shape = size(concat) |
| 221 | + count(!iszero, catdims)::Int > 1 && zero!(dest) |
| 222 | + return __cat!(dest, shape, catdims, concat.args...) |
| 223 | +end |
| 224 | + |
| 225 | +end |
0 commit comments