Skip to content

Commit 76e6276

Browse files
committed
Update
1 parent ea9cdc4 commit 76e6276

8 files changed

Lines changed: 40 additions & 42 deletions

DESCRIPTION

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,8 @@ Imports:
2727
tidyr,
2828
rlang,
2929
Boruta,
30-
censored,
3130
survival,
32-
survminer,
33-
tidyverse,
34-
tidymodels,
35-
cowplot,
36-
recipes
31+
survminer
3732
Remotes:
3833
VeraPancaldiLab/multideconv,
3934
Suggests:
@@ -45,7 +40,12 @@ Suggests:
4540
glmnet,
4641
xgboost,
4742
kernlab,
48-
rmarkdown
43+
rmarkdown,
44+
tidyverse,
45+
tidymodels,
46+
cowplot,
47+
recipes,
48+
censored
4949
LazyData: true
5050
LazyDataCompression: bzip2
5151
VignetteBuilder: knitr

R/machine_learning.R

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@ utils::globalVariables(c(
1414
"Kappa_resample",
1515
"MAD_AUROC",
1616
"MAD_AUPRC",
17-
"MAD_Accuracy"
17+
"MAD_Accuracy",
18+
"c_index",
19+
"c_index_median",
20+
"Median_CINDEX",
21+
"MAD_CINDEX",
22+
"time",
23+
"event",
24+
".config_id",
25+
".pred",
26+
"n_resamples",
27+
"parameter_i"
1828
))
1929

2030
#' Compute Boruta algorithm
@@ -248,7 +258,7 @@ feature.selection.boruta <- function(data, iterations = NULL, fix = FALSE, tenta
248258
#'
249259
#' This function performs repeated stratified k-fold cross-validation on a dataset to train and tune hyperparameters for 13 machine learning methods. Optionally, it can also perform model stacking and Boruta-based feature selection. Performance is evaluated using user-specified metrics such as Accuracy, AUROC, or AUPRC.
250260
#'
251-
#' @param model A data frame containing features and a target column named 'target' corresponding to the response variable to predict.
261+
#' @param train_data A data frame containing features and a target column named 'target' corresponding to the response variable to predict.
252262
#' @param k_folds Integer. Number of folds for k-fold cross-validation. Default is 5.
253263
#' @param n_rep Integer. Number of repetitions of the k-fold cross-validation. Default is 100.
254264
#' @param stacking Logical. Whether to perform model stacking. Default is FALSE.
@@ -395,8 +405,7 @@ compute_k_fold_CV = function(train_data, k_folds, n_rep, stacking = FALSE, metri
395405
#If both are ON it can slower performance (lead to over-parallelization and CPU contention)
396406
trainControl <- caret::trainControl(index = multifolds, method="repeatedcv", number=k_folds, repeats=n_rep, verboseIter = F, allowParallel = F, classProbs = TRUE, savePredictions=T)
397407

398-
#invisible(utils::capture.output({fit.xgbTree <- caret::train(target~., data=train_data, method="xgbTree", metric = "Accuracy", trControl=trainControl)}, type = "output"))
399-
fit.xgbTree <- caret::train(target~., data=train_data, method="xgbTree", metric = "Accuracy", trControl=trainControl)
408+
invisible(utils::capture.output({fit.xgbTree <- caret::train(target~., data=train_data, method="xgbTree", metric = "Accuracy", trControl=trainControl)}, type = "output"))
400409

401410
parallel::stopCluster(cl) # stop the cluster after parallel execution
402411
unregister_dopar() #Stop Dopar from running in the background
@@ -3401,6 +3410,10 @@ model_boruta_selection <- function(model,
34013410
#' in \code{data}. Default is \code{"target"}.
34023411
#' @param cor_thresh A numeric value between 0 and 1 specifying the correlation
34033412
#' threshold for removing highly correlated features. Default is \code{0.9}.
3413+
#' @param time_var A character string specifying the name of the time-to-event
3414+
#' column in \code{data}. Used only for survival analysis.
3415+
#' @param event_var A character string specifying the name of the event indicator
3416+
#' column in \code{data} (e.g., 0/1). Used only for survival analysis.
34043417
#'
34053418
#' @details
34063419
#' The preprocessing steps include:
@@ -3513,18 +3526,14 @@ preprocess_features <- function(data,
35133526
#'
35143527
#' @param train_data A data frame containing the full training dataset,
35153528
#' including predictors and the target variable.
3516-
#' @param fold_data A list or object containing pre-constructed folds for
3517-
#' cross-validation, typically created by \code{fold_construction_fun}.
3529+
#' @param optimized An object returned by \code{compute_custom_k_fold_CV}
3530+
#' containing the optimized hyperparameters and cross-validation results.
35183531
#' @param ml_method A character string specifying the machine learning method
35193532
#' to be passed to \code{caret::train}.
35203533
#' @param fold_construction_fun A function used to (re)construct training
35213534
#' data partitions given the best hyperparameters.
35223535
#' @param fold_construction_args_fixed A named list of additional fixed arguments
35233536
#' to pass to \code{fold_construction_fun}.
3524-
#' @param tuneGrid (optional) A data frame of hyperparameter values to evaluate.
3525-
#' If \code{NULL}, defaults are used.
3526-
#' @param ncores (optional) Integer specifying the number of cores for parallel
3527-
#' processing during cross-validation. If \code{NULL}, defaults to serial execution.
35283537
#'
35293538
#' @details
35303539
#' The workflow proceeds in the following steps:
@@ -3675,9 +3684,6 @@ wrapper_train_best_hyperparams_classification <- function(train_data, optimized,
36753684
#' function (e.g., CellTFusion outputs or parameter tables).}
36763685
#' }
36773686
#'
3678-
#' @seealso [compute_k_fold_CV_survival()], [aggregate_results_survival()],
3679-
#' [compute_ml_survival()]
3680-
#'
36813687
#' @export
36823688
#'
36833689
wrapper_train_best_hyperparams_survival <- function(train_data,
@@ -3841,7 +3847,6 @@ wrapper_train_best_hyperparams_survival <- function(train_data,
38413847
#' \item{`Resample_matrix`}{Fold-level metrics for the best configuration.}
38423848
#' }
38433849
#'
3844-
#' @seealso [compute_k_fold_CV_survival()], [calculate_accuracy_kappa_resample()]
38453850
#'
38463851
#' @export
38473852
#'
@@ -4191,6 +4196,7 @@ get_tune_grid = function(method, train_data){
41914196
#' sampled at each split in tree-based models.
41924197
#' @param levels Integer specifying how many values to generate per hyperparameter.
41934198
#' Defaults to \code{5}. Must be at least 2.
4199+
#' @param v Integer. Number of folds for K-fold cross-validation (default = 5).
41944200
#'
41954201
#' @return A named list of hyperparameter grids.
41964202
#' Each element is a numeric vector of sampled values for that parameter.
@@ -4554,9 +4560,6 @@ compute_ml_survival <- function(df_train, df_test,
45544560
#' \item{`Custom_output`}{Optional list of custom outputs from fold construction.}
45554561
#' }
45564562
#'
4557-
#' @seealso [aggregate_results()], [compute_cv_CINDEX()],
4558-
#' [wrapper_train_best_hyperparams_survival()]
4559-
#'
45604563
#' @export
45614564
#'
45624565
compute_k_fold_CV_survival <- function(df_features, df_outcome, outcome_col, event_col, k_folds, n_rep, ncores,

man/aggregate_results.Rd

Lines changed: 0 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/compute_k_fold_CV.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/compute_k_fold_CV_survival.Rd

Lines changed: 0 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/get_default_hyperparams.Rd

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/preprocess_features.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/wrapper_train_best_hyperparams_classification.Rd

Lines changed: 3 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)