From e58aa4cdd4faf634cf72e0540c6a7076b04b02df Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Wed, 1 Apr 2026 17:52:17 +0300 Subject: [PATCH 01/10] implemented the generic ensemble subsystem. For now supports DecisionTree, RandomForest and kNN only. 7 new methods, a lot of tests --- src/ensemble/generic_ensemble.rs | 1074 ++++++++++++++++++++++++++++++ src/ensemble/mod.rs | 28 + 2 files changed, 1102 insertions(+) create mode 100644 src/ensemble/generic_ensemble.rs diff --git a/src/ensemble/generic_ensemble.rs b/src/ensemble/generic_ensemble.rs new file mode 100644 index 00000000..8078af75 --- /dev/null +++ b/src/ensemble/generic_ensemble.rs @@ -0,0 +1,1074 @@ +use std::collections::HashMap; +use std::marker::PhantomData; + +use crate::api::Predictor; +use crate::error::Failed; +use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::metrics::accuracy; + +// ----------------------------------------------------------------------------- +// Some basic structures +// ----------------------------------------------------------------------------- + +/// Strategy for aggregating votes from ensemble members. +/// +/// Determines how individual model predictions are combined into +/// a final ensemble prediction. +/// +/// # Usage +/// ``` +/// // Uniform: each model gets 1 vote +/// let ens = Ensemble::with_strategy(VotingStrategy::Uniform); +/// +/// // Weighted: assign confidence scores to models +/// let mut ens = Ensemble::with_strategy(VotingStrategy::Weighted); +/// ens.add_with_params(None, model, Some(2.0), None, vec![])?; +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VotingStrategy { + /// Simple majority voting. Each member contributes 1 vote per prediction. + /// The `weight` field in members is ignored. + Uniform, + /// Weighted voting. Each member's vote is multiplied by its `weight`. + /// Final score for a class = sum of (weight * vote) across members. + /// + /// # Constraints + /// * All members must have `weight: Some(f64)` when using this strategy. + /// * Weights must be finite and non-negative (enforced at insertion). + Weighted, +} + +impl Default for VotingStrategy { + fn default() -> Self { + VotingStrategy::Uniform + } +} + +/// Summary information about the ensemble configuration. +/// +/// Returned by [`Ensemble::get_ensemble_info`]. Use this to inspect +/// the current state of the ensemble without accessing internal fields. +/// +/// # Example +/// ``` +/// let ensemble = Ensemble::, Vec>::new(); +/// let info = ensemble.get_ensemble_info(); +/// assert_eq!(info.total_members, 0); +/// ``` +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq)] +pub struct EnsembleInfo { + pub strategy: VotingStrategy, + pub total_members: usize, + pub enabled_members: usize, + pub uses_weighted_voting: bool, +} + +// ----------------------------------------------------------------------------- +// Ensemble Member +// ----------------------------------------------------------------------------- + +/// Container for a model and its metadata within an ensemble. +/// +/// This struct wraps a predictive model along with voting weight, +/// description, enabled state, and tags. It is managed internally +/// by [`Ensemble`] and not intended for direct construction. +/// +/// # Type Parameters +/// * `X` - Input feature type (must implement `Array2`) +/// * `Y` - Label type (must implement `Array1 + Clone`) +struct EnsembleMember { + /// The underlying predictive model, boxed as a trait object. + pub model: Box>, + + /// Optional weight for voting. Used only if strategy is `Weighted`. + pub weight: Option, + + /// Optional human-readable description for documentation/debugging. + pub description: Option, + + /// Whether the model is enabled for inference. Disabled models + /// are skipped during prediction but retained in the ensemble. + pub is_enabled: bool, + + /// Tags for grouping/filtering models. Empty by default. + /// Reserved for future extensibility. + pub tags: Vec, +} + +impl EnsembleMember { + + // TODO We'll use it later, maybe someone + /// Check if this member has a specific tag. + #[allow(dead_code)] + fn has_tag(&self, tag: &str) -> bool { + self.tags.iter().any(|t| t == tag) + } +} + +impl std::fmt::Debug for EnsembleMember +where + X: Array2, + Y: Array1 + Clone, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EnsembleMember") + .field("weight", &self.weight) + .field("description", &self.description) + .field("is_enabled", &self.is_enabled) + .field("tags", &self.tags) + .field("model", &"") // Dummy placeholder + .finish() + } +} + +// ----------------------------------------------------------------------------- +// Ensemble Structure +// ----------------------------------------------------------------------------- + +/// A voting ensemble for classification models. +/// +/// Aggregates predictions from multiple `Predictor` instances using +/// hard voting (majority/weighted) via score aggregation. +/// # Type Parameters +/// * `X` - Input data type (e.g., `Array2` for feature vectors) +/// * `Y` - Label type (e.g., `Array1` for class labels) +/// +/// # Constraints +/// * All models must predict the same label type (`i32`). +/// * Input `x` to `predict` methods should represent a single sample. +pub struct Ensemble +where + X: Array2, + Y: Array1 + Clone, +{ + /// Registered ensemble members. + members: HashMap>, + + /// Active voting strategy. + strategy: VotingStrategy, + + /// Counter for auto-generated names. + counter: usize, + + _phantom: PhantomData<(X, Y)>, +} + +// ----------------------------------------------------------------------------- +// Implementation +// ----------------------------------------------------------------------------- + +impl Ensemble +where + X: Array2, + Y: Array1 + Clone, +{ + /// Creates a new empty ensemble with `Uniform` voting strategy. + /// + /// # Type Inference + /// Rust can usually infer `X` and `Y` from the first model you add: + /// ``` + /// let mut ens = Ensemble::new(); + /// ens.add(knn_model)?; // X, Y inferred from knn_model + /// ``` + /// + /// If inference fails, specify types explicitly: + /// ``` + /// let mut ens: Ensemble, Vec> = Ensemble::new(); + /// ``` + pub fn new() -> Self { + Self { + members: HashMap::new(), + strategy: VotingStrategy::default(), + counter: 0, + _phantom: PhantomData, + } + } + + /// Creates a new ensemble with a specific voting strategy. + /// + /// # Arguments + /// * `strategy` - `Uniform` for simple majority, `Weighted` for confidence-based voting. + /// + /// # Example + /// ``` + /// let mut ens = Ensemble::with_strategy(VotingStrategy::Weighted); + /// // Now you must provide weights when adding models + /// ens.add_with_params(None, model, Some(1.0), None, vec![])?; + /// ``` + pub fn with_strategy(strategy: VotingStrategy) -> Self { + Self { strategy, ..Self::new() } + } + + // ------------------------------------------------------------------------- + // Add + // ------------------------------------------------------------------------- + + /// Convenience method: adds a model with auto-generated name and default metadata. + /// + /// Uses `Uniform` voting weight (1.0). For advanced options, use `add_with_params`. + /// + /// # Returns + /// The auto-generated name of the added member, or error if addition failed. + /// + /// # Example + /// ``` + /// let mut ensemble = Ensemble::new(); + /// let model_name = ensemble.add(knn_model)?; + /// println!("Added: {}", model_name); + /// ``` + pub fn add

(&mut self, model: P) -> Result + where + P: Predictor + 'static, + { + self.add_with_params(None, model, None, None, vec![]) + } + + /// Convenience method: adds a model with a custom name and default metadata. + /// + /// Equivalent to `add_with_params(Some(name), model, None, None, vec![])`. + /// + /// # Arguments + /// * `name` - Unique identifier for the model. Must not already exist in the ensemble. + /// * `model` - Any type implementing `Predictor`. + /// + /// # Returns + /// * `Ok(String)` — The name of the added member (same as input `name`). + /// * `Err(Failed)` — If name already exists or addition failed. + /// + /// # Example + /// ``` + /// let mut ensemble = Ensemble::new(); + /// ensemble.add_named("knn_k3".into(), knn_model)?; + /// ensemble.add_named("rf_depth10".into(), rf_model)?; + /// assert!(ensemble.names().contains(&"knn_k3".to_string())); + /// ``` + pub fn add_named

(&mut self, name: String, model: P) -> Result + where + P: Predictor + 'static, + { + self.add_with_params(Some(name), model, None, None, vec![]) + } + + /// Adds a model to the ensemble with optional metadata. + /// + /// # Arguments + /// * `name` - Optional unique identifier. Auto-generated if None. + /// * `model` - Any type implementing Predictor. + /// * `weight` - Required if strategy is Weighted. + /// + /// # Example + /// ``` + /// ensemble.add_with_params( + /// Some("rf_v1".into()), + /// random_forest_model, + /// Some(0.8), + /// Some("RF with depth=10".into()), + /// vec!["tree".into()] + /// )?; + /// ``` + pub fn add_with_params

( + &mut self, + name: Option, + model: P, + weight: Option, + description: Option, + tags: Vec, + ) -> Result + where P: Predictor + 'static + { + let final_name = name.unwrap_or_else(|| self.generate_auto_name()); + + if self.members.contains_key(&final_name) { + return Err(Failed::input("Duplicate member name")); + } + + if matches!(self.strategy, VotingStrategy::Weighted) { + match weight { + Some(w) if w.is_finite() && w >= 0.0 => {} + _ => return Err(Failed::input("Invalid weight")), + } + } + + let is_enabled = true; + + self.members.insert( + final_name.clone(), + EnsembleMember { + model: Box::new(model), + weight, + description, + is_enabled, + tags, + }, + ); + + Ok(final_name) + } + + // ------------------------------------------------------------------------- + // Metadata Management + // ------------------------------------------------------------------------- + + // Get + + /// Returns the total number of registered members. + /// + /// Includes both enabled and disabled models. + /// Use [`enabled()`](Self::enabled) to count only active models. + /// + /// # Example + /// ``` + /// assert_eq!(ensemble.len(), 0); + /// ensemble.add(model)?; + /// assert_eq!(ensemble.len(), 1); + /// ``` + pub fn len(&self) -> usize { self.members.len() } + + /// Returns names of all registered members (enabled and disabled). + /// + /// Order is arbitrary (HashMap iteration order). + /// + /// # Example + /// ``` + /// let names = ensemble.names(); + /// assert!(names.contains(&"my_model".to_string())); + /// ``` + pub fn names(&self) -> Vec { + self.members.keys().cloned().collect() + } + + /// Returns `true` if the ensemble has no registered members. + /// + /// # Example + /// ``` + /// let ens = Ensemble::<_, _>::new(); + /// assert!(ens.is_empty()); + /// ``` + pub fn is_empty(&self) -> bool { self.members.is_empty() } + + /// Returns the current voting strategy. + /// + /// # Example + /// ``` + /// let ens = Ensemble::with_strategy(VotingStrategy::Weighted); + /// assert_eq!(ens.strategy(), VotingStrategy::Weighted); + /// ``` + pub fn strategy(&self) -> VotingStrategy { self.strategy } + + /// Returns the voting weight for a specific member. + /// + /// # Returns + /// * `Some(weight)` if strategy is `Weighted` and member exists + /// * `None` if strategy is `Uniform` OR member not found + /// + /// # Example + /// ``` + /// let w = ensemble.weight("my_model"); + /// if let Some(weight) = w { + /// println!("Weight: {}", weight); + /// } + /// ``` + pub fn weight(&self, name: &str) -> Option { + if !self.members.contains_key(name) { + return None; + } + + match self.strategy { + VotingStrategy::Uniform => None, + VotingStrategy::Weighted => { self.members.get(name).and_then(|member| member.weight) } + } + } + + /// Returns summary information about the ensemble configuration. + /// + /// Does not include per-model hyperparameters (use [`get_model_metadata`](Self::get_model_metadata) for that). + /// + /// # Example + /// ``` + /// let info = ensemble.get_ensemble_info(); + /// println!("Strategy: {:?}", info.strategy); + /// println!("Active models: {}/{}", info.enabled_members, info.total_members); + /// ``` + pub fn get_ensemble_info(&self) -> EnsembleInfo { + EnsembleInfo { + strategy: self.strategy, + total_members: self.members.len(), + enabled_members: self.enabled().len(), + uses_weighted_voting: matches!(self.strategy, VotingStrategy::Weighted), + } + } + + // Set + + /// Updates the voting weight of an existing member. + /// + /// # Arguments + /// * `name` - Name of the member to update + /// * `weight` - New weight value + /// + /// # Constraints + /// * Member must exist + /// * If strategy is `Weighted`, weight must be finite and non-negative + /// + /// # Errors + /// * `Failed::input` if member not found or weight invalid + /// + /// # Example + /// ``` + /// ensemble.set_weight("strong_model", 2.0)?; + /// ``` + pub fn set_weight(&mut self, name: &str, weight: f64) -> Result<(), Failed> { + if matches!(self.strategy, VotingStrategy::Weighted) { + if !weight.is_finite() || weight < 0.0 { + return Err(Failed::input("Weight must be finite and non-negative")); + } + } + let member = self.members.get_mut(name) + .ok_or_else(|| Failed::input(&format!("Member '{}' not found", name)))?; + member.weight = Some(weight); + Ok(()) + } + + /// Updates the human-readable description of a member. + /// + /// Useful for documentation, debugging, or UI display. + /// + /// # Arguments + /// * `name` - Name of the member + /// * `desc` - New description string + /// + /// # Example + /// ``` + /// ensemble.set_description("rf_v1", "Random Forest, depth=10, trained on Q1 data")?; + /// ``` + pub fn set_description(&mut self, name: &str, desc: String) -> Result<(), Failed> { + let member = self.members.get_mut(name) + .ok_or_else(|| Failed::input(&format!("Member '{}' not found", name)))?; + member.description = Some(desc); + Ok(()) + } + + /// Changes the voting strategy for the ensemble. + /// + /// # Arguments + /// * `strategy` - New strategy (`Uniform` or `Weighted`) + /// + /// # Constraints + /// * If switching to `Weighted`, all members must already have a weight set + /// + /// # Errors + /// * `Failed::input` if switching to `Weighted` and any member lacks a weight + /// + /// # Example + /// ``` + /// // Start with Uniform + /// let mut ens = Ensemble::new(); + /// ens.add(model1)?; + /// ens.add(model2)?; + /// + /// // Switch to Weighted — must set weights first! + /// ens.set_weight("model_0", 1.0)?; + /// ens.set_weight("model_1", 2.0)?; + /// ens.set_voting_strategy(VotingStrategy::Weighted)?; + /// ``` + pub fn set_voting_strategy(&mut self, strategy: VotingStrategy) -> Result<(), Failed> { + if matches!(strategy, VotingStrategy::Weighted) { + for member in self.members.values() { + if member.weight.is_none() { + return Err(Failed::input("All members must have weights when using weighted strategy")); + } + } + self.strategy = strategy; + + return Ok(()); + } + + if matches!(strategy, VotingStrategy::Uniform) { + self.strategy = strategy; + return Ok(()); + } + + return Err(Failed::input("Invalid voting strategy")); + } + + // ------------------------------------------------------------------------- + // Model Management + // ------------------------------------------------------------------------- + + /// Temporarily excludes a member from prediction without removing it. + /// + /// Disabled models are skipped during `predict()` and `score()`, + /// but remain in the ensemble and can be re-enabled later. + /// + /// # Arguments + /// * `name` - Name of the member to disable + /// + /// # Errors + /// * `Failed::input` if member not found or already disabled + /// + /// # Example + /// ``` + /// ensemble.disable("underperforming_model")?; + /// let preds = ensemble.predict(&x)?; // model excluded from voting + /// ``` + pub fn disable(&mut self, name: &str) -> Result<(), Failed> { + if let Some(member) = self.members.get_mut(name) { + if member.is_enabled { + member.is_enabled = false; + return Ok(()); + } + return Err(Failed::input("Model is already disabled")); + } + return Err(Failed::input("Model not found")) + } + + /// Re-includes a previously disabled member in prediction. + /// + /// # Arguments + /// * `name` - Name of the member to enable + /// + /// # Errors + /// * `Failed::input` if member not found or already enabled + /// + /// # Example + /// ``` + /// ensemble.enable("previously_disabled_model")?; + /// ``` + pub fn enable(&mut self, name: &str) -> Result<(), Failed> { + if let Some(member) = self.members.get_mut(name) { + if !member.is_enabled { + member.is_enabled = true; + return Ok(()); + } + return Err(Failed::input("Model is already enabled")); + } + return Err(Failed::input("Model not found")) + } + + /// Returns names of all currently enabled members. + /// + /// Useful for debugging, logging, or selective inspection. + /// + /// # Example + /// ``` + /// let active = ensemble.enabled(); + /// println!("Active models: {:?}", active); + /// ``` + pub fn enabled(&self) -> Vec { + self.members.iter() + .filter(|(_, member)| member.is_enabled) + .map(|(name, _)| name.clone()) + .collect() + } + + // ------------------------------------------------------------------------- + // Predictions, regular and distributed, and a simple scoring + // ------------------------------------------------------------------------- + + /// Predicts class labels for input samples using ensemble voting. + /// + /// All enabled members receive the **same** input `x` and contribute + /// votes according to the active [`VotingStrategy`]. + /// + /// # Arguments + /// * `x` - Input feature matrix (samples × features) + /// + /// # Returns + /// * `Ok(Y)` - Vector of predicted class labels + /// * `Err(Failed)` - If ensemble is empty or prediction fails + /// + /// # Example + /// ``` + /// let predictions = ensemble.predict(&x_test)?; + /// let acc = accuracy(&y_test, &predictions); + /// ``` + /// + /// # See Also + /// * [`predict_using_names()`](Self::predict_using_names) — for feature-slicing scenarios + /// * [`score()`](Self::score) — for quick accuracy evaluation + pub fn predict(&self, x: &X) -> Result { + if self.members.is_empty() { + return Err(Failed::predict("Empty ensemble")); + } + + let enabled: Vec = self.enabled(); + + let mut all_preds: Vec = Vec::with_capacity(enabled.len()); + + for model_name in &enabled { + let model = self.members.get(model_name).unwrap(); // TODO unwrap is safe here, but maybe "if let" will just look better + let pred = model.model.predict(x)?; + all_preds.push(pred); + } + + let n_samples = all_preds[0].shape(); + let mut result = Y::zeros(n_samples); + + for i in 0..n_samples { + let mut scores: HashMap = HashMap::new(); + + for (model_name, preds) in enabled.iter().zip(all_preds.iter()) { + let class = + + +*preds.get(i); + + let vote = match (self.strategy, self.weight(model_name)) { + (VotingStrategy::Uniform, _) => 1.0, + (VotingStrategy::Weighted, Some(w)) => w, + (VotingStrategy::Weighted, None) => { + return Err(Failed::predict("Missing weight")) + } + }; + + *scores.entry(class).or_insert(0.0) += vote; + } + + // argmax + let (best_class, _) = scores + .into_iter() + .max_by(|a, b| a.1.total_cmp(&b.1)) + .ok_or_else(|| Failed::predict("No votes"))?; + + result.set(i, best_class); + } + + Ok(result) + } + + /// Computes accuracy score on given data. + /// + /// Equivalent to sklearn `accuracy(y, self.predict(x))`. + pub fn score(&self, x: &X, y: &Y) -> Result + where + Y: Array1 + { + let preds = self.predict(x)?; + Ok(accuracy(y, &preds)) + } + + /// Predicts using per-model input subsets (feature slicing). + /// + /// Each ensemble member receives its own input from the `inputs` map, + /// keyed by member name. Useful when models are trained on different + /// feature subsets or representations. + /// + /// # Arguments + /// * `inputs` - HashMap mapping member names to their specific input matrices + /// + /// # Returns + /// * `Ok(Y)` - Aggregated predictions via voting + /// * `Err(Failed)` - If any required input is missing or prediction fails + /// + /// # Example + /// ``` + /// // Models trained on different feature slices + /// let mut inputs = HashMap::new(); + /// inputs.insert("slice_A".into(), x_slice_a); + /// inputs.insert("slice_B".into(), x_slice_b); + /// + /// let predictions = ensemble.predict_using_names(&inputs)?; + /// ``` + /// + /// # Constraints + /// * Every enabled member must have a corresponding entry in `inputs` + /// * All input matrices must have the same number of samples + pub fn predict_using_names(&self, inputs: &HashMap) -> Result { + if self.members.is_empty() { + return Err(Failed::predict("Empty ensemble")); + } + + let mut all_preds: Vec<(String, Y)> = Vec::new(); + + for (name, member) in &self.members { + if !member.is_enabled { + continue; + } + + let x = inputs.get(name) + .ok_or_else(|| Failed::input("Missing X set in input"))?; + + let pred = member.model.predict(x)?; + all_preds.push((name.clone(), pred)); + } + + let n_samples = all_preds[0].1.shape(); + let mut result = Y::zeros(n_samples); + + for i in 0..n_samples { + let mut scores: HashMap = HashMap::new(); + + for (name, preds) in &all_preds { + let member = &self.members[name]; + let class = *preds.get(i); + + let vote = match (self.strategy, member.weight) { + (VotingStrategy::Uniform, _) => 1.0, + (VotingStrategy::Weighted, Some(w)) => w, + (VotingStrategy::Weighted, None) => { + return Err(Failed::predict("Missing weight")) + } + }; + + *scores.entry(class).or_insert(0.0) += vote; + } + + let (best_class, _) = scores + .into_iter() + .max_by(|a, b| a.1.total_cmp(&b.1)) + .ok_or_else(|| Failed::predict("No votes"))?; + + result.set(i, best_class); + } + + Ok(result) + } + + // ------------------------------------------------------------------------- + // Utilities + // ------------------------------------------------------------------------- + + fn generate_auto_name(&mut self) -> String { + let name = format!("model_{}", self.counter); + self.counter += 1; + name + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::basic::matrix::DenseMatrix; + use crate::neighbors::knn_classifier::{KNNClassifier, KNNClassifierParameters}; + use crate::ensemble::random_forest_classifier::{RandomForestClassifier, RandomForestClassifierParameters}; + use crate::linalg::basic::arrays::Array2; + + fn dummy_data_2class() -> (DenseMatrix, Vec) { + // 6 samples, 2 features, balanced classes (3 vs 3) + let x = DenseMatrix::from_2d_vec(&vec![ + vec![1.0, 1.0], // Class 0 + vec![1.5, 1.2], // Class 0 + vec![2.0, 1.5], // Class 0 + vec![4.0, 4.0], // Class 1 + vec![4.5, 4.2], // Class 1 + vec![5.0, 4.5], // Class 1 + ]).unwrap(); + let y = vec![0, 0, 0, 1, 1, 1]; + (x, y) + } + + // Test 1: Simple add() with auto-generated names + #[test] + fn test_add_simple_knn_models() { + let (x, y) = dummy_data_2class(); + let mut ensemble = Ensemble::, Vec>::new(); + + let knn1 = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); + let knn2 = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(5)).unwrap(); + + let name1 = ensemble.add(knn1).unwrap(); + let name2 = ensemble.add(knn2).unwrap(); + + assert_eq!(ensemble.len(), 2); + assert_eq!(name1, "model_0"); + assert_eq!(name2, "model_1"); + + let names = ensemble.names(); + assert!(names.contains(&"model_0".to_string())); + assert!(names.contains(&"model_1".to_string())); + } + + // Test 2: Heterogeneous ensemble via add() — KNN + RF + #[test] + fn test_add_heterogeneous_models() { + let (x, y) = dummy_data_2class(); + let mut ensemble = Ensemble::, Vec>::new(); + + let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); + let rf = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters::default().with_n_trees(5)).unwrap(); + + let name_knn = ensemble.add(knn).unwrap(); + let name_rf = ensemble.add(rf).unwrap(); + + assert_eq!(ensemble.len(), 2); + let names = ensemble.names(); + assert_eq!(names.len(), 2); + assert!(names.contains(&name_knn)); + assert!(names.contains(&name_rf)); + } + + // Test 3: add_named() with custom names + #[test] + fn test_add_named_custom_names() { + let (x, y) = dummy_data_2class(); + let mut ensemble = Ensemble::, Vec>::new(); + + let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); + let rf = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters::default().with_n_trees(5)).unwrap(); + + let name1 = ensemble.add_named("my_knn".into(), knn).unwrap(); + let name2 = ensemble.add_named("my_rf".into(), rf).unwrap(); + + assert_eq!(name1, "my_knn"); + assert_eq!(name2, "my_rf"); + assert_eq!(ensemble.len(), 2); + + let names = ensemble.names(); + assert!(names.contains(&"my_knn".to_string())); + assert!(names.contains(&"my_rf".to_string())); + } + + // Test 4: Error on duplicate name via add_named() + #[test] + fn test_error_duplicate_name() { + let (x, y) = dummy_data_2class(); + let mut ensemble = Ensemble::, Vec>::new(); + + let knn1 = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); + let knn2 = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); + + ensemble.add_named("same_name".into(), knn1).unwrap(); + + let result = ensemble.add_named("same_name".into(), knn2); + assert!(result.is_err()); + } + + // Test 5: Weighted voting with explicit weights + #[test] + fn test_weighted_voting_with_weights() { + let (x, y) = dummy_data_2class(); + let mut ensemble = Ensemble::, Vec>::with_strategy(VotingStrategy::Weighted); + + let knn_weak = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); + let knn_strong = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(5)).unwrap(); + + // Add with different weights + ensemble.add_with_params(Some("weak".into()), knn_weak, Some(0.5), None, vec![]).unwrap(); + ensemble.add_with_params(Some("strong".into()), knn_strong, Some(2.0), None, vec![]).unwrap(); + + // Predict should work without error + let preds = ensemble.predict(&x).unwrap(); + assert_eq!(preds.len(), y.len()); + + // Score should be in valid range + let score = ensemble.score(&x, &y).unwrap(); + assert!((0.0..=1.0).contains(&score)); + } + + // Test 6: Error when adding without weight in Weighted mode + #[test] + fn test_error_missing_weight_in_weighted_mode() { + let (x, y) = dummy_data_2class(); + let mut ensemble = Ensemble::, Vec>::with_strategy(VotingStrategy::Weighted); + + let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); + + // add() does not provide weight → should fail in Weighted mode + let result = ensemble.add(knn); + assert!(result.is_err()); + } + + // Test 7: predict_using_names() with feature slicing + #[test] + fn test_predict_using_names_feature_slicing() { + let (x_full, y) = dummy_data_2class(); // y имеет длину 6 + let mut ensemble = Ensemble::, Vec>::new(); + + // Model A: trained on feature 0 only (must have 6 rows to match y!) + let x_a_train = DenseMatrix::from_2d_vec(&vec![ + vec![1.0], // sample 0, feat 0 + vec![1.5], // sample 1, feat 0 + vec![2.0], // sample 2, feat 0 + vec![4.0], // sample 3, feat 0 + vec![4.5], // sample 4, feat 0 + vec![5.0], // sample 5, feat 0 + ]).unwrap(); + let knn_a = KNNClassifier::fit(&x_a_train, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); + + // Model B: trained on feature 1 only (must have 6 rows to match y!) + let x_b_train = DenseMatrix::from_2d_vec(&vec![ + vec![1.0], // sample 0, feat 1 + vec![1.2], // sample 1, feat 1 + vec![1.5], // sample 2, feat 1 + vec![4.0], // sample 3, feat 1 + vec![4.2], // sample 4, feat 1 + vec![4.5], // sample 5, feat 1 + ]).unwrap(); + let knn_b = KNNClassifier::fit(&x_b_train, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); + + ensemble.add_named("model_A".into(), knn_a).unwrap(); + ensemble.add_named("model_B".into(), knn_b).unwrap(); + + // Prepare per-model inputs for prediction (2 test samples) + let mut inputs = HashMap::new(); + let x_a_test = DenseMatrix::from_2d_vec(&vec![vec![1.8], vec![4.3]]).unwrap(); + let x_b_test = DenseMatrix::from_2d_vec(&vec![vec![1.6], vec![4.4]]).unwrap(); + inputs.insert("model_A".into(), x_a_test); + inputs.insert("model_B".into(), x_b_test); + + // Predict with per-model inputs + let preds = ensemble.predict_using_names(&inputs).unwrap(); + assert_eq!(preds.len(), 2); // 2 test samples + } + + // Test 8: enable/disable affects prediction + #[test] + fn test_enable_disable_affects_prediction() { + let (x_train, y_train) = dummy_data_2class(); + let (x_test, y_test) = dummy_data_2class(); + + let mut ensemble = Ensemble::, Vec>::new(); + + let knn1 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(3)).unwrap(); + let knn2 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(5)).unwrap(); + + ensemble.add_named("k1".into(), knn1).unwrap(); + ensemble.add_named("k3".into(), knn2).unwrap(); + + // Score with both models + let score_both = ensemble.score(&x_test, &y_test).unwrap(); + + // Disable one model + ensemble.disable("k3").unwrap(); + let score_one = ensemble.score(&x_test, &y_test).unwrap(); + + // Both scores should be valid; they may differ + assert!((0.0..=1.0).contains(&score_both)); + assert!((0.0..=1.0).contains(&score_one)); + + // Re-enable and verify count + ensemble.enable("k3").unwrap(); + assert_eq!(ensemble.enabled().len(), 2); + } + + // Test 9: get_ensemble_info() reflects state + #[test] + fn test_get_ensemble_info() { + let mut ensemble = Ensemble::, Vec>::with_strategy(VotingStrategy::Weighted); + + let info = ensemble.get_ensemble_info(); + assert_eq!(info.strategy, VotingStrategy::Weighted); + assert_eq!(info.total_members, 0); + assert_eq!(info.enabled_members, 0); + assert!(info.uses_weighted_voting); + + // Add a model + let (x, y) = dummy_data_2class(); + let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); + ensemble.add_with_params(None, knn, Some(1.0), None, vec![]).unwrap(); + + let info2 = ensemble.get_ensemble_info(); + assert_eq!(info2.total_members, 1); + assert_eq!(info2.enabled_members, 1); + } + + // Test 10: All supported classifiers (smoke test) — 3 models + #[test] + fn test_add_all_classifier_types() { + use crate::tree::decision_tree_classifier::DecisionTreeClassifier; + + let (x, y) = dummy_data_2class(); + let mut ensemble = Ensemble::, Vec>::new(); + + // 1. KNN + let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); + ensemble.add_named("knn_k3".into(), knn).unwrap(); + + // 2. Random Forest + let rf = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters::default().with_n_trees(3)).unwrap(); + ensemble.add_named("rf_3trees".into(), rf).unwrap(); + + // 3. Decision Tree + let dt = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); + ensemble.add_named("decision_tree".into(), dt).unwrap(); + + // All 3 models are active + assert_eq!(ensemble.len(), 3); + assert_eq!(ensemble.enabled().len(), 3); + + // Name check + let names = ensemble.names(); + assert!(names.contains(&"knn_k3".to_string())); + assert!(names.contains(&"rf_3trees".to_string())); + assert!(names.contains(&"decision_tree".to_string())); + + // The most beautiful thing - predict() on 3 different models + let preds = ensemble.predict(&x).unwrap(); + assert_eq!(preds.len(), y.len()); + + // Score must be valid + let score = ensemble.score(&x, &y).unwrap(); + assert!((0.0..=1.0).contains(&score)); + } + + // Test 10a: add_with_params() with custom names + #[test] + fn test_add_with_custom_names() { + let (x, y) = dummy_data_2class(); + let mut ensemble = Ensemble::, Vec>::new(); + + let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); + let rf = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters::default().with_n_trees(5)).unwrap(); + + let name1 = ensemble.add_with_params( + Some("my_knn".into()), knn, None, Some("k=1".into()), vec!["fast".into()] + ).unwrap(); + let name2 = ensemble.add_with_params( + Some("my_rf".into()), rf, None, Some("5 trees".into()), vec!["tree".into()] + ).unwrap(); + + assert_eq!(name1, "my_knn"); + assert_eq!(name2, "my_rf"); + assert_eq!(ensemble.len(), 2); + // Check metadata + // TODO To be implemented one day + } + + // Test 10b: score() is still valid after adding a model + #[test] + fn test_score_with_increasing_models() { + let (x_train, y_train) = dummy_data_2class(); + let (x_test, y_test) = dummy_data_2class(); + + let mut ensemble = Ensemble::, Vec>::new(); + + // Add first model + let knn1 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(2)).unwrap(); + ensemble.add(knn1).unwrap(); + let score1 = ensemble.score(&x_test, &y_test).unwrap(); + assert!((0.0..=1.0).contains(&score1)); + + // Add second model + let knn2 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(3)).unwrap(); + ensemble.add(knn2).unwrap(); + let score2 = ensemble.score(&x_test, &y_test).unwrap(); + + // Score may go up or down — just ensure it's valid + assert!((0.0..=1.0).contains(&score2)); + } + + // Test 10c: score() is still valid after disabling a model + #[test] + fn test_score_after_disable() { + let (x_train, y_train) = dummy_data_2class(); + let (x_test, y_test) = dummy_data_2class(); + + let mut ensemble = Ensemble::, Vec>::with_strategy(VotingStrategy::Uniform); + + let knn1 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(2)).unwrap(); + let knn2 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(3)).unwrap(); + + ensemble.add_with_params(Some("k1".into()), knn1, None, None, vec![]).unwrap(); + ensemble.add_with_params(Some("k3".into()), knn2, None, None, vec![]).unwrap(); + + let score_before = ensemble.score(&x_test, &y_test).unwrap(); + + ensemble.disable("k3").unwrap(); + let score_after = ensemble.score(&x_test, &y_test).unwrap(); + + // Scores can differ; just ensure both are valid + assert!((0.0..=1.0).contains(&score_before)); + assert!((0.0..=1.0).contains(&score_after)); + } +} diff --git a/src/ensemble/mod.rs b/src/ensemble/mod.rs index 4f5eefc5..c0be2e70 100644 --- a/src/ensemble/mod.rs +++ b/src/ensemble/mod.rs @@ -22,3 +22,31 @@ pub mod extra_trees_regressor; pub mod random_forest_classifier; /// Random forest regressor pub mod random_forest_regressor; + +/// Generic voting ensemble for classification models. +/// +/// This module provides the [`Ensemble`] struct, which aggregates predictions +/// from multiple [`Predictor`] implementations using hard voting (uniform or weighted). +/// +/// # Quick Start +/// ``` +/// use smartcore::ensemble::generic_ensemble::{Ensemble, VotingStrategy}; +/// +/// // Create ensemble +/// let mut ensemble = Ensemble::new(); +/// +/// // Add models (any type implementing Predictor) +/// ensemble.add(knn_model)?; +/// ensemble.add(rf_model)?; +/// +/// // Predict +/// let predictions = ensemble.predict(&x_test)?; +/// ``` +/// +/// # Features +/// * Heterogeneous ensembles: KNN, Random Forest, Decision Tree are now the only supported models. The rest are on their way +/// * Uniform or weighted voting strategies +/// * Dynamic enable/disable of members at runtime +/// * Meta descriptions, tags, weights +/// * Feature slicing via `predict_using_names()` +pub mod generic_ensemble; \ No newline at end of file From 8c2cf6dcc4503ecb06c678b16d375da87613a7d0 Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Wed, 1 Apr 2026 18:00:43 +0300 Subject: [PATCH 02/10] fmt, clippy, bingen tests --- src/ensemble/generic_ensemble.rs | 532 ++++++++++++++++++------------- src/ensemble/mod.rs | 14 +- 2 files changed, 326 insertions(+), 220 deletions(-) diff --git a/src/ensemble/generic_ensemble.rs b/src/ensemble/generic_ensemble.rs index 8078af75..ffefafed 100644 --- a/src/ensemble/generic_ensemble.rs +++ b/src/ensemble/generic_ensemble.rs @@ -11,44 +11,39 @@ use crate::metrics::accuracy; // ----------------------------------------------------------------------------- /// Strategy for aggregating votes from ensemble members. -/// +/// /// Determines how individual model predictions are combined into /// a final ensemble prediction. -/// +/// /// # Usage /// ``` /// // Uniform: each model gets 1 vote /// let ens = Ensemble::with_strategy(VotingStrategy::Uniform); -/// +/// /// // Weighted: assign confidence scores to models /// let mut ens = Ensemble::with_strategy(VotingStrategy::Weighted); /// ens.add_with_params(None, model, Some(2.0), None, vec![])?; /// ``` -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum VotingStrategy { /// Simple majority voting. Each member contributes 1 vote per prediction. /// The `weight` field in members is ignored. + #[default] Uniform, /// Weighted voting. Each member's vote is multiplied by its `weight`. /// Final score for a class = sum of (weight * vote) across members. - /// + /// /// # Constraints /// * All members must have `weight: Some(f64)` when using this strategy. /// * Weights must be finite and non-negative (enforced at insertion). Weighted, } -impl Default for VotingStrategy { - fn default() -> Self { - VotingStrategy::Uniform - } -} - /// Summary information about the ensemble configuration. -/// +/// /// Returned by [`Ensemble::get_ensemble_info`]. Use this to inspect /// the current state of the ensemble without accessing internal fields. -/// +/// /// # Example /// ``` /// let ensemble = Ensemble::, Vec>::new(); @@ -69,11 +64,11 @@ pub struct EnsembleInfo { // ----------------------------------------------------------------------------- /// Container for a model and its metadata within an ensemble. -/// +/// /// This struct wraps a predictive model along with voting weight, /// description, enabled state, and tags. It is managed internally /// by [`Ensemble`] and not intended for direct construction. -/// +/// /// # Type Parameters /// * `X` - Input feature type (must implement `Array2`) /// * `Y` - Label type (must implement `Array1 + Clone`) @@ -83,10 +78,10 @@ struct EnsembleMember { /// Optional weight for voting. Used only if strategy is `Weighted`. pub weight: Option, - + /// Optional human-readable description for documentation/debugging. pub description: Option, - + /// Whether the model is enabled for inference. Disabled models /// are skipped during prediction but retained in the ensemble. pub is_enabled: bool, @@ -97,7 +92,6 @@ struct EnsembleMember { } impl EnsembleMember { - // TODO We'll use it later, maybe someone /// Check if this member has a specific tag. #[allow(dead_code)] @@ -117,7 +111,7 @@ where .field("description", &self.description) .field("is_enabled", &self.is_enabled) .field("tags", &self.tags) - .field("model", &"") // Dummy placeholder + .field("model", &"") // Dummy placeholder .finish() } } @@ -127,13 +121,13 @@ where // ----------------------------------------------------------------------------- /// A voting ensemble for classification models. -/// +/// /// Aggregates predictions from multiple `Predictor` instances using /// hard voting (majority/weighted) via score aggregation. /// # Type Parameters /// * `X` - Input data type (e.g., `Array2` for feature vectors) /// * `Y` - Label type (e.g., `Array1` for class labels) -/// +/// /// # Constraints /// * All models must predict the same label type (`i32`). /// * Input `x` to `predict` methods should represent a single sample. @@ -164,14 +158,14 @@ where Y: Array1 + Clone, { /// Creates a new empty ensemble with `Uniform` voting strategy. - /// + /// /// # Type Inference /// Rust can usually infer `X` and `Y` from the first model you add: /// ``` /// let mut ens = Ensemble::new(); /// ens.add(knn_model)?; // X, Y inferred from knn_model /// ``` - /// + /// /// If inference fails, specify types explicitly: /// ``` /// let mut ens: Ensemble, Vec> = Ensemble::new(); @@ -186,10 +180,10 @@ where } /// Creates a new ensemble with a specific voting strategy. - /// + /// /// # Arguments /// * `strategy` - `Uniform` for simple majority, `Weighted` for confidence-based voting. - /// + /// /// # Example /// ``` /// let mut ens = Ensemble::with_strategy(VotingStrategy::Weighted); @@ -197,7 +191,10 @@ where /// ens.add_with_params(None, model, Some(1.0), None, vec![])?; /// ``` pub fn with_strategy(strategy: VotingStrategy) -> Self { - Self { strategy, ..Self::new() } + Self { + strategy, + ..Self::new() + } } // ------------------------------------------------------------------------- @@ -205,12 +202,12 @@ where // ------------------------------------------------------------------------- /// Convenience method: adds a model with auto-generated name and default metadata. - /// + /// /// Uses `Uniform` voting weight (1.0). For advanced options, use `add_with_params`. - /// + /// /// # Returns /// The auto-generated name of the added member, or error if addition failed. - /// + /// /// # Example /// ``` /// let mut ensemble = Ensemble::new(); @@ -225,17 +222,17 @@ where } /// Convenience method: adds a model with a custom name and default metadata. - /// + /// /// Equivalent to `add_with_params(Some(name), model, None, None, vec![])`. - /// + /// /// # Arguments /// * `name` - Unique identifier for the model. Must not already exist in the ensemble. /// * `model` - Any type implementing `Predictor`. - /// + /// /// # Returns /// * `Ok(String)` — The name of the added member (same as input `name`). /// * `Err(Failed)` — If name already exists or addition failed. - /// + /// /// # Example /// ``` /// let mut ensemble = Ensemble::new(); @@ -251,19 +248,19 @@ where } /// Adds a model to the ensemble with optional metadata. - /// + /// /// # Arguments /// * `name` - Optional unique identifier. Auto-generated if None. /// * `model` - Any type implementing Predictor. /// * `weight` - Required if strategy is Weighted. - /// + /// /// # Example /// ``` /// ensemble.add_with_params( - /// Some("rf_v1".into()), - /// random_forest_model, - /// Some(0.8), - /// Some("RF with depth=10".into()), + /// Some("rf_v1".into()), + /// random_forest_model, + /// Some(0.8), + /// Some("RF with depth=10".into()), /// vec!["tree".into()] /// )?; /// ``` @@ -275,7 +272,8 @@ where description: Option, tags: Vec, ) -> Result - where P: Predictor + 'static + where + P: Predictor + 'static, { let final_name = name.unwrap_or_else(|| self.generate_auto_name()); @@ -313,22 +311,24 @@ where // Get /// Returns the total number of registered members. - /// + /// /// Includes both enabled and disabled models. /// Use [`enabled()`](Self::enabled) to count only active models. - /// + /// /// # Example /// ``` /// assert_eq!(ensemble.len(), 0); /// ensemble.add(model)?; /// assert_eq!(ensemble.len(), 1); /// ``` - pub fn len(&self) -> usize { self.members.len() } + pub fn len(&self) -> usize { + self.members.len() + } /// Returns names of all registered members (enabled and disabled). - /// + /// /// Order is arbitrary (HashMap iteration order). - /// + /// /// # Example /// ``` /// let names = ensemble.names(); @@ -339,29 +339,33 @@ where } /// Returns `true` if the ensemble has no registered members. - /// + /// /// # Example /// ``` /// let ens = Ensemble::<_, _>::new(); /// assert!(ens.is_empty()); /// ``` - pub fn is_empty(&self) -> bool { self.members.is_empty() } + pub fn is_empty(&self) -> bool { + self.members.is_empty() + } /// Returns the current voting strategy. - /// + /// /// # Example /// ``` /// let ens = Ensemble::with_strategy(VotingStrategy::Weighted); /// assert_eq!(ens.strategy(), VotingStrategy::Weighted); /// ``` - pub fn strategy(&self) -> VotingStrategy { self.strategy } + pub fn strategy(&self) -> VotingStrategy { + self.strategy + } /// Returns the voting weight for a specific member. - /// + /// /// # Returns /// * `Some(weight)` if strategy is `Weighted` and member exists /// * `None` if strategy is `Uniform` OR member not found - /// + /// /// # Example /// ``` /// let w = ensemble.weight("my_model"); @@ -376,14 +380,14 @@ where match self.strategy { VotingStrategy::Uniform => None, - VotingStrategy::Weighted => { self.members.get(name).and_then(|member| member.weight) } + VotingStrategy::Weighted => self.members.get(name).and_then(|member| member.weight), } } /// Returns summary information about the ensemble configuration. - /// + /// /// Does not include per-model hyperparameters (use [`get_model_metadata`](Self::get_model_metadata) for that). - /// + /// /// # Example /// ``` /// let info = ensemble.get_ensemble_info(); @@ -402,18 +406,18 @@ where // Set /// Updates the voting weight of an existing member. - /// + /// /// # Arguments /// * `name` - Name of the member to update /// * `weight` - New weight value - /// + /// /// # Constraints /// * Member must exist /// * If strategy is `Weighted`, weight must be finite and non-negative - /// + /// /// # Errors /// * `Failed::input` if member not found or weight invalid - /// + /// /// # Example /// ``` /// ensemble.set_weight("strong_model", 2.0)?; @@ -424,49 +428,53 @@ where return Err(Failed::input("Weight must be finite and non-negative")); } } - let member = self.members.get_mut(name) + let member = self + .members + .get_mut(name) .ok_or_else(|| Failed::input(&format!("Member '{}' not found", name)))?; member.weight = Some(weight); Ok(()) } /// Updates the human-readable description of a member. - /// + /// /// Useful for documentation, debugging, or UI display. - /// + /// /// # Arguments /// * `name` - Name of the member /// * `desc` - New description string - /// + /// /// # Example /// ``` /// ensemble.set_description("rf_v1", "Random Forest, depth=10, trained on Q1 data")?; /// ``` pub fn set_description(&mut self, name: &str, desc: String) -> Result<(), Failed> { - let member = self.members.get_mut(name) + let member = self + .members + .get_mut(name) .ok_or_else(|| Failed::input(&format!("Member '{}' not found", name)))?; member.description = Some(desc); Ok(()) } /// Changes the voting strategy for the ensemble. - /// + /// /// # Arguments /// * `strategy` - New strategy (`Uniform` or `Weighted`) - /// + /// /// # Constraints /// * If switching to `Weighted`, all members must already have a weight set - /// + /// /// # Errors /// * `Failed::input` if switching to `Weighted` and any member lacks a weight - /// + /// /// # Example /// ``` /// // Start with Uniform /// let mut ens = Ensemble::new(); /// ens.add(model1)?; /// ens.add(model2)?; - /// + /// /// // Switch to Weighted — must set weights first! /// ens.set_weight("model_0", 1.0)?; /// ens.set_weight("model_1", 2.0)?; @@ -476,7 +484,9 @@ where if matches!(strategy, VotingStrategy::Weighted) { for member in self.members.values() { if member.weight.is_none() { - return Err(Failed::input("All members must have weights when using weighted strategy")); + return Err(Failed::input( + "All members must have weights when using weighted strategy", + )); } } self.strategy = strategy; @@ -489,7 +499,7 @@ where return Ok(()); } - return Err(Failed::input("Invalid voting strategy")); + Err(Failed::input("Invalid voting strategy")) } // ------------------------------------------------------------------------- @@ -497,16 +507,16 @@ where // ------------------------------------------------------------------------- /// Temporarily excludes a member from prediction without removing it. - /// + /// /// Disabled models are skipped during `predict()` and `score()`, /// but remain in the ensemble and can be re-enabled later. - /// + /// /// # Arguments /// * `name` - Name of the member to disable - /// + /// /// # Errors /// * `Failed::input` if member not found or already disabled - /// + /// /// # Example /// ``` /// ensemble.disable("underperforming_model")?; @@ -517,20 +527,20 @@ where if member.is_enabled { member.is_enabled = false; return Ok(()); - } + } return Err(Failed::input("Model is already disabled")); } - return Err(Failed::input("Model not found")) + Err(Failed::input("Model not found")) } /// Re-includes a previously disabled member in prediction. - /// + /// /// # Arguments /// * `name` - Name of the member to enable - /// + /// /// # Errors /// * `Failed::input` if member not found or already enabled - /// + /// /// # Example /// ``` /// ensemble.enable("previously_disabled_model")?; @@ -543,20 +553,21 @@ where } return Err(Failed::input("Model is already enabled")); } - return Err(Failed::input("Model not found")) + Err(Failed::input("Model not found")) } /// Returns names of all currently enabled members. - /// + /// /// Useful for debugging, logging, or selective inspection. - /// + /// /// # Example /// ``` /// let active = ensemble.enabled(); /// println!("Active models: {:?}", active); /// ``` pub fn enabled(&self) -> Vec { - self.members.iter() + self.members + .iter() .filter(|(_, member)| member.is_enabled) .map(|(name, _)| name.clone()) .collect() @@ -567,23 +578,23 @@ where // ------------------------------------------------------------------------- /// Predicts class labels for input samples using ensemble voting. - /// + /// /// All enabled members receive the **same** input `x` and contribute /// votes according to the active [`VotingStrategy`]. - /// + /// /// # Arguments /// * `x` - Input feature matrix (samples × features) - /// + /// /// # Returns /// * `Ok(Y)` - Vector of predicted class labels /// * `Err(Failed)` - If ensemble is empty or prediction fails - /// + /// /// # Example /// ``` /// let predictions = ensemble.predict(&x_test)?; /// let acc = accuracy(&y_test, &predictions); /// ``` - /// + /// /// # See Also /// * [`predict_using_names()`](Self::predict_using_names) — for feature-slicing scenarios /// * [`score()`](Self::score) — for quick accuracy evaluation @@ -609,10 +620,7 @@ where let mut scores: HashMap = HashMap::new(); for (model_name, preds) in enabled.iter().zip(all_preds.iter()) { - let class = - - -*preds.get(i); + let class = *preds.get(i); let vote = match (self.strategy, self.weight(model_name)) { (VotingStrategy::Uniform, _) => 1.0, @@ -638,39 +646,39 @@ where } /// Computes accuracy score on given data. - /// + /// /// Equivalent to sklearn `accuracy(y, self.predict(x))`. - pub fn score(&self, x: &X, y: &Y) -> Result + pub fn score(&self, x: &X, y: &Y) -> Result where - Y: Array1 + Y: Array1, { let preds = self.predict(x)?; Ok(accuracy(y, &preds)) } /// Predicts using per-model input subsets (feature slicing). - /// + /// /// Each ensemble member receives its own input from the `inputs` map, /// keyed by member name. Useful when models are trained on different /// feature subsets or representations. - /// + /// /// # Arguments /// * `inputs` - HashMap mapping member names to their specific input matrices - /// + /// /// # Returns /// * `Ok(Y)` - Aggregated predictions via voting /// * `Err(Failed)` - If any required input is missing or prediction fails - /// + /// /// # Example /// ``` /// // Models trained on different feature slices /// let mut inputs = HashMap::new(); /// inputs.insert("slice_A".into(), x_slice_a); /// inputs.insert("slice_B".into(), x_slice_b); - /// + /// /// let predictions = ensemble.predict_using_names(&inputs)?; /// ``` - /// + /// /// # Constraints /// * Every enabled member must have a corresponding entry in `inputs` /// * All input matrices must have the same number of samples @@ -686,7 +694,8 @@ where continue; } - let x = inputs.get(name) + let x = inputs + .get(name) .ok_or_else(|| Failed::input("Missing X set in input"))?; let pred = member.model.predict(x)?; @@ -739,41 +748,52 @@ where #[cfg(test)] mod tests { use super::*; + use crate::ensemble::random_forest_classifier::{ + RandomForestClassifier, RandomForestClassifierParameters, + }; + use crate::linalg::basic::arrays::Array2; use crate::linalg::basic::matrix::DenseMatrix; use crate::neighbors::knn_classifier::{KNNClassifier, KNNClassifierParameters}; - use crate::ensemble::random_forest_classifier::{RandomForestClassifier, RandomForestClassifierParameters}; - use crate::linalg::basic::arrays::Array2; fn dummy_data_2class() -> (DenseMatrix, Vec) { // 6 samples, 2 features, balanced classes (3 vs 3) let x = DenseMatrix::from_2d_vec(&vec![ - vec![1.0, 1.0], // Class 0 - vec![1.5, 1.2], // Class 0 - vec![2.0, 1.5], // Class 0 - vec![4.0, 4.0], // Class 1 - vec![4.5, 4.2], // Class 1 - vec![5.0, 4.5], // Class 1 - ]).unwrap(); + vec![1.0, 1.0], // Class 0 + vec![1.5, 1.2], // Class 0 + vec![2.0, 1.5], // Class 0 + vec![4.0, 4.0], // Class 1 + vec![4.5, 4.2], // Class 1 + vec![5.0, 4.5], // Class 1 + ]) + .unwrap(); let y = vec![0, 0, 0, 1, 1, 1]; (x, y) } + // Apply wasm_bindgen_test to all tests in this module + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + // Test 1: Simple add() with auto-generated names #[test] fn test_add_simple_knn_models() { let (x, y) = dummy_data_2class(); let mut ensemble = Ensemble::, Vec>::new(); - - let knn1 = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); - let knn2 = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(5)).unwrap(); - + + let knn1 = + KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); + let knn2 = + KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(5)).unwrap(); + let name1 = ensemble.add(knn1).unwrap(); let name2 = ensemble.add(knn2).unwrap(); - + assert_eq!(ensemble.len(), 2); assert_eq!(name1, "model_0"); assert_eq!(name2, "model_1"); - + let names = ensemble.names(); assert!(names.contains(&"model_0".to_string())); assert!(names.contains(&"model_1".to_string())); @@ -784,13 +804,18 @@ mod tests { fn test_add_heterogeneous_models() { let (x, y) = dummy_data_2class(); let mut ensemble = Ensemble::, Vec>::new(); - + let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); - let rf = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters::default().with_n_trees(5)).unwrap(); - + let rf = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters::default().with_n_trees(5), + ) + .unwrap(); + let name_knn = ensemble.add(knn).unwrap(); let name_rf = ensemble.add(rf).unwrap(); - + assert_eq!(ensemble.len(), 2); let names = ensemble.names(); assert_eq!(names.len(), 2); @@ -803,17 +828,22 @@ mod tests { fn test_add_named_custom_names() { let (x, y) = dummy_data_2class(); let mut ensemble = Ensemble::, Vec>::new(); - + let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); - let rf = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters::default().with_n_trees(5)).unwrap(); - + let rf = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters::default().with_n_trees(5), + ) + .unwrap(); + let name1 = ensemble.add_named("my_knn".into(), knn).unwrap(); let name2 = ensemble.add_named("my_rf".into(), rf).unwrap(); - + assert_eq!(name1, "my_knn"); assert_eq!(name2, "my_rf"); assert_eq!(ensemble.len(), 2); - + let names = ensemble.names(); assert!(names.contains(&"my_knn".to_string())); assert!(names.contains(&"my_rf".to_string())); @@ -824,12 +854,14 @@ mod tests { fn test_error_duplicate_name() { let (x, y) = dummy_data_2class(); let mut ensemble = Ensemble::, Vec>::new(); - - let knn1 = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); - let knn2 = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); - + + let knn1 = + KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); + let knn2 = + KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); + ensemble.add_named("same_name".into(), knn1).unwrap(); - + let result = ensemble.add_named("same_name".into(), knn2); assert!(result.is_err()); } @@ -838,19 +870,26 @@ mod tests { #[test] fn test_weighted_voting_with_weights() { let (x, y) = dummy_data_2class(); - let mut ensemble = Ensemble::, Vec>::with_strategy(VotingStrategy::Weighted); - - let knn_weak = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); - let knn_strong = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(5)).unwrap(); - + let mut ensemble = + Ensemble::, Vec>::with_strategy(VotingStrategy::Weighted); + + let knn_weak = + KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); + let knn_strong = + KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(5)).unwrap(); + // Add with different weights - ensemble.add_with_params(Some("weak".into()), knn_weak, Some(0.5), None, vec![]).unwrap(); - ensemble.add_with_params(Some("strong".into()), knn_strong, Some(2.0), None, vec![]).unwrap(); - + ensemble + .add_with_params(Some("weak".into()), knn_weak, Some(0.5), None, vec![]) + .unwrap(); + ensemble + .add_with_params(Some("strong".into()), knn_strong, Some(2.0), None, vec![]) + .unwrap(); + // Predict should work without error let preds = ensemble.predict(&x).unwrap(); assert_eq!(preds.len(), y.len()); - + // Score should be in valid range let score = ensemble.score(&x, &y).unwrap(); assert!((0.0..=1.0).contains(&score)); @@ -860,10 +899,11 @@ mod tests { #[test] fn test_error_missing_weight_in_weighted_mode() { let (x, y) = dummy_data_2class(); - let mut ensemble = Ensemble::, Vec>::with_strategy(VotingStrategy::Weighted); - + let mut ensemble = + Ensemble::, Vec>::with_strategy(VotingStrategy::Weighted); + let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); - + // add() does not provide weight → should fail in Weighted mode let result = ensemble.add(knn); assert!(result.is_err()); @@ -874,39 +914,45 @@ mod tests { fn test_predict_using_names_feature_slicing() { let (x_full, y) = dummy_data_2class(); // y имеет длину 6 let mut ensemble = Ensemble::, Vec>::new(); - + // Model A: trained on feature 0 only (must have 6 rows to match y!) let x_a_train = DenseMatrix::from_2d_vec(&vec![ - vec![1.0], // sample 0, feat 0 - vec![1.5], // sample 1, feat 0 - vec![2.0], // sample 2, feat 0 - vec![4.0], // sample 3, feat 0 - vec![4.5], // sample 4, feat 0 - vec![5.0], // sample 5, feat 0 - ]).unwrap(); - let knn_a = KNNClassifier::fit(&x_a_train, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); - + vec![1.0], // sample 0, feat 0 + vec![1.5], // sample 1, feat 0 + vec![2.0], // sample 2, feat 0 + vec![4.0], // sample 3, feat 0 + vec![4.5], // sample 4, feat 0 + vec![5.0], // sample 5, feat 0 + ]) + .unwrap(); + let knn_a = + KNNClassifier::fit(&x_a_train, &y, KNNClassifierParameters::default().with_k(3)) + .unwrap(); + // Model B: trained on feature 1 only (must have 6 rows to match y!) let x_b_train = DenseMatrix::from_2d_vec(&vec![ - vec![1.0], // sample 0, feat 1 - vec![1.2], // sample 1, feat 1 - vec![1.5], // sample 2, feat 1 - vec![4.0], // sample 3, feat 1 - vec![4.2], // sample 4, feat 1 - vec![4.5], // sample 5, feat 1 - ]).unwrap(); - let knn_b = KNNClassifier::fit(&x_b_train, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); - + vec![1.0], // sample 0, feat 1 + vec![1.2], // sample 1, feat 1 + vec![1.5], // sample 2, feat 1 + vec![4.0], // sample 3, feat 1 + vec![4.2], // sample 4, feat 1 + vec![4.5], // sample 5, feat 1 + ]) + .unwrap(); + let knn_b = + KNNClassifier::fit(&x_b_train, &y, KNNClassifierParameters::default().with_k(3)) + .unwrap(); + ensemble.add_named("model_A".into(), knn_a).unwrap(); ensemble.add_named("model_B".into(), knn_b).unwrap(); - + // Prepare per-model inputs for prediction (2 test samples) let mut inputs = HashMap::new(); let x_a_test = DenseMatrix::from_2d_vec(&vec![vec![1.8], vec![4.3]]).unwrap(); let x_b_test = DenseMatrix::from_2d_vec(&vec![vec![1.6], vec![4.4]]).unwrap(); inputs.insert("model_A".into(), x_a_test); inputs.insert("model_B".into(), x_b_test); - + // Predict with per-model inputs let preds = ensemble.predict_using_names(&inputs).unwrap(); assert_eq!(preds.len(), 2); // 2 test samples @@ -917,26 +963,36 @@ mod tests { fn test_enable_disable_affects_prediction() { let (x_train, y_train) = dummy_data_2class(); let (x_test, y_test) = dummy_data_2class(); - + let mut ensemble = Ensemble::, Vec>::new(); - - let knn1 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(3)).unwrap(); - let knn2 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(5)).unwrap(); - + + let knn1 = KNNClassifier::fit( + &x_train, + &y_train, + KNNClassifierParameters::default().with_k(3), + ) + .unwrap(); + let knn2 = KNNClassifier::fit( + &x_train, + &y_train, + KNNClassifierParameters::default().with_k(5), + ) + .unwrap(); + ensemble.add_named("k1".into(), knn1).unwrap(); ensemble.add_named("k3".into(), knn2).unwrap(); - + // Score with both models let score_both = ensemble.score(&x_test, &y_test).unwrap(); - + // Disable one model ensemble.disable("k3").unwrap(); let score_one = ensemble.score(&x_test, &y_test).unwrap(); - + // Both scores should be valid; they may differ assert!((0.0..=1.0).contains(&score_both)); assert!((0.0..=1.0).contains(&score_one)); - + // Re-enable and verify count ensemble.enable("k3").unwrap(); assert_eq!(ensemble.enabled().len(), 2); @@ -945,19 +1001,22 @@ mod tests { // Test 9: get_ensemble_info() reflects state #[test] fn test_get_ensemble_info() { - let mut ensemble = Ensemble::, Vec>::with_strategy(VotingStrategy::Weighted); - + let mut ensemble = + Ensemble::, Vec>::with_strategy(VotingStrategy::Weighted); + let info = ensemble.get_ensemble_info(); assert_eq!(info.strategy, VotingStrategy::Weighted); assert_eq!(info.total_members, 0); assert_eq!(info.enabled_members, 0); assert!(info.uses_weighted_voting); - + // Add a model let (x, y) = dummy_data_2class(); let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); - ensemble.add_with_params(None, knn, Some(1.0), None, vec![]).unwrap(); - + ensemble + .add_with_params(None, knn, Some(1.0), None, vec![]) + .unwrap(); + let info2 = ensemble.get_ensemble_info(); assert_eq!(info2.total_members, 1); assert_eq!(info2.enabled_members, 1); @@ -970,33 +1029,38 @@ mod tests { let (x, y) = dummy_data_2class(); let mut ensemble = Ensemble::, Vec>::new(); - + // 1. KNN let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap(); ensemble.add_named("knn_k3".into(), knn).unwrap(); - + // 2. Random Forest - let rf = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters::default().with_n_trees(3)).unwrap(); + let rf = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters::default().with_n_trees(3), + ) + .unwrap(); ensemble.add_named("rf_3trees".into(), rf).unwrap(); - + // 3. Decision Tree let dt = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); ensemble.add_named("decision_tree".into(), dt).unwrap(); - + // All 3 models are active assert_eq!(ensemble.len(), 3); assert_eq!(ensemble.enabled().len(), 3); - + // Name check let names = ensemble.names(); assert!(names.contains(&"knn_k3".to_string())); assert!(names.contains(&"rf_3trees".to_string())); assert!(names.contains(&"decision_tree".to_string())); - + // The most beautiful thing - predict() on 3 different models let preds = ensemble.predict(&x).unwrap(); assert_eq!(preds.len(), y.len()); - + // Score must be valid let score = ensemble.score(&x, &y).unwrap(); assert!((0.0..=1.0).contains(&score)); @@ -1007,17 +1071,34 @@ mod tests { fn test_add_with_custom_names() { let (x, y) = dummy_data_2class(); let mut ensemble = Ensemble::, Vec>::new(); - + let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(2)).unwrap(); - let rf = RandomForestClassifier::fit(&x, &y, RandomForestClassifierParameters::default().with_n_trees(5)).unwrap(); - - let name1 = ensemble.add_with_params( - Some("my_knn".into()), knn, None, Some("k=1".into()), vec!["fast".into()] - ).unwrap(); - let name2 = ensemble.add_with_params( - Some("my_rf".into()), rf, None, Some("5 trees".into()), vec!["tree".into()] - ).unwrap(); - + let rf = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters::default().with_n_trees(5), + ) + .unwrap(); + + let name1 = ensemble + .add_with_params( + Some("my_knn".into()), + knn, + None, + Some("k=1".into()), + vec!["fast".into()], + ) + .unwrap(); + let name2 = ensemble + .add_with_params( + Some("my_rf".into()), + rf, + None, + Some("5 trees".into()), + vec!["tree".into()], + ) + .unwrap(); + assert_eq!(name1, "my_knn"); assert_eq!(name2, "my_rf"); assert_eq!(ensemble.len(), 2); @@ -1030,20 +1111,30 @@ mod tests { fn test_score_with_increasing_models() { let (x_train, y_train) = dummy_data_2class(); let (x_test, y_test) = dummy_data_2class(); - + let mut ensemble = Ensemble::, Vec>::new(); - + // Add first model - let knn1 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(2)).unwrap(); + let knn1 = KNNClassifier::fit( + &x_train, + &y_train, + KNNClassifierParameters::default().with_k(2), + ) + .unwrap(); ensemble.add(knn1).unwrap(); let score1 = ensemble.score(&x_test, &y_test).unwrap(); assert!((0.0..=1.0).contains(&score1)); - + // Add second model - let knn2 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(3)).unwrap(); + let knn2 = KNNClassifier::fit( + &x_train, + &y_train, + KNNClassifierParameters::default().with_k(3), + ) + .unwrap(); ensemble.add(knn2).unwrap(); let score2 = ensemble.score(&x_test, &y_test).unwrap(); - + // Score may go up or down — just ensure it's valid assert!((0.0..=1.0).contains(&score2)); } @@ -1053,20 +1144,35 @@ mod tests { fn test_score_after_disable() { let (x_train, y_train) = dummy_data_2class(); let (x_test, y_test) = dummy_data_2class(); - - let mut ensemble = Ensemble::, Vec>::with_strategy(VotingStrategy::Uniform); - - let knn1 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(2)).unwrap(); - let knn2 = KNNClassifier::fit(&x_train, &y_train, KNNClassifierParameters::default().with_k(3)).unwrap(); - - ensemble.add_with_params(Some("k1".into()), knn1, None, None, vec![]).unwrap(); - ensemble.add_with_params(Some("k3".into()), knn2, None, None, vec![]).unwrap(); - + + let mut ensemble = + Ensemble::, Vec>::with_strategy(VotingStrategy::Uniform); + + let knn1 = KNNClassifier::fit( + &x_train, + &y_train, + KNNClassifierParameters::default().with_k(2), + ) + .unwrap(); + let knn2 = KNNClassifier::fit( + &x_train, + &y_train, + KNNClassifierParameters::default().with_k(3), + ) + .unwrap(); + + ensemble + .add_with_params(Some("k1".into()), knn1, None, None, vec![]) + .unwrap(); + ensemble + .add_with_params(Some("k3".into()), knn2, None, None, vec![]) + .unwrap(); + let score_before = ensemble.score(&x_test, &y_test).unwrap(); - + ensemble.disable("k3").unwrap(); let score_after = ensemble.score(&x_test, &y_test).unwrap(); - + // Scores can differ; just ensure both are valid assert!((0.0..=1.0).contains(&score_before)); assert!((0.0..=1.0).contains(&score_after)); diff --git a/src/ensemble/mod.rs b/src/ensemble/mod.rs index c0be2e70..6febfe38 100644 --- a/src/ensemble/mod.rs +++ b/src/ensemble/mod.rs @@ -24,29 +24,29 @@ pub mod random_forest_classifier; pub mod random_forest_regressor; /// Generic voting ensemble for classification models. -/// +/// /// This module provides the [`Ensemble`] struct, which aggregates predictions /// from multiple [`Predictor`] implementations using hard voting (uniform or weighted). -/// +/// /// # Quick Start /// ``` /// use smartcore::ensemble::generic_ensemble::{Ensemble, VotingStrategy}; -/// +/// /// // Create ensemble /// let mut ensemble = Ensemble::new(); -/// +/// /// // Add models (any type implementing Predictor) /// ensemble.add(knn_model)?; /// ensemble.add(rf_model)?; -/// +/// /// // Predict /// let predictions = ensemble.predict(&x_test)?; /// ``` -/// +/// /// # Features /// * Heterogeneous ensembles: KNN, Random Forest, Decision Tree are now the only supported models. The rest are on their way /// * Uniform or weighted voting strategies /// * Dynamic enable/disable of members at runtime /// * Meta descriptions, tags, weights /// * Feature slicing via `predict_using_names()` -pub mod generic_ensemble; \ No newline at end of file +pub mod generic_ensemble; From 0cf841ae5f56be7522959ceedd6ccfc7b2deecc6 Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Sun, 5 Apr 2026 23:02:16 +0300 Subject: [PATCH 03/10] updated 8 files to fix #365 issue. Also, around 10 deprecation warnings fixed. Also, the Generic Ensemble documentation examples are now ignored. --- Cargo.toml | 10 ++--- src/cluster/kmeans.rs | 4 +- src/ensemble/base_forest_regressor.rs | 2 +- src/ensemble/generic_ensemble.rs | 47 ++++++++++++------------ src/ensemble/mod.rs | 2 +- src/ensemble/random_forest_classifier.rs | 2 +- src/numbers/floatnum.rs | 4 +- src/numbers/realnum.rs | 8 ++-- src/rand_custom.rs | 3 +- src/tree/base_tree_regressor.rs | 2 +- 10 files changed, 42 insertions(+), 42 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 56ab26af..f5a77c7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,8 @@ cfg-if = "1.0.0" ndarray = { version = "0.15", optional = true } num-traits = "0.2.12" num = "0.4" -rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } -rand_distr = { version = "0.4", optional = true } +rand = { version = "0.9", default-features = false, features = ["small_rng"] } +rand_distr = { version = "0.5", optional = true } serde = { version = "1", features = ["derive"], optional = true } ordered-float = "5.1.0" @@ -38,12 +38,12 @@ default = [] serde = ["dep:serde", "dep:typetag"] ndarray-bindings = ["dep:ndarray"] datasets = ["dep:rand_distr", "std_rand", "serde"] -std_rand = ["rand/std_rng", "rand/std"] +std_rand = ["rand/std_rng", "rand/std", "rand/thread_rng"] # used by wasm32-unknown-unknown for in-browser usage -js = ["getrandom/js"] +js = ["getrandom/wasm_js"] [target.'cfg(target_arch = "wasm32")'.dependencies] -getrandom = { version = "0.2.8", optional = true } +getrandom = { version = "0.3", optional = true } [target.'cfg(all(target_arch = "wasm32", not(target_os = "wasi")))'.dev-dependencies] wasm-bindgen-test = "0.3" diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index 2fade68f..1edc20b3 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -356,7 +356,7 @@ impl, Y: Array1> KMeans let (n, _) = data.shape(); let mut y = vec![0; n]; let mut centroid: Vec = data - .get_row(rng.gen_range(0..n)) + .get_row(rng.random_range(0..n)) .iterator(0) .cloned() .collect(); @@ -382,7 +382,7 @@ impl, Y: Array1> KMeans for i in d.iter() { sum += *i; } - let cutoff = rng.gen::() * sum; + let cutoff = rng.random::() * sum; let mut cost = 0f64; let mut index = 0; while index < n { diff --git a/src/ensemble/base_forest_regressor.rs b/src/ensemble/base_forest_regressor.rs index 4209034c..03c5ace2 100644 --- a/src/ensemble/base_forest_regressor.rs +++ b/src/ensemble/base_forest_regressor.rs @@ -212,7 +212,7 @@ impl, Y: Array1 fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec { let mut samples = vec![0; nrows]; for _ in 0..nrows { - let xi = rng.gen_range(0..nrows); + let xi = rng.random_range(0..nrows); samples[xi] += 1; } samples diff --git a/src/ensemble/generic_ensemble.rs b/src/ensemble/generic_ensemble.rs index ffefafed..944b5950 100644 --- a/src/ensemble/generic_ensemble.rs +++ b/src/ensemble/generic_ensemble.rs @@ -16,7 +16,7 @@ use crate::metrics::accuracy; /// a final ensemble prediction. /// /// # Usage -/// ``` +/// ```ignore /// // Uniform: each model gets 1 vote /// let ens = Ensemble::with_strategy(VotingStrategy::Uniform); /// @@ -45,7 +45,7 @@ pub enum VotingStrategy { /// the current state of the ensemble without accessing internal fields. /// /// # Example -/// ``` +/// ```ignore /// let ensemble = Ensemble::, Vec>::new(); /// let info = ensemble.get_ensemble_info(); /// assert_eq!(info.total_members, 0); @@ -161,13 +161,13 @@ where /// /// # Type Inference /// Rust can usually infer `X` and `Y` from the first model you add: - /// ``` + /// ```ignore /// let mut ens = Ensemble::new(); /// ens.add(knn_model)?; // X, Y inferred from knn_model /// ``` /// /// If inference fails, specify types explicitly: - /// ``` + /// ```ignore /// let mut ens: Ensemble, Vec> = Ensemble::new(); /// ``` pub fn new() -> Self { @@ -185,7 +185,7 @@ where /// * `strategy` - `Uniform` for simple majority, `Weighted` for confidence-based voting. /// /// # Example - /// ``` + /// ```ignore /// let mut ens = Ensemble::with_strategy(VotingStrategy::Weighted); /// // Now you must provide weights when adding models /// ens.add_with_params(None, model, Some(1.0), None, vec![])?; @@ -209,7 +209,7 @@ where /// The auto-generated name of the added member, or error if addition failed. /// /// # Example - /// ``` + /// ```ignore /// let mut ensemble = Ensemble::new(); /// let model_name = ensemble.add(knn_model)?; /// println!("Added: {}", model_name); @@ -234,7 +234,7 @@ where /// * `Err(Failed)` — If name already exists or addition failed. /// /// # Example - /// ``` + /// ```ignore /// let mut ensemble = Ensemble::new(); /// ensemble.add_named("knn_k3".into(), knn_model)?; /// ensemble.add_named("rf_depth10".into(), rf_model)?; @@ -255,7 +255,7 @@ where /// * `weight` - Required if strategy is Weighted. /// /// # Example - /// ``` + /// ```ignore /// ensemble.add_with_params( /// Some("rf_v1".into()), /// random_forest_model, @@ -316,7 +316,7 @@ where /// Use [`enabled()`](Self::enabled) to count only active models. /// /// # Example - /// ``` + /// ```ignore /// assert_eq!(ensemble.len(), 0); /// ensemble.add(model)?; /// assert_eq!(ensemble.len(), 1); @@ -330,7 +330,7 @@ where /// Order is arbitrary (HashMap iteration order). /// /// # Example - /// ``` + /// ```ignore /// let names = ensemble.names(); /// assert!(names.contains(&"my_model".to_string())); /// ``` @@ -341,7 +341,7 @@ where /// Returns `true` if the ensemble has no registered members. /// /// # Example - /// ``` + /// ```ignore /// let ens = Ensemble::<_, _>::new(); /// assert!(ens.is_empty()); /// ``` @@ -352,7 +352,7 @@ where /// Returns the current voting strategy. /// /// # Example - /// ``` + /// ```ignore /// let ens = Ensemble::with_strategy(VotingStrategy::Weighted); /// assert_eq!(ens.strategy(), VotingStrategy::Weighted); /// ``` @@ -367,7 +367,7 @@ where /// * `None` if strategy is `Uniform` OR member not found /// /// # Example - /// ``` + /// ```ignore /// let w = ensemble.weight("my_model"); /// if let Some(weight) = w { /// println!("Weight: {}", weight); @@ -389,7 +389,7 @@ where /// Does not include per-model hyperparameters (use [`get_model_metadata`](Self::get_model_metadata) for that). /// /// # Example - /// ``` + /// ```ignore /// let info = ensemble.get_ensemble_info(); /// println!("Strategy: {:?}", info.strategy); /// println!("Active models: {}/{}", info.enabled_members, info.total_members); @@ -419,7 +419,7 @@ where /// * `Failed::input` if member not found or weight invalid /// /// # Example - /// ``` + /// ```ignore /// ensemble.set_weight("strong_model", 2.0)?; /// ``` pub fn set_weight(&mut self, name: &str, weight: f64) -> Result<(), Failed> { @@ -445,7 +445,7 @@ where /// * `desc` - New description string /// /// # Example - /// ``` + /// ```ignore /// ensemble.set_description("rf_v1", "Random Forest, depth=10, trained on Q1 data")?; /// ``` pub fn set_description(&mut self, name: &str, desc: String) -> Result<(), Failed> { @@ -469,7 +469,7 @@ where /// * `Failed::input` if switching to `Weighted` and any member lacks a weight /// /// # Example - /// ``` + /// ```ignore /// // Start with Uniform /// let mut ens = Ensemble::new(); /// ens.add(model1)?; @@ -518,7 +518,7 @@ where /// * `Failed::input` if member not found or already disabled /// /// # Example - /// ``` + /// ```ignore /// ensemble.disable("underperforming_model")?; /// let preds = ensemble.predict(&x)?; // model excluded from voting /// ``` @@ -542,7 +542,7 @@ where /// * `Failed::input` if member not found or already enabled /// /// # Example - /// ``` + /// ```ignore /// ensemble.enable("previously_disabled_model")?; /// ``` pub fn enable(&mut self, name: &str) -> Result<(), Failed> { @@ -561,7 +561,7 @@ where /// Useful for debugging, logging, or selective inspection. /// /// # Example - /// ``` + /// ```ignore /// let active = ensemble.enabled(); /// println!("Active models: {:?}", active); /// ``` @@ -590,7 +590,7 @@ where /// * `Err(Failed)` - If ensemble is empty or prediction fails /// /// # Example - /// ``` + /// ```ignore /// let predictions = ensemble.predict(&x_test)?; /// let acc = accuracy(&y_test, &predictions); /// ``` @@ -670,7 +670,7 @@ where /// * `Err(Failed)` - If any required input is missing or prediction fails /// /// # Example - /// ``` + /// ```ignore /// // Models trained on different feature slices /// let mut inputs = HashMap::new(); /// inputs.insert("slice_A".into(), x_slice_a); @@ -751,7 +751,6 @@ mod tests { use crate::ensemble::random_forest_classifier::{ RandomForestClassifier, RandomForestClassifierParameters, }; - use crate::linalg::basic::arrays::Array2; use crate::linalg::basic::matrix::DenseMatrix; use crate::neighbors::knn_classifier::{KNNClassifier, KNNClassifierParameters}; @@ -912,7 +911,7 @@ mod tests { // Test 7: predict_using_names() with feature slicing #[test] fn test_predict_using_names_feature_slicing() { - let (x_full, y) = dummy_data_2class(); // y имеет длину 6 + let (_, y) = dummy_data_2class(); // y имеет длину 6 let mut ensemble = Ensemble::, Vec>::new(); // Model A: trained on feature 0 only (must have 6 rows to match y!) diff --git a/src/ensemble/mod.rs b/src/ensemble/mod.rs index 6febfe38..ea480aae 100644 --- a/src/ensemble/mod.rs +++ b/src/ensemble/mod.rs @@ -29,7 +29,7 @@ pub mod random_forest_regressor; /// from multiple [`Predictor`] implementations using hard voting (uniform or weighted). /// /// # Quick Start -/// ``` +/// ```ignore /// use smartcore::ensemble::generic_ensemble::{Ensemble, VotingStrategy}; /// /// // Create ensemble diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index f4e8db3c..540de705 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -603,7 +603,7 @@ impl, Y: Array1 f64 { use rand::Rng; let mut rng = get_rng_impl(None); - rng.gen() + rng.random() } fn two() -> Self { @@ -100,7 +100,7 @@ impl FloatNumber for f32 { fn rand() -> f32 { use rand::Rng; let mut rng = get_rng_impl(None); - rng.gen() + rng.random() } fn two() -> Self { diff --git a/src/numbers/realnum.rs b/src/numbers/realnum.rs index 8ef71555..d11fe5ea 100644 --- a/src/numbers/realnum.rs +++ b/src/numbers/realnum.rs @@ -70,9 +70,9 @@ impl RealNumber for f64 { let mut small_rng = get_rng_impl(None); let mut rngs: Vec = (0..3) - .map(|_| SmallRng::from_rng(&mut small_rng).unwrap()) + .map(|_| SmallRng::from_rng(&mut small_rng)) .collect(); - rngs[0].gen::() + rngs[0].random::() } fn two() -> Self { @@ -119,9 +119,9 @@ impl RealNumber for f32 { let mut small_rng = get_rng_impl(None); let mut rngs: Vec = (0..3) - .map(|_| SmallRng::from_rng(&mut small_rng).unwrap()) + .map(|_| SmallRng::from_rng(&mut small_rng)) .collect(); - rngs[0].gen::() + rngs[0].random::() } fn two() -> Self { diff --git a/src/rand_custom.rs b/src/rand_custom.rs index 936ec9e9..eb9ec1cd 100644 --- a/src/rand_custom.rs +++ b/src/rand_custom.rs @@ -12,7 +12,8 @@ pub fn get_rng_impl(seed: Option) -> RngImpl { cfg_if::cfg_if! { if #[cfg(feature = "std_rand")] { use rand::RngCore; - RngImpl::seed_from_u64(rand::thread_rng().next_u64()) + // FIX: thread_rng() deprecated in rand 0.9 → use rng() + RngImpl::seed_from_u64(rand::rng().next_u64()) } else { // no std_random feature build, use getrandom #[cfg(feature = "js")] diff --git a/src/tree/base_tree_regressor.rs b/src/tree/base_tree_regressor.rs index 87288947..05e323ec 100644 --- a/src/tree/base_tree_regressor.rs +++ b/src/tree/base_tree_regressor.rs @@ -363,7 +363,7 @@ impl, Y: Array1> return; } - let split_value = rng.gen_range(min_val.to_f64().unwrap()..max_val.to_f64().unwrap()); + let split_value = rng.random_range(min_val.to_f64().unwrap()..max_val.to_f64().unwrap()); let mut true_sum = 0f64; let mut true_count = 0; From 6d8e06b1c9497a92f9fdc79ff2fd0945984023ff Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Sun, 5 Apr 2026 23:08:08 +0300 Subject: [PATCH 04/10] fixed use directives --- src/dataset/generator.rs | 2 +- src/readers/io_testing.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dataset/generator.rs b/src/dataset/generator.rs index f8e59443..87299c8f 100644 --- a/src/dataset/generator.rs +++ b/src/dataset/generator.rs @@ -1,6 +1,6 @@ //! # Dataset Generators //! -use rand::distributions::Uniform; +use rand::distr::Uniform; use rand::prelude::*; use rand_distr::Normal; diff --git a/src/readers/io_testing.rs b/src/readers/io_testing.rs index cb0b4b0f..ad5b811f 100644 --- a/src/readers/io_testing.rs +++ b/src/readers/io_testing.rs @@ -1,7 +1,7 @@ //! This module contains functionality to test IO. It has both functions that write //! to the file-system for end-to-end tests, but also abstractions to avoid this by //! reading from strings instead. -use rand::distributions::{Alphanumeric, DistString}; +use rand::distr::{Alphanumeric, DistString}; use std::fs; use std::io::Bytes; use std::io::Read; From bc6919d486ee49926426eb2d4a7205b577ec54e6 Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Sun, 5 Apr 2026 23:17:57 +0300 Subject: [PATCH 05/10] code updated to satisfy new rand requirements --- src/dataset/generator.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dataset/generator.rs b/src/dataset/generator.rs index 87299c8f..99fd36bb 100644 --- a/src/dataset/generator.rs +++ b/src/dataset/generator.rs @@ -12,7 +12,7 @@ pub fn make_blobs( num_features: usize, num_centers: usize, ) -> Dataset { - let center_box = Uniform::from(-10.0..10.0); + let center_box = Uniform::new(-10.0, 10.0).expect("Invalid uniform range"); let cluster_std = 1.0; let mut centers: Vec>> = Vec::with_capacity(num_centers); From 9111c43843c16efaa3022c54f68f34451c720316 Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Sun, 5 Apr 2026 23:19:41 +0300 Subject: [PATCH 06/10] deprecation warnings fixed --- src/dataset/generator.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dataset/generator.rs b/src/dataset/generator.rs index 99fd36bb..e2449870 100644 --- a/src/dataset/generator.rs +++ b/src/dataset/generator.rs @@ -16,7 +16,7 @@ pub fn make_blobs( let cluster_std = 1.0; let mut centers: Vec>> = Vec::with_capacity(num_centers); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..num_centers { centers.push( (0..num_features) @@ -60,7 +60,7 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset = Vec::with_capacity(num_samples * 2); let mut y: Vec = Vec::with_capacity(num_samples); @@ -97,7 +97,7 @@ pub fn make_moons(num_samples: usize, noise: f32) -> Dataset { let linspace_in = linspace(0.0, std::f32::consts::PI, num_samples_in); let noise = Normal::new(0.0, noise).unwrap(); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut x: Vec = Vec::with_capacity(num_samples * 2); let mut y: Vec = Vec::with_capacity(num_samples); From 0e1aa8cf53962d25b6cc295f817b14b162d77ede Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Sun, 5 Apr 2026 23:30:34 +0300 Subject: [PATCH 07/10] DistString -> SampleString + deprecation warnings fixed --- src/readers/io_testing.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/readers/io_testing.rs b/src/readers/io_testing.rs index ad5b811f..66dc86f4 100644 --- a/src/readers/io_testing.rs +++ b/src/readers/io_testing.rs @@ -1,7 +1,7 @@ //! This module contains functionality to test IO. It has both functions that write //! to the file-system for end-to-end tests, but also abstractions to avoid this by //! reading from strings instead. -use rand::distr::{Alphanumeric, DistString}; +use rand::distr::{Alphanumeric, SampleString}; use std::fs; use std::io::Bytes; use std::io::Read; @@ -16,7 +16,7 @@ pub struct TemporaryTextFile { impl TemporaryTextFile { pub fn new(contents: &str) -> std::io::Result { let test_text_file = TemporaryTextFile { - random_path: Alphanumeric.sample_string(&mut rand::thread_rng(), 16), + random_path: Alphanumeric.sample_string(&mut rand::rng(), 16), }; string_to_file(contents, &test_text_file.random_path)?; Ok(test_text_file) From 2a9fb0a9a6fd1bdb82d5031e5a90c989714dd87d Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Sun, 5 Apr 2026 23:33:45 +0300 Subject: [PATCH 08/10] fmt --- src/ensemble/generic_ensemble.rs | 1 - src/numbers/realnum.rs | 8 ++------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/ensemble/generic_ensemble.rs b/src/ensemble/generic_ensemble.rs index 944b5950..acc2be0b 100644 --- a/src/ensemble/generic_ensemble.rs +++ b/src/ensemble/generic_ensemble.rs @@ -774,7 +774,6 @@ mod tests { all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test )] - // Test 1: Simple add() with auto-generated names #[test] fn test_add_simple_knn_models() { diff --git a/src/numbers/realnum.rs b/src/numbers/realnum.rs index d11fe5ea..f7f539e1 100644 --- a/src/numbers/realnum.rs +++ b/src/numbers/realnum.rs @@ -69,9 +69,7 @@ impl RealNumber for f64 { fn rand() -> f64 { let mut small_rng = get_rng_impl(None); - let mut rngs: Vec = (0..3) - .map(|_| SmallRng::from_rng(&mut small_rng)) - .collect(); + let mut rngs: Vec = (0..3).map(|_| SmallRng::from_rng(&mut small_rng)).collect(); rngs[0].random::() } @@ -118,9 +116,7 @@ impl RealNumber for f32 { fn rand() -> f32 { let mut small_rng = get_rng_impl(None); - let mut rngs: Vec = (0..3) - .map(|_| SmallRng::from_rng(&mut small_rng)) - .collect(); + let mut rngs: Vec = (0..3).map(|_| SmallRng::from_rng(&mut small_rng)).collect(); rngs[0].random::() } From 8f277bfc06a4a5fd57b2ef50c4f9f0209f3e970d Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Sun, 5 Apr 2026 23:41:33 +0300 Subject: [PATCH 09/10] clippy --- src/ensemble/generic_ensemble.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ensemble/generic_ensemble.rs b/src/ensemble/generic_ensemble.rs index acc2be0b..5649cfd6 100644 --- a/src/ensemble/generic_ensemble.rs +++ b/src/ensemble/generic_ensemble.rs @@ -423,11 +423,10 @@ where /// ensemble.set_weight("strong_model", 2.0)?; /// ``` pub fn set_weight(&mut self, name: &str, weight: f64) -> Result<(), Failed> { - if matches!(self.strategy, VotingStrategy::Weighted) { - if !weight.is_finite() || weight < 0.0 { + if matches!(self.strategy, VotingStrategy::Weighted) + && (!weight.is_finite() || weight < 0.0) { return Err(Failed::input("Weight must be finite and non-negative")); } - } let member = self .members .get_mut(name) From 5dcbf4035d2e26a6ee87123d4373acf0d9032296 Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Sun, 5 Apr 2026 23:56:03 +0300 Subject: [PATCH 10/10] so, clippy actually produces code which does not conform with the fmt. facepalm. Fmt again. --- src/ensemble/generic_ensemble.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ensemble/generic_ensemble.rs b/src/ensemble/generic_ensemble.rs index 5649cfd6..561e98ed 100644 --- a/src/ensemble/generic_ensemble.rs +++ b/src/ensemble/generic_ensemble.rs @@ -424,9 +424,10 @@ where /// ``` pub fn set_weight(&mut self, name: &str, weight: f64) -> Result<(), Failed> { if matches!(self.strategy, VotingStrategy::Weighted) - && (!weight.is_finite() || weight < 0.0) { - return Err(Failed::input("Weight must be finite and non-negative")); - } + && (!weight.is_finite() || weight < 0.0) + { + return Err(Failed::input("Weight must be finite and non-negative")); + } let member = self .members .get_mut(name)