Skip to content

Commit dd729db

Browse files
authored
Add zero! and concatenate (#6)
1 parent 62c92a0 commit dd729db

12 files changed

Lines changed: 347 additions & 3 deletions

File tree

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
name = "FunctionImplementations"
22
uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.3.0"
4+
version = "0.3.1"
55

66
[weakdeps]
7+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
8+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810

911
[extensions]
12+
FunctionImplementationsBlockArraysExt = "BlockArrays"
13+
FunctionImplementationsFillArraysExt = "FillArrays"
1014
FunctionImplementationsLinearAlgebraExt = "LinearAlgebra"
1115

1216
[compat]
17+
BlockArrays = "1.4"
18+
FillArrays = "1.15"
1319
LinearAlgebra = "1.10"
1420
julia = "1.10"
1521

docs/src/reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Reference
22

33
```@autodocs
4-
Modules = [FunctionImplementations]
4+
Modules = [FunctionImplementations, FunctionImplementations.Concatenate]
55
```
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module FunctionImplementationsBlockArraysExt
2+
3+
using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths
4+
using FunctionImplementations.Concatenate: Concatenate
5+
6+
function Concatenate.cat_axis(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
7+
first(a1) == first(a2) == 1 || throw(ArgumentError("Concatenated axes must start at 1"))
8+
return blockedrange([blocklengths(a1); blocklengths(a2)])
9+
end
10+
11+
end
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module FunctionImplementationsFillArraysExt
2+
3+
using FillArrays: RectDiagonal
4+
using FunctionImplementations: FunctionImplementations
5+
6+
function FunctionImplementations.permuteddims(a::RectDiagonal, perm)
7+
(ndims(a) == length(perm) && isperm(perm)) ||
8+
throw(ArgumentError("no valid permutation of dimensions"))
9+
return RectDiagonal(parent(a), ntuple(d -> axes(a)[perm[d]], ndims(a)))
10+
end
11+
12+
end

src/FunctionImplementations.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,7 @@ module FunctionImplementations
33
include("implementation.jl")
44
include("style.jl")
55
include("permuteddims.jl")
6+
include("zero.jl")
7+
include("concatenate.jl")
68

79
end

src/concatenate.jl

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
"""
2+
module Concatenate
3+
4+
Alternative implementation for `Base.cat` through `Concatenate.cat(!)`.
5+
6+
This is mostly a copy of the Base implementation, with the main difference being
7+
that the destination is chosen based on all inputs instead of just the first.
8+
9+
Additionally, we have an intermediate representation in terms of a Concatenated object,
10+
reminiscent of how Broadcast works.
11+
12+
The various entry points for specializing behavior are:
13+
14+
* Destination selection can be achieved through:
15+
16+
```julia
17+
Base.similar(concat::Concatenated{Style}, ::Type{T}, axes) where {Style}
18+
```
19+
20+
* Custom implementations:
21+
22+
```julia
23+
Base.copy(concat::Concatenated{Style}) # custom implementation of cat
24+
Base.copyto!(dest, concat::Concatenated{Style}) # custom implementation of cat! based on style
25+
Base.copyto!(dest, concat::Concatenated{Nothing}) # custom implementation of cat! based on typeof(dest)
26+
```
27+
"""
28+
module Concatenate
29+
30+
export concatenate
31+
VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public Concatenated, cat, cat!, concatenated"))
32+
33+
using Base: promote_eltypeof
34+
import Base.Broadcast as BC
35+
using ..FunctionImplementations: zero!
36+
37+
unval(::Val{x}) where {x} = x
38+
39+
function _Concatenated end
40+
41+
"""
42+
Concatenated{Style, Dims, Args <: Tuple}
43+
44+
Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide
45+
hooks to customize the implementation.
46+
"""
47+
struct Concatenated{Style, Dims, Args <: Tuple}
48+
style::Style
49+
dims::Val{Dims}
50+
args::Args
51+
global @inline function _Concatenated(
52+
style::Style, dims::Val{Dims}, args::Args
53+
) where {Style, Dims, Args <: Tuple}
54+
return new{Style, Dims, Args}(style, dims, args)
55+
end
56+
end
57+
58+
function Concatenated(
59+
style::Union{BC.AbstractArrayStyle, Nothing}, dims::Val, args::Tuple
60+
)
61+
return _Concatenated(style, dims, args)
62+
end
63+
function Concatenated(dims::Val, args::Tuple)
64+
return Concatenated(cat_style(dims, args...), dims, args)
65+
end
66+
function Concatenated{Style}(
67+
dims::Val, args::Tuple
68+
) where {Style <: Union{BC.AbstractArrayStyle, Nothing}}
69+
return Concatenated(Style(), dims, args)
70+
end
71+
72+
dims(::Concatenated{<:Any, D}) where {D} = D
73+
style(concat::Concatenated) = getfield(concat, :style)
74+
75+
concatenated(dims, args...) = concatenated(Val(dims), args...)
76+
concatenated(dims::Val, args...) = Concatenated(dims, args)
77+
78+
function Base.convert(
79+
::Type{Concatenated{NewStyle}}, concat::Concatenated{<:Any, Dims, Args}
80+
) where {NewStyle, Dims, Args}
81+
return Concatenated{NewStyle}(
82+
concat.dims, concat.args
83+
)::Concatenated{NewStyle, Dims, Args}
84+
end
85+
86+
# allocating the destination container
87+
# ------------------------------------
88+
Base.similar(concat::Concatenated) = similar(concat, eltype(concat))
89+
Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat))
90+
function Base.similar(concat::Concatenated, ax)
91+
return similar(concat, eltype(concat), ax)
92+
end
93+
94+
function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
95+
# Convert to a broadcasted to leverage its similar implementation.
96+
bc = BC.Broadcasted(style(concat), identity, concat.args, ax)
97+
return similar(bc, T)
98+
end
99+
100+
function cat_axis(
101+
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
102+
)
103+
return cat_axis(cat_axis(a1, a2), a_rest...)
104+
end
105+
function cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange)
106+
first(a1) == first(a2) == 1 || throw(ArgumentError("Concatenated axes must start at 1"))
107+
return Base.OneTo(length(a1) + length(a2))
108+
end
109+
110+
function cat_ndims(dims, as::AbstractArray...)
111+
return max(maximum(dims), maximum(ndims, as))
112+
end
113+
function cat_ndims(dims::Val, as::AbstractArray...)
114+
return cat_ndims(unval(dims), as...)
115+
end
116+
117+
function cat_axes(dims, a::AbstractArray, as::AbstractArray...)
118+
return ntuple(cat_ndims(dims, a, as...)) do dim
119+
return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim)
120+
end
121+
end
122+
function cat_axes(dims::Val, as::AbstractArray...)
123+
return cat_axes(unval(dims), as...)
124+
end
125+
126+
function cat_style(dims, as::AbstractArray...)
127+
N = cat_ndims(dims, as...)
128+
return typeof(BC.combine_styles(as...))(Val(N))
129+
end
130+
131+
Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
132+
Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...)
133+
Base.size(concat::Concatenated) = length.(axes(concat))
134+
Base.ndims(concat::Concatenated) = cat_ndims(dims(concat), concat.args...)
135+
136+
# Main logic
137+
# ----------
138+
"""
139+
concatenate(dims, args...)
140+
141+
Concatenate the supplied `args` along dimensions `dims`.
142+
143+
See also [`cat`](@ref) and [`cat!`](@ref).
144+
"""
145+
concatenate(dims, args...) = Base.materialize(concatenated(dims, args...))
146+
147+
"""
148+
Concatenate.cat(args...; dims)
149+
150+
Concatenate the supplied `args` along dimensions `dims`.
151+
152+
See also [`concatenate`](@ref) and [`cat!`](@ref).
153+
"""
154+
cat(args...; dims) = concatenate(dims, args...)
155+
Base.materialize(concat::Concatenated) = copy(concat)
156+
157+
"""
158+
Concatenate.cat!(dest, args...; dims)
159+
160+
Concatenate the supplied `args` along dimensions `dims`, placing the result into `dest`.
161+
"""
162+
function cat!(dest, args...; dims)
163+
Base.materialize!(dest, concatenated(dims, args...))
164+
return dest
165+
end
166+
Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat)
167+
168+
Base.copy(concat::Concatenated) = copyto!(similar(concat), concat)
169+
170+
# The following is largely copied from the Base implementation of `Base.cat`, see:
171+
# https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887
172+
_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x)
173+
_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x)
174+
175+
cat_size(A) = (1,)
176+
cat_size(A::AbstractArray) = size(A)
177+
cat_size(A, d) = 1
178+
cat_size(A::AbstractArray, d) = size(A, d)
179+
180+
cat_indices(A, d) = Base.OneTo(1)
181+
cat_indices(A::AbstractArray, d) = axes(A, d)
182+
183+
function __cat!(A, shape, catdims, X...)
184+
return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)
185+
end
186+
function __cat_offset!(A, shape, catdims, offsets, x, X...)
187+
# splitting the "work" on x from X... may reduce latency (fewer costly specializations)
188+
newoffsets = __cat_offset1!(A, shape, catdims, offsets, x)
189+
return __cat_offset!(A, shape, catdims, newoffsets, X...)
190+
end
191+
__cat_offset!(A, shape, catdims, offsets) = A
192+
function __cat_offset1!(A, shape, catdims, offsets, x)
193+
inds = ntuple(length(offsets)) do i
194+
(i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i]
195+
end
196+
_copy_or_fill!(A, inds, x)
197+
newoffsets = ntuple(length(offsets)) do i
198+
(i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i]
199+
end
200+
return newoffsets
201+
end
202+
203+
dims2cat(dims::Val) = dims2cat(unval(dims))
204+
function dims2cat(dims)
205+
if any((0), dims)
206+
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
207+
end
208+
return ntuple(in(dims), maximum(dims))
209+
end
210+
211+
# default falls back to replacing style with Nothing
212+
# this permits specializing on typeof(dest) without ambiguities
213+
# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base.
214+
@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated)
215+
return copyto!(dest, convert(Concatenated{Nothing}, concat))
216+
end
217+
218+
function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing})
219+
catdims = dims2cat(dims(concat))
220+
shape = size(concat)
221+
count(!iszero, catdims)::Int > 1 && zero!(dest)
222+
return __cat!(dest, shape, catdims, concat.args...)
223+
end
224+
225+
end

src/zero.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
zero!(a::AbstractArray)
3+
4+
In-place version of `zero(a)`, sets all entries of `a` to zero.
5+
"""
6+
zero!(a::AbstractArray) = style(a)(zero!)(a)
7+
function (::Implementation{typeof(zero!)})(a::AbstractArray)
8+
fill!(a, zero(eltype(a)))
9+
return a
10+
end

test/Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
[deps]
2+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
23
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
5+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
36
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
7+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
48
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
59
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
610
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
@@ -10,8 +14,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1014
FunctionImplementations = {path = ".."}
1115

1216
[compat]
17+
Adapt = "4"
1318
Aqua = "0.8"
19+
BlockArrays = "1.4"
20+
FillArrays = "1.15"
1421
FunctionImplementations = "0.3"
22+
JLArrays = "0.3"
1523
LinearAlgebra = "1.10"
1624
SafeTestsets = "0.1"
1725
Suppressor = "0.2"

test/test_blockarraysext.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using BlockArrays: BlockArray, blockedrange, blockisequal
2+
using FunctionImplementations.Concatenate: concatenate
3+
using Test: @test, @testset
4+
5+
@testset "BlockArraysExt" begin
6+
a = BlockArray(randn(4, 4), [2, 2], [2, 2])
7+
b = BlockArray(randn(4, 4), [2, 2], [2, 2])
8+
9+
concat = concatenate(1, a, b)
10+
@test axes(concat) == (Base.OneTo(8), Base.OneTo(4))
11+
@test blockisequal(axes(concat, 1), blockedrange([2, 2, 2, 2]))
12+
@test blockisequal(axes(concat, 2), blockedrange([2, 2]))
13+
@test size(concat) == (8, 4)
14+
@test eltype(concat) Float64
15+
@test copy(concat) == cat(a, b; dims = 1)
16+
@test copy(concat) isa BlockArray{Float64, 2}
17+
end

test/test_concatenate.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using Adapt: adapt
2+
using FunctionImplementations.Concatenate: concatenated
3+
using JLArrays: JLArray
4+
using Test: @test, @testset
5+
6+
@testset "Concatenated" for arrayt in (Array, JLArray)
7+
dev = adapt(arrayt)
8+
a = dev(randn(Float32, 2, 2))
9+
b = dev(randn(Float64, 2, 2))
10+
11+
concat = concatenated((1, 2), a, b)
12+
@test axes(concat) == Base.OneTo.((4, 4))
13+
@test size(concat) == (4, 4)
14+
@test eltype(concat) === Float64
15+
@test copy(concat) == cat(a, b; dims = (1, 2))
16+
@test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 2}
17+
18+
concat = concatenated(1, a, b)
19+
@test axes(concat) == Base.OneTo.((4, 2))
20+
@test size(concat) == (4, 2)
21+
@test eltype(concat) === Float64
22+
@test copy(concat) == cat(a, b; dims = 1)
23+
@test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 2}
24+
25+
concat = concatenated(3, a, b)
26+
@test axes(concat) == Base.OneTo.((2, 2, 2))
27+
@test size(concat) == (2, 2, 2)
28+
@test eltype(concat) === Float64
29+
@test copy(concat) == cat(a, b; dims = 3)
30+
@test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 3}
31+
32+
concat = concatenated(4, a, b)
33+
@test axes(concat) == Base.OneTo.((2, 2, 1, 2))
34+
@test size(concat) == (2, 2, 1, 2)
35+
@test eltype(concat) === Float64
36+
@test copy(concat) == cat(a, b; dims = 4)
37+
@test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 4}
38+
end

0 commit comments

Comments
 (0)