Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FixedEffects"
uuid = "c8885935-8500-56a7-9867-7708b20db0eb"
version = "3.2.0"
version = "3.3.0"

[deps]
GroupedArrays = "6407cd72-fade-4a84-8a1e-56e431fc1533"
Expand Down
8 changes: 7 additions & 1 deletion src/FixedEffect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ end

function FixedEffect(args...; interaction::AbstractVector = uweights(length(args[1])))
g = GroupedArray(args..., sort = nothing)
FixedEffect{typeof(g.groups), typeof(interaction)}(g.groups, interaction, g.ngroups)
# Store refs as Int32 (refs lie in [0, ngroups]) to halve the dominant memory stream read by the
# scatter/gather kernels on every solver iteration. GroupedArrays always builds Int64 groups (it
# needs signed sentinels during construction); narrowing here, where FixedEffect manufactures its
# own ref representation, lets every backend (CPU/GPU) and solve_coefficients! stream the smaller
# type. The rare ngroups > typemax(Int32) keeps the original integer type.
refs = g.ngroups > typemax(Int32) ? g.groups : convert(Vector{Int32}, g.groups)
FixedEffect{typeof(refs), typeof(interaction)}(refs, interaction, g.ngroups)
end

Base.show(io::IO, ::FixedEffect) = print(io, "Fixed Effects")
Expand Down
27 changes: 22 additions & 5 deletions src/utils/lsmr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@ struct ConvergenceHistory{T, R}
residuals::R
end

# Fast 2-norm for the well-scaled length-N CPU vectors in LSMR.
# `LinearAlgebra.norm` (BLAS nrm2 / generic_norm2) takes an overflow-safe scaled path that is
# ~8x slower than a plain SIMD sum-of-squares. LSMR's u/v are unit-normalized every iteration
# and the residual stays O(data), so the overflow guard is unnecessary here. The sum of squares
# is accumulated in Float64 even for Float32 inputs: this is more accurate than the in-precision
# BLAS `snrm2` (avoids ~sqrt(N)*eps(Float32) drift over millions of rows) at negligible cost.
# Restricted to concrete `Vector{<:BlasFloat}` so GPU arrays (CuVector/MtlVector) and
# FixedEffectCoefficients keep the generic `norm` (the fast path would trigger scalar GPU indexing).
@inline function _norm2(x::Vector{T}) where {T<:Union{Float32,Float64}}
s = 0.0
@inbounds @simd for i in eachindex(x)
s += abs2(Float64(x[i]))
end
return T(sqrt(s))
end
_norm2(x) = norm(x)



##############################################################################
Expand Down Expand Up @@ -62,10 +79,10 @@ function lsmr!(x, A, b, v, h, hbar;
conlim > 0 ? ctol = convert(Tr, inv(conlim)) : ctol = zero(Tr)
# form the first vectors u and v (satisfy β*u = b, α*v = A'u)
u = mul!(b, A, x, -1, 1)
β = norm(u)
β = _norm2(u)
β > 0 && rmul!(u, inv(β))
mul!(v, A', u, 1, 0)
α = norm(v)
α = _norm2(v)
α > 0 && rmul!(v, inv(α))
# Initialize variables for 1st iteration.
ζbar = α * β
Expand Down Expand Up @@ -103,11 +120,11 @@ function lsmr!(x, A, b, v, h, hbar;
while iter < maxiter
iter += 1
mul!(u, A, v, 1, -α)
β = norm(u)
β = _norm2(u)
if β > 0
rmul!(u, inv(β))
mul!(v, A', u, 1, -β)
α = norm(v)
α = _norm2(v)
α > 0 && rmul!(v, inv(α))
end

Expand Down Expand Up @@ -190,7 +207,7 @@ function lsmr!(x, A, b, v, h, hbar;

# Compute norms for convergence testing.
normAr = abs(ζbar)
normx = norm(x)
normx = _norm2(x)



Expand Down
4 changes: 2 additions & 2 deletions test/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import Base.==
@test sprint(show, fe1) == "Fixed Effects"
@test sprint(show, MIME("text/plain"), fe1) == """
Fixed Effects:
refs (10-element Vector{Int64}):
refs (10-element Vector{Int32}):
[1, 2, 3, 4, 5, ... ]
interaction (UnitWeights):
none"""
fe2 = FixedEffect(1:10, interaction=fill(1.23456789, 10))
@test sprint(show, MIME("text/plain"), fe2) == """
Fixed Effects:
refs (10-element Vector{Int64}):
refs (10-element Vector{Int32}):
[1, 2, 3, 4, 5, ... ]
interaction (10-element Vector{Float64}):
[1.23457, 1.23457, 1.23457, 1.23457, 1.23457, ... ]"""
Expand Down
Loading