diff --git a/src/additional_functions/helper.jl b/src/additional_functions/helper.jl index d6a1fc6c8..60365def0 100644 --- a/src/additional_functions/helper.jl +++ b/src/additional_functions/helper.jl @@ -89,3 +89,37 @@ function nonunique(values::AbstractVector) end return res end + +# check that a model only has a single lossfun +function check_single_lossfun(model::AbstractSemSingle; throw_error) + if (length(model.loss.functions) > 1) & throw_error + @error "The model has $(length(sem.loss.functions)) loss functions. + Only a single loss function is supported." + end + return isone(length(model.loss.functions)) +end + +# check that all models use the same single loss function +function check_single_lossfun(models::AbstractSemSingle...; throw_error) + uniform = true + lossfun = models[1].loss.functions[1] + L = typeof(lossfun) + for (i, model) in enumerate(models) + uniform &= check_single_lossfun(model; throw_error = throw_error) + cur_lossfun = model.loss.functions[1] + if !isa(cur_lossfun, L) & throw_error + @error "Loss function for group #$i model is $(typeof(cur_lossfun)), expected $L. + Heterogeneous loss functions are not supported." + end + uniform &= isa(cur_lossfun, L) + end + return uniform +end + +check_single_lossfun(model::SemEnsemble; throw_error) = + check_single_lossfun(model.sems...; throw_error) + +# scaling corrections for multigroup models +mg_correction(::SemFIML) = 0 +mg_correction(::SemML) = 0 +mg_correction(::SemWLS) = -1 diff --git a/src/frontend/fit/fitmeasures/RMSEA.jl b/src/frontend/fit/fitmeasures/RMSEA.jl index f9dae84ed..764b5e116 100644 --- a/src/frontend/fit/fitmeasures/RMSEA.jl +++ b/src/frontend/fit/fitmeasures/RMSEA.jl @@ -7,13 +7,24 @@ function RMSEA end RMSEA(fit::SemFit) = RMSEA(fit, fit.model) -RMSEA(fit::SemFit, model::AbstractSemSingle) = RMSEA(dof(fit), χ²(fit), nsamples(fit)) +function RMSEA(fit::SemFit, model::AbstractSemSingle) + check_single_lossfun(model; throw_error = true) + return RMSEA(dof(fit), χ²(fit), nsamples(fit)+rmsea_correction(model.loss.functions[1])) +end -RMSEA(fit::SemFit, model::SemEnsemble) = - sqrt(length(model.sems)) * RMSEA(dof(fit), χ²(fit), nsamples(fit)) +function RMSEA(fit::SemFit, model::SemEnsemble) + check_single_lossfun(model; throw_error = true) + n = nsamples(fit)+model.n*rmsea_correction(model.sems[1].loss.functions[1]) + return sqrt(length(model.sems)) * RMSEA(dof(fit), χ²(fit), n) +end -function RMSEA(dof, chi2, nsamples) - rmsea = (chi2 - dof) / (nsamples * dof) - rmsea > 0 ? nothing : rmsea = 0 +function RMSEA(dof, chi2, N⁻) + rmsea = (chi2 - dof) / (N⁻ * dof) + rmsea = rmsea > 0 ? rmsea : 0 return sqrt(rmsea) end + +# scaling corrections +rmsea_correction(::SemFIML) = 0 +rmsea_correction(::SemML) = -1 +rmsea_correction(::SemWLS) = -1 diff --git a/src/frontend/fit/fitmeasures/chi2.jl b/src/frontend/fit/fitmeasures/chi2.jl index dc19467fc..9ebb06bd9 100644 --- a/src/frontend/fit/fitmeasures/chi2.jl +++ b/src/frontend/fit/fitmeasures/chi2.jl @@ -9,20 +9,21 @@ Return the χ² value. # Single Models ############################################################################################ -χ²(fit::SemFit, model::AbstractSemSingle) = - sum(loss -> χ²(loss, fit, model), model.loss.functions) +function χ²(fit::SemFit, model::AbstractSemSingle) + check_single_lossfun(model; throw_error = true) + return χ²(model.loss.functions[1], fit::SemFit, model::AbstractSemSingle) +end -# RAM + SemML -χ²(lossfun::SemML, fit::SemFit, model::AbstractSemSingle) = +χ²(::SemML, fit::SemFit, model::AbstractSemSingle) = (nsamples(fit) - 1) * - (fit.minimum - logdet(obs_cov(observed(model))) - nobserved_vars(observed(model))) + (fit.minimum - logdet(obs_cov(observed(model))) - nobserved_vars(model)) # bollen, p. 115, only correct for GLS weight matrix -χ²(lossfun::SemWLS, fit::SemFit, model::AbstractSemSingle) = +χ²(::SemWLS, fit::SemFit, model::AbstractSemSingle) = (nsamples(fit) - 1) * fit.minimum # FIML -function χ²(lossfun::SemFIML, fit::SemFit, model::AbstractSemSingle) +function χ²(::SemFIML, fit::SemFit, model::AbstractSemSingle) ll_H0 = minus2ll(fit) ll_H1 = minus2ll(observed(model)) return ll_H0 - ll_H1 @@ -32,38 +33,28 @@ end # Collections ############################################################################################ -function χ²(fit::SemFit, models::SemEnsemble) - isempty(models.sems) && return 0.0 - - lossfun = models.sems[1].loss.functions[1] - # check that all models use the same single loss function - L = typeof(lossfun) - for (i, sem) in enumerate(models.sems) - if length(sem.loss.functions) > 1 - @error "Model for group #$i has $(length(sem.loss.functions)) loss functions. Only the single one is supported" - end - cur_lossfun = sem.loss.functions[1] - if !isa(cur_lossfun, L) - @error "Loss function for group #$i model is $(typeof(cur_lossfun)), expected $L. Heterogeneous loss functions are not supported" - end - end - - return χ²(lossfun, fit, models) +function χ²(fit::SemFit, model::SemEnsemble) + check_single_lossfun(model; throw_error = true) + lossfun = model.sems[1].loss.functions[1] + return χ²(lossfun, fit, model) end -function χ²(lossfun::SemWLS, fit::SemFit, models::SemEnsemble) - return (nsamples(models) - 1) * fit.minimum +function χ²(::SemWLS, fit::SemFit, models::SemEnsemble) + return (nsamples(models) - models.n) * fit.minimum end -function χ²(lossfun::SemML, fit::SemFit, models::SemEnsemble) - G = sum(zip(models.weights, models.sems)) do (w, model) - data = observed(model) - w * (logdet(obs_cov(data)) + nobserved_vars(data)) +function χ²(::SemML, fit::SemFit, models::SemEnsemble) + F = 0 + for model in models.sems + Fᵢ = objective(model, fit.solution) + Fᵢ -= logdet(obs_cov(observed(model))) + nobserved_vars(model) + Fᵢ *= nsamples(model) - 1 + F += Fᵢ end - return (nsamples(models) - 1) * (fit.minimum - G) + return F end -function χ²(lossfun::SemFIML, fit::SemFit, models::SemEnsemble) +function χ²(::SemFIML, fit::SemFit, models::SemEnsemble) ll_H0 = minus2ll(fit) ll_H1 = sum(minus2ll ∘ observed, models.sems) return ll_H0 - ll_H1 diff --git a/src/frontend/fit/fitmeasures/minus2ll.jl b/src/frontend/fit/fitmeasures/minus2ll.jl index 9b211fb44..961822ef5 100644 --- a/src/frontend/fit/fitmeasures/minus2ll.jl +++ b/src/frontend/fit/fitmeasures/minus2ll.jl @@ -3,36 +3,31 @@ Return the negative 2* log likelihood. """ -function minus2ll end +minus2ll(fit::SemFit) = minus2ll(fit, fit.model) ############################################################################################ # Single Models ############################################################################################ -minus2ll(fit::SemFit) = minus2ll(fit, fit.model) - function minus2ll(fit::SemFit, model::AbstractSemSingle) - minimum = objective(model, fit.solution) - return minus2ll(minimum, model) + check_single_lossfun(model; throw_error = true) + F = objective(model, fit.solution) + return minus2ll(model.loss.functions[1], F, model) end -minus2ll(minimum::Number, model::AbstractSemSingle) = - sum(lossfun -> minus2ll(lossfun, minimum, model), model.loss.functions) - # SemML ------------------------------------------------------------------------------------ -function minus2ll(lossfun::SemML, minimum::Number, model::AbstractSemSingle) - obs = observed(model) - return nsamples(obs) * (minimum + log(2π) * nobserved_vars(obs)) +function minus2ll(::SemML, F, model::AbstractSemSingle) + return nsamples(model) * (F + log(2π) * nobserved_vars(model)) end # WLS -------------------------------------------------------------------------------------- -minus2ll(lossfun::SemWLS, minimum::Number, model::AbstractSemSingle) = missing +minus2ll(::SemWLS, F, ::AbstractSemSingle) = missing # compute likelihood for missing data - H0 ------------------------------------------------- -# -2ll = (∑ log(2π)*(nᵢ + mᵢ)) + F*n -function minus2ll(lossfun::SemFIML, minimum::Number, model::AbstractSemSingle) +# -2ll = (∑ log(2π)*(nᵢ*mᵢ)) + F*n +function minus2ll(::SemFIML, F, model::AbstractSemSingle) obs = observed(model)::SemObservedMissing - F = minimum * nsamples(obs) + F *= nsamples(obs) F += log(2π) * sum(pat -> nsamples(pat) * nmeasured_vars(pat), obs.patterns) return F end @@ -67,4 +62,7 @@ end # Collection ############################################################################################ -minus2ll(fit::SemFit, model::SemEnsemble) = sum(Base.Fix1(minus2ll, fit), model.sems) +function minus2ll(fit::SemFit, model::SemEnsemble) + check_single_lossfun(model; throw_error = true) + return sum(Base.Fix1(minus2ll, fit), model.sems) +end diff --git a/src/types.jl b/src/types.jl index 777165f37..3f695bfa3 100644 --- a/src/types.jl +++ b/src/types.jl @@ -192,10 +192,7 @@ end function SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...) n = length(models) # default weights - if isnothing(weights) - nsamples_total = sum(nsamples, models) - weights = [nsamples(model) / nsamples_total for model in models] - end + weights = isnothing(weights) ? multigroup_weights(models, n) : weights # default group labels groups = isnothing(groups) ? Symbol.(:g, 1:n) : groups # check parameters equality @@ -226,7 +223,25 @@ function SemEnsemble(; specification, data, groups, column = :group, kwargs...) model = Sem(; specification = ram_matrices, data = data_group, kwargs...) push!(models, model) end - return SemEnsemble(models...; weights = nothing, groups = groups, kwargs...) + return SemEnsemble(models...; groups = groups, kwargs...) +end + +function multigroup_weights(models, n) + nsamples_total = sum(nsamples, models) + uniform_lossfun = check_single_lossfun(models...; throw_error = false) + if !uniform_lossfun + @info "Your ensemble model contains heterogeneous loss functions. + Default weights of (#samples per group/#total samples) will be used". + return [(nsamples(model)) / (nsamples_total) for model in models] + end + lossfun = models[1].loss.functions[1] + if !applicable(mg_correction, lossfun) + @info "We don't know how to choose group weights for the specified loss function. + Default weights of (#samples per group/#total samples) will be used". + return [(nsamples(model)) / (nsamples_total) for model in models] + end + c = mg_correction(lossfun) + return [(nsamples(model)+c) / (nsamples_total+n*c) for model in models] end param_labels(ensemble::SemEnsemble) = ensemble.param_labels