diff --git a/examples/ensemble_attack/README.md b/examples/ensemble_attack/README.md index 030433fe..2b7c1b62 100644 --- a/examples/ensemble_attack/README.md +++ b/examples/ensemble_attack/README.md @@ -1,11 +1,11 @@ # Ensemble Attack ## Data Processing -As the first step of the attack, we need to collect and split the data. The input data collected from all the attacks provided by the MIDST Challenge should be stored in `data_paths.data_paths` as defined by `config.yaml`. You can download and unzip the resources from [this Google Drive link](https://drive.google.com/drive/folders/1rmJ_E6IzG25eCL3foYAb2jVmAstXktJ1?usp=drive_link). Note that you can safely remove the provided shadow models with the competition resources since they are not used in this attack. +As the first step of the attack, we need to collect and split the data. The input data collected from all the attacks provided by the MIDST Challenge should be stored in `data_paths.midst_data_path` as defined by [`configs/experiment_config.yaml`](configs/experiment_config.yaml). You can download and unzip the resources from [this Google Drive link](https://drive.google.com/drive/folders/1rmJ_E6IzG25eCL3foYAb2jVmAstXktJ1?usp=drive_link). Note that you can safely remove the provided shadow models with the competition resources since they are not used in this attack. -Make sure directories and JSON files specified in `data_paths` and `data_processing_config` configurations in `examples/ensemble_attack/config.yaml` exist. +Make sure directories and JSON files specified in `data_paths` and `data_processing_config` configurations in `examples/ensemble_attack/configs/experiment_config.yaml` exist. -To run the whole data processing pipeline, run `run_attack.py` and set `pipeline.run_data_processing` to `true` in `config.yaml`. It reads data from `data_paths.midst_data_path` specified in config, and populates `data_paths.population_data` and `data_paths.processed_attack_data_path` directories. +To run the whole data processing pipeline, run `run_attack.py` and set `pipeline.run_data_processing` to `true` in [`configs/experiment_config.yaml`](configs/experiment_config.yaml). It reads data from `data_paths.midst_data_path` specified in config, and populates `data_paths.population_path` and `data_paths.processed_attack_data_path` directories. Data processing steps for the MIDST challenge provided resources according to Ensemble attack are as follows: diff --git a/examples/ensemble_attack/compute_attack_success.py b/examples/ensemble_attack/compute_attack_success.py index 56871113..4c50806b 100644 --- a/examples/ensemble_attack/compute_attack_success.py +++ b/examples/ensemble_attack/compute_attack_success.py @@ -41,7 +41,12 @@ def load_target_challenge_labels_and_probabilities( test_prediction_probabilities = np.load(attack_result_file_path) # Challenge labels are the true membership labels for the challenge points. - test_target = pd.read_csv(challenge_label_path).to_numpy().squeeze() + if challenge_label_path.suffix == ".npy": + test_target = np.load(challenge_label_path).squeeze() + elif challenge_label_path.suffix == ".csv": + test_target = pd.read_csv(challenge_label_path).to_numpy().squeeze() + else: + raise ValueError(f"Unsupported challenge label file type: {challenge_label_path}. Must be .npy or .csv.") assert len(test_prediction_probabilities) == len(test_target), ( "Number of challenge labels must match number of prediction probabilities." @@ -71,9 +76,12 @@ def compute_attack_success_for_given_targets( predictions = [] targets = [] for target_id in target_ids: - # Override target model id in config as ``attack_probabilities_result_path`` and - # ``challenge_label_path`` are dependent on it and change in runtime. - target_model_config.target_model_id = target_id + # If there is a target model id in the config, override it with the current target id + if "target_model_id" in target_model_config: + # Override target model id in config as ``attack_probabilities_result_path`` and + # ``challenge_label_path`` are dependent on it and change in runtime. + target_model_config.target_model_id = target_id + # Load challenge labels and prediction probabilities log(INFO, f"Loading challenge labels and prediction probabilities for target model ID {target_id}...") test_target, test_prediction_probabilities = load_target_challenge_labels_and_probabilities( diff --git a/examples/ensemble_attack/configs/experiment_config.yaml b/examples/ensemble_attack/configs/experiment_config.yaml index 4216715c..979e1f64 100644 --- a/examples/ensemble_attack/configs/experiment_config.yaml +++ b/examples/ensemble_attack/configs/experiment_config.yaml @@ -32,10 +32,6 @@ data_paths: attack_evaluation_result_path: ${base_experiment_dir}/evaluation_results # Path where the attack (train phase) evaluation results will be stored (output) -model_paths: - metaclassifier_model_path: ${base_experiment_dir}/trained_models # Path where the trained metaclassifier model will be saved - - # Dataset specific information used for processing in this example data_processing_config: midst_data_path: /projects/midst-experiments/all_tabddpms/ # Used to collect the data (input) @@ -76,7 +72,7 @@ shadow_training: training_json_config_paths: # Config json files used for tabddpm training on the trans table table_domain_file_path: ${base_data_config_dir}/trans_domain.json dataset_meta_file_path: ${base_data_config_dir}/dataset_meta.json - tabddpm_training_config_path: ${base_data_config_dir}/trans.json + training_config_path: ${base_data_config_dir}/trans.json # Model training artifacts are saved under shadow_models_data_path/workspace_name/exp_name # Also, training configs for each shadow model are created under shadow_models_data_path. shadow_models_output_path: ${base_experiment_dir}/shadow_models_and_data @@ -112,6 +108,7 @@ metaclassifier: # Temporary. Might remove having an epoch parameter. epochs: 1 meta_classifier_model_name: ${metaclassifier.model_type}_metaclassifier_model + metaclassifier_model_path: ${base_experiment_dir}/trained_models # Path where the trained metaclassifier model will be saved attack_success_computation: target_ids_to_test: [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # List of target model IDs to compute the attack success for. diff --git a/examples/ensemble_attack/configs/original_attack_config.yaml b/examples/ensemble_attack/configs/original_attack_config.yaml index 4adaa181..5e549b16 100644 --- a/examples/ensemble_attack/configs/original_attack_config.yaml +++ b/examples/ensemble_attack/configs/original_attack_config.yaml @@ -10,9 +10,6 @@ data_paths: processed_attack_data_path: ${base_data_dir}/attack_data # Path where the processed attack real train and evaluation data is stored attack_evaluation_result_path: ${base_example_dir}/attack_results # Path where the attack evaluation results will be stored -model_paths: - metaclassifier_model_path: ${base_example_dir}/trained_models # Path where the trained metaclassifier model will be saved - # Pipeline control pipeline: run_data_processing: true # Set this to false if you have already saved the processed data @@ -58,7 +55,7 @@ shadow_training: training_json_config_paths: # Config json files used for tabddpm training on the trans table table_domain_file_path: ${base_example_dir}/data_configs/trans_domain.json dataset_meta_file_path: ${base_example_dir}/data_configs/dataset_meta.json - tabddpm_training_config_path: ${base_example_dir}/data_configs/trans.json + training_config_path: ${base_example_dir}/data_configs/trans.json # Model training artifacts are saved under shadow_models_data_path/workspace_name/exp_name # Also, training configs for each shadow model are created under shadow_models_data_path. shadow_models_output_path: ${base_data_dir}/shadow_models_and_data @@ -93,6 +90,7 @@ metaclassifier: # Temporary. Might remove having an epoch parameter. epochs: 1 meta_classifier_model_name: ${metaclassifier.model_type}_metaclassifier_model + metaclassifier_model_path: ${base_example_dir}/trained_models # Path where the trained metaclassifier model will be saved # General settings diff --git a/examples/ensemble_attack/real_data_collection.py b/examples/ensemble_attack/real_data_collection.py index a0a10127..0a2b710e 100644 --- a/examples/ensemble_attack/real_data_collection.py +++ b/examples/ensemble_attack/real_data_collection.py @@ -6,7 +6,6 @@ from enum import Enum from logging import INFO from pathlib import Path -from typing import Literal import pandas as pd from omegaconf import DictConfig @@ -15,6 +14,10 @@ from midst_toolkit.common.logger import log +COLLECTED_DATA_FILE_NAME = "population_all_with_challenge.csv" +COLLECTED_DATA_NO_CHALLENGE_FILE_NAME = "population_all_no_challenge.csv" + + class AttackType(Enum): """Enum for the different attack types.""" @@ -32,7 +35,11 @@ class AttackType(Enum): TABDDPM_100K = "tabddpm_trained_with_100k" -DatasetType = Literal["train", "challenge"] +class AttackDataset(Enum): + """Enum for the different attack datasets.""" + + TRAIN = "train" + CHALLENGE = "challenge" def expand_ranges(ranges: list[tuple[int, int]]) -> list[int]: @@ -56,7 +63,7 @@ def collect_midst_attack_data( attack_type: AttackType, data_dir: Path, split_folder: str, - dataset: DatasetType, + dataset: AttackDataset, data_processing_config: DictConfig, ) -> pd.DataFrame: """ @@ -74,10 +81,9 @@ def collect_midst_attack_data( Returns: pd.DataFrame: The specified dataset in this setting. """ - assert dataset in { - "train", - "challenge", - }, "Only 'train' and 'challenge' collection is supported." + assert dataset in {AttackDataset.TRAIN, AttackDataset.CHALLENGE}, ( + "Only 'train' and 'challenge' collections are supported." + ) # `data_id` is the folder numbering of each training or challenge dataset, # and is defined with the provided config. data_id = expand_ranges(data_processing_config.folder_ranges[split_folder]) @@ -85,9 +91,9 @@ def collect_midst_attack_data( # Get file name based on the kind of dataset to be collected (i.e. train vs challenge). # TODO: Make the below parsing a bit more robust and less brittle generation_name = attack_type.value.split("_")[0] - if dataset == "challenge": + if dataset == AttackDataset.CHALLENGE: file_name = data_processing_config.challenge_data_file_name - else: + else: # dataset == AttackDataset.TRAIN # Multi-table attacks have different file names. file_name = ( data_processing_config.multi_table_train_data_file_name @@ -110,7 +116,7 @@ def collect_midst_data( midst_data_input_dir: Path, attack_types: list[AttackType], split_folders: list[str], - dataset: DatasetType, + dataset: AttackDataset, data_processing_config: DictConfig, ) -> pd.DataFrame: """ @@ -133,7 +139,6 @@ def collect_midst_data( Returns: Collected train or challenge data as a dataframe. """ - assert dataset in {"train", "challenge"}, "Only 'train' and 'challenge' collection is supported." population = [] for attack_type in attack_types: for split_folder in split_folders: @@ -204,7 +209,7 @@ def collect_population_data_ensemble( midst_data_input_dir, population_attack_types, split_folders=population_splits, - dataset="train", + dataset=AttackDataset.TRAIN, data_processing_config=data_processing_config, ) @@ -221,7 +226,8 @@ def collect_population_data_ensemble( ) # Drop ids. - df_population_no_id = df_population.drop(columns=["trans_id", "account_id"]) + id_columns = [c for c in df_population.columns if c.endswith("_id")] + df_population_no_id = df_population.drop(columns=id_columns) # Save the population data save_dataframe(df_population, save_dir, "population_all.csv") save_dataframe(df_population_no_id, save_dir, "population_all_no_id.csv") @@ -233,7 +239,7 @@ def collect_population_data_ensemble( midst_data_input_dir, attack_types=challenge_attack_types, split_folders=challenge_splits, - dataset="challenge", + dataset=AttackDataset.CHALLENGE, data_processing_config=data_processing_config, ) log(INFO, f"Collected challenge data length: {len(df_challenge)} from splits: {challenge_splits}") @@ -242,23 +248,23 @@ def collect_population_data_ensemble( # Population data without the challenge points df_population_no_challenge = df_population[~df_population["trans_id"].isin(df_challenge["trans_id"])] - save_dataframe(df_population_no_challenge, save_dir, "population_all_no_challenge.csv") + save_dataframe(df_population_no_challenge, save_dir, COLLECTED_DATA_NO_CHALLENGE_FILE_NAME) # Remove ids df_population_no_challenge_no_id = df_population_no_challenge.drop(columns=["trans_id", "account_id"]) save_dataframe( df_population_no_challenge_no_id, save_dir, - "population_all_no_challenge_no_id.csv", + f"{Path(COLLECTED_DATA_NO_CHALLENGE_FILE_NAME).stem}_no_id.csv", ) # Population data with all the challenge points df_population_with_challenge = pd.concat([df_population_no_challenge, df_challenge]) - save_dataframe(df_population_with_challenge, save_dir, "population_all_with_challenge.csv") + save_dataframe(df_population_with_challenge, save_dir, COLLECTED_DATA_FILE_NAME) # Remove ids df_population_with_challenge_no_id = df_population_with_challenge.drop(columns=["trans_id", "account_id"]) save_dataframe( df_population_with_challenge_no_id, save_dir, - "population_all_with_challenge_no_id.csv", + f"{Path(COLLECTED_DATA_FILE_NAME).stem}_no_id.csv", ) return df_population_with_challenge diff --git a/examples/ensemble_attack/run_attack.py b/examples/ensemble_attack/run_attack.py index 4e67fa50..e252be45 100644 --- a/examples/ensemble_attack/run_attack.py +++ b/examples/ensemble_attack/run_attack.py @@ -11,7 +11,7 @@ import examples.ensemble_attack.run_metaclassifier_training as meta_pipeline import examples.ensemble_attack.run_shadow_model_training as shadow_pipeline -from examples.ensemble_attack.real_data_collection import collect_population_data_ensemble +from examples.ensemble_attack.real_data_collection import COLLECTED_DATA_FILE_NAME, collect_population_data_ensemble from midst_toolkit.attacks.ensemble.data_utils import load_dataframe from midst_toolkit.attacks.ensemble.process_split_data import process_split_data from midst_toolkit.common.logger import log @@ -33,7 +33,7 @@ def run_data_processing(config: DictConfig) -> None: # is not enough. original_population_data = load_dataframe( Path(config.data_processing_config.original_population_data_path), - "population_all_with_challenge.csv", + COLLECTED_DATA_FILE_NAME, ) log(INFO, "Running data processing pipeline...") # Collect the real data from the MIDST challenge resources. diff --git a/examples/ensemble_attack/run_metaclassifier_training.py b/examples/ensemble_attack/run_metaclassifier_training.py index 47cfdd32..dd79033a 100644 --- a/examples/ensemble_attack/run_metaclassifier_training.py +++ b/examples/ensemble_attack/run_metaclassifier_training.py @@ -6,6 +6,7 @@ import pandas as pd from omegaconf import DictConfig +from examples.ensemble_attack.real_data_collection import COLLECTED_DATA_FILE_NAME from midst_toolkit.attacks.ensemble.blending import BlendingPlusPlus, MetaClassifierType from midst_toolkit.attacks.ensemble.data_utils import load_dataframe from midst_toolkit.common.logger import log @@ -80,9 +81,10 @@ def run_metaclassifier_training( assert target_synthetic_data is not None, "Target model's synthetic data is missing." target_synthetic_data = target_synthetic_data.copy() + data_file_name = config.data_file_name if "data_file_name" in config else COLLECTED_DATA_FILE_NAME df_reference = load_dataframe( Path(config.data_paths.population_path), - "population_all_with_challenge_no_id.csv", + f"{Path(data_file_name).stem}_no_id.csv", ) log( INFO, @@ -125,7 +127,7 @@ def run_metaclassifier_training( ) model_filename = config.metaclassifier.meta_classifier_model_name - model_path = Path(config.model_paths.metaclassifier_model_path) / f"{model_filename}.pkl" + model_path = Path(config.metaclassifier.metaclassifier_model_path) / f"{model_filename}.pkl" model_path.parent.mkdir(parents=True, exist_ok=True) with open(model_path, "wb") as f: pickle.dump(blending_attacker.trained_model, f) diff --git a/examples/ensemble_attack/run_shadow_model_training.py b/examples/ensemble_attack/run_shadow_model_training.py index ae69de80..e9dde456 100644 --- a/examples/ensemble_attack/run_shadow_model_training.py +++ b/examples/ensemble_attack/run_shadow_model_training.py @@ -1,21 +1,32 @@ import shutil from logging import INFO from pathlib import Path +from typing import cast import pandas as pd from omegaconf import DictConfig +from examples.ensemble_attack.real_data_collection import COLLECTED_DATA_FILE_NAME from midst_toolkit.attacks.ensemble.data_utils import load_dataframe from midst_toolkit.attacks.ensemble.rmia.shadow_model_training import ( train_three_sets_of_shadow_models, ) from midst_toolkit.attacks.ensemble.shadow_model_utils import ( - save_additional_tabddpm_config, + ModelType, + TrainingResult, + save_additional_training_config, + train_or_fine_tune_and_synthesize_with_ctgan, train_tabddpm_and_synthesize, ) +from midst_toolkit.common.config import ClavaDDPMTrainingConfig, CTGANTrainingConfig from midst_toolkit.common.logger import log +DEFAULT_TABLE_NAME = "trans" +DEFAULT_ID_COLUMN_NAME = "trans_id" +DEFAULT_MODEL_TYPE = ModelType.TABDDPM + + def run_target_model_training(config: DictConfig) -> Path: """ Function to run the target model training for RMIA attack. @@ -39,11 +50,15 @@ def run_target_model_training(config: DictConfig) -> Path: target_model_output_path = Path(config.shadow_training.target_model_output_path) target_training_json_config_paths = config.shadow_training.training_json_config_paths - # TODO: Add this to config or .json files - table_name = "trans" + table_name = config.table_name if "table_name" in config else DEFAULT_TABLE_NAME target_folder = target_model_output_path / "target_model" + model_type = DEFAULT_MODEL_TYPE + if "model_name" in config.shadow_training: + model_type = ModelType(config.shadow_training.model_name) + log(INFO, f"Training target model with model type: {model_type.value}") + target_folder.mkdir(parents=True, exist_ok=True) shutil.copyfile( target_training_json_config_paths.table_domain_file_path, @@ -53,20 +68,30 @@ def run_target_model_training(config: DictConfig) -> Path: target_training_json_config_paths.dataset_meta_file_path, target_folder / "dataset_meta.json", ) - configs, save_dir = save_additional_tabddpm_config( + configs, save_dir = save_additional_training_config( data_dir=target_folder, - training_config_json_path=Path(target_training_json_config_paths.tabddpm_training_config_path), + training_config_json_path=Path(target_training_json_config_paths.training_config_path), final_config_json_path=target_folder / f"{table_name}.json", # Path to the new json experiment_name="trained_target_model", + model_type=model_type, ) - train_result = train_tabddpm_and_synthesize( - train_set=df_real_data, - configs=configs, - save_dir=save_dir, - synthesize=True, - number_of_points_to_synthesize=config.shadow_training.number_of_points_to_synthesize, - ) + train_result: TrainingResult + if model_type == ModelType.TABDDPM: + train_result = train_tabddpm_and_synthesize( + train_set=df_real_data, + configs=cast(ClavaDDPMTrainingConfig, configs), + save_dir=save_dir, + synthesize=True, + number_of_points_to_synthesize=config.shadow_training.number_of_points_to_synthesize, + ) + elif model_type == ModelType.CTGAN: + train_result = train_or_fine_tune_and_synthesize_with_ctgan( + dataset=df_real_data, + configs=cast(CTGANTrainingConfig, configs), + save_dir=save_dir, + synthesize=True, + ) # To train the attack model (metaclassifier), we only need to save target's synthetic data, # and not the entire target model's training result object. @@ -94,19 +119,29 @@ def run_shadow_model_training(config: DictConfig, df_challenge_train: pd.DataFra at src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py. """ log(INFO, "Running shadow model training...") + + table_name = config.table_name if "table_name" in config else DEFAULT_TABLE_NAME + id_column_name = config.table_id_column_name if "table_id_column_name" in config else DEFAULT_ID_COLUMN_NAME + data_file_name = config.data_file_name if "data_file_name" in config else COLLECTED_DATA_FILE_NAME + # Load the required dataframes for shadow model training. # For shadow model training we need master_challenge_train and population data. # Master challenge is the main training (or fine-tuning) data for the shadow models. # Population data is used to pre-train some of the shadow models. - df_population_with_challenge = load_dataframe( - Path(config.data_paths.population_path), - "population_all_with_challenge.csv", + df_population_with_challenge = load_dataframe(Path(config.data_paths.population_path), data_file_name) + + model_type = DEFAULT_MODEL_TYPE + if "model_name" in config.shadow_training: + model_type = ModelType(config.shadow_training.model_name) + log(INFO, f"Training shadow models with model type: {model_type.value}") + + # Make sure master challenge train and population data have the id column. + assert id_column_name in df_challenge_train.columns, ( + f"{id_column_name} column should be present in master train data for the shadow model pipeline." ) - # Make sure master challenge train and population data have the "trans_id" column. - assert "trans_id" in df_challenge_train.columns, ( - "trans_id column should be present in master train data for the shadow model pipeline." + assert id_column_name in df_population_with_challenge.columns, ( + f"{id_column_name} column should be present in population data for the shadow model pipeline." ) - assert "trans_id" in df_population_with_challenge.columns # ``population_data`` in ensemble attack is used for shadow pre-training, and # ``master_challenge_df`` is used for fine-tuning for half of the shadow models. # For the other half of the shadow models, only ``master_challenge_df`` is used for training. @@ -116,14 +151,15 @@ def run_shadow_model_training(config: DictConfig, df_challenge_train: pd.DataFra shadow_models_output_path=Path(config.shadow_training.shadow_models_output_path), training_json_config_paths=config.shadow_training.training_json_config_paths, fine_tuning_config=config.shadow_training.fine_tuning_config, - table_name="trans", - id_column_name="trans_id", + table_name=table_name, + id_column_name=id_column_name, # Number of shadow models to train in each set of shadow training (3 sets total) results in # ``4 * n_models_per_set`` total shadow models. n_models_per_set=4, # 4 based on the original code, must be even n_reps=12, # Number of repetitions of challenge points in each shadow model training set. `12` based on the original code number_of_points_to_synthesize=config.shadow_training.number_of_points_to_synthesize, random_seed=config.random_seed, + model_type=model_type, ) log( INFO, diff --git a/examples/ensemble_attack/test_attack_model.py b/examples/ensemble_attack/test_attack_model.py index 910189bc..f3558c83 100644 --- a/examples/ensemble_attack/test_attack_model.py +++ b/examples/ensemble_attack/test_attack_model.py @@ -12,7 +12,12 @@ import pandas as pd from omegaconf import DictConfig -from examples.ensemble_attack.real_data_collection import AttackType, collect_midst_data +from examples.ensemble_attack.real_data_collection import ( + COLLECTED_DATA_FILE_NAME, + AttackDataset, + AttackType, + collect_midst_data, +) from examples.ensemble_attack.run_shadow_model_training import run_shadow_model_training from midst_toolkit.attacks.ensemble.blending import BlendingPlusPlus, MetaClassifierType from midst_toolkit.attacks.ensemble.data_utils import load_dataframe @@ -182,7 +187,7 @@ def collect_challenge_and_train_data( attack_types=challenge_attack_types, # For ensemble experiments, change to ``test`` for 10k, and change to ``final`` for 20k split_folders=["final"], - dataset="challenge", + dataset=AttackDataset.CHALLENGE, data_processing_config=data_processing_config, ) log( @@ -261,11 +266,29 @@ def train_rmia_shadows_for_test_phase(config: DictConfig) -> list[dict[str, list A list containing three dictionaries, each representing a collection of shadow models with their training data IDs and generated synthetic outputs. """ - df_challenge_experiment, df_master_train = collect_challenge_and_train_data( - config.data_processing_config, - processed_attack_data_path=Path(config.data_paths.processed_attack_data_path), - targets_data_path=Path(config.data_processing_config.midst_data_path), - ) + # Checking if challenge data exists + processed_attack_data_path = Path(config.data_paths.processed_attack_data_path) + data_file_name = config.data_file_name if "data_file_name" in config else COLLECTED_DATA_FILE_NAME + challenge_data_file_name = f"{Path(data_file_name).stem}_challenge_data.csv" + + if (processed_attack_data_path / challenge_data_file_name).exists(): + log(INFO, "Skipping data collection for testing phase.") + df_challenge_experiment = load_dataframe( + processed_attack_data_path, + challenge_data_file_name, + ) + df_master_train = load_dataframe( + processed_attack_data_path, + "master_challenge_train.csv", + ) + else: + # If challenge data does not exist, collect it from the cluster + df_challenge_experiment, df_master_train = collect_challenge_and_train_data( + config.data_processing_config, + processed_attack_data_path=Path(config.data_paths.processed_attack_data_path), + targets_data_path=Path(config.data_processing_config.midst_data_path), + ) + # Load the challenge dataframe for training RMIA shadow models. rmia_training_choice = RmiaTrainingDataChoice(config.target_model.attack_rmia_shadow_training_data_choice) df_challenge = select_challenge_data_for_training(rmia_training_choice, df_challenge_experiment, df_master_train) @@ -292,7 +315,10 @@ def run_metaclassifier_testing( Args: config: Configuration object set in ``experiments_config.yaml``. """ - log(INFO, f"Running metaclassifier testing on target model {config.target_model.target_model_id}...") + log( + INFO, + f"Running metaclassifier testing on target synthetic data at {config.target_model.target_synthetic_data_path}...", + ) if config.random_seed is not None: set_all_random_seeds(seed=config.random_seed) @@ -302,7 +328,7 @@ def run_metaclassifier_testing( meta_classifier_type = MetaClassifierType(config.metaclassifier.model_type) metaclassifier_model_name = config.metaclassifier.meta_classifier_model_name - mataclassifier_path = Path(config.model_paths.metaclassifier_model_path) / f"{metaclassifier_model_name}.pkl" + mataclassifier_path = Path(config.metaclassifier.metaclassifier_model_path) / f"{metaclassifier_model_name}.pkl" assert mataclassifier_path.exists(), ( f"No metaclassifier model found at {mataclassifier_path}. Make sure to run the training script first." ) @@ -321,7 +347,13 @@ def run_metaclassifier_testing( test_data = pd.read_csv(challenge_data_path) log(INFO, f"Challenge data loaded from {challenge_data_path} with a size of {len(test_data)}.") - test_target = pd.read_csv(challenge_label_path).to_numpy().squeeze() + if challenge_label_path.suffix == ".npy": + test_target = np.load(challenge_label_path).squeeze() + elif challenge_label_path.suffix == ".csv": + test_target = pd.read_csv(challenge_label_path).to_numpy().squeeze() + else: + raise ValueError(f"Unsupported challenge label file type: {challenge_label_path}. Must be .npy or .csv.") + assert len(test_data) == len(test_target), "Number of challenge labels must match number of challenge data points." target_synthetic_path = Path(config.target_model.target_synthetic_data_path) @@ -378,10 +410,8 @@ def run_metaclassifier_testing( # 5) Get predictions on the challenge data (test set). # Load the reference population data for DOMIAS signals. - df_reference = load_dataframe( - Path(config.data_paths.population_path), - "population_all_with_challenge_no_id.csv", - ) + data_file_name = config.data_file_name if "data_file_name" in config else COLLECTED_DATA_FILE_NAME + df_reference = load_dataframe(Path(config.data_paths.population_path), f"{Path(data_file_name).stem}_no_id.csv") probabilities, pred_score = blending_attacker.predict( df_test=test_data, diff --git a/examples/gan/README.md b/examples/gan/README.md index fd7fc530..9764e423 100644 --- a/examples/gan/README.md +++ b/examples/gan/README.md @@ -8,7 +8,7 @@ some data afterwards. ## Downloading data First, we need the data. Download it from this -[Google Drive link](https://drive.google.com/file/d/1J5qDuMHHg4dm9c3ISmb41tcTHSu1SVUC/view?usp=drive_link), +[Google Drive link](https://drive.google.com/file/d/1YbDRVn-fwfdcPnHj5eMhCa6A-YPiGnKr/view?usp=sharing), extract the files and place them in a `/data` folder in within this folder (`examples/gan`). diff --git a/examples/gan/ensemble_attack/README.md b/examples/gan/ensemble_attack/README.md new file mode 100644 index 00000000..8b3295c2 --- /dev/null +++ b/examples/gan/ensemble_attack/README.md @@ -0,0 +1,97 @@ +# CTGAN Ensemble Attack Example + +On this example, we demonstrate how to run the [Ensemble Attack](../../ensemble_attack/README.md) +using the [CTGAN](https://arxiv.org/pdf/1907.00503) model. + +## 1. Downloading data + +First, we need the data. Download it from this +[Google Drive link](https://drive.google.com/file/d/1B9z4vh51mH6ZMj5E0pJitqR8lid3EJKM/view?usp=drive_link), +extract the files and place them in a `/data/ensemble_attack` folder within this folder +(`examples/gan`). + +> [!NOTE] +> If you wish to change the data folder, you can do so by editing the `base_data_dir` attribute +> of the [`config.yaml`](config.yaml) file. + +Here is a description of the files that have been extracted: +- `trans.csv`: The full set of training data. +- `dataset_meta.json`: Metadata about the relationship between the tables in the dataset. Since this is a +single table dataset, it will only contain information about the transaction (`trans`) table. +- `trans_domain.json`: Metadata about the columns of the transaction table, such as their size +and type (`continuous` or `discrete`). +- `data_types.json`: Additional metadata about the columns, splitting them into 4 types: + - `numerical`: a list of the columns that contain numerical information + - `categorical`: a list of the columns that contain categorical information + - `variable_to_predict`: the name of the target column that will be predicted + - `id_column_name`: the name of the column in the table that represents the rows' id. + +With the data present in the correct folder, we can proceed with running the attack. + +## 2. Generating target synthetic data to be tested + +The **target model** is the model being attacked, and the **target synthetic data** +is the synthetic data generated by the target model that will be evaluated against +the attack. + +If you already have a set of synthetic data produced by a target model, +you can add its path to the `ensemble_attack.target_model.target_synthetic_data_path` +property in the [`config.yaml`](config.yaml) file and skip this step. + +If you wish to train a new target model and produce the synthetic data that will be the +target of the attack, you can run: + +```bash +python -m examples.gan.synthesize --config-path=./ensemble_attack +``` + +## 3. Producing the challenge points dataset + +The challenge points dataset is composed of real data points where half of them +were used in training the target model and half weren't. It is the dataset we are going +to use to evaluate how good the attack model is in differentiating between +the points used in training and the ones not used in training. + +To produce such dataset, run the following script: + +```bash +python -m examples.gan.ensemble_attack.make_challenge_dataset +``` + +## 4. Training the attack model + +> [!NOTE] +> In the [`config.yaml`](config.yaml) file, the attribute `ensemble_attack.shadow_training.model_name` +> is what determines this attack will be run with the CTGAN model. + +To train the attack models, execute the following command: + +```bash +python -m examples.gan.ensemble_attack.train_attack_model +``` + +This will take a long time to run, so it might be a good idea to execute it as a +background process. If you want to have a quick test run before kicking off the +full process, you can change the number of iterations, epochs, population and +sample sizes to smaller numbers. + +## 5. Testing the attack model + +To test the attack model against the target model and synthetic data produced on +[step 2](#2-generating-target-synthetic-data-to-be-tested), please run: + +```bash +python -m examples.gan.ensemble_attack.test_attack_model +``` + +## 6. Compute the attack success + +To compute the metrics about the success of the attack against the target +synthetic data, you can run the following command: + +```bash +python -m examples.gan.ensemble_attack.compute_attack_success +``` + +The results will both printed on the console and saved in the file +`examples/gan/results/attack_success_for_xgb_metaclassifier_model.txt` diff --git a/examples/gan/ensemble_attack/compute_attack_success.py b/examples/gan/ensemble_attack/compute_attack_success.py new file mode 100644 index 00000000..c8c09f99 --- /dev/null +++ b/examples/gan/ensemble_attack/compute_attack_success.py @@ -0,0 +1,30 @@ +from logging import INFO +from pathlib import Path + +import hydra +from omegaconf import DictConfig + +from examples.ensemble_attack.compute_attack_success import compute_attack_success_for_given_targets +from midst_toolkit.common.logger import log + + +@hydra.main(config_path="./", config_name="config", version_base=None) +def compute_attack_success(config: DictConfig) -> None: + """Main function to compute the attack success.""" + log( + INFO, + f"Computing attack success for target synthetic data at {config.ensemble_attack.target_model.target_synthetic_data_path}...", + ) + + compute_attack_success_for_given_targets( + target_model_config=config.ensemble_attack.target_model, + # TODO: refactor this to work better outside of the challenge context (i.e. no target ID) + # No target ID needed for CTGAN, but it needs at least one element in this array. The value does not matter. + target_ids=[0], + experiment_directory=Path(config.results_dir), + metaclassifier_model_name=config.ensemble_attack.metaclassifier.meta_classifier_model_name, + ) + + +if __name__ == "__main__": + compute_attack_success() diff --git a/examples/gan/ensemble_attack/config.yaml b/examples/gan/ensemble_attack/config.yaml new file mode 100644 index 00000000..a383a9f9 --- /dev/null +++ b/examples/gan/ensemble_attack/config.yaml @@ -0,0 +1,87 @@ +# Training example configuration +# Base data directory (can be overridden from command line) +base_data_dir: examples/gan/data/ensemble_attack +results_dir: examples/gan/results +data_name: trans +data_file_name: ${data_name}.csv + +training: + epochs: 300 + verbose: True + data_path: ${base_data_dir}/population_all_with_challenge.csv + sample_size: 20000 + +synthesizing: + sample_size: 20000 + +ensemble_attack: + random_seed: null # Set this to a value if you want to set a random seed for reproducibility + table_name: "trans" + table_id_column_name: "trans_id" + data_file_name: ${data_file_name} + + data_paths: + processed_attack_data_path: ${base_data_dir} + population_path: ${base_data_dir} # This is the population data that the attacker has collected or has access to. + attack_evaluation_result_path: ${results_dir}/evaluation_results # Path where the attack evaluation results will be stored + + data_processing_config: + column_to_stratify: "trans_type" # Attention: This value is not documented in the original codebase. + population_sample_size: 40000 # Population size is the total data that your attack has access to. + + pipeline: + # TODO: properly test these + run_data_processing: true # Set this to false if you have already saved the processed data + run_shadow_model_training: true # Set this to false if shadow models are already trained and saved + run_metaclassifier_training: true + + shadow_training: + model_name: ctgan + model_config: # Configurations specific for the CTGAN model + training: + epochs: 300 + verbose: True + synthesizing: + sample_size: 20000 + shadow_models_output_path: ${results_dir}/ensemble_attack/shadow_models + target_model_output_path: ${results_dir}/shadow_target_model_and_data + training_json_config_paths: # Config json files used for tabddpm training on the trans table + table_domain_file_path: ${base_data_dir}/trans_domain.json + dataset_meta_file_path: ${base_data_dir}/dataset_meta.json + training_config_path: ${base_data_dir}/trans.json # if this is not present, it will be created by copying the example config + fine_tuning_config: + fine_tune_diffusion_iterations: 200000 + fine_tune_classifier_iterations: 20000 + pre_train_data_size: 60000 + number_of_points_to_synthesize: 20000 # Number of synthetic data samples to be generated by shadow models. + + final_shadow_models_path: [ + "${ensemble_attack.shadow_training.shadow_models_output_path}/initial_model_rmia_1/shadow_workspace/pre_trained_model/rmia_shadows.pkl", + "${ensemble_attack.shadow_training.shadow_models_output_path}/initial_model_rmia_2/shadow_workspace/pre_trained_model/rmia_shadows.pkl", + "${ensemble_attack.shadow_training.shadow_models_output_path}/shadow_model_rmia_third_set/shadow_workspace/trained_model/rmia_shadows_third_set.pkl", + ] + target_synthetic_data_path: ${ensemble_attack.shadow_training.target_model_output_path}/target_synthetic_data.csv + + # Metaclassifier settings + metaclassifier: + # Data types json file is used for xgboost model training. + data_types_file_path: ${base_data_dir}/data_types.json + model_type: "xgb" + # Model training parameters + num_optuna_trials: 100 # Original code: 100 + num_kfolds: 5 + use_gpu: false + # Temporary. Might remove having an epoch parameter. + epochs: 1 + meta_classifier_model_name: ${ensemble_attack.metaclassifier.model_type}_metaclassifier_model + metaclassifier_model_path: ${results_dir}/trained_models # Path where the trained metaclassifier model will be saved + + target_model: # This is only used for testing the attack on a real target model. + target_synthetic_data_path: ${results_dir}/${data_name}_synthetic.csv + challenge_data_path: ${ensemble_attack.data_paths.processed_attack_data_path}/${data_name}_challenge_data.csv + challenge_label_path: ${ensemble_attack.data_paths.processed_attack_data_path}/${data_name}_challenge_labels.npy + + target_shadow_models_output_path: ${results_dir}/test_all_targets # Sub-directory to store test shadows and results + attack_probabilities_result_path: ${results_dir}/test_probabilities + attack_rmia_shadow_training_data_choice: "combined" # Options: "combined", "only_challenge", "only_train". This determines which data to use for training RMIA attack model in testing phase. + # See select_challenge_data_for_training()'s docstring for more details. diff --git a/examples/gan/ensemble_attack/make_challenge_dataset.py b/examples/gan/ensemble_attack/make_challenge_dataset.py new file mode 100644 index 00000000..1ab30ea9 --- /dev/null +++ b/examples/gan/ensemble_attack/make_challenge_dataset.py @@ -0,0 +1,44 @@ +from logging import INFO +from pathlib import Path + +import hydra +import numpy as np +import pandas as pd +from omegaconf import DictConfig + +from examples.gan.utils import get_table_name +from midst_toolkit.common.logger import log + + +@hydra.main(config_path="./", config_name="config", version_base=None) +def make_challenge_dataset(config: DictConfig) -> None: + """Main function to make the challenge dataset.""" + log(INFO, "Making challenge dataset...") + + if config.training.data_path is None: + dataset_name = get_table_name(config.base_data_dir) + real_data = pd.read_csv(Path(config.base_data_dir) / f"{dataset_name}.csv") + else: + dataset_name = Path(config.training.data_path).stem + real_data = pd.read_csv(config.training.data_path) + + training_data = pd.read_csv(Path(config.results_dir) / f"{dataset_name}_sampled.csv") + id_column = config.ensemble_attack.table_id_column_name + untrained_data = real_data[~real_data[id_column].isin(training_data[id_column])].sample(len(training_data)) + + challenge_data = pd.concat([training_data, untrained_data]) + challenge_data_labels = np.concatenate([np.ones(len(training_data)), np.zeros(len(untrained_data))]) + + processed_attack_data_path = Path(config.ensemble_attack.data_paths.processed_attack_data_path) + processed_attack_data_path.mkdir(parents=True, exist_ok=True) + + challenge_data_path = processed_attack_data_path / f"{dataset_name}_challenge_data.csv" + challenge_label_path = processed_attack_data_path / f"{dataset_name}_challenge_labels.npy" + log(INFO, f"Saving challenge data to {challenge_data_path}") + challenge_data.to_csv(challenge_data_path, index=False) + log(INFO, f"Saving challenge labels to {challenge_label_path}") + np.save(challenge_label_path, challenge_data_labels) + + +if __name__ == "__main__": + make_challenge_dataset() diff --git a/examples/gan/ensemble_attack/test_attack_model.py b/examples/gan/ensemble_attack/test_attack_model.py new file mode 100644 index 00000000..d684402a --- /dev/null +++ b/examples/gan/ensemble_attack/test_attack_model.py @@ -0,0 +1,21 @@ +from logging import INFO + +import hydra +from omegaconf import DictConfig + +from examples.ensemble_attack.test_attack_model import run_metaclassifier_testing +from midst_toolkit.common.logger import log + + +@hydra.main(config_path="./", config_name="config", version_base=None) +def attack_model_test(config: DictConfig) -> None: + """Main function to test the attack model.""" + log( + INFO, + f"Testing attack model against synthetic data at {config.ensemble_attack.target_model.target_synthetic_data_path}...", + ) + run_metaclassifier_testing(config.ensemble_attack) + + +if __name__ == "__main__": + attack_model_test() diff --git a/examples/gan/ensemble_attack/train_attack_model.py b/examples/gan/ensemble_attack/train_attack_model.py new file mode 100644 index 00000000..84f2b5af --- /dev/null +++ b/examples/gan/ensemble_attack/train_attack_model.py @@ -0,0 +1,105 @@ +import json +from logging import INFO +from pathlib import Path + +import hydra +from omegaconf import DictConfig, OmegaConf + +from examples.ensemble_attack.run_metaclassifier_training import run_metaclassifier_training +from examples.ensemble_attack.run_shadow_model_training import run_shadow_model_training, run_target_model_training +from midst_toolkit.attacks.ensemble.data_utils import load_dataframe, save_dataframe +from midst_toolkit.attacks.ensemble.process_split_data import process_split_data +from midst_toolkit.common.logger import log +from midst_toolkit.common.random import set_all_random_seeds + + +@hydra.main(config_path="./", config_name="config", version_base=None) +def train_attack_model(config: DictConfig) -> None: + """ + Train the Ensemble Attack pipeline with CTGAN model. + + As the first step, data processing is done. + Second step is shadow model training used for RMIA attack. + Third step is metaclassifier training and evaluation. + + Args: + config: Attack configuration as an OmegaConf DictConfig object. + """ + if config.ensemble_attack.random_seed is not None: + set_all_random_seeds(seed=config.ensemble_attack.random_seed) + log(INFO, f"Training phase random seed set to {config.ensemble_attack.random_seed}.") + + if config.ensemble_attack.pipeline.run_data_processing: + log(INFO, "Running data processing pipeline...") + # The following function saves the required dataframe splits in the specified processed_attack_data_path path. + population_data = load_dataframe( + Path(config.ensemble_attack.data_paths.population_path), + config.data_file_name, + ) + + # Removing id columns and saving the dataset + id_columns = [c for c in population_data.columns if c.endswith("_id")] + population_data_no_id = population_data.drop(columns=id_columns) + save_dataframe( + population_data_no_id, + Path(config.ensemble_attack.data_paths.population_path), + f"{Path(config.data_file_name).stem}_no_id.csv", + ) + + process_split_data( + all_population_data=population_data, + processed_attack_data_path=Path(config.ensemble_attack.data_paths.processed_attack_data_path), + # TODO: column_to_stratify value is not documented in the original codebase. + column_to_stratify=config.ensemble_attack.data_processing_config.column_to_stratify, + num_total_samples=config.ensemble_attack.data_processing_config.population_sample_size, + random_seed=config.ensemble_attack.random_seed, + ) + + # Saving the model config from the config.yaml into a json file + # because that's what the ensemble attack code will be looking for + training_config_path = Path(config.ensemble_attack.shadow_training.training_json_config_paths.training_config_path) + training_config_path.unlink(missing_ok=True) + with open(training_config_path, "w") as f: + training_config = OmegaConf.to_container(config.ensemble_attack.shadow_training.model_config) + assert isinstance(training_config, dict), "Training config must be a dictionary." + training_config["general"] = { + "test_data_dir": config.base_data_dir, + "sample_prefix": "ctgan", + # The values below will be overriden + "exp_name": "", + "data_dir": "", + "workspace_dir": "", + } + json.dump(training_config, f) + + if config.ensemble_attack.pipeline.run_shadow_model_training: + log(INFO, "Training the shadow models...") + master_challenge_train = load_dataframe( + Path(config.ensemble_attack.data_paths.population_path), + "master_challenge_train.csv", + ) + shadow_data_paths = run_shadow_model_training(config.ensemble_attack, master_challenge_train) + shadow_data_paths = [Path(path) for path in shadow_data_paths] + + log(INFO, "Training the target model...") + target_model_synthetic_path = run_target_model_training(config.ensemble_attack) + + if config.ensemble_attack.pipeline.run_metaclassifier_training: + log(INFO, "Training the metaclassifier...") + if not config.ensemble_attack.pipeline.run_shadow_model_training: + # If shadow model training is skipped, we need to provide the previous shadow model and target model paths. + shadow_data_paths = [ + Path(path) for path in config.ensemble_attack.shadow_training.final_shadow_models_path + ] + target_model_synthetic_path = Path(config.ensemble_attack.shadow_training.target_synthetic_data_path) + + assert len(shadow_data_paths) == 3, "The attack_data_paths list must contain exactly three elements." + assert target_model_synthetic_path is not None, ( + "The target_data_path must be provided for metaclassifier training." + ) + + run_metaclassifier_training(config.ensemble_attack, shadow_data_paths, target_model_synthetic_path) + + +if __name__ == "__main__": + train_attack_model() diff --git a/examples/gan/synthesize.py b/examples/gan/synthesize.py index 413b65e0..53a7b761 100644 --- a/examples/gan/synthesize.py +++ b/examples/gan/synthesize.py @@ -34,8 +34,12 @@ def main(config: DictConfig) -> None: log(INFO, f"Synthesizing data of size {config.synthesizing.sample_size}...") synthetic_data = ctgan.sample(num_rows=config.synthesizing.sample_size) - table_name = get_table_name(config.base_data_dir) - synthetic_data_file = Path(config.results_dir) / f"{table_name}_synthetic.csv" + if config.training.data_path is not None: + dataset_name = Path(config.training.data_path).stem + else: + dataset_name = get_table_name(config.base_data_dir) + + synthetic_data_file = Path(config.results_dir) / f"{dataset_name}_synthetic.csv" log(INFO, f"Saving synthetic data to {synthetic_data_file}...") synthetic_data.to_csv(synthetic_data_file, index=False) diff --git a/examples/gan/train.py b/examples/gan/train.py index 40bbbe02..379c2a7f 100644 --- a/examples/gan/train.py +++ b/examples/gan/train.py @@ -22,15 +22,27 @@ def main(config: DictConfig) -> None: Args: config: Configuration as an OmegaConf DictConfig object. """ - log(INFO, "Loading data...") - table_name = get_table_name(config.base_data_dir) + if config.training.data_path is None: + log(INFO, "Loading data with table name...") + dataset_name = table_name + real_data = pd.read_csv(Path(config.base_data_dir) / f"{table_name}.csv") + + else: + log(INFO, f"Loading data from {config.training.data_path}...") + dataset_name = Path(config.training.data_path).stem + real_data = pd.read_csv(config.training.data_path) + + if config.training.sample_size is not None: + log(INFO, f"Sampling {config.training.sample_size} rows from data...") + real_data = real_data.sample(n=config.training.sample_size) + Path(config.results_dir).mkdir(parents=True, exist_ok=True) + real_data.to_csv(Path(config.results_dir) / f"{dataset_name}_sampled.csv", index=False) + with open(Path(config.base_data_dir) / f"{table_name}_domain.json", "r") as f: domain_info = json.load(f) - real_data = pd.read_csv(Path(config.base_data_dir) / f"{table_name}.csv") - metadata, real_data_without_ids = get_single_table_svd_metadata(real_data, domain_info) log(INFO, "Fitting CTGAN...") diff --git a/examples/synthesizing/multi_table/README.md b/examples/synthesizing/multi_table/README.md index 737b49ec..0396e010 100644 --- a/examples/synthesizing/multi_table/README.md +++ b/examples/synthesizing/multi_table/README.md @@ -7,7 +7,7 @@ up using the code in this toolkit. ## Downloading data First, we need the data. Download it from this -[Google Drive link](https://drive.google.com/file/d/1Ao222l4AJjG54-HDEGCWkIfzRbl9_IKa/view?usp=drive_link), +[Google Drive link](https://drive.google.com/file/d/1x2yXw824sMUJb9WKUoTkcyfPfx3zS7We/view?usp=sharing), extract the files and place them in a `/data` folder within this folder (`examples/synthesizing/multi_table`). diff --git a/examples/synthesizing/multi_table/run_synthesizing.py b/examples/synthesizing/multi_table/run_synthesizing.py index 9d845e2e..ba6916f4 100644 --- a/examples/synthesizing/multi_table/run_synthesizing.py +++ b/examples/synthesizing/multi_table/run_synthesizing.py @@ -7,7 +7,7 @@ from omegaconf import DictConfig from examples.training.multi_table import run_training -from midst_toolkit.common.config import GeneralConfig, MatchingConfig, SamplingConfig +from midst_toolkit.common.config import ClavaDDPMMatchingConfig, ClavaDDPMSamplingConfig, GeneralConfig from midst_toolkit.common.logger import TOOLKIT_LOGGER, log from midst_toolkit.models.clavaddpm.data_loaders import load_tables from midst_toolkit.models.clavaddpm.enumerations import Relation @@ -76,8 +76,8 @@ def main(config: DictConfig) -> None: Path(config.results_dir), models, GeneralConfig(**config.general_config), - SamplingConfig(**config.sampling_config), - MatchingConfig(**config.matching_config), + ClavaDDPMSamplingConfig(**config.sampling_config), + ClavaDDPMMatchingConfig(**config.matching_config), all_group_lengths_prob_dicts, ) diff --git a/examples/synthesizing/single_table/README.md b/examples/synthesizing/single_table/README.md index 5f6f1f51..924b024d 100644 --- a/examples/synthesizing/single_table/README.md +++ b/examples/synthesizing/single_table/README.md @@ -7,7 +7,7 @@ up using the code in this toolkit. ## Downloading data First, we need the data. Download it from this -[Google Drive link](https://drive.google.com/file/d/1J5qDuMHHg4dm9c3ISmb41tcTHSu1SVUC/view?usp=drive_link), +[Google Drive link](https://drive.google.com/file/d/1YbDRVn-fwfdcPnHj5eMhCa6A-YPiGnKr/view?usp=sharing), extract the files and place them in a `/data` folder within this folder (`examples/synthesizing/single_table`). diff --git a/examples/synthesizing/single_table/run_synthesizing.py b/examples/synthesizing/single_table/run_synthesizing.py index b9f6a649..fd5341f3 100644 --- a/examples/synthesizing/single_table/run_synthesizing.py +++ b/examples/synthesizing/single_table/run_synthesizing.py @@ -7,7 +7,7 @@ from omegaconf import DictConfig from examples.training.single_table import run_training -from midst_toolkit.common.config import GeneralConfig, MatchingConfig, SamplingConfig +from midst_toolkit.common.config import ClavaDDPMMatchingConfig, ClavaDDPMSamplingConfig, GeneralConfig from midst_toolkit.common.logger import TOOLKIT_LOGGER, log from midst_toolkit.models.clavaddpm.data_loaders import load_tables from midst_toolkit.models.clavaddpm.enumerations import Relation @@ -73,8 +73,8 @@ def main(config: DictConfig) -> None: Path(config.results_dir), models, GeneralConfig(**config.general_config), - SamplingConfig(**config.sampling_config), - MatchingConfig(**config.matching_config), + ClavaDDPMSamplingConfig(**config.sampling_config), + ClavaDDPMMatchingConfig(**config.matching_config), ) log(INFO, "Data synthesized successfully.") diff --git a/examples/training/multi_table/README.md b/examples/training/multi_table/README.md index 31791112..fa374d50 100644 --- a/examples/training/multi_table/README.md +++ b/examples/training/multi_table/README.md @@ -7,7 +7,7 @@ code in this toolkit. ## Downloading data First, we need the data. Download it from this -[Google Drive link](https://drive.google.com/file/d/1Ao222l4AJjG54-HDEGCWkIfzRbl9_IKa/view?usp=drive_link), +[Google Drive link](https://drive.google.com/file/d/1x2yXw824sMUJb9WKUoTkcyfPfx3zS7We/view?usp=sharing), extract the files and place them in a `/data` folder in within this folder (`examples/training/multi_table`). diff --git a/examples/training/multi_table/run_training.py b/examples/training/multi_table/run_training.py index 6d5548a5..095f52f5 100644 --- a/examples/training/multi_table/run_training.py +++ b/examples/training/multi_table/run_training.py @@ -5,12 +5,12 @@ import hydra from omegaconf import DictConfig -from midst_toolkit.common.config import ClassifierConfig, ClusteringConfig, DiffusionConfig +from midst_toolkit.common.config import ClavaDDPMClassifierConfig, ClavaDDPMClusteringConfig, ClavaDDPMDiffusionConfig from midst_toolkit.common.logger import TOOLKIT_LOGGER, log from midst_toolkit.common.variables import DEVICE from midst_toolkit.models.clavaddpm.clustering import clava_clustering from midst_toolkit.models.clavaddpm.data_loaders import Table, load_tables -from midst_toolkit.models.clavaddpm.train import ModelArtifacts, clava_training +from midst_toolkit.models.clavaddpm.train import ClavaDDPMModelArtifacts, clava_training # Preventing some excessive logging @@ -32,12 +32,12 @@ def main(config: DictConfig) -> None: tables, relation_order, _ = load_tables(Path(config.base_data_dir)) log(INFO, "Clustering data...") - clustering_config = ClusteringConfig(**config.clustering_config) + clustering_config = ClavaDDPMClusteringConfig(**config.clustering_config) tables, _ = clava_clustering(tables, relation_order, Path(config.results_dir), clustering_config) log(INFO, "Training model...") - diffusion_config = DiffusionConfig(**config.diffusion_config) - classifier_config = ClassifierConfig(**config.classifier_config) + diffusion_config = ClavaDDPMDiffusionConfig(**config.diffusion_config) + classifier_config = ClavaDDPMClassifierConfig(**config.classifier_config) tables, _ = clava_training( tables, @@ -65,7 +65,7 @@ def main(config: DictConfig) -> None: result = pickle.load(f) # Asserting the results are the correct type - assert isinstance(result, ModelArtifacts) + assert isinstance(result, ClavaDDPMModelArtifacts) log(INFO, f"Result size (in bytes): {results_file.stat().st_size}") diff --git a/examples/training/single_table/README.md b/examples/training/single_table/README.md index ac6fa12b..b274f733 100644 --- a/examples/training/single_table/README.md +++ b/examples/training/single_table/README.md @@ -7,7 +7,7 @@ code in this toolkit. ## Downloading data First, we need the data. Download it from this -[Google Drive link](https://drive.google.com/file/d/1J5qDuMHHg4dm9c3ISmb41tcTHSu1SVUC/view?usp=drive_link), +[Google Drive link](https://drive.google.com/file/d/1YbDRVn-fwfdcPnHj5eMhCa6A-YPiGnKr/view?usp=sharing), extract the files and place them in a `/data` folder in within this folder (`examples/training/single_table`). diff --git a/examples/training/single_table/run_training.py b/examples/training/single_table/run_training.py index 74897db7..62e6fdb8 100644 --- a/examples/training/single_table/run_training.py +++ b/examples/training/single_table/run_training.py @@ -5,11 +5,11 @@ import hydra from omegaconf import DictConfig -from midst_toolkit.common.config import DiffusionConfig +from midst_toolkit.common.config import ClavaDDPMDiffusionConfig from midst_toolkit.common.logger import TOOLKIT_LOGGER, log from midst_toolkit.common.variables import DEVICE from midst_toolkit.models.clavaddpm.data_loaders import load_tables -from midst_toolkit.models.clavaddpm.train import ModelArtifacts, clava_training +from midst_toolkit.models.clavaddpm.train import ClavaDDPMModelArtifacts, clava_training # Preventing some excessive logging @@ -31,7 +31,7 @@ def main(config: DictConfig) -> None: tables, relation_order, _ = load_tables(Path(config.base_data_dir)) log(INFO, "Training model...") - diffusion_config = DiffusionConfig(**config.diffusion_config) + diffusion_config = ClavaDDPMDiffusionConfig(**config.diffusion_config) tables, _ = clava_training( tables, @@ -49,7 +49,7 @@ def main(config: DictConfig) -> None: result = pickle.load(f) # Asserting the results are the correct type - assert isinstance(result, ModelArtifacts) + assert isinstance(result, ClavaDDPMModelArtifacts) log(INFO, f"Result size (in bytes): {results_file.stat().st_size}") diff --git a/src/midst_toolkit/attacks/ensemble/blending.py b/src/midst_toolkit/attacks/ensemble/blending.py index 24104cc1..87417d37 100644 --- a/src/midst_toolkit/attacks/ensemble/blending.py +++ b/src/midst_toolkit/attacks/ensemble/blending.py @@ -252,7 +252,9 @@ def predict( if y_test is not None: score = TprAtFpr.get_tpr_at_fpr( - true_membership=y_test, predicted_membership=probabilities, fpr_threshold=0.1 + true_membership=y_test, + predicted_membership=probabilities, + fpr_threshold=0.1, ) return probabilities, score diff --git a/src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py b/src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py index f4ec8aa2..ec51cf4b 100644 --- a/src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py +++ b/src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py @@ -12,7 +12,7 @@ import torch from torch import optim -from midst_toolkit.common.config import ClassifierConfig, DiffusionConfig +from midst_toolkit.common.config import ClavaDDPMClassifierConfig, ClavaDDPMDiffusionConfig from midst_toolkit.common.enumerations import DataSplit from midst_toolkit.common.logger import KeyValueLogger, log from midst_toolkit.common.variables import DEVICE @@ -37,7 +37,7 @@ ) from midst_toolkit.models.clavaddpm.sampler import ScheduleSamplerType from midst_toolkit.models.clavaddpm.train import ( - ModelArtifacts, + ClavaDDPMModelArtifacts, _numerical_forward_backward_log, get_table_metadata, ) @@ -56,7 +56,7 @@ def fine_tune_model( weight_decay: float, data_split_ratios: list[float], device: torch.device = DEVICE, -) -> ModelArtifacts: +) -> ClavaDDPMModelArtifacts: """ Fine-tune a trained diffusion model on a new dataset. @@ -124,7 +124,7 @@ def fine_tune_model( if dataset.numerical_transform is not None: inverse_transform_function = dataset.numerical_transform.inverse_transform - return ModelArtifacts( + return ClavaDDPMModelArtifacts( diffusion=diffusion, label_encoders=label_encoders, dataset=dataset, @@ -241,17 +241,17 @@ def fine_tune_classifier( def child_fine_tuning( - pre_trained_model: ModelArtifacts, + pre_trained_model: ClavaDDPMModelArtifacts, child_df_with_cluster: pd.DataFrame, child_domain_dict: dict[str, Any], parent_name: str | None, child_name: str, - diffusion_config: DiffusionConfig, - classifier_config: ClassifierConfig | None, + diffusion_config: ClavaDDPMDiffusionConfig, + classifier_config: ClavaDDPMClassifierConfig | None, fine_tuning_diffusion_iterations: int, fine_tuning_classifier_iterations: int, device: torch.device = DEVICE, -) -> ModelArtifacts: +) -> ClavaDDPMModelArtifacts: """ Fine-tune a child model based on the parent model. @@ -340,14 +340,14 @@ def child_fine_tuning( def clava_fine_tuning( - trained_models: dict[Relation, ModelArtifacts], + trained_models: dict[Relation, ClavaDDPMModelArtifacts], new_tables: Tables, relation_order: RelationOrder, - diffusion_config: DiffusionConfig, - classifier_config: ClassifierConfig, + diffusion_config: ClavaDDPMDiffusionConfig, + classifier_config: ClavaDDPMClassifierConfig, fine_tuning_diffusion_iterations: int, fine_tuning_classifier_iterations: int, -) -> dict[Relation, ModelArtifacts]: +) -> dict[Relation, ClavaDDPMModelArtifacts]: """ Fine-tune the trained models on new tables data. diff --git a/src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py b/src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py index 0f52d61a..92b69088 100644 --- a/src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py +++ b/src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py @@ -3,16 +3,20 @@ import shutil from logging import INFO from pathlib import Path -from typing import Any +from typing import Any, cast import pandas as pd from omegaconf import DictConfig from midst_toolkit.attacks.ensemble.shadow_model_utils import ( + ModelType, + TrainingResult, fine_tune_tabddpm_and_synthesize, - save_additional_tabddpm_config, + save_additional_training_config, + train_or_fine_tune_and_synthesize_with_ctgan, train_tabddpm_and_synthesize, ) +from midst_toolkit.common.config import ClavaDDPMTrainingConfig, CTGANTrainingConfig from midst_toolkit.common.logger import log @@ -32,6 +36,7 @@ def train_fine_tuned_shadow_models( number_of_points_to_synthesize: int = 20000, init_data_seed: int | None = None, random_seed: int | None = None, + model_type: ModelType = ModelType.TABDDPM, ) -> Path: """ Train ``n_models`` shadow models that start from a pre-trained TabDDPM model and are fine-tuned on @@ -65,7 +70,7 @@ def train_fine_tuned_shadow_models( An example of this config is provided in ``examples/ensemble_attack/config.yaml``. Required keys are: - table_domain_file_path (str): Path to the table domain json file. - dataset_meta_file_path (str): Path to dataset meta json file. - - tabddpm_training_config_path (str): Path to table's training config json file. + - training_config_path (str): Path to table's training config json file. fine_tuning_config: Configuration dictionary containing shadow model fine-tuning specific information. init_model_id: An ID to assign to the pre-trained initial models. This can be used to save multiple pre-trained models with different IDs. @@ -76,6 +81,7 @@ def train_fine_tuned_shadow_models( defaults to 20,000. init_data_seed: Random seed for the initial training set. random_seed: Random seed used for reproducibility, defaults to None. + model_type: Type of model to be used for training the shadow models. Defaults to ModelType.TABDDPM. Returns: The path where the shadow models and their artifacts are saved. @@ -112,21 +118,39 @@ def train_fine_tuned_shadow_models( ) # Train initial model with 60K data without any challenge points - # ``save_additional_tabddpm_config`` makes a personalized copy of the training config for each - # tabddpm model (here the base model). + # ``save_additional_training_config`` makes a personalized copy of the training config for each + # training model (here the base model). # All the shadow models will be saved under the base model data directory. - configs, save_dir = save_additional_tabddpm_config( + configs, save_dir = save_additional_training_config( data_dir=shadow_model_data_folder, - training_config_json_path=Path(training_json_config_paths.tabddpm_training_config_path), + training_config_json_path=Path(training_json_config_paths.training_config_path), final_config_json_path=shadow_model_data_folder / f"{table_name}.json", # Path to the new json experiment_name="pre_trained_model", + model_type=model_type, ) # Train the initial model if it is not already trained and saved. initial_model_path = save_dir / f"initial_model_rmia_{init_model_id}.pkl" if not initial_model_path.exists(): - log(INFO, f"Training initial model with ID {init_model_id}...") - initial_model_training_results = train_tabddpm_and_synthesize(train, configs, save_dir, synthesize=False) + log(INFO, f"Training initial {model_type.value} model with ID {init_model_id}...") + + initial_model_training_results: TrainingResult + if model_type == ModelType.TABDDPM: + initial_model_training_results = train_tabddpm_and_synthesize( + train, + cast(ClavaDDPMTrainingConfig, configs), + save_dir, + synthesize=False, + ) + elif model_type == ModelType.CTGAN: + initial_model_training_results = train_or_fine_tune_and_synthesize_with_ctgan( + train, + cast(CTGANTrainingConfig, configs), + save_dir, + synthesize=False, + ) + else: + raise ValueError(f"Invalid model type: {model_type}") # Save the initial model # Pickle dump the results @@ -169,16 +193,28 @@ def train_fine_tuned_shadow_models( # Shuffle the dataset selected_challenges = selected_challenges.sample(frac=1, random_state=random_seed).reset_index(drop=True) - train_result = fine_tune_tabddpm_and_synthesize( - trained_models=initial_model_training_results.models, - fine_tune_set=selected_challenges, - configs=configs, - save_dir=save_dir, - fine_tuning_diffusion_iterations=fine_tuning_config.fine_tune_diffusion_iterations, - fine_tuning_classifier_iterations=fine_tuning_config.fine_tune_classifier_iterations, - synthesize=True, - number_of_points_to_synthesize=number_of_points_to_synthesize, - ) + if model_type == ModelType.TABDDPM: + train_result = fine_tune_tabddpm_and_synthesize( + trained_models=initial_model_training_results.models, + fine_tune_set=selected_challenges, + configs=cast(ClavaDDPMTrainingConfig, configs), + save_dir=save_dir, + fine_tuning_diffusion_iterations=fine_tuning_config.fine_tune_diffusion_iterations, + fine_tuning_classifier_iterations=fine_tuning_config.fine_tune_classifier_iterations, + synthesize=True, + number_of_points_to_synthesize=number_of_points_to_synthesize, + ) + elif model_type == ModelType.CTGAN: + train_result = train_or_fine_tune_and_synthesize_with_ctgan( + dataset=selected_challenges, + configs=cast(CTGANTrainingConfig, configs), + save_dir=save_dir, + synthesize=True, + trained_model=initial_model_training_results.models[(None, table_name)].model, + ) + else: + raise ValueError(f"Invalid model type: {model_type}") + assert train_result.synthetic_data is not None, "Fine-tuned models should generate synthetic data." log( INFO, @@ -204,6 +240,7 @@ def train_shadow_on_half_challenge_data( id_column_name: str, number_of_points_to_synthesize: int = 20000, random_seed: int | None = None, + model_type: ModelType = ModelType.TABDDPM, ) -> Path: """ 1. Create eight training sets with exactly half of the observations included in the challenge lists @@ -223,12 +260,13 @@ def train_shadow_on_half_challenge_data( An example of this config is provided in ``examples/ensemble_attack/config.yaml``. Required keys are: - table_domain_file_path (str): Path to the table domain json file. - dataset_meta_file_path (str): Path to dataset meta json file. - - tabddpm_training_config_path (str): Path to table's training config json file. + - training_config_path (str): Path to table's training config json file. table_name: Name of the main table to be used for training the TabDDPM model. id_column_name: Name of the ID column in the data. number_of_points_to_synthesize: Size of the synthetic data to be generated by each shadow model, defaults to 20,000. random_seed: Random seed used for reproducibility, defaults to None. + model_type: Type of model to be used for training the shadow models. Defaults to ModelType.TABDDPM. Returns: The path where the shadow models and their artifacts are saved. @@ -259,11 +297,12 @@ def train_shadow_on_half_challenge_data( training_json_config_paths.dataset_meta_file_path, shadow_folder / "dataset_meta.json", ) - configs, save_dir = save_additional_tabddpm_config( + configs, save_dir = save_additional_training_config( data_dir=shadow_folder, - training_config_json_path=Path(training_json_config_paths.tabddpm_training_config_path), + training_config_json_path=Path(training_json_config_paths.training_config_path), final_config_json_path=shadow_folder / f"{table_name}.json", # Path to the new json experiment_name="trained_model", + model_type=model_type, ) attack_data: dict[str, Any] = { "selected_sets": selected_id_lists, @@ -283,13 +322,25 @@ def train_shadow_on_half_challenge_data( # Shuffle the dataset selected_challenges = selected_challenges.sample(frac=1, random_state=random_seed).reset_index(drop=True) - train_result = train_tabddpm_and_synthesize( - selected_challenges, - configs, - save_dir, - synthesize=True, - number_of_points_to_synthesize=number_of_points_to_synthesize, - ) + train_result: TrainingResult + if model_type == ModelType.TABDDPM: + train_result = train_tabddpm_and_synthesize( + selected_challenges, + cast(ClavaDDPMTrainingConfig, configs), + save_dir, + synthesize=True, + number_of_points_to_synthesize=number_of_points_to_synthesize, + ) + elif model_type == ModelType.CTGAN: + train_result = train_or_fine_tune_and_synthesize_with_ctgan( + dataset=selected_challenges, + configs=cast(CTGANTrainingConfig, configs), + save_dir=save_dir, + synthesize=True, + ) + else: + raise ValueError(f"Invalid model type: {model_type}") + assert train_result.synthetic_data is not None, "Trained shadow model did not generate synthetic data." log( INFO, @@ -318,6 +369,7 @@ def train_three_sets_of_shadow_models( n_reps: int = 12, number_of_points_to_synthesize: int = 20000, random_seed: int | None = None, + model_type: ModelType = ModelType.TABDDPM, ) -> tuple[Path, Path, Path]: """ Runs the shadow model training pipeline of the ensemble attack. This pipeline trains three sets of shadow models. @@ -354,7 +406,7 @@ def train_three_sets_of_shadow_models( An example of this config is provided in ``examples/ensemble_attack/config.yaml``. Required keys are: - table_domain_file_path (str): Path to the table domain json file. - dataset_meta_file_path (str): Path to dataset meta json file. - - tabddpm_training_config_path (str): Path to table's training config json file. + - training_config_path (str): Path to table's training config json file. fine_tuning_config: Configuration dictionary containing shadow model fine-tuning specific information. An example of this config is provided in ``examples/ensemble_attack/config.yaml``. Required keys are: - fine_tune_diffusion_iterations (int): Number of diffusion fine-tuning iterations. @@ -367,6 +419,7 @@ def train_three_sets_of_shadow_models( number_of_points_to_synthesize: Size of the synthetic data to be generated by each shadow model, defaults to 20,000. random_seed: Random seed used for reproducibility, defaults to None. + model_type: Type of model to be used for training the shadow models. Defaults to ModelType.TABDDPM. Returns: Paths where the shadow models and their artifacts including synthetic data are saved for each of @@ -392,6 +445,7 @@ def train_three_sets_of_shadow_models( number_of_points_to_synthesize=number_of_points_to_synthesize, init_data_seed=random_seed, random_seed=random_seed, + model_type=model_type, ) log( INFO, @@ -416,6 +470,7 @@ def train_three_sets_of_shadow_models( # Setting a different seed for the second train set init_data_seed=random_seed + 1 if random_seed is not None else None, random_seed=random_seed, + model_type=model_type, ) log( INFO, @@ -433,6 +488,7 @@ def train_three_sets_of_shadow_models( id_column_name=id_column_name, number_of_points_to_synthesize=number_of_points_to_synthesize, random_seed=random_seed, + model_type=model_type, ) log( INFO, diff --git a/src/midst_toolkit/attacks/ensemble/shadow_model_utils.py b/src/midst_toolkit/attacks/ensemble/shadow_model_utils.py index f1693c8e..c03af364 100644 --- a/src/midst_toolkit/attacks/ensemble/shadow_model_utils.py +++ b/src/midst_toolkit/attacks/ensemble/shadow_model_utils.py @@ -2,13 +2,17 @@ import json import os from dataclasses import dataclass +from enum import Enum from logging import INFO from pathlib import Path +from typing import Any import pandas as pd +from sdv.single_table import CTGANSynthesizer # type: ignore[import-untyped] +from examples.gan.utils import get_single_table_svd_metadata, get_table_name from midst_toolkit.attacks.ensemble.clavaddpm_fine_tuning import clava_fine_tuning -from midst_toolkit.common.config import TrainingConfig +from midst_toolkit.common.config import ClavaDDPMTrainingConfig, CTGANTrainingConfig, TrainingConfig from midst_toolkit.common.logger import log from midst_toolkit.common.variables import DEVICE from midst_toolkit.models.clavaddpm.clustering import clava_clustering @@ -19,26 +23,48 @@ RelationOrder, ) from midst_toolkit.models.clavaddpm.synthesizer import clava_synthesizing -from midst_toolkit.models.clavaddpm.train import ModelArtifacts, clava_training +from midst_toolkit.models.clavaddpm.train import ( + ClavaDDPMModelArtifacts, + CTGANModelArtifacts, + clava_training, +) -@dataclass +class ModelType(Enum): + TABDDPM = "tabddpm" + CTGAN = "ctgan" + + +@dataclass(kw_only=True) # Setting kw_only=True avoids and error with default values and inheritance class TrainingResult: save_dir: Path configs: TrainingConfig + models: Any + synthetic_data: pd.DataFrame | None = None + + +@dataclass +class CTGANTrainingResult(TrainingResult): + configs: CTGANTrainingConfig + models: dict[Relation, CTGANModelArtifacts] + + +@dataclass +class TabDDPMTrainingResult(TrainingResult): + configs: ClavaDDPMTrainingConfig + models: dict[Relation, ClavaDDPMModelArtifacts] tables: Tables relation_order: RelationOrder all_group_lengths_probabilities: GroupLengthsProbDicts - models: dict[Relation, ModelArtifacts] - synthetic_data: pd.DataFrame | None = None -def save_additional_tabddpm_config( +def save_additional_training_config( data_dir: Path, training_config_json_path: Path, final_config_json_path: Path, experiment_name: str = "attack_experiment", workspace_name: str = "shadow_workspace", + model_type: ModelType = ModelType.TABDDPM, ) -> tuple[TrainingConfig, Path]: """ Modifies a TabDDPM configuration JSON file with the specified data directory, experiment name and workspace name, @@ -50,14 +76,21 @@ def save_additional_tabddpm_config( final_config_json_path: Path where the modified configuration JSON file will be saved. experiment_name: Name of the experiment, used to create a unique save directory. workspace_name: Name of the workspace, used to create a unique save directory. + model_type: Type of model to be used for training the shadow models. Defaults to ModelType.TABDDPM. Returns: - configs: Loaded configuration dictionary for TabDDPM. + configs: Loaded configuration dictionary for the model type. save_dir: Directory path where results will be saved. """ # Modify the config file to give the correct training data and saving directory with open(training_config_json_path, "r") as file: - configs = TrainingConfig(**json.load(file)) + configs: TrainingConfig + if model_type == ModelType.TABDDPM: + configs = ClavaDDPMTrainingConfig(**json.load(file)) + elif model_type == ModelType.CTGAN: + configs = CTGANTrainingConfig(**json.load(file)) + else: + raise ValueError(f"Invalid model type: {model_type}") configs.general.data_dir = data_dir # Save dir is set by joining the workspace_dir and exp_name @@ -79,11 +112,11 @@ def save_additional_tabddpm_config( # TODO: This and the next function should be unified later. def train_tabddpm_and_synthesize( train_set: pd.DataFrame, - configs: TrainingConfig, + configs: ClavaDDPMTrainingConfig, save_dir: Path, synthesize: bool = True, number_of_points_to_synthesize: int = 20000, -) -> TrainingResult: +) -> TabDDPMTrainingResult: """ Train a TabDDPM model on the provided training set and optionally synthesize data using the trained models. @@ -120,7 +153,7 @@ def train_tabddpm_and_synthesize( classifier_config=configs.classifier, device=DEVICE, ) - result = TrainingResult( + result = TabDDPMTrainingResult( save_dir=save_dir, configs=configs, tables=tables, @@ -156,9 +189,9 @@ def train_tabddpm_and_synthesize( def fine_tune_tabddpm_and_synthesize( - trained_models: dict[Relation, ModelArtifacts], + trained_models: dict[Relation, ClavaDDPMModelArtifacts], fine_tune_set: pd.DataFrame, - configs: TrainingConfig, + configs: ClavaDDPMTrainingConfig, save_dir: Path, fine_tuning_diffusion_iterations: int = 100, fine_tuning_classifier_iterations: int = 10, @@ -213,7 +246,7 @@ def fine_tune_tabddpm_and_synthesize( fine_tuning_diffusion_iterations=fine_tuning_diffusion_iterations, fine_tuning_classifier_iterations=fine_tuning_classifier_iterations, ) - result = TrainingResult( + result = TabDDPMTrainingResult( save_dir=save_dir, configs=configs, tables=new_tables, @@ -248,6 +281,74 @@ def fine_tune_tabddpm_and_synthesize( return result +def train_or_fine_tune_and_synthesize_with_ctgan( + dataset: pd.DataFrame, + configs: CTGANTrainingConfig, + save_dir: Path, + synthesize: bool = True, + trained_model: CTGANSynthesizer | None = None, +) -> TrainingResult: + """ + Train or fine tune a CTGAN model on the provided dataset and optionally synthesize data. + + If no trained model is provided, a new model will be trained. Otherwise, the + provided model will be fine tuned. + + Args: + dataset: The dataset as a pandas DataFrame. + configs: Configuration dictionary for CTGAN. + save_dir: Directory path where models and results will be saved. + synthesize: Flag indicating whether to generate synthetic data after training. Defaults to True. + trained_model: The trained model to fine tune. If None, a new model will be trained. + + Returns: + A dataclass TrainingResult object containing: + - save_dir: Directory where results are saved. + - configs: Configuration dictionary used for training. + - models: The trained models. + - synthetic_data: The synthesized data as a pandas DataFrame, if synthesis was performed, + otherwise, None. + """ + table_name = get_table_name(configs.general.data_dir) + domain_file_path = configs.general.data_dir / f"{table_name}_domain.json" + with open(domain_file_path, "r") as file: + domain_dictionary = json.load(file) + + metadata, dataset_without_ids = get_single_table_svd_metadata(dataset, domain_dictionary) + + if trained_model is None: + log(INFO, "Training new CTGAN model...") + ctgan = CTGANSynthesizer( + metadata=metadata, + epochs=configs.training.epochs, + verbose=configs.training.verbose, + ) + model_name = "trained_ctgan_model.pkl" + else: + log(INFO, "Fine tuning CTGAN model...") + ctgan = trained_model + model_name = "fine_tuned_ctgan_model.pkl" + + ctgan.fit(dataset_without_ids) + + results_file = Path(save_dir) / model_name + results_file.parent.mkdir(parents=True, exist_ok=True) + + ctgan.save(results_file) + + result = CTGANTrainingResult( + save_dir=save_dir, + configs=configs, + models={(None, table_name): CTGANModelArtifacts(model=ctgan, model_file_path=results_file)}, + ) + + if synthesize: + synthetic_data = ctgan.sample(num_rows=configs.synthesizing.sample_size) + result.synthetic_data = synthetic_data + + return result + + # TODO: The following function is directly copied from the midst reference code since # I need it to run the attack code, but, it should probably be moved to somewhere else # as it is an essential part of a working TabDDPM training pipeline. diff --git a/src/midst_toolkit/common/config.py b/src/midst_toolkit/common/config.py index 2f5974f1..268ab567 100644 --- a/src/midst_toolkit/common/config.py +++ b/src/midst_toolkit/common/config.py @@ -18,7 +18,7 @@ class GeneralConfig(BaseModel): sample_prefix: str -class ClusteringConfig(BaseModel): +class ClavaDDPMClusteringConfig(BaseModel): """Configuration for the trainer's clustering model.""" num_clusters: int | dict[str, int] @@ -26,7 +26,7 @@ class ClusteringConfig(BaseModel): parent_scale: float -class DiffusionConfig(BaseModel): +class ClavaDDPMDiffusionConfig(BaseModel): """Configuration for the trainer's diffusion model.""" d_layers: list[int] @@ -49,7 +49,7 @@ def validate_data_split_ratios(self) -> Self: return self -class ClassifierConfig(BaseModel): +class ClavaDDPMClassifierConfig(BaseModel): """Configuration for the trainer's classifier model.""" d_layers: list[int] @@ -67,14 +67,14 @@ def validate_data_split_ratios(self) -> Self: return self -class SamplingConfig(BaseModel): +class ClavaDDPMSamplingConfig(BaseModel): """Configuration for the synthesizer's sampling process.""" batch_size: int classifier_scale: float -class MatchingConfig(BaseModel): +class ClavaDDPMMatchingConfig(BaseModel): """Configuration for the synthesizer's matching process.""" num_matching_clusters: int @@ -83,14 +83,39 @@ class MatchingConfig(BaseModel): no_matching: bool +class CTGANModelConfig(BaseModel): + """Configuration for the CTGAN model.""" + + epochs: int + verbose: bool + + +class CTGANSynthesizingConfig(BaseModel): + """Configuration for the CTGAN model.""" + + sample_size: int + + class TrainingConfig(BaseModel): - """All configuration settings for training, synthesizing, and fine tuning.""" + """Base configuration settings for training models.""" model_config = ConfigDict(extra="forbid") # disallow extra fields from config files general: GeneralConfig - clustering: ClusteringConfig - diffusion: DiffusionConfig - classifier: ClassifierConfig - sampling: SamplingConfig - matching: MatchingConfig + + +class ClavaDDPMTrainingConfig(TrainingConfig): + """All configuration settings for training, synthesizing, and fine tuning TabDDPM models.""" + + clustering: ClavaDDPMClusteringConfig + diffusion: ClavaDDPMDiffusionConfig + classifier: ClavaDDPMClassifierConfig + sampling: ClavaDDPMSamplingConfig + matching: ClavaDDPMMatchingConfig + + +class CTGANTrainingConfig(TrainingConfig): + """All configuration settings for training, synthesizing, and fine tuning CTGAN models.""" + + training: CTGANModelConfig + synthesizing: CTGANSynthesizingConfig diff --git a/src/midst_toolkit/models/clavaddpm/clustering.py b/src/midst_toolkit/models/clavaddpm/clustering.py index b841ec6d..66785fc7 100644 --- a/src/midst_toolkit/models/clavaddpm/clustering.py +++ b/src/midst_toolkit/models/clavaddpm/clustering.py @@ -13,7 +13,7 @@ from sklearn.mixture import BayesianGaussianMixture, GaussianMixture from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder, QuantileTransformer -from midst_toolkit.common.config import ClusteringConfig +from midst_toolkit.common.config import ClavaDDPMClusteringConfig from midst_toolkit.common.enumerations import DomainDataType from midst_toolkit.common.logger import log from midst_toolkit.models.clavaddpm.data_loaders import NO_PARENT_COLUMN_NAME, Tables @@ -29,7 +29,7 @@ def clava_clustering( tables: Tables, relation_order: RelationOrder, save_dir: Path, - configs: ClusteringConfig, + configs: ClavaDDPMClusteringConfig, ) -> tuple[dict[str, Any], GroupLengthsProbDicts]: """ Clustering function for the multi-table function of the ClavaDDPM model. @@ -96,7 +96,7 @@ def _load_clustering_info_from_checkpoint(save_dir: Path) -> dict[str, Any] | No def _run_clustering( tables: Tables, relation_order: RelationOrder, - configs: ClusteringConfig, + configs: ClavaDDPMClusteringConfig, ) -> tuple[Tables, GroupLengthsProbDicts]: """ Run the clustering process. @@ -112,7 +112,7 @@ def _run_clustering( - The tables dictionary. - The dictionary with the group lengths probability for all the parent-child pairs. """ - all_group_lengths_prob_dicts = {} + all_group_lengths_prob_dicts: GroupLengthsProbDicts = {} relation_order_reversed = relation_order[::-1] for parent, child in relation_order_reversed: if parent is not None: diff --git a/src/midst_toolkit/models/clavaddpm/enumerations.py b/src/midst_toolkit/models/clavaddpm/enumerations.py index f8af7830..fc8bb60b 100644 --- a/src/midst_toolkit/models/clavaddpm/enumerations.py +++ b/src/midst_toolkit/models/clavaddpm/enumerations.py @@ -3,7 +3,7 @@ import numpy as np -Relation = tuple[str, str] +Relation = tuple[str | None, str] RelationOrder = list[Relation] GroupLengthProbDict = dict[int, dict[int, float]] GroupLengthsProbDicts = dict[Relation, GroupLengthProbDict] diff --git a/src/midst_toolkit/models/clavaddpm/synthesizer.py b/src/midst_toolkit/models/clavaddpm/synthesizer.py index 44741bcd..c361615a 100644 --- a/src/midst_toolkit/models/clavaddpm/synthesizer.py +++ b/src/midst_toolkit/models/clavaddpm/synthesizer.py @@ -14,7 +14,7 @@ from torch.nn import functional from tqdm import tqdm -from midst_toolkit.common.config import GeneralConfig, MatchingConfig, SamplingConfig +from midst_toolkit.common.config import ClavaDDPMMatchingConfig, ClavaDDPMSamplingConfig, GeneralConfig from midst_toolkit.common.enumerations import DataSplit from midst_toolkit.common.logger import log from midst_toolkit.models.clavaddpm.data_loaders import NO_PARENT_COLUMN_NAME, Tables @@ -32,7 +32,7 @@ GaussianMultinomialDiffusion, ) from midst_toolkit.models.clavaddpm.model import Classifier, ModelParameters -from midst_toolkit.models.clavaddpm.train import ModelArtifacts, get_df_without_id +from midst_toolkit.models.clavaddpm.train import ClavaDDPMModelArtifacts, get_df_without_id def sample_from_diffusion( @@ -675,7 +675,7 @@ def clava_synthesizing_matching_process( synthetic_tables: dict[Relation, dict[str, Any]], tables: Tables, relation_order: RelationOrder, - matching_config: MatchingConfig, + matching_config: ClavaDDPMMatchingConfig, ) -> dict[str, pd.DataFrame]: """ Matches synthetic child tables to synthetic parent tables based on clustering information. @@ -711,10 +711,10 @@ def clava_synthesizing( tables: Tables, relation_order: RelationOrder, save_dir: Path, - models: dict[Relation, ModelArtifacts], + models: dict[Relation, ClavaDDPMModelArtifacts], general_config: GeneralConfig, - sampling_config: SamplingConfig, - matching_config: MatchingConfig, + sampling_config: ClavaDDPMSamplingConfig, + matching_config: ClavaDDPMMatchingConfig, all_group_lengths_prob_dicts: GroupLengthsProbDicts | None = None, sample_scale: float = 1.0, ) -> tuple[dict[str, pd.DataFrame], float, float]: @@ -827,7 +827,7 @@ def clava_synthesizing( def _synthesize_single_table( table_name: str, data: pd.DataFrame, - training_results: ModelArtifacts, + training_results: ClavaDDPMModelArtifacts, sample_scale: float, sample_batch_size: int, ) -> tuple[pd.DataFrame, list[int]]: @@ -883,8 +883,8 @@ def _synthesize_single_table( def _synthesize_multi_table( parent_name: str, child_name: str, - parent_training_results: ModelArtifacts, - child_training_results: ModelArtifacts, + parent_training_results: ClavaDDPMModelArtifacts, + child_training_results: ClavaDDPMModelArtifacts, parent_synthetic_data: dict[str, Any], data: pd.DataFrame, group_length_prob_dict: GroupLengthProbDict, diff --git a/src/midst_toolkit/models/clavaddpm/train.py b/src/midst_toolkit/models/clavaddpm/train.py index 3df3398a..cc8ee580 100644 --- a/src/midst_toolkit/models/clavaddpm/train.py +++ b/src/midst_toolkit/models/clavaddpm/train.py @@ -10,10 +10,11 @@ import numpy as np import pandas as pd import torch +from sdv.single_table import CTGANSynthesizer # type: ignore[import-untyped] from sklearn.preprocessing import LabelEncoder from torch import Tensor, optim -from midst_toolkit.common.config import ClassifierConfig, DiffusionConfig +from midst_toolkit.common.config import ClavaDDPMClassifierConfig, ClavaDDPMDiffusionConfig from midst_toolkit.common.enumerations import DataSplit, DomainDataType, TaskType from midst_toolkit.common.logger import KeyValueLogger, log from midst_toolkit.common.variables import DEVICE @@ -39,6 +40,17 @@ @dataclass class ModelArtifacts: + pass + + +@dataclass +class CTGANModelArtifacts(ModelArtifacts): + model: CTGANSynthesizer + model_file_path: Path + + +@dataclass +class ClavaDDPMModelArtifacts(ModelArtifacts): diffusion: GaussianMultinomialDiffusion label_encoders: dict[int, LabelEncoder] dataset: Dataset @@ -58,10 +70,10 @@ def clava_training( tables: Tables, relation_order: RelationOrder, save_dir: Path, - diffusion_config: DiffusionConfig, - classifier_config: ClassifierConfig | None = None, + diffusion_config: ClavaDDPMDiffusionConfig, + classifier_config: ClavaDDPMClassifierConfig | None = None, device: torch.device = DEVICE, -) -> tuple[Tables, dict[Relation, ModelArtifacts]]: +) -> tuple[Tables, dict[Relation, ClavaDDPMModelArtifacts]]: """ Training function for the ClavaDDPM model. @@ -123,10 +135,10 @@ def child_training( child_domain: dict[str, Any], parent_name: str | None, child_name: str, - diffusion_config: DiffusionConfig, - classifier_config: ClassifierConfig | None = None, + diffusion_config: ClavaDDPMDiffusionConfig, + classifier_config: ClavaDDPMClassifierConfig | None = None, device: torch.device = DEVICE, -) -> ModelArtifacts: +) -> ClavaDDPMModelArtifacts: """ Training function for a single child table. @@ -205,9 +217,9 @@ def train_model( table_metadata: TableMetadata, model_params: ModelParameters, transformations: Transformations, - diffusion_config: DiffusionConfig, + diffusion_config: ClavaDDPMDiffusionConfig, device: torch.device = DEVICE, -) -> ModelArtifacts: +) -> ClavaDDPMModelArtifacts: """ Training function for the diffusion model. @@ -281,7 +293,7 @@ def train_model( if dataset.numerical_transform is not None: inverse_transform_function = dataset.numerical_transform.inverse_transform - return ModelArtifacts( + return ClavaDDPMModelArtifacts( diffusion=diffusion, label_encoders=label_encoders, dataset=dataset, @@ -299,8 +311,8 @@ def train_classifier( table_metadata: TableMetadata, model_params: ModelParameters, transformations: Transformations, - diffusion_config: DiffusionConfig, - classifier_config: ClassifierConfig, + diffusion_config: ClavaDDPMDiffusionConfig, + classifier_config: ClavaDDPMClassifierConfig, device: torch.device = DEVICE, cluster_col: str = "cluster", classifier_evaluation_interval: int = 5, @@ -486,7 +498,7 @@ def get_table_metadata(df: pd.DataFrame, table_domain: dict[str, Any], target_co def save_table_info( tables: Tables, relation_order: RelationOrder, - models: dict[Relation, ModelArtifacts], + models: dict[Relation, ClavaDDPMModelArtifacts], save_dir: Path, ) -> None: """ diff --git a/tests/integration/attacks/ensemble/configs/shadow_training_config.yaml b/tests/integration/attacks/ensemble/configs/shadow_training_config.yaml index 717f3d82..4ece9d95 100644 --- a/tests/integration/attacks/ensemble/configs/shadow_training_config.yaml +++ b/tests/integration/attacks/ensemble/configs/shadow_training_config.yaml @@ -7,7 +7,7 @@ shadow_training: training_json_config_paths: # Config json files used for tabddpm training on the trans table table_domain_file_path: ${base_test_assets_dir}/data_configs/trans_domain.json dataset_meta_file_path: ${base_test_assets_dir}/data_configs/dataset_meta.json - tabddpm_training_config_path: ${base_test_assets_dir}/data_configs/trans.json + training_config_path: ${base_test_assets_dir}/data_configs/trans.json # Model training artifacts are saved under shadow_models_data_path/workspace_name/exp_name # Also, training configs for each shadow model are created under shadow_models_data_path. shadow_models_output_path: ${base_test_assets_dir}/shadow_models_data diff --git a/tests/integration/attacks/ensemble/test_shadow_model_training.py b/tests/integration/attacks/ensemble/test_shadow_model_training.py index 4e2e6eb9..8008f97d 100644 --- a/tests/integration/attacks/ensemble/test_shadow_model_training.py +++ b/tests/integration/attacks/ensemble/test_shadow_model_training.py @@ -2,6 +2,7 @@ import pickle import shutil from pathlib import Path +from typing import cast import pandas as pd import pytest @@ -15,9 +16,10 @@ ) from midst_toolkit.attacks.ensemble.shadow_model_utils import ( fine_tune_tabddpm_and_synthesize, - save_additional_tabddpm_config, + save_additional_training_config, train_tabddpm_and_synthesize, ) +from midst_toolkit.common.config import ClavaDDPMTrainingConfig POPULATION_DATA = load_dataframe( @@ -122,7 +124,7 @@ def test_train_and_fine_tune_tabddpm(cfg: DictConfig, tmp_path: Path) -> None: "tests/unit/attacks/ensemble/assets/population_data/all_population.csv" ) # For testing purposes only. fine_tuning_set = copy.deepcopy(train_set) - tabddpm_config_path = Path(cfg.shadow_training.training_json_config_paths.tabddpm_training_config_path) + training_config_path = Path(cfg.shadow_training.training_json_config_paths.training_config_path) tmp_training_dir = tmp_path # We should move ``dataset_meta.json`` and ``trans_domain.json`` files to the ``tmp_training_dir`` assert Path(cfg.shadow_training.training_json_config_paths.table_domain_file_path).exists() @@ -135,16 +137,20 @@ def test_train_and_fine_tune_tabddpm(cfg: DictConfig, tmp_path: Path) -> None: cfg.shadow_training.training_json_config_paths.dataset_meta_file_path, tmp_training_dir / "dataset_meta.json", ) - configs, save_dir = save_additional_tabddpm_config( + configs, save_dir = save_additional_training_config( data_dir=tmp_training_dir, - training_config_json_path=tabddpm_config_path, + training_config_json_path=training_config_path, final_config_json_path=tmp_training_dir / "trans.json", experiment_name="test_experiment", workspace_name="test_workspace", ) train_result = train_tabddpm_and_synthesize( - train_set, configs, save_dir, synthesize=True, number_of_points_to_synthesize=99 + train_set, + cast(ClavaDDPMTrainingConfig, configs), + save_dir, + synthesize=True, + number_of_points_to_synthesize=99, ) assert train_result.synthetic_data is not None assert type(train_result.synthetic_data) is pd.DataFrame @@ -158,7 +164,7 @@ def test_train_and_fine_tune_tabddpm(cfg: DictConfig, tmp_path: Path) -> None: fine_tuned_results = fine_tune_tabddpm_and_synthesize( trained_models=train_result.models, fine_tune_set=fine_tuning_set, # fine-tuning on the same data for testing purposes - configs=configs, + configs=cast(ClavaDDPMTrainingConfig, configs), save_dir=save_dir, fine_tuning_diffusion_iterations=cfg.shadow_training.fine_tuning_config.fine_tune_diffusion_iterations, fine_tuning_classifier_iterations=cfg.shadow_training.fine_tuning_config.fine_tune_classifier_iterations, diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index 43689ff0..aec00323 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -10,7 +10,7 @@ import torch from torch.nn import functional -from midst_toolkit.common.config import ClassifierConfig, ClusteringConfig, DiffusionConfig +from midst_toolkit.common.config import ClavaDDPMClassifierConfig, ClavaDDPMClusteringConfig, ClavaDDPMDiffusionConfig from midst_toolkit.common.logger import log from midst_toolkit.common.random import set_all_random_seeds, unset_all_random_seeds from midst_toolkit.common.variables import DEVICE @@ -33,13 +33,13 @@ from tests.integration.utils import is_running_on_ci_environment -CLUSTERING_CONFIG = ClusteringConfig( +CLUSTERING_CONFIG = ClavaDDPMClusteringConfig( parent_scale=1.0, num_clusters=3, clustering_method=ClusteringMethod.KMEANS_AND_GMM, ) -DIFFUSION_CONFIG = DiffusionConfig( +DIFFUSION_CONFIG = ClavaDDPMDiffusionConfig( d_layers=[512, 1024, 1024, 1024, 1024, 512], dropout=0.0, num_timesteps=100, @@ -53,7 +53,7 @@ data_split_ratios=[0.99, 0.005, 0.005], ) -CLASSIFIER_CONFIG = ClassifierConfig( +CLASSIFIER_CONFIG = ClavaDDPMClassifierConfig( d_layers=[128, 256, 512, 1024, 512, 256, 128], lr=0.0001, dim_t=128, diff --git a/tests/integration/models/clavaddpm/test_synthesizer.py b/tests/integration/models/clavaddpm/test_synthesizer.py index ae88477e..46c66030 100644 --- a/tests/integration/models/clavaddpm/test_synthesizer.py +++ b/tests/integration/models/clavaddpm/test_synthesizer.py @@ -5,12 +5,12 @@ import pytest from midst_toolkit.common.config import ( - ClassifierConfig, - ClusteringConfig, - DiffusionConfig, + ClavaDDPMClassifierConfig, + ClavaDDPMClusteringConfig, + ClavaDDPMDiffusionConfig, + ClavaDDPMMatchingConfig, + ClavaDDPMSamplingConfig, GeneralConfig, - MatchingConfig, - SamplingConfig, ) from midst_toolkit.common.logger import log from midst_toolkit.common.random import set_all_random_seeds, unset_all_random_seeds @@ -25,13 +25,13 @@ from tests.integration.utils import is_running_on_ci_environment -CLUSTERING_CONFIG = ClusteringConfig( +CLUSTERING_CONFIG = ClavaDDPMClusteringConfig( parent_scale=1.0, num_clusters=3, clustering_method=ClusteringMethod.KMEANS_AND_GMM, ) -DIFFUSION_CONFIG = DiffusionConfig( +DIFFUSION_CONFIG = ClavaDDPMDiffusionConfig( d_layers=[512, 1024, 1024, 1024, 1024, 512], dropout=0.0, num_timesteps=100, @@ -45,7 +45,7 @@ data_split_ratios=[0.99, 0.005, 0.005], ) -CLASSIFIER_CONFIG = ClassifierConfig( +CLASSIFIER_CONFIG = ClavaDDPMClassifierConfig( d_layers=[128, 256, 512, 1024, 512, 256, 128], lr=0.0001, dim_t=128, @@ -62,12 +62,12 @@ sample_prefix="", ) -SAMPLING_CONFIG = SamplingConfig( +SAMPLING_CONFIG = ClavaDDPMSamplingConfig( batch_size=2, classifier_scale=1.0, ) -MATCHING_CONFIG = MatchingConfig( +MATCHING_CONFIG = ClavaDDPMMatchingConfig( num_matching_clusters=1, matching_batch_size=1, unique_matching=True, diff --git a/tests/unit/attacks/ensemble/configs/shadow_training_config.yaml b/tests/unit/attacks/ensemble/configs/shadow_training_config.yaml index a6319f49..ed8e5a1f 100644 --- a/tests/unit/attacks/ensemble/configs/shadow_training_config.yaml +++ b/tests/unit/attacks/ensemble/configs/shadow_training_config.yaml @@ -7,7 +7,7 @@ shadow_training: training_json_config_paths: # Config json files used for tabddpm training on the trans table table_domain_file_path: ${base_test_assets_dir}/data_configs/trans_domain.json dataset_meta_file_path: ${base_test_assets_dir}/data_configs/dataset_meta.json - tabddpm_training_config_path: ${base_test_assets_dir}/data_configs/trans.json + training_config_path: ${base_test_assets_dir}/data_configs/trans.json # Model training artifacts are saved under shadow_models_data_path/workspace_name/exp_name # Also, training configs for each shadow model are created under shadow_models_data_path. shadow_models_output_path: ${base_test_assets_dir}/shadow_models_data diff --git a/tests/unit/attacks/ensemble/test_shadow_model_utils.py b/tests/unit/attacks/ensemble/test_shadow_model_utils.py index 7222b3ff..722918ea 100644 --- a/tests/unit/attacks/ensemble/test_shadow_model_utils.py +++ b/tests/unit/attacks/ensemble/test_shadow_model_utils.py @@ -6,7 +6,7 @@ from omegaconf import DictConfig from midst_toolkit.attacks.ensemble.shadow_model_utils import ( - save_additional_tabddpm_config, + save_additional_training_config, ) @@ -18,7 +18,7 @@ def cfg() -> DictConfig: def test_save_additional_tabddpm_config(cfg: DictConfig, tmp_path: Path) -> None: # Input path - tabddpm_config_path = Path(cfg.shadow_training.training_json_config_paths.tabddpm_training_config_path) + tabddpm_config_path = Path(cfg.shadow_training.training_json_config_paths.training_config_path) # Extract original parameters with open(tabddpm_config_path, "r") as file: @@ -33,7 +33,7 @@ def test_save_additional_tabddpm_config(cfg: DictConfig, tmp_path: Path) -> None new_experiment_name = "test_experiment" final_json_path = tmp_path / "modified_config.json" - configs, save_dir = save_additional_tabddpm_config( + configs, save_dir = save_additional_training_config( data_dir=new_data_dir, training_config_json_path=tabddpm_config_path, final_config_json_path=final_json_path,